This is an automated email from the ASF dual-hosted git repository. onikolas pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 2961561430a AWS Lambda Executor (#50516) 2961561430a is described below commit 2961561430af37d881ac27be4f404ae45810425e Author: Niko Oliveira <oniko...@amazon.com> AuthorDate: Thu May 15 15:36:11 2025 -0700 AWS Lambda Executor (#50516) This changeset includes: - The Lambda executor code itself - Tests for the executor - Docs for the executor --- .../tests/unit/always/test_project_structure.py | 2 + providers/amazon/docs/executors/batch-executor.rst | 27 + providers/amazon/docs/executors/ecs-executor.rst | 27 + providers/amazon/docs/executors/general.rst | 48 +- providers/amazon/docs/executors/index.rst | 1 + .../amazon/docs/executors/lambda-executor.rst | 358 ++++++++ providers/amazon/docs/img/lambda_executor.png | Bin 0 -> 56488 bytes providers/amazon/provider.yaml | 76 ++ .../amazon/aws/executors/aws_lambda/__init__.py | 21 + .../aws/executors/aws_lambda/docker/Dockerfile | 107 +++ .../aws/executors/aws_lambda/docker/__init__.py | 16 + .../amazon/aws/executors/aws_lambda/docker/app.py | 129 +++ .../aws/executors/aws_lambda/lambda_executor.py | 479 +++++++++++ .../amazon/aws/executors/aws_lambda/utils.py | 70 ++ .../airflow/providers/amazon/get_provider_info.py | 68 ++ .../amazon/aws/executors/aws_lambda/__init__.py | 16 + .../executors/aws_lambda/test_lambda_executor.py | 928 +++++++++++++++++++++ 17 files changed, 2354 insertions(+), 19 deletions(-) diff --git a/airflow-core/tests/unit/always/test_project_structure.py b/airflow-core/tests/unit/always/test_project_structure.py index a1cfba80cf6..e81476b0216 100644 --- a/airflow-core/tests/unit/always/test_project_structure.py +++ b/airflow-core/tests/unit/always/test_project_structure.py @@ -69,6 +69,8 @@ class TestProjectStructure: "providers/amazon/tests/unit/amazon/aws/executors/ecs/test_boto_schema.py", "providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor_config.py", "providers/amazon/tests/unit/amazon/aws/executors/ecs/test_utils.py", + "providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_utils.py", + "providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/docker/test_app.py", "providers/amazon/tests/unit/amazon/aws/executors/utils/test_base_config_keys.py", "providers/amazon/tests/unit/amazon/aws/operators/test_emr.py", "providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker.py", diff --git a/providers/amazon/docs/executors/batch-executor.rst b/providers/amazon/docs/executors/batch-executor.rst index c68f05537b7..920bc7e9e58 100644 --- a/providers/amazon/docs/executors/batch-executor.rst +++ b/providers/amazon/docs/executors/batch-executor.rst @@ -141,6 +141,33 @@ newly created Job Role and Execution role for the Job Definition. :start-after: .. BEGIN DOCKERFILE_AUTH_SECOND_METHOD :end-before: .. END DOCKERFILE_AUTH_SECOND_METHOD +.. include:: general.rst + :start-after: .. BEGIN BASE_IMAGE + :end-before: .. END BASE_IMAGE + +.. include:: general.rst + :start-after: .. BEGIN LOADING_DAGS_OVERVIEW + :end-before: .. END LOADING_DAGS_OVERVIEW + +.. include:: general.rst + :start-after: .. BEGIN LOADING_DAGS_FROM_S3 + :end-before: .. END LOADING_DAGS_FROM_S3 + +.. include:: general.rst + :start-after: .. BEGIN LOADING_DAGS_FROM_LOCAL + :end-before: .. END LOADING_DAGS_FROM_LOCAL + +.. include:: general.rst + :start-after: .. BEGIN DEPENDENCIES + :end-before: .. END DEPENDENCIES + +Building Image for AWS Batch Executor +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Detailed instructions on how to use the Docker image, that you have +created via this readme, with the Batch Executor can be found +:ref:`here <batch_setup_guide>`. + .. _batch_logging: .. include:: general.rst diff --git a/providers/amazon/docs/executors/ecs-executor.rst b/providers/amazon/docs/executors/ecs-executor.rst index fa1da53e0ad..d7451e55277 100644 --- a/providers/amazon/docs/executors/ecs-executor.rst +++ b/providers/amazon/docs/executors/ecs-executor.rst @@ -162,6 +162,33 @@ newly created Task Role and Task Execution role for the Task Definition. :start-after: .. BEGIN DOCKERFILE_AUTH_SECOND_METHOD :end-before: .. END DOCKERFILE_AUTH_SECOND_METHOD +.. include:: general.rst + :start-after: .. BEGIN BASE_IMAGE + :end-before: .. END BASE_IMAGE + +.. include:: general.rst + :start-after: .. BEGIN LOADING_DAGS_OVERVIEW + :end-before: .. END LOADING_DAGS_OVERVIEW + +.. include:: general.rst + :start-after: .. BEGIN LOADING_DAGS_FROM_S3 + :end-before: .. END LOADING_DAGS_FROM_S3 + +.. include:: general.rst + :start-after: .. BEGIN LOADING_DAGS_FROM_LOCAL + :end-before: .. END LOADING_DAGS_FROM_LOCAL + +.. include:: general.rst + :start-after: .. BEGIN DEPENDENCIES + :end-before: .. END DEPENDENCIES + +Building Image for AWS ECS Executor +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Detailed instructions on how to use the Docker image, that you have +created via this readme, with the ECS Executor can be found +:ref:`here <setup_guide>`. + .. _logging: .. include:: general.rst diff --git a/providers/amazon/docs/executors/general.rst b/providers/amazon/docs/executors/general.rst index 7df66e36ac4..4928e2d8f7a 100644 --- a/providers/amazon/docs/executors/general.rst +++ b/providers/amazon/docs/executors/general.rst @@ -42,8 +42,7 @@ An example Dockerfile can be found |dockerfileLink|, it creates an image that can be used by AWS |executorName| to run Airflow tasks using the AWS |executorName| Executor in Apache Airflow. The image supports AWS CLI/API integration, allowing you to interact with AWS services within your -Airflow environment. It also includes options to load DAGs (Directed -Acyclic Graphs) from either an S3 bucket or a local folder. +Airflow environment. It also includes options to load Dags from either an S3 bucket or a local folder. Prerequisites @@ -90,9 +89,10 @@ The second method is to use the build-time arguments (``aws_access_key_id``, ``aws_secret_access_key``, ``aws_default_region``, and ``aws_session_token``). -Note: This method is not recommended for use in production environments, -because user credentials are stored in the container, which may be a -security vulnerability. +.. warning:: + This method is not recommended for use in production environments, + because user credentials are stored in the container, which may be a + security vulnerability. To pass AWS authentication information using these arguments, use the ``--build-arg`` option during the Docker build process. For example: @@ -109,6 +109,10 @@ Replace ``YOUR_ACCESS_KEY``, ``YOUR_SECRET_KEY``, ``YOUR_SESSION_TOKEN``, and ``YOUR_DEFAULT_REGION`` with valid AWS credentials. +.. END DOCKERFILE_AUTH_SECOND_METHOD + +.. BEGIN BASE_IMAGE + Base Image ~~~~~~~~~~ @@ -142,15 +146,22 @@ version <https://hub.docker.com/r/apache/airflow/tags?page=1&name=3.9>`__. For example, the tag ``latest-python3.9`` specifies that the image will have python 3.9 installed. +.. END BASE_IMAGE + +.. BEGIN LOADING_DAGS_OVERVIEW Loading DAGs ~~~~~~~~~~~~ -There are many ways to load DAGs on a container managed by |executorName|. This Dockerfile +There are many ways to load DAGs on a container used by |executorName|. This Dockerfile is preconfigured with two possible ways: copying from a local folder, or downloading from an S3 bucket. Other methods of loading DAGs are possible as well. +.. END LOADING_DAGS_OVERVIEW + +.. BEGIN LOADING_DAGS_FROM_S3 + From S3 Bucket ^^^^^^^^^^^^^^ @@ -176,6 +187,10 @@ build arguments. --build-arg aws_session_token=YOUR_SESSION_TOKEN \ --build-arg s3_uri=YOUR_S3_URI . +.. END LOADING_DAGS_FROM_S3 + +.. BEGIN LOADING_DAGS_FROM_LOCAL + From Local Folder ^^^^^^^^^^^^^^^^^ @@ -194,6 +209,10 @@ If choosing to load DAGs onto a different path than ``/opt/airflow/dags``, then the new path will need to be updated in the Airflow config. +.. END LOADING_DAGS_FROM_LOCAL + +.. BEGIN DEPENDENCIES + Installing Python Dependencies ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -206,28 +225,19 @@ Uncomment the two appropriate lines in the Dockerfile that copy the ``requirements.txt`` file to the container, and run ``pip install`` to install the dependencies on the container. -Building Image for AWS |executorName| Executor -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Detailed instructions on how to use the Docker image, that you have -created via this readme, with the |executorName| Executor can be found -:ref:`here <setup_guide>`. - -.. END DOCKERFILE_AUTH_SECOND_METHOD +.. END DEPENDENCIES .. BEGIN LOGGING Logging ------- -Airflow tasks executed via this executor run in containers within +Airflow tasks executed via this executor run within the configured VPC. This means that logs are not directly accessible to -the Airflow Webserver and when containers are stopped, after task -completion, the logs would be permanently lost. +the Airflow UI, after task completion, the logs would be permanently lost. Remote logging should be employed when using the |executorName| executor to persist -your Airflow Task logs and make them viewable from the Airflow -Webserver. +your Airflow Task logs and make them viewable from the Airflow UI. Configuring Remote Logging ~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/providers/amazon/docs/executors/index.rst b/providers/amazon/docs/executors/index.rst index 117cd1faccc..8117d4d969f 100644 --- a/providers/amazon/docs/executors/index.rst +++ b/providers/amazon/docs/executors/index.rst @@ -26,3 +26,4 @@ Amazon Executors ECS Executor <ecs-executor> Batch Executor (experimental) <batch-executor> + Lambda Executor (experimental) <lambda-executor> diff --git a/providers/amazon/docs/executors/lambda-executor.rst b/providers/amazon/docs/executors/lambda-executor.rst new file mode 100644 index 00000000000..3d09931fd96 --- /dev/null +++ b/providers/amazon/docs/executors/lambda-executor.rst @@ -0,0 +1,358 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +.. warning:: + The Lambda Executor is alpha/experimental at the moment and may unstable and be subject to change without warning. +.. |executorName| replace:: Lambda +.. |dockerfileLink| replace:: `here <https://github.com/apache/airflow/blob/main/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/docker/Dockerfile>`__ +.. |appHandlerLink| replace:: `here <https://github.com/apache/airflow/blob/main/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/docker/app.py>`__ + +=================== +AWS Lambda Executor +=================== + +This is an Airflow executor powered by AWS Lambda. Each task that Airflow schedules +for execution is run as an asynchronous Lambda function invocation. Some benefits +of an executor like this include: + +1. Task isolation: No task can be a noisy neighbor for another. Resources are + isolated to each individual task. Any failures affect only the single task + running in that Lambda function. No single user can overload the environment + by triggering too many tasks. +2. Lightweight execution: Lambda functions start quickly and have minimal overhead, + making this executor ideal for short-running tasks that don't require substantial + CPU or memory. +3. Cost effective: Compute resources only exist for the lifetime of the Airflow task + itself. Lambda's pay-per-use pricing model ensures you only pay for actual task + execution time. +4. Highly scalable: Lambda can scale to handle thousands of concurrent executions + with no pre-provisioning. + +For a quick start guide please see :ref:`here <lambda_setup_guide>`, it will +get you up and running with a basic configuration. Both an example Dockerfile +(found |dockerfileLink|) and Lambda invocation handler (found |appHandlerLink|) +are provided. The below sections provide more generic details about configuration, +the provided example Dockerfile and logging. + +How Does it Work? +----------------- +The executor leverages AWS Lambda for compute and SQS as a communication channel. +When a task is queued to run with the Lambda executor a Lambda function is asynchronously +invoked with a payload containing the task key being run and the Airflow CLI +command/workload to run. The Lambda function executes the task and then must +publish a message to the SQS results queue with the same task key it received as well as +the return code of the airflow task execution process. The Lambda executor periodically +polls this SQS results queue for messages and updates its internal state with the results +it receives. The Lambda function must also be configured with a Dead Letter Queue (DLQ) to +which Lambda will automatically send messages to for invocations that fail to be processed, +for example due to the 15 minute max timeout of Lambda functions. The Lambda executor will +also poll this DLQ for messages and update the internal state with the results it receives. + +.. image:: ../img/lambda_executor.png + +.. _lambda_config-options: + +Config Options +-------------- + +There are a number of configuration options available, which can either +be set directly in the airflow.cfg file under an "aws_lambda_executor" +section or via environment variables using the +``AIRFLOW__AWS_LAMBDA_EXECUTOR__<OPTION_NAME>`` format, for example +``AIRFLOW__AWS_LAMBDA_EXECUTOR__FUNCTION_NAME = "myLambdaFunction"``. For +more information on how to set these options, see `Setting Configuration +Options <https://airflow.apache.org/docs/apache-airflow/stable/howto/set-config.html>`__ + +.. note:: + Configuration options must be consistent across all the hosts/environments running the Airflow components (Scheduler, Webserver, Lambda function, etc). See `here <https://airflow.apache.org/docs/apache-airflow/stable/configurations-ref.html>`__ for more details on setting configurations. + +In the case of conflicts, the order of precedence from lowest to highest is: + +1. Load default values for options which have defaults. +2. Load any values explicitly provided through airflow.cfg or + environment variables. These are checked with Airflow's config + precedence. + +Required config options: +~~~~~~~~~~~~~~~~~~~~~~~~ + +- FUNCTION_NAME - Name or ARN of the AWS Lambda function to invoke. Required. +- QUEUE_URL - URL of the SQS queue that will receive task execution results. Required. +- DEAD_LETTER_QUEUE_URL - URL of the Dead Letter Queue for processing Lambda execution failures. Required + +Optional config options: +~~~~~~~~~~~~~~~~~~~~~~~~ + +- FUNCTION_QUALIFIER - Version or alias of the Lambda function to invoke. Defaults to ``$LATEST``. +- CONN_ID - The Airflow connection (i.e. credentials) used by the Lambda + executor to make API calls to AWS Lambda and SQS. Defaults to ``aws_default``. +- REGION_NAME - The name of the AWS Region where AWS Lambda and SQS are configured. +- CHECK_HEALTH_ON_STARTUP - Whether to check the Lambda Executor health on startup by testing connectivity to required AWS services. Defaults to ``True``. +- MAX_INVOKE_ATTEMPTS - The maximum number of times the Lambda Executor should attempt to invoke a function. This refers to instances where the invocation fails (i.e. AWS API failures, throttling, etc). Defaults to ``3``. + +For a more detailed description of available options, including type +hints and examples, see the ``config_templates`` folder in the Amazon +provider package. + +.. _dockerfile_for_lambda_executor: + +.. include:: general.rst + :start-after: .. BEGIN DOCKERFILE + :end-before: .. END DOCKERFILE + + +The most secure method is to use IAM roles. When creating a Lambda Function +Definition, you are able to select an execution role. This role needs +permissions to publish messages to the SQS queues and to write to CloudWatchLogs +or S3 if using AWS remote logging and/or using S3 to synchronize dags +(e.g. ``CloudWatchLogsFullAccess`` or ``CloudWatchLogsFullAccessV2``). +The AWS credentials used on the Scheduler need permissions to +describe and invoke Lambda functions as well as to describe and read/delete +SQS messages. + +To create a new role for use by Lambda follow the below: + +1. Navigate to the IAM page on the AWS console, and from the left hand + tab, under Access Management, select Roles. +2. On the Roles page, click Create role on the top right hand corner. +3. Under Trusted entity type, select AWS Service. +4. Select Lambda from the drop down under Use case. Click Next. +5. In the Permissions page, select the permissions the role will need, + depending on the actions it will perform (at least including the + permissions described above). + Click Next after selecting all the required permissions. +6. Enter a name for the new role, and an optional description. Review + the Trusted Entities, and the permissions for the role. Add any tags + as necessary, and click Create role. + +When creating the Lambda Function (see the :ref:`setup guide <lambda_setup_guide>` for more details), select the newly created role. + +.. include:: general.rst + :start-after: .. BEGIN DOCKERFILE_AUTH_SECOND_METHOD + :end-before: .. END DOCKERFILE_AUTH_SECOND_METHOD + +Base Image +~~~~~~~~~~ + +The Docker image is built upon the ``public.ecr.aws/lambda/python`` image. +This base image contains the necessary components for the Lambda function to +run. Apache Airflow is installed into the image via pip. You can modify the +Dockerfile to install additional packages or different versions of Airflow as +needed for your use case. + +It is also possible to build the image based of ``apache/airflow:latest`` + and the Lambda runtime can be included separately (follow steps `here <https://docs.aws.amazon.com/lambda/latest/dg/images-create.html#images-ric>`__). + + +.. include:: general.rst + :start-after: .. BEGIN LOADING_DAGS_OVERVIEW + :end-before: .. END LOADING_DAGS_OVERVIEW + +From S3 Bucket +^^^^^^^^^^^^^^ + +Dags can be loaded from S3 when using the provided example app.py, which +contains logic to synchronize the DAGs from S3 to the local filesystem of +the Lambda function (see the app.py code |appHandlerLink|). + +To load Dags from an S3 bucket add ``--build-arg s3_uri=YOUR_S3_URI`` in +the docker build command. Replace ``YOUR_S3_URI`` with the URI of your S3 +bucket/path containing your dags. Make sure you have the appropriate +permissions to read from the bucket. + +.. code-block:: bash + + docker build -t my-airflow-image --build-arg s3_uri=YOUR_S3_URI . + +.. include:: general.rst + :start-after: .. BEGIN LOADING_DAGS_FROM_LOCAL + :end-before: .. END LOADING_DAGS_FROM_LOCAL + +.. include:: general.rst + :start-after: .. BEGIN DEPENDENCIES + :end-before: .. END DEPENDENCIES + +Using the AWS Lambda Executor Docker Image +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Detailed instructions on how to use the Docker image, that you have +created via these steps can be found :ref:`here <lambda_setup_guide>`. + +.. _lamba_logging: + +.. include:: general.rst + :start-after: .. BEGIN LOGGING + :end-before: .. END LOGGING + +- The configuration options for Airflow remote logging should be + configured on all hosts and containers running Airflow. For example + the Webserver requires this config so that it can fetch logs from + the remote location and the Lambda container requires the config so that + it can upload the logs to the remote location. See + `here <https://airflow.apache.org/docs/apache-airflow/stable/howto/set-config.html>`__ + to read more about how to set Airflow configuration via config file + or environment variable exports. +- Adding the Airflow remote logging config to Lambda can be done + in many ways. Some examples include, but are not limited to: + + - Exported as environment variables directly in the Dockerfile (see + the Dockerfile section :ref:`above <dockerfile_for_lambda_executor>`) + - Updating the ``airflow.cfg`` file or copy/mounting/downloading a + custom ``airflow.cfg`` in the Dockerfile. + - Added in the Lambda Function definition in plain text or via + `Secrets/System + Manager <https://docs.aws.amazon.com/secretsmanager/latest/userguide/retrieving-secrets_lambda.html>`__ + +- Remember that you must have credentials configured within the container to be able + to interact with the remote service for your logs (e.g. S3, + CloudWatch Logs, etc). This can be done in many ways. Some examples + include, but are not limited to: + + - Export credentials into the Dockerfile directly (see the + Dockerfile section :ref:`above <dockerfile_for_lambda_executor>`) + - Configure an Airflow Connection and provide this as the `remote + logging conn + id <https://airflow.apache.org/docs/apache-airflow/stable/configurations-ref.html#remote-log-conn-id>`__ + (exported into the Lambda Function by any of the means listed above or + your preferred method). Airflow will then use these credentials + *specifically* for interacting with your chosen remote logging + destination. + +.. note:: + Configuration options must be consistent across all the hosts/environments running the Airflow components (Scheduler, Webserver, Lambda invocations, etc). See `here <https://airflow.apache.org/docs/apache-airflow/stable/configurations-ref.html>`__ for more details on setting configurations. + +.. _lambda_logging: + +Lambda Logging +~~~~~~~~~~~~~~~~ + +Lambda will send logging information to CloudWatch Logs for the Function +invocations themselves. These logs will include the Airflow Task Operator +logging and any other logging that occurs throughout the life of the process +running in the Lambda function. This can be helpful for debugging issues with +remote logging or while testing remote logging configuration. + +**Note: These logs will NOT be viewable from the Airflow Webserver UI.** + +.. _lambda_setup_guide: + + +Setting up a Lambda Executor for Apache Airflow +----------------------------------------------- + +There are 3 steps involved in getting a Lambda Executor to work in Apache Airflow: + +1. Creating a database that Airflow and the tasks running in Lambda Functions can + connect to (Note: In Airflow 3, tasks no longer directly access the DB). + +2. Creating and configuring a Lambda Function (and its associated SQS queues) that + can run tasks from Airflow. + +3. Configuring Airflow to use the Lambda Executor + +There are different options for selecting a database backend. See `here <https://airflow.apache.org/docs/apache-airflow/stable/howto/set-up-database.html>`_ for more information about the different options supported by Airflow. The following guide will explain how to set up a PostgreSQL RDS Instance on AWS. + +.. include:: general.rst + :start-after: .. BEGIN DATABASE_CONNECTION + :end-before: .. END DATABASE_CONNECTION + + +Creating a Lambda Function +-------------------------- + +To create them Lambda Function first start with building the docker image above (see the :ref:`Dockerfile <dockerfile_for_lambda_executor>` section for instructions). + +Once the image is built, it needs to be put in a repository where it can be pulled by Lambda. There are multiple ways +to accomplish this. This guide will use Amazon Elastic Container Registry (ECR). + +.. include:: general.rst + :start-after: .. BEGIN ECR_STEPS + :end-before: .. END ECR_STEPS + +Create SQS Queue +~~~~~~~~~~~~~~~~ + +You will also need two SQS queues 1) to send the results of the Airflow Task executions +and 2) a Dead Letter Queue (DLQ) to handle failed invocations of the Lambda function. + +Create SQS queues by following these steps (for each queue): + +1. Log in to your AWS Management Console and navigate to the Amazon SQS Service. + +2. Click "Create queue". + +3. Select "Standard" as the queue type. + +4. Provide a name for the queue, and select the defaults that work for your use case. + +2. Click Create. + +Create Function +~~~~~~~~~~~~~~~ + +Finally create the function: + +1. Log in to your AWS Management Console and navigate to the AWS Lambda Service. + +2. Click "Create Function". + +3. Select "Container image" as the function type. + +4. Provide a name for the function, select the ECR repository and image tag that you created in the previous steps, and + select the execution role that you created for use in the Lambda Function. + +5. Click Create. + +6. Once created, add the following environment variables to the container under Configuration > Environment variables: + +- ``AIRFLOW__DATABASE__SQL_ALCHEMY_CONN``, with the value being the PostgreSQL connection string in the following format using the values set during the `Database section <#create-the-rds-db-instance>`_ above: + +.. code-block:: bash + + postgresql+psycopg2://<username>:<password>@<endpoint>/<database_name> + + +- ``AIRFLOW__LAMBDA_EXECUTOR__QUEUE_URL``, with the value being the URL of the SQS queue created above. + +7. Add other configuration as necessary for Airflow generally (see `here <https://airflow.apache.org/docs/apache-airflow/stable/configurations-ref.html>`__), the Lambda executor (see :ref:`here <lambda_config-options>`) or for remote logging (see :ref:`here <lamba_logging>`). Note that any configuration changes should be made across the entire Airflow environment to keep configuration consistent. + +8. Add the dead letter queue URL to the Lambda function under Configuration > Asynchronous invocation > edit > Dead-letter queue. Select the SQS queue you created above. + +9. If using Airflow 2 allow access to the database. There are many possible methods, but one simple approach is to add a connection to the RDS database via Configuration > RDS databases, follow the steps of the wizard. + +Configure Airflow +~~~~~~~~~~~~~~~~~ + +To configure Airflow to utilize the Lambda Executor and leverage the resources we've set up, create a script (e.g., ``lambda_executor_config.sh``) with at least the following contents: + +.. code-block:: bash + + export AIRFLOW__CORE__EXECUTOR='airflow.providers.amazon.aws.executors.aws_lambda.AwsLambdaExecutor' + + export AIRFLOW__DATABASE__SQL_ALCHEMY_CONN=<postgres-connection-string> + + export AIRFLOW__AWS_LAMBDA_EXECUTOR__FUNCTION_NAME=<lambda-function-name> + + export AIRFLOW__AWS_LAMBDA_EXECUTOR__QUEUE_URL=<sqs_queue_url> + + export AIRFLOW__AWS_LAMBDA_EXECUTOR__DEAD_LETTER_QUEUE_URL=<sqs_dlq_queue_url> + + +.. include:: general.rst + :start-after: .. BEGIN INIT_DB + :end-before: .. END INIT_DB diff --git a/providers/amazon/docs/img/lambda_executor.png b/providers/amazon/docs/img/lambda_executor.png new file mode 100644 index 00000000000..98773c916c1 Binary files /dev/null and b/providers/amazon/docs/img/lambda_executor.png differ diff --git a/providers/amazon/provider.yaml b/providers/amazon/provider.yaml index 64013749ac4..fb7bf428790 100644 --- a/providers/amazon/provider.yaml +++ b/providers/amazon/provider.yaml @@ -1004,6 +1004,82 @@ config: example: '{"Tags": [{"Key": "key", "Value": "value"}]}' default: ~ + aws_lambda_executor: + description: | + This section only applies if you are using the AwsLambdaExecutor in + Airflow's ``[core.executor]`` configuration. + For more information see: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/lambda/client/invoke.html + https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html + https://airflow.apache.org/docs/apache-airflow-providers-amazon/stable/executors/lambda-executor.html + options: + conn_id: + description: | + The Airflow connection (i.e. credentials) used by the Lambda executor to make API calls. + version_added: "9.9.0" + type: string + example: "aws_default" + default: "aws_default" + region_name: + description: | + The name of the AWS Region where Amazon Lambda is configured. + version_added: "9.9.0" + type: string + example: "us-east-1" + default: ~ + check_health_on_startup: + description: | + Whether or not to check the Lambda Executor health on startup. + version_added: "9.9.0" + type: boolean + example: "True" + default: "True" + max_run_task_attempts: + description: | + The maximum number of times the Lambda Executor should attempt to start an Airflow task. + version_added: "9.9.0" + type: integer + example: "3" + default: "3" + queue_url: + description: | + The URL of the SQS queue to use for the Lambda executor. Required. + The Lambda executor will poll this queue for results of the lambda function's Airflow Task. + version_added: "9.9.0" + type: string + example: "airflow-lambda-executor-results-queue" + default: ~ + dead_letter_queue_url: + description: | + The URL of the SQS dead letter queue to use for the Lambda function. The Lambda executor + will poll this queue for timeout/exception results of the lambda function. Required + version_added: "9.9.0" + type: string + example: "airflow-lambda-executor-dlq" + default: ~ + function_name: + description: | + The name of the Lambda function to invoke. Required. + version_added: "9.9.0" + type: string + example: "airflow-lambda-executor-function" + default: ~ + qualifier: + description: | + The version or alias of the Lambda function to invoke. If not specified, the latest version is used. + version_added: "9.9.0" + type: string + example: "1" + default: ~ + end_wait_timeout: + description: | + The number of seconds to wait for all lambda executor invocations to complete when terminating the + executor/scheduler. Default of 0 means wait indefinitely. + version_added: "9.9.0" + type: integer + example: "500" + default: "0" + aws_ecs_executor: description: | This section only applies if you are using the AwsEcsExecutor in diff --git a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/__init__.py b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/__init__.py new file mode 100644 index 00000000000..b3f8fdcb26e --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +__all__ = ["AwsLambdaExecutor"] + +from airflow.providers.amazon.aws.executors.aws_lambda.lambda_executor import AwsLambdaExecutor diff --git a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/docker/Dockerfile b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/docker/Dockerfile new file mode 100644 index 00000000000..f5c4648caf4 --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/docker/Dockerfile @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Use the official AWS Lambda Python base image. +# If you wish to use a different Python version, please update the line below +FROM public.ecr.aws/lambda/python:3.12 + +# hadolint ignore=DL3041 +RUN dnf -y install unzip \ + && dnf clean all \ + && rm -rf /var/cache/dnf + +# Install the AWS CLI. +RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip" && \ + unzip awscliv2.zip \ + && ./aws/install \ + && rm -rf ./aws awscliv2.zip + +## Install Airflow and dependencies. +# The most current version of Airflow is installed by default, along with the amazon and postgres +# provider packages. +# If you would like to install a specific version, you can use the requirements.txt to change the +# version along with installing your dependencies or update this Dockerfile to install a specific +# version of Airflow. + +# NOTE: If you change the below line, specifically removing the amazon extra, please ensure boto3 is +# installed via another method. Boto3 is required for the Lambda executor to function properly. +# hadolint ignore=SC2102 +RUN pip install --no-cache-dir --upgrade pip && pip install --no-cache-dir apache-airflow[amazon,postgres] +# /tmp is the only writable directory in Lambda, so we need to set the AIRFLOW_HOME there. +ENV AIRFLOW_HOME=/tmp/airflow +# Dags are read-only, so they can be stored in opt (or another path provided). +ARG container_dag_path=/opt/airflow/dags +ENV AIRFLOW__CORE__DAGS_FOLDER=$container_dag_path +RUN mkdir -p /opt/airflow/dags + +# Python dependencies can be installed by providing a requirements.txt. +# If the file is in a different location, use the requirements_path build argument to specify +# the file path. +ARG requirements_path=./requirements.txt +ENV REQUIREMENTS_PATH=$requirements_path +# Uncomment the two lines below to copy the requirements.txt file to the container, and +# install the dependencies. +# COPY --chown=airflow:root $REQUIREMENTS_PATH /opt/airflow/requirements.txt +# RUN pip install --no-cache-dir -r /opt/airflow/requirements.txt + +## AWS Authentication +# The image requires access to AWS services. This Dockerfile supports 2 ways to authenticate with AWS. +# The first is using build arguments where you can provide the AWS credentials as arguments +# passed when building the image. The other option is to leverage the Lambda execution role. Airflow +# will default to using Boto credential strategy which will look for roles from Lambda, this is the +# preferred approach. See the Lambda Executor Airflow documentation for more details. + +# If you would like to use an alternative method of authentication, feel free to make the +# necessary changes to this file. + +# Uncomment to use these arguments to provide AWS authentication information if not using the Lambda +# execution role. +#ARG aws_access_key_id +#ARG aws_secret_access_key +#ARG aws_default_region +#ARG aws_session_token + +#ENV AWS_ACCESS_KEY_ID=$aws_access_key_id +#ENV AWS_SECRET_ACCESS_KEY=$aws_secret_access_key +#ENV AWS_DEFAULT_REGION=$aws_default_region +#ENV AWS_SESSION_TOKEN=$aws_session_token + +## Loading DAGs +# This Dockerfile supports 2 ways to load DAGs onto the container. +# One is to download them from S3 at runtime during the Lambda app invocation. The other +# is to copy the dags into the image at build time, this will make task execution +# much faster, since the images will already be present but the image will need to be rebuilt +# every time the DAGs are updated. +# If you would like to use an alternative method of loading DAGs, feel free to make the +# necessary changes to this file. + +ARG host_dag_path=./dags +ENV HOST_DAG_PATH=$host_dag_path +# Uncomment the line below to copy the DAGs from the host to the container. +# COPY $HOST_DAG_PATH $AIRFLOW__CORE__DAGS_FOLDER + +# Use these arguments to load DAGs at runtime. If you are using the provided Lambda function (app.py), +# it will check for this environment variable and download the DAGs from S3. See the example app in Lambda +# Executor documentation for details. +ARG s3_uri +ENV S3_URI=$s3_uri + +# Copy your Lambda function code into the Docker build directory/context. +COPY app.py ${LAMBDA_TASK_ROOT}/ + +# Specify the Lambda function handler (update if you provide a different handler with a different name). +CMD ["app.lambda_handler"] diff --git a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/docker/__init__.py b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/docker/__init__.py new file mode 100644 index 00000000000..13a83393a91 --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/docker/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/docker/app.py b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/docker/app.py new file mode 100644 index 00000000000..9d9b5a5b15e --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/docker/app.py @@ -0,0 +1,129 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +import logging +import os +import subprocess +from tempfile import mkdtemp + +import boto3 + +""" +Example Lambda function to execute an Airflow command or workload. Use or modify this code as needed. +""" + +log = logging.getLogger() +log.setLevel(logging.INFO) + + +# Get the S3 URI from the environment variable. Set either on the Lambda function or in the +# docker image used for the lambda invocations. +S3_URI = os.environ.get("S3_URI", None) +# Input and output keys +TASK_KEY_KEY = "task_key" +COMMAND_KEY = "command" +RETURN_CODE_KEY = "return_code" + + +def lambda_handler(event, context): + log.info("Received event: %s", event) + log.info("Received context: %s", context) + + command = event.get(COMMAND_KEY) + task_key = event.get(TASK_KEY_KEY) + + # Any pre-processing or validation of the command or use of the context can be done here or above. + + # Sync dags from s3 to the local dags directory + if S3_URI: + fetch_dags_from_s3(S3_URI) + # This function must be called, it executes the Airflow command and reports to SQS. + run_and_report(command, task_key) + + # Any post-processing or cleanup can be done here. + + +def run_and_report(command, task_key): + """Execute the provided Airflow command or workload and report the result via SQS.""" + try: + log.info("Starting execution for task: %s", task_key) + result = subprocess.run( + command, shell=isinstance(command, str), stdout=subprocess.PIPE, stderr=subprocess.STDOUT + ) + return_code = result.returncode + log.info("Execution completed for task %s with return code %s", task_key, return_code) + log.info("Output:") + log.info("%s", result.stdout.decode()) + except Exception: + log.exception("Error executing task %s: ", task_key) + return_code = 1 # Non-zero indicates failure to run the task + + queue_url = get_queue_url() + message = json.dumps({TASK_KEY_KEY: task_key, RETURN_CODE_KEY: return_code}) + try: + sqs_client = get_sqs_client() + sqs_client.send_message(QueueUrl=queue_url, MessageBody=message) + log.info("Sent result to SQS %s", message) + except Exception: + log.exception("Failed to send message to SQS for task %s", task_key) + + +def get_sqs_client(): + """Create an SQS client. Credentials and region are automatically picked up from the environment.""" + return boto3.client("sqs") + + +def get_queue_url(): + """ + Get the SQS queue URL from the environment variable. + + Set either on the Lambda function or in the image used for the lambda invocations. + """ + queue_url = os.environ.get("AIRFLOW__AWS_LAMBDA_EXECUTOR__QUEUE_URL", os.environ.get("QUEUE_URL", None)) + if not queue_url: + raise RuntimeError( + "No Queue URL detected (either AIRFLOW__AWS_LAMBDA_EXECUTOR__QUEUE_URL or " + "QUEUE_URL); Will be unable to send task results. Exiting!" + ) + return queue_url + + +def fetch_dags_from_s3(s3_uri): + """Fetch DAGs from S3 and sync them to the local dags directory.""" + log.info("Fetching DAGs from S3 URI: %s", s3_uri) + # Use a named temporary directory for the local dags folder, only tmp is writeable in Lambda + local_dags_dir = mkdtemp(prefix="airflow_dags_") + log.info("Setting AIRFLOW__CORE__DAGS_FOLDER to: %s", local_dags_dir) + os.environ["AIRFLOW__CORE__DAGS_FOLDER"] = local_dags_dir + + # S3 URI format s3://bucket-name/path/to/dags/ + bucket_name = s3_uri.split("/")[2] + prefix = "/".join(s3_uri.split("/")[3:]) + + s3_resource = boto3.resource("s3") + bucket = s3_resource.Bucket(bucket_name) + + for obj in bucket.objects.filter(Prefix=prefix): + if obj.key.endswith("/"): + # Skip directories + continue + key = obj.key + local_path = os.path.join(local_dags_dir, os.path.basename(key)) + log.info("Downloading %s to %s", key, local_path) + bucket.download_file(key, local_path) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py new file mode 100644 index 00000000000..3f477095f44 --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py @@ -0,0 +1,479 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +import time +from collections import deque +from collections.abc import Sequence +from typing import TYPE_CHECKING + +from boto3.session import NoCredentialsError +from botocore.utils import ClientError + +from airflow.configuration import conf +from airflow.exceptions import AirflowException +from airflow.executors.base_executor import BaseExecutor +from airflow.models.taskinstancekey import TaskInstanceKey +from airflow.providers.amazon.aws.executors.aws_lambda.utils import ( + CONFIG_GROUP_NAME, + INVALID_CREDENTIALS_EXCEPTIONS, + AllLambdaConfigKeys, + CommandType, + LambdaQueuedTask, +) +from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry import ( + calculate_next_attempt_delay, + exponential_backoff_retry, +) +from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook +from airflow.providers.amazon.aws.hooks.sqs import SqsHook +from airflow.stats import Stats +from airflow.utils import timezone + +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + + from airflow.executors import workloads + from airflow.models.taskinstance import TaskInstance + + +class AwsLambdaExecutor(BaseExecutor): + """ + An Airflow Executor that submits tasks to AWS Lambda asynchronously. + + When execute_async() is called, the executor invokes a specified AWS Lambda function (asynchronously) + with a payload that includes the task command and a unique task key. + + The Lambda function writes its result directly to an SQS queue, which is then polled by this executor + to update task state in Airflow. + """ + + if TYPE_CHECKING and AIRFLOW_V_3_0_PLUS: + # In the v3 path, we store workloads, not commands as strings. + # TODO: TaskSDK: move this type change into BaseExecutor + queued_tasks: dict[TaskInstanceKey, workloads.All] # type: ignore[assignment] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.pending_tasks: deque = deque() + self.running_tasks: dict[str, TaskInstanceKey] = {} + self.lambda_function_name = conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.FUNCTION_NAME) + self.sqs_queue_url = conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.QUEUE_URL) + self.dlq_url = conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.DLQ_URL) + self.qualifier = conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.QUALIFIER, fallback=None) + # Maximum number of retries to invoke Lambda. + self.max_invoke_attempts = conf.get( + CONFIG_GROUP_NAME, + AllLambdaConfigKeys.MAX_INVOKE_ATTEMPTS, + ) + + self.attempts_since_last_successful_connection = 0 + self.IS_BOTO_CONNECTION_HEALTHY = False + self.load_connections(check_connection=False) + + def start(self): + """Call this when the Executor is run for the first time by the scheduler.""" + check_health = conf.getboolean(CONFIG_GROUP_NAME, AllLambdaConfigKeys.CHECK_HEALTH_ON_STARTUP) + + if not check_health: + return + + self.log.info("Starting Lambda Executor and determining health...") + try: + self.check_health() + except AirflowException: + self.log.error("Stopping the Airflow Scheduler from starting until the issue is resolved.") + raise + + def check_health(self): + """ + Check the health of the Lambda and SQS connections. + + For lambda: Use get_function to test if the lambda connection works and the function can be + described. + For SQS: Use get_queue_attributes is used as a close analog to describe to test if the SQS + connection is working. + """ + self.IS_BOTO_CONNECTION_HEALTHY = False + + def _check_queue(queue_url): + sqs_get_queue_attrs_response = self.sqs_client.get_queue_attributes( + QueueUrl=queue_url, AttributeNames=["ApproximateNumberOfMessages"] + ) + approx_num_msgs = sqs_get_queue_attrs_response.get("Attributes").get( + "ApproximateNumberOfMessages" + ) + self.log.info( + "SQS connection is healthy and queue %s is present with %s messages.", + queue_url, + approx_num_msgs, + ) + + self.log.info("Checking Lambda and SQS connections") + try: + # Check Lambda health + lambda_get_response = self.lambda_client.get_function(FunctionName=self.lambda_function_name) + if self.lambda_function_name not in lambda_get_response["Configuration"]["FunctionName"]: + raise AirflowException("Lambda function %s not found.", self.lambda_function_name) + self.log.info( + "Lambda connection is healthy and function %s is present.", self.lambda_function_name + ) + + # Check SQS results queue + _check_queue(self.sqs_queue_url) + # Check SQS dead letter queue + _check_queue(self.dlq_url) + + # If we reach this point, both connections are healthy and all resources are present + self.IS_BOTO_CONNECTION_HEALTHY = True + except Exception: + self.log.exception("Lambda Executor health check failed") + raise AirflowException( + "The Lambda executor will not be able to run Airflow tasks until the issue is addressed." + ) + + def load_connections(self, check_connection: bool = True): + """ + Retrieve the AWS connection via Hooks to leverage the Airflow connection system. + + :param check_connection: If True, check the health of the connection after loading it. + """ + self.log.info("Loading Connections") + aws_conn_id = conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.AWS_CONN_ID) + region_name = conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.REGION_NAME, fallback=None) + self.sqs_client = SqsHook(aws_conn_id=aws_conn_id, region_name=region_name).conn + self.lambda_client = LambdaHook(aws_conn_id=aws_conn_id, region_name=region_name).conn + + self.attempts_since_last_successful_connection += 1 + self.last_connection_reload = timezone.utcnow() + + if check_connection: + self.check_health() + self.attempts_since_last_successful_connection = 0 + + def sync(self): + """ + Sync the executor with the current state of tasks. + + Check in on currently running tasks and attempt to run any new tasks that have been queued. + """ + if not self.IS_BOTO_CONNECTION_HEALTHY: + exponential_backoff_retry( + self.last_connection_reload, + self.attempts_since_last_successful_connection, + self.load_connections, + ) + if not self.IS_BOTO_CONNECTION_HEALTHY: + return + try: + self.sync_running_tasks() + self.attempt_task_runs() + except (ClientError, NoCredentialsError) as error: + error_code = error.response["Error"]["Code"] + if error_code in INVALID_CREDENTIALS_EXCEPTIONS: + self.IS_BOTO_CONNECTION_HEALTHY = False + self.log.warning( + "AWS credentials are either missing or expired: %s.\nRetrying connection", error + ) + except Exception: + self.log.exception("An error occurred while syncing tasks") + + def queue_workload(self, workload: workloads.All, session: Session | None) -> None: + from airflow.executors import workloads + + if not isinstance(workload, workloads.ExecuteTask): + raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}") + ti = workload.ti + self.queued_tasks[ti.key] = workload + + def _process_workloads(self, workloads: list[workloads.All]) -> None: + from airflow.executors.workloads import ExecuteTask + + for w in workloads: + if not isinstance(w, ExecuteTask): + raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(w)}") + + command = [w] + key = w.ti.key + queue = w.ti.queue + executor_config = w.ti.executor_config or {} + + del self.queued_tasks[key] + self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) # type: ignore[arg-type] + self.running.add(key) + + def execute_async(self, key: TaskInstanceKey, command: CommandType, queue=None, executor_config=None): + """ + Save the task to be executed in the next sync by inserting the commands into a queue. + + :param key: A unique task key (typically a tuple identifying the task instance). + :param command: The shell command string to execute. + :param executor_config: (Unused) to keep the same signature as the base. + :param queue: (Unused) to keep the same signature as the base. + """ + if len(command) == 1: + from airflow.executors.workloads import ExecuteTask + + if isinstance(command[0], ExecuteTask): + workload = command[0] + ser_input = workload.model_dump_json() + command = [ + "python", + "-m", + "airflow.sdk.execution_time.execute_workload", + "--json-string", + ser_input, + ] + else: + raise RuntimeError( + f"LambdaExecutor doesn't know how to handle workload of type: {type(command[0])}" + ) + + self.pending_tasks.append( + LambdaQueuedTask( + key, command, queue if queue else "", executor_config or {}, 1, timezone.utcnow() + ) + ) + + def attempt_task_runs(self): + """ + Attempt to run tasks that are queued in the pending_tasks. + + Each task is submitted to AWS Lambda with a payload containing the task key and command. + The task key is used to track the task's state in Airflow. + """ + queue_len = len(self.pending_tasks) + for _ in range(queue_len): + task_to_run = self.pending_tasks.popleft() + task_key = task_to_run.key + cmd = task_to_run.command + attempt_number = task_to_run.attempt_number + failure_reasons = [] + ser_task_key = json.dumps(task_key._asdict()) + payload = { + "task_key": ser_task_key, + "command": cmd, + } + if timezone.utcnow() < task_to_run.next_attempt_time: + self.pending_tasks.append(task_to_run) + continue + + self.log.info("Submitting task %s to Lambda function %s", task_key, self.lambda_function_name) + + try: + invoke_kwargs = { + "FunctionName": self.lambda_function_name, + "InvocationType": "Event", + "Payload": json.dumps(payload), + } + if self.qualifier: + invoke_kwargs["Qualifier"] = self.qualifier + response = self.lambda_client.invoke(**invoke_kwargs) + except NoCredentialsError: + self.pending_tasks.append(task_to_run) + raise + except ClientError as e: + error_code = e.response["Error"]["Code"] + if error_code in INVALID_CREDENTIALS_EXCEPTIONS: + self.pending_tasks.append(task_to_run) + raise + failure_reasons.append(str(e)) + except Exception as e: + # Failed to even get a response back from the Boto3 API or something else went + # wrong. For any possible failure we want to add the exception reasons to the + # failure list so that it is logged to the user and most importantly the task is + # added back to the pending list to be retried later. + failure_reasons.append(str(e)) + + if failure_reasons: + # Make sure the number of attempts does not exceed max invoke attempts + if int(attempt_number) < int(self.max_invoke_attempts): + task_to_run.attempt_number += 1 + task_to_run.next_attempt_time = timezone.utcnow() + calculate_next_attempt_delay( + attempt_number + ) + self.pending_tasks.append(task_to_run) + else: + reasons_str = ", ".join(failure_reasons) + self.log.error( + "Lambda invoke %s has failed a maximum of %s times. Marking as failed. Reasons: %s", + task_key, + attempt_number, + reasons_str, + ) + self.log_task_event( + event="lambda invoke failure", + ti_key=task_key, + extra=( + f"Task could not be queued after {attempt_number} attempts. " + f"Marking as failed. Reasons: {reasons_str}" + ), + ) + self.fail(task_key) + else: + status_code = response.get("StatusCode") + self.log.info("Invoked Lambda for task %s with status %s", task_key, status_code) + self.running_tasks[ser_task_key] = task_key + # Add the serialized task key as the info, this will be assigned on the ti as the external_executor_id + self.running_state(task_key, ser_task_key) + + def sync_running_tasks(self): + """ + Poll the SQS queue for messages indicating task completion. + + Each message is expected to contain a JSON payload with 'task_key' and 'return_code'. + Based on the return code, update the task state accordingly. + """ + if not len(self.running_tasks): + self.log.debug("No running tasks to process.") + return + + self.process_queue(self.sqs_queue_url) + if self.dlq_url and self.running_tasks: + self.process_queue(self.dlq_url) + + def process_queue(self, queue_url: str): + """ + Poll the SQS queue for messages indicating task completion. + + Each message is expected to contain a JSON payload with 'task_key' and 'return_code'. + + Based on the return code, update the task state accordingly. + """ + response = self.sqs_client.receive_message( + QueueUrl=queue_url, + MaxNumberOfMessages=10, + ) + + messages = response.get("Messages", []) + # Pagination? Maybe we don't need it. Since we always delete messages after looking at them. + # But then that may delete messages that could have been adopted. Let's leave it for now and see how it goes. + if messages and queue_url == self.dlq_url: + self.log.warning("%d messages received from the dead letter queue", len(messages)) + + for message in messages: + receipt_handle = message["ReceiptHandle"] + body = json.loads(message["Body"]) + return_code = body.get("return_code") + ser_task_key = body.get("task_key") + # Fetch the real task key from the running_tasks dict, using the serialized task key. + try: + task_key = self.running_tasks[ser_task_key] + except KeyError: + self.log.warning( + "Received task %s from the queue which is not found in running tasks. Removing message.", + ser_task_key, + ) + task_key = None + + if task_key: + if return_code == 0: + self.success(task_key) + self.log.info( + "Successful Lambda invocation for task %s received from SQS queue.", task_key + ) + else: + # In this case the Lambda likely started but failed at run time since we got a non-zero + # return code. We could consider retrying these tasks within the executor, because this _likely_ + # means the Airflow task did not run to completion, however we can't be sure (maybe the + # lambda runtime code has a bug and is returning a non-zero when it actually passed?). So + # perhaps not retrying is the safest option. + self.fail(task_key) + self.log.error( + "Lambda invocation for task: %s has failed to run with return code %s", + task_key, + return_code, + ) + # Remove the task from the tracking mapping. + self.running_tasks.pop(ser_task_key) + + # Delete the message from the queue. + self.sqs_client.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_handle) + + def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]: + """ + Adopt task instances which have an external_executor_id (the serialized task key). + + Anything that is not adopted will be cleared by the scheduler and becomes eligible for re-scheduling. + + :param tis: The task instances to adopt. + """ + with Stats.timer("lambda_executor.adopt_task_instances.duration"): + adopted_tis: list[TaskInstance] = [] + + if serialized_task_keys := [ + (ti, ti.external_executor_id) for ti in tis if ti.external_executor_id + ]: + for ti, ser_task_key in serialized_task_keys: + try: + task_key = TaskInstanceKey.from_dict(json.loads(ser_task_key)) + except Exception: + # If that task fails to deserialize, we should just skip it. + self.log.exception( + "Task failed to be adopted because the key could not be deserialized" + ) + continue + self.running_tasks[ser_task_key] = task_key + adopted_tis.append(ti) + + if adopted_tis: + tasks = [f"{task} in state {task.state}" for task in adopted_tis] + task_instance_str = "\n\t".join(tasks) + self.log.info( + "Adopted the following %d tasks from a dead executor:\n\t%s", + len(adopted_tis), + task_instance_str, + ) + + not_adopted_tis = [ti for ti in tis if ti not in adopted_tis] + return not_adopted_tis + + def end(self, heartbeat_interval=10): + """ + End execution. Poll until all outstanding tasks are marked as completed. + + This is a blocking call and async Lambda tasks can not be cancelled, so this will wait until + all tasks are either completed or the timeout is reached. + + :param heartbeat_interval: The interval in seconds to wait between checks for task completion. + """ + self.log.info("Received signal to end, waiting for outstanding tasks to finish.") + time_to_wait = int(conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.END_WAIT_TIMEOUT)) + start_time = timezone.utcnow() + while True: + if time_to_wait: + current_time = timezone.utcnow() + elapsed_time = (current_time - start_time).total_seconds() + if elapsed_time > time_to_wait: + self.log.warning( + "Timed out waiting for tasks to finish. Some tasks may not be handled gracefully" + " as the executor is force ending due to timeout." + ) + break + self.sync() + if not self.running_tasks: + self.log.info("All tasks completed; executor ending.") + break + self.log.info("Waiting for %d task(s) to complete.", len(self.running_tasks)) + time.sleep(heartbeat_interval) + + def terminate(self): + """Get called when the daemon receives a SIGTERM.""" + self.log.warning("Terminating Lambda executor. In-flight tasks cannot be stopped.") diff --git a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/utils.py b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/utils.py new file mode 100644 index 00000000000..483b81508db --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/utils.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import datetime +from collections.abc import Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from airflow.providers.amazon.aws.executors.utils.base_config_keys import BaseConfigKeys + +if TYPE_CHECKING: + from airflow.models.taskinstancekey import TaskInstanceKey + + +CONFIG_GROUP_NAME = "aws_lambda_executor" +INVALID_CREDENTIALS_EXCEPTIONS = [ + "ExpiredTokenException", + "InvalidClientTokenId", + "UnrecognizedClientException", +] + + +@dataclass +class LambdaQueuedTask: + """Represents a Lambda task that is queued. The task will be run in the next heartbeat.""" + + key: TaskInstanceKey + command: CommandType + queue: str + executor_config: ExecutorConfigType + attempt_number: int + next_attempt_time: datetime.datetime + + +class InvokeLambdaKwargsConfigKeys(BaseConfigKeys): + """Config keys loaded which are valid lambda invoke args.""" + + FUNCTION_NAME = "function_name" + QUALIFIER = "function_qualifier" + + +class AllLambdaConfigKeys(InvokeLambdaKwargsConfigKeys): + """All config keys which are related to the Lambda Executor.""" + + AWS_CONN_ID = "conn_id" + CHECK_HEALTH_ON_STARTUP = "check_health_on_startup" + MAX_INVOKE_ATTEMPTS = "max_run_task_attempts" + REGION_NAME = "region_name" + QUEUE_URL = "queue_url" + DLQ_URL = "dead_letter_queue_url" + END_WAIT_TIMEOUT = "end_wait_timeout" + + +CommandType = Sequence[str] +ExecutorConfigType = dict[str, Any] diff --git a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py index 136c907a08e..731e9bb377c 100644 --- a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py +++ b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py @@ -1152,6 +1152,74 @@ def get_provider_info(): }, }, }, + "aws_lambda_executor": { + "description": "This section only applies if you are using the AwsLambdaExecutor in\nAirflow's ``[core.executor]`` configuration.\nFor more information see:\nhttps://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/lambda/client/invoke.html\nhttps://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html\nhttps://airflow.apache.org/docs/apache-airflow-providers-amazon/stable/executors/lambda-executor.html\n", + "options": { + "conn_id": { + "description": "The Airflow connection (i.e. credentials) used by the Lambda executor to make API calls.\n", + "version_added": "9.9.0", + "type": "string", + "example": "aws_default", + "default": "aws_default", + }, + "region_name": { + "description": "The name of the AWS Region where Amazon Lambda is configured.\n", + "version_added": "9.9.0", + "type": "string", + "example": "us-east-1", + "default": None, + }, + "check_health_on_startup": { + "description": "Whether or not to check the Lambda Executor health on startup.\n", + "version_added": "9.9.0", + "type": "boolean", + "example": "True", + "default": "True", + }, + "max_run_task_attempts": { + "description": "The maximum number of times the Lambda Executor should attempt to start an Airflow task.\n", + "version_added": "9.9.0", + "type": "integer", + "example": "3", + "default": "3", + }, + "queue_url": { + "description": "The URL of the SQS queue to use for the Lambda executor. Required.\nThe Lambda executor will poll this queue for results of the lambda function's Airflow Task.\n", + "version_added": "9.9.0", + "type": "string", + "example": "airflow-lambda-executor-results-queue", + "default": None, + }, + "dead_letter_queue_url": { + "description": "The URL of the SQS dead letter queue to use for the Lambda function. The Lambda executor\nwill poll this queue for timeout/exception results of the lambda function. Required\n", + "version_added": "9.9.0", + "type": "string", + "example": "airflow-lambda-executor-dlq", + "default": None, + }, + "function_name": { + "description": "The name of the Lambda function to invoke. Required.\n", + "version_added": "9.9.0", + "type": "string", + "example": "airflow-lambda-executor-function", + "default": None, + }, + "qualifier": { + "description": "The version or alias of the Lambda function to invoke. If not specified, the latest version is used.\n", + "version_added": "9.9.0", + "type": "string", + "example": "1", + "default": None, + }, + "end_wait_timeout": { + "description": "The number of seconds to wait for all lambda executor invocations to complete when terminating the\nexecutor/scheduler. Default of 0 means wait indefinitely.\n", + "version_added": "9.9.0", + "type": "integer", + "example": "500", + "default": "0", + }, + }, + }, "aws_ecs_executor": { "description": "This section only applies if you are using the AwsEcsExecutor in\nAirflow's ``[core]`` configuration.\nFor more information on any of these execution parameters, see the link below:\nhttps://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs/client/run_task.html\nFor boto3 credential management, see\nhttps://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html\n", "options": { diff --git a/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/__init__.py b/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/__init__.py new file mode 100644 index 00000000000..13a83393a91 --- /dev/null +++ b/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py b/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py new file mode 100644 index 00000000000..f1123ca5e07 --- /dev/null +++ b/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py @@ -0,0 +1,928 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import datetime as dt +import json +from unittest import mock + +import pytest +from botocore.exceptions import ClientError +from semver import VersionInfo + +from airflow.exceptions import AirflowException +from airflow.executors.base_executor import BaseExecutor +from airflow.models.taskinstance import TaskInstance +from airflow.models.taskinstancekey import TaskInstanceKey +from airflow.providers.amazon.aws.executors.aws_lambda import lambda_executor +from airflow.providers.amazon.aws.executors.aws_lambda.lambda_executor import AwsLambdaExecutor +from airflow.providers.amazon.aws.executors.aws_lambda.utils import CONFIG_GROUP_NAME, AllLambdaConfigKeys +from airflow.utils import timezone +from airflow.utils.state import TaskInstanceState +from airflow.version import version as airflow_version_str + +from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS + +airflow_version = VersionInfo(*map(int, airflow_version_str.split(".")[:3])) + +DEFAULT_QUEUE_URL = "queue-url" +DEFAULT_DLQ_URL = "dlq-url" +DEFAULT_FUNCTION_NAME = "function-name" + + +@pytest.fixture +def set_env_vars(): + overrides: dict[tuple[str, str], str] = { + (CONFIG_GROUP_NAME, AllLambdaConfigKeys.REGION_NAME): "us-west-1", + (CONFIG_GROUP_NAME, AllLambdaConfigKeys.FUNCTION_NAME): DEFAULT_FUNCTION_NAME, + (CONFIG_GROUP_NAME, AllLambdaConfigKeys.QUEUE_URL): DEFAULT_QUEUE_URL, + (CONFIG_GROUP_NAME, AllLambdaConfigKeys.DLQ_URL): DEFAULT_DLQ_URL, + (CONFIG_GROUP_NAME, AllLambdaConfigKeys.QUALIFIER): "1", + (CONFIG_GROUP_NAME, AllLambdaConfigKeys.MAX_INVOKE_ATTEMPTS): "3", + (CONFIG_GROUP_NAME, AllLambdaConfigKeys.CHECK_HEALTH_ON_STARTUP): "True", + } + with conf_vars(overrides): + yield + + +@pytest.fixture +def mock_airflow_key(): + def _key(): + key_mock = mock.Mock() + # Use a "random" value (memory id of the mock obj) so each key serializes uniquely + key_mock._asdict = mock.Mock(return_value={"mock_key": id(key_mock)}) + return key_mock + + return _key + + +def _generate_mock_cmd(): + return ["airflow", "tasks", "run", "dag_id", "task_id", "run_id", "--local"] + + +# The following two fixtures look different because no existing test +# cares if they have unique values, so the same value is always used. +@pytest.fixture +def mock_cmd(): + return _generate_mock_cmd() + + +@pytest.fixture +def mock_executor(set_env_vars) -> AwsLambdaExecutor: + """Mock Lambda to a repeatable starting state..""" + executor = AwsLambdaExecutor() + executor.IS_BOTO_CONNECTION_HEALTHY = True + + # Replace boto3 clients with mocks + lambda_mock = mock.Mock(spec=executor.lambda_client) + lambda_mock.invoke.return_value = {"StatusCode": 0, "failures": []} + executor.lambda_client = lambda_mock + + sqs_mock = mock.Mock(spec=executor.sqs_client) + sqs_mock.receive_message.return_value = {"Messages": []} + executor.sqs_client = sqs_mock + + return executor + + +class TestAwsLambdaExecutor: + @mock.patch( + "airflow.providers.amazon.aws.executors.aws_lambda.lambda_executor.AwsLambdaExecutor.change_state" + ) + def test_execute(self, change_state_mock, mock_airflow_key, mock_executor, mock_cmd): + """Test execution from end-to-end.""" + airflow_key = mock_airflow_key() + ser_airflow_key = json.dumps(airflow_key._asdict()) + + assert len(mock_executor.pending_tasks) == 0 + mock_executor.execute_async(airflow_key, mock_cmd) + assert len(mock_executor.pending_tasks) == 1 + + mock_executor.attempt_task_runs() + mock_executor.lambda_client.invoke.assert_called_once() + + # Task is stored in active worker. + assert len(mock_executor.running_tasks) == 1 + assert json.dumps(airflow_key._asdict()) in mock_executor.running_tasks + change_state_mock.assert_called_once_with( + airflow_key, TaskInstanceState.RUNNING, ser_airflow_key, remove_running=False + ) + + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3+") + @mock.patch( + "airflow.providers.amazon.aws.executors.aws_lambda.lambda_executor.AwsLambdaExecutor.change_state" + ) + def test_task_sdk(self, change_state_mock, mock_airflow_key, mock_executor, mock_cmd): + """Test task sdk execution from end-to-end.""" + from airflow.executors.workloads import ExecuteTask + + airflow_key = mock_airflow_key() + ser_airflow_key = json.dumps(airflow_key._asdict()) + + workload = mock.Mock(spec=ExecuteTask) + workload.ti = mock.Mock(spec=TaskInstance) + workload.ti.key = airflow_key + ser_workload = json.dumps({"test_key": "test_value"}) + workload.model_dump_json.return_value = ser_workload + + mock_executor.queue_workload(workload, mock.Mock()) + + assert mock_executor.queued_tasks[workload.ti.key] == workload + assert len(mock_executor.pending_tasks) == 0 + assert len(mock_executor.running) == 0 + mock_executor._process_workloads([workload]) + assert len(mock_executor.queued_tasks) == 0 + assert len(mock_executor.running) == 1 + assert workload.ti.key in mock_executor.running + assert len(mock_executor.pending_tasks) == 1 + assert mock_executor.pending_tasks[0].command == [ + "python", + "-m", + "airflow.sdk.execution_time.execute_workload", + "--json-string", + '{"test_key": "test_value"}', + ] + + mock_executor.attempt_task_runs() + mock_executor.lambda_client.invoke.assert_called_once() + assert len(mock_executor.pending_tasks) == 0 + + # Task is stored in active worker. + assert len(mock_executor.running_tasks) == 1 + assert mock_executor.running_tasks[ser_airflow_key] == workload.ti.key + change_state_mock.assert_called_once_with( + workload.ti.key, TaskInstanceState.RUNNING, ser_airflow_key, remove_running=False + ) + + @mock.patch.object(lambda_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0)) + def test_success_execute_api_exception(self, mock_backoff, mock_executor, mock_cmd, mock_airflow_key): + """Test what happens when Lambda throws an initial exception on invoke, but ultimately passes on retries.""" + airflow_key = mock_airflow_key() + ser_airflow_key = json.dumps(airflow_key._asdict()) + + invoke_exception = Exception("Test exception") + invoke_success = {"StatusCode": 0, "failures": []} + mock_executor.lambda_client.invoke.side_effect = [invoke_exception, invoke_exception, invoke_success] + mock_executor.execute_async(airflow_key, mock_cmd) + expected_retry_count = 2 + + # Fail 2 times + for _ in range(expected_retry_count): + mock_executor.attempt_task_runs() + # Task is not stored in active workers. + assert len(mock_executor.running_tasks) == 0 + + # Pass in last attempt + mock_executor.attempt_task_runs() + assert len(mock_executor.pending_tasks) == 0 + assert ser_airflow_key in mock_executor.running_tasks + assert mock_backoff.call_count == expected_retry_count + for attempt_number in range(1, expected_retry_count): + mock_backoff.assert_has_calls([mock.call(attempt_number)]) + + def test_failed_execute_api_exception(self, mock_executor, mock_cmd, mock_airflow_key): + """Test what happens when Lambda refuses to execute a task and throws an exception""" + mock_airflow_key = mock_airflow_key() + + mock_executor.lambda_client.invoke.side_effect = Exception("Test exception") + mock_executor.execute_async(mock_airflow_key, mock_cmd) + + # No matter what, don't schedule until invoke becomes successful. + for _ in range(int(mock_executor.max_invoke_attempts) * 2): + mock_executor.attempt_task_runs() + # Task is not stored in running tasks + assert len(mock_executor.running_tasks) == 0 + + def test_failed_execute_creds_exception(self, mock_executor, mock_cmd, mock_airflow_key): + """Test what happens when Lambda refuses to execute a task and throws an exception due to credentials""" + airflow_key = mock_airflow_key() + + mock_executor.IS_BOTO_CONNECTION_HEALTHY = True + mock_executor.execute_async(airflow_key, mock_cmd) + assert mock_executor.pending_tasks[0].attempt_number == 1 + + error_to_raise = ClientError( + {"Error": {"Code": "ExpiredTokenException", "Message": "foobar"}}, "OperationName" + ) + mock_executor.lambda_client.invoke.side_effect = error_to_raise + + # Sync will ultimately call attempt_task_runs, which is the code under test + mock_executor.sync() + + # Task should end up back in the queue + assert mock_executor.pending_tasks[0].key == airflow_key + # The connection should get marked as unhealthy + assert not mock_executor.IS_BOTO_CONNECTION_HEALTHY + # We retry on connections issues indefinitely, so the attempt number should be 1 + assert mock_executor.pending_tasks[0].attempt_number == 1 + + def test_failed_execute_client_error_exception(self, mock_executor, mock_cmd, mock_airflow_key): + """Test what happens when Lambda refuses to execute a task and throws an exception for non-credentials issue""" + airflow_key = mock_airflow_key() + mock_executor.IS_BOTO_CONNECTION_HEALTHY = True + mock_executor.execute_async(airflow_key, mock_cmd) + assert mock_executor.pending_tasks[0].attempt_number == 1 + + error_to_raise = ClientError( + {"Error": {"Code": "RandomeError", "Message": "foobar"}}, "OperationName" + ) + mock_executor.lambda_client.invoke.side_effect = error_to_raise + + # Sync will ultimately call attempt_task_runs, which is the code under test + mock_executor.sync() + + # Task should end up back in the queue + assert mock_executor.pending_tasks[0].key == airflow_key + # The connection should stay marked as healthy because the error is something else + assert mock_executor.IS_BOTO_CONNECTION_HEALTHY + # Not a retry so increment attempts + assert mock_executor.pending_tasks[0].attempt_number == 2 + + @mock.patch.object(lambda_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0)) + def test_attempt_task_runs_attempts_when_tasks_fail(self, _, mock_executor): + """ + Test case when all tasks fail to run. + + The executor should attempt each task exactly once per sync() iteration. + It should preserve the order of tasks, and attempt each task up to + `max_invoke_attempts` times before dropping the task. + """ + airflow_keys = [ + TaskInstanceKey("a", "task_a", "c", 1, -1), + TaskInstanceKey("a", "task_b", "c", 1, -1), + ] + airflow_cmd1 = _generate_mock_cmd() + airflow_cmd2 = _generate_mock_cmd() + commands = [airflow_cmd1, airflow_cmd2] + + failures = [Exception("Failure 1"), Exception("Failure 2")] + + mock_executor.execute_async(airflow_keys[0], commands[0]) + mock_executor.execute_async(airflow_keys[1], commands[1]) + + assert len(mock_executor.pending_tasks) == 2 + assert len(mock_executor.running_tasks) == 0 + + mock_executor.lambda_client.invoke.side_effect = failures + mock_executor.attempt_task_runs() + + for i in range(2): + payload = json.loads(mock_executor.lambda_client.invoke.call_args_list[i].kwargs["Payload"]) + assert airflow_keys[i].task_id in payload["task_key"] + + assert len(mock_executor.pending_tasks) == 2 + assert len(mock_executor.running_tasks) == 0 + + mock_executor.lambda_client.invoke.call_args_list.clear() + + mock_executor.lambda_client.invoke.side_effect = failures + mock_executor.attempt_task_runs() + + for i in range(2): + payload = json.loads(mock_executor.lambda_client.invoke.call_args_list[i].kwargs["Payload"]) + assert airflow_keys[i].task_id in payload["task_key"] + + assert len(mock_executor.pending_tasks) == 2 + assert len(mock_executor.running_tasks) == 0 + + mock_executor.lambda_client.invoke.call_args_list.clear() + + mock_executor.lambda_client.invoke.side_effect = failures + mock_executor.attempt_task_runs() + + assert ( + len(mock_executor.pending_tasks) == 0 + ) # Pending now zero since we've had three failures to invoke + assert len(mock_executor.running_tasks) == 0 + + if airflow_version >= (2, 10, 0): + events = [(x.event, x.task_id, x.try_number) for x in mock_executor._task_event_logs] + assert events == [ + ("lambda invoke failure", "task_a", 1), + ("lambda invoke failure", "task_b", 1), + ] + + @mock.patch.object(lambda_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0)) + def test_attempt_task_runs_attempts_when_some_tasks_fal(self, _, mock_executor): + """ + Test case when one task fail to run, others succeed, and a new task gets queued. + + """ + airflow_keys = [ + TaskInstanceKey("a", "task_a", "c", 1, -1), + TaskInstanceKey("a", "task_b", "c", 1, -1), + ] + airflow_cmd1 = _generate_mock_cmd() + airflow_cmd2 = _generate_mock_cmd() + airflow_commands = [airflow_cmd1, airflow_cmd2] + + success_response = {"StatusCode": 0, "failures": []} + + responses = [Exception("Failure 1"), success_response] + + mock_executor.execute_async(airflow_keys[0], airflow_commands[0]) + mock_executor.execute_async(airflow_keys[1], airflow_commands[1]) + + assert len(mock_executor.pending_tasks) == 2 + + mock_executor.lambda_client.invoke.side_effect = responses + mock_executor.attempt_task_runs() + + for i in range(2): + payload = json.loads(mock_executor.lambda_client.invoke.call_args_list[i].kwargs["Payload"]) + assert airflow_keys[i].task_id in payload["task_key"] + + assert len(mock_executor.pending_tasks) == 1 + assert len(mock_executor.running_tasks) == 1 + + mock_executor.lambda_client.invoke.call_args_list.clear() + + # queue new task + airflow_keys[1] = TaskInstanceKey("a", "task_c", "c", 1, -1) + airflow_commands[1] = _generate_mock_cmd() + mock_executor.execute_async(airflow_keys[1], airflow_commands[1]) + + assert len(mock_executor.pending_tasks) == 2 + # assert that the order of pending tasks is preserved i.e. the first task is 1st etc. + assert mock_executor.pending_tasks[0].key == airflow_keys[0] + assert mock_executor.pending_tasks[0].command == airflow_commands[0] + + responses = [Exception("Failure 1"), success_response] + mock_executor.lambda_client.invoke.side_effect = responses + mock_executor.attempt_task_runs() + + for i in range(2): + payload = json.loads(mock_executor.lambda_client.invoke.call_args_list[i].kwargs["Payload"]) + assert airflow_keys[i].task_id in payload["task_key"] + + assert len(mock_executor.pending_tasks) == 1 + assert len(mock_executor.running_tasks) == 2 + + mock_executor.lambda_client.invoke.call_args_list.clear() + + responses = [Exception("Failure 1")] + mock_executor.lambda_client.invoke.side_effect = responses + mock_executor.attempt_task_runs() + + payload = json.loads(mock_executor.lambda_client.invoke.call_args_list[0].kwargs["Payload"]) + assert airflow_keys[0].task_id in payload["task_key"] + + if airflow_version >= (2, 10, 0): + events = [(x.event, x.task_id, x.try_number) for x in mock_executor._task_event_logs] + assert events == [("lambda invoke failure", "task_a", 1)] + + @mock.patch.object(BaseExecutor, "fail") + @mock.patch.object(BaseExecutor, "success") + def test_sync_running_dlq(self, success_mock, fail_mock, mock_executor, mock_airflow_key): + airflow_key = mock_airflow_key() + ser_airflow_key = json.dumps(airflow_key._asdict()) + + mock_executor.running_tasks.clear() + mock_executor.running_tasks[ser_airflow_key] = airflow_key + mock_executor.sqs_client.receive_message.side_effect = [ + {}, # First request from the results queue will be empt + { + # Second request from the DLQ will have a message + "Messages": [ + { + "ReceiptHandle": "receipt_handle", + "Body": json.dumps( + { + "task_key": ser_airflow_key, + # DLQ messages will have the input (task_key, command) instead of return_code + "command": "command", + } + ), + } + ] + }, + ] + + mock_executor.sync_running_tasks() + # Receive messages should be called twice + assert mock_executor.sqs_client.receive_message.call_count == 2 + assert mock_executor.sqs_client.receive_message.call_args_list[0].kwargs == { + "QueueUrl": DEFAULT_QUEUE_URL, + "MaxNumberOfMessages": 10, + } + + assert mock_executor.sqs_client.receive_message.call_args_list[1].kwargs == { + "QueueUrl": DEFAULT_DLQ_URL, + "MaxNumberOfMessages": 10, + } + + # Task is not stored in active workers. + assert len(mock_executor.running_tasks) == 0 + success_mock.assert_not_called() + fail_mock.assert_called_once() + assert mock_executor.sqs_client.delete_message.call_count == 1 + assert mock_executor.sqs_client.delete_message.call_args_list[0].kwargs == { + "QueueUrl": DEFAULT_DLQ_URL, + "ReceiptHandle": "receipt_handle", + } + + @mock.patch.object(BaseExecutor, "fail") + @mock.patch.object(BaseExecutor, "success") + def test_sync_running_success(self, success_mock, fail_mock, mock_executor, mock_airflow_key): + airflow_key = mock_airflow_key() + ser_airflow_key = json.dumps(airflow_key._asdict()) + + mock_executor.running_tasks.clear() + mock_executor.running_tasks[ser_airflow_key] = airflow_key + # Success message + mock_executor.sqs_client.receive_message.return_value = { + "Messages": [ + { + "ReceiptHandle": "receipt_handle", + "Body": json.dumps( + { + "task_key": ser_airflow_key, + "return_code": 0, + } + ), + } + ] + } + + mock_executor.sync_running_tasks() + mock_executor.sqs_client.receive_message.assert_called_once() + assert mock_executor.sqs_client.receive_message.call_args_list[0].kwargs == { + "QueueUrl": DEFAULT_QUEUE_URL, + "MaxNumberOfMessages": 10, + } + + # Task is not stored in active workers. + assert len(mock_executor.running_tasks) == 0 + # Task is immediately succeeded. + success_mock.assert_called_once() + fail_mock.assert_not_called() + assert mock_executor.sqs_client.delete_message.call_count == 1 + assert mock_executor.sqs_client.delete_message.call_args_list[0].kwargs == { + "QueueUrl": DEFAULT_QUEUE_URL, + "ReceiptHandle": "receipt_handle", + } + + @mock.patch.object(BaseExecutor, "fail") + @mock.patch.object(BaseExecutor, "success") + def test_sync_running_fail(self, success_mock, fail_mock, mock_executor, mock_airflow_key): + airflow_key = mock_airflow_key() + ser_airflow_key = json.dumps(airflow_key._asdict()) + + mock_executor.running_tasks.clear() + mock_executor.running_tasks[ser_airflow_key] = airflow_key + # Failure message + mock_executor.sqs_client.receive_message.return_value = { + "Messages": [ + { + "ReceiptHandle": "receipt_handle", + "Body": json.dumps( + { + "task_key": ser_airflow_key, + "return_code": 1, # Non-zero return code, task failed + } + ), + } + ] + } + + mock_executor.sync_running_tasks() + mock_executor.sqs_client.receive_message.assert_called_once() + + # Task is not stored in active workers. + assert len(mock_executor.running_tasks) == 0 + # Task is immediately succeeded. + success_mock.assert_not_called() + fail_mock.assert_called_once() + assert mock_executor.sqs_client.delete_message.call_count == 1 + + @mock.patch.object(BaseExecutor, "fail") + @mock.patch.object(BaseExecutor, "success") + def test_sync_running_short_circuit(self, success_mock, fail_mock, mock_executor, mock_airflow_key): + mock_executor.running_tasks.clear() + # No running tasks, so we will short circuit + + mock_executor.sync_running_tasks() + mock_executor.sqs_client.receive_message.assert_not_called() + + # Task is still stored in active workers. + assert len(mock_executor.running_tasks) == 0 + # Task is immediately succeeded. + success_mock.assert_not_called() + fail_mock.assert_not_called() + assert mock_executor.sqs_client.delete_message.call_count == 0 + + @mock.patch.object(BaseExecutor, "fail") + @mock.patch.object(BaseExecutor, "success") + def test_sync_running_no_updates(self, success_mock, fail_mock, mock_executor, mock_airflow_key): + airflow_key = mock_airflow_key() + ser_airflow_key = json.dumps(airflow_key._asdict()) + + mock_executor.running_tasks.clear() + mock_executor.running_tasks[ser_airflow_key] = airflow_key + # No messages, so we will not loop + mock_executor.sqs_client.receive_message.return_value = {"Messages": []} + + mock_executor.sync_running_tasks() + # Both the results queue and DLQ should have been checked + assert mock_executor.sqs_client.receive_message.call_count == 2 + + # Task is still stored in active workers. + assert len(mock_executor.running_tasks) == 1 + # Task is immediately succeeded. + success_mock.assert_not_called() + fail_mock.assert_not_called() + assert mock_executor.sqs_client.delete_message.call_count == 0 + + @mock.patch.object(BaseExecutor, "fail") + @mock.patch.object(BaseExecutor, "success") + def test_sync_running_two_tasks_one_relevant( + self, success_mock, fail_mock, mock_executor, mock_airflow_key + ): + airflow_key = mock_airflow_key() + ser_airflow_key = json.dumps(airflow_key._asdict()) + airflow_key_2 = mock_airflow_key() + ser_airflow_key_2 = json.dumps(airflow_key_2._asdict()) + + mock_executor.running_tasks.clear() + mock_executor.running_tasks[ser_airflow_key] = airflow_key + mock_executor.running_tasks[ser_airflow_key_2] = airflow_key_2 + # Success message + mock_executor.sqs_client.receive_message.side_effect = [ + { + "Messages": [ + { + "ReceiptHandle": "receipt_handle", + "Body": json.dumps( + { + "task_key": ser_airflow_key, + "return_code": 0, + } + ), + } + ] + }, + {}, # No messages from DLQ + ] + + mock_executor.sync_running_tasks() + # Both the results queue and DLQ should have been checked + assert mock_executor.sqs_client.receive_message.call_count == 2 + + # One task left running + assert len(mock_executor.running_tasks) == 1 + # Task one completed, task two is still running + assert ser_airflow_key_2 in mock_executor.running_tasks + # Task is immediately succeeded. + success_mock.assert_called_once() + fail_mock.assert_not_called() + assert mock_executor.sqs_client.delete_message.call_count == 1 + + @mock.patch.object(BaseExecutor, "fail") + @mock.patch.object(BaseExecutor, "success") + def test_sync_running_unknown_task(self, success_mock, fail_mock, mock_executor, mock_airflow_key): + airflow_key = mock_airflow_key() + ser_airflow_key = json.dumps(airflow_key._asdict()) + airflow_key_2 = mock_airflow_key() + ser_airflow_key_2 = json.dumps(airflow_key_2._asdict()) + + mock_executor.running_tasks.clear() + # Only add one of the tasks to the running list, the other will be unknown + mock_executor.running_tasks[ser_airflow_key] = airflow_key + + # Receive the known task and unknown task + mock_executor.sqs_client.receive_message.return_value = { + "Messages": [ + { + "ReceiptHandle": "receipt_handle", + "Body": json.dumps( + { + "task_key": ser_airflow_key, + "return_code": 0, + } + ), + }, + { + "ReceiptHandle": "receipt_handle", + "Body": json.dumps( + { + "task_key": ser_airflow_key_2, + "return_code": 0, + } + ), + }, + ] + } + + mock_executor.sync_running_tasks() + mock_executor.sqs_client.receive_message.assert_called_once() + + # The known task is set to succeeded, unknown task is dropped + assert len(mock_executor.running_tasks) == 0 + success_mock.assert_called_once() + fail_mock.assert_not_called() + # Both messages from the queue should be deleted, both known and unknown + assert mock_executor.sqs_client.delete_message.call_count == 2 + + def test_start_no_check_health(self, mock_executor): + mock_executor.check_health = mock.Mock() + with conf_vars({(CONFIG_GROUP_NAME, AllLambdaConfigKeys.CHECK_HEALTH_ON_STARTUP): "False"}): + mock_executor.start() + + assert mock_executor.check_health.call_count == 0 + + def test_start_check_health_success(self, mock_executor): + mock_executor.check_health = mock.Mock() + with conf_vars({(CONFIG_GROUP_NAME, AllLambdaConfigKeys.CHECK_HEALTH_ON_STARTUP): "True"}): + mock_executor.start() + + assert mock_executor.check_health.call_count == 1 + + def test_start_check_health_fail(self, mock_executor): + mock_executor.check_health = mock.Mock() + mock_executor.check_health.side_effect = AirflowException("Test exception") + with conf_vars({(CONFIG_GROUP_NAME, AllLambdaConfigKeys.CHECK_HEALTH_ON_STARTUP): "True"}): + with pytest.raises(AirflowException): + mock_executor.start() + + assert mock_executor.check_health.call_count == 1 + + def test_check_health_success(self, mock_executor): + mock_executor.IS_BOTO_CONNECTION_HEALTHY = False + mock_executor.sqs_client.get_queue_attributes.return_value = { + "Attributes": {"ApproximateNumberOfMessages": 0} + } + mock_executor.lambda_client.get_function.return_value = { + "Configuration": { + "FunctionName": DEFAULT_FUNCTION_NAME, + "State": "Active", + } + } + mock_executor.check_health() + assert mock_executor.sqs_client.get_queue_attributes.call_count == 2 + assert mock_executor.lambda_client.get_function.call_count == 1 + assert mock_executor.IS_BOTO_CONNECTION_HEALTHY + + def test_check_health_lambda_fails(self, mock_executor): + mock_executor.IS_BOTO_CONNECTION_HEALTHY = False + mock_executor.sqs_client.get_queue_attributes.return_value = { + "Attributes": {"ApproximateNumberOfMessages": 0} + } + mock_executor.lambda_client.get_function.return_value = ClientError( + {"Error": {"Code": "ResourceNotFoundException", "Message": "foobar"}}, "OperationName" + ) + + with pytest.raises(AirflowException): + mock_executor.check_health() + assert mock_executor.lambda_client.get_function.call_count == 1 + # Lambda has already failed so SQS should not be called + assert mock_executor.sqs_client.get_queue_attributes.call_count == 0 + assert not mock_executor.IS_BOTO_CONNECTION_HEALTHY + + def test_check_health_sqs_fails(self, mock_executor): + mock_executor.IS_BOTO_CONNECTION_HEALTHY = False + mock_executor.sqs_client.get_queue_attributes.return_value = ClientError( + {"Error": {"Code": "ResourceNotFoundException", "Message": "foobar"}}, "OperationName" + ) + mock_executor.lambda_client.get_function.return_value = { + "Configuration": { + "FunctionName": DEFAULT_FUNCTION_NAME, + "State": "Active", + } + } + with pytest.raises(AirflowException): + mock_executor.check_health() + assert mock_executor.lambda_client.get_function.call_count == 1 + # Lambda has already failed so SQS should not be called + assert mock_executor.sqs_client.get_queue_attributes.call_count == 1 + assert not mock_executor.IS_BOTO_CONNECTION_HEALTHY + + def test_check_health_sqs_results_queue_success_dlq_fails(self, mock_executor): + mock_executor.IS_BOTO_CONNECTION_HEALTHY = False + mock_executor.sqs_client.get_queue_attributes.side_effect = [ + {"Attributes": {"ApproximateNumberOfMessages": 0}}, + ClientError( + {"Error": {"Code": "ResourceNotFoundException", "Message": "foobar"}}, "OperationName" + ), + ] + mock_executor.lambda_client.get_function.return_value = { + "Configuration": { + "FunctionName": DEFAULT_FUNCTION_NAME, + "State": "Active", + } + } + with pytest.raises(AirflowException): + mock_executor.check_health() + assert mock_executor.lambda_client.get_function.call_count == 1 + # Lambda has already failed so SQS should not be called + assert mock_executor.sqs_client.get_queue_attributes.call_count == 2 + assert not mock_executor.IS_BOTO_CONNECTION_HEALTHY + + def test_sync_already_unhealthy(self, mock_executor): + # Something has set the connection to unhealthy (tested elsewhere) + mock_executor.IS_BOTO_CONNECTION_HEALTHY = False + mock_executor.sync_running_tasks = mock.Mock() + mock_executor.attempt_task_runs = mock.Mock() + mock_executor.load_connections = mock.Mock() + # Set the last connection reload to be more than 60 seconds ago so that we get a reload + mock_executor.last_connection_reload = timezone.utcnow() - dt.timedelta(seconds=100) + # We should not be able to sync + mock_executor.sync() + assert not mock_executor.IS_BOTO_CONNECTION_HEALTHY + mock_executor.sync_running_tasks.assert_not_called() + mock_executor.attempt_task_runs.assert_not_called() + mock_executor.load_connections.assert_called_once() + + def test_sync_already_unhealthy_then_repaired(self, mock_executor): + # Something has set the connection to unhealthy (tested elsewhere) + mock_executor.IS_BOTO_CONNECTION_HEALTHY = False + mock_executor.sync_running_tasks = mock.Mock() + mock_executor.attempt_task_runs = mock.Mock() + + def check_health_side_effect(): + mock_executor.IS_BOTO_CONNECTION_HEALTHY = True + + mock_executor.check_health = mock.Mock(side_effect=check_health_side_effect) + # Set the last connection reload to be more than 60 seconds ago so that we get a reload + mock_executor.last_connection_reload = timezone.utcnow() - dt.timedelta(seconds=100) + # Sync should repair itself and continue to call the sync methods + mock_executor.sync() + assert mock_executor.IS_BOTO_CONNECTION_HEALTHY + mock_executor.sync_running_tasks.assert_called_once() + mock_executor.attempt_task_runs.assert_called_once() + + @pytest.mark.parametrize( + "error_code", + [ + "ExpiredTokenException", + "InvalidClientTokenId", + "UnrecognizedClientException", + ], + ) + def test_sync_become_unhealthy_no_creds(self, error_code, mock_executor): + # Something has set the connection to unhealthy (tested elsewhere) + mock_executor.IS_BOTO_CONNECTION_HEALTHY = True + mock_executor.log.warning = mock.Mock() + mock_executor.attempt_task_runs = mock.Mock() + error_to_raise = ClientError({"Error": {"Code": error_code, "Message": "foobar"}}, "OperationName") + mock_executor.sync_running_tasks = mock.Mock(side_effect=error_to_raise) + + # sync should catch the error and handle it, setting connection to unhealthy + mock_executor.sync() + assert not mock_executor.IS_BOTO_CONNECTION_HEALTHY + mock_executor.sync_running_tasks.assert_called_once() + mock_executor.attempt_task_runs.assert_not_called() + # Check that the substring "AWS credentials are either missing or expired" was logged + mock_executor.log.warning.assert_called_once() + assert "AWS credentials are either missing or expired" in mock_executor.log.warning.call_args[0][0] + + def test_sync_exception(self, mock_executor): + # Something has set the connection to unhealthy (tested elsewhere) + mock_executor.IS_BOTO_CONNECTION_HEALTHY = True + mock_executor.log.exception = mock.Mock() + mock_executor.attempt_task_runs = mock.Mock() + mock_executor.sync_running_tasks = mock.Mock(side_effect=Exception()) + + # sync should catch the error and log, don't kill scheduler by letting it raise up higher. + mock_executor.sync() + # Not a credentials error that we can tell, so connection stays healthy + assert mock_executor.IS_BOTO_CONNECTION_HEALTHY + mock_executor.sync_running_tasks.assert_called_once() + mock_executor.attempt_task_runs.assert_not_called() + # Check that the substring "AWS credentials are either missing or expired" was logged + mock_executor.log.exception.assert_called_once() + assert "An error occurred while syncing tasks" in mock_executor.log.exception.call_args[0][0] + + def test_try_adopt_task_instances(self, mock_executor, mock_airflow_key): + """Test that executor can adopt orphaned task instances from a SchedulerJob shutdown event.""" + # airflow_key_1 = mock_airflow_key() + airflow_key_1 = TaskInstanceKey("a", "task_a", "c", 1, -1) + ser_airflow_key_1 = json.dumps(airflow_key_1._asdict()) + # airflow_key_2 = mock_airflow_key() + airflow_key_2 = TaskInstanceKey("a", "task_b", "c", 1, -1) + ser_airflow_key_2 = json.dumps(airflow_key_2._asdict()) + + orphaned_tasks = [ + mock.Mock(spec=TaskInstance), + mock.Mock(spec=TaskInstance), + mock.Mock(spec=TaskInstance), + ] + orphaned_tasks[0].external_executor_id = ser_airflow_key_1 + orphaned_tasks[1].external_executor_id = ser_airflow_key_2 + orphaned_tasks[ + 2 + ].external_executor_id = None # One orphaned task has no external_executor_id, not adopted + + for task in orphaned_tasks: + task.try_number = 1 + + not_adopted_tasks = mock_executor.try_adopt_task_instances(orphaned_tasks) + + # Two of the three tasks should be adopted. + assert len(orphaned_tasks) - 1 == len(mock_executor.running_tasks) + assert ser_airflow_key_1 in mock_executor.running_tasks + + assert mock_executor.running_tasks[ser_airflow_key_1] == airflow_key_1 + assert ser_airflow_key_2 in mock_executor.running_tasks + assert mock_executor.running_tasks[ser_airflow_key_2] == airflow_key_2 + + # The remaining one task is unable to be adopted. + assert len(not_adopted_tasks) == 1 + assert not_adopted_tasks[0] == orphaned_tasks[2] + + @mock.patch.object(BaseExecutor, "fail") + @mock.patch.object(BaseExecutor, "success") + def test_end(self, success_mock, fail_mock, mock_executor, mock_airflow_key): + """Test that executor can end successfully; waiting for all tasks to naturally exit.""" + airflow_key = mock_airflow_key() + ser_airflow_key = json.dumps(airflow_key._asdict()) + + mock_executor.running_tasks.clear() + mock_executor.running_tasks[ser_airflow_key] = airflow_key + # First message is empty, so we loop again while waiting for tasks to finish + mock_executor.sqs_client.receive_message.side_effect = [ + {}, + {}, + { + "Messages": [ + { + "ReceiptHandle": "receipt_handle", + "Body": json.dumps( + { + "task_key": ser_airflow_key, + "return_code": 0, + } + ), + } + ] + }, + ] + mock_executor.end(heartbeat_interval=0) + # Assert that the sqs_client mock method receive_message was called exactly twice + assert mock_executor.sqs_client.receive_message.call_count == 3 + + # Task is not stored in active workers. + assert len(mock_executor.running_tasks) == 0 + success_mock.assert_called_once() + fail_mock.assert_not_called() + assert mock_executor.sqs_client.delete_message.call_count == 1 + + @mock.patch("airflow.providers.amazon.aws.executors.aws_lambda.lambda_executor.timezone") + def test_end_timeout(self, mock_timezone, mock_executor, mock_airflow_key): + """Test that executor can end successfully; waiting for all tasks to naturally exit.""" + # Mock the sync method of the mock_executor object so we can count how many times it was called + mock_executor.sync = mock.Mock() + mock_executor.log.warning = mock.Mock() + current_time = timezone.utcnow() + mock_timezone.utcnow.side_effect = [ + current_time, + current_time, + current_time + dt.timedelta(seconds=5), + current_time + dt.timedelta(seconds=10), + ] + + airflow_key = mock_airflow_key() + ser_airflow_key = json.dumps(airflow_key._asdict()) + + mock_executor.running_tasks.clear() + mock_executor.running_tasks[ser_airflow_key] = airflow_key + + with conf_vars({(CONFIG_GROUP_NAME, AllLambdaConfigKeys.END_WAIT_TIMEOUT): "5"}): + mock_executor.end(heartbeat_interval=0) + + # Task is still stored in active workers. + assert len(mock_executor.running_tasks) == 1 + assert mock_executor.sync.call_count == 2 + mock_executor.log.warning.assert_called_once_with( + "Timed out waiting for tasks to finish. Some tasks may not be handled gracefully" + " as the executor is force ending due to timeout." + ) + + def test_terminate(self, mock_executor, mock_airflow_key): + """Test that executor can terminate successfully.""" + airflow_key = mock_airflow_key() + ser_airflow_key = json.dumps(airflow_key._asdict()) + + mock_executor.running_tasks.clear() + mock_executor.running_tasks[ser_airflow_key] = airflow_key + mock_executor.log.warning = mock.Mock() + + mock_executor.terminate() + mock_executor.log.warning.assert_called_once_with( + "Terminating Lambda executor. In-flight tasks cannot be stopped." + ) + assert len(mock_executor.running_tasks) == 1