This is an automated email from the ASF dual-hosted git repository.

maxyang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/cloudberry-go-libs.git

commit 7a51d647f4bdd1faa2edd68d75955040bec265ca
Author: Nihal Jain <[email protected]>
AuthorDate: Mon May 6 14:48:06 2024 +0530

    Add `SelectContext` and `QueryContext` to DBConn
    
    These new methods allows us to pass a context when executing queries. This 
can be helpful in situations where we want to cancel a running query by 
cancelling the context passed.
---
 dbconn/dbconn.go      | 16 +++++++++
 dbconn/dbconn_test.go | 94 +++++++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 110 insertions(+)

diff --git a/dbconn/dbconn.go b/dbconn/dbconn.go
index e98ceb2..91f5efa 100644
--- a/dbconn/dbconn.go
+++ b/dbconn/dbconn.go
@@ -352,6 +352,14 @@ func (dbconn *DBConn) Select(destination interface{}, 
query string, whichConn ..
        return dbconn.ConnPool[connNum].Select(destination, query)
 }
 
+func (dbconn *DBConn) SelectContext(ctx context.Context, destination 
interface{}, query string, whichConn ...int) error {
+       connNum := dbconn.ValidateConnNum(whichConn...)
+       if dbconn.Tx[connNum] != nil {
+               return dbconn.Tx[connNum].SelectContext(ctx, destination, query)
+       }
+       return dbconn.ConnPool[connNum].SelectContext(ctx, destination, query)
+}
+
 func (dbconn *DBConn) QueryWithArgs(query string, args ...interface{}) 
(*sqlx.Rows, error) {
        if dbconn.Tx[0] != nil {
                return dbconn.Tx[0].Queryx(query, args...)
@@ -367,6 +375,14 @@ func (dbconn *DBConn) Query(query string, whichConn 
...int) (*sqlx.Rows, error)
        return dbconn.ConnPool[connNum].Queryx(query)
 }
 
+func (dbconn *DBConn) QueryContext(ctx context.Context, query string, 
whichConn ...int) (*sqlx.Rows, error) {
+       connNum := dbconn.ValidateConnNum(whichConn...)
+       if dbconn.Tx[connNum] != nil {
+               return dbconn.Tx[connNum].QueryxContext(ctx, query)
+       }
+       return dbconn.ConnPool[connNum].QueryxContext(ctx, query)
+}
+
 /*
  * Ensure there isn't a mismatch between the connection pool size and number of
  * jobs, and default to using the first connection if no number is given.
diff --git a/dbconn/dbconn_test.go b/dbconn/dbconn_test.go
index b58c67b..c08e33c 100644
--- a/dbconn/dbconn_test.go
+++ b/dbconn/dbconn_test.go
@@ -360,6 +360,100 @@ var _ = Describe("dbconn/dbconn tests", func() {
                        Expect(testSlice[1].Tablename).To(Equal("table2"))
                })
        })
+       Describe("DBConn.SelectContext", func() {
+               It("executes a SELECT outside of a transaction", func() {
+                       two_col_rows := sqlmock.NewRows([]string{"schemaname", 
"tablename"}).
+                               AddRow("schema1", "table1").
+                               AddRow("schema2", "table2")
+                       mock.ExpectQuery("SELECT 
(.*)").WillReturnRows(two_col_rows)
+
+                       testSlice := make([]struct {
+                               Schemaname string
+                               Tablename  string
+                       }, 0)
+
+                       ctx, cancel := context.WithCancel(context.Background())
+                       defer cancel()
+
+                       err := connection.SelectContext(ctx, &testSlice, 
"SELECT schemaname, tablename FROM two_columns ORDER BY schemaname LIMIT 2")
+
+                       Expect(err).ToNot(HaveOccurred())
+                       Expect(len(testSlice)).To(Equal(2))
+                       Expect(testSlice[0].Schemaname).To(Equal("schema1"))
+                       Expect(testSlice[0].Tablename).To(Equal("table1"))
+                       Expect(testSlice[1].Schemaname).To(Equal("schema2"))
+                       Expect(testSlice[1].Tablename).To(Equal("table2"))
+               })
+               It("errors out when the context is cancelled", func() {
+                       two_col_rows := sqlmock.NewRows([]string{"schemaname", 
"tablename"}).
+                               AddRow("schema1", "table1").
+                               AddRow("schema2", "table2")
+                       mock.ExpectQuery("SELECT 
(.*)").WillReturnRows(two_col_rows)
+
+                       testSlice := make([]struct {
+                               Schemaname string
+                               Tablename  string
+                       }, 0)
+
+                       ctx, cancel := context.WithCancel(context.Background())
+                       cancel()
+
+                       err := connection.SelectContext(ctx, &testSlice, 
"SELECT schemaname, tablename FROM two_columns ORDER BY schemaname LIMIT 2")
+
+                       Expect(err).To(HaveOccurred())
+                       Expect(err).Should(MatchError(context.Canceled))
+               })
+       })
+       Describe("DBConn.QueryContext", func() {
+               It("executes a QUERY and returns the correct rows", func() {
+                       two_col_rows := sqlmock.NewRows([]string{"schemaname", 
"tablename"}).
+                               AddRow("schema1", "table1").
+                               AddRow("schema2", "table2")
+                       mock.ExpectQuery("SELECT 
(.*)").WillReturnRows(two_col_rows)
+
+                       type testSlice struct {
+                               Schemaname string
+                               Tablename  string
+                       }
+
+                       ctx, cancel := context.WithCancel(context.Background())
+                       defer cancel()
+
+                       rows, err := connection.QueryContext(ctx, "SELECT 
schemaname, tablename FROM two_columns ORDER BY schemaname LIMIT 2")
+                       defer rows.Close()
+
+                       columns, _ := rows.Columns()
+
+                       var result []testSlice
+                       for rows.Next() {
+                               var row testSlice
+                               rows.StructScan(&row)
+                               result = append(result, row)
+                       }
+
+                       Expect(err).ToNot(HaveOccurred())
+                       Expect(len(result)).To(Equal(2))
+                       Expect(columns).To(Equal([]string{"schemaname", 
"tablename"}))
+                       Expect(result[0].Schemaname).To(Equal("schema1"))
+                       Expect(result[0].Tablename).To(Equal("table1"))
+                       Expect(result[1].Schemaname).To(Equal("schema2"))
+                       Expect(result[1].Tablename).To(Equal("table2"))
+               })
+               It("errors out when the context is cancelled", func() {
+                       two_col_rows := sqlmock.NewRows([]string{"schemaname", 
"tablename"}).
+                               AddRow("schema1", "table1").
+                               AddRow("schema2", "table2")
+                       mock.ExpectQuery("SELECT 
(.*)").WillReturnRows(two_col_rows)
+
+                       ctx, cancel := context.WithCancel(context.Background())
+                       cancel()
+
+                       _, err := connection.QueryContext(ctx, "SELECT 
schemaname, tablename FROM two_columns ORDER BY schemaname LIMIT 2")
+
+                       Expect(err).To(HaveOccurred())
+                       Expect(err).Should(MatchError(context.Canceled))
+               })
+       })
        Describe("DBConn.MustBegin", func() {
                It("successfully executes a BEGIN outside a transaction", 
func() {
                        ExpectBegin(mock)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to