This is an automated email from the ASF dual-hosted git repository.
jason810496 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 826f4712698 Go-SDK: Implement coordinator-mode runtime entry point and
task runner (#67318)
826f4712698 is described below
commit 826f4712698f79a9f4da58bb0e34e7f5009cd098
Author: Jason(Zhe-You) Liu <[email protected]>
AuthorDate: Wed Jun 3 11:11:23 2026 +0800
Go-SDK: Implement coordinator-mode runtime entry point and task runner
(#67318)
* go-sdk: Implement coordinator-mode runtime entry point and task runner
Wire the supervisor-launched runtime that speaks ADR 0003's coordinator
protocol. execution.Serve dials the comm and logs sockets the supervisor
passes via the new --comm/--logs flags, installs SocketLogHandler so slog
records reach the supervisor, reads StartupDetails, and drives a single
TaskInstance through task_runner.Run. The runner injects a
CoordinatorClient into the user task function via
sdkcontext.SdkClientContextKey so tasks written against the existing
sdk.Client API run unchanged. bundlev1server.Serve grows a mode selector
so the same binary still serves go-plugin when no coordinator flags are
present, and exits non-zero on partial --comm/--logs misuse.
DAG-file parsing is intentionally not part of this stack -- it will land
in a follow-up once the parsing protocol settles.
* Add coordinator-mode client round-trip integration test
The existing end-to-end test only drove a no-op task, so the comm
dispatcher's request/response multiplexing -- the core of coordinator
mode, where a running task calls back into the supervisor mid-execution
-- was never exercised against the real Serve entry point.
Add a test that runs Serve against a fake supervisor over a loopback
socket pair and registers a task that pulls a variable (GetVariable) and
returns a value (triggering a return-value SetXCom push). The fake
supervisor must answer both runtime-initiated requests, on distinct
frame ids, before the terminal SucceedTask frame arrives, covering the
full multiplexed round trip. A connection deadline keeps a regression
(e.g. the env-var fast path swallowing the request, or a dispatcher
deadlock) from hanging until the Go test timeout.
* Log coordinator-mode runtime failures and document exit contract
When the coordinator runtime fails after connecting to the supervisor
(bundle registration error, an unreadable or undecodable initial frame,
a supervisor-reported error, or an unexpected first message type) it
returned a bare error and relied on the process exit code alone to
signal failure. The supervisor only sees the closed comm socket, with no
explanation in its log stream, making such failures hard to diagnose.
Log the reason at Error on each post-connect failure path so it reaches
the supervisor over the already-connected logs socket, and document the
failure-signaling contract on Serve: a non-nil return must become a
non-zero process exit, which the supervisor records as FAILED (or
UP_FOR_RETRY), and a structured TaskState frame is only honored when the
process exits 0. This makes the fail-closed behavior explicit so a later
refactor cannot accidentally turn a startup failure into a success.
Adds a test that a RegisterDags failure makes Serve return the error
without writing a terminal frame, leaving the supervisor to observe the
comm socket closing.
* Cancel coordinator-mode tasks on SIGINT/SIGTERM
The coordinator runtime ran each task under a fresh context.Background(),
so a supervisor shutdown could not reach the task: SIGTERM terminated the
process outright and a cooperative task had no chance to observe the
shutdown and return. The final-frame write was also unbounded, so a
half-open comm socket could wedge the runtime indefinitely.
Derive the task's root context from signal.NotifyContext(SIGINT,
SIGTERM) in Serve and thread it through RunTask into the user task, so a
ctx-aware task returns promptly on a supervisor shutdown; a task that
ignores ctx is still stopped by the supervisor's follow-up SIGKILL.
Bound the terminal frame write with a deadline so a wedged socket fails
fast (non-zero exit) instead of hanging. An in-flight client call is
unaffected: it already unblocks on supervisor disconnect via the comm
dispatcher's ErrDispatcherClosed path.
Adds a test that a cancelled root context reaches the user task through
RunTask and is reported as a failed terminal state.
* Track up for retry transition as TODO
---
go-sdk/bundle/bundlev1/bundlev1server/server.go | 128 +++++--
go-sdk/example/bundle/main.go | 14 +-
go-sdk/pkg/execution/integration_test.go | 424 ++++++++++++++++++++++++
go-sdk/pkg/execution/server.go | 199 +++++++++++
go-sdk/pkg/execution/task_runner.go | 132 ++++++++
5 files changed, 873 insertions(+), 24 deletions(-)
diff --git a/go-sdk/bundle/bundlev1/bundlev1server/server.go
b/go-sdk/bundle/bundlev1/bundlev1server/server.go
index 67d212ff069..02036742d09 100644
--- a/go-sdk/bundle/bundlev1/bundlev1server/server.go
+++ b/go-sdk/bundle/bundlev1/bundlev1server/server.go
@@ -19,6 +19,7 @@ package bundlev1server
import (
"encoding/json"
+ "errors"
"fmt"
"log/slog"
"os"
@@ -32,9 +33,34 @@ import (
"github.com/apache/airflow/go-sdk/bundle/bundlev1/bundlev1server/impl"
"github.com/apache/airflow/go-sdk/pkg/bundles/shared"
"github.com/apache/airflow/go-sdk/pkg/config"
+ "github.com/apache/airflow/go-sdk/pkg/execution"
)
-var versionInfo *bool = flag.Bool("bundle-metadata", false, "show the embedded
bundle info")
+// ErrCoordinatorFlagsIncomplete is returned by [Serve] when exactly one of
+// --comm or --logs is supplied. Both flags select coordinator mode and must
+// be set together; callers (typically main) can check for this sentinel to
+// print usage before exiting non-zero.
+var ErrCoordinatorFlagsIncomplete = errors.New(
+ "--comm and --logs must be supplied together",
+)
+
+// CLI Flags.
+// The --bundle-metadata flag is used for showing the embedded bundle info in
airflow-metadata.yaml spec format.
+// The --comm and --logs select the coordinator-mode protocol
+// All three are read by Serve to choose a server mode below.
+var (
+ versionInfo = flag.Bool("bundle-metadata", false, "show the embedded
bundle info")
+ commAddr = flag.String(
+ "comm",
+ "",
+ "host:port of the supervisor's coordinator comm channel
(selects coordinator mode)",
+ )
+ logsAddr = flag.String(
+ "logs",
+ "",
+ "host:port of the supervisor's coordinator logs channel
(selects coordinator mode)",
+ )
+)
// ServeOpt is an interface for defining options that can be passed to the
// Serve function. Each implementation modifies the ServeConfig being
@@ -52,24 +78,35 @@ func (s serveConfigFunc) ApplyServeOpt(in *ServerConfig)
error {
type ServerConfig struct{}
-// Serve is the entrypoint for your bundle, and sets it up ready for Airflow's
Go Worker to use
+// serveMode tags the protocol the binary will speak this run.
+type serveMode int
+
+const (
+ modePlugin serveMode = iota // go-plugin gRPC (existing
Edge Worker path)
+ modeMetadataDump // --bundle-metadata: print
BundleInfo JSON
+ modeCoordinator // --comm/--logs:
msgpack-over-IPC (ADR 0003)
+ modeCoordinatorUsageError // misuse: print usage and
exit non-zero
+)
+
+// Serve is the entrypoint for your bundle, and sets it up ready for Airflow's
+// Go Worker (go-plugin) or Python supervisor (coordinator protocol) to use.
//
-// Zero or more options to configure the server may also be passed. There are
no options yet, this is to allow
-// future changes without breaking compatibility
+// The mode is decided from CLI flags and process environment. Callers should
+// surface the returned error so misuse (e.g. only one of --comm/--logs
+// supplied) produces a non-zero exit:
+//
+// func main() {
+// if err := bundlev1server.Serve(&myBundle{}); err != nil {
+// log.Fatal(err)
+// }
+// }
+//
+// Zero or more options to configure the server may also be passed. There are
+// no options yet; the parameter exists to allow future additions without
+// breaking compatibility.
func Serve(bundle bundlev1.BundleProvider, opts ...ServeOpt) error {
config.SetupViper("")
- hcLogger := hclog.New(&hclog.LoggerOptions{
- Level: hclog.Trace,
- Output: os.Stderr,
- JSONFormat: true,
- IncludeLocation: true,
- AdditionalLocationOffset: 3,
- })
-
- log := slog.New(hclogslog.Adapt(hcLogger))
- slog.SetDefault(log)
-
flag.Parse()
serveConfig := &ServerConfig{}
@@ -77,16 +114,63 @@ func Serve(bundle bundlev1.BundleProvider, opts
...ServeOpt) error {
c.ApplyServeOpt(serveConfig)
}
+ switch decideMode() {
+ case modeMetadataDump:
+ return dumpBundleMetadata(bundle)
+ case modeCoordinator:
+ // In coordinator mode the supervisor reads the logs channel for
+ // structured records, so configuring the hclog/stderr default
+ // logger here is unnecessary — execution.Serve installs its own
+ // slog handler against the logs socket before any user code
runs.
+ return execution.Serve(bundle, *commAddr, *logsAddr)
+ case modePlugin:
+ installPluginLogger()
+ return servePlugin(bundle)
+ case modeCoordinatorUsageError:
+ return ErrCoordinatorFlagsIncomplete
+ }
+ return nil
+}
+
+func decideMode() serveMode {
if *versionInfo {
- meta := bundle.GetBundleVersion()
- data, err := json.MarshalIndent(meta, "", " ")
- if err != nil {
- return err
- }
- fmt.Println(string(data))
- return nil
+ return modeMetadataDump
+ }
+ commSet := *commAddr != ""
+ logsSet := *logsAddr != ""
+ if commSet && logsSet {
+ return modeCoordinator
+ }
+ if commSet || logsSet {
+ // Partial use is a hard error, both flags are required
+ return modeCoordinatorUsageError
+ }
+ return modePlugin
+}
+
+func dumpBundleMetadata(bundle bundlev1.BundleProvider) error {
+ meta := bundle.GetBundleVersion()
+ data, err := json.MarshalIndent(meta, "", " ")
+ if err != nil {
+ return err
}
+ fmt.Println(string(data))
+ return nil
+}
+
+func installPluginLogger() {
+ hcLogger := hclog.New(&hclog.LoggerOptions{
+ Level: hclog.Trace,
+ Output: os.Stderr,
+ JSONFormat: true,
+ IncludeLocation: true,
+ AdditionalLocationOffset: 3,
+ })
+ log := slog.New(hclogslog.Adapt(hcLogger))
+ slog.SetDefault(log)
+}
+func servePlugin(bundle bundlev1.BundleProvider) error {
pluginConfig := &plugin.ServeConfig{
HandshakeConfig: shared.Handshake,
Plugins: plugin.PluginSet{
diff --git a/go-sdk/example/bundle/main.go b/go-sdk/example/bundle/main.go
index 5e970da5495..350c101b8b0 100644
--- a/go-sdk/example/bundle/main.go
+++ b/go-sdk/example/bundle/main.go
@@ -20,6 +20,7 @@ package main
import (
"context"
"fmt"
+ "log"
"log/slog"
"runtime"
"time"
@@ -54,7 +55,9 @@ func (m *myBundle) RegisterDags(dagbag v1.Registry) error {
}
func main() {
- bundlev1server.Serve(&myBundle{})
+ if err := bundlev1server.Serve(&myBundle{}); err != nil {
+ log.Fatal(err)
+ }
}
func extract(ctx context.Context, client sdk.Client, log *slog.Logger) (any,
error) {
@@ -63,7 +66,14 @@ func extract(ctx context.Context, client sdk.Client, log
*slog.Logger) (any, err
if err != nil {
log.ErrorContext(ctx, "unable to get conn", "error", err)
} else {
- log.InfoContext(ctx, "got conn", "conn", conn)
+ // Log only non-sensitive fields; conn.Password and any secrets
in
+ // conn.Extra must never reach the log stream.
+ log.InfoContext(ctx, "got conn",
+ "conn_id", conn.ID,
+ "conn_type", conn.Type,
+ "host", conn.Host,
+ "port", conn.Port,
+ )
}
for range 10 {
diff --git a/go-sdk/pkg/execution/integration_test.go
b/go-sdk/pkg/execution/integration_test.go
new file mode 100644
index 00000000000..ed0f437cff6
--- /dev/null
+++ b/go-sdk/pkg/execution/integration_test.go
@@ -0,0 +1,424 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package execution
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "io"
+ "log/slog"
+ "net"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/apache/airflow/go-sdk/bundle/bundlev1"
+ "github.com/apache/airflow/go-sdk/sdk"
+)
+
+// --- Test task functions ---
+
+func failingTask() error {
+ return errors.New("task failed intentionally")
+}
+
+func panicTask() error {
+ panic("something went wrong")
+}
+
+func simpleTask() error {
+ return nil
+}
+
+// buildBundle wires a bundlev1.Registry from a closure and returns it as a
+// bundlev1.Bundle (the materialised registry).
+func buildBundle(t *testing.T, register func(bundlev1.Registry))
bundlev1.Bundle {
+ t.Helper()
+ reg := bundlev1.New()
+ register(reg)
+ return reg
+}
+
+// --- Tests ---
+
+func TestTaskRunnerSuccess(t *testing.T) {
+ bundle := buildBundle(t, func(r bundlev1.Registry) {
+ r.AddDag("test_dag").AddTask(simpleTask)
+ })
+
+ details := &StartupDetails{
+ TI: TaskInstanceInfo{
+ ID: "550e8400-e29b-41d4-a716-446655440000",
+ DagID: "test_dag",
+ TaskID: "simpleTask",
+ RunID: "run1",
+ MapIndex: -1,
+ },
+ BundleInfo: BundleInfoMsg{Name: "test", Version: "1.0"},
+ }
+
+ logger := slog.New(slog.NewTextHandler(io.Discard, nil))
+ comm := NewCoordinatorComm(bytes.NewReader(nil), io.Discard, logger)
+
+ result := RunTask(context.Background(), bundle, details, comm, logger)
+ assert.Equal(t, "SucceedTask", result["type"])
+}
+
+func TestTaskRunnerFailure(t *testing.T) {
+ bundle := buildBundle(t, func(r bundlev1.Registry) {
+ r.AddDag("test_dag").AddTask(failingTask)
+ })
+
+ details := &StartupDetails{
+ TI: TaskInstanceInfo{
+ ID: "550e8400-e29b-41d4-a716-446655440000",
+ DagID: "test_dag",
+ TaskID: "failingTask",
+ RunID: "run1",
+ MapIndex: -1,
+ },
+ BundleInfo: BundleInfoMsg{Name: "test", Version: "1.0"},
+ }
+
+ logger := slog.New(slog.NewTextHandler(io.Discard, nil))
+ comm := NewCoordinatorComm(bytes.NewReader(nil), io.Discard, logger)
+
+ result := RunTask(context.Background(), bundle, details, comm, logger)
+ assert.Equal(t, "TaskState", result["type"])
+ assert.Equal(t, "failed", result["state"])
+}
+
+func TestTaskRunnerTaskNotFound(t *testing.T) {
+ bundle := buildBundle(t, func(r bundlev1.Registry) {
+ r.AddDag("test_dag").AddTask(simpleTask)
+ })
+
+ details := &StartupDetails{
+ TI: TaskInstanceInfo{
+ ID: "550e8400-e29b-41d4-a716-446655440000",
+ DagID: "test_dag",
+ TaskID: "nonexistent",
+ RunID: "run1",
+ },
+ BundleInfo: BundleInfoMsg{Name: "test", Version: "1.0"},
+ }
+
+ logger := slog.New(slog.NewTextHandler(io.Discard, nil))
+ comm := NewCoordinatorComm(bytes.NewReader(nil), io.Discard, logger)
+
+ result := RunTask(context.Background(), bundle, details, comm, logger)
+ assert.Equal(t, "TaskState", result["type"])
+ assert.Equal(t, "removed", result["state"])
+}
+
+func TestTaskRunnerPanic(t *testing.T) {
+ bundle := buildBundle(t, func(r bundlev1.Registry) {
+ r.AddDag("test_dag").AddTask(panicTask)
+ })
+
+ details := &StartupDetails{
+ TI: TaskInstanceInfo{
+ ID: "550e8400-e29b-41d4-a716-446655440000",
+ DagID: "test_dag",
+ TaskID: "panicTask",
+ RunID: "run1",
+ MapIndex: -1,
+ },
+ BundleInfo: BundleInfoMsg{Name: "test", Version: "1.0"},
+ }
+
+ logger := slog.New(slog.NewTextHandler(io.Discard, nil))
+ comm := NewCoordinatorComm(bytes.NewReader(nil), io.Discard, logger)
+
+ result := RunTask(context.Background(), bundle, details, comm, logger)
+ assert.Equal(t, "TaskState", result["type"])
+ assert.Equal(t, "failed", result["state"])
+}
+
+func TestRunTaskHonorsContextCancellation(t *testing.T) {
+ bundle := buildBundle(t, func(r bundlev1.Registry) {
+ r.AddDag("test_dag").AddTaskWithName("ctxcheck",
+ func(ctx context.Context) error { return ctx.Err() })
+ })
+
+ details := &StartupDetails{
+ TI: TaskInstanceInfo{
+ ID: "550e8400-e29b-41d4-a716-446655440000",
+ DagID: "test_dag",
+ TaskID: "ctxcheck",
+ RunID: "run1",
+ MapIndex: -1,
+ },
+ BundleInfo: BundleInfoMsg{Name: "test", Version: "1.0"},
+ }
+
+ // A cancelled root context must reach the user task through RunTask's
+ // threading; the task surfaces ctx.Err(), which RunTask maps to failed.
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ logger := slog.New(slog.NewTextHandler(io.Discard, nil))
+ comm := NewCoordinatorComm(bytes.NewReader(nil), io.Discard, logger)
+
+ result := RunTask(ctx, bundle, details, comm, logger)
+ assert.Equal(t, "TaskState", result["type"])
+ assert.Equal(t, "failed", result["state"])
+}
+
+// --- End-to-end Serve test against a fake supervisor ---
+
+// fakeProvider implements bundlev1.BundleProvider; it lets a test inject the
+// registration closure and a synthetic version.
+type fakeProvider struct {
+ register func(bundlev1.Registry) error
+}
+
+func (f *fakeProvider) GetBundleVersion() bundlev1.BundleInfo {
+ v := "1.0"
+ return bundlev1.BundleInfo{Name: "fake", Version: &v}
+}
+
+func (f *fakeProvider) RegisterDags(reg bundlev1.Registry) error {
+ if f.register == nil {
+ return nil
+ }
+ return f.register(reg)
+}
+
+func startSupervisor(
+ t *testing.T,
+) (commAddr, logsAddr string, commCh, logsCh chan net.Conn, cleanup func()) {
+ t.Helper()
+ commLn, err := net.Listen("tcp", "127.0.0.1:0")
+ require.NoError(t, err)
+ logsLn, err := net.Listen("tcp", "127.0.0.1:0")
+ require.NoError(t, err)
+
+ commCh = make(chan net.Conn, 1)
+ logsCh = make(chan net.Conn, 1)
+ go func() {
+ c, err := commLn.Accept()
+ if err == nil {
+ commCh <- c
+ }
+ close(commCh)
+ }()
+ go func() {
+ c, err := logsLn.Accept()
+ if err == nil {
+ logsCh <- c
+ }
+ close(logsCh)
+ }()
+ cleanup = func() {
+ commLn.Close()
+ logsLn.Close()
+ }
+ return commLn.Addr().String(), logsLn.Addr().String(), commCh, logsCh,
cleanup
+}
+
+func TestServeStartupDetailsEndToEnd(t *testing.T) {
+ commAddr, logsAddr, commCh, logsCh, cleanup := startSupervisor(t)
+ defer cleanup()
+
+ provider := &fakeProvider{
+ register: func(r bundlev1.Registry) error {
+ r.AddDag("dag1").AddTask(simpleTask)
+ return nil
+ },
+ }
+
+ done := make(chan error, 1)
+ go func() { done <- Serve(provider, commAddr, logsAddr) }()
+
+ commConn := <-commCh
+ defer commConn.Close()
+ logsConn := <-logsCh
+ defer logsConn.Close()
+
+ payload, err := encodeRequest(0, map[string]any{
+ "type": "StartupDetails",
+ "ti": map[string]any{
+ "id": "550e8400-e29b-41d4-a716-446655440000",
+ "dag_id": "dag1",
+ "task_id": "simpleTask",
+ "run_id": "run1",
+ "try_number": 1,
+ },
+ "bundle_info": map[string]any{"name": "fake", "version": "1.0"},
+ })
+ require.NoError(t, err)
+ require.NoError(t, writeFrame(commConn, payload))
+
+ frame, err := readFrame(commConn)
+ require.NoError(t, err)
+ require.Nil(t, frame.Err)
+ assert.Equal(t, "SucceedTask", frame.Body["type"])
+
+ select {
+ case err := <-done:
+ require.NoError(t, err)
+ case <-time.After(2 * time.Second):
+ t.Fatal("Serve did not return after task completion")
+ }
+}
+
+// TestServeClientRoundTripEndToEnd drives a task that calls back into the
+// supervisor mid-execution, so the comm dispatcher's request/response
+// multiplexing is exercised against the real Serve rather than only the
+// no-op task path. The registered task pulls a variable (GetVariable) and
+// returns a value (which triggers a return-value SetXCom push); the fake
+// supervisor must answer both runtime-initiated requests before the terminal
+// SucceedTask frame is sent.
+func TestServeClientRoundTripEndToEnd(t *testing.T) {
+ commAddr, logsAddr, commCh, logsCh, cleanup := startSupervisor(t)
+ defer cleanup()
+
+ // Unique key so the GetVariable env-var fast path
+ // (AIRFLOW_VAR_<KEY>) cannot short-circuit the socket round trip.
+ const varKey = "go_sdk_round_trip_only_key"
+
+ var gotVar string
+ provider := &fakeProvider{
+ register: func(r bundlev1.Registry) error {
+ r.AddDag("dag1").AddTaskWithName("getvar",
+ func(ctx context.Context, c sdk.Client)
(string, error) {
+ v, err := c.GetVariable(ctx, varKey)
+ if err != nil {
+ return "", err
+ }
+ gotVar = v
+ return "xval", nil
+ })
+ return nil
+ },
+ }
+
+ done := make(chan error, 1)
+ go func() { done <- Serve(provider, commAddr, logsAddr) }()
+
+ commConn := <-commCh
+ defer commConn.Close()
+ logsConn := <-logsCh
+ defer logsConn.Close()
+
+ // Bound every read/write so a regression (e.g. the env-var fast path
+ // swallowing the request, or a dispatcher deadlock) fails fast instead
of
+ // hanging until the Go test timeout.
+ require.NoError(t, commConn.SetDeadline(time.Now().Add(10*time.Second)))
+
+ // 1. Kick off task execution.
+ startup, err := encodeRequest(0, map[string]any{
+ "type": "StartupDetails",
+ "ti": map[string]any{
+ "id": "550e8400-e29b-41d4-a716-446655440000",
+ "dag_id": "dag1",
+ "task_id": "getvar",
+ "run_id": "run1",
+ "try_number": 1,
+ },
+ "bundle_info": map[string]any{"name": "fake", "version": "1.0"},
+ })
+ require.NoError(t, err)
+ require.NoError(t, writeFrame(commConn, startup))
+
+ // 2. The task's GetVariable call blocks until the supervisor answers.
+ varReq, err := readFrame(commConn)
+ require.NoError(t, err)
+ require.Nil(t, varReq.Err)
+ assert.Equal(t, "GetVariable", varReq.Body["type"])
+ assert.Equal(t, varKey, varReq.Body["key"])
+
+ varReply, err := encodeRequest(varReq.ID, map[string]any{
+ "type": "VariableResult",
+ "key": varKey,
+ "value": "hello",
+ })
+ require.NoError(t, err)
+ require.NoError(t, writeFrame(commConn, varReply))
+
+ // 3. Returning a value triggers a return-value XCom push; answer it
with
+ // an empty (non-error) response so PushXCom unblocks.
+ xcomReq, err := readFrame(commConn)
+ require.NoError(t, err)
+ require.Nil(t, xcomReq.Err)
+ assert.Equal(t, "SetXCom", xcomReq.Body["type"])
+ assert.Equal(t, "return_value", xcomReq.Body["key"])
+ assert.Equal(t, "xval", xcomReq.Body["value"])
+ assert.NotEqual(t, varReq.ID, xcomReq.ID, "second runtime request must
use a fresh frame id")
+
+ xcomReply, err := encodeRequest(xcomReq.ID, map[string]any{})
+ require.NoError(t, err)
+ require.NoError(t, writeFrame(commConn, xcomReply))
+
+ // 4. With both calls answered, the task finishes and Serve ships the
+ // terminal SucceedTask frame on the StartupDetails frame id.
+ term, err := readFrame(commConn)
+ require.NoError(t, err)
+ require.Nil(t, term.Err)
+ assert.Equal(t, "SucceedTask", term.Body["type"])
+
+ select {
+ case err := <-done:
+ require.NoError(t, err)
+ case <-time.After(2 * time.Second):
+ t.Fatal("Serve did not return after task completion")
+ }
+
+ assert.Equal(t, "hello", gotVar)
+}
+
+// TestServeRegisterDagsFailureClosesComm asserts the failure-signaling
+// contract: when bundle registration fails after the sockets are connected,
+// Serve returns the error (so the caller exits non-zero) without writing a
+// terminal frame. The supervisor observes the failure as the comm socket
+// closing rather than as a TaskState message.
+func TestServeRegisterDagsFailureClosesComm(t *testing.T) {
+ commAddr, logsAddr, commCh, logsCh, cleanup := startSupervisor(t)
+ defer cleanup()
+
+ wantErr := errors.New("boom registering dags")
+ provider := &fakeProvider{
+ register: func(bundlev1.Registry) error { return wantErr },
+ }
+
+ done := make(chan error, 1)
+ go func() { done <- Serve(provider, commAddr, logsAddr) }()
+
+ commConn := <-commCh
+ defer commConn.Close()
+ logsConn := <-logsCh
+ defer logsConn.Close()
+
+ select {
+ case err := <-done:
+ require.Error(t, err)
+ assert.ErrorIs(t, err, wantErr)
+ case <-time.After(2 * time.Second):
+ t.Fatal("Serve did not return after RegisterDags failure")
+ }
+
+ // No terminal frame was sent: the next read on the comm socket sees the
+ // connection close instead of a decodable frame.
+ require.NoError(t,
commConn.SetReadDeadline(time.Now().Add(time.Second)))
+ _, err := readFrame(commConn)
+ require.Error(t, err)
+}
diff --git a/go-sdk/pkg/execution/server.go b/go-sdk/pkg/execution/server.go
new file mode 100644
index 00000000000..ade60d22956
--- /dev/null
+++ b/go-sdk/pkg/execution/server.go
@@ -0,0 +1,199 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Package execution implements the SDK coordinator-protocol runtime
+// (msgpack-over-IPC). It is the second mode of bundlev1server.Serve: when
+// the bundle binary is launched with --comm/--logs by the Airflow supervisor
+// (Python ExecutableCoordinator), bundlev1server.Serve dispatches here.
+//
+// The first inbound frame on the comm socket is a StartupDetails message
+// that drives multi-round task execution.
+//
+// See go-sdk/adr/0003-coordinator-protocol-msgpack-ipc.md.
+package execution
+
+import (
+ "context"
+ "fmt"
+ "log/slog"
+ "net"
+ "os/signal"
+ "sync"
+ "syscall"
+ "time"
+
+ "github.com/apache/airflow/go-sdk/bundle/bundlev1"
+)
+
+// dialTimeout bounds how long execution.Serve waits to reach the supervisor's
+// comm and logs sockets. The supervisor opens the listeners before spawning
+// the bundle binary, so the dials normally succeed in milliseconds; the
+// timeout exists so an unreachable address fails fast instead of hanging the
+// runtime indefinitely.
+const dialTimeout = 30 * time.Second
+
+// terminalSendTimeout bounds the write of the final TaskState/SucceedTask
+// frame. The supervisor normally drains the comm socket promptly, but a
+// half-open connection (the supervisor gone without a clean close) could
+// otherwise wedge the runtime on a blocked write; the deadline turns that
+// into a fast failure -- and thus a non-zero exit -- instead of a hang.
+const terminalSendTimeout = 30 * time.Second
+
+// Serve runs the bundle binary in coordinator mode. It dials the supervisor's
+// comm and logs sockets, installs an slog handler that writes JSON-line
+// records to the logs connection, and dispatches on the first frame.
+//
+// Serve returns nil on a clean shutdown: the task ran and its terminal
+// TaskState/SucceedTask frame was delivered, and the caller should exit 0. A
+// non-nil error indicates a protocol-level failure (connection loss,
+// malformed frames, unknown first message type) that happens before or
+// instead of delivering a terminal frame.
+//
+// Failure-signaling contract: the caller (main) must turn a non-nil error
+// into a non-zero process exit. The supervisor derives the task's final state
+// primarily from the child's exit code -- a non-zero exit is recorded as
+// FAILED (or UP_FOR_RETRY when retries are configured), and a structured
+// TaskState frame is only honored when the process exits 0 (see the Python
+// supervisor's ActivitySubprocess.final_state). So an early error return here
+// fails closed without needing to send a frame; the post-connect paths below
+// log the reason at Error first so it still reaches the supervisor's log
+// stream over the already-connected logs socket.
+func Serve(provider bundlev1.BundleProvider, commAddr, logsAddr string) error {
+ if commAddr == "" {
+ return fmt.Errorf("missing --comm=host:port argument")
+ }
+ if logsAddr == "" {
+ return fmt.Errorf("missing --logs=host:port argument")
+ }
+
+ // A supervisor shutdown arrives as SIGTERM (escalated to SIGKILL after
a
+ // grace period). Trap SIGINT/SIGTERM into a context so a cooperative,
+ // ctx-aware task can observe the shutdown and return promptly. A task
that
+ // ignores ctx is still stopped by the supervisor's follow-up SIGKILL,
so
+ // trapping the signal here does not strand a non-cooperative task.
+ ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT,
syscall.SIGTERM)
+ defer stop()
+
+ // Buffer log records until the logs socket is connected. Anything the
+ // runtime emits between Connect-time and the first frame still gets
+ // flushed.
+ logHandler := NewSocketLogHandler(nil, slog.LevelDebug)
+ logger := slog.New(logHandler)
+ slog.SetDefault(logger)
+
+ // Connect to both sockets concurrently so the supervisor can accept
them
+ // in either order.
+ dialer := &net.Dialer{Timeout: dialTimeout}
+ var commConn, logsConn net.Conn
+ var commErr, logsErr error
+ var wg sync.WaitGroup
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+ commConn, commErr = dialer.Dial("tcp", commAddr)
+ }()
+ go func() {
+ defer wg.Done()
+ logsConn, logsErr = dialer.Dial("tcp", logsAddr)
+ }()
+ wg.Wait()
+
+ // Either dial may succeed while the other fails; close any orphaned
+ // connection before returning so we don't leak an open TCP socket.
+ if commErr != nil {
+ if logsConn != nil {
+ logsConn.Close()
+ }
+ return fmt.Errorf("connecting to comm socket %s: %w", commAddr,
commErr)
+ }
+ if logsErr != nil {
+ commConn.Close()
+ return fmt.Errorf("connecting to logs socket %s: %w", logsAddr,
logsErr)
+ }
+ defer commConn.Close()
+ defer logsConn.Close()
+
+ logHandler.Connect(logsConn)
+ logger.Debug("Connected", "comm", commAddr, "logs", logsAddr)
+
+ // Materialise the bundle (RegisterDags) up front. Both protocol paths
+ // need the registry, and doing it once before the first frame keeps the
+ // dispatcher simple.
+ bundle, err := materialiseBundle(provider)
+ if err != nil {
+ logger.Error("Bundle registration failed", "error", err)
+ return fmt.Errorf("registering dags: %w", err)
+ }
+
+ comm := NewCoordinatorComm(commConn, commConn, logger)
+
+ frame, err := comm.ReadMessage()
+ if err != nil {
+ logger.Error("Failed to read initial message from supervisor",
"error", err)
+ return fmt.Errorf("reading initial message: %w", err)
+ }
+
+ if frame.Err != nil {
+ errResp := decodeErrorResponse(frame.Err)
+ if errResp != nil {
+ logger.Error("Supervisor reported an error on the
initial frame",
+ "error", errResp.Error,
+ "detail", errResp.Detail,
+ )
+ return fmt.Errorf(
+ "received error from supervisor: [%s] %v",
+ errResp.Error,
+ errResp.Detail,
+ )
+ }
+ }
+
+ body, err := decodeIncomingBody(frame.Body)
+ if err != nil {
+ logger.Error("Failed to decode initial message", "error", err)
+ return fmt.Errorf("decoding initial message: %w", err)
+ }
+
+ switch msg := body.(type) {
+ case *StartupDetails:
+ logger.Debug("Task execution mode",
+ "dag_id", msg.TI.DagID,
+ "task_id", msg.TI.TaskID,
+ )
+ result := RunTask(ctx, bundle, msg, comm, logger)
+ // Bound the terminal write so a wedged socket cannot hang
shutdown.
+ _ =
commConn.SetWriteDeadline(time.Now().Add(terminalSendTimeout))
+ if err := comm.SendRequest(frame.ID, result); err != nil {
+ return fmt.Errorf("sending task result: %w", err)
+ }
+ logger.Debug("Task execution complete")
+
+ default:
+ logger.Error("Unexpected initial message type", "type",
fmt.Sprintf("%T", body))
+ return fmt.Errorf("unexpected initial message type: %T", body)
+ }
+
+ return nil
+}
+
+func materialiseBundle(provider bundlev1.BundleProvider) (bundlev1.Bundle,
error) {
+ reg := bundlev1.New()
+ if err := provider.RegisterDags(reg); err != nil {
+ return nil, err
+ }
+ return reg, nil
+}
diff --git a/go-sdk/pkg/execution/task_runner.go
b/go-sdk/pkg/execution/task_runner.go
new file mode 100644
index 00000000000..dba184b588e
--- /dev/null
+++ b/go-sdk/pkg/execution/task_runner.go
@@ -0,0 +1,132 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package execution
+
+import (
+ "context"
+ "log/slog"
+ "runtime/debug"
+ "time"
+
+ "github.com/google/uuid"
+
+ "github.com/apache/airflow/go-sdk/bundle/bundlev1"
+ "github.com/apache/airflow/go-sdk/pkg/api"
+ "github.com/apache/airflow/go-sdk/pkg/sdkcontext"
+ "github.com/apache/airflow/go-sdk/sdk"
+)
+
+// RunTask executes a task based on StartupDetails received from the
supervisor.
+//
+// It looks up the task in the bundle, creates a CoordinatorClient for SDK
+// calls, executes the task, and returns a terminal message body
+// (SucceedTaskMsg or TaskStateMsg) ready to ship as the final response frame.
+//
+// The supervisor owns the Execution-API state transitions in coordinator
+// mode, so we deliberately bypass worker.ExecuteTaskWorkload (which drives
+// Run / UpdateState itself) and only invoke the user's task function.
+//
+// ctx is the task's root context; Serve derives it from SIGINT/SIGTERM, so a
+// cooperative task that honors ctx returns promptly on a supervisor shutdown.
+func RunTask(
+ ctx context.Context,
+ bundle bundlev1.Bundle,
+ details *StartupDetails,
+ comm *CoordinatorComm,
+ logger *slog.Logger,
+) map[string]any {
+ task, exists := bundle.LookupTask(details.TI.DagID, details.TI.TaskID)
+ if !exists {
+ logger.Error("Task not registered",
+ "dag_id", details.TI.DagID,
+ "task_id", details.TI.TaskID,
+ )
+ return TaskStateMsg{State: TaskStateRemoved, EndDate:
time.Now().UTC()}.toMap()
+ }
+
+ client := NewCoordinatorClient(comm)
+
+ // taskFunction.sendXcom reads the workload from context to get the task
+ // instance ids; populate it the same shape the gRPC path uses.
+ tiUUID, err := uuid.Parse(details.TI.ID)
+ if err != nil {
+ logger.Error("Invalid task instance UUID from supervisor",
+ "dag_id", details.TI.DagID,
+ "task_id", details.TI.TaskID,
+ "ti_id", details.TI.ID,
+ "error", err,
+ )
+ return TaskStateMsg{State: TaskStateFailed, EndDate:
time.Now().UTC()}.toMap()
+ }
+ mapIndex := details.TI.MapIndex
+ workload := api.ExecuteTaskWorkload{
+ TI: api.TaskInstance{
+ Id: tiUUID,
+ DagId: details.TI.DagID,
+ RunId: details.TI.RunID,
+ TaskId: details.TI.TaskID,
+ TryNumber: details.TI.TryNumber,
+ MapIndex: &mapIndex,
+ },
+ BundleInfo: api.BundleInfo{
+ Name: details.BundleInfo.Name,
+ Version: &details.BundleInfo.Version,
+ },
+ }
+
+ ctx = context.WithValue(ctx, sdkcontext.WorkloadContextKey, workload)
+ ctx = context.WithValue(ctx, sdkcontext.SdkClientContextKey,
sdk.Client(client))
+
+ return executeTask(ctx, task, logger)
+}
+
+// executeTask runs the task and handles success, failure, and panics.
+func executeTask(
+ ctx context.Context,
+ task bundlev1.Task,
+ logger *slog.Logger,
+) (result map[string]any) {
+ defer func() {
+ if r := recover(); r != nil {
+ logger.Error("Recovered panic in task",
+ "error", r,
+ "stack", string(debug.Stack()),
+ )
+ result = TaskStateMsg{
+ State: TaskStateFailed,
+ EndDate: time.Now().UTC(),
+ }.toMap()
+ }
+ }()
+
+ if err := task.Execute(ctx, logger); err != nil {
+ logger.ErrorContext(ctx, "Task failed", "error", err)
+ // TODO(https://github.com/apache/airflow/issues/67797): emit
RetryTask
+ // (UP_FOR_RETRY) when ti_context.should_retry is set. Today
every
+ // failure maps to terminal FAILED because the supervisor
honors this
+ // frame on exit 0 and we never send RetryTask, so retries are
lost.
+ return TaskStateMsg{
+ State: TaskStateFailed,
+ EndDate: time.Now().UTC(),
+ }.toMap()
+ }
+
+ return SucceedTaskMsg{
+ EndDate: time.Now().UTC(),
+ }.toMap()
+}