kaxil commented on code in PR #67315: URL: https://github.com/apache/airflow/pull/67315#discussion_r3295541071
########## go-sdk/pkg/execution/frames.go: ########## @@ -0,0 +1,257 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package execution + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + "github.com/vmihailenco/msgpack/v5" +) + +// maxFrameSize caps the payload length a single frame may declare. A +// malformed length prefix from a corrupted stream (or hostile peer) would +// otherwise let readFrame allocate up to 4 GiB before the read failed. +// 64 MiB is far above any legitimate StartupDetails or XCom payload while +// still preventing accidental OOM. +const maxFrameSize = 64 * 1024 * 1024 Review Comment: Python caps at 4 GiB (`2**32`) per `comms.py:152`; Go caps at 64 MiB here. The asymmetry means a Python supervisor can hand the Go runtime a legitimate frame (large XCom, large StartupDetails) that this side rejects with no per-side documentation explaining why. If 64 MiB is intentional (DoS guard, runtime memory budget), add a sentence saying so -- and consider exposing it as an exported const so the supervisor side can pre-check sends and refuse upstream with a clearer error instead of getting the frame yanked on the wire. PR description also claims "64 MiB is far above any legitimate StartupDetails or XCom payload" -- worth either lowering Python's cap to match, or noting why the two ends diverge. ########## go-sdk/pkg/execution/messages.go: ########## @@ -0,0 +1,405 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package execution + +import ( + "fmt" + "time" +) + +// Inbound messages (Supervisor -> Runtime). + +// TaskInstanceInfo holds task instance details from StartupDetails. +type TaskInstanceInfo struct { + ID string + TaskID string + DagID string + RunID string + TryNumber int + DagVersionID string + MapIndex int + ContextCarrier map[string]any +} + +func decodeTaskInstanceInfo(m map[string]any) (TaskInstanceInfo, error) { + if m == nil { + return TaskInstanceInfo{}, fmt.Errorf("nil task instance map") + } + id, err := mapString(m, "id") + if err != nil { + return TaskInstanceInfo{}, fmt.Errorf("ti.id: %w", err) + } + taskID, err := mapString(m, "task_id") + if err != nil { + return TaskInstanceInfo{}, fmt.Errorf("ti.task_id: %w", err) + } + dagID, err := mapString(m, "dag_id") + if err != nil { + return TaskInstanceInfo{}, fmt.Errorf("ti.dag_id: %w", err) + } + runID, err := mapString(m, "run_id") + if err != nil { + return TaskInstanceInfo{}, fmt.Errorf("ti.run_id: %w", err) + } + tryNumber := mapIntOr(m, "try_number", 1) + dagVersionID := mapStringOr(m, "dag_version_id", "") + mapIndex := mapIntOr(m, "map_index", -1) + contextCarrier := mapMap(m, "context_carrier") + + return TaskInstanceInfo{ + ID: id, + TaskID: taskID, + DagID: dagID, + RunID: runID, + TryNumber: tryNumber, + DagVersionID: dagVersionID, + MapIndex: mapIndex, + ContextCarrier: contextCarrier, + }, nil +} + +// BundleInfoMsg holds bundle identification from StartupDetails. +type BundleInfoMsg struct { + Name string + Version string +} + +func decodeBundleInfo(m map[string]any) BundleInfoMsg { + if m == nil { + return BundleInfoMsg{} + } + return BundleInfoMsg{ + Name: mapStringOr(m, "name", ""), + Version: mapStringOr(m, "version", ""), + } +} + +// TIRunContext holds the runtime context for a task instance. +type TIRunContext struct { + LogicalDate *time.Time + DataIntervalStart *time.Time + DataIntervalEnd *time.Time +} + +func decodeTIRunContext(m map[string]any) (TIRunContext, error) { + if m == nil { + return TIRunContext{}, nil + } + ctx := TIRunContext{} + for _, f := range []struct { + key string + dst **time.Time + }{ + {"logical_date", &ctx.LogicalDate}, + {"data_interval_start", &ctx.DataIntervalStart}, + {"data_interval_end", &ctx.DataIntervalEnd}, + } { + raw, present := m[f.key] + if !present || raw == nil { + continue + } + t, err := asTime(raw) + if err != nil { + return TIRunContext{}, fmt.Errorf("ti_context.%s: %w", f.key, err) + } + *f.dst = &t + } + return ctx, nil +} + +// StartupDetails is sent by the supervisor to initiate task execution. +type StartupDetails struct { + TI TaskInstanceInfo + DagRelPath string + BundleInfo BundleInfoMsg + StartDate time.Time + TIContext TIRunContext + SentryIntegration string +} + +func decodeStartupDetails(m map[string]any) (*StartupDetails, error) { + tiMap := mapMap(m, "ti") + ti, err := decodeTaskInstanceInfo(tiMap) + if err != nil { + return nil, fmt.Errorf("decoding ti: %w", err) + } + + dagRelPath := mapStringOr(m, "dag_rel_path", "") + bundleInfo := decodeBundleInfo(mapMap(m, "bundle_info")) + + var startDate time.Time + if raw, present := m["start_date"]; present && raw != nil { + startDate, err = asTime(raw) + if err != nil { + return nil, fmt.Errorf("start_date: %w", err) + } + } + + tiContext, err := decodeTIRunContext(mapMap(m, "ti_context")) + if err != nil { + return nil, fmt.Errorf("decoding ti_context: %w", err) + } + sentryIntegration := mapStringOr(m, "sentry_integration", "") + + return &StartupDetails{ + TI: ti, + DagRelPath: dagRelPath, + BundleInfo: bundleInfo, + StartDate: startDate, + TIContext: tiContext, + SentryIntegration: sentryIntegration, + }, nil +} + +// Response types (for runtime-initiated requests). + +// ConnectionResult is the response to GetConnection. +type ConnectionResult struct { + ConnID string + ConnType string + Host string + Schema string + Login string + Password string + Port int + Extra string +} + +func decodeConnectionResult(m map[string]any) (*ConnectionResult, error) { + return &ConnectionResult{ + ConnID: mapStringOr(m, "conn_id", ""), + ConnType: mapStringOr(m, "conn_type", ""), + Host: mapStringOr(m, "host", ""), + Schema: mapStringOr(m, "schema", ""), + Login: mapStringOr(m, "login", ""), + Password: mapStringOr(m, "password", ""), + Port: mapIntOr(m, "port", 0), + Extra: mapStringOr(m, "extra", ""), + }, nil +} + +// VariableResult is the response to GetVariable. +type VariableResult struct { + Key string + Value any +} + +func decodeVariableResult(m map[string]any) (*VariableResult, error) { + return &VariableResult{ + Key: mapStringOr(m, "key", ""), + Value: m["value"], + }, nil +} + +// XComResult is the response to GetXCom. +type XComResult struct { + Key string + Value any +} + +func decodeXComResult(m map[string]any) (*XComResult, error) { + return &XComResult{ + Key: mapStringOr(m, "key", ""), + Value: m["value"], + }, nil +} + +// ErrorResponse represents an error returned by the supervisor. +type ErrorResponse struct { + Error string + Detail any +} + +func decodeErrorResponse(m map[string]any) *ErrorResponse { + if m == nil { + return nil + } + return &ErrorResponse{ + Error: mapStringOr(m, "error", ""), + Detail: m["detail"], + } +} + +// Outbound messages (Runtime -> Supervisor). + +// GetConnectionMsg is sent to request a connection from the supervisor. +type GetConnectionMsg struct { + ConnID string +} + +func (m GetConnectionMsg) toMap() map[string]any { + return map[string]any{ + "type": "GetConnection", + "conn_id": m.ConnID, + } +} + +// GetVariableMsg is sent to request a variable from the supervisor. +type GetVariableMsg struct { + Key string +} + +func (m GetVariableMsg) toMap() map[string]any { + return map[string]any{ + "type": "GetVariable", + "key": m.Key, + } +} + +// GetXComMsg is sent to request an XCom value from the supervisor. +type GetXComMsg struct { + Key string + DagID string + TaskID string + RunID string + MapIndex *int + IncludePriorDates bool +} + +func (m GetXComMsg) toMap() map[string]any { + result := map[string]any{ + "type": "GetXCom", + "key": m.Key, + "dag_id": m.DagID, + "task_id": m.TaskID, + "run_id": m.RunID, + "include_prior_dates": m.IncludePriorDates, + } + if m.MapIndex != nil { + result["map_index"] = *m.MapIndex + } + return result +} + +// SetXComMsg is sent to set an XCom value. +type SetXComMsg struct { + Key string + Value any + DagID string + TaskID string + RunID string + MapIndex int + MappedLength *int +} + +func (m SetXComMsg) toMap() map[string]any { + result := map[string]any{ + "type": "SetXCom", + "key": m.Key, + "value": m.Value, + "dag_id": m.DagID, + "task_id": m.TaskID, + "run_id": m.RunID, + "map_index": m.MapIndex, + } + if m.MappedLength != nil { + result["mapped_length"] = *m.MappedLength + } + return result +} + +// SucceedTaskMsg is sent as a terminal message when a task succeeds. +type SucceedTaskMsg struct { + EndDate time.Time + TaskOutlets []any + OutletEvents []any +} + +func (m SucceedTaskMsg) toMap() map[string]any { + taskOutlets := m.TaskOutlets + if taskOutlets == nil { + taskOutlets = []any{} + } + outletEvents := m.OutletEvents + if outletEvents == nil { + outletEvents = []any{} + } + return map[string]any{ + "type": "SucceedTask", + "end_date": m.EndDate.UTC().Format(time.RFC3339), Review Comment: `time.RFC3339` drops sub-second precision; `asTime` on line 401 parses with `RFC3339Nano`. Python's `AwareDatetime` accepts both, so this works, but you lose nanoseconds on terminal-state timestamps -- which matters for ordering closely-spaced events (state transitions, retries within the same second). Same issue on line 363 (`TaskStateMsg`). Suggest `time.RFC3339Nano` for symmetry with the inbound parser, so the encoder doesn't quietly truncate data the decoder would happily round-trip. ########## go-sdk/pkg/execution/frames.go: ########## @@ -0,0 +1,257 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package execution + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + "github.com/vmihailenco/msgpack/v5" +) + +// maxFrameSize caps the payload length a single frame may declare. A +// malformed length prefix from a corrupted stream (or hostile peer) would +// otherwise let readFrame allocate up to 4 GiB before the read failed. +// 64 MiB is far above any legitimate StartupDetails or XCom payload while +// still preventing accidental OOM. +const maxFrameSize = 64 * 1024 * 1024 + +// IncomingFrame represents a decoded frame received from the comm socket. +type IncomingFrame struct { + ID int + Body map[string]any + Err map[string]any // non-nil only for response frames (3-element arrays) +} + +// encodeRequest encodes a request frame (2-element msgpack array: [id, body]). +func encodeRequest(id int, body map[string]any) ([]byte, error) { Review Comment: This always emits 2-element `[id, body]` frames with no `context_carrier`. Python's `_RequestFrame` (`comms.py:159-174`) carries a third W3C trace-context element, and the supervisor restores it on incoming requests (`supervisor.py:811`) *before* any outbound HTTP call -- which is what makes supervisor-side spans for those calls chain under the task span in OTEL. Without `context_carrier` propagation, every outbound HTTP triggered by a Go-runtime request will start a detached span tree. Fine to defer for this scaffolding PR, but worth a `// TODO(context-carrier): match Python _RequestFrame three-tuple` so the follow-up comms-layer PR (#67317) doesn't ship without it. ########## go-sdk/pkg/execution/frames.go: ########## @@ -0,0 +1,257 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package execution + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + "github.com/vmihailenco/msgpack/v5" +) + +// maxFrameSize caps the payload length a single frame may declare. A +// malformed length prefix from a corrupted stream (or hostile peer) would +// otherwise let readFrame allocate up to 4 GiB before the read failed. +// 64 MiB is far above any legitimate StartupDetails or XCom payload while +// still preventing accidental OOM. +const maxFrameSize = 64 * 1024 * 1024 + +// IncomingFrame represents a decoded frame received from the comm socket. +type IncomingFrame struct { + ID int + Body map[string]any + Err map[string]any // non-nil only for response frames (3-element arrays) +} + +// encodeRequest encodes a request frame (2-element msgpack array: [id, body]). +func encodeRequest(id int, body map[string]any) ([]byte, error) { + var buf bytes.Buffer + enc := msgpack.NewEncoder(&buf) + enc.UseCompactInts(true) + + if err := enc.EncodeArrayLen(2); err != nil { + return nil, err + } + if err := enc.EncodeInt(int64(id)); err != nil { + return nil, err + } + if err := enc.Encode(body); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// writeFrame writes a length-prefixed msgpack payload to the writer. +// Format: [4-byte big-endian length][payload bytes]. +// +// The prefix and payload are concatenated into a single buffer and written +// in one Write call so we never leave a half-framed message on the wire if +// an io.Writer implementation does a short write between the two halves. +func writeFrame(w io.Writer, payload []byte) error { + buf := make([]byte, 4+len(payload)) + binary.BigEndian.PutUint32(buf[:4], uint32(len(payload))) Review Comment: `uint32(len(payload))` silently wraps on payloads >= 4 GiB. Python's `_FrameMixin.as_bytes` (`task-sdk/src/airflow/sdk/execution_time/comms.py:152`) explicitly raises `OverflowError` for the same condition. The 64 MiB read-side cap below would catch incoming garbage, but a Go runtime that ever tries to *send* a 4+ GiB payload (think large XCom value, batched-state msg) would emit a corrupt length prefix and silently desynchronise the peer rather than failing loudly here. Suggest an explicit `if len(payload) >= maxFrameSize { return fmt.Errorf(...) }` guard before the conversion, mirroring the Python OverflowError. ########## go-sdk/pkg/execution/frames.go: ########## @@ -0,0 +1,257 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package execution + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + "github.com/vmihailenco/msgpack/v5" +) + +// maxFrameSize caps the payload length a single frame may declare. A +// malformed length prefix from a corrupted stream (or hostile peer) would +// otherwise let readFrame allocate up to 4 GiB before the read failed. +// 64 MiB is far above any legitimate StartupDetails or XCom payload while +// still preventing accidental OOM. +const maxFrameSize = 64 * 1024 * 1024 + +// IncomingFrame represents a decoded frame received from the comm socket. +type IncomingFrame struct { + ID int + Body map[string]any + Err map[string]any // non-nil only for response frames (3-element arrays) +} + +// encodeRequest encodes a request frame (2-element msgpack array: [id, body]). +func encodeRequest(id int, body map[string]any) ([]byte, error) { + var buf bytes.Buffer + enc := msgpack.NewEncoder(&buf) + enc.UseCompactInts(true) + + if err := enc.EncodeArrayLen(2); err != nil { + return nil, err + } + if err := enc.EncodeInt(int64(id)); err != nil { + return nil, err + } + if err := enc.Encode(body); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// writeFrame writes a length-prefixed msgpack payload to the writer. +// Format: [4-byte big-endian length][payload bytes]. +// +// The prefix and payload are concatenated into a single buffer and written +// in one Write call so we never leave a half-framed message on the wire if +// an io.Writer implementation does a short write between the two halves. +func writeFrame(w io.Writer, payload []byte) error { + buf := make([]byte, 4+len(payload)) + binary.BigEndian.PutUint32(buf[:4], uint32(len(payload))) + copy(buf[4:], payload) + n, err := w.Write(buf) + if err != nil { + return fmt.Errorf("writing frame: %w", err) + } + if n < len(buf) { + return fmt.Errorf("writing frame: %w", io.ErrShortWrite) + } + return nil +} + +// readFrame reads one length-prefixed msgpack frame from the reader and decodes it. +func readFrame(r io.Reader) (IncomingFrame, error) { + // Read 4-byte big-endian length prefix. + prefix := make([]byte, 4) + if _, err := io.ReadFull(r, prefix); err != nil { + return IncomingFrame{}, fmt.Errorf("reading length prefix: %w", err) + } + payloadLen := binary.BigEndian.Uint32(prefix) + if payloadLen > maxFrameSize { + return IncomingFrame{}, fmt.Errorf( + "frame payload length %d exceeds max %d", + payloadLen, + maxFrameSize, + ) + } + + // Read the payload. The maxFrameSize guard above keeps the value well + // within int range on every supported platform, so the conversion is + // safe. + payload := make([]byte, int(payloadLen)) + if _, err := io.ReadFull(r, payload); err != nil { + return IncomingFrame{}, fmt.Errorf("reading payload (%d bytes): %w", payloadLen, err) + } + + return decodeFrame(payload) +} + +// decodeFrame decodes a msgpack payload into an IncomingFrame. +func decodeFrame(data []byte) (IncomingFrame, error) { + dec := msgpack.NewDecoder(bytes.NewReader(data)) + + arrLen, err := dec.DecodeArrayLen() + if err != nil { + return IncomingFrame{}, fmt.Errorf("decoding array header: %w", err) + } + if arrLen < 2 { + return IncomingFrame{}, fmt.Errorf("unexpected frame arity %d, need at least 2", arrLen) + } + + id64, err := dec.DecodeInt64() + if err != nil { + return IncomingFrame{}, fmt.Errorf("decoding frame id: %w", err) + } + + // Decode the body element. + bodyRaw, err := dec.DecodeInterface() + if err != nil { + return IncomingFrame{}, fmt.Errorf("decoding body: %w", err) + } + body, ok := toStringMap(bodyRaw) + if bodyRaw != nil && !ok { + return IncomingFrame{}, fmt.Errorf("body element: expected map, got %T", bodyRaw) + } + + // For response frames (3-element), decode the error element. + var errMap map[string]any + if arrLen >= 3 { + errRaw, err := dec.DecodeInterface() + if err != nil { + return IncomingFrame{}, fmt.Errorf("decoding error element: %w", err) + } + errMap, ok = toStringMap(errRaw) + if errRaw != nil && !ok { + return IncomingFrame{}, fmt.Errorf("error element: expected map, got %T", errRaw) + } + } + + return IncomingFrame{ + ID: int(id64), + Body: body, + Err: errMap, + }, nil +} + +// toStringMap converts a decoded interface{} to map[string]any. +// Returns nil, false if the input is nil or not a map. +func toStringMap(v any) (map[string]any, bool) { + if v == nil { + return nil, false + } + switch m := v.(type) { + case map[string]any: + return m, true + case map[any]any: + result := make(map[string]any, len(m)) + for k, val := range m { + result[fmt.Sprint(k)] = val + } + return result, true + default: + return nil, false + } +} + +// mapString extracts a string value from a map. +func mapString(m map[string]any, key string) (string, error) { + v, ok := m[key] + if !ok { + return "", fmt.Errorf("missing key %q", key) + } + s, ok := v.(string) + if !ok { + return "", fmt.Errorf("key %q: expected string, got %T", key, v) + } + return s, nil +} + +// mapIntOr extracts an int value from a map, returning the default if missing. +func mapIntOr(m map[string]any, key string, def int) int { + v, ok := m[key] + if !ok { + return def + } + n, err := toInt(v) + if err != nil { + return def Review Comment: `mapInt` returns the default for both "missing key" AND "wrong type", but the docstring on line 188 only mentions "missing". `try_number` in Python's `TaskInstance` model is required (`int`), so if a future supervisor change accidentally sends it as a string, decoding silently reports `try_number=1` and the runtime keeps going with wrong audit data -- no error, no signal. Either return `(int, error)` like `mapString` does (with the err carrying the offending type), or split into `mapInt` (errors on bad type) and `mapIntOr` (default only on miss). The current behaviour swallows a class of supervisor/runtime version-drift bugs. ########## go-sdk/pkg/execution/messages.go: ########## @@ -0,0 +1,405 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package execution + +import ( + "fmt" + "time" +) + +// Inbound messages (Supervisor -> Runtime). + +// TaskInstanceInfo holds task instance details from StartupDetails. +type TaskInstanceInfo struct { + ID string + TaskID string + DagID string + RunID string + TryNumber int + DagVersionID string + MapIndex int + ContextCarrier map[string]any +} + +func decodeTaskInstanceInfo(m map[string]any) (TaskInstanceInfo, error) { + if m == nil { + return TaskInstanceInfo{}, fmt.Errorf("nil task instance map") + } + id, err := mapString(m, "id") + if err != nil { + return TaskInstanceInfo{}, fmt.Errorf("ti.id: %w", err) + } + taskID, err := mapString(m, "task_id") + if err != nil { + return TaskInstanceInfo{}, fmt.Errorf("ti.task_id: %w", err) + } + dagID, err := mapString(m, "dag_id") + if err != nil { + return TaskInstanceInfo{}, fmt.Errorf("ti.dag_id: %w", err) + } + runID, err := mapString(m, "run_id") + if err != nil { + return TaskInstanceInfo{}, fmt.Errorf("ti.run_id: %w", err) + } + tryNumber := mapIntOr(m, "try_number", 1) + dagVersionID := mapStringOr(m, "dag_version_id", "") + mapIndex := mapIntOr(m, "map_index", -1) + contextCarrier := mapMap(m, "context_carrier") + + return TaskInstanceInfo{ + ID: id, + TaskID: taskID, + DagID: dagID, + RunID: runID, + TryNumber: tryNumber, + DagVersionID: dagVersionID, + MapIndex: mapIndex, + ContextCarrier: contextCarrier, + }, nil +} + +// BundleInfoMsg holds bundle identification from StartupDetails. +type BundleInfoMsg struct { + Name string + Version string +} + +func decodeBundleInfo(m map[string]any) BundleInfoMsg { + if m == nil { + return BundleInfoMsg{} + } + return BundleInfoMsg{ + Name: mapStringOr(m, "name", ""), + Version: mapStringOr(m, "version", ""), + } +} + +// TIRunContext holds the runtime context for a task instance. +type TIRunContext struct { + LogicalDate *time.Time + DataIntervalStart *time.Time + DataIntervalEnd *time.Time +} + +func decodeTIRunContext(m map[string]any) (TIRunContext, error) { + if m == nil { + return TIRunContext{}, nil + } + ctx := TIRunContext{} + for _, f := range []struct { + key string + dst **time.Time + }{ + {"logical_date", &ctx.LogicalDate}, + {"data_interval_start", &ctx.DataIntervalStart}, + {"data_interval_end", &ctx.DataIntervalEnd}, + } { + raw, present := m[f.key] + if !present || raw == nil { + continue + } + t, err := asTime(raw) + if err != nil { + return TIRunContext{}, fmt.Errorf("ti_context.%s: %w", f.key, err) + } + *f.dst = &t + } + return ctx, nil +} + +// StartupDetails is sent by the supervisor to initiate task execution. +type StartupDetails struct { + TI TaskInstanceInfo + DagRelPath string + BundleInfo BundleInfoMsg + StartDate time.Time + TIContext TIRunContext + SentryIntegration string +} + +func decodeStartupDetails(m map[string]any) (*StartupDetails, error) { + tiMap := mapMap(m, "ti") + ti, err := decodeTaskInstanceInfo(tiMap) + if err != nil { + return nil, fmt.Errorf("decoding ti: %w", err) + } + + dagRelPath := mapStringOr(m, "dag_rel_path", "") + bundleInfo := decodeBundleInfo(mapMap(m, "bundle_info")) + + var startDate time.Time + if raw, present := m["start_date"]; present && raw != nil { + startDate, err = asTime(raw) + if err != nil { + return nil, fmt.Errorf("start_date: %w", err) + } + } + + tiContext, err := decodeTIRunContext(mapMap(m, "ti_context")) + if err != nil { + return nil, fmt.Errorf("decoding ti_context: %w", err) + } + sentryIntegration := mapStringOr(m, "sentry_integration", "") + + return &StartupDetails{ + TI: ti, + DagRelPath: dagRelPath, + BundleInfo: bundleInfo, + StartDate: startDate, + TIContext: tiContext, + SentryIntegration: sentryIntegration, + }, nil +} + +// Response types (for runtime-initiated requests). + +// ConnectionResult is the response to GetConnection. +type ConnectionResult struct { + ConnID string + ConnType string + Host string + Schema string + Login string + Password string + Port int + Extra string +} + +func decodeConnectionResult(m map[string]any) (*ConnectionResult, error) { + return &ConnectionResult{ + ConnID: mapStringOr(m, "conn_id", ""), + ConnType: mapStringOr(m, "conn_type", ""), + Host: mapStringOr(m, "host", ""), + Schema: mapStringOr(m, "schema", ""), + Login: mapStringOr(m, "login", ""), + Password: mapStringOr(m, "password", ""), + Port: mapIntOr(m, "port", 0), + Extra: mapStringOr(m, "extra", ""), + }, nil +} + +// VariableResult is the response to GetVariable. +type VariableResult struct { + Key string + Value any +} + +func decodeVariableResult(m map[string]any) (*VariableResult, error) { + return &VariableResult{ + Key: mapStringOr(m, "key", ""), + Value: m["value"], + }, nil +} + +// XComResult is the response to GetXCom. +type XComResult struct { + Key string + Value any +} + +func decodeXComResult(m map[string]any) (*XComResult, error) { + return &XComResult{ + Key: mapStringOr(m, "key", ""), + Value: m["value"], + }, nil +} + +// ErrorResponse represents an error returned by the supervisor. +type ErrorResponse struct { + Error string + Detail any +} + +func decodeErrorResponse(m map[string]any) *ErrorResponse { + if m == nil { + return nil + } + return &ErrorResponse{ + Error: mapStringOr(m, "error", ""), + Detail: m["detail"], + } +} + +// Outbound messages (Runtime -> Supervisor). + +// GetConnectionMsg is sent to request a connection from the supervisor. +type GetConnectionMsg struct { + ConnID string +} + +func (m GetConnectionMsg) toMap() map[string]any { + return map[string]any{ + "type": "GetConnection", + "conn_id": m.ConnID, + } +} + +// GetVariableMsg is sent to request a variable from the supervisor. +type GetVariableMsg struct { + Key string +} + +func (m GetVariableMsg) toMap() map[string]any { + return map[string]any{ + "type": "GetVariable", + "key": m.Key, + } +} + +// GetXComMsg is sent to request an XCom value from the supervisor. +type GetXComMsg struct { + Key string + DagID string + TaskID string + RunID string + MapIndex *int + IncludePriorDates bool +} + +func (m GetXComMsg) toMap() map[string]any { + result := map[string]any{ + "type": "GetXCom", + "key": m.Key, + "dag_id": m.DagID, + "task_id": m.TaskID, + "run_id": m.RunID, + "include_prior_dates": m.IncludePriorDates, + } + if m.MapIndex != nil { + result["map_index"] = *m.MapIndex + } + return result +} + +// SetXComMsg is sent to set an XCom value. +type SetXComMsg struct { + Key string + Value any + DagID string + TaskID string + RunID string + MapIndex int Review Comment: Python's `SetXCom.map_index` is `int | None` (`comms.py:902`); here it's `int` with -1 as the implicit sentinel. For unmapped tasks both sides treat -1 the same, but you've lost the ability to express `None` explicitly on the wire. If a caller ever needs to express "all map indexes" semantics or "unset" distinctly from -1, the current encoding can't represent it. `GetXComMsg.MapIndex` above already uses `*int` -- suggest matching, or at minimum a doc comment pinning the `-1 means unmapped` convention so future readers don't pick the wrong default. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
