This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2961561430a AWS Lambda Executor (#50516)
2961561430a is described below
commit 2961561430af37d881ac27be4f404ae45810425e
Author: Niko Oliveira <[email protected]>
AuthorDate: Thu May 15 15:36:11 2025 -0700
AWS Lambda Executor (#50516)
This changeset includes:
- The Lambda executor code itself
- Tests for the executor
- Docs for the executor
---
.../tests/unit/always/test_project_structure.py | 2 +
providers/amazon/docs/executors/batch-executor.rst | 27 +
providers/amazon/docs/executors/ecs-executor.rst | 27 +
providers/amazon/docs/executors/general.rst | 48 +-
providers/amazon/docs/executors/index.rst | 1 +
.../amazon/docs/executors/lambda-executor.rst | 358 ++++++++
providers/amazon/docs/img/lambda_executor.png | Bin 0 -> 56488 bytes
providers/amazon/provider.yaml | 76 ++
.../amazon/aws/executors/aws_lambda/__init__.py | 21 +
.../aws/executors/aws_lambda/docker/Dockerfile | 107 +++
.../aws/executors/aws_lambda/docker/__init__.py | 16 +
.../amazon/aws/executors/aws_lambda/docker/app.py | 129 +++
.../aws/executors/aws_lambda/lambda_executor.py | 479 +++++++++++
.../amazon/aws/executors/aws_lambda/utils.py | 70 ++
.../airflow/providers/amazon/get_provider_info.py | 68 ++
.../amazon/aws/executors/aws_lambda/__init__.py | 16 +
.../executors/aws_lambda/test_lambda_executor.py | 928 +++++++++++++++++++++
17 files changed, 2354 insertions(+), 19 deletions(-)
diff --git a/airflow-core/tests/unit/always/test_project_structure.py
b/airflow-core/tests/unit/always/test_project_structure.py
index a1cfba80cf6..e81476b0216 100644
--- a/airflow-core/tests/unit/always/test_project_structure.py
+++ b/airflow-core/tests/unit/always/test_project_structure.py
@@ -69,6 +69,8 @@ class TestProjectStructure:
"providers/amazon/tests/unit/amazon/aws/executors/ecs/test_boto_schema.py",
"providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor_config.py",
"providers/amazon/tests/unit/amazon/aws/executors/ecs/test_utils.py",
+
"providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_utils.py",
+
"providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/docker/test_app.py",
"providers/amazon/tests/unit/amazon/aws/executors/utils/test_base_config_keys.py",
"providers/amazon/tests/unit/amazon/aws/operators/test_emr.py",
"providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker.py",
diff --git a/providers/amazon/docs/executors/batch-executor.rst
b/providers/amazon/docs/executors/batch-executor.rst
index c68f05537b7..920bc7e9e58 100644
--- a/providers/amazon/docs/executors/batch-executor.rst
+++ b/providers/amazon/docs/executors/batch-executor.rst
@@ -141,6 +141,33 @@ newly created Job Role and Execution role for the Job
Definition.
:start-after: .. BEGIN DOCKERFILE_AUTH_SECOND_METHOD
:end-before: .. END DOCKERFILE_AUTH_SECOND_METHOD
+.. include:: general.rst
+ :start-after: .. BEGIN BASE_IMAGE
+ :end-before: .. END BASE_IMAGE
+
+.. include:: general.rst
+ :start-after: .. BEGIN LOADING_DAGS_OVERVIEW
+ :end-before: .. END LOADING_DAGS_OVERVIEW
+
+.. include:: general.rst
+ :start-after: .. BEGIN LOADING_DAGS_FROM_S3
+ :end-before: .. END LOADING_DAGS_FROM_S3
+
+.. include:: general.rst
+ :start-after: .. BEGIN LOADING_DAGS_FROM_LOCAL
+ :end-before: .. END LOADING_DAGS_FROM_LOCAL
+
+.. include:: general.rst
+ :start-after: .. BEGIN DEPENDENCIES
+ :end-before: .. END DEPENDENCIES
+
+Building Image for AWS Batch Executor
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Detailed instructions on how to use the Docker image, that you have
+created via this readme, with the Batch Executor can be found
+:ref:`here <batch_setup_guide>`.
+
.. _batch_logging:
.. include:: general.rst
diff --git a/providers/amazon/docs/executors/ecs-executor.rst
b/providers/amazon/docs/executors/ecs-executor.rst
index fa1da53e0ad..d7451e55277 100644
--- a/providers/amazon/docs/executors/ecs-executor.rst
+++ b/providers/amazon/docs/executors/ecs-executor.rst
@@ -162,6 +162,33 @@ newly created Task Role and Task Execution role for the
Task Definition.
:start-after: .. BEGIN DOCKERFILE_AUTH_SECOND_METHOD
:end-before: .. END DOCKERFILE_AUTH_SECOND_METHOD
+.. include:: general.rst
+ :start-after: .. BEGIN BASE_IMAGE
+ :end-before: .. END BASE_IMAGE
+
+.. include:: general.rst
+ :start-after: .. BEGIN LOADING_DAGS_OVERVIEW
+ :end-before: .. END LOADING_DAGS_OVERVIEW
+
+.. include:: general.rst
+ :start-after: .. BEGIN LOADING_DAGS_FROM_S3
+ :end-before: .. END LOADING_DAGS_FROM_S3
+
+.. include:: general.rst
+ :start-after: .. BEGIN LOADING_DAGS_FROM_LOCAL
+ :end-before: .. END LOADING_DAGS_FROM_LOCAL
+
+.. include:: general.rst
+ :start-after: .. BEGIN DEPENDENCIES
+ :end-before: .. END DEPENDENCIES
+
+Building Image for AWS ECS Executor
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Detailed instructions on how to use the Docker image, that you have
+created via this readme, with the ECS Executor can be found
+:ref:`here <setup_guide>`.
+
.. _logging:
.. include:: general.rst
diff --git a/providers/amazon/docs/executors/general.rst
b/providers/amazon/docs/executors/general.rst
index 7df66e36ac4..4928e2d8f7a 100644
--- a/providers/amazon/docs/executors/general.rst
+++ b/providers/amazon/docs/executors/general.rst
@@ -42,8 +42,7 @@ An example Dockerfile can be found |dockerfileLink|, it
creates an
image that can be used by AWS |executorName| to run Airflow tasks using
the AWS |executorName| Executor in Apache Airflow. The image supports AWS
CLI/API
integration, allowing you to interact with AWS services within your
-Airflow environment. It also includes options to load DAGs (Directed
-Acyclic Graphs) from either an S3 bucket or a local folder.
+Airflow environment. It also includes options to load Dags from either an S3
bucket or a local folder.
Prerequisites
@@ -90,9 +89,10 @@ The second method is to use the build-time arguments
(``aws_access_key_id``, ``aws_secret_access_key``,
``aws_default_region``, and ``aws_session_token``).
-Note: This method is not recommended for use in production environments,
-because user credentials are stored in the container, which may be a
-security vulnerability.
+.. warning::
+ This method is not recommended for use in production environments,
+ because user credentials are stored in the container, which may be a
+ security vulnerability.
To pass AWS authentication information using these arguments, use the
``--build-arg`` option during the Docker build process. For example:
@@ -109,6 +109,10 @@ Replace ``YOUR_ACCESS_KEY``, ``YOUR_SECRET_KEY``,
``YOUR_SESSION_TOKEN``, and ``YOUR_DEFAULT_REGION`` with valid AWS
credentials.
+.. END DOCKERFILE_AUTH_SECOND_METHOD
+
+.. BEGIN BASE_IMAGE
+
Base Image
~~~~~~~~~~
@@ -142,15 +146,22 @@ version
<https://hub.docker.com/r/apache/airflow/tags?page=1&name=3.9>`__.
For example, the tag ``latest-python3.9`` specifies that the image will
have python 3.9 installed.
+.. END BASE_IMAGE
+
+.. BEGIN LOADING_DAGS_OVERVIEW
Loading DAGs
~~~~~~~~~~~~
-There are many ways to load DAGs on a container managed by |executorName|.
This Dockerfile
+There are many ways to load DAGs on a container used by |executorName|. This
Dockerfile
is preconfigured with two possible ways: copying from a local folder, or
downloading from an S3 bucket. Other methods of loading DAGs are
possible as well.
+.. END LOADING_DAGS_OVERVIEW
+
+.. BEGIN LOADING_DAGS_FROM_S3
+
From S3 Bucket
^^^^^^^^^^^^^^
@@ -176,6 +187,10 @@ build arguments.
--build-arg aws_session_token=YOUR_SESSION_TOKEN \
--build-arg s3_uri=YOUR_S3_URI .
+.. END LOADING_DAGS_FROM_S3
+
+.. BEGIN LOADING_DAGS_FROM_LOCAL
+
From Local Folder
^^^^^^^^^^^^^^^^^
@@ -194,6 +209,10 @@ If choosing to load DAGs onto a different path than
``/opt/airflow/dags``, then the new path will need to be updated in the
Airflow config.
+.. END LOADING_DAGS_FROM_LOCAL
+
+.. BEGIN DEPENDENCIES
+
Installing Python Dependencies
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -206,28 +225,19 @@ Uncomment the two appropriate lines in the Dockerfile
that copy the
``requirements.txt`` file to the container, and run ``pip install`` to
install the dependencies on the container.
-Building Image for AWS |executorName| Executor
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-Detailed instructions on how to use the Docker image, that you have
-created via this readme, with the |executorName| Executor can be found
-:ref:`here <setup_guide>`.
-
-.. END DOCKERFILE_AUTH_SECOND_METHOD
+.. END DEPENDENCIES
.. BEGIN LOGGING
Logging
-------
-Airflow tasks executed via this executor run in containers within
+Airflow tasks executed via this executor run within
the configured VPC. This means that logs are not directly accessible to
-the Airflow Webserver and when containers are stopped, after task
-completion, the logs would be permanently lost.
+the Airflow UI, after task completion, the logs would be permanently lost.
Remote logging should be employed when using the |executorName| executor to
persist
-your Airflow Task logs and make them viewable from the Airflow
-Webserver.
+your Airflow Task logs and make them viewable from the Airflow UI.
Configuring Remote Logging
~~~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/providers/amazon/docs/executors/index.rst
b/providers/amazon/docs/executors/index.rst
index 117cd1faccc..8117d4d969f 100644
--- a/providers/amazon/docs/executors/index.rst
+++ b/providers/amazon/docs/executors/index.rst
@@ -26,3 +26,4 @@ Amazon Executors
ECS Executor <ecs-executor>
Batch Executor (experimental) <batch-executor>
+ Lambda Executor (experimental) <lambda-executor>
diff --git a/providers/amazon/docs/executors/lambda-executor.rst
b/providers/amazon/docs/executors/lambda-executor.rst
new file mode 100644
index 00000000000..3d09931fd96
--- /dev/null
+++ b/providers/amazon/docs/executors/lambda-executor.rst
@@ -0,0 +1,358 @@
+.. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+
+.. warning::
+ The Lambda Executor is alpha/experimental at the moment and may unstable
and be subject to change without warning.
+.. |executorName| replace:: Lambda
+.. |dockerfileLink| replace:: `here
<https://github.com/apache/airflow/blob/main/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/docker/Dockerfile>`__
+.. |appHandlerLink| replace:: `here
<https://github.com/apache/airflow/blob/main/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/docker/app.py>`__
+
+===================
+AWS Lambda Executor
+===================
+
+This is an Airflow executor powered by AWS Lambda. Each task that Airflow
schedules
+for execution is run as an asynchronous Lambda function invocation. Some
benefits
+of an executor like this include:
+
+1. Task isolation: No task can be a noisy neighbor for another. Resources are
+ isolated to each individual task. Any failures affect only the single task
+ running in that Lambda function. No single user can overload the environment
+ by triggering too many tasks.
+2. Lightweight execution: Lambda functions start quickly and have minimal
overhead,
+ making this executor ideal for short-running tasks that don't require
substantial
+ CPU or memory.
+3. Cost effective: Compute resources only exist for the lifetime of the
Airflow task
+ itself. Lambda's pay-per-use pricing model ensures you only pay for actual
task
+ execution time.
+4. Highly scalable: Lambda can scale to handle thousands of concurrent
executions
+ with no pre-provisioning.
+
+For a quick start guide please see :ref:`here <lambda_setup_guide>`, it will
+get you up and running with a basic configuration. Both an example Dockerfile
+(found |dockerfileLink|) and Lambda invocation handler (found |appHandlerLink|)
+are provided. The below sections provide more generic details about
configuration,
+the provided example Dockerfile and logging.
+
+How Does it Work?
+-----------------
+The executor leverages AWS Lambda for compute and SQS as a communication
channel.
+When a task is queued to run with the Lambda executor a Lambda function is
asynchronously
+invoked with a payload containing the task key being run and the Airflow CLI
+command/workload to run. The Lambda function executes the task and then must
+publish a message to the SQS results queue with the same task key it received
as well as
+the return code of the airflow task execution process. The Lambda executor
periodically
+polls this SQS results queue for messages and updates its internal state with
the results
+it receives. The Lambda function must also be configured with a Dead Letter
Queue (DLQ) to
+which Lambda will automatically send messages to for invocations that fail to
be processed,
+for example due to the 15 minute max timeout of Lambda functions. The Lambda
executor will
+also poll this DLQ for messages and update the internal state with the results
it receives.
+
+.. image:: ../img/lambda_executor.png
+
+.. _lambda_config-options:
+
+Config Options
+--------------
+
+There are a number of configuration options available, which can either
+be set directly in the airflow.cfg file under an "aws_lambda_executor"
+section or via environment variables using the
+``AIRFLOW__AWS_LAMBDA_EXECUTOR__<OPTION_NAME>`` format, for example
+``AIRFLOW__AWS_LAMBDA_EXECUTOR__FUNCTION_NAME = "myLambdaFunction"``. For
+more information on how to set these options, see `Setting Configuration
+Options
<https://airflow.apache.org/docs/apache-airflow/stable/howto/set-config.html>`__
+
+.. note::
+ Configuration options must be consistent across all the hosts/environments
running the Airflow components (Scheduler, Webserver, Lambda function, etc).
See `here
<https://airflow.apache.org/docs/apache-airflow/stable/configurations-ref.html>`__
for more details on setting configurations.
+
+In the case of conflicts, the order of precedence from lowest to highest is:
+
+1. Load default values for options which have defaults.
+2. Load any values explicitly provided through airflow.cfg or
+ environment variables. These are checked with Airflow's config
+ precedence.
+
+Required config options:
+~~~~~~~~~~~~~~~~~~~~~~~~
+
+- FUNCTION_NAME - Name or ARN of the AWS Lambda function to invoke. Required.
+- QUEUE_URL - URL of the SQS queue that will receive task execution results.
Required.
+- DEAD_LETTER_QUEUE_URL - URL of the Dead Letter Queue for processing Lambda
execution failures. Required
+
+Optional config options:
+~~~~~~~~~~~~~~~~~~~~~~~~
+
+- FUNCTION_QUALIFIER - Version or alias of the Lambda function to invoke.
Defaults to ``$LATEST``.
+- CONN_ID - The Airflow connection (i.e. credentials) used by the Lambda
+ executor to make API calls to AWS Lambda and SQS. Defaults to
``aws_default``.
+- REGION_NAME - The name of the AWS Region where AWS Lambda and SQS are
configured.
+- CHECK_HEALTH_ON_STARTUP - Whether to check the Lambda Executor health on
startup by testing connectivity to required AWS services. Defaults to ``True``.
+- MAX_INVOKE_ATTEMPTS - The maximum number of times the Lambda Executor
should attempt to invoke a function. This refers to instances where the
invocation fails (i.e. AWS API failures, throttling, etc). Defaults to ``3``.
+
+For a more detailed description of available options, including type
+hints and examples, see the ``config_templates`` folder in the Amazon
+provider package.
+
+.. _dockerfile_for_lambda_executor:
+
+.. include:: general.rst
+ :start-after: .. BEGIN DOCKERFILE
+ :end-before: .. END DOCKERFILE
+
+
+The most secure method is to use IAM roles. When creating a Lambda Function
+Definition, you are able to select an execution role. This role needs
+permissions to publish messages to the SQS queues and to write to
CloudWatchLogs
+or S3 if using AWS remote logging and/or using S3 to synchronize dags
+(e.g. ``CloudWatchLogsFullAccess`` or ``CloudWatchLogsFullAccessV2``).
+The AWS credentials used on the Scheduler need permissions to
+describe and invoke Lambda functions as well as to describe and read/delete
+SQS messages.
+
+To create a new role for use by Lambda follow the below:
+
+1. Navigate to the IAM page on the AWS console, and from the left hand
+ tab, under Access Management, select Roles.
+2. On the Roles page, click Create role on the top right hand corner.
+3. Under Trusted entity type, select AWS Service.
+4. Select Lambda from the drop down under Use case. Click Next.
+5. In the Permissions page, select the permissions the role will need,
+ depending on the actions it will perform (at least including the
+ permissions described above).
+ Click Next after selecting all the required permissions.
+6. Enter a name for the new role, and an optional description. Review
+ the Trusted Entities, and the permissions for the role. Add any tags
+ as necessary, and click Create role.
+
+When creating the Lambda Function (see the :ref:`setup guide
<lambda_setup_guide>` for more details), select the newly created role.
+
+.. include:: general.rst
+ :start-after: .. BEGIN DOCKERFILE_AUTH_SECOND_METHOD
+ :end-before: .. END DOCKERFILE_AUTH_SECOND_METHOD
+
+Base Image
+~~~~~~~~~~
+
+The Docker image is built upon the ``public.ecr.aws/lambda/python`` image.
+This base image contains the necessary components for the Lambda function to
+run. Apache Airflow is installed into the image via pip. You can modify the
+Dockerfile to install additional packages or different versions of Airflow as
+needed for your use case.
+
+It is also possible to build the image based of ``apache/airflow:latest``
+ and the Lambda runtime can be included separately (follow steps `here
<https://docs.aws.amazon.com/lambda/latest/dg/images-create.html#images-ric>`__).
+
+
+.. include:: general.rst
+ :start-after: .. BEGIN LOADING_DAGS_OVERVIEW
+ :end-before: .. END LOADING_DAGS_OVERVIEW
+
+From S3 Bucket
+^^^^^^^^^^^^^^
+
+Dags can be loaded from S3 when using the provided example app.py, which
+contains logic to synchronize the DAGs from S3 to the local filesystem of
+the Lambda function (see the app.py code |appHandlerLink|).
+
+To load Dags from an S3 bucket add ``--build-arg s3_uri=YOUR_S3_URI`` in
+the docker build command. Replace ``YOUR_S3_URI`` with the URI of your S3
+bucket/path containing your dags. Make sure you have the appropriate
+permissions to read from the bucket.
+
+.. code-block:: bash
+
+ docker build -t my-airflow-image --build-arg s3_uri=YOUR_S3_URI .
+
+.. include:: general.rst
+ :start-after: .. BEGIN LOADING_DAGS_FROM_LOCAL
+ :end-before: .. END LOADING_DAGS_FROM_LOCAL
+
+.. include:: general.rst
+ :start-after: .. BEGIN DEPENDENCIES
+ :end-before: .. END DEPENDENCIES
+
+Using the AWS Lambda Executor Docker Image
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Detailed instructions on how to use the Docker image, that you have
+created via these steps can be found :ref:`here <lambda_setup_guide>`.
+
+.. _lamba_logging:
+
+.. include:: general.rst
+ :start-after: .. BEGIN LOGGING
+ :end-before: .. END LOGGING
+
+- The configuration options for Airflow remote logging should be
+ configured on all hosts and containers running Airflow. For example
+ the Webserver requires this config so that it can fetch logs from
+ the remote location and the Lambda container requires the config so that
+ it can upload the logs to the remote location. See
+ `here
<https://airflow.apache.org/docs/apache-airflow/stable/howto/set-config.html>`__
+ to read more about how to set Airflow configuration via config file
+ or environment variable exports.
+- Adding the Airflow remote logging config to Lambda can be done
+ in many ways. Some examples include, but are not limited to:
+
+ - Exported as environment variables directly in the Dockerfile (see
+ the Dockerfile section :ref:`above <dockerfile_for_lambda_executor>`)
+ - Updating the ``airflow.cfg`` file or copy/mounting/downloading a
+ custom ``airflow.cfg`` in the Dockerfile.
+ - Added in the Lambda Function definition in plain text or via
+ `Secrets/System
+ Manager
<https://docs.aws.amazon.com/secretsmanager/latest/userguide/retrieving-secrets_lambda.html>`__
+
+- Remember that you must have credentials configured within the container to
be able
+ to interact with the remote service for your logs (e.g. S3,
+ CloudWatch Logs, etc). This can be done in many ways. Some examples
+ include, but are not limited to:
+
+ - Export credentials into the Dockerfile directly (see the
+ Dockerfile section :ref:`above <dockerfile_for_lambda_executor>`)
+ - Configure an Airflow Connection and provide this as the `remote
+ logging conn
+ id
<https://airflow.apache.org/docs/apache-airflow/stable/configurations-ref.html#remote-log-conn-id>`__
+ (exported into the Lambda Function by any of the means listed above or
+ your preferred method). Airflow will then use these credentials
+ *specifically* for interacting with your chosen remote logging
+ destination.
+
+.. note::
+ Configuration options must be consistent across all the hosts/environments
running the Airflow components (Scheduler, Webserver, Lambda invocations, etc).
See `here
<https://airflow.apache.org/docs/apache-airflow/stable/configurations-ref.html>`__
for more details on setting configurations.
+
+.. _lambda_logging:
+
+Lambda Logging
+~~~~~~~~~~~~~~~~
+
+Lambda will send logging information to CloudWatch Logs for the Function
+invocations themselves. These logs will include the Airflow Task Operator
+logging and any other logging that occurs throughout the life of the process
+running in the Lambda function. This can be helpful for debugging issues with
+remote logging or while testing remote logging configuration.
+
+**Note: These logs will NOT be viewable from the Airflow Webserver UI.**
+
+.. _lambda_setup_guide:
+
+
+Setting up a Lambda Executor for Apache Airflow
+-----------------------------------------------
+
+There are 3 steps involved in getting a Lambda Executor to work in Apache
Airflow:
+
+1. Creating a database that Airflow and the tasks running in Lambda Functions
can
+ connect to (Note: In Airflow 3, tasks no longer directly access the DB).
+
+2. Creating and configuring a Lambda Function (and its associated SQS queues)
that
+ can run tasks from Airflow.
+
+3. Configuring Airflow to use the Lambda Executor
+
+There are different options for selecting a database backend. See `here
<https://airflow.apache.org/docs/apache-airflow/stable/howto/set-up-database.html>`_
for more information about the different options supported by Airflow. The
following guide will explain how to set up a PostgreSQL RDS Instance on AWS.
+
+.. include:: general.rst
+ :start-after: .. BEGIN DATABASE_CONNECTION
+ :end-before: .. END DATABASE_CONNECTION
+
+
+Creating a Lambda Function
+--------------------------
+
+To create them Lambda Function first start with building the docker image
above (see the :ref:`Dockerfile <dockerfile_for_lambda_executor>` section for
instructions).
+
+Once the image is built, it needs to be put in a repository where it can be
pulled by Lambda. There are multiple ways
+to accomplish this. This guide will use Amazon Elastic Container Registry
(ECR).
+
+.. include:: general.rst
+ :start-after: .. BEGIN ECR_STEPS
+ :end-before: .. END ECR_STEPS
+
+Create SQS Queue
+~~~~~~~~~~~~~~~~
+
+You will also need two SQS queues 1) to send the results of the Airflow Task
executions
+and 2) a Dead Letter Queue (DLQ) to handle failed invocations of the Lambda
function.
+
+Create SQS queues by following these steps (for each queue):
+
+1. Log in to your AWS Management Console and navigate to the Amazon SQS
Service.
+
+2. Click "Create queue".
+
+3. Select "Standard" as the queue type.
+
+4. Provide a name for the queue, and select the defaults that work for your
use case.
+
+2. Click Create.
+
+Create Function
+~~~~~~~~~~~~~~~
+
+Finally create the function:
+
+1. Log in to your AWS Management Console and navigate to the AWS Lambda
Service.
+
+2. Click "Create Function".
+
+3. Select "Container image" as the function type.
+
+4. Provide a name for the function, select the ECR repository and image tag
that you created in the previous steps, and
+ select the execution role that you created for use in the Lambda Function.
+
+5. Click Create.
+
+6. Once created, add the following environment variables to the container
under Configuration > Environment variables:
+
+- ``AIRFLOW__DATABASE__SQL_ALCHEMY_CONN``, with the value being the PostgreSQL
connection string in the following format using the values set during the
`Database section <#create-the-rds-db-instance>`_ above:
+
+.. code-block:: bash
+
+ postgresql+psycopg2://<username>:<password>@<endpoint>/<database_name>
+
+
+- ``AIRFLOW__LAMBDA_EXECUTOR__QUEUE_URL``, with the value being the URL of the
SQS queue created above.
+
+7. Add other configuration as necessary for Airflow generally (see `here
<https://airflow.apache.org/docs/apache-airflow/stable/configurations-ref.html>`__),
the Lambda executor (see :ref:`here <lambda_config-options>`) or for remote
logging (see :ref:`here <lamba_logging>`). Note that any configuration changes
should be made across the entire Airflow environment to keep configuration
consistent.
+
+8. Add the dead letter queue URL to the Lambda function under Configuration >
Asynchronous invocation > edit > Dead-letter queue. Select the SQS queue you
created above.
+
+9. If using Airflow 2 allow access to the database. There are many possible
methods, but one simple approach is to add a connection to the RDS database via
Configuration > RDS databases, follow the steps of the wizard.
+
+Configure Airflow
+~~~~~~~~~~~~~~~~~
+
+To configure Airflow to utilize the Lambda Executor and leverage the resources
we've set up, create a script (e.g., ``lambda_executor_config.sh``) with at
least the following contents:
+
+.. code-block:: bash
+
+ export
AIRFLOW__CORE__EXECUTOR='airflow.providers.amazon.aws.executors.aws_lambda.AwsLambdaExecutor'
+
+ export AIRFLOW__DATABASE__SQL_ALCHEMY_CONN=<postgres-connection-string>
+
+ export AIRFLOW__AWS_LAMBDA_EXECUTOR__FUNCTION_NAME=<lambda-function-name>
+
+ export AIRFLOW__AWS_LAMBDA_EXECUTOR__QUEUE_URL=<sqs_queue_url>
+
+ export
AIRFLOW__AWS_LAMBDA_EXECUTOR__DEAD_LETTER_QUEUE_URL=<sqs_dlq_queue_url>
+
+
+.. include:: general.rst
+ :start-after: .. BEGIN INIT_DB
+ :end-before: .. END INIT_DB
diff --git a/providers/amazon/docs/img/lambda_executor.png
b/providers/amazon/docs/img/lambda_executor.png
new file mode 100644
index 00000000000..98773c916c1
Binary files /dev/null and b/providers/amazon/docs/img/lambda_executor.png
differ
diff --git a/providers/amazon/provider.yaml b/providers/amazon/provider.yaml
index 64013749ac4..fb7bf428790 100644
--- a/providers/amazon/provider.yaml
+++ b/providers/amazon/provider.yaml
@@ -1004,6 +1004,82 @@ config:
example: '{"Tags": [{"Key": "key", "Value": "value"}]}'
default: ~
+ aws_lambda_executor:
+ description: |
+ This section only applies if you are using the AwsLambdaExecutor in
+ Airflow's ``[core.executor]`` configuration.
+ For more information see:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/lambda/client/invoke.html
+
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html
+
https://airflow.apache.org/docs/apache-airflow-providers-amazon/stable/executors/lambda-executor.html
+ options:
+ conn_id:
+ description: |
+ The Airflow connection (i.e. credentials) used by the Lambda
executor to make API calls.
+ version_added: "9.9.0"
+ type: string
+ example: "aws_default"
+ default: "aws_default"
+ region_name:
+ description: |
+ The name of the AWS Region where Amazon Lambda is configured.
+ version_added: "9.9.0"
+ type: string
+ example: "us-east-1"
+ default: ~
+ check_health_on_startup:
+ description: |
+ Whether or not to check the Lambda Executor health on startup.
+ version_added: "9.9.0"
+ type: boolean
+ example: "True"
+ default: "True"
+ max_run_task_attempts:
+ description: |
+ The maximum number of times the Lambda Executor should attempt to
start an Airflow task.
+ version_added: "9.9.0"
+ type: integer
+ example: "3"
+ default: "3"
+ queue_url:
+ description: |
+ The URL of the SQS queue to use for the Lambda executor. Required.
+ The Lambda executor will poll this queue for results of the lambda
function's Airflow Task.
+ version_added: "9.9.0"
+ type: string
+ example: "airflow-lambda-executor-results-queue"
+ default: ~
+ dead_letter_queue_url:
+ description: |
+ The URL of the SQS dead letter queue to use for the Lambda function.
The Lambda executor
+ will poll this queue for timeout/exception results of the lambda
function. Required
+ version_added: "9.9.0"
+ type: string
+ example: "airflow-lambda-executor-dlq"
+ default: ~
+ function_name:
+ description: |
+ The name of the Lambda function to invoke. Required.
+ version_added: "9.9.0"
+ type: string
+ example: "airflow-lambda-executor-function"
+ default: ~
+ qualifier:
+ description: |
+ The version or alias of the Lambda function to invoke. If not
specified, the latest version is used.
+ version_added: "9.9.0"
+ type: string
+ example: "1"
+ default: ~
+ end_wait_timeout:
+ description: |
+ The number of seconds to wait for all lambda executor invocations to
complete when terminating the
+ executor/scheduler. Default of 0 means wait indefinitely.
+ version_added: "9.9.0"
+ type: integer
+ example: "500"
+ default: "0"
+
aws_ecs_executor:
description: |
This section only applies if you are using the AwsEcsExecutor in
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/__init__.py
b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/__init__.py
new file mode 100644
index 00000000000..b3f8fdcb26e
--- /dev/null
+++
b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/__init__.py
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+__all__ = ["AwsLambdaExecutor"]
+
+from airflow.providers.amazon.aws.executors.aws_lambda.lambda_executor import
AwsLambdaExecutor
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/docker/Dockerfile
b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/docker/Dockerfile
new file mode 100644
index 00000000000..f5c4648caf4
--- /dev/null
+++
b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/docker/Dockerfile
@@ -0,0 +1,107 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Use the official AWS Lambda Python base image.
+# If you wish to use a different Python version, please update the line below
+FROM public.ecr.aws/lambda/python:3.12
+
+# hadolint ignore=DL3041
+RUN dnf -y install unzip \
+ && dnf clean all \
+ && rm -rf /var/cache/dnf
+
+# Install the AWS CLI.
+RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o
"awscliv2.zip" && \
+ unzip awscliv2.zip \
+ && ./aws/install \
+ && rm -rf ./aws awscliv2.zip
+
+## Install Airflow and dependencies.
+# The most current version of Airflow is installed by default, along with the
amazon and postgres
+# provider packages.
+# If you would like to install a specific version, you can use the
requirements.txt to change the
+# version along with installing your dependencies or update this Dockerfile to
install a specific
+# version of Airflow.
+
+# NOTE: If you change the below line, specifically removing the amazon extra,
please ensure boto3 is
+# installed via another method. Boto3 is required for the Lambda executor to
function properly.
+# hadolint ignore=SC2102
+RUN pip install --no-cache-dir --upgrade pip && pip install --no-cache-dir
apache-airflow[amazon,postgres]
+# /tmp is the only writable directory in Lambda, so we need to set the
AIRFLOW_HOME there.
+ENV AIRFLOW_HOME=/tmp/airflow
+# Dags are read-only, so they can be stored in opt (or another path provided).
+ARG container_dag_path=/opt/airflow/dags
+ENV AIRFLOW__CORE__DAGS_FOLDER=$container_dag_path
+RUN mkdir -p /opt/airflow/dags
+
+# Python dependencies can be installed by providing a requirements.txt.
+# If the file is in a different location, use the requirements_path build
argument to specify
+# the file path.
+ARG requirements_path=./requirements.txt
+ENV REQUIREMENTS_PATH=$requirements_path
+# Uncomment the two lines below to copy the requirements.txt file to the
container, and
+# install the dependencies.
+# COPY --chown=airflow:root $REQUIREMENTS_PATH /opt/airflow/requirements.txt
+# RUN pip install --no-cache-dir -r /opt/airflow/requirements.txt
+
+## AWS Authentication
+# The image requires access to AWS services. This Dockerfile supports 2 ways
to authenticate with AWS.
+# The first is using build arguments where you can provide the AWS credentials
as arguments
+# passed when building the image. The other option is to leverage the Lambda
execution role. Airflow
+# will default to using Boto credential strategy which will look for roles
from Lambda, this is the
+# preferred approach. See the Lambda Executor Airflow documentation for more
details.
+
+# If you would like to use an alternative method of authentication, feel free
to make the
+# necessary changes to this file.
+
+# Uncomment to use these arguments to provide AWS authentication information
if not using the Lambda
+# execution role.
+#ARG aws_access_key_id
+#ARG aws_secret_access_key
+#ARG aws_default_region
+#ARG aws_session_token
+
+#ENV AWS_ACCESS_KEY_ID=$aws_access_key_id
+#ENV AWS_SECRET_ACCESS_KEY=$aws_secret_access_key
+#ENV AWS_DEFAULT_REGION=$aws_default_region
+#ENV AWS_SESSION_TOKEN=$aws_session_token
+
+## Loading DAGs
+# This Dockerfile supports 2 ways to load DAGs onto the container.
+# One is to download them from S3 at runtime during the Lambda app invocation.
The other
+# is to copy the dags into the image at build time, this will make task
execution
+# much faster, since the images will already be present but the image will
need to be rebuilt
+# every time the DAGs are updated.
+# If you would like to use an alternative method of loading DAGs, feel free to
make the
+# necessary changes to this file.
+
+ARG host_dag_path=./dags
+ENV HOST_DAG_PATH=$host_dag_path
+# Uncomment the line below to copy the DAGs from the host to the container.
+# COPY $HOST_DAG_PATH $AIRFLOW__CORE__DAGS_FOLDER
+
+# Use these arguments to load DAGs at runtime. If you are using the provided
Lambda function (app.py),
+# it will check for this environment variable and download the DAGs from S3.
See the example app in Lambda
+# Executor documentation for details.
+ARG s3_uri
+ENV S3_URI=$s3_uri
+
+# Copy your Lambda function code into the Docker build directory/context.
+COPY app.py ${LAMBDA_TASK_ROOT}/
+
+# Specify the Lambda function handler (update if you provide a different
handler with a different name).
+CMD ["app.lambda_handler"]
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/docker/__init__.py
b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/docker/__init__.py
new file mode 100644
index 00000000000..13a83393a91
--- /dev/null
+++
b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/docker/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/docker/app.py
b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/docker/app.py
new file mode 100644
index 00000000000..9d9b5a5b15e
--- /dev/null
+++
b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/docker/app.py
@@ -0,0 +1,129 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import json
+import logging
+import os
+import subprocess
+from tempfile import mkdtemp
+
+import boto3
+
+"""
+Example Lambda function to execute an Airflow command or workload. Use or
modify this code as needed.
+"""
+
+log = logging.getLogger()
+log.setLevel(logging.INFO)
+
+
+# Get the S3 URI from the environment variable. Set either on the Lambda
function or in the
+# docker image used for the lambda invocations.
+S3_URI = os.environ.get("S3_URI", None)
+# Input and output keys
+TASK_KEY_KEY = "task_key"
+COMMAND_KEY = "command"
+RETURN_CODE_KEY = "return_code"
+
+
+def lambda_handler(event, context):
+ log.info("Received event: %s", event)
+ log.info("Received context: %s", context)
+
+ command = event.get(COMMAND_KEY)
+ task_key = event.get(TASK_KEY_KEY)
+
+ # Any pre-processing or validation of the command or use of the context
can be done here or above.
+
+ # Sync dags from s3 to the local dags directory
+ if S3_URI:
+ fetch_dags_from_s3(S3_URI)
+ # This function must be called, it executes the Airflow command and
reports to SQS.
+ run_and_report(command, task_key)
+
+ # Any post-processing or cleanup can be done here.
+
+
+def run_and_report(command, task_key):
+ """Execute the provided Airflow command or workload and report the result
via SQS."""
+ try:
+ log.info("Starting execution for task: %s", task_key)
+ result = subprocess.run(
+ command, shell=isinstance(command, str), stdout=subprocess.PIPE,
stderr=subprocess.STDOUT
+ )
+ return_code = result.returncode
+ log.info("Execution completed for task %s with return code %s",
task_key, return_code)
+ log.info("Output:")
+ log.info("%s", result.stdout.decode())
+ except Exception:
+ log.exception("Error executing task %s: ", task_key)
+ return_code = 1 # Non-zero indicates failure to run the task
+
+ queue_url = get_queue_url()
+ message = json.dumps({TASK_KEY_KEY: task_key, RETURN_CODE_KEY:
return_code})
+ try:
+ sqs_client = get_sqs_client()
+ sqs_client.send_message(QueueUrl=queue_url, MessageBody=message)
+ log.info("Sent result to SQS %s", message)
+ except Exception:
+ log.exception("Failed to send message to SQS for task %s", task_key)
+
+
+def get_sqs_client():
+ """Create an SQS client. Credentials and region are automatically picked
up from the environment."""
+ return boto3.client("sqs")
+
+
+def get_queue_url():
+ """
+ Get the SQS queue URL from the environment variable.
+
+ Set either on the Lambda function or in the image used for the lambda
invocations.
+ """
+ queue_url = os.environ.get("AIRFLOW__AWS_LAMBDA_EXECUTOR__QUEUE_URL",
os.environ.get("QUEUE_URL", None))
+ if not queue_url:
+ raise RuntimeError(
+ "No Queue URL detected (either
AIRFLOW__AWS_LAMBDA_EXECUTOR__QUEUE_URL or "
+ "QUEUE_URL); Will be unable to send task results. Exiting!"
+ )
+ return queue_url
+
+
+def fetch_dags_from_s3(s3_uri):
+ """Fetch DAGs from S3 and sync them to the local dags directory."""
+ log.info("Fetching DAGs from S3 URI: %s", s3_uri)
+ # Use a named temporary directory for the local dags folder, only tmp is
writeable in Lambda
+ local_dags_dir = mkdtemp(prefix="airflow_dags_")
+ log.info("Setting AIRFLOW__CORE__DAGS_FOLDER to: %s", local_dags_dir)
+ os.environ["AIRFLOW__CORE__DAGS_FOLDER"] = local_dags_dir
+
+ # S3 URI format s3://bucket-name/path/to/dags/
+ bucket_name = s3_uri.split("/")[2]
+ prefix = "/".join(s3_uri.split("/")[3:])
+
+ s3_resource = boto3.resource("s3")
+ bucket = s3_resource.Bucket(bucket_name)
+
+ for obj in bucket.objects.filter(Prefix=prefix):
+ if obj.key.endswith("/"):
+ # Skip directories
+ continue
+ key = obj.key
+ local_path = os.path.join(local_dags_dir, os.path.basename(key))
+ log.info("Downloading %s to %s", key, local_path)
+ bucket.download_file(key, local_path)
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py
b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py
new file mode 100644
index 00000000000..3f477095f44
--- /dev/null
+++
b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py
@@ -0,0 +1,479 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import json
+import time
+from collections import deque
+from collections.abc import Sequence
+from typing import TYPE_CHECKING
+
+from boto3.session import NoCredentialsError
+from botocore.utils import ClientError
+
+from airflow.configuration import conf
+from airflow.exceptions import AirflowException
+from airflow.executors.base_executor import BaseExecutor
+from airflow.models.taskinstancekey import TaskInstanceKey
+from airflow.providers.amazon.aws.executors.aws_lambda.utils import (
+ CONFIG_GROUP_NAME,
+ INVALID_CREDENTIALS_EXCEPTIONS,
+ AllLambdaConfigKeys,
+ CommandType,
+ LambdaQueuedTask,
+)
+from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry
import (
+ calculate_next_attempt_delay,
+ exponential_backoff_retry,
+)
+from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook
+from airflow.providers.amazon.aws.hooks.sqs import SqsHook
+from airflow.stats import Stats
+from airflow.utils import timezone
+
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
+
+if TYPE_CHECKING:
+ from sqlalchemy.orm import Session
+
+ from airflow.executors import workloads
+ from airflow.models.taskinstance import TaskInstance
+
+
+class AwsLambdaExecutor(BaseExecutor):
+ """
+ An Airflow Executor that submits tasks to AWS Lambda asynchronously.
+
+ When execute_async() is called, the executor invokes a specified AWS
Lambda function (asynchronously)
+ with a payload that includes the task command and a unique task key.
+
+ The Lambda function writes its result directly to an SQS queue, which is
then polled by this executor
+ to update task state in Airflow.
+ """
+
+ if TYPE_CHECKING and AIRFLOW_V_3_0_PLUS:
+ # In the v3 path, we store workloads, not commands as strings.
+ # TODO: TaskSDK: move this type change into BaseExecutor
+ queued_tasks: dict[TaskInstanceKey, workloads.All] # type:
ignore[assignment]
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.pending_tasks: deque = deque()
+ self.running_tasks: dict[str, TaskInstanceKey] = {}
+ self.lambda_function_name = conf.get(CONFIG_GROUP_NAME,
AllLambdaConfigKeys.FUNCTION_NAME)
+ self.sqs_queue_url = conf.get(CONFIG_GROUP_NAME,
AllLambdaConfigKeys.QUEUE_URL)
+ self.dlq_url = conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.DLQ_URL)
+ self.qualifier = conf.get(CONFIG_GROUP_NAME,
AllLambdaConfigKeys.QUALIFIER, fallback=None)
+ # Maximum number of retries to invoke Lambda.
+ self.max_invoke_attempts = conf.get(
+ CONFIG_GROUP_NAME,
+ AllLambdaConfigKeys.MAX_INVOKE_ATTEMPTS,
+ )
+
+ self.attempts_since_last_successful_connection = 0
+ self.IS_BOTO_CONNECTION_HEALTHY = False
+ self.load_connections(check_connection=False)
+
+ def start(self):
+ """Call this when the Executor is run for the first time by the
scheduler."""
+ check_health = conf.getboolean(CONFIG_GROUP_NAME,
AllLambdaConfigKeys.CHECK_HEALTH_ON_STARTUP)
+
+ if not check_health:
+ return
+
+ self.log.info("Starting Lambda Executor and determining health...")
+ try:
+ self.check_health()
+ except AirflowException:
+ self.log.error("Stopping the Airflow Scheduler from starting until
the issue is resolved.")
+ raise
+
+ def check_health(self):
+ """
+ Check the health of the Lambda and SQS connections.
+
+ For lambda: Use get_function to test if the lambda connection works
and the function can be
+ described.
+ For SQS: Use get_queue_attributes is used as a close analog to
describe to test if the SQS
+ connection is working.
+ """
+ self.IS_BOTO_CONNECTION_HEALTHY = False
+
+ def _check_queue(queue_url):
+ sqs_get_queue_attrs_response =
self.sqs_client.get_queue_attributes(
+ QueueUrl=queue_url,
AttributeNames=["ApproximateNumberOfMessages"]
+ )
+ approx_num_msgs =
sqs_get_queue_attrs_response.get("Attributes").get(
+ "ApproximateNumberOfMessages"
+ )
+ self.log.info(
+ "SQS connection is healthy and queue %s is present with %s
messages.",
+ queue_url,
+ approx_num_msgs,
+ )
+
+ self.log.info("Checking Lambda and SQS connections")
+ try:
+ # Check Lambda health
+ lambda_get_response =
self.lambda_client.get_function(FunctionName=self.lambda_function_name)
+ if self.lambda_function_name not in
lambda_get_response["Configuration"]["FunctionName"]:
+ raise AirflowException("Lambda function %s not found.",
self.lambda_function_name)
+ self.log.info(
+ "Lambda connection is healthy and function %s is present.",
self.lambda_function_name
+ )
+
+ # Check SQS results queue
+ _check_queue(self.sqs_queue_url)
+ # Check SQS dead letter queue
+ _check_queue(self.dlq_url)
+
+ # If we reach this point, both connections are healthy and all
resources are present
+ self.IS_BOTO_CONNECTION_HEALTHY = True
+ except Exception:
+ self.log.exception("Lambda Executor health check failed")
+ raise AirflowException(
+ "The Lambda executor will not be able to run Airflow tasks
until the issue is addressed."
+ )
+
+ def load_connections(self, check_connection: bool = True):
+ """
+ Retrieve the AWS connection via Hooks to leverage the Airflow
connection system.
+
+ :param check_connection: If True, check the health of the connection
after loading it.
+ """
+ self.log.info("Loading Connections")
+ aws_conn_id = conf.get(CONFIG_GROUP_NAME,
AllLambdaConfigKeys.AWS_CONN_ID)
+ region_name = conf.get(CONFIG_GROUP_NAME,
AllLambdaConfigKeys.REGION_NAME, fallback=None)
+ self.sqs_client = SqsHook(aws_conn_id=aws_conn_id,
region_name=region_name).conn
+ self.lambda_client = LambdaHook(aws_conn_id=aws_conn_id,
region_name=region_name).conn
+
+ self.attempts_since_last_successful_connection += 1
+ self.last_connection_reload = timezone.utcnow()
+
+ if check_connection:
+ self.check_health()
+ self.attempts_since_last_successful_connection = 0
+
+ def sync(self):
+ """
+ Sync the executor with the current state of tasks.
+
+ Check in on currently running tasks and attempt to run any new tasks
that have been queued.
+ """
+ if not self.IS_BOTO_CONNECTION_HEALTHY:
+ exponential_backoff_retry(
+ self.last_connection_reload,
+ self.attempts_since_last_successful_connection,
+ self.load_connections,
+ )
+ if not self.IS_BOTO_CONNECTION_HEALTHY:
+ return
+ try:
+ self.sync_running_tasks()
+ self.attempt_task_runs()
+ except (ClientError, NoCredentialsError) as error:
+ error_code = error.response["Error"]["Code"]
+ if error_code in INVALID_CREDENTIALS_EXCEPTIONS:
+ self.IS_BOTO_CONNECTION_HEALTHY = False
+ self.log.warning(
+ "AWS credentials are either missing or expired:
%s.\nRetrying connection", error
+ )
+ except Exception:
+ self.log.exception("An error occurred while syncing tasks")
+
+ def queue_workload(self, workload: workloads.All, session: Session | None)
-> None:
+ from airflow.executors import workloads
+
+ if not isinstance(workload, workloads.ExecuteTask):
+ raise RuntimeError(f"{type(self)} cannot handle workloads of type
{type(workload)}")
+ ti = workload.ti
+ self.queued_tasks[ti.key] = workload
+
+ def _process_workloads(self, workloads: list[workloads.All]) -> None:
+ from airflow.executors.workloads import ExecuteTask
+
+ for w in workloads:
+ if not isinstance(w, ExecuteTask):
+ raise RuntimeError(f"{type(self)} cannot handle workloads of
type {type(w)}")
+
+ command = [w]
+ key = w.ti.key
+ queue = w.ti.queue
+ executor_config = w.ti.executor_config or {}
+
+ del self.queued_tasks[key]
+ self.execute_async(key=key, command=command, queue=queue,
executor_config=executor_config) # type: ignore[arg-type]
+ self.running.add(key)
+
+ def execute_async(self, key: TaskInstanceKey, command: CommandType,
queue=None, executor_config=None):
+ """
+ Save the task to be executed in the next sync by inserting the
commands into a queue.
+
+ :param key: A unique task key (typically a tuple identifying the task
instance).
+ :param command: The shell command string to execute.
+ :param executor_config: (Unused) to keep the same signature as the
base.
+ :param queue: (Unused) to keep the same signature as the base.
+ """
+ if len(command) == 1:
+ from airflow.executors.workloads import ExecuteTask
+
+ if isinstance(command[0], ExecuteTask):
+ workload = command[0]
+ ser_input = workload.model_dump_json()
+ command = [
+ "python",
+ "-m",
+ "airflow.sdk.execution_time.execute_workload",
+ "--json-string",
+ ser_input,
+ ]
+ else:
+ raise RuntimeError(
+ f"LambdaExecutor doesn't know how to handle workload of
type: {type(command[0])}"
+ )
+
+ self.pending_tasks.append(
+ LambdaQueuedTask(
+ key, command, queue if queue else "", executor_config or {},
1, timezone.utcnow()
+ )
+ )
+
+ def attempt_task_runs(self):
+ """
+ Attempt to run tasks that are queued in the pending_tasks.
+
+ Each task is submitted to AWS Lambda with a payload containing the
task key and command.
+ The task key is used to track the task's state in Airflow.
+ """
+ queue_len = len(self.pending_tasks)
+ for _ in range(queue_len):
+ task_to_run = self.pending_tasks.popleft()
+ task_key = task_to_run.key
+ cmd = task_to_run.command
+ attempt_number = task_to_run.attempt_number
+ failure_reasons = []
+ ser_task_key = json.dumps(task_key._asdict())
+ payload = {
+ "task_key": ser_task_key,
+ "command": cmd,
+ }
+ if timezone.utcnow() < task_to_run.next_attempt_time:
+ self.pending_tasks.append(task_to_run)
+ continue
+
+ self.log.info("Submitting task %s to Lambda function %s",
task_key, self.lambda_function_name)
+
+ try:
+ invoke_kwargs = {
+ "FunctionName": self.lambda_function_name,
+ "InvocationType": "Event",
+ "Payload": json.dumps(payload),
+ }
+ if self.qualifier:
+ invoke_kwargs["Qualifier"] = self.qualifier
+ response = self.lambda_client.invoke(**invoke_kwargs)
+ except NoCredentialsError:
+ self.pending_tasks.append(task_to_run)
+ raise
+ except ClientError as e:
+ error_code = e.response["Error"]["Code"]
+ if error_code in INVALID_CREDENTIALS_EXCEPTIONS:
+ self.pending_tasks.append(task_to_run)
+ raise
+ failure_reasons.append(str(e))
+ except Exception as e:
+ # Failed to even get a response back from the Boto3 API or
something else went
+ # wrong. For any possible failure we want to add the
exception reasons to the
+ # failure list so that it is logged to the user and most
importantly the task is
+ # added back to the pending list to be retried later.
+ failure_reasons.append(str(e))
+
+ if failure_reasons:
+ # Make sure the number of attempts does not exceed max invoke
attempts
+ if int(attempt_number) < int(self.max_invoke_attempts):
+ task_to_run.attempt_number += 1
+ task_to_run.next_attempt_time = timezone.utcnow() +
calculate_next_attempt_delay(
+ attempt_number
+ )
+ self.pending_tasks.append(task_to_run)
+ else:
+ reasons_str = ", ".join(failure_reasons)
+ self.log.error(
+ "Lambda invoke %s has failed a maximum of %s times.
Marking as failed. Reasons: %s",
+ task_key,
+ attempt_number,
+ reasons_str,
+ )
+ self.log_task_event(
+ event="lambda invoke failure",
+ ti_key=task_key,
+ extra=(
+ f"Task could not be queued after {attempt_number}
attempts. "
+ f"Marking as failed. Reasons: {reasons_str}"
+ ),
+ )
+ self.fail(task_key)
+ else:
+ status_code = response.get("StatusCode")
+ self.log.info("Invoked Lambda for task %s with status %s",
task_key, status_code)
+ self.running_tasks[ser_task_key] = task_key
+ # Add the serialized task key as the info, this will be
assigned on the ti as the external_executor_id
+ self.running_state(task_key, ser_task_key)
+
+ def sync_running_tasks(self):
+ """
+ Poll the SQS queue for messages indicating task completion.
+
+ Each message is expected to contain a JSON payload with 'task_key' and
'return_code'.
+ Based on the return code, update the task state accordingly.
+ """
+ if not len(self.running_tasks):
+ self.log.debug("No running tasks to process.")
+ return
+
+ self.process_queue(self.sqs_queue_url)
+ if self.dlq_url and self.running_tasks:
+ self.process_queue(self.dlq_url)
+
+ def process_queue(self, queue_url: str):
+ """
+ Poll the SQS queue for messages indicating task completion.
+
+ Each message is expected to contain a JSON payload with 'task_key' and
'return_code'.
+
+ Based on the return code, update the task state accordingly.
+ """
+ response = self.sqs_client.receive_message(
+ QueueUrl=queue_url,
+ MaxNumberOfMessages=10,
+ )
+
+ messages = response.get("Messages", [])
+ # Pagination? Maybe we don't need it. Since we always delete messages
after looking at them.
+ # But then that may delete messages that could have been adopted.
Let's leave it for now and see how it goes.
+ if messages and queue_url == self.dlq_url:
+ self.log.warning("%d messages received from the dead letter
queue", len(messages))
+
+ for message in messages:
+ receipt_handle = message["ReceiptHandle"]
+ body = json.loads(message["Body"])
+ return_code = body.get("return_code")
+ ser_task_key = body.get("task_key")
+ # Fetch the real task key from the running_tasks dict, using the
serialized task key.
+ try:
+ task_key = self.running_tasks[ser_task_key]
+ except KeyError:
+ self.log.warning(
+ "Received task %s from the queue which is not found in
running tasks. Removing message.",
+ ser_task_key,
+ )
+ task_key = None
+
+ if task_key:
+ if return_code == 0:
+ self.success(task_key)
+ self.log.info(
+ "Successful Lambda invocation for task %s received
from SQS queue.", task_key
+ )
+ else:
+ # In this case the Lambda likely started but failed at run
time since we got a non-zero
+ # return code. We could consider retrying these tasks
within the executor, because this _likely_
+ # means the Airflow task did not run to completion,
however we can't be sure (maybe the
+ # lambda runtime code has a bug and is returning a
non-zero when it actually passed?). So
+ # perhaps not retrying is the safest option.
+ self.fail(task_key)
+ self.log.error(
+ "Lambda invocation for task: %s has failed to run with
return code %s",
+ task_key,
+ return_code,
+ )
+ # Remove the task from the tracking mapping.
+ self.running_tasks.pop(ser_task_key)
+
+ # Delete the message from the queue.
+ self.sqs_client.delete_message(QueueUrl=queue_url,
ReceiptHandle=receipt_handle)
+
+ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) ->
Sequence[TaskInstance]:
+ """
+ Adopt task instances which have an external_executor_id (the
serialized task key).
+
+ Anything that is not adopted will be cleared by the scheduler and
becomes eligible for re-scheduling.
+
+ :param tis: The task instances to adopt.
+ """
+ with Stats.timer("lambda_executor.adopt_task_instances.duration"):
+ adopted_tis: list[TaskInstance] = []
+
+ if serialized_task_keys := [
+ (ti, ti.external_executor_id) for ti in tis if
ti.external_executor_id
+ ]:
+ for ti, ser_task_key in serialized_task_keys:
+ try:
+ task_key =
TaskInstanceKey.from_dict(json.loads(ser_task_key))
+ except Exception:
+ # If that task fails to deserialize, we should just
skip it.
+ self.log.exception(
+ "Task failed to be adopted because the key could
not be deserialized"
+ )
+ continue
+ self.running_tasks[ser_task_key] = task_key
+ adopted_tis.append(ti)
+
+ if adopted_tis:
+ tasks = [f"{task} in state {task.state}" for task in
adopted_tis]
+ task_instance_str = "\n\t".join(tasks)
+ self.log.info(
+ "Adopted the following %d tasks from a dead
executor:\n\t%s",
+ len(adopted_tis),
+ task_instance_str,
+ )
+
+ not_adopted_tis = [ti for ti in tis if ti not in adopted_tis]
+ return not_adopted_tis
+
+ def end(self, heartbeat_interval=10):
+ """
+ End execution. Poll until all outstanding tasks are marked as
completed.
+
+ This is a blocking call and async Lambda tasks can not be cancelled,
so this will wait until
+ all tasks are either completed or the timeout is reached.
+
+ :param heartbeat_interval: The interval in seconds to wait between
checks for task completion.
+ """
+ self.log.info("Received signal to end, waiting for outstanding tasks
to finish.")
+ time_to_wait = int(conf.get(CONFIG_GROUP_NAME,
AllLambdaConfigKeys.END_WAIT_TIMEOUT))
+ start_time = timezone.utcnow()
+ while True:
+ if time_to_wait:
+ current_time = timezone.utcnow()
+ elapsed_time = (current_time - start_time).total_seconds()
+ if elapsed_time > time_to_wait:
+ self.log.warning(
+ "Timed out waiting for tasks to finish. Some tasks may
not be handled gracefully"
+ " as the executor is force ending due to timeout."
+ )
+ break
+ self.sync()
+ if not self.running_tasks:
+ self.log.info("All tasks completed; executor ending.")
+ break
+ self.log.info("Waiting for %d task(s) to complete.",
len(self.running_tasks))
+ time.sleep(heartbeat_interval)
+
+ def terminate(self):
+ """Get called when the daemon receives a SIGTERM."""
+ self.log.warning("Terminating Lambda executor. In-flight tasks cannot
be stopped.")
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/utils.py
b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/utils.py
new file mode 100644
index 00000000000..483b81508db
--- /dev/null
+++
b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/utils.py
@@ -0,0 +1,70 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import datetime
+from collections.abc import Sequence
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any
+
+from airflow.providers.amazon.aws.executors.utils.base_config_keys import
BaseConfigKeys
+
+if TYPE_CHECKING:
+ from airflow.models.taskinstancekey import TaskInstanceKey
+
+
+CONFIG_GROUP_NAME = "aws_lambda_executor"
+INVALID_CREDENTIALS_EXCEPTIONS = [
+ "ExpiredTokenException",
+ "InvalidClientTokenId",
+ "UnrecognizedClientException",
+]
+
+
+@dataclass
+class LambdaQueuedTask:
+ """Represents a Lambda task that is queued. The task will be run in the
next heartbeat."""
+
+ key: TaskInstanceKey
+ command: CommandType
+ queue: str
+ executor_config: ExecutorConfigType
+ attempt_number: int
+ next_attempt_time: datetime.datetime
+
+
+class InvokeLambdaKwargsConfigKeys(BaseConfigKeys):
+ """Config keys loaded which are valid lambda invoke args."""
+
+ FUNCTION_NAME = "function_name"
+ QUALIFIER = "function_qualifier"
+
+
+class AllLambdaConfigKeys(InvokeLambdaKwargsConfigKeys):
+ """All config keys which are related to the Lambda Executor."""
+
+ AWS_CONN_ID = "conn_id"
+ CHECK_HEALTH_ON_STARTUP = "check_health_on_startup"
+ MAX_INVOKE_ATTEMPTS = "max_run_task_attempts"
+ REGION_NAME = "region_name"
+ QUEUE_URL = "queue_url"
+ DLQ_URL = "dead_letter_queue_url"
+ END_WAIT_TIMEOUT = "end_wait_timeout"
+
+
+CommandType = Sequence[str]
+ExecutorConfigType = dict[str, Any]
diff --git a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py
b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py
index 136c907a08e..731e9bb377c 100644
--- a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py
+++ b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py
@@ -1152,6 +1152,74 @@ def get_provider_info():
},
},
},
+ "aws_lambda_executor": {
+ "description": "This section only applies if you are using the
AwsLambdaExecutor in\nAirflow's ``[core.executor]`` configuration.\nFor more
information
see:\nhttps://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/lambda/client/invoke.html\nhttps://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html\nhttps://airflow.apache.org/docs/apache-airflow-providers-amazon/stable/executors/lambda-executor.html\n",
+ "options": {
+ "conn_id": {
+ "description": "The Airflow connection (i.e.
credentials) used by the Lambda executor to make API calls.\n",
+ "version_added": "9.9.0",
+ "type": "string",
+ "example": "aws_default",
+ "default": "aws_default",
+ },
+ "region_name": {
+ "description": "The name of the AWS Region where
Amazon Lambda is configured.\n",
+ "version_added": "9.9.0",
+ "type": "string",
+ "example": "us-east-1",
+ "default": None,
+ },
+ "check_health_on_startup": {
+ "description": "Whether or not to check the Lambda
Executor health on startup.\n",
+ "version_added": "9.9.0",
+ "type": "boolean",
+ "example": "True",
+ "default": "True",
+ },
+ "max_run_task_attempts": {
+ "description": "The maximum number of times the Lambda
Executor should attempt to start an Airflow task.\n",
+ "version_added": "9.9.0",
+ "type": "integer",
+ "example": "3",
+ "default": "3",
+ },
+ "queue_url": {
+ "description": "The URL of the SQS queue to use for
the Lambda executor. Required.\nThe Lambda executor will poll this queue for
results of the lambda function's Airflow Task.\n",
+ "version_added": "9.9.0",
+ "type": "string",
+ "example": "airflow-lambda-executor-results-queue",
+ "default": None,
+ },
+ "dead_letter_queue_url": {
+ "description": "The URL of the SQS dead letter queue
to use for the Lambda function. The Lambda executor\nwill poll this queue for
timeout/exception results of the lambda function. Required\n",
+ "version_added": "9.9.0",
+ "type": "string",
+ "example": "airflow-lambda-executor-dlq",
+ "default": None,
+ },
+ "function_name": {
+ "description": "The name of the Lambda function to
invoke. Required.\n",
+ "version_added": "9.9.0",
+ "type": "string",
+ "example": "airflow-lambda-executor-function",
+ "default": None,
+ },
+ "qualifier": {
+ "description": "The version or alias of the Lambda
function to invoke. If not specified, the latest version is used.\n",
+ "version_added": "9.9.0",
+ "type": "string",
+ "example": "1",
+ "default": None,
+ },
+ "end_wait_timeout": {
+ "description": "The number of seconds to wait for all
lambda executor invocations to complete when terminating
the\nexecutor/scheduler. Default of 0 means wait indefinitely.\n",
+ "version_added": "9.9.0",
+ "type": "integer",
+ "example": "500",
+ "default": "0",
+ },
+ },
+ },
"aws_ecs_executor": {
"description": "This section only applies if you are using the
AwsEcsExecutor in\nAirflow's ``[core]`` configuration.\nFor more information on
any of these execution parameters, see the link
below:\nhttps://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs/client/run_task.html\nFor
boto3 credential management,
see\nhttps://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html\n",
"options": {
diff --git
a/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/__init__.py
b/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/__init__.py
new file mode 100644
index 00000000000..13a83393a91
--- /dev/null
+++ b/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git
a/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py
b/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py
new file mode 100644
index 00000000000..f1123ca5e07
--- /dev/null
+++
b/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py
@@ -0,0 +1,928 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import datetime as dt
+import json
+from unittest import mock
+
+import pytest
+from botocore.exceptions import ClientError
+from semver import VersionInfo
+
+from airflow.exceptions import AirflowException
+from airflow.executors.base_executor import BaseExecutor
+from airflow.models.taskinstance import TaskInstance
+from airflow.models.taskinstancekey import TaskInstanceKey
+from airflow.providers.amazon.aws.executors.aws_lambda import lambda_executor
+from airflow.providers.amazon.aws.executors.aws_lambda.lambda_executor import
AwsLambdaExecutor
+from airflow.providers.amazon.aws.executors.aws_lambda.utils import
CONFIG_GROUP_NAME, AllLambdaConfigKeys
+from airflow.utils import timezone
+from airflow.utils.state import TaskInstanceState
+from airflow.version import version as airflow_version_str
+
+from tests_common.test_utils.config import conf_vars
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
+
+airflow_version = VersionInfo(*map(int, airflow_version_str.split(".")[:3]))
+
+DEFAULT_QUEUE_URL = "queue-url"
+DEFAULT_DLQ_URL = "dlq-url"
+DEFAULT_FUNCTION_NAME = "function-name"
+
+
[email protected]
+def set_env_vars():
+ overrides: dict[tuple[str, str], str] = {
+ (CONFIG_GROUP_NAME, AllLambdaConfigKeys.REGION_NAME): "us-west-1",
+ (CONFIG_GROUP_NAME, AllLambdaConfigKeys.FUNCTION_NAME):
DEFAULT_FUNCTION_NAME,
+ (CONFIG_GROUP_NAME, AllLambdaConfigKeys.QUEUE_URL): DEFAULT_QUEUE_URL,
+ (CONFIG_GROUP_NAME, AllLambdaConfigKeys.DLQ_URL): DEFAULT_DLQ_URL,
+ (CONFIG_GROUP_NAME, AllLambdaConfigKeys.QUALIFIER): "1",
+ (CONFIG_GROUP_NAME, AllLambdaConfigKeys.MAX_INVOKE_ATTEMPTS): "3",
+ (CONFIG_GROUP_NAME, AllLambdaConfigKeys.CHECK_HEALTH_ON_STARTUP):
"True",
+ }
+ with conf_vars(overrides):
+ yield
+
+
[email protected]
+def mock_airflow_key():
+ def _key():
+ key_mock = mock.Mock()
+ # Use a "random" value (memory id of the mock obj) so each key
serializes uniquely
+ key_mock._asdict = mock.Mock(return_value={"mock_key": id(key_mock)})
+ return key_mock
+
+ return _key
+
+
+def _generate_mock_cmd():
+ return ["airflow", "tasks", "run", "dag_id", "task_id", "run_id",
"--local"]
+
+
+# The following two fixtures look different because no existing test
+# cares if they have unique values, so the same value is always used.
[email protected]
+def mock_cmd():
+ return _generate_mock_cmd()
+
+
[email protected]
+def mock_executor(set_env_vars) -> AwsLambdaExecutor:
+ """Mock Lambda to a repeatable starting state.."""
+ executor = AwsLambdaExecutor()
+ executor.IS_BOTO_CONNECTION_HEALTHY = True
+
+ # Replace boto3 clients with mocks
+ lambda_mock = mock.Mock(spec=executor.lambda_client)
+ lambda_mock.invoke.return_value = {"StatusCode": 0, "failures": []}
+ executor.lambda_client = lambda_mock
+
+ sqs_mock = mock.Mock(spec=executor.sqs_client)
+ sqs_mock.receive_message.return_value = {"Messages": []}
+ executor.sqs_client = sqs_mock
+
+ return executor
+
+
+class TestAwsLambdaExecutor:
+ @mock.patch(
+
"airflow.providers.amazon.aws.executors.aws_lambda.lambda_executor.AwsLambdaExecutor.change_state"
+ )
+ def test_execute(self, change_state_mock, mock_airflow_key, mock_executor,
mock_cmd):
+ """Test execution from end-to-end."""
+ airflow_key = mock_airflow_key()
+ ser_airflow_key = json.dumps(airflow_key._asdict())
+
+ assert len(mock_executor.pending_tasks) == 0
+ mock_executor.execute_async(airflow_key, mock_cmd)
+ assert len(mock_executor.pending_tasks) == 1
+
+ mock_executor.attempt_task_runs()
+ mock_executor.lambda_client.invoke.assert_called_once()
+
+ # Task is stored in active worker.
+ assert len(mock_executor.running_tasks) == 1
+ assert json.dumps(airflow_key._asdict()) in mock_executor.running_tasks
+ change_state_mock.assert_called_once_with(
+ airflow_key, TaskInstanceState.RUNNING, ser_airflow_key,
remove_running=False
+ )
+
+ @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow
3+")
+ @mock.patch(
+
"airflow.providers.amazon.aws.executors.aws_lambda.lambda_executor.AwsLambdaExecutor.change_state"
+ )
+ def test_task_sdk(self, change_state_mock, mock_airflow_key,
mock_executor, mock_cmd):
+ """Test task sdk execution from end-to-end."""
+ from airflow.executors.workloads import ExecuteTask
+
+ airflow_key = mock_airflow_key()
+ ser_airflow_key = json.dumps(airflow_key._asdict())
+
+ workload = mock.Mock(spec=ExecuteTask)
+ workload.ti = mock.Mock(spec=TaskInstance)
+ workload.ti.key = airflow_key
+ ser_workload = json.dumps({"test_key": "test_value"})
+ workload.model_dump_json.return_value = ser_workload
+
+ mock_executor.queue_workload(workload, mock.Mock())
+
+ assert mock_executor.queued_tasks[workload.ti.key] == workload
+ assert len(mock_executor.pending_tasks) == 0
+ assert len(mock_executor.running) == 0
+ mock_executor._process_workloads([workload])
+ assert len(mock_executor.queued_tasks) == 0
+ assert len(mock_executor.running) == 1
+ assert workload.ti.key in mock_executor.running
+ assert len(mock_executor.pending_tasks) == 1
+ assert mock_executor.pending_tasks[0].command == [
+ "python",
+ "-m",
+ "airflow.sdk.execution_time.execute_workload",
+ "--json-string",
+ '{"test_key": "test_value"}',
+ ]
+
+ mock_executor.attempt_task_runs()
+ mock_executor.lambda_client.invoke.assert_called_once()
+ assert len(mock_executor.pending_tasks) == 0
+
+ # Task is stored in active worker.
+ assert len(mock_executor.running_tasks) == 1
+ assert mock_executor.running_tasks[ser_airflow_key] == workload.ti.key
+ change_state_mock.assert_called_once_with(
+ workload.ti.key, TaskInstanceState.RUNNING, ser_airflow_key,
remove_running=False
+ )
+
+ @mock.patch.object(lambda_executor, "calculate_next_attempt_delay",
return_value=dt.timedelta(seconds=0))
+ def test_success_execute_api_exception(self, mock_backoff, mock_executor,
mock_cmd, mock_airflow_key):
+ """Test what happens when Lambda throws an initial exception on
invoke, but ultimately passes on retries."""
+ airflow_key = mock_airflow_key()
+ ser_airflow_key = json.dumps(airflow_key._asdict())
+
+ invoke_exception = Exception("Test exception")
+ invoke_success = {"StatusCode": 0, "failures": []}
+ mock_executor.lambda_client.invoke.side_effect = [invoke_exception,
invoke_exception, invoke_success]
+ mock_executor.execute_async(airflow_key, mock_cmd)
+ expected_retry_count = 2
+
+ # Fail 2 times
+ for _ in range(expected_retry_count):
+ mock_executor.attempt_task_runs()
+ # Task is not stored in active workers.
+ assert len(mock_executor.running_tasks) == 0
+
+ # Pass in last attempt
+ mock_executor.attempt_task_runs()
+ assert len(mock_executor.pending_tasks) == 0
+ assert ser_airflow_key in mock_executor.running_tasks
+ assert mock_backoff.call_count == expected_retry_count
+ for attempt_number in range(1, expected_retry_count):
+ mock_backoff.assert_has_calls([mock.call(attempt_number)])
+
+ def test_failed_execute_api_exception(self, mock_executor, mock_cmd,
mock_airflow_key):
+ """Test what happens when Lambda refuses to execute a task and throws
an exception"""
+ mock_airflow_key = mock_airflow_key()
+
+ mock_executor.lambda_client.invoke.side_effect = Exception("Test
exception")
+ mock_executor.execute_async(mock_airflow_key, mock_cmd)
+
+ # No matter what, don't schedule until invoke becomes successful.
+ for _ in range(int(mock_executor.max_invoke_attempts) * 2):
+ mock_executor.attempt_task_runs()
+ # Task is not stored in running tasks
+ assert len(mock_executor.running_tasks) == 0
+
+ def test_failed_execute_creds_exception(self, mock_executor, mock_cmd,
mock_airflow_key):
+ """Test what happens when Lambda refuses to execute a task and throws
an exception due to credentials"""
+ airflow_key = mock_airflow_key()
+
+ mock_executor.IS_BOTO_CONNECTION_HEALTHY = True
+ mock_executor.execute_async(airflow_key, mock_cmd)
+ assert mock_executor.pending_tasks[0].attempt_number == 1
+
+ error_to_raise = ClientError(
+ {"Error": {"Code": "ExpiredTokenException", "Message": "foobar"}},
"OperationName"
+ )
+ mock_executor.lambda_client.invoke.side_effect = error_to_raise
+
+ # Sync will ultimately call attempt_task_runs, which is the code under
test
+ mock_executor.sync()
+
+ # Task should end up back in the queue
+ assert mock_executor.pending_tasks[0].key == airflow_key
+ # The connection should get marked as unhealthy
+ assert not mock_executor.IS_BOTO_CONNECTION_HEALTHY
+ # We retry on connections issues indefinitely, so the attempt number
should be 1
+ assert mock_executor.pending_tasks[0].attempt_number == 1
+
+ def test_failed_execute_client_error_exception(self, mock_executor,
mock_cmd, mock_airflow_key):
+ """Test what happens when Lambda refuses to execute a task and throws
an exception for non-credentials issue"""
+ airflow_key = mock_airflow_key()
+ mock_executor.IS_BOTO_CONNECTION_HEALTHY = True
+ mock_executor.execute_async(airflow_key, mock_cmd)
+ assert mock_executor.pending_tasks[0].attempt_number == 1
+
+ error_to_raise = ClientError(
+ {"Error": {"Code": "RandomeError", "Message": "foobar"}},
"OperationName"
+ )
+ mock_executor.lambda_client.invoke.side_effect = error_to_raise
+
+ # Sync will ultimately call attempt_task_runs, which is the code under
test
+ mock_executor.sync()
+
+ # Task should end up back in the queue
+ assert mock_executor.pending_tasks[0].key == airflow_key
+ # The connection should stay marked as healthy because the error is
something else
+ assert mock_executor.IS_BOTO_CONNECTION_HEALTHY
+ # Not a retry so increment attempts
+ assert mock_executor.pending_tasks[0].attempt_number == 2
+
+ @mock.patch.object(lambda_executor, "calculate_next_attempt_delay",
return_value=dt.timedelta(seconds=0))
+ def test_attempt_task_runs_attempts_when_tasks_fail(self, _,
mock_executor):
+ """
+ Test case when all tasks fail to run.
+
+ The executor should attempt each task exactly once per sync()
iteration.
+ It should preserve the order of tasks, and attempt each task up to
+ `max_invoke_attempts` times before dropping the task.
+ """
+ airflow_keys = [
+ TaskInstanceKey("a", "task_a", "c", 1, -1),
+ TaskInstanceKey("a", "task_b", "c", 1, -1),
+ ]
+ airflow_cmd1 = _generate_mock_cmd()
+ airflow_cmd2 = _generate_mock_cmd()
+ commands = [airflow_cmd1, airflow_cmd2]
+
+ failures = [Exception("Failure 1"), Exception("Failure 2")]
+
+ mock_executor.execute_async(airflow_keys[0], commands[0])
+ mock_executor.execute_async(airflow_keys[1], commands[1])
+
+ assert len(mock_executor.pending_tasks) == 2
+ assert len(mock_executor.running_tasks) == 0
+
+ mock_executor.lambda_client.invoke.side_effect = failures
+ mock_executor.attempt_task_runs()
+
+ for i in range(2):
+ payload =
json.loads(mock_executor.lambda_client.invoke.call_args_list[i].kwargs["Payload"])
+ assert airflow_keys[i].task_id in payload["task_key"]
+
+ assert len(mock_executor.pending_tasks) == 2
+ assert len(mock_executor.running_tasks) == 0
+
+ mock_executor.lambda_client.invoke.call_args_list.clear()
+
+ mock_executor.lambda_client.invoke.side_effect = failures
+ mock_executor.attempt_task_runs()
+
+ for i in range(2):
+ payload =
json.loads(mock_executor.lambda_client.invoke.call_args_list[i].kwargs["Payload"])
+ assert airflow_keys[i].task_id in payload["task_key"]
+
+ assert len(mock_executor.pending_tasks) == 2
+ assert len(mock_executor.running_tasks) == 0
+
+ mock_executor.lambda_client.invoke.call_args_list.clear()
+
+ mock_executor.lambda_client.invoke.side_effect = failures
+ mock_executor.attempt_task_runs()
+
+ assert (
+ len(mock_executor.pending_tasks) == 0
+ ) # Pending now zero since we've had three failures to invoke
+ assert len(mock_executor.running_tasks) == 0
+
+ if airflow_version >= (2, 10, 0):
+ events = [(x.event, x.task_id, x.try_number) for x in
mock_executor._task_event_logs]
+ assert events == [
+ ("lambda invoke failure", "task_a", 1),
+ ("lambda invoke failure", "task_b", 1),
+ ]
+
+ @mock.patch.object(lambda_executor, "calculate_next_attempt_delay",
return_value=dt.timedelta(seconds=0))
+ def test_attempt_task_runs_attempts_when_some_tasks_fal(self, _,
mock_executor):
+ """
+ Test case when one task fail to run, others succeed, and a new task
gets queued.
+
+ """
+ airflow_keys = [
+ TaskInstanceKey("a", "task_a", "c", 1, -1),
+ TaskInstanceKey("a", "task_b", "c", 1, -1),
+ ]
+ airflow_cmd1 = _generate_mock_cmd()
+ airflow_cmd2 = _generate_mock_cmd()
+ airflow_commands = [airflow_cmd1, airflow_cmd2]
+
+ success_response = {"StatusCode": 0, "failures": []}
+
+ responses = [Exception("Failure 1"), success_response]
+
+ mock_executor.execute_async(airflow_keys[0], airflow_commands[0])
+ mock_executor.execute_async(airflow_keys[1], airflow_commands[1])
+
+ assert len(mock_executor.pending_tasks) == 2
+
+ mock_executor.lambda_client.invoke.side_effect = responses
+ mock_executor.attempt_task_runs()
+
+ for i in range(2):
+ payload =
json.loads(mock_executor.lambda_client.invoke.call_args_list[i].kwargs["Payload"])
+ assert airflow_keys[i].task_id in payload["task_key"]
+
+ assert len(mock_executor.pending_tasks) == 1
+ assert len(mock_executor.running_tasks) == 1
+
+ mock_executor.lambda_client.invoke.call_args_list.clear()
+
+ # queue new task
+ airflow_keys[1] = TaskInstanceKey("a", "task_c", "c", 1, -1)
+ airflow_commands[1] = _generate_mock_cmd()
+ mock_executor.execute_async(airflow_keys[1], airflow_commands[1])
+
+ assert len(mock_executor.pending_tasks) == 2
+ # assert that the order of pending tasks is preserved i.e. the first
task is 1st etc.
+ assert mock_executor.pending_tasks[0].key == airflow_keys[0]
+ assert mock_executor.pending_tasks[0].command == airflow_commands[0]
+
+ responses = [Exception("Failure 1"), success_response]
+ mock_executor.lambda_client.invoke.side_effect = responses
+ mock_executor.attempt_task_runs()
+
+ for i in range(2):
+ payload =
json.loads(mock_executor.lambda_client.invoke.call_args_list[i].kwargs["Payload"])
+ assert airflow_keys[i].task_id in payload["task_key"]
+
+ assert len(mock_executor.pending_tasks) == 1
+ assert len(mock_executor.running_tasks) == 2
+
+ mock_executor.lambda_client.invoke.call_args_list.clear()
+
+ responses = [Exception("Failure 1")]
+ mock_executor.lambda_client.invoke.side_effect = responses
+ mock_executor.attempt_task_runs()
+
+ payload =
json.loads(mock_executor.lambda_client.invoke.call_args_list[0].kwargs["Payload"])
+ assert airflow_keys[0].task_id in payload["task_key"]
+
+ if airflow_version >= (2, 10, 0):
+ events = [(x.event, x.task_id, x.try_number) for x in
mock_executor._task_event_logs]
+ assert events == [("lambda invoke failure", "task_a", 1)]
+
+ @mock.patch.object(BaseExecutor, "fail")
+ @mock.patch.object(BaseExecutor, "success")
+ def test_sync_running_dlq(self, success_mock, fail_mock, mock_executor,
mock_airflow_key):
+ airflow_key = mock_airflow_key()
+ ser_airflow_key = json.dumps(airflow_key._asdict())
+
+ mock_executor.running_tasks.clear()
+ mock_executor.running_tasks[ser_airflow_key] = airflow_key
+ mock_executor.sqs_client.receive_message.side_effect = [
+ {}, # First request from the results queue will be empt
+ {
+ # Second request from the DLQ will have a message
+ "Messages": [
+ {
+ "ReceiptHandle": "receipt_handle",
+ "Body": json.dumps(
+ {
+ "task_key": ser_airflow_key,
+ # DLQ messages will have the input (task_key,
command) instead of return_code
+ "command": "command",
+ }
+ ),
+ }
+ ]
+ },
+ ]
+
+ mock_executor.sync_running_tasks()
+ # Receive messages should be called twice
+ assert mock_executor.sqs_client.receive_message.call_count == 2
+ assert
mock_executor.sqs_client.receive_message.call_args_list[0].kwargs == {
+ "QueueUrl": DEFAULT_QUEUE_URL,
+ "MaxNumberOfMessages": 10,
+ }
+
+ assert
mock_executor.sqs_client.receive_message.call_args_list[1].kwargs == {
+ "QueueUrl": DEFAULT_DLQ_URL,
+ "MaxNumberOfMessages": 10,
+ }
+
+ # Task is not stored in active workers.
+ assert len(mock_executor.running_tasks) == 0
+ success_mock.assert_not_called()
+ fail_mock.assert_called_once()
+ assert mock_executor.sqs_client.delete_message.call_count == 1
+ assert
mock_executor.sqs_client.delete_message.call_args_list[0].kwargs == {
+ "QueueUrl": DEFAULT_DLQ_URL,
+ "ReceiptHandle": "receipt_handle",
+ }
+
+ @mock.patch.object(BaseExecutor, "fail")
+ @mock.patch.object(BaseExecutor, "success")
+ def test_sync_running_success(self, success_mock, fail_mock,
mock_executor, mock_airflow_key):
+ airflow_key = mock_airflow_key()
+ ser_airflow_key = json.dumps(airflow_key._asdict())
+
+ mock_executor.running_tasks.clear()
+ mock_executor.running_tasks[ser_airflow_key] = airflow_key
+ # Success message
+ mock_executor.sqs_client.receive_message.return_value = {
+ "Messages": [
+ {
+ "ReceiptHandle": "receipt_handle",
+ "Body": json.dumps(
+ {
+ "task_key": ser_airflow_key,
+ "return_code": 0,
+ }
+ ),
+ }
+ ]
+ }
+
+ mock_executor.sync_running_tasks()
+ mock_executor.sqs_client.receive_message.assert_called_once()
+ assert
mock_executor.sqs_client.receive_message.call_args_list[0].kwargs == {
+ "QueueUrl": DEFAULT_QUEUE_URL,
+ "MaxNumberOfMessages": 10,
+ }
+
+ # Task is not stored in active workers.
+ assert len(mock_executor.running_tasks) == 0
+ # Task is immediately succeeded.
+ success_mock.assert_called_once()
+ fail_mock.assert_not_called()
+ assert mock_executor.sqs_client.delete_message.call_count == 1
+ assert
mock_executor.sqs_client.delete_message.call_args_list[0].kwargs == {
+ "QueueUrl": DEFAULT_QUEUE_URL,
+ "ReceiptHandle": "receipt_handle",
+ }
+
+ @mock.patch.object(BaseExecutor, "fail")
+ @mock.patch.object(BaseExecutor, "success")
+ def test_sync_running_fail(self, success_mock, fail_mock, mock_executor,
mock_airflow_key):
+ airflow_key = mock_airflow_key()
+ ser_airflow_key = json.dumps(airflow_key._asdict())
+
+ mock_executor.running_tasks.clear()
+ mock_executor.running_tasks[ser_airflow_key] = airflow_key
+ # Failure message
+ mock_executor.sqs_client.receive_message.return_value = {
+ "Messages": [
+ {
+ "ReceiptHandle": "receipt_handle",
+ "Body": json.dumps(
+ {
+ "task_key": ser_airflow_key,
+ "return_code": 1, # Non-zero return code, task
failed
+ }
+ ),
+ }
+ ]
+ }
+
+ mock_executor.sync_running_tasks()
+ mock_executor.sqs_client.receive_message.assert_called_once()
+
+ # Task is not stored in active workers.
+ assert len(mock_executor.running_tasks) == 0
+ # Task is immediately succeeded.
+ success_mock.assert_not_called()
+ fail_mock.assert_called_once()
+ assert mock_executor.sqs_client.delete_message.call_count == 1
+
+ @mock.patch.object(BaseExecutor, "fail")
+ @mock.patch.object(BaseExecutor, "success")
+ def test_sync_running_short_circuit(self, success_mock, fail_mock,
mock_executor, mock_airflow_key):
+ mock_executor.running_tasks.clear()
+ # No running tasks, so we will short circuit
+
+ mock_executor.sync_running_tasks()
+ mock_executor.sqs_client.receive_message.assert_not_called()
+
+ # Task is still stored in active workers.
+ assert len(mock_executor.running_tasks) == 0
+ # Task is immediately succeeded.
+ success_mock.assert_not_called()
+ fail_mock.assert_not_called()
+ assert mock_executor.sqs_client.delete_message.call_count == 0
+
+ @mock.patch.object(BaseExecutor, "fail")
+ @mock.patch.object(BaseExecutor, "success")
+ def test_sync_running_no_updates(self, success_mock, fail_mock,
mock_executor, mock_airflow_key):
+ airflow_key = mock_airflow_key()
+ ser_airflow_key = json.dumps(airflow_key._asdict())
+
+ mock_executor.running_tasks.clear()
+ mock_executor.running_tasks[ser_airflow_key] = airflow_key
+ # No messages, so we will not loop
+ mock_executor.sqs_client.receive_message.return_value = {"Messages":
[]}
+
+ mock_executor.sync_running_tasks()
+ # Both the results queue and DLQ should have been checked
+ assert mock_executor.sqs_client.receive_message.call_count == 2
+
+ # Task is still stored in active workers.
+ assert len(mock_executor.running_tasks) == 1
+ # Task is immediately succeeded.
+ success_mock.assert_not_called()
+ fail_mock.assert_not_called()
+ assert mock_executor.sqs_client.delete_message.call_count == 0
+
+ @mock.patch.object(BaseExecutor, "fail")
+ @mock.patch.object(BaseExecutor, "success")
+ def test_sync_running_two_tasks_one_relevant(
+ self, success_mock, fail_mock, mock_executor, mock_airflow_key
+ ):
+ airflow_key = mock_airflow_key()
+ ser_airflow_key = json.dumps(airflow_key._asdict())
+ airflow_key_2 = mock_airflow_key()
+ ser_airflow_key_2 = json.dumps(airflow_key_2._asdict())
+
+ mock_executor.running_tasks.clear()
+ mock_executor.running_tasks[ser_airflow_key] = airflow_key
+ mock_executor.running_tasks[ser_airflow_key_2] = airflow_key_2
+ # Success message
+ mock_executor.sqs_client.receive_message.side_effect = [
+ {
+ "Messages": [
+ {
+ "ReceiptHandle": "receipt_handle",
+ "Body": json.dumps(
+ {
+ "task_key": ser_airflow_key,
+ "return_code": 0,
+ }
+ ),
+ }
+ ]
+ },
+ {}, # No messages from DLQ
+ ]
+
+ mock_executor.sync_running_tasks()
+ # Both the results queue and DLQ should have been checked
+ assert mock_executor.sqs_client.receive_message.call_count == 2
+
+ # One task left running
+ assert len(mock_executor.running_tasks) == 1
+ # Task one completed, task two is still running
+ assert ser_airflow_key_2 in mock_executor.running_tasks
+ # Task is immediately succeeded.
+ success_mock.assert_called_once()
+ fail_mock.assert_not_called()
+ assert mock_executor.sqs_client.delete_message.call_count == 1
+
+ @mock.patch.object(BaseExecutor, "fail")
+ @mock.patch.object(BaseExecutor, "success")
+ def test_sync_running_unknown_task(self, success_mock, fail_mock,
mock_executor, mock_airflow_key):
+ airflow_key = mock_airflow_key()
+ ser_airflow_key = json.dumps(airflow_key._asdict())
+ airflow_key_2 = mock_airflow_key()
+ ser_airflow_key_2 = json.dumps(airflow_key_2._asdict())
+
+ mock_executor.running_tasks.clear()
+ # Only add one of the tasks to the running list, the other will be
unknown
+ mock_executor.running_tasks[ser_airflow_key] = airflow_key
+
+ # Receive the known task and unknown task
+ mock_executor.sqs_client.receive_message.return_value = {
+ "Messages": [
+ {
+ "ReceiptHandle": "receipt_handle",
+ "Body": json.dumps(
+ {
+ "task_key": ser_airflow_key,
+ "return_code": 0,
+ }
+ ),
+ },
+ {
+ "ReceiptHandle": "receipt_handle",
+ "Body": json.dumps(
+ {
+ "task_key": ser_airflow_key_2,
+ "return_code": 0,
+ }
+ ),
+ },
+ ]
+ }
+
+ mock_executor.sync_running_tasks()
+ mock_executor.sqs_client.receive_message.assert_called_once()
+
+ # The known task is set to succeeded, unknown task is dropped
+ assert len(mock_executor.running_tasks) == 0
+ success_mock.assert_called_once()
+ fail_mock.assert_not_called()
+ # Both messages from the queue should be deleted, both known and
unknown
+ assert mock_executor.sqs_client.delete_message.call_count == 2
+
+ def test_start_no_check_health(self, mock_executor):
+ mock_executor.check_health = mock.Mock()
+ with conf_vars({(CONFIG_GROUP_NAME,
AllLambdaConfigKeys.CHECK_HEALTH_ON_STARTUP): "False"}):
+ mock_executor.start()
+
+ assert mock_executor.check_health.call_count == 0
+
+ def test_start_check_health_success(self, mock_executor):
+ mock_executor.check_health = mock.Mock()
+ with conf_vars({(CONFIG_GROUP_NAME,
AllLambdaConfigKeys.CHECK_HEALTH_ON_STARTUP): "True"}):
+ mock_executor.start()
+
+ assert mock_executor.check_health.call_count == 1
+
+ def test_start_check_health_fail(self, mock_executor):
+ mock_executor.check_health = mock.Mock()
+ mock_executor.check_health.side_effect = AirflowException("Test
exception")
+ with conf_vars({(CONFIG_GROUP_NAME,
AllLambdaConfigKeys.CHECK_HEALTH_ON_STARTUP): "True"}):
+ with pytest.raises(AirflowException):
+ mock_executor.start()
+
+ assert mock_executor.check_health.call_count == 1
+
+ def test_check_health_success(self, mock_executor):
+ mock_executor.IS_BOTO_CONNECTION_HEALTHY = False
+ mock_executor.sqs_client.get_queue_attributes.return_value = {
+ "Attributes": {"ApproximateNumberOfMessages": 0}
+ }
+ mock_executor.lambda_client.get_function.return_value = {
+ "Configuration": {
+ "FunctionName": DEFAULT_FUNCTION_NAME,
+ "State": "Active",
+ }
+ }
+ mock_executor.check_health()
+ assert mock_executor.sqs_client.get_queue_attributes.call_count == 2
+ assert mock_executor.lambda_client.get_function.call_count == 1
+ assert mock_executor.IS_BOTO_CONNECTION_HEALTHY
+
+ def test_check_health_lambda_fails(self, mock_executor):
+ mock_executor.IS_BOTO_CONNECTION_HEALTHY = False
+ mock_executor.sqs_client.get_queue_attributes.return_value = {
+ "Attributes": {"ApproximateNumberOfMessages": 0}
+ }
+ mock_executor.lambda_client.get_function.return_value = ClientError(
+ {"Error": {"Code": "ResourceNotFoundException", "Message":
"foobar"}}, "OperationName"
+ )
+
+ with pytest.raises(AirflowException):
+ mock_executor.check_health()
+ assert mock_executor.lambda_client.get_function.call_count == 1
+ # Lambda has already failed so SQS should not be called
+ assert mock_executor.sqs_client.get_queue_attributes.call_count == 0
+ assert not mock_executor.IS_BOTO_CONNECTION_HEALTHY
+
+ def test_check_health_sqs_fails(self, mock_executor):
+ mock_executor.IS_BOTO_CONNECTION_HEALTHY = False
+ mock_executor.sqs_client.get_queue_attributes.return_value =
ClientError(
+ {"Error": {"Code": "ResourceNotFoundException", "Message":
"foobar"}}, "OperationName"
+ )
+ mock_executor.lambda_client.get_function.return_value = {
+ "Configuration": {
+ "FunctionName": DEFAULT_FUNCTION_NAME,
+ "State": "Active",
+ }
+ }
+ with pytest.raises(AirflowException):
+ mock_executor.check_health()
+ assert mock_executor.lambda_client.get_function.call_count == 1
+ # Lambda has already failed so SQS should not be called
+ assert mock_executor.sqs_client.get_queue_attributes.call_count == 1
+ assert not mock_executor.IS_BOTO_CONNECTION_HEALTHY
+
+ def test_check_health_sqs_results_queue_success_dlq_fails(self,
mock_executor):
+ mock_executor.IS_BOTO_CONNECTION_HEALTHY = False
+ mock_executor.sqs_client.get_queue_attributes.side_effect = [
+ {"Attributes": {"ApproximateNumberOfMessages": 0}},
+ ClientError(
+ {"Error": {"Code": "ResourceNotFoundException", "Message":
"foobar"}}, "OperationName"
+ ),
+ ]
+ mock_executor.lambda_client.get_function.return_value = {
+ "Configuration": {
+ "FunctionName": DEFAULT_FUNCTION_NAME,
+ "State": "Active",
+ }
+ }
+ with pytest.raises(AirflowException):
+ mock_executor.check_health()
+ assert mock_executor.lambda_client.get_function.call_count == 1
+ # Lambda has already failed so SQS should not be called
+ assert mock_executor.sqs_client.get_queue_attributes.call_count == 2
+ assert not mock_executor.IS_BOTO_CONNECTION_HEALTHY
+
+ def test_sync_already_unhealthy(self, mock_executor):
+ # Something has set the connection to unhealthy (tested elsewhere)
+ mock_executor.IS_BOTO_CONNECTION_HEALTHY = False
+ mock_executor.sync_running_tasks = mock.Mock()
+ mock_executor.attempt_task_runs = mock.Mock()
+ mock_executor.load_connections = mock.Mock()
+ # Set the last connection reload to be more than 60 seconds ago so
that we get a reload
+ mock_executor.last_connection_reload = timezone.utcnow() -
dt.timedelta(seconds=100)
+ # We should not be able to sync
+ mock_executor.sync()
+ assert not mock_executor.IS_BOTO_CONNECTION_HEALTHY
+ mock_executor.sync_running_tasks.assert_not_called()
+ mock_executor.attempt_task_runs.assert_not_called()
+ mock_executor.load_connections.assert_called_once()
+
+ def test_sync_already_unhealthy_then_repaired(self, mock_executor):
+ # Something has set the connection to unhealthy (tested elsewhere)
+ mock_executor.IS_BOTO_CONNECTION_HEALTHY = False
+ mock_executor.sync_running_tasks = mock.Mock()
+ mock_executor.attempt_task_runs = mock.Mock()
+
+ def check_health_side_effect():
+ mock_executor.IS_BOTO_CONNECTION_HEALTHY = True
+
+ mock_executor.check_health =
mock.Mock(side_effect=check_health_side_effect)
+ # Set the last connection reload to be more than 60 seconds ago so
that we get a reload
+ mock_executor.last_connection_reload = timezone.utcnow() -
dt.timedelta(seconds=100)
+ # Sync should repair itself and continue to call the sync methods
+ mock_executor.sync()
+ assert mock_executor.IS_BOTO_CONNECTION_HEALTHY
+ mock_executor.sync_running_tasks.assert_called_once()
+ mock_executor.attempt_task_runs.assert_called_once()
+
+ @pytest.mark.parametrize(
+ "error_code",
+ [
+ "ExpiredTokenException",
+ "InvalidClientTokenId",
+ "UnrecognizedClientException",
+ ],
+ )
+ def test_sync_become_unhealthy_no_creds(self, error_code, mock_executor):
+ # Something has set the connection to unhealthy (tested elsewhere)
+ mock_executor.IS_BOTO_CONNECTION_HEALTHY = True
+ mock_executor.log.warning = mock.Mock()
+ mock_executor.attempt_task_runs = mock.Mock()
+ error_to_raise = ClientError({"Error": {"Code": error_code, "Message":
"foobar"}}, "OperationName")
+ mock_executor.sync_running_tasks =
mock.Mock(side_effect=error_to_raise)
+
+ # sync should catch the error and handle it, setting connection to
unhealthy
+ mock_executor.sync()
+ assert not mock_executor.IS_BOTO_CONNECTION_HEALTHY
+ mock_executor.sync_running_tasks.assert_called_once()
+ mock_executor.attempt_task_runs.assert_not_called()
+ # Check that the substring "AWS credentials are either missing or
expired" was logged
+ mock_executor.log.warning.assert_called_once()
+ assert "AWS credentials are either missing or expired" in
mock_executor.log.warning.call_args[0][0]
+
+ def test_sync_exception(self, mock_executor):
+ # Something has set the connection to unhealthy (tested elsewhere)
+ mock_executor.IS_BOTO_CONNECTION_HEALTHY = True
+ mock_executor.log.exception = mock.Mock()
+ mock_executor.attempt_task_runs = mock.Mock()
+ mock_executor.sync_running_tasks = mock.Mock(side_effect=Exception())
+
+ # sync should catch the error and log, don't kill scheduler by letting
it raise up higher.
+ mock_executor.sync()
+ # Not a credentials error that we can tell, so connection stays healthy
+ assert mock_executor.IS_BOTO_CONNECTION_HEALTHY
+ mock_executor.sync_running_tasks.assert_called_once()
+ mock_executor.attempt_task_runs.assert_not_called()
+ # Check that the substring "AWS credentials are either missing or
expired" was logged
+ mock_executor.log.exception.assert_called_once()
+ assert "An error occurred while syncing tasks" in
mock_executor.log.exception.call_args[0][0]
+
+ def test_try_adopt_task_instances(self, mock_executor, mock_airflow_key):
+ """Test that executor can adopt orphaned task instances from a
SchedulerJob shutdown event."""
+ # airflow_key_1 = mock_airflow_key()
+ airflow_key_1 = TaskInstanceKey("a", "task_a", "c", 1, -1)
+ ser_airflow_key_1 = json.dumps(airflow_key_1._asdict())
+ # airflow_key_2 = mock_airflow_key()
+ airflow_key_2 = TaskInstanceKey("a", "task_b", "c", 1, -1)
+ ser_airflow_key_2 = json.dumps(airflow_key_2._asdict())
+
+ orphaned_tasks = [
+ mock.Mock(spec=TaskInstance),
+ mock.Mock(spec=TaskInstance),
+ mock.Mock(spec=TaskInstance),
+ ]
+ orphaned_tasks[0].external_executor_id = ser_airflow_key_1
+ orphaned_tasks[1].external_executor_id = ser_airflow_key_2
+ orphaned_tasks[
+ 2
+ ].external_executor_id = None # One orphaned task has no
external_executor_id, not adopted
+
+ for task in orphaned_tasks:
+ task.try_number = 1
+
+ not_adopted_tasks =
mock_executor.try_adopt_task_instances(orphaned_tasks)
+
+ # Two of the three tasks should be adopted.
+ assert len(orphaned_tasks) - 1 == len(mock_executor.running_tasks)
+ assert ser_airflow_key_1 in mock_executor.running_tasks
+
+ assert mock_executor.running_tasks[ser_airflow_key_1] == airflow_key_1
+ assert ser_airflow_key_2 in mock_executor.running_tasks
+ assert mock_executor.running_tasks[ser_airflow_key_2] == airflow_key_2
+
+ # The remaining one task is unable to be adopted.
+ assert len(not_adopted_tasks) == 1
+ assert not_adopted_tasks[0] == orphaned_tasks[2]
+
+ @mock.patch.object(BaseExecutor, "fail")
+ @mock.patch.object(BaseExecutor, "success")
+ def test_end(self, success_mock, fail_mock, mock_executor,
mock_airflow_key):
+ """Test that executor can end successfully; waiting for all tasks to
naturally exit."""
+ airflow_key = mock_airflow_key()
+ ser_airflow_key = json.dumps(airflow_key._asdict())
+
+ mock_executor.running_tasks.clear()
+ mock_executor.running_tasks[ser_airflow_key] = airflow_key
+ # First message is empty, so we loop again while waiting for tasks to
finish
+ mock_executor.sqs_client.receive_message.side_effect = [
+ {},
+ {},
+ {
+ "Messages": [
+ {
+ "ReceiptHandle": "receipt_handle",
+ "Body": json.dumps(
+ {
+ "task_key": ser_airflow_key,
+ "return_code": 0,
+ }
+ ),
+ }
+ ]
+ },
+ ]
+ mock_executor.end(heartbeat_interval=0)
+ # Assert that the sqs_client mock method receive_message was called
exactly twice
+ assert mock_executor.sqs_client.receive_message.call_count == 3
+
+ # Task is not stored in active workers.
+ assert len(mock_executor.running_tasks) == 0
+ success_mock.assert_called_once()
+ fail_mock.assert_not_called()
+ assert mock_executor.sqs_client.delete_message.call_count == 1
+
+
@mock.patch("airflow.providers.amazon.aws.executors.aws_lambda.lambda_executor.timezone")
+ def test_end_timeout(self, mock_timezone, mock_executor, mock_airflow_key):
+ """Test that executor can end successfully; waiting for all tasks to
naturally exit."""
+ # Mock the sync method of the mock_executor object so we can count how
many times it was called
+ mock_executor.sync = mock.Mock()
+ mock_executor.log.warning = mock.Mock()
+ current_time = timezone.utcnow()
+ mock_timezone.utcnow.side_effect = [
+ current_time,
+ current_time,
+ current_time + dt.timedelta(seconds=5),
+ current_time + dt.timedelta(seconds=10),
+ ]
+
+ airflow_key = mock_airflow_key()
+ ser_airflow_key = json.dumps(airflow_key._asdict())
+
+ mock_executor.running_tasks.clear()
+ mock_executor.running_tasks[ser_airflow_key] = airflow_key
+
+ with conf_vars({(CONFIG_GROUP_NAME,
AllLambdaConfigKeys.END_WAIT_TIMEOUT): "5"}):
+ mock_executor.end(heartbeat_interval=0)
+
+ # Task is still stored in active workers.
+ assert len(mock_executor.running_tasks) == 1
+ assert mock_executor.sync.call_count == 2
+ mock_executor.log.warning.assert_called_once_with(
+ "Timed out waiting for tasks to finish. Some tasks may not be
handled gracefully"
+ " as the executor is force ending due to timeout."
+ )
+
+ def test_terminate(self, mock_executor, mock_airflow_key):
+ """Test that executor can terminate successfully."""
+ airflow_key = mock_airflow_key()
+ ser_airflow_key = json.dumps(airflow_key._asdict())
+
+ mock_executor.running_tasks.clear()
+ mock_executor.running_tasks[ser_airflow_key] = airflow_key
+ mock_executor.log.warning = mock.Mock()
+
+ mock_executor.terminate()
+ mock_executor.log.warning.assert_called_once_with(
+ "Terminating Lambda executor. In-flight tasks cannot be stopped."
+ )
+ assert len(mock_executor.running_tasks) == 1