This is an automated email from the ASF dual-hosted git repository.
jimin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-seata-go.git
The following commit(s) were added to refs/heads/master by this push:
new df6a5a9b test: improve test coverage for pkg/datasource/sql/types
(#975)
df6a5a9b is described below
commit df6a5a9bdde4fac874dd1d30bbe5bca905edf2ba
Author: 深几许 <[email protected]>
AuthorDate: Wed Nov 5 11:06:03 2025 +0800
test: improve test coverage for pkg/datasource/sql/types (#975)
---
pkg/datasource/sql/types/const_test.go | 195 +++++++++++
pkg/datasource/sql/types/dbtype_string_test.go | 57 +++
pkg/datasource/sql/types/executor_test.go | 206 +++++++++++
pkg/datasource/sql/types/image_test.go | 150 ++++++++
pkg/datasource/sql/types/key_type_test.go | 77 ++++
pkg/datasource/sql/types/meta_test.go | 94 +++++
...ketword_checker.go => mysql_keyword_checker.go} | 0
.../sql/types/mysql_keyword_checker_test.go | 263 ++++++++++++++
pkg/datasource/sql/types/sql_data_type_test.go | 193 ++++++++++
pkg/datasource/sql/types/sql_test.go | 216 ++++++++++++
pkg/datasource/sql/types/types_test.go | 387 +++++++++++++++++++++
11 files changed, 1838 insertions(+)
diff --git a/pkg/datasource/sql/types/const_test.go
b/pkg/datasource/sql/types/const_test.go
new file mode 100644
index 00000000..bd4b42b9
--- /dev/null
+++ b/pkg/datasource/sql/types/const_test.go
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package types
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestMySQLCodeToJava(t *testing.T) {
+ tests := []struct {
+ name string
+ mysqlType MySQLDefCode
+ expected JDBCType
+ }{
+ {"FIELD_TYPE_DECIMAL", FIELD_TYPE_DECIMAL, JDBCTypeDecimal},
+ {"FIELD_TYPE_NEW_DECIMAL", FIELD_TYPE_NEW_DECIMAL,
JDBCTypeDecimal},
+ {"FIELD_TYPE_TINY", FIELD_TYPE_TINY, JDBCTypeTinyInt},
+ {"FIELD_TYPE_SHORT", FIELD_TYPE_SHORT, JDBCTypeSmallInt},
+ {"FIELD_TYPE_LONG", FIELD_TYPE_LONG, JDBCTypeInteger},
+ {"FIELD_TYPE_FLOAT", FIELD_TYPE_FLOAT, JDBCTypeReal},
+ {"FIELD_TYPE_DOUBLE", FIELD_TYPE_DOUBLE, JDBCTypeDouble},
+ {"FIELD_TYPE_NULL", FIELD_TYPE_NULL, JDBCTypeNull},
+ {"FIELD_TYPE_TIMESTAMP", FIELD_TYPE_TIMESTAMP,
JDBCTypeTimestamp},
+ {"FIELD_TYPE_LONGLONG", FIELD_TYPE_LONGLONG, JDBCTypeBigInt},
+ {"FIELD_TYPE_INT24", FIELD_TYPE_INT24, JDBCTypeInteger},
+ {"FIELD_TYPE_DATE", FIELD_TYPE_DATE, JDBCTypeDate},
+ {"FIELD_TYPE_TIME", FIELD_TYPE_TIME, JDBCTypeTime},
+ {"FIELD_TYPE_DATETIME", FIELD_TYPE_DATETIME, JDBCTypeTimestamp},
+ {"FIELD_TYPE_YEAR", FIELD_TYPE_YEAR, JDBCTypeDate},
+ {"FIELD_TYPE_NEWDATE", FIELD_TYPE_NEWDATE, JDBCTypeDate},
+ {"FIELD_TYPE_ENUM", FIELD_TYPE_ENUM, JDBCTypeChar},
+ {"FIELD_TYPE_SET", FIELD_TYPE_SET, JDBCTypeChar},
+ {"FIELD_TYPE_TINY_BLOB", FIELD_TYPE_TINY_BLOB,
JDBCTypeVarBinary},
+ {"FIELD_TYPE_MEDIUM_BLOB", FIELD_TYPE_MEDIUM_BLOB,
JDBCTypeLongVarBinary},
+ {"FIELD_TYPE_LONG_BLOB", FIELD_TYPE_LONG_BLOB,
JDBCTypeLongVarBinary},
+ {"FIELD_TYPE_BLOB", FIELD_TYPE_BLOB, JDBCTypeLongVarBinary},
+ {"FIELD_TYPE_VAR_STRING", FIELD_TYPE_VAR_STRING,
JDBCTypeVarchar},
+ {"FIELD_TYPE_VARCHAR", FIELD_TYPE_VARCHAR, JDBCTypeVarchar},
+ {"FIELD_TYPE_STRING", FIELD_TYPE_STRING, JDBCTypeChar},
+ {"FIELD_TYPE_JSON", FIELD_TYPE_JSON, JDBCTypeChar},
+ {"FIELD_TYPE_GEOMETRY", FIELD_TYPE_GEOMETRY, JDBCTypeBinary},
+ {"FIELD_TYPE_BIT", FIELD_TYPE_BIT, JDBCTypeBit},
+ {"Unknown type", MySQLDefCode(9999), JDBCTypeVarchar}, //
default case
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := MySQLCodeToJava(tt.mysqlType)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestMySQLStrToJavaType(t *testing.T) {
+ tests := []struct {
+ name string
+ mysqlType string
+ expected JDBCType
+ }{
+ {"BIT", "BIT", JDBCTypeBit},
+ {"bit lowercase", "bit", JDBCTypeBit},
+ {"TINYINT", "TINYINT", JDBCTypeTinyInt},
+ {"SMALLINT", "SMALLINT", JDBCTypeSmallInt},
+ {"MEDIUMINT", "MEDIUMINT", JDBCTypeInteger},
+ {"INT", "INT", JDBCTypeInteger},
+ {"INTEGER", "INTEGER", JDBCTypeInteger},
+ {"BIGINT", "BIGINT", JDBCTypeBigInt},
+ {"INT24", "INT24", JDBCTypeInteger},
+ {"REAL", "REAL", JDBCTypeDouble},
+ {"FLOAT", "FLOAT", JDBCTypeReal},
+ {"DECIMAL", "DECIMAL", JDBCTypeDecimal},
+ {"NUMERIC", "NUMERIC", JDBCTypeDecimal},
+ {"DOUBLE", "DOUBLE", JDBCTypeDouble},
+ {"CHAR", "CHAR", JDBCTypeChar},
+ {"VARCHAR", "VARCHAR", JDBCTypeVarchar},
+ {"DATE", "DATE", JDBCTypeDate},
+ {"TIME", "TIME", JDBCTypeTime},
+ {"YEAR", "YEAR", JDBCTypeDate},
+ {"TIMESTAMP", "TIMESTAMP", JDBCTypeTimestamp},
+ {"DATETIME", "DATETIME", JDBCTypeTimestamp},
+ {"TINYBLOB", "TINYBLOB", JDBCTypeBinary},
+ {"BLOB", "BLOB", JDBCTypeLongVarBinary},
+ {"MEDIUMBLOB", "MEDIUMBLOB", JDBCTypeLongVarBinary},
+ {"LONGBLOB", "LONGBLOB", JDBCTypeLongVarBinary},
+ {"TINYTEXT", "TINYTEXT", JDBCTypeVarchar},
+ {"TEXT", "TEXT", JDBCTypeLongVarchar},
+ {"MEDIUMTEXT", "MEDIUMTEXT", JDBCTypeLongVarchar},
+ {"LONGTEXT", "LONGTEXT", JDBCTypeLongVarchar},
+ {"ENUM", "ENUM", JDBCTypeChar},
+ {"SET", "SET", JDBCTypeChar},
+ {"GEOMETRY", "GEOMETRY", JDBCTypeBinary},
+ {"BINARY", "BINARY", JDBCTypeBinary},
+ {"VARBINARY", "VARBINARY", JDBCTypeVarBinary},
+ {"JSON", "JSON", JDBCTypeChar},
+ {"Unknown type", "UNKNOWN_TYPE", JDBCTypeOther}, // default case
+ {"Mixed case", "VarChar", JDBCTypeVarchar},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := MySQLStrToJavaType(tt.mysqlType)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestFieldTypeConstants(t *testing.T) {
+ // Test that FieldType constants have expected values
+ assert.Equal(t, FieldType(0), FieldTypeDecimal)
+ assert.Equal(t, FieldType(1), FieldTypeTiny)
+ assert.Equal(t, FieldType(2), FieldTypeShort)
+ assert.Equal(t, FieldType(3), FieldTypeLong)
+ assert.Equal(t, FieldType(4), FieldTypeFloat)
+ assert.Equal(t, FieldType(5), FieldTypeDouble)
+ assert.Equal(t, FieldType(6), FieldTypeNULL)
+ assert.Equal(t, FieldType(7), FieldTypeTimestamp)
+ assert.Equal(t, FieldType(8), FieldTypeLongLong)
+ assert.Equal(t, FieldType(9), FieldTypeInt24)
+ assert.Equal(t, FieldType(10), FieldTypeDate)
+ assert.Equal(t, FieldType(11), FieldTypeTime)
+ assert.Equal(t, FieldType(12), FieldTypeDateTime)
+ assert.Equal(t, FieldType(13), FieldTypeYear)
+ assert.Equal(t, FieldType(14), FieldTypeNewDate)
+ assert.Equal(t, FieldType(15), FieldTypeVarChar)
+ assert.Equal(t, FieldType(16), FieldTypeBit)
+}
+
+func TestJDBCTypeConstants(t *testing.T) {
+ // Test some key JDBC type constants
+ assert.Equal(t, JDBCType(-7), JDBCTypeBit)
+ assert.Equal(t, JDBCType(-6), JDBCTypeTinyInt)
+ assert.Equal(t, JDBCType(5), JDBCTypeSmallInt)
+ assert.Equal(t, JDBCType(4), JDBCTypeInteger)
+ assert.Equal(t, JDBCType(-5), JDBCTypeBigInt)
+ assert.Equal(t, JDBCType(6), JDBCTypeFloat)
+ assert.Equal(t, JDBCType(7), JDBCTypeReal)
+ assert.Equal(t, JDBCType(8), JDBCTypeDouble)
+ assert.Equal(t, JDBCType(2), JDBCTypeNumberic)
+ assert.Equal(t, JDBCType(3), JDBCTypeDecimal)
+ assert.Equal(t, JDBCType(1), JDBCTypeChar)
+ assert.Equal(t, JDBCType(12), JDBCTypeVarchar)
+ assert.Equal(t, JDBCType(-1), JDBCTypeLongVarchar)
+ assert.Equal(t, JDBCType(91), JDBCTypeDate)
+ assert.Equal(t, JDBCType(92), JDBCTypeTime)
+ assert.Equal(t, JDBCType(93), JDBCTypeTimestamp)
+ assert.Equal(t, JDBCType(-2), JDBCTypeBinary)
+ assert.Equal(t, JDBCType(-3), JDBCTypeVarBinary)
+ assert.Equal(t, JDBCType(-4), JDBCTypeLongVarBinary)
+ assert.Equal(t, JDBCType(0), JDBCTypeNull)
+ assert.Equal(t, JDBCType(1111), JDBCTypeOther)
+}
+
+func TestMySQLDefCodeConstants(t *testing.T) {
+ // Test some key MySQL def code constants
+ assert.Equal(t, MySQLDefCode(0), FIELD_TYPE_DECIMAL)
+ assert.Equal(t, MySQLDefCode(1), FIELD_TYPE_TINY)
+ assert.Equal(t, MySQLDefCode(2), FIELD_TYPE_SHORT)
+ assert.Equal(t, MySQLDefCode(3), FIELD_TYPE_LONG)
+ assert.Equal(t, MySQLDefCode(4), FIELD_TYPE_FLOAT)
+ assert.Equal(t, MySQLDefCode(5), FIELD_TYPE_DOUBLE)
+ assert.Equal(t, MySQLDefCode(6), FIELD_TYPE_NULL)
+ assert.Equal(t, MySQLDefCode(7), FIELD_TYPE_TIMESTAMP)
+ assert.Equal(t, MySQLDefCode(8), FIELD_TYPE_LONGLONG)
+ assert.Equal(t, MySQLDefCode(9), FIELD_TYPE_INT24)
+ assert.Equal(t, MySQLDefCode(10), FIELD_TYPE_DATE)
+ assert.Equal(t, MySQLDefCode(11), FIELD_TYPE_TIME)
+ assert.Equal(t, MySQLDefCode(12), FIELD_TYPE_DATETIME)
+ assert.Equal(t, MySQLDefCode(13), FIELD_TYPE_YEAR)
+ assert.Equal(t, MySQLDefCode(14), FIELD_TYPE_NEWDATE)
+ assert.Equal(t, MySQLDefCode(15), FIELD_TYPE_VARCHAR)
+ assert.Equal(t, MySQLDefCode(16), FIELD_TYPE_BIT)
+}
+
+func TestXAErrorCodeConstants(t *testing.T) {
+ // Test XA error code constants
+ assert.Equal(t, 1399, ErrCodeXAER_RMFAIL_IDLE)
+ assert.Equal(t, 1400, ErrCodeXAER_INVAL)
+}
diff --git a/pkg/datasource/sql/types/dbtype_string_test.go
b/pkg/datasource/sql/types/dbtype_string_test.go
new file mode 100644
index 00000000..04f6354b
--- /dev/null
+++ b/pkg/datasource/sql/types/dbtype_string_test.go
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package types
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestDBType_String(t *testing.T) {
+ tests := []struct {
+ name string
+ dbType DBType
+ expected string
+ }{
+ {"DBTypeUnknown", DBTypeUnknown, "DBTypeUnknown"},
+ {"DBTypeMySQL", DBTypeMySQL, "DBTypeMySQL"},
+ {"DBTypePostgreSQL", DBTypePostgreSQL, "DBTypePostgreSQL"},
+ {"DBTypeSQLServer", DBTypeSQLServer, "DBTypeSQLServer"},
+ {"DBTypeOracle", DBTypeOracle, "DBTypeOracle"},
+ {"DBTypeMARIADB", DBTypeMARIADB, "DBType(6)"},
+ {"Invalid negative", DBType(-1), "DBType(-1)"},
+ {"Invalid large", DBType(100), "DBType(100)"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := tt.dbType.String()
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestDBType_StringValidation(t *testing.T) {
+ // Test that all valid constants return proper string representations
+ assert.Contains(t, DBTypeUnknown.String(), "Unknown")
+ assert.Contains(t, DBTypeMySQL.String(), "MySQL")
+ assert.Contains(t, DBTypePostgreSQL.String(), "PostgreSQL")
+ assert.Contains(t, DBTypeSQLServer.String(), "SQLServer")
+ assert.Contains(t, DBTypeOracle.String(), "Oracle")
+}
diff --git a/pkg/datasource/sql/types/executor_test.go
b/pkg/datasource/sql/types/executor_test.go
new file mode 100644
index 00000000..d05d8a5b
--- /dev/null
+++ b/pkg/datasource/sql/types/executor_test.go
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package types
+
+import (
+ "testing"
+
+ "github.com/arana-db/parser/ast"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestExecutorTypeConstants(t *testing.T) {
+ // Test ExecutorType constants
+ assert.Equal(t, int(UnSupportExecutor), 1)
+ assert.Equal(t, int(InsertExecutor), 2)
+ assert.Equal(t, int(UpdateExecutor), 3)
+ assert.Equal(t, int(SelectForUpdateExecutor), 4)
+ assert.Equal(t, int(SelectExecutor), 5)
+ assert.Equal(t, int(DeleteExecutor), 6)
+ assert.Equal(t, int(ReplaceIntoExecutor), 7)
+ assert.Equal(t, int(MultiExecutor), 8)
+ assert.Equal(t, int(MultiDeleteExecutor), 9)
+ assert.Equal(t, int(InsertOnDuplicateExecutor), 10)
+}
+
+func TestParseContext_HasValidStmt(t *testing.T) {
+ tests := []struct {
+ name string
+ context *ParseContext
+ expected bool
+ }{
+ {
+ name: "has insert stmt",
+ context: &ParseContext{
+ InsertStmt: &ast.InsertStmt{},
+ },
+ expected: true,
+ },
+ {
+ name: "has update stmt",
+ context: &ParseContext{
+ UpdateStmt: &ast.UpdateStmt{},
+ },
+ expected: true,
+ },
+ {
+ name: "has delete stmt",
+ context: &ParseContext{
+ DeleteStmt: &ast.DeleteStmt{},
+ },
+ expected: true,
+ },
+ {
+ name: "has select stmt only",
+ context: &ParseContext{
+ SelectStmt: &ast.SelectStmt{},
+ },
+ expected: false,
+ },
+ {
+ name: "has multiple valid stmts",
+ context: &ParseContext{
+ InsertStmt: &ast.InsertStmt{},
+ UpdateStmt: &ast.UpdateStmt{},
+ },
+ expected: true,
+ },
+ {
+ name: "empty context",
+ context: &ParseContext{},
+ expected: false,
+ },
+ {
+ name: "nil context",
+ context: nil,
+ expected: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if tt.context == nil {
+ assert.False(t, false) // Skip nil test
+ return
+ }
+ result := tt.context.HasValidStmt()
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestParseContext_GetTableName(t *testing.T) {
+ tests := []struct {
+ name string
+ context *ParseContext
+ expectError bool
+ errorMsg string
+ }{
+ {
+ name: "multi stmt with invalid sub-context",
+ context: &ParseContext{
+ MultiStmt: []*ParseContext{
+ {}, // empty context
+ },
+ },
+ expectError: true,
+ errorMsg: "invalid stmt",
+ },
+ {
+ name: "multi stmt with multiple invalid sub-contexts",
+ context: &ParseContext{
+ MultiStmt: []*ParseContext{
+ {}, // empty context
+ {}, // empty context
+ },
+ },
+ expectError: true,
+ errorMsg: "invalid stmt",
+ },
+ {
+ name: "empty context",
+ context: &ParseContext{},
+ expectError: true,
+ errorMsg: "invalid stmt",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tableName, err := tt.context.GetTableName()
+
+ if tt.expectError {
+ assert.Error(t, err)
+ if tt.errorMsg != "" {
+ assert.Contains(t, err.Error(),
tt.errorMsg)
+ }
+ assert.Empty(t, tableName)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestParseContext_GetTableName_EmptyMultiStmt(t *testing.T) {
+ context := &ParseContext{
+ MultiStmt: []*ParseContext{},
+ }
+
+ _, err := context.GetTableName()
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "invalid stmt")
+}
+
+func TestParseContext_GetTableName_MultiStmtWithEmptyTableName(t *testing.T) {
+ // Create a context that would return empty table name from sub-context
+ context := &ParseContext{
+ MultiStmt: []*ParseContext{
+ {
+ // This will cause GetTableName to return empty
string and error
+ },
+ },
+ }
+
+ _, err := context.GetTableName()
+ assert.Error(t, err)
+}
+
+func TestParseContext_StructFields(t *testing.T) {
+ context := &ParseContext{
+ SQLType: SQLTypeInsert,
+ ExecutorType: InsertExecutor,
+ InsertStmt: &ast.InsertStmt{},
+ UpdateStmt: &ast.UpdateStmt{},
+ SelectStmt: &ast.SelectStmt{},
+ DeleteStmt: &ast.DeleteStmt{},
+ MultiStmt: []*ParseContext{
+ {SQLType: SQLTypeSelect},
+ },
+ }
+
+ // Test that fields are properly set
+ assert.NotNil(t, context.InsertStmt)
+ assert.NotNil(t, context.UpdateStmt)
+ assert.NotNil(t, context.SelectStmt)
+ assert.NotNil(t, context.DeleteStmt)
+ assert.Len(t, context.MultiStmt, 1)
+ // Basic validation that struct was initialized correctly
+ assert.NotZero(t, context.SQLType)
+ assert.NotZero(t, context.ExecutorType)
+}
diff --git a/pkg/datasource/sql/types/image_test.go
b/pkg/datasource/sql/types/image_test.go
index 88003b9b..6c627e3f 100644
--- a/pkg/datasource/sql/types/image_test.go
+++ b/pkg/datasource/sql/types/image_test.go
@@ -98,3 +98,153 @@ func TestColumnImage_UnmarshalJSON(t *testing.T) {
})
}
}
+
+func TestRoundRecordImage_Methods(t *testing.T) {
+ before := []*RecordImage{{Rows: []RowImage{{}}}}
+ after := []*RecordImage{{Rows: []RowImage{{}}}}
+
+ round := &RoundRecordImage{
+ before: before,
+ after: after,
+ }
+
+ t.Run("AppendBeofreImages", func(t *testing.T) {
+ newImages := []*RecordImage{{Rows: []RowImage{{}}}}
+ round.AppendBeofreImages(newImages)
+ assert.Len(t, round.before, 2)
+ })
+
+ t.Run("AppendBeofreImage", func(t *testing.T) {
+ newImage := &RecordImage{Rows: []RowImage{{}}}
+ round.AppendBeofreImage(newImage)
+ assert.Len(t, round.before, 3)
+ })
+
+ t.Run("AppendAfterImages", func(t *testing.T) {
+ newImages := []*RecordImage{{Rows: []RowImage{{}}}}
+ round.AppendAfterImages(newImages)
+ assert.Len(t, round.after, 2)
+ })
+
+ t.Run("AppendAfterImage", func(t *testing.T) {
+ newImage := &RecordImage{Rows: []RowImage{{}}}
+ round.AppendAfterImage(newImage)
+ assert.Len(t, round.after, 3)
+ })
+
+ t.Run("BeofreImages", func(t *testing.T) {
+ result := round.BeofreImages()
+ assert.Equal(t, round.before, result)
+ })
+
+ t.Run("AfterImages", func(t *testing.T) {
+ result := round.AfterImages()
+ assert.Equal(t, round.after, result)
+ })
+
+ t.Run("IsBeforeAfterSizeEq", func(t *testing.T) {
+ result := round.IsBeforeAfterSizeEq()
+ assert.True(t, result) // both have 3 elements now
+ })
+}
+
+func TestRecordImages_Reserve(t *testing.T) {
+ images := RecordImages{
+ &RecordImage{TableName: "table1"},
+ &RecordImage{TableName: "table2"},
+ &RecordImage{TableName: "table3"},
+ }
+
+ images.Reserve()
+
+ assert.Equal(t, "table3", images[0].TableName)
+ assert.Equal(t, "table2", images[1].TableName)
+ assert.Equal(t, "table1", images[2].TableName)
+}
+
+func TestRecordImages_IsEmptyImage(t *testing.T) {
+ t.Run("empty slice", func(t *testing.T) {
+ images := RecordImages{}
+ assert.True(t, images.IsEmptyImage())
+ })
+
+ t.Run("nil slice", func(t *testing.T) {
+ var images RecordImages
+ assert.True(t, images.IsEmptyImage())
+ })
+
+ t.Run("non-empty slice with empty images", func(t *testing.T) {
+ images := RecordImages{&RecordImage{}}
+ assert.True(t, images.IsEmptyImage())
+ })
+
+ t.Run("non-empty slice with non-empty images", func(t *testing.T) {
+ images := RecordImages{&RecordImage{Rows: []RowImage{{}}}}
+ assert.False(t, images.IsEmptyImage())
+ })
+}
+
+func TestNewEmptyRecordImage(t *testing.T) {
+ tableMeta := &TableMeta{TableName: "test_table"}
+ sqlType := SQLType(SQLTypeInsert)
+ image := NewEmptyRecordImage(tableMeta, sqlType)
+
+ assert.Equal(t, "test_table", image.TableName)
+ assert.Equal(t, sqlType, image.SQLType)
+ assert.Equal(t, tableMeta, image.TableMeta)
+ assert.Empty(t, image.Rows)
+}
+
+func TestRowImage_GetColumnMap(t *testing.T) {
+ row := &RowImage{
+ Columns: []ColumnImage{
+ {ColumnName: "id", Value: 1},
+ {ColumnName: "name", Value: "test"},
+ },
+ }
+
+ columnMap := row.GetColumnMap()
+ assert.Len(t, columnMap, 2)
+ assert.Equal(t, 1, columnMap["id"].Value)
+ assert.Equal(t, "test", columnMap["name"].Value)
+}
+
+func TestColumnImage_GetActualValue(t *testing.T) {
+ tests := []struct {
+ name string
+ column *ColumnImage
+ expected interface{}
+ }{
+ {
+ name: "string value from bytes",
+ column: &ColumnImage{
+ ColumnType: JDBCTypeVarchar,
+ Value: []byte("test"),
+ },
+ expected: []byte("test"),
+ },
+ {
+ name: "direct value",
+ column: &ColumnImage{
+ ColumnType: JDBCTypeInteger,
+ Value: 123,
+ },
+ expected: 123,
+ },
+ {
+ name: "nil value",
+ column: &ColumnImage{
+ ColumnType: JDBCTypeVarchar,
+ Value: nil,
+ },
+ expected: nil,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := tt.column.GetActualValue()
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
diff --git a/pkg/datasource/sql/types/key_type_test.go
b/pkg/datasource/sql/types/key_type_test.go
new file mode 100644
index 00000000..028ee3e8
--- /dev/null
+++ b/pkg/datasource/sql/types/key_type_test.go
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package types
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestKeyTypeConstants(t *testing.T) {
+ assert.Equal(t, KeyType("NULL"), Null)
+ assert.Equal(t, KeyType("PRIMARY_KEY"), PrimaryKey)
+}
+
+func TestKeyType_Number(t *testing.T) {
+ tests := []struct {
+ name string
+ keyType KeyType
+ expected IndexType
+ }{
+ {"Null key type", Null, IndexType(0)},
+ {"Primary key type", PrimaryKey, IndexType(1)},
+ {"Unknown key type", KeyType("UNKNOWN"), IndexType(0)},
+ {"Empty key type", KeyType(""), IndexType(0)},
+ {"Custom key type", KeyType("CUSTOM"), IndexType(0)},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := tt.keyType.Number()
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestKeyType_String(t *testing.T) {
+ // Test that KeyType behaves as a string
+ assert.Equal(t, "NULL", string(Null))
+ assert.Equal(t, "PRIMARY_KEY", string(PrimaryKey))
+}
+
+func TestKeyType_Comparison(t *testing.T) {
+ // Test KeyType comparison
+ assert.True(t, Null == KeyType("NULL"))
+ assert.True(t, PrimaryKey == KeyType("PRIMARY_KEY"))
+ assert.False(t, Null == PrimaryKey)
+ assert.False(t, Null == KeyType("UNKNOWN"))
+}
+
+func TestKeyType_CaseSensitive(t *testing.T) {
+ // Test that KeyType is case sensitive
+ assert.False(t, Null == KeyType("null"))
+ assert.False(t, PrimaryKey == KeyType("primary_key"))
+ assert.False(t, PrimaryKey == KeyType("Primary_Key"))
+}
+
+func TestKeyType_NumberMapping(t *testing.T) {
+ // Test that the Number() method maps to correct IndexType values
+ assert.Equal(t, IndexTypeNull, Null.Number())
+ assert.Equal(t, IndexTypePrimaryKey, PrimaryKey.Number())
+}
diff --git a/pkg/datasource/sql/types/meta_test.go
b/pkg/datasource/sql/types/meta_test.go
index de747ad1..c44f3406 100644
--- a/pkg/datasource/sql/types/meta_test.go
+++ b/pkg/datasource/sql/types/meta_test.go
@@ -108,3 +108,97 @@ func TestTableMeta_GetPrimaryKeyType(t *testing.T) {
})
}
}
+
+func TestTableMetaCache_GetTableMeta(t *testing.T) {
+ cache := map[string]*TableMeta{
+ "table1": {
+ TableName: "table1",
+ Columns: map[string]ColumnMeta{
+ "id": {
+ Schema: "test",
+ Table: "table1",
+ ColumnName: "id",
+ ColumnType: "int",
+ },
+ },
+ },
+ }
+
+ // Test existing table
+ meta := cache["table1"]
+ assert.NotNil(t, meta)
+ assert.Equal(t, "table1", meta.TableName)
+
+ // Test non-existing table
+ meta = cache["table2"]
+ assert.Nil(t, meta)
+}
+
+func TestColumnType_DatabaseTypeName(t *testing.T) {
+ meta := ColumnType{DatabaseType: "varchar(255)"}
+ assert.Equal(t, "varchar(255)", meta.DatabaseTypeName())
+}
+
+func TestTableMeta_IsEmpty(t *testing.T) {
+ t.Run("empty table", func(t *testing.T) {
+ meta := &TableMeta{}
+ assert.True(t, meta.IsEmpty())
+ })
+
+ t.Run("table with columns", func(t *testing.T) {
+ meta := &TableMeta{
+ TableName: "test_table",
+ Columns: map[string]ColumnMeta{
+ "id": {ColumnName: "id"},
+ },
+ }
+ assert.False(t, meta.IsEmpty())
+ })
+}
+
+func TestTableMeta_GetPrimaryKeyMap(t *testing.T) {
+ meta := &TableMeta{
+ Indexs: map[string]IndexMeta{
+ "primary": {
+ Name: "primary",
+ IType: IndexTypePrimaryKey,
+ Columns: []ColumnMeta{
+ {ColumnName: "id"},
+ {ColumnName: "user_id"},
+ },
+ },
+ "normal": {
+ Name: "normal",
+ IType: IndexTypeNull,
+ Columns: []ColumnMeta{
+ {ColumnName: "name"},
+ },
+ },
+ },
+ }
+
+ primaryKeys := meta.GetPrimaryKeyMap()
+ assert.Len(t, primaryKeys, 2)
+ assert.Contains(t, primaryKeys, "id")
+ assert.Contains(t, primaryKeys, "user_id")
+}
+
+func TestTableMeta_GetPrimaryKeyOnlyName(t *testing.T) {
+ meta := &TableMeta{
+ Indexs: map[string]IndexMeta{
+ "primary": {
+ Name: "primary",
+ IType: IndexTypePrimaryKey,
+ Columns: []ColumnMeta{
+ {ColumnName: "id"},
+ {ColumnName: "user_id"},
+ },
+ },
+ },
+ }
+
+ primaryKeys := meta.GetPrimaryKeyOnlyName()
+ assert.Len(t, primaryKeys, 2)
+ assert.Contains(t, primaryKeys, "id")
+ assert.Contains(t, primaryKeys, "user_id")
+}
diff --git a/pkg/datasource/sql/types/mysql_ketword_checker.go
b/pkg/datasource/sql/types/mysql_keyword_checker.go
similarity index 100%
rename from pkg/datasource/sql/types/mysql_ketword_checker.go
rename to pkg/datasource/sql/types/mysql_keyword_checker.go
diff --git a/pkg/datasource/sql/types/mysql_keyword_checker_test.go
b/pkg/datasource/sql/types/mysql_keyword_checker_test.go
new file mode 100644
index 00000000..e4cf5fdd
--- /dev/null
+++ b/pkg/datasource/sql/types/mysql_keyword_checker_test.go
@@ -0,0 +1,263 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package types
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestGetMysqlKeyWord(t *testing.T) {
+ keywordMap := GetMysqlKeyWord()
+
+ // Test that the map is not empty
+ assert.NotEmpty(t, keywordMap, "MySQL keyword map should not be empty")
+
+ // Test some key MySQL keywords that are actually in the map
+ expectedKeywords := []string{
+ "SELECT", "INSERT", "UPDATE", "DELETE",
+ "FROM", "WHERE", "ORDER", "BY",
+ "GROUP", "HAVING", "JOIN", "INNER",
+ "LEFT", "RIGHT", "CREATE", "DROP",
+ "ALTER", "TABLE", "INDEX", "PRIMARY",
+ "KEY", "FOREIGN", "REFERENCES", "NOT",
+ "NULL", "UNIQUE", "DEFAULT", "VARCHAR",
+ "INT", "BIGINT", "DECIMAL", "BLOB",
+ }
+
+ for _, keyword := range expectedKeywords {
+ t.Run("keyword_"+keyword, func(t *testing.T) {
+ value, exists := keywordMap[keyword]
+ assert.True(t, exists, "Keyword %s should exist in
MySQL keyword map", keyword)
+ assert.Equal(t, keyword, value, "Keyword %s should map
to itself", keyword)
+ })
+ }
+}
+
+func TestGetMysqlKeyWord_LazyInit(t *testing.T) {
+ // Reset the global variable to test lazy initialization
+ originalMap := MysqlKeyWord
+ defer func() {
+ MysqlKeyWord = originalMap
+ }()
+
+ // Set to nil to test lazy init
+ MysqlKeyWord = nil
+
+ // First call should initialize the map
+ keywordMap1 := GetMysqlKeyWord()
+ assert.NotNil(t, keywordMap1)
+ assert.NotEmpty(t, keywordMap1)
+
+ // Second call should return the same map
+ keywordMap2 := GetMysqlKeyWord()
+ assert.Equal(t, keywordMap1, keywordMap2)
+
+ // The global variable should now be set
+ assert.NotNil(t, MysqlKeyWord)
+ assert.Equal(t, keywordMap1, MysqlKeyWord)
+}
+
+func TestGetMysqlKeyWord_SpecificKeywords(t *testing.T) {
+ keywordMap := GetMysqlKeyWord()
+
+ // Test specific keywords with their expected values
+ specificTests := map[string]string{
+ "ACCESSIBLE": "ACCESSIBLE",
+ "ADD": "ADD",
+ "ALL": "ALL",
+ "ALTER": "ALTER",
+ "ANALYZE": "ANALYZE",
+ "AND": "AND",
+ "AS": "AS",
+ "ASC": "ASC",
+ "BEFORE": "BEFORE",
+ "BETWEEN": "BETWEEN",
+ "BIGINT": "BIGINT",
+ "BINARY": "BINARY",
+ "BLOB": "BLOB",
+ "BY": "BY",
+ "CALL": "CALL",
+ "CASCADE": "CASCADE",
+ "CASE": "CASE",
+ "CHANGE": "CHANGE",
+ "CHAR": "CHAR",
+ "CHARACTER": "CHARACTER",
+ "CHECK": "CHECK",
+ "COLLATE": "COLLATE",
+ "COLUMN": "COLUMN",
+ "CONDITION": "CONDITION",
+ "CONSTRAINT": "CONSTRAINT",
+ "CONTINUE": "CONTINUE",
+ "CONVERT": "CONVERT",
+ "CREATE": "CREATE",
+ "CROSS": "CROSS",
+ "CURRENT_DATE": "CURRENT_DATE",
+ "CURRENT_TIME": "CURRENT_TIME",
+ "CURRENT_TIMESTAMP": "CURRENT_TIMESTAMP",
+ "CURRENT_USER": "CURRENT_USER",
+ "CURSOR": "CURSOR",
+ "DATABASE": "DATABASE",
+ "DATABASES": "DATABASES",
+ "DECIMAL": "DECIMAL",
+ "DECLARE": "DECLARE",
+ "DEFAULT": "DEFAULT",
+ "DELETE": "DELETE",
+ "DESC": "DESC",
+ "DESCRIBE": "DESCRIBE",
+ "DISTINCT": "DISTINCT",
+ "DOUBLE": "DOUBLE",
+ "DROP": "DROP",
+ "EACH": "EACH",
+ "ELSE": "ELSE",
+ "EXISTS": "EXISTS",
+ "EXPLAIN": "EXPLAIN",
+ "FALSE": "FALSE",
+ "FLOAT": "FLOAT",
+ "FOR": "FOR",
+ "FOREIGN": "FOREIGN",
+ "FROM": "FROM",
+ "FULLTEXT": "FULLTEXT",
+ "GRANT": "GRANT",
+ "GROUP": "GROUP",
+ "HAVING": "HAVING",
+ "IF": "IF",
+ "IN": "IN",
+ "INDEX": "INDEX",
+ "INNER": "INNER",
+ "INSERT": "INSERT",
+ "INT": "INT",
+ "INTEGER": "INTEGER",
+ "INTO": "INTO",
+ "IS": "IS",
+ "JOIN": "JOIN",
+ "KEY": "KEY",
+ "KEYS": "KEYS",
+ "LEFT": "LEFT",
+ "LIKE": "LIKE",
+ "LIMIT": "LIMIT",
+ "LOAD": "LOAD",
+ "LOCK": "LOCK",
+ "LONG": "LONG",
+ "LONGBLOB": "LONGBLOB",
+ "LONGTEXT": "LONGTEXT",
+ "MATCH": "MATCH",
+ "MEDIUMBLOB": "MEDIUMBLOB",
+ "MEDIUMINT": "MEDIUMINT",
+ "MEDIUMTEXT": "MEDIUMTEXT",
+ "NOT": "NOT",
+ "NULL": "NULL",
+ "NUMERIC": "NUMERIC",
+ "ON": "ON",
+ "OR": "OR",
+ "ORDER": "ORDER",
+ "OUTER": "OUTER",
+ "PRIMARY": "PRIMARY",
+ "PROCEDURE": "PROCEDURE",
+ "REFERENCES": "REFERENCES",
+ "RENAME": "RENAME",
+ "REPLACE": "REPLACE",
+ "RIGHT": "RIGHT",
+ "SELECT": "SELECT",
+ "SET": "SET",
+ "SHOW": "SHOW",
+ "SMALLINT": "SMALLINT",
+ "TABLE": "TABLE",
+ "THEN": "THEN",
+ "TINYBLOB": "TINYBLOB",
+ "TINYINT": "TINYINT",
+ "TINYTEXT": "TINYTEXT",
+ "TO": "TO",
+ "TRUE": "TRUE",
+ "UNION": "UNION",
+ "UNIQUE": "UNIQUE",
+ "UNLOCK": "UNLOCK",
+ "UNSIGNED": "UNSIGNED",
+ "UPDATE": "UPDATE",
+ "USE": "USE",
+ "USING": "USING",
+ "VALUES": "VALUES",
+ "VARBINARY": "VARBINARY",
+ "VARCHAR": "VARCHAR",
+ "WHEN": "WHEN",
+ "WHERE": "WHERE",
+ "WHILE": "WHILE",
+ "WITH": "WITH",
+ "XOR": "XOR",
+ "ZEROFILL": "ZEROFILL",
+ }
+
+ for keyword, expectedValue := range specificTests {
+ t.Run("specific_"+keyword, func(t *testing.T) {
+ value, exists := keywordMap[keyword]
+ assert.True(t, exists, "Keyword %s should exist",
keyword)
+ assert.Equal(t, expectedValue, value, "Keyword %s
should have correct value", keyword)
+ })
+ }
+}
+
+func TestGetMysqlKeyWord_MapIntegrity(t *testing.T) {
+ keywordMap := GetMysqlKeyWord()
+
+ // Test that all keys map to themselves (which is the pattern in the
implementation)
+ for key, value := range keywordMap {
+ assert.Equal(t, key, value, "Keyword %s should map to itself",
key)
+ }
+
+ // Test that the map has a reasonable number of keywords
+ // MySQL has hundreds of reserved words, so we expect a substantial map
+ assert.Greater(t, len(keywordMap), 100, "MySQL keyword map should have
more than 100 entries")
+ assert.Less(t, len(keywordMap), 1000, "MySQL keyword map should have
less than 1000 entries")
+}
+
+func TestGetMysqlKeyWord_CaseSensitivity(t *testing.T) {
+ keywordMap := GetMysqlKeyWord()
+
+ // Test that keywords are stored in uppercase
+ testKeywords := []string{"SELECT", "INSERT", "UPDATE", "DELETE",
"CREATE", "DROP"}
+
+ for _, keyword := range testKeywords {
+ t.Run("case_"+keyword, func(t *testing.T) {
+ // Uppercase should exist
+ _, exists := keywordMap[keyword]
+ assert.True(t, exists, "Uppercase keyword %s should
exist", keyword)
+
+ // Lowercase should not exist (since map stores
uppercase)
+ _, exists = keywordMap[strings.ToLower(keyword)]
+ assert.False(t, exists, "Lowercase keyword %s should
not exist", strings.ToLower(keyword))
+ })
+ }
+}
+
+func TestGetMysqlKeyWord_ReturnsSameInstance(t *testing.T) {
+ // Test that multiple calls return the same map instance
+ map1 := GetMysqlKeyWord()
+ map2 := GetMysqlKeyWord()
+
+ // Should be the same instance (same memory address)
+ assert.True(t, &map1 == &map2 || len(map1) == len(map2), "Multiple
calls should return consistent maps")
+
+ // Should have the same content
+ for key, value1 := range map1 {
+ value2, exists := map2[key]
+ assert.True(t, exists, "Key %s should exist in both maps", key)
+ assert.Equal(t, value1, value2, "Value for key %s should be the
same in both maps", key)
+ }
+}
diff --git a/pkg/datasource/sql/types/sql_data_type_test.go
b/pkg/datasource/sql/types/sql_data_type_test.go
new file mode 100644
index 00000000..5caa92f5
--- /dev/null
+++ b/pkg/datasource/sql/types/sql_data_type_test.go
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package types
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestSqlDataTypes(t *testing.T) {
+ // Test that the SqlDataTypes map has expected values
+ expectedTypes := map[string]int32{
+ "BIT": -7,
+ "TINYINT": -6,
+ "SMALLINT": 5,
+ "INTEGER": 4,
+ "BIGINT": -5,
+ "FLOAT": 6,
+ "REAL": 7,
+ "DOUBLE": 8,
+ "NUMERIC": 2,
+ "DECIMAL": 3,
+ "CHAR": 1,
+ "VARCHAR": 12,
+ "LONGVARCHAR": -1,
+ "DATE": 91,
+ "TIME": 92,
+ "TIMESTAMP": 93,
+ "BINARY": -2,
+ "VARBINARY": -3,
+ "LONGVARBINARY": -4,
+ "NULL": 0,
+ "OTHER": 1111,
+ "JAVA_OBJECT": 2000,
+ "DISTINCT": 2001,
+ "STRUCT": 2002,
+ "ARRAY": 2003,
+ "BLOB": 2004,
+ "CLOB": 2005,
+ "REF": 2006,
+ "DATALINK": 70,
+ "BOOLEAN": 16,
+ "ROWID": -8,
+ "NCHAR": -15,
+ "NVARCHAR": -9,
+ "LONGNVARCHAR": -16,
+ "NCLOB": 2011,
+ "SQLXML": 2009,
+ "REF_CURSOR": 2012,
+ "TIME_WITH_TIMEZONE": 2013,
+ "TIMESTAMP_WITH_TIMEZONE": 2014,
+ }
+
+ for dataType, expectedValue := range expectedTypes {
+ t.Run(dataType, func(t *testing.T) {
+ actualValue, exists := SqlDataTypes[dataType]
+ assert.True(t, exists, "Data type %s should exist in
SqlDataTypes map", dataType)
+ assert.Equal(t, expectedValue, actualValue, "Data type
%s should have value %d", dataType, expectedValue)
+ })
+ }
+}
+
+func TestGetSqlDataType(t *testing.T) {
+ tests := []struct {
+ name string
+ dataType string
+ expected int32
+ }{
+ // Test known types
+ {"BIT", "BIT", -7},
+ {"TINYINT", "TINYINT", -6},
+ {"SMALLINT", "SMALLINT", 5},
+ {"INTEGER", "INTEGER", 4},
+ {"BIGINT", "BIGINT", -5},
+ {"FLOAT", "FLOAT", 6},
+ {"REAL", "REAL", 7},
+ {"DOUBLE", "DOUBLE", 8},
+ {"NUMERIC", "NUMERIC", 2},
+ {"DECIMAL", "DECIMAL", 3},
+ {"CHAR", "CHAR", 1},
+ {"VARCHAR", "VARCHAR", 12},
+ {"LONGVARCHAR", "LONGVARCHAR", -1},
+ {"DATE", "DATE", 91},
+ {"TIME", "TIME", 92},
+ {"TIMESTAMP", "TIMESTAMP", 93},
+ {"BINARY", "BINARY", -2},
+ {"VARBINARY", "VARBINARY", -3},
+ {"LONGVARBINARY", "LONGVARBINARY", -4},
+ {"NULL", "NULL", 0},
+ {"OTHER", "OTHER", 1111},
+ {"JAVA_OBJECT", "JAVA_OBJECT", 2000},
+ {"DISTINCT", "DISTINCT", 2001},
+ {"STRUCT", "STRUCT", 2002},
+ {"ARRAY", "ARRAY", 2003},
+ {"BLOB", "BLOB", 2004},
+ {"CLOB", "CLOB", 2005},
+ {"REF", "REF", 2006},
+ {"DATALINK", "DATALINK", 70},
+ {"BOOLEAN", "BOOLEAN", 16},
+ {"ROWID", "ROWID", -8},
+ {"NCHAR", "NCHAR", -15},
+ {"NVARCHAR", "NVARCHAR", -9},
+ {"LONGNVARCHAR", "LONGNVARCHAR", -16},
+ {"NCLOB", "NCLOB", 2011},
+ {"SQLXML", "SQLXML", 2009},
+ {"REF_CURSOR", "REF_CURSOR", 2012},
+ {"TIME_WITH_TIMEZONE", "TIME_WITH_TIMEZONE", 2013},
+ {"TIMESTAMP_WITH_TIMEZONE", "TIMESTAMP_WITH_TIMEZONE", 2014},
+
+ // Test case insensitive
+ {"bit lowercase", "bit", -7},
+ {"varchar mixed case", "VarChar", 12},
+ {"integer lowercase", "integer", 4},
+
+ // Test unknown types (should return 0 - zero value for int32)
+ {"unknown type", "UNKNOWN_TYPE", 0},
+ {"empty string", "", 0},
+ {"partial match", "VAR", 0},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := GetSqlDataType(tt.dataType)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestGetSqlDataType_CaseInsensitive(t *testing.T) {
+ // Test that the function is case insensitive
+ testCases := []string{"varchar", "VARCHAR", "VarChar", "VARCHARACTER"}
+
+ for _, testCase := range testCases {
+ if testCase == "VARCHARACTER" {
+ // This should return 0 because it's not an exact match
+ result := GetSqlDataType(testCase)
+ assert.Equal(t, int32(0), result)
+ } else {
+ // All these should return the same value for VARCHAR
+ result := GetSqlDataType(testCase)
+ assert.Equal(t, int32(12), result, "Case insensitive
test failed for: %s", testCase)
+ }
+ }
+}
+
+func TestSqlDataTypesMapIntegrity(t *testing.T) {
+ // Test that the map is not empty
+ assert.NotEmpty(t, SqlDataTypes, "SqlDataTypes map should not be empty")
+
+ // Test that all expected data types are present
+ expectedCount := 39 // Update this if you add/remove data types
+ assert.Equal(t, expectedCount, len(SqlDataTypes), "SqlDataTypes map
should have %d entries", expectedCount)
+
+ // Test that no values are duplicated (except for legitimate cases)
+ valueCount := make(map[int32][]string)
+ for dataType, value := range SqlDataTypes {
+ valueCount[value] = append(valueCount[value], dataType)
+ }
+
+ // Some values might legitimately be duplicated, but let's check for
unexpected duplicates
+ for value, dataTypes := range valueCount {
+ if len(dataTypes) > 1 {
+ t.Logf("Value %d is used by multiple data types: %v",
value, dataTypes)
+ // This is informational - some SQL types might
legitimately map to the same JDBC type
+ }
+ }
+}
+
+func TestSqlDataTypes_SpecificValues(t *testing.T) {
+ // Test some specific important mappings
+ assert.Equal(t, int32(-7), SqlDataTypes["BIT"])
+ assert.Equal(t, int32(4), SqlDataTypes["INTEGER"])
+ assert.Equal(t, int32(12), SqlDataTypes["VARCHAR"])
+ assert.Equal(t, int32(93), SqlDataTypes["TIMESTAMP"])
+ assert.Equal(t, int32(0), SqlDataTypes["NULL"])
+ assert.Equal(t, int32(1111), SqlDataTypes["OTHER"])
+}
diff --git a/pkg/datasource/sql/types/sql_test.go
b/pkg/datasource/sql/types/sql_test.go
new file mode 100644
index 00000000..007b240e
--- /dev/null
+++ b/pkg/datasource/sql/types/sql_test.go
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package types
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestSQLTypeConstants(t *testing.T) {
+ // Test some key SQL type constants
+ assert.Equal(t, int(SQLTypeSelect), 0)
+ assert.Equal(t, int(SQLTypeInsert), 1)
+ assert.Equal(t, int(SQLTypeUpdate), 2)
+ assert.Equal(t, int(SQLTypeDelete), 3)
+ assert.Equal(t, int(SQLTypeSelectForUpdate), 4)
+ assert.Equal(t, int(SQLTypeReplace), 5)
+ assert.Equal(t, int(SQLTypeTruncate), 6)
+ assert.Equal(t, int(SQLTypeCreate), 7)
+ assert.Equal(t, int(SQLTypeDrop), 8)
+ assert.Equal(t, int(SQLTypeLoad), 9)
+ assert.Equal(t, int(SQLTypeMerge), 10)
+ assert.Equal(t, int(SQLTypeShow), 11)
+ assert.Equal(t, int(SQLTypeAlter), 12)
+ assert.Equal(t, int(SQLTypeRename), 13)
+ assert.Equal(t, int(SQLTypeDump), 14)
+ assert.Equal(t, int(SQLTypeDebug), 15)
+ assert.Equal(t, int(SQLTypeExplain), 16)
+ assert.Equal(t, int(SQLTypeProcedure), 17)
+ assert.Equal(t, int(SQLTypeDesc), 18)
+}
+
+func TestSQLType_MarshalText(t *testing.T) {
+ tests := []struct {
+ name string
+ sqlType SQLType
+ expected []byte
+ }{
+ {"SELECT", SQLTypeSelect, []byte("SELECT")},
+ {"INSERT", SQLTypeInsert, []byte("INSERT")},
+ {"UPDATE", SQLTypeUpdate, []byte("UPDATE")},
+ {"DELETE", SQLTypeDelete, []byte("DELETE")},
+ {"SELECT_FOR_UPDATE", SQLTypeSelectForUpdate,
[]byte("SELECT_FOR_UPDATE")},
+ {"INSERT_ON_UPDATE", SQLTypeInsertOnDuplicateUpdate,
[]byte("INSERT_ON_UPDATE")},
+ {"REPLACE", SQLTypeReplace, []byte("REPLACE")},
+ {"TRUNCATE", SQLTypeTruncate, []byte("TRUNCATE")},
+ {"CREATE", SQLTypeCreate, []byte("CREATE")},
+ {"DROP", SQLTypeDrop, []byte("DROP")},
+ {"LOAD", SQLTypeLoad, []byte("LOAD")},
+ {"MERGE", SQLTypeMerge, []byte("MERGE")},
+ {"SHOW", SQLTypeShow, []byte("SHOW")},
+ {"ALTER", SQLTypeAlter, []byte("ALTER")},
+ {"RENAME", SQLTypeRename, []byte("RENAME")},
+ {"DUMP", SQLTypeDump, []byte("DUMP")},
+ {"DEBUG", SQLTypeDebug, []byte("DEBUG")},
+ {"EXPLAIN", SQLTypeExplain, []byte("EXPLAIN")},
+ {"DESC", SQLTypeDesc, []byte("DESC")},
+ {"SET", SQLTypeSet, []byte("SET")},
+ {"RELOAD", SQLTypeReload, []byte("RELOAD")},
+ {"SELECT_UNION", SQLTypeSelectUnion, []byte("SELECT_UNION")},
+ {"CREATE_TABLE", SQLTypeCreateTable, []byte("CREATE_TABLE")},
+ {"DROP_TABLE", SQLTypeDropTable, []byte("DROP_TABLE")},
+ {"ALTER_TABLE", SQLTypeAlterTable, []byte("ALTER_TABLE")},
+ {"SELECT_FROM_UPDATE", SQLTypeSelectFromUpdate,
[]byte("SELECT_FROM_UPDATE")},
+ {"MULTI_DELETE", SQLTypeMultiDelete, []byte("MULTI_DELETE")},
+ {"MULTI_UPDATE", SQLTypeMultiUpdate, []byte("MULTI_UPDATE")},
+ {"CREATE_INDEX", SQLTypeCreateIndex, []byte("CREATE_INDEX")},
+ {"DROP_INDEX", SQLTypeDropIndex, []byte("DROP_INDEX")},
+ {"MULTI", SQLTypeMulti, []byte("MULTI")},
+ {"INVALID", SQLType(9999), []byte("INVALID_SQLTYPE")},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result, err := tt.sqlType.MarshalText()
+ assert.NoError(t, err)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestSQLType_UnmarshalText(t *testing.T) {
+ tests := []struct {
+ name string
+ input []byte
+ expected SQLType
+ }{
+ {"SELECT", []byte("SELECT"), SQLTypeSelect},
+ {"INSERT", []byte("INSERT"), SQLTypeInsert},
+ {"UPDATE", []byte("UPDATE"), SQLTypeUpdate},
+ {"DELETE", []byte("DELETE"), SQLTypeDelete},
+ {"SELECT_FOR_UPDATE", []byte("SELECT_FOR_UPDATE"),
SQLTypeSelectForUpdate},
+ {"INSERT_ON_UPDATE", []byte("INSERT_ON_UPDATE"),
SQLTypeInsertOnDuplicateUpdate},
+ {"REPLACE", []byte("REPLACE"), SQLTypeReplace},
+ {"TRUNCATE", []byte("TRUNCATE"), SQLTypeTruncate},
+ {"CREATE", []byte("CREATE"), SQLTypeCreate},
+ {"DROP", []byte("DROP"), SQLTypeDrop},
+ {"LOAD", []byte("LOAD"), SQLTypeLoad},
+ {"MERGE", []byte("MERGE"), SQLTypeMerge},
+ {"SHOW", []byte("SHOW"), SQLTypeShow},
+ {"ALTER", []byte("ALTER"), SQLTypeAlter},
+ {"RENAME", []byte("RENAME"), SQLTypeRename},
+ {"DUMP", []byte("DUMP"), SQLTypeDump},
+ {"DEBUG", []byte("DEBUG"), SQLTypeDebug},
+ {"EXPLAIN", []byte("EXPLAIN"), SQLTypeExplain},
+ {"DESC", []byte("DESC"), SQLTypeDesc},
+ {"SET", []byte("SET"), SQLTypeSet},
+ {"RELOAD", []byte("RELOAD"), SQLTypeReload},
+ {"SELECT_UNION", []byte("SELECT_UNION"), SQLTypeSelectUnion},
+ {"CREATE_TABLE", []byte("CREATE_TABLE"), SQLTypeCreateTable},
+ {"DROP_TABLE", []byte("DROP_TABLE"), SQLTypeDropTable},
+ {"ALTER_TABLE", []byte("ALTER_TABLE"), SQLTypeAlterTable},
+ {"SELECT_FROM_UPDATE", []byte("SELECT_FROM_UPDATE"),
SQLTypeSelectFromUpdate},
+ {"MULTI_DELETE", []byte("MULTI_DELETE"), SQLTypeMultiDelete},
+ {"MULTI_UPDATE", []byte("MULTI_UPDATE"), SQLTypeMultiUpdate},
+ {"CREATE_INDEX", []byte("CREATE_INDEX"), SQLTypeCreateIndex},
+ {"DROP_INDEX", []byte("DROP_INDEX"), SQLTypeDropIndex},
+ {"MULTI", []byte("MULTI"), SQLTypeMulti},
+ {"Unknown", []byte("UNKNOWN"), SQLType(0)}, // defaults to 0
+ {"Empty", []byte(""), SQLType(0)}, // defaults to 0
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var sqlType SQLType
+ err := sqlType.UnmarshalText(tt.input)
+ assert.NoError(t, err)
+ assert.Equal(t, tt.expected, sqlType)
+ })
+ }
+}
+
+func TestSQLType_MarshalUnmarshalRoundTrip(t *testing.T) {
+ sqlTypes := []SQLType{
+ SQLTypeSelect,
+ SQLTypeInsert,
+ SQLTypeUpdate,
+ SQLTypeDelete,
+ SQLTypeSelectForUpdate,
+ SQLTypeInsertOnDuplicateUpdate,
+ SQLTypeReplace,
+ SQLTypeTruncate,
+ SQLTypeCreate,
+ SQLTypeDrop,
+ SQLTypeLoad,
+ SQLTypeMerge,
+ SQLTypeShow,
+ SQLTypeAlter,
+ SQLTypeRename,
+ SQLTypeDump,
+ SQLTypeDebug,
+ SQLTypeExplain,
+ SQLTypeDesc,
+ SQLTypeSet,
+ SQLTypeReload,
+ SQLTypeSelectUnion,
+ SQLTypeCreateTable,
+ SQLTypeDropTable,
+ SQLTypeAlterTable,
+ SQLTypeSelectFromUpdate,
+ SQLTypeMultiDelete,
+ SQLTypeMultiUpdate,
+ SQLTypeCreateIndex,
+ SQLTypeDropIndex,
+ SQLTypeMulti,
+ }
+
+ for i, originalType := range sqlTypes {
+ t.Run(fmt.Sprintf("SQLType_%d", i), func(t *testing.T) {
+ // Marshal
+ marshaled, err := originalType.MarshalText()
+ assert.NoError(t, err)
+
+ // Unmarshal
+ var unmarshaled SQLType
+ err = unmarshaled.UnmarshalText(marshaled)
+ assert.NoError(t, err)
+
+ // Should be equal
+ assert.Equal(t, originalType, unmarshaled)
+ })
+ }
+}
+
+func TestSQLTypeSpecialConstants(t *testing.T) {
+ // Test the special constants that have gaps
+ // SQLTypeInsertIgnore = iota + 57, where iota was around 69 before
this line
+ assert.Equal(t, int(SQLTypeInsertIgnore), 101) // roughly 44
+ 57
+ assert.Equal(t, int(SQLTypeInsertOnDuplicateUpdate), 102) // next in
sequence
+ assert.True(t, int(SQLTypeMulti) > 1000) // iota + 999
+ assert.Equal(t, SQLTypeMulti+1, SQLTypeUnknown) // Should be
one more than multi
+}
+
+func TestSQLType_StringRepresentation(t *testing.T) {
+ // Test that SQLType can be represented as an integer
+ assert.Equal(t, int(SQLTypeSelect), 0)
+ assert.Equal(t, int(SQLTypeInsert), 1)
+ assert.Equal(t, int(SQLTypeUpdate), 2)
+}
diff --git a/pkg/datasource/sql/types/types_test.go
b/pkg/datasource/sql/types/types_test.go
new file mode 100644
index 00000000..955a635d
--- /dev/null
+++ b/pkg/datasource/sql/types/types_test.go
@@ -0,0 +1,387 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package types
+
+import (
+ "database/sql/driver"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+
+ "seata.apache.org/seata-go/pkg/protocol/branch"
+)
+
+func TestParseIndexType(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ expected IndexType
+ }{
+ {"PRIMARY_KEY", "PRIMARY_KEY", IndexTypePrimaryKey},
+ {"primary_key lowercase", "primary_key", IndexTypeNull},
+ {"NULL", "NULL", IndexTypeNull},
+ {"empty string", "", IndexTypeNull},
+ {"unknown type", "UNKNOWN", IndexTypeNull},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := ParseIndexType(tt.input)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestIndexType_MarshalText(t *testing.T) {
+ tests := []struct {
+ name string
+ input IndexType
+ expected []byte
+ }{
+ {"IndexTypePrimaryKey", IndexTypePrimaryKey,
[]byte("PRIMARY_KEY")},
+ {"IndexTypeNull", IndexTypeNull, []byte("NULL")},
+ {"Unknown type", IndexType(999), []byte("NULL")},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result, err := tt.input.MarshalText()
+ assert.NoError(t, err)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestIndexType_UnmarshalText(t *testing.T) {
+ tests := []struct {
+ name string
+ input []byte
+ expected IndexType
+ expectError bool
+ }{
+ {"PRIMARY_KEY", []byte("PRIMARY_KEY"), IndexTypePrimaryKey,
false},
+ {"NULL", []byte("NULL"), IndexTypeNull, false},
+ {"invalid type", []byte("INVALID"), IndexTypeNull, true},
+ {"empty", []byte(""), IndexTypeNull, true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var indexType IndexType
+ err := indexType.UnmarshalText(tt.input)
+
+ if tt.expectError {
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "invalid index
type")
+ } else {
+ assert.NoError(t, err)
+ assert.Equal(t, tt.expected, indexType)
+ }
+ })
+ }
+}
+
+func TestDBTypeConstants(t *testing.T) {
+ assert.Equal(t, DBType(1), DBTypeUnknown)
+ assert.Equal(t, DBType(2), DBTypeMySQL)
+ assert.Equal(t, DBType(3), DBTypePostgreSQL)
+ assert.Equal(t, DBType(4), DBTypeSQLServer)
+ assert.Equal(t, DBType(5), DBTypeOracle)
+ assert.Equal(t, DBType(6), DBTypeMARIADB)
+}
+
+func TestParseDBType(t *testing.T) {
+ tests := []struct {
+ name string
+ driverName string
+ expected DBType
+ }{
+ {"mysql", "mysql", DBTypeMySQL},
+ {"MySQL uppercase", "MySQL", DBTypeMySQL},
+ {"MYSQL", "MYSQL", DBTypeMySQL},
+ {"postgres", "postgres", DBTypeUnknown},
+ {"unknown", "unknown", DBTypeUnknown},
+ {"empty", "", DBTypeUnknown},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := ParseDBType(tt.driverName)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestTransactionModeConstants(t *testing.T) {
+ assert.Equal(t, TransactionMode(1), Local)
+ assert.Equal(t, TransactionMode(2), XAMode)
+ assert.Equal(t, TransactionMode(3), ATMode)
+}
+
+func TestTransactionMode_BranchType(t *testing.T) {
+ tests := []struct {
+ name string
+ mode TransactionMode
+ expected branch.BranchType
+ }{
+ {"XAMode", XAMode, branch.BranchTypeXA},
+ {"ATMode", ATMode, branch.BranchTypeAT},
+ {"Local", Local, branch.BranchTypeUnknow},
+ {"Unknown", TransactionMode(99), branch.BranchTypeUnknow},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := tt.mode.BranchType()
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestNewTxCtx(t *testing.T) {
+ ctx := NewTxCtx()
+
+ assert.NotNil(t, ctx)
+ assert.NotNil(t, ctx.LockKeys)
+ assert.Equal(t, 0, len(ctx.LockKeys))
+ assert.Equal(t, Local, ctx.TransactionMode)
+ assert.NotEmpty(t, ctx.LocalTransID)
+ assert.NotNil(t, ctx.RoundImages)
+}
+
+func TestTransactionContext_HasUndoLog(t *testing.T) {
+ tests := []struct {
+ name string
+ ctx *TransactionContext
+ expected bool
+ }{
+ {
+ name: "AT mode with images",
+ ctx: &TransactionContext{
+ TransactionMode: ATMode,
+ RoundImages: &RoundRecordImage{}, //
Assuming empty is false for HasUndoLog
+ },
+ expected: false, // Empty RoundRecordImage
+ },
+ {
+ name: "XA mode",
+ ctx: &TransactionContext{
+ TransactionMode: XAMode,
+ RoundImages: &RoundRecordImage{},
+ },
+ expected: false,
+ },
+ {
+ name: "Local mode",
+ ctx: &TransactionContext{
+ TransactionMode: Local,
+ RoundImages: &RoundRecordImage{},
+ },
+ expected: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := tt.ctx.HasUndoLog()
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestTransactionContext_HasLockKey(t *testing.T) {
+ tests := []struct {
+ name string
+ ctx *TransactionContext
+ expected bool
+ }{
+ {
+ name: "with lock keys",
+ ctx: &TransactionContext{
+ LockKeys: map[string]struct{}{
+ "key1": {},
+ "key2": {},
+ },
+ },
+ expected: true,
+ },
+ {
+ name: "empty lock keys",
+ ctx: &TransactionContext{
+ LockKeys: map[string]struct{}{},
+ },
+ expected: false,
+ },
+ {
+ name: "nil lock keys",
+ ctx: &TransactionContext{
+ LockKeys: nil,
+ },
+ expected: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := tt.ctx.HasLockKey()
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestTransactionContext_OpenGlobalTransaction(t *testing.T) {
+ tests := []struct {
+ name string
+ mode TransactionMode
+ expected bool
+ }{
+ {"Local", Local, false},
+ {"XAMode", XAMode, true},
+ {"ATMode", ATMode, true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ctx := &TransactionContext{
+ TransactionMode: tt.mode,
+ }
+ result := ctx.OpenGlobalTransaction()
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestTransactionContext_IsBranchRegistered(t *testing.T) {
+ tests := []struct {
+ name string
+ branchID uint64
+ expected bool
+ }{
+ {"registered", 123, true},
+ {"not registered", 0, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ctx := &TransactionContext{
+ BranchID: tt.branchID,
+ }
+ result := ctx.IsBranchRegistered()
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestQueryResult(t *testing.T) {
+ rows := &mockRows{}
+ result := &queryResult{Rows: rows}
+
+ assert.Equal(t, rows, result.GetRows())
+ assert.Panics(t, func() {
+ result.GetResult()
+ })
+}
+
+func TestWriteResult(t *testing.T) {
+ sqlResult := &mockResult{}
+ result := &writeResult{Result: sqlResult}
+
+ assert.Equal(t, sqlResult, result.GetResult())
+ assert.Panics(t, func() {
+ result.GetRows()
+ })
+}
+
+func TestNewResult(t *testing.T) {
+ t.Run("with rows", func(t *testing.T) {
+ rows := &mockRows{}
+ result := NewResult(WithRows(rows))
+
+ queryRes, ok := result.(*queryResult)
+ assert.True(t, ok)
+ assert.Equal(t, rows, queryRes.Rows)
+ })
+
+ t.Run("with result", func(t *testing.T) {
+ sqlResult := &mockResult{}
+ result := NewResult(WithResult(sqlResult))
+
+ writeRes, ok := result.(*writeResult)
+ assert.True(t, ok)
+ assert.Equal(t, sqlResult, writeRes.Result)
+ })
+
+ t.Run("with both (result takes precedence)", func(t *testing.T) {
+ rows := &mockRows{}
+ sqlResult := &mockResult{}
+ result := NewResult(WithRows(rows), WithResult(sqlResult))
+
+ writeRes, ok := result.(*writeResult)
+ assert.True(t, ok)
+ assert.Equal(t, sqlResult, writeRes.Result)
+ })
+
+ t.Run("with neither (panics)", func(t *testing.T) {
+ assert.Panics(t, func() {
+ NewResult()
+ })
+ })
+}
+
+func TestWithRows(t *testing.T) {
+ rows := &mockRows{}
+ opt := &option{}
+
+ WithRows(rows)(opt)
+ assert.Equal(t, rows, opt.rows)
+}
+
+func TestWithResult(t *testing.T) {
+ result := &mockResult{}
+ opt := &option{}
+
+ WithResult(result)(opt)
+ assert.Equal(t, result, opt.ret)
+}
+
+// Mock types for testing
+type mockRows struct{}
+
+func (m *mockRows) Columns() []string { return nil }
+func (m *mockRows) Close() error { return nil }
+func (m *mockRows) Next(dest []driver.Value) error { return nil }
+
+type mockResult struct{}
+
+func (m *mockResult) LastInsertId() (int64, error) { return 0, nil }
+func (m *mockResult) RowsAffected() (int64, error) { return 0, nil }
+
+func TestBranchPhaseConstants(t *testing.T) {
+ assert.Equal(t, 0, BranchPhase_Unknown)
+ assert.Equal(t, 1, BranchPhase_Done)
+ assert.Equal(t, 2, BranchPhase_Failed)
+}
+
+func TestIndexConstants(t *testing.T) {
+ // IndexPrimary starts from where DBType iota left off (after
DBTypeMARIADB which is 6)
+ // But since there are also BranchPhase constants in between, we need
to check actual values
+ assert.Equal(t, IndexType(10), IndexPrimary)
+ assert.Equal(t, IndexType(11), IndexNormal)
+ assert.Equal(t, IndexType(12), IndexUnique)
+ assert.Equal(t, IndexType(13), IndexFullText)
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]