This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 268ff81e3bf Change the interface by which Go tasks get access to 
variables etc. (#54743)
268ff81e3bf is described below

commit 268ff81e3bfba41533fc895d6a1710aae71b8a57
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Tue Sep 2 17:44:23 2025 +0100

    Change the interface by which Go tasks get access to variables etc. (#54743)
    
    The previous approach (of a module function, and pulling the API client out 
of
    context) was sort of thrown together, but on coming back to it with fresh 
eyes
    I didn't like it for a couple of reasons:
    
    1. It makes it harder to library consumers (i.e. Go task authors) to write
       unit tests of their task functions.
    
       Instead of being able to simply call the function with a mock that
       implements the right interface, they would have had to create a mock of 
the
       API client (which is lower level) and then put that in the context under
       the right Key -- all of which is too tightly coupled to the current
       implementation and shouldn't be exposed to users.
    
    2. It was essentially using global variables.
    
       Now, we are still pulling the HTTP API client out of the Context, but 
that
       could very easily be changed in future to be stored as a field on the
       `client` struct.
    
    3. By having the function accept an argument of Client interface it is much
       clearer what the function needs, and what it's doing.
    
       It is already an accepted pattern in Python that task functions get 
called
       with values for function arguments of "special" names (`ti` for example 
in
       python), so this is not a new pattern for us. This is similar to how
       Temporal support [passing arguments][1] to Activities (= our Tasks)
    
    This PR doesn't introduce any new features, but sets us up to be add
    Connection and XCom support in a future PR.
    
    [1]: 
https://docs.temporal.io/develop/go/core-application#activity-parameters
    
    Co-authored-by: Nick Stenning <[email protected]>
---
 go-sdk/.mockery.yml                   |   3 +
 go-sdk/example/main.go                |   4 +-
 go-sdk/example/main_test.go           |  57 +++++++++++++++++++
 go-sdk/sdk/{variable.go => client.go} |  44 +++++++++++++--
 go-sdk/sdk/client_test.go             | 100 ++++++++++++++++++++++++++++++++++
 go-sdk/sdk/{variable.go => errors.go} |  25 ++-------
 go-sdk/sdk/sdk.go                     |  52 ++++++++++++++++++
 7 files changed, 257 insertions(+), 28 deletions(-)

diff --git a/go-sdk/.mockery.yml b/go-sdk/.mockery.yml
index b663e1b8603..1b7fc313d2c 100644
--- a/go-sdk/.mockery.yml
+++ b/go-sdk/.mockery.yml
@@ -23,3 +23,6 @@ all: true
 recursive: true
 packages:
   github.com/apache/airflow/go-sdk:
+  github.com/apache/airflow/go-sdk/sdk:
+    config:
+      all: false
diff --git a/go-sdk/example/main.go b/go-sdk/example/main.go
index 34022ce0061..da0a618e241 100644
--- a/go-sdk/example/main.go
+++ b/go-sdk/example/main.go
@@ -56,9 +56,9 @@ func extract(ctx context.Context, log *slog.Logger) error {
        return nil
 }
 
-func transform(ctx context.Context, log *slog.Logger) error {
+func transform(ctx context.Context, client sdk.Client, log *slog.Logger) error 
{
        key := "my_variable"
-       val, err := sdk.VariableGet(ctx, key)
+       val, err := client.GetVariable(ctx, key)
        if err != nil {
                return err
        }
diff --git a/go-sdk/example/main_test.go b/go-sdk/example/main_test.go
new file mode 100644
index 00000000000..54ff86c1044
--- /dev/null
+++ b/go-sdk/example/main_test.go
@@ -0,0 +1,57 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package main
+
+import (
+       "context"
+       "log/slog"
+       "testing"
+
+       "github.com/stretchr/testify/assert"
+
+       "github.com/apache/airflow/go-sdk/sdk"
+)
+
+// This file serves as an example of how you could write unit tests against 
your own Go Tasks.
+// An example of how to write a test for a Task function!
+
+type mockVars struct{}
+
+// GetVariable implements sdk.VariableClient.
+func (m *mockVars) GetVariable(ctx context.Context, key string) (string, 
error) {
+       switch key {
+       case "my_variable":
+               return "value1", nil
+       default:
+               return "", sdk.VariableNotFound
+       }
+}
+
+// UnmarshalJSONVariable implements sdk.VariableClient.
+func (m *mockVars) UnmarshalJSONVariable(ctx context.Context, key string, 
pointer any) error {
+       panic("unimplemented")
+}
+
+var _ sdk.VariableClient = (*mockVars)(nil)
+
+func Test_transform(t *testing.T) {
+       log := slog.Default()
+       // This is not the best test, but it is a good proof of concept -- you 
can just call the function.
+       err := transform(context.Background(), &mockVars{}, log)
+       assert.NoError(t, err)
+}
diff --git a/go-sdk/sdk/variable.go b/go-sdk/sdk/client.go
similarity index 51%
copy from go-sdk/sdk/variable.go
copy to go-sdk/sdk/client.go
index 715778a5f08..971e29c165f 100644
--- a/go-sdk/sdk/variable.go
+++ b/go-sdk/sdk/client.go
@@ -15,29 +15,61 @@
 // specific language governing permissions and limitations
 // under the License.
 
+/*
+Package sdk provides access to the Airflow objects (Variables, Connection, 
XCom etc) during run time for tasks.
+*/
 package sdk
 
 import (
        "context"
+       "encoding/json"
        "errors"
        "fmt"
+       "os"
+       "strings"
 
        "github.com/apache/airflow/go-sdk/pkg/api"
        "github.com/apache/airflow/go-sdk/pkg/sdkcontext"
 )
 
-func VariableGet(ctx context.Context, key string) (string, error) {
-       client := 
ctx.Value(sdkcontext.ApiClientContextKey).(api.ClientInterface)
+type client struct{}
 
-       resp, err := client.Variables().Get(ctx, key)
+var _ Client = (*client)(nil)
+
+func NewClient() Client {
+       return &client{}
+}
+
+func variableFromEnv(key string) (string, bool) {
+       return os.LookupEnv(VariableEnvPrefix + strings.ToUpper(key))
+}
+
+func (*client) GetVariable(ctx context.Context, key string) (string, error) {
+       // TODO: Let the lookup priority be configurable like it is in Python 
SDK
+       if env, ok := variableFromEnv(key); ok {
+               return env, nil
+       }
+
+       httpClient := 
ctx.Value(sdkcontext.ApiClientContextKey).(api.ClientInterface)
+
+       resp, err := httpClient.Variables().Get(ctx, key)
        if err != nil {
                var httpError *api.GeneralHTTPError
+               errors.As(err, &httpError)
                if errors.As(err, &httpError) && 
httpError.Response.StatusCode() == 404 {
-                       // TODO: return a custom error message!
-                       return "", fmt.Errorf("variable %q not found: %d", key, 
httpError.Response.StatusCode())
+                       err = fmt.Errorf("%w: %q", VariableNotFound, key)
                }
                return "", err
        }
-       // TODO: Handle deserialization etc!
        return *resp.Value, nil
 }
+
+// UnmarshalJSONVariable implements AirflowClient.
+func (c *client) UnmarshalJSONVariable(ctx context.Context, key string, 
pointer any) error {
+       val, err := c.GetVariable(ctx, key)
+       if err != nil {
+               return err
+       }
+
+       return json.Unmarshal([]byte(val), pointer)
+}
diff --git a/go-sdk/sdk/client_test.go b/go-sdk/sdk/client_test.go
new file mode 100644
index 00000000000..970140d4f23
--- /dev/null
+++ b/go-sdk/sdk/client_test.go
@@ -0,0 +1,100 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package sdk
+
+import (
+       "context"
+       "fmt"
+       "net/http"
+       "testing"
+
+       "github.com/stretchr/testify/mock"
+       "github.com/stretchr/testify/suite"
+       "resty.dev/v3"
+
+       "github.com/apache/airflow/go-sdk/pkg/api"
+       apiMock "github.com/apache/airflow/go-sdk/pkg/api/mocks"
+       "github.com/apache/airflow/go-sdk/pkg/sdkcontext"
+)
+
+type ClientSuite struct {
+       suite.Suite
+
+       apiClient       *apiMock.ClientInterface
+       variablesClient *apiMock.VariablesClient
+       ctx             context.Context
+}
+
+var AnyContext any = mock.MatchedBy(func(_ context.Context) bool { return true 
})
+
+func makeHTTPError(status int, statusMessage string) error {
+       return &api.GeneralHTTPError{
+               Response: &resty.Response{
+                       RawResponse: &http.Response{
+                               Status:     fmt.Sprintf("%d %s", status, 
statusMessage),
+                               StatusCode: status,
+                       },
+               },
+       }
+}
+
+func TestClientSuite(t *testing.T) {
+       suite.Run(t, &ClientSuite{})
+}
+
+func (s *ClientSuite) SetupTest() {
+       c := apiMock.NewClientInterface(s.T())
+       vars := apiMock.NewVariablesClient(s.T())
+       c.EXPECT().Variables().Maybe().Return(vars)
+
+       s.apiClient = c
+       s.variablesClient = vars
+       s.ctx = context.WithValue(context.Background(), 
sdkcontext.ApiClientContextKey, c)
+}
+
+func (s *ClientSuite) TestGetVariable() {
+       key := "my_var"
+       expected := `some"raw"value`
+       s.variablesClient.EXPECT().
+               Get(AnyContext, key).
+               Return(&api.VariableResponse{Value: &expected}, nil)
+
+       c := &client{}
+       val, err := c.GetVariable(s.ctx, key)
+       s.Require().NoError(err)
+       s.Assert().Equal(expected, val)
+}
+
+func (s *ClientSuite) TestGetVariable_404Error() {
+       key := "my_var"
+       s.variablesClient.EXPECT().Get(AnyContext, key).Return(nil, 
makeHTTPError(404, "Not Found"))
+
+       c := &client{}
+       _, err := c.GetVariable(s.ctx, key)
+       s.Assert().ErrorContainsf(err, `variable not found: "my_var"`, "")
+}
+
+func (s *ClientSuite) TestGetVariable_EnvFirst() {
+       s.T().Setenv("AIRFLOW_VAR_MY_VAR", "value1")
+
+       c := &client{}
+       val, err := c.GetVariable(s.ctx, "my_var")
+       s.Require().NoError(err)
+       s.Assert().Equal("value1", val)
+       s.variablesClient.AssertNotCalled(s.T(), "Get")
+}
diff --git a/go-sdk/sdk/variable.go b/go-sdk/sdk/errors.go
similarity index 55%
rename from go-sdk/sdk/variable.go
rename to go-sdk/sdk/errors.go
index 715778a5f08..e3fb80cf70c 100644
--- a/go-sdk/sdk/variable.go
+++ b/go-sdk/sdk/errors.go
@@ -18,26 +18,11 @@
 package sdk
 
 import (
-       "context"
        "errors"
-       "fmt"
-
-       "github.com/apache/airflow/go-sdk/pkg/api"
-       "github.com/apache/airflow/go-sdk/pkg/sdkcontext"
 )
 
-func VariableGet(ctx context.Context, key string) (string, error) {
-       client := 
ctx.Value(sdkcontext.ApiClientContextKey).(api.ClientInterface)
-
-       resp, err := client.Variables().Get(ctx, key)
-       if err != nil {
-               var httpError *api.GeneralHTTPError
-               if errors.As(err, &httpError) && 
httpError.Response.StatusCode() == 404 {
-                       // TODO: return a custom error message!
-                       return "", fmt.Errorf("variable %q not found: %d", key, 
httpError.Response.StatusCode())
-               }
-               return "", err
-       }
-       // TODO: Handle deserialization etc!
-       return *resp.Value, nil
-}
+// VariableNotFound is an error value used to signal that a variable could not 
be found (and that there were
+// no communication issues to the API server).
+//
+// See the “GetVariable“ method of [VariableClient] for an example
+var VariableNotFound = errors.New("variable not found")
diff --git a/go-sdk/sdk/sdk.go b/go-sdk/sdk/sdk.go
new file mode 100644
index 00000000000..2fbd0e89f23
--- /dev/null
+++ b/go-sdk/sdk/sdk.go
@@ -0,0 +1,52 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+/*
+Package sdk provides access to the Airflow objects (Variables, Connection, 
XCom etc) during run time for tasks.
+*/
+package sdk
+
+import (
+       "context"
+)
+
+const (
+       VariableEnvPrefix   = "AIRFLOW_VAR_"
+       ConnectionEnvPrefix = "AIRFLOW_CONN_"
+)
+
+type VariableClient interface {
+       // GetVariable returns the value of an Airflow Variable.
+       //
+       // It will first look in the os.environ for the appropriately named 
variable, and if not found there will
+       // fallback to asking the API server
+       //
+       // If the variable is not found error will be a wrapped 
``VariableNotFound``:
+       //
+       //              val, err := client.GetVariable(ctx, "my-var")
+       //              if errors.Is(err, VariableNotFound) {
+       //                              // Handle not found, set default, 
return custom error etc
+       //              } else {
+       //                              // Other errors here, such as http 
network timeouts etc.
+       //              }
+       GetVariable(ctx context.Context, key string) (string, error)
+       UnmarshalJSONVariable(ctx context.Context, key string, pointer any) 
error
+}
+
+type Client interface {
+       VariableClient
+}

Reply via email to