Copilot commented on code in PR #922: URL: https://github.com/apache/incubator-seata-go/pull/922#discussion_r2453716729
########## pkg/datasource/sql/undo/builder/postgresql_multi_update_undo_log_builder.go: ########## @@ -0,0 +1,530 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package builder + +import ( + "context" + "database/sql/driver" + "fmt" + "strings" + + "github.com/arana-db/parser/ast" + "github.com/arana-db/parser/format" + "github.com/auxten/postgresql-parser/pkg/sql/sem/tree" + + "seata.apache.org/seata-go/pkg/datasource/sql/datasource" + "seata.apache.org/seata-go/pkg/datasource/sql/parser" + "seata.apache.org/seata-go/pkg/datasource/sql/types" + "seata.apache.org/seata-go/pkg/datasource/sql/undo" + "seata.apache.org/seata-go/pkg/util/bytes" +) + +type PostgreSQLMultiUpdateUndoLogBuilder struct { + BasicUndoLogBuilder +} + +func GetPostgreSQLMultiUpdateUndoLogBuilder() undo.UndoLogBuilder { + return &PostgreSQLMultiUpdateUndoLogBuilder{} +} + +func (p *PostgreSQLMultiUpdateUndoLogBuilder) GetExecutorType() types.ExecutorType { + return types.UpdateExecutor +} + +func (p *PostgreSQLMultiUpdateUndoLogBuilder) BeforeImage(ctx context.Context, execCtx *types.ExecContext) ([]*types.RecordImage, error) { + vals := execCtx.Values + if vals == nil { + vals = make([]driver.Value, 0) + for _, param := range execCtx.NamedValues { + vals = append(vals, param.Value) + } + } + + updateStatements, err := p.splitStatementsAdvanced(execCtx.Query) + if err != nil { + return nil, fmt.Errorf("failed to split statements: %v", err) + } + if len(updateStatements) == 0 { + return []*types.RecordImage{}, nil + } + + var allImages []*types.RecordImage + argOffset := 0 + + for i, updateSQL := range updateStatements { + updateSQL = strings.TrimSpace(updateSQL) + if updateSQL == "" { + return nil, fmt.Errorf("update statement %d is empty", i) + } + + stmtArgs, newOffset, err := p.extractArgsForStatementAdvanced(updateSQL, vals, argOffset) + if err != nil { + return nil, fmt.Errorf("failed to extract args for statement %d: %v", i, err) + } + argOffset = newOffset + + selectSQL, selectArgs, tableName, err := p.buildSingleBeforeImageSQL(updateSQL, stmtArgs) + if err != nil { + return nil, fmt.Errorf("failed to build before image SQL for statement %d: %v", i, err) + } + + if tableName == "" { + return nil, fmt.Errorf("table name is empty for statement %d", i) + } + + schemaName, tableNameOnly := p.parseTableName(tableName) + + var metaData *types.TableMeta + if schemaName != "" { + metaData, err = datasource.GetTableCache(types.DBTypePostgreSQL).GetTableMeta(ctx, schemaName, tableNameOnly) + } else { + metaData, err = datasource.GetTableCache(types.DBTypePostgreSQL).GetTableMeta(ctx, execCtx.DBName, tableNameOnly) + } + + if err != nil { + return nil, fmt.Errorf("failed to get table meta for %s: %v", tableName, err) + } + + stmt, err := execCtx.Conn.Prepare(selectSQL) + if err != nil { + return nil, fmt.Errorf("failed to prepare statement %d: %v", i, err) + } + defer stmt.Close() + + rows, err := stmt.Query(selectArgs) + if err != nil { + return nil, fmt.Errorf("failed to query statement %d: %v", i, err) + } + defer rows.Close() + + image, err := p.buildRecordImages(rows, metaData) + if err != nil { + return nil, fmt.Errorf("failed to build record images for statement %d: %v", i, err) + } + + lockKey := p.buildLockKey2(image, *metaData) + execCtx.TxCtx.LockKeys[lockKey] = struct{}{} + + allImages = append(allImages, image) + } + + return allImages, nil +} + +func (p *PostgreSQLMultiUpdateUndoLogBuilder) AfterImage(ctx context.Context, execCtx *types.ExecContext, beforeImages []*types.RecordImage) ([]*types.RecordImage, error) { + if len(beforeImages) == 0 { + return []*types.RecordImage{}, nil + } + + updateStatements := p.splitStatements(execCtx.Query) + if len(updateStatements) == 0 { + return []*types.RecordImage{}, nil + } + + if len(beforeImages) != len(updateStatements) { + return nil, fmt.Errorf("mismatch between before images count (%d) and update statements count (%d)", len(beforeImages), len(updateStatements)) + } + + var afterImages []*types.RecordImage + + for i, beforeImage := range beforeImages { + if beforeImage == nil { + return nil, fmt.Errorf("before image %d is nil", i) + } + + updateSQL := strings.TrimSpace(updateStatements[i]) + if updateSQL == "" { + return nil, fmt.Errorf("update statement %d is empty", i) + } + + _, _, tableName, err := p.buildSingleBeforeImageSQL(updateSQL, []driver.Value{}) + if err != nil { + return nil, fmt.Errorf("failed to parse table name from statement %d: %v", i, err) + } + + schemaName, tableNameOnly := p.parseTableName(tableName) + + var metaData *types.TableMeta + if schemaName != "" { + metaData, err = datasource.GetTableCache(types.DBTypePostgreSQL).GetTableMeta(ctx, schemaName, tableNameOnly) + } else { + metaData, err = datasource.GetTableCache(types.DBTypePostgreSQL).GetTableMeta(ctx, execCtx.DBName, tableNameOnly) + } + + if err != nil { + return nil, fmt.Errorf("failed to get table meta for %s: %v", tableName, err) + } + + selectSQL, selectArgs := p.buildAfterImageSQL(beforeImage, *metaData) + + stmt, err := execCtx.Conn.Prepare(selectSQL) + if err != nil { + return nil, fmt.Errorf("failed to prepare after image SQL: %v", err) + } + defer stmt.Close() + + rows, err := stmt.Query(selectArgs) + if err != nil { + return nil, fmt.Errorf("failed to query after image: %v", err) + } + defer rows.Close() + + image, err := p.buildRecordImages(rows, metaData) + if err != nil { + return nil, fmt.Errorf("failed to build record images: %v", err) + } + + afterImages = append(afterImages, image) + } + + return afterImages, nil +} + +func (p *PostgreSQLMultiUpdateUndoLogBuilder) splitStatements(query string) []string { + statements := strings.Split(query, ";") + var result []string + for _, stmt := range statements { + stmt = strings.TrimSpace(stmt) + if stmt != "" && len(stmt) >= 6 && strings.ToUpper(stmt[:6]) == "UPDATE" { + result = append(result, stmt) + } + } + return result +} + +func (p *PostgreSQLMultiUpdateUndoLogBuilder) splitStatementsAdvanced(query string) ([]string, error) { + var statements []string + var current strings.Builder + inSingleQuote := false + inDoubleQuote := false + inComment := false + i := 0 + + for i < len(query) { + ch := query[i] + + if !inComment && !inSingleQuote && !inDoubleQuote { + if ch == '-' && i+1 < len(query) && query[i+1] == '-' { + inComment = true + current.WriteByte(ch) + i++ + continue + } + if ch == '/' && i+1 < len(query) && query[i+1] == '*' { + inComment = true + current.WriteByte(ch) + i++ + continue + } + } + + if inComment { + current.WriteByte(ch) + if ch == '\n' || (ch == '*' && i+1 < len(query) && query[i+1] == '/') { + inComment = false + if ch == '*' { + i++ + current.WriteByte('/') + } + } + i++ + continue + } + + if ch == '\'' && !inDoubleQuote { + if i+1 < len(query) && query[i+1] == '\'' { + current.WriteByte(ch) + i++ + current.WriteByte(ch) + } else { + inSingleQuote = !inSingleQuote + current.WriteByte(ch) + } + } else if ch == '"' && !inSingleQuote { + if i+1 < len(query) && query[i+1] == '"' { + current.WriteByte(ch) + i++ + current.WriteByte(ch) + } else { + inDoubleQuote = !inDoubleQuote + current.WriteByte(ch) + } + } else if ch == ';' && !inSingleQuote && !inDoubleQuote { + stmt := strings.TrimSpace(current.String()) + if stmt != "" && len(stmt) >= 6 && strings.ToUpper(stmt[:6]) == "UPDATE" { + statements = append(statements, stmt) + } + current.Reset() + } else { + current.WriteByte(ch) + } + i++ + } + + finalStmt := strings.TrimSpace(current.String()) + if finalStmt != "" && len(finalStmt) >= 6 && strings.ToUpper(finalStmt[:6]) == "UPDATE" { + statements = append(statements, finalStmt) + } + + return statements, nil +} + +func (p *PostgreSQLMultiUpdateUndoLogBuilder) extractArgsForStatement(stmt string, allArgs []driver.Value, offset int) ([]driver.Value, int) { + paramCount := strings.Count(stmt, "?") + strings.Count(stmt, "$") Review Comment: Incorrect parameter counting logic: counting all '$' characters will produce false positives for PostgreSQL placeholders like `$1`, `$2`, etc. This should count numbered placeholders (`$\d+`) rather than raw '$' symbols. ########## pkg/datasource/sql/undo/builder/postgresql_insert_undo_log_builder.go: ########## @@ -0,0 +1,477 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package builder + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + "strings" + + "github.com/arana-db/parser/ast" + "github.com/auxten/postgresql-parser/pkg/sql/sem/tree" + + "seata.apache.org/seata-go/pkg/datasource/sql/types" + "seata.apache.org/seata-go/pkg/datasource/sql/undo" + "seata.apache.org/seata-go/pkg/util/log" +) + +type PostgreSQLInsertUndoLogBuilder struct { + BasicUndoLogBuilder + // InsertResult after insert sql + InsertResult types.ExecResult + IncrementStep int +} + +func GetPostgreSQLInsertUndoLogBuilder() undo.UndoLogBuilder { + return &PostgreSQLInsertUndoLogBuilder{ + BasicUndoLogBuilder: BasicUndoLogBuilder{}, + } +} + +func (p *PostgreSQLInsertUndoLogBuilder) GetExecutorType() types.ExecutorType { + return types.InsertExecutor +} + +func (p *PostgreSQLInsertUndoLogBuilder) BeforeImage(ctx context.Context, execCtx *types.ExecContext) ([]*types.RecordImage, error) { + return []*types.RecordImage{}, nil +} + +func (p *PostgreSQLInsertUndoLogBuilder) AfterImage(ctx context.Context, execCtx *types.ExecContext, beforeImages []*types.RecordImage) ([]*types.RecordImage, error) { + if execCtx == nil || execCtx.ParseContext == nil { + return nil, nil + } + + var tableName string + var err error + + // Support both MySQL and PostgreSQL AST + if execCtx.ParseContext.InsertStmt != nil { + // MySQL AST + tableName = execCtx.ParseContext.InsertStmt.Table.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).Name.O + } else if execCtx.ParseContext.AuxtenInsertStmt != nil { + // PostgreSQL AST + tableName, err = execCtx.ParseContext.GetTableName() + if err != nil { + return nil, err + } + } else { + return nil, nil + } + + metaData := execCtx.MetaDataMap[tableName] + selectSQL, selectArgs, err := p.buildAfterImageSQL(ctx, execCtx) + if err != nil { + return nil, err + } + + stmt, err := execCtx.Conn.Prepare(selectSQL) + if err != nil { + log.Errorf("build prepare stmt: %+v", err) + return nil, err + } + + rows, err := stmt.Query(selectArgs) + if err != nil { + log.Errorf("stmt query: %+v", err) + return nil, err + } + + image, err := p.buildRecordImages(rows, &metaData) + if err != nil { + return nil, err + } + + return []*types.RecordImage{image}, nil +} + +// buildAfterImageSQL build select SQL for PostgreSQL after insert +func (p *PostgreSQLInsertUndoLogBuilder) buildAfterImageSQL(ctx context.Context, execCtx *types.ExecContext) (string, []driver.Value, error) { + if execCtx == nil || execCtx.ParseContext == nil { + return "", nil, fmt.Errorf("can't find execCtx or ParseContext") + } + + var tableName string + var err error + + // Get table name from appropriate AST + if execCtx.ParseContext.InsertStmt != nil { + // MySQL AST + tableName = execCtx.ParseContext.InsertStmt.Table.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).Name.O + } else if execCtx.ParseContext.AuxtenInsertStmt != nil { + // PostgreSQL AST + tableName, err = execCtx.ParseContext.GetTableName() + if err != nil { + return "", nil, err + } + } else { + return "", nil, fmt.Errorf("can't find valid insert statement") + } + + if execCtx.MetaDataMap == nil { + return "", nil, fmt.Errorf("can't find MetaDataMap") + } + + meta := execCtx.MetaDataMap[tableName] + pkValuesMap, err := p.getPkValues(execCtx, execCtx.ParseContext, meta) + if err != nil { + return "", nil, err + } + + return p.buildSelectSQLByPKValues(tableName, meta.GetPrimaryKeyOnlyName(), pkValuesMap) +} + +func (p *PostgreSQLInsertUndoLogBuilder) getPkValues(execCtx *types.ExecContext, parseCtx *types.ParseContext, meta types.TableMeta) (map[string][]driver.Value, error) { + pkValuesMap := make(map[string][]driver.Value) + pkColumns := meta.GetPrimaryKeyOnlyName() + + if p.InsertResult != nil { + rows := p.InsertResult.GetRows() + if rows != nil { + pkValuesFromReturning, err := p.extractPkValuesFromReturning(rows, pkColumns) + if err == nil && len(pkValuesFromReturning) > 0 { + return pkValuesFromReturning, nil + } + } + } + + pkValuesFromInsert, err := p.extractPkValuesFromInsert(parseCtx, meta) + if err == nil && len(pkValuesFromInsert) > 0 { + return pkValuesFromInsert, nil + } + + if len(pkColumns) == 1 && p.InsertResult != nil && p.InsertResult.GetResult() != nil { + result := p.InsertResult.GetResult() + lastInsertId, err := result.LastInsertId() + if err == nil && lastInsertId > 0 { + rowsAffected, err := result.RowsAffected() + if err != nil { + return nil, err + } + + values := make([]driver.Value, 0) + for i := int64(0); i < rowsAffected; i++ { + values = append(values, lastInsertId+i) + } + pkValuesMap[pkColumns[0]] = values + return pkValuesMap, nil + } + } + + seqValues, err := p.getPkValuesFromSequence(execCtx, meta) + if err == nil && len(seqValues) > 0 { + return seqValues, nil + } + + return nil, fmt.Errorf("PostgreSQL insert primary key detection failed: cannot determine primary key values. Recommend using INSERT ... RETURNING to capture primary key values") +} + +func (p *PostgreSQLInsertUndoLogBuilder) extractPkValuesFromInsert(parseCtx *types.ParseContext, meta types.TableMeta) (map[string][]driver.Value, error) { + pkColumns := meta.GetPrimaryKeyOnlyName() + pkValuesMap := make(map[string][]driver.Value) + + if parseCtx.AuxtenInsertStmt != nil { + insertStmt := parseCtx.AuxtenInsertStmt + if insertStmt.Columns != nil && insertStmt.Rows != nil { + colIndexMap := make(map[string]int) + for i, col := range insertStmt.Columns { + colName := col.String() + colName = strings.Trim(colName, `" `) + colIndexMap[strings.ToLower(colName)] = i + } + + for _, pkCol := range pkColumns { + pkColLower := strings.ToLower(pkCol) + if colIdx, exists := colIndexMap[pkColLower]; exists { + values := make([]driver.Value, 0) + if selectStmt, ok := insertStmt.Rows.Select.(*tree.ValuesClause); ok { + for _, row := range selectStmt.Rows { + if colIdx < len(row) { + if datum, ok := row[colIdx].(*tree.StrVal); ok { + values = append(values, datum.RawString()) + } else if datum, ok := row[colIdx].(*tree.NumVal); ok { + values = append(values, datum.String()) + } + } + } + } + if len(values) > 0 { + pkValuesMap[pkCol] = values + } + } + } + } + } else if parseCtx.InsertStmt != nil { + insertStmt := parseCtx.InsertStmt + if insertStmt.Columns != nil && insertStmt.Lists != nil { + colIndexMap := make(map[string]int) + for i, col := range insertStmt.Columns { + colName := col.Name.O + colIndexMap[strings.ToLower(colName)] = i + } + + for _, pkCol := range pkColumns { + pkColLower := strings.ToLower(pkCol) + if colIdx, exists := colIndexMap[pkColLower]; exists { + values := make([]driver.Value, 0) + for _, row := range insertStmt.Lists { + if colIdx < len(row) { Review Comment: Empty conditional block - the code checks if `colIdx < len(row)` but performs no action. This appears to be incomplete logic that should append values to the `values` slice. ```suggestion if colIdx < len(row) { // Extract value using AST's Format method var sb strings.Builder row[colIdx].Format(&sb) values = append(values, sb.String()) ``` ########## pkg/datasource/sql/undo/builder/postgresql_multi_delete_undo_log_builder.go: ########## @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package builder + +import ( + "context" + "database/sql/driver" + "fmt" + "strings" + + "github.com/arana-db/parser/ast" + "github.com/arana-db/parser/format" + "github.com/auxten/postgresql-parser/pkg/sql/sem/tree" + + "seata.apache.org/seata-go/pkg/datasource/sql/datasource" + "seata.apache.org/seata-go/pkg/datasource/sql/parser" + "seata.apache.org/seata-go/pkg/datasource/sql/types" + "seata.apache.org/seata-go/pkg/datasource/sql/undo" + "seata.apache.org/seata-go/pkg/util/bytes" + "seata.apache.org/seata-go/pkg/util/log" +) + +type PostgreSQLMultiDeleteUndoLogBuilder struct { + BasicUndoLogBuilder +} + +func GetPostgreSQLMultiDeleteUndoLogBuilder() undo.UndoLogBuilder { + return &PostgreSQLMultiDeleteUndoLogBuilder{} +} + +func (p *PostgreSQLMultiDeleteUndoLogBuilder) GetExecutorType() types.ExecutorType { + return types.MultiDeleteExecutor +} + +func (p *PostgreSQLMultiDeleteUndoLogBuilder) BeforeImage(ctx context.Context, execCtx *types.ExecContext) ([]*types.RecordImage, error) { + vals := execCtx.Values + if vals == nil { + vals = make([]driver.Value, 0) + for _, param := range execCtx.NamedValues { + vals = append(vals, param.Value) + } + } + + deleteStatements := p.splitStatements(execCtx.Query) + if len(deleteStatements) == 0 { + return []*types.RecordImage{}, nil + } + + var allImages []*types.RecordImage + argOffset := 0 + + for _, deleteSQL := range deleteStatements { + deleteSQL = strings.TrimSpace(deleteSQL) + if deleteSQL == "" { + continue + } + + stmtArgs, newOffset := p.extractArgsForStatement(deleteSQL, vals, argOffset) + argOffset = newOffset + + selectSQL, selectArgs, tableName, err := p.buildSingleBeforeImageSQL(deleteSQL, stmtArgs) + if err != nil { + log.Errorf("failed to build before image SQL for statement %s: %v", deleteSQL, err) + continue + } + + if tableName == "" { + continue + } + + schemaName, tableNameOnly := p.parseTableName(tableName) + + var metaData *types.TableMeta + if schemaName != "" { + metaData, err = datasource.GetTableCache(types.DBTypePostgreSQL).GetTableMeta(ctx, schemaName, tableNameOnly) + } else { + metaData, err = datasource.GetTableCache(types.DBTypePostgreSQL).GetTableMeta(ctx, execCtx.DBName, tableNameOnly) + } + + if err != nil { + log.Errorf("failed to get table meta for %s: %v", tableName, err) + continue + } + + stmt, err := execCtx.Conn.Prepare(selectSQL) + if err != nil { + log.Errorf("build prepare stmt: %+v", err) + continue + } + + rows, err := stmt.Query(selectArgs) + if err != nil { + log.Errorf("stmt query: %+v", err) + stmt.Close() + continue + } + + image, err := p.buildRecordImages(rows, metaData) + if err != nil { + log.Errorf("failed to build record images: %v", err) + rows.Close() + stmt.Close() + continue + } + + lockKey := p.buildLockKey2(image, *metaData) + execCtx.TxCtx.LockKeys[lockKey] = struct{}{} + + allImages = append(allImages, image) + + rows.Close() + stmt.Close() + } + + return allImages, nil +} + +func (p *PostgreSQLMultiDeleteUndoLogBuilder) AfterImage(ctx context.Context, execCtx *types.ExecContext, beforeImages []*types.RecordImage) ([]*types.RecordImage, error) { + return []*types.RecordImage{}, nil +} + +func (p *PostgreSQLMultiDeleteUndoLogBuilder) splitStatements(query string) []string { + statements := strings.Split(query, ";") + var result []string + for _, stmt := range statements { + stmt = strings.TrimSpace(stmt) + if stmt != "" && len(stmt) >= 6 && strings.ToUpper(stmt[:6]) == "DELETE" { + result = append(result, stmt) + } + } + return result +} + +func (p *PostgreSQLMultiDeleteUndoLogBuilder) extractArgsForStatement(stmt string, allArgs []driver.Value, offset int) ([]driver.Value, int) { + paramCount := strings.Count(stmt, "?") + strings.Count(stmt, "$") Review Comment: Same issue as Comment 1: counting raw '$' characters instead of numbered PostgreSQL placeholders will cause incorrect parameter extraction. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
