This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0496f98e722 Adds retry to go-sdk connection (#64184)
0496f98e722 is described below

commit 0496f98e722805545b20cd7d4d9b8aac979e4031
Author: Aritra Basu <[email protected]>
AuthorDate: Sat Mar 28 18:40:42 2026 +0530

    Adds retry to go-sdk connection (#64184)
---
 go-sdk/edge/worker.go        | 30 +++++++++++++++---------------
 go-sdk/pkg/config/config.go  | 29 ++++++++++++++++++++++++++++-
 go-sdk/pkg/edgeapi/client.go | 41 +++++++++++++++++++++++++++++++++++++++--
 3 files changed, 82 insertions(+), 18 deletions(-)

diff --git a/go-sdk/edge/worker.go b/go-sdk/edge/worker.go
index e40560ef26c..48e7eec18fe 100644
--- a/go-sdk/edge/worker.go
+++ b/go-sdk/edge/worker.go
@@ -37,6 +37,7 @@ import (
        "github.com/apache/airflow/go-sdk/bundle/bundlev1"
        "github.com/apache/airflow/go-sdk/bundle/bundlev1/bundlev1client"
        "github.com/apache/airflow/go-sdk/pkg/bundles/shared"
+       "github.com/apache/airflow/go-sdk/pkg/config"
        "github.com/apache/airflow/go-sdk/pkg/edgeapi"
        "github.com/apache/airflow/go-sdk/pkg/logging"
        logserver "github.com/apache/airflow/go-sdk/pkg/logging/server"
@@ -69,8 +70,6 @@ var (
 )
 
 func Run(ctx context.Context) error {
-       apiURL := viper.GetString("edge.api_url")
-
        hostname := viper.GetString("edge.hostname")
 
        if hostname == "" {
@@ -80,10 +79,14 @@ func Run(ctx context.Context) error {
                        return err
                }
        }
+       var conf config.WorkerConfig
+       err := viper.Unmarshal(&conf)
+       if err != nil {
+               fmt.Println(err)
+       }
 
-       w, err := NewWorker(hostname, apiURL, 
viper.GetString("api_auth.secret_key"),
-               viper.GetStringSlice("queues"),
-       )
+       w, err := NewWorker(conf)
+       w.logger.Info("Config", "config", conf)
        if err != nil {
                return err
        }
@@ -106,13 +109,10 @@ func configOrDefault[T cast.Basic](key string, fallback 
T) T {
        return cast.To[T](x)
 }
 
-func NewWorker(
-       hostname string,
-       apiURL string,
-       apiJWTSecretKey string,
-       queues []string,
-) (*worker, error) {
-       client, err := edgeapi.NewClient(apiURL, 
edgeapi.WithEdgeAPIJWTKey([]byte(apiJWTSecretKey)))
+func NewWorker(conf config.WorkerConfig) (*worker, error) {
+       client, err := edgeapi.NewClient(conf.ApiURL,
+               edgeapi.WithEdgeAPIJWTKey([]byte(conf.ApiJWTSecretKey), 
conf.Issuer),
+               edgeapi.WithRetry(conf.ClientConfig))
        if err != nil {
                return nil, err
        }
@@ -140,8 +140,8 @@ func NewWorker(
        w := &worker{
                Discovery: 
shared.NewDiscovery(viper.GetString("bundles.folder"), nil),
 
-               hostname:        hostname,
-               queues:          queues,
+               hostname:        conf.Hostname,
+               queues:          conf.Queues,
                client:          client,
                sysInfo:         sysInfo,
                logger:          slog.Default().With("logger", "edge.worker"),
@@ -149,7 +149,7 @@ func NewWorker(
                activeWorkloads: map[uuid.UUID]bundlev1.ExecuteTaskWorkload{},
        }
 
-       w.logger.Info("Starting Go Edge worker", "queues", queues)
+       w.logger.Info("Starting Go Edge worker", "queues", conf.Queues)
 
        w.freeConcurrency.Store(maxConcurrency)
 
diff --git a/go-sdk/pkg/config/config.go b/go-sdk/pkg/config/config.go
index aedb80f9256..f0a29895834 100644
--- a/go-sdk/pkg/config/config.go
+++ b/go-sdk/pkg/config/config.go
@@ -23,6 +23,7 @@ import (
        "os"
        "path"
        "strings"
+       "time"
 
        "github.com/MatusOllah/slogcolor"
        "github.com/fatih/color"
@@ -38,6 +39,21 @@ type BundleConfig struct {
        BundlesFolder string `mapstructure:"bundles_folder"`
 }
 
+type WorkerConfig struct {
+       ClientConfig    ClientConfig `mapstructure:",squash"`
+       ApiJWTSecretKey string       `mapstructure:"API_AUTH.SECRET_KEY"`
+       Issuer          string       `mapstructure:"API_AUTH.ISSUER"`
+       Queues          []string     `mapstructure:"QUEUES"`
+       Hostname        string       `mapstructure:"EDGE.HOSTNAME"`
+       ApiURL          string       `mapstructure:"EDGE.API_URL"`
+}
+
+type ClientConfig struct {
+       RetryCount    int           `mapstructure:"EDGE.API_RETRIES"`
+       StartWaitTime time.Duration `mapstructure:"EDGE.API_RETRY_WAIT_MIN"`
+       MaxWaitTime   time.Duration `mapstructure:"EDGE.API_RETRY_WAIT_MAX"`
+}
+
 var envKeyReplacer *strings.Replacer = strings.NewReplacer(".", "__", "-", "_")
 
 func InitColor(rootCmd *cobra.Command) {
@@ -65,6 +81,7 @@ func Configure(cmd *cobra.Command) error {
        }
        // Bind the current command's flags to viper
        BindFlagsToViper(cmd, v)
+       setDefaults(v)
 
        logger := makeLogger(v)
        slog.SetDefault(logger)
@@ -122,7 +139,10 @@ func SetupViper(cfgFile string) (*viper.Viper, error) {
                viper.SetConfigName("go-sdk.yaml")
        }
 
-       viper.SetOptions(viper.ExperimentalBindStruct())
+       // We use __ as the delimiter, by default viper uses . which leads to
+       // unmarshalling failing, since viper tries to search for nested structs
+       // when the key has a .
+       viper.SetOptions(viper.ExperimentalBindStruct(), 
viper.KeyDelimiter("__"))
 
        // Attempt to read the config file, gracefully ignoring errors
        // caused by a config file not being found. Return an error
@@ -187,3 +207,10 @@ func BindFlagsToViper(cmd *cobra.Command, viper 
*viper.Viper) {
                }
        })
 }
+
+func setDefaults(viper *viper.Viper) {
+       viper.SetDefault("edge.api_retries", 10)
+       viper.SetDefault("edge.api_retry_wait_min", "1m")
+       viper.SetDefault("edge.api_retry_wait_max", "90m")
+       viper.SetDefault("api_auth.issuer", "airflow")
+}
diff --git a/go-sdk/pkg/edgeapi/client.go b/go-sdk/pkg/edgeapi/client.go
index b8dd0458162..59d17fc4096 100644
--- a/go-sdk/pkg/edgeapi/client.go
+++ b/go-sdk/pkg/edgeapi/client.go
@@ -18,18 +18,25 @@
 package edgeapi
 
 import (
+       "errors"
+       "net"
+       "net/http"
+       "os"
        "strings"
+       "syscall"
        "time"
 
        "github.com/golang-jwt/jwt/v5"
        "resty.dev/v3"
+
+       "github.com/apache/airflow/go-sdk/pkg/config"
 )
 
 //go:generate -command openapi-gen go run 
github.com/ashb/oapi-resty-codegen@latest --config oapi-codegen.yml
 
 //go:generate openapi-gen 
https://raw.githubusercontent.com/apache/airflow/refs/tags/providers-edge3/1.3.0/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml
 
-func WithEdgeAPIJWTKey(key []byte) ClientOption {
+func WithEdgeAPIJWTKey(key []byte, issuer string) ClientOption {
        return func(c *Client) error {
                c.SetAuthScheme("")
 
@@ -39,7 +46,7 @@ func WithEdgeAPIJWTKey(key []byte) ClientOption {
                        now := time.Now().UTC().Unix()
                        t := jwt.NewWithClaims(jwt.SigningMethodHS512, 
jwt.MapClaims{
                                "method": endpointPath,
-                               "iss":    "airflow",
+                               "iss":    issuer,
                                "aud":    "api",
                                "iat":    now,
                                "nbf":    now,
@@ -61,3 +68,33 @@ func WithEdgeAPIJWTKey(key []byte) ClientOption {
                return nil
        }
 }
+
+func WithRetry(conf config.ClientConfig) ClientOption {
+       return func(c *Client) error {
+               c.SetRetryCount(conf.RetryCount).
+                       SetRetryWaitTime(conf.StartWaitTime).
+                       SetRetryMaxWaitTime(conf.MaxWaitTime).
+                       AddRetryConditions(func(r *resty.Response, err error) 
bool {
+                               var opErr *net.OpError
+
+                               if errors.As(err, &opErr) {
+                                       if opErr.Temporary() || opErr.Timeout() 
{
+                                               c.Logger().Warnf("Retrying 
request %v", err)
+                                               return true
+                                       }
+                                       if sysErr, ok := 
opErr.Err.(*os.SyscallError); ok {
+                                               if sysErr.Err == 
syscall.ECONNREFUSED {
+                                                       
c.Logger().Warnf("Retrying request %v", err)
+                                                       return true
+                                               }
+                                       }
+                               }
+                               if r.StatusCode() == http.StatusBadGateway {
+                                       c.Logger().Warnf("Retrying request %v", 
err)
+                                       return true
+                               }
+                               return false
+                       }).SetAllowNonIdempotentRetry(true)
+               return nil
+       }
+}

Reply via email to