This is an automated email from the ASF dual-hosted git repository.

abeizn pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-devlake.git


The following commit(s) were added to refs/heads/main by this push:
     new 511ed3b19 refactor: refactor starrocks, improve performance, solve 
bugs (#4475)
511ed3b19 is described below

commit 511ed3b19e1b9f11a251d73fac5e1b3aeb003019
Author: abeizn <[email protected]>
AuthorDate: Thu Feb 23 20:05:59 2023 +0800

    refactor: refactor starrocks, improve performance, solve bugs (#4475)
    
    * refactor: refactor starrocks, improve performance, solve bugs
---
 backend/plugins/starrocks/impl/impl.go   |   2 +-
 backend/plugins/starrocks/tasks/tasks.go | 552 ++++++++++++++++---------------
 2 files changed, 283 insertions(+), 271 deletions(-)

diff --git a/backend/plugins/starrocks/impl/impl.go 
b/backend/plugins/starrocks/impl/impl.go
index 00eed64a1..129c989ed 100644
--- a/backend/plugins/starrocks/impl/impl.go
+++ b/backend/plugins/starrocks/impl/impl.go
@@ -34,7 +34,7 @@ var _ plugin.PluginModel = (*StarRocks)(nil)
 
 func (s StarRocks) SubTaskMetas() []plugin.SubTaskMeta {
        return []plugin.SubTaskMeta{
-               tasks.LoadDataTaskMeta,
+               tasks.ExportDataTaskMeta,
        }
 }
 
diff --git a/backend/plugins/starrocks/tasks/tasks.go 
b/backend/plugins/starrocks/tasks/tasks.go
index 1b504e1c9..648271631 100644
--- a/backend/plugins/starrocks/tasks/tasks.go
+++ b/backend/plugins/starrocks/tasks/tasks.go
@@ -19,14 +19,8 @@ package tasks
 
 import (
        "bytes"
-       "database/sql"
        "encoding/json"
        "fmt"
-       "github.com/apache/incubator-devlake/core/dal"
-       "github.com/apache/incubator-devlake/core/errors"
-       "github.com/apache/incubator-devlake/core/plugin"
-       "github.com/apache/incubator-devlake/impls/dalgorm"
-       "github.com/apache/incubator-devlake/plugins/starrocks/utils"
        "io"
        "net/http"
        "net/url"
@@ -34,10 +28,17 @@ import (
        "strings"
        "time"
 
+       "github.com/apache/incubator-devlake/core/dal"
+       "github.com/apache/incubator-devlake/core/errors"
+       "github.com/apache/incubator-devlake/core/plugin"
+       "github.com/apache/incubator-devlake/impls/dalgorm"
+       "github.com/apache/incubator-devlake/plugins/starrocks/utils"
+
        "github.com/lib/pq"
        "gorm.io/driver/mysql"
        "gorm.io/driver/postgres"
        "gorm.io/gorm"
+       "gorm.io/gorm/clause"
 )
 
 type Table struct {
@@ -48,107 +49,61 @@ func (t *Table) TableName() string {
        return t.name
 }
 
-func LoadData(c plugin.SubTaskContext) errors.Error {
-       var db dal.Dal
+type DataConfigParams struct {
+       Ctx           plugin.SubTaskContext
+       Config        *StarRocksConfig
+       SrcDb         dal.Dal
+       DestDb        dal.Dal
+       SrcTableName  string
+       DestTableName string
+}
+
+func ExportData(c plugin.SubTaskContext) errors.Error {
+       logger := c.GetLogger()
        config := c.GetData().(*StarRocksConfig)
-       if config.SourceDsn != "" && config.SourceType != "" {
-               var o *gorm.DB
-               var err error
-               if config.SourceType == "mysql" {
-                       o, err = gorm.Open(mysql.Open(config.SourceDsn))
-                       if err != nil {
-                               return errors.Convert(err)
-                       }
-               } else if config.SourceType == "postgres" {
-                       o, err = gorm.Open(postgres.Open(config.SourceDsn))
-                       if err != nil {
-                               return errors.Convert(err)
-                       }
-               } else {
-                       return errors.NotFound.New(fmt.Sprintf("unsupported 
source type %s", config.SourceType))
-               }
-               db = dalgorm.NewDalgorm(o)
-               sqlDB, err := o.DB()
-               if err != nil {
-                       return errors.Convert(err)
-               }
-               defer sqlDB.Close()
-       } else {
-               db = c.GetDal()
+
+       // 1. Get db instance
+       db, err := getDbInstance(c)
+       if err != nil {
+               return errors.Convert(err)
        }
-       var starrocksTables []string
-       if config.DomainLayer != "" {
-               starrocksTables = 
utils.GetTablesByDomainLayer(config.DomainLayer)
-               if starrocksTables == nil {
-                       return errors.NotFound.New(fmt.Sprintf("no table found 
by domain layer: %s", config.DomainLayer))
-               }
-       } else {
-               tables := config.Tables
-               allTables, err := db.AllTables()
-               if err != nil {
-                       return err
-               }
-               if len(tables) == 0 {
-                       starrocksTables = allTables
-               } else {
-                       for _, table := range allTables {
-                               for _, r := range tables {
-                                       var ok bool
-                                       ok, err = 
errors.Convert01(regexp.Match(r, []byte(table)))
-                                       if err != nil {
-                                               return err
-                                       }
-                                       if ok {
-                                               starrocksTables = 
append(starrocksTables, table)
-                                       }
-                               }
-                       }
-               }
+       // 2. Filter out the tables to export
+       starrocksTables, err := getExportingTables(c, db)
+       if err != nil {
+               return errors.Convert(err)
        }
-
-       starrocks, err := sql.Open("mysql", 
fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local", 
config.User, config.Password, config.Host, config.Port, config.Database))
+       // 3. copy devlake data to starrocks
+       sr, err := 
gorm.Open(mysql.Open(fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local",
 config.User, config.Password, config.Host, config.Port, config.Database)))
        if err != nil {
                return errors.Convert(err)
        }
-       defer starrocks.Close()
+       starrocksDb := dalgorm.NewDalgorm(sr)
 
        for _, table := range starrocksTables {
-               starrocksTable := strings.TrimLeft(table, "_")
-               starrocksTmpTable := fmt.Sprintf("%s_tmp", starrocksTable)
-               var columnMap map[string]string
-               var orderBy string
-               var skip bool
-               columnMap, orderBy, skip, err = createTmpTable(starrocks, db, 
starrocksTable, starrocksTmpTable, table, c, config)
+               select {
+               case <-c.GetContext().Done():
+                       return errors.Convert(c.GetContext().Err())
+               default:
+               }
+
+               dc := DataConfigParams{
+                       Ctx:           c,
+                       Config:        config,
+                       SrcDb:         db,
+                       DestDb:        starrocksDb,
+                       SrcTableName:  table,
+                       DestTableName: strings.TrimLeft(table, "_"),
+               }
+               columnMap, orderBy, skip, err := createTmpTableInStarrocks(&dc)
                if skip {
-                       c.GetLogger().Info(fmt.Sprintf("table %s is up to date, 
so skip it", table))
+                       logger.Info(fmt.Sprintf("table %s is up to date, so 
skip it", table))
                        continue
                }
                if err != nil {
-                       c.GetLogger().Error(err, "create table %s in starrocks 
error", table)
+                       logger.Error(err, "create table %s in starrocks error", 
table)
                        return errors.Convert(err)
                }
-               if db.Dialect() == "postgres" {
-                       err = db.Exec("begin transaction isolation level 
repeatable read")
-                       if err != nil {
-                               return errors.Convert(err)
-                       }
-               } else if db.Dialect() == "mysql" {
-                       err = db.Exec("set session transaction isolation level 
repeatable read")
-                       if err != nil {
-                               return errors.Convert(err)
-                       }
-                       err = errors.Convert(db.Exec("start transaction"))
-                       if err != nil {
-                               return errors.Convert(err)
-                       }
-               } else {
-                       return errors.NotFound.New(fmt.Sprintf("unsupported 
dialect %s", db.Dialect()))
-               }
-               err = errors.Convert(loadData(starrocks, c, starrocksTable, 
starrocksTmpTable, table, columnMap, db, config, orderBy))
-               if err != nil {
-                       return errors.Convert(err)
-               }
-               err = errors.Convert(db.Exec("commit"))
+               err = copyDataToDst(&dc, columnMap, orderBy)
                if err != nil {
                        return errors.Convert(err)
                }
@@ -156,26 +111,33 @@ func LoadData(c plugin.SubTaskContext) errors.Error {
        return nil
 }
 
-func createTmpTable(starrocks *sql.DB, db dal.Dal, starrocksTable string, 
starrocksTmpTable string, table string, c plugin.SubTaskContext, config 
*StarRocksConfig) (map[string]string, string, bool, errors.Error) {
+// create temp table for dealing with some complex logic
+func createTmpTableInStarrocks(dc *DataConfigParams) (map[string]string, 
string, bool, error) {
+       logger := dc.Ctx.GetLogger()
+       config := dc.Config
+       db := dc.SrcDb
+       starrocksDb := dc.DestDb
+       table := dc.SrcTableName
+       starrocksTable := dc.DestTableName
+       starrocksTmpTable := fmt.Sprintf("%s_tmp", starrocksTable)
+
        columnMetas, err := db.GetColumns(&Table{name: table}, nil)
        updateColumn := config.UpdateColumn
        columnMap := make(map[string]string)
        if err != nil {
                if strings.Contains(err.Error(), "cached plan must not change 
result type") {
-                       c.GetLogger().Warn(err, "skip err: cached plan must not 
change result type")
+                       logger.Warn(err, "skip err: cached plan must not change 
result type")
                        columnMetas, err = db.GetColumns(&Table{name: table}, 
nil)
                        if err != nil {
-                               return nil, "", false, errors.Convert(err)
+                               return nil, "", false, err
                        }
                } else {
-                       return nil, "", false, errors.Convert(err)
+                       return nil, "", false, err
                }
        }
 
-       var pks []string
-       var orders []string
-       var columns []string
-       var separator string
+       var pks, orders, columns []string
+       var separator, firstcm, firstcmName string
        if db.Dialect() == "postgres" {
                separator = "\""
        } else if db.Dialect() == "mysql" {
@@ -183,57 +145,29 @@ func createTmpTable(starrocks *sql.DB, db dal.Dal, 
starrocksTable string, starro
        } else {
                return nil, "", false, 
errors.NotFound.New(fmt.Sprintf("unsupported dialect %s", db.Dialect()))
        }
-       firstcm := ""
-       firstcmName := ""
-       var rowsInStarRocks *sql.Rows
-       var rowsInPostgres dal.Rows
-       defer func() {
-               if rowsInStarRocks != nil {
-                       rowsInStarRocks.Close()
-               }
-               if rowsInPostgres != nil {
-                       rowsInPostgres.Close()
-               }
-       }()
        for _, cm := range columnMetas {
                name := cm.Name()
                if name == updateColumn {
                        // check update column to detect skip or not
-                       rowsInPostgres, err = db.Cursor(
-                               dal.From(table),
-                               dal.Select(updateColumn),
-                               dal.Limit(1),
-                               dal.Orderby(fmt.Sprintf("%s desc", 
updateColumn)),
-                       )
+                       var updatedFrom time.Time
+                       err = db.All(&updatedFrom, dal.Select(updateColumn), 
dal.From(table), dal.Limit(1), dal.Orderby(fmt.Sprintf("%s desc", 
updateColumn)))
                        if err != nil {
                                return nil, "", false, err
                        }
-                       var updatedFrom time.Time
-                       if rowsInPostgres.Next() {
-                               err = 
errors.Convert(rowsInPostgres.Scan(&updatedFrom))
-                               if err != nil {
+
+                       var updatedTo time.Time
+                       err = starrocksDb.All(&updatedTo, 
dal.Select(updateColumn), dal.From(starrocksTable), dal.Limit(1), 
dal.Orderby(fmt.Sprintf("%s desc", updateColumn)))
+                       if err != nil {
+                               if !strings.Contains(err.Error(), "Unknown 
table") {
                                        return nil, "", false, err
                                }
-                       }
-                       var starrocksErr error
-                       rowsInStarRocks, starrocksErr = 
starrocks.Query(fmt.Sprintf("select %s from %s order by %s desc limit 1", 
updateColumn, starrocksTable, updateColumn))
-                       if starrocksErr != nil {
-                               if !strings.Contains(starrocksErr.Error(), 
"Unknown table") {
-                                       return nil, "", false, 
errors.Convert(starrocksErr)
-                               }
                        } else {
-                               var updatedTo time.Time
-                               if rowsInStarRocks.Next() {
-                                       err = 
errors.Convert(rowsInStarRocks.Scan(&updatedTo))
-                                       if err != nil {
-                                               return nil, "", false, err
-                                       }
-                               }
                                if updatedFrom.Equal(updatedTo) {
                                        return nil, "", true, nil
                                }
                        }
                }
+
                columnDatatype, ok := cm.ColumnType()
                if !ok {
                        return columnMap, "", false, 
errors.Default.New(fmt.Sprintf("Get [%s] ColumeType Failed", name))
@@ -271,85 +205,156 @@ func createTmpTable(starrocks *sql.DB, db dal.Dal, 
starrocksTable string, starro
                        extra = v
                }
        }
-       tableSql := fmt.Sprintf("drop table if exists %s; create table if not 
exists `%s` ( %s ) %s", starrocksTmpTable, starrocksTmpTable, 
strings.Join(columns, ","), extra)
-       c.GetLogger().Debug(tableSql)
-       _, err = errors.Convert01(starrocks.Exec(tableSql))
+       tableSql := fmt.Sprintf("DROP TABLE IF EXISTS %s; CREATE TABLE IF NOT 
EXISTS `%s` ( %s ) %s", starrocksTmpTable, starrocksTmpTable, 
strings.Join(columns, ","), extra)
+       logger.Debug(tableSql)
+       err = starrocksDb.Exec(tableSql)
        return columnMap, orderBy, false, err
 }
 
-func loadData(starrocks *sql.DB, c plugin.SubTaskContext, starrocksTable, 
starrocksTmpTable, table string, columnMap map[string]string, db dal.Dal, 
config *StarRocksConfig, orderBy string) error {
-       offset := 0
+// put data to final dst database
+func copyDataToDst(dc *DataConfigParams, columnMap map[string]string, orderBy 
string) error {
+       c := dc.Ctx
+       logger := dc.Ctx.GetLogger()
+       config := dc.Config
+       db := dc.SrcDb
+       starrocksDb := dc.DestDb
+       table := dc.SrcTableName
+       starrocksTable := dc.DestTableName
+       starrocksTmpTable := fmt.Sprintf("%s_tmp", starrocksTable)
+
+       var offset int
        var err error
-       for {
-               var data []map[string]interface{}
-               // select data from db
-               err = func() error {
-                       var rows dal.Rows
-                       rows, err = db.Cursor(
-                               dal.From(table),
-                               dal.Orderby(orderBy),
-                               dal.Limit(config.BatchSize),
-                               dal.Offset(offset),
-                       )
-                       if err != nil {
-                               return err
-                       }
-                       defer rows.Close()
-                       cols, err := rows.Columns()
-                       if err != nil {
-                               return err
-                       }
-                       for rows.Next() {
-                               row := make(map[string]interface{})
-                               columns := make([]interface{}, len(cols))
-                               columnPointers := make([]interface{}, len(cols))
-                               for i := range columns {
-                                       dataType := columnMap[cols[i]]
-                                       if strings.HasPrefix(dataType, "array") 
{
-                                               var arr []string
-                                               columns[i] = &arr
-                                               columnPointers[i] = 
pq.Array(&arr)
-                                       } else {
-                                               columnPointers[i] = &columns[i]
-                                       }
-                               }
-                               err = rows.Scan(columnPointers...)
-                               if err != nil {
-                                       return err
-                               }
-                               for i, colName := range cols {
-                                       row[colName] = columns[i]
-                               }
-                               data = append(data, row)
+       var rows dal.Rows
+       rows, err = db.Cursor(
+               dal.From(table),
+               dal.Orderby(orderBy),
+       )
+       if err != nil {
+               return err
+       }
+       defer rows.Close()
+
+       var data []map[string]interface{}
+       cols, err := (rows).Columns()
+       if err != nil {
+               return err
+       }
+
+       var batchCount int
+       for rows.Next() {
+               select {
+               case <-c.GetContext().Done():
+                       return c.GetContext().Err()
+               default:
+               }
+               row := make(map[string]interface{})
+               columns := make([]interface{}, len(cols))
+               columnPointers := make([]interface{}, len(cols))
+               for i := range columns {
+                       dataType := columnMap[cols[i]]
+                       if strings.HasPrefix(dataType, "array") {
+                               var arr []string
+                               columns[i] = &arr
+                               columnPointers[i] = pq.Array(&arr)
+                       } else {
+                               columnPointers[i] = &columns[i]
                        }
-                       return nil
-               }()
+               }
+               err = rows.Scan(columnPointers...)
                if err != nil {
                        return err
                }
-               if len(data) == 0 {
-                       c.GetLogger().Warn(nil, "no data found in table %s 
already, limit: %d, offset: %d, so break", table, config.BatchSize, offset)
-                       break
+               for i, colName := range cols {
+                       row[colName] = columns[i]
                }
-               // insert data to tmp table
-               loadURL := fmt.Sprintf("http://%s:%d/api/%s/%s/_stream_load";, 
config.BeHost, config.BePort, config.Database, starrocksTmpTable)
-               headers := map[string]string{
-                       "format":            "json",
-                       "strip_outer_array": "true",
-                       "Expect":            "100-continue",
-                       "ignore_json_size":  "true",
-                       "Connection":        "close",
+               data = append(data, row)
+               batchCount += 1
+               if batchCount == config.BatchSize {
+                       err = putBatchData(c, starrocksTmpTable, table, data, 
config, offset)
+                       if err != nil {
+                               return err
+                       }
+                       batchCount = 0
+                       data = nil
                }
-               jsonData, err := json.Marshal(data)
+       }
+       if batchCount != 0 {
+               err = putBatchData(c, starrocksTmpTable, table, data, config, 
offset)
                if err != nil {
                        return err
                }
-               client := http.Client{
-                       CheckRedirect: func(req *http.Request, via 
[]*http.Request) error {
-                               return http.ErrUseLastResponse
-                       },
+       }
+
+       // drop old table
+       err = starrocksDb.Exec("DROP TABLE IF EXISTS ?", clause.Table{Name: 
starrocksTable})
+       if err != nil {
+               return err
+       }
+       // rename tmp table to old table
+       err = starrocksDb.Exec("ALTER TABLE ? RENAME ?", clause.Table{Name: 
starrocksTmpTable}, clause.Table{Name: starrocksTable})
+       if err != nil {
+               return err
+       }
+
+       // check data count
+       sourceCount, err := db.Count(dal.From(table))
+       if err != nil {
+               return err
+       }
+       starrocksCount, err := starrocksDb.Count(dal.From(starrocksTable))
+       if err != nil {
+               return err
+       }
+       if sourceCount != starrocksCount {
+               logger.Warn(nil, "source count %d not equal to starrocks count 
%d", sourceCount, starrocksCount)
+       }
+       logger.Info("load %s to starrocks success", table)
+       return nil
+}
+
+// put batch size data to database
+func putBatchData(c plugin.SubTaskContext, starrocksTmpTable, table string, 
data []map[string]interface{}, config *StarRocksConfig, offset int) error {
+       logger := c.GetLogger()
+       // insert data to tmp table
+       loadURL := fmt.Sprintf("http://%s:%d/api/%s/%s/_stream_load";, 
config.BeHost, config.BePort, config.Database, starrocksTmpTable)
+       headers := map[string]string{
+               "format":            "json",
+               "strip_outer_array": "true",
+               "Expect":            "100-continue",
+               "ignore_json_size":  "true",
+               "Connection":        "close",
+       }
+       jsonData, err := json.Marshal(data)
+       if err != nil {
+               return err
+       }
+       client := http.Client{
+               CheckRedirect: func(req *http.Request, via []*http.Request) 
error {
+                       return http.ErrUseLastResponse
+               },
+       }
+       req, err := http.NewRequest(http.MethodPut, loadURL, 
bytes.NewBuffer(jsonData))
+       if err != nil {
+               return err
+       }
+       req.SetBasicAuth(config.User, config.Password)
+       for k, v := range headers {
+               req.Header.Set(k, v)
+       }
+       resp, err := client.Do(req)
+       if err != nil {
+               return err
+       }
+       defer resp.Body.Close()
+       var b []byte
+
+       if resp.StatusCode == 307 {
+               var location *url.URL
+               location, err = resp.Location()
+               if err != nil {
+                       return err
                }
-               req, err := http.NewRequest(http.MethodPut, loadURL, 
bytes.NewBuffer(jsonData))
+               req, err = http.NewRequest(http.MethodPut, location.String(), 
bytes.NewBuffer(jsonData))
                if err != nil {
                        return err
                }
@@ -357,96 +362,103 @@ func loadData(starrocks *sql.DB, c 
plugin.SubTaskContext, starrocksTable, starro
                for k, v := range headers {
                        req.Header.Set(k, v)
                }
-               resp, err := client.Do(req)
+               respRetry, err := client.Do(req)
                if err != nil {
                        return err
                }
-               if resp.StatusCode == 307 {
-                       var location *url.URL
-                       location, err = resp.Location()
-                       if err != nil {
-                               return err
-                       }
-                       req, err = http.NewRequest(http.MethodPut, 
location.String(), bytes.NewBuffer(jsonData))
-                       if err != nil {
-                               return err
-                       }
-                       req.SetBasicAuth(config.User, config.Password)
-                       for k, v := range headers {
-                               req.Header.Set(k, v)
-                       }
-                       resp, err = client.Do(req)
-               }
-               if err != nil {
-                       return err
-               }
-               b, err := io.ReadAll(resp.Body)
+               defer respRetry.Body.Close()
+               b, err = io.ReadAll(respRetry.Body)
                if err != nil {
                        return err
                }
-               var result map[string]interface{}
-               err = json.Unmarshal(b, &result)
+       } else {
+               b, err = io.ReadAll(resp.Body)
                if err != nil {
                        return err
                }
-               if resp.StatusCode != http.StatusOK {
-                       c.GetLogger().Error(nil, "[%s]: %s", resp.StatusCode, 
string(b))
-               }
-               if result["Status"] != "Success" {
-                       c.GetLogger().Error(nil, "load %s failed: %s", table, 
string(b))
-               } else {
-                       c.GetLogger().Debug("load %s success: %s, limit: %d, 
offset: %d", table, b, config.BatchSize, offset)
-               }
-               offset += len(data)
        }
-       // drop old table
-       _, err = starrocks.Exec(fmt.Sprintf("drop table if exists %s", 
starrocksTable))
+
+       var result map[string]interface{}
+       err = json.Unmarshal(b, &result)
        if err != nil {
                return err
        }
-       // rename tmp table to old table
-       _, err = starrocks.Exec(fmt.Sprintf("alter table %s rename %s", 
starrocksTmpTable, starrocksTable))
-       if err != nil {
-               return err
+       if resp.StatusCode != http.StatusOK {
+               logger.Error(nil, "[%s]: %s", resp.StatusCode, string(b))
        }
-       // check data count
-       rows, err := db.Cursor(
-               dal.Select("count(*)"),
-               dal.From(table),
-       )
-       if err != nil {
-               return err
+       if result["Status"] != "Success" {
+               logger.Error(nil, "load %s failed: %s", table, string(b))
+       } else {
+               logger.Debug("load %s success: %s, limit: %d, offset: %d", 
table, b, config.BatchSize, offset)
        }
-       defer rows.Close()
-       var sourceCount int
-       for rows.Next() {
-               err = rows.Scan(&sourceCount)
-               if err != nil {
-                       return err
+       return nil
+}
+
+// get db instance
+func getDbInstance(c plugin.SubTaskContext) (db dal.Dal, err error) {
+       config := c.GetData().(*StarRocksConfig)
+       if config.SourceDsn != "" && config.SourceType != "" {
+               var o *gorm.DB
+               switch config.SourceType {
+               case "mysql":
+                       o, err = gorm.Open(mysql.Open(config.SourceDsn))
+                       if err != nil {
+                               return nil, err
+                       }
+               case "postgres":
+                       o, err = gorm.Open(postgres.Open(config.SourceDsn))
+                       if err != nil {
+                               return nil, err
+                       }
+               default:
+                       return nil, 
errors.NotFound.New(fmt.Sprintf("unsupported source type %s", 
config.SourceType))
                }
+               db = dalgorm.NewDalgorm(o)
+       } else {
+               db = c.GetDal()
        }
-       rowsStarRocks, err := starrocks.Query(fmt.Sprintf("select count(*) from 
%s", starrocksTable))
-       if err != nil {
-               return err
-       }
-       defer rowsStarRocks.Close()
-       var starrocksCount int
-       for rowsStarRocks.Next() {
-               err = rowsStarRocks.Scan(&starrocksCount)
+
+       return db, nil
+
+}
+
+// get imported tables
+func getExportingTables(c plugin.SubTaskContext, db dal.Dal) (starrocksTables 
[]string, err error) {
+       config := c.GetData().(*StarRocksConfig)
+       if config.DomainLayer != "" {
+               starrocksTables = 
utils.GetTablesByDomainLayer(config.DomainLayer)
+               if starrocksTables == nil {
+                       return nil, errors.NotFound.New(fmt.Sprintf("no table 
found by domain layer: %s", config.DomainLayer))
+               }
+       } else {
+               tables := config.Tables
+               allTables, err := db.AllTables()
                if err != nil {
-                       return err
+                       return nil, err
+               }
+               if len(tables) == 0 {
+                       starrocksTables = allTables
+               } else {
+                       for _, table := range allTables {
+                               for _, r := range tables {
+                                       var ok bool
+                                       ok, err := regexp.Match(r, 
[]byte(table))
+                                       if err != nil {
+                                               return nil, err
+                                       }
+                                       if ok {
+                                               starrocksTables = 
append(starrocksTables, table)
+                                       }
+                               }
+                       }
                }
        }
-       if sourceCount != starrocksCount {
-               c.GetLogger().Warn(nil, "source count %d not equal to starrocks 
count %d", sourceCount, starrocksCount)
-       }
-       c.GetLogger().Info("load %s to starrocks success", table)
-       return nil
+       return starrocksTables, nil
 }
 
-var LoadDataTaskMeta = plugin.SubTaskMeta{
-       Name:             "LoadData",
-       EntryPoint:       LoadData,
+var ExportDataTaskMeta = plugin.SubTaskMeta{
+       Name:             "ExportData",
+       EntryPoint:       ExportData,
        EnabledByDefault: true,
        Description:      "Load data to StarRocks",
 }

Reply via email to