This is an automated email from the ASF dual-hosted git repository. eladkal pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 0723a8f01d Introduce Amazon Bedrock service (#38602) 0723a8f01d is described below commit 0723a8f01d1bc9eb62324a222ba34b82a8d8252c Author: D. Ferruzzi <ferru...@amazon.com> AuthorDate: Sat Mar 30 01:54:42 2024 -0700 Introduce Amazon Bedrock service (#38602) * Introduce Amazon Bedrock service --- airflow/providers/amazon/aws/hooks/bedrock.py | 39 +++++++++ airflow/providers/amazon/aws/operators/bedrock.py | 93 +++++++++++++++++++++ airflow/providers/amazon/provider.yaml | 12 +++ .../operators/bedrock.rst | 72 ++++++++++++++++ .../aws/amazon-bedrock_light...@4x.png | Bin 0 -> 12621 bytes tests/providers/amazon/aws/hooks/test_bedrock.py | 27 ++++++ .../providers/amazon/aws/operators/test_bedrock.py | 59 +++++++++++++ .../system/providers/amazon/aws/example_bedrock.py | 76 +++++++++++++++++ 8 files changed, 378 insertions(+) diff --git a/airflow/providers/amazon/aws/hooks/bedrock.py b/airflow/providers/amazon/aws/hooks/bedrock.py new file mode 100644 index 0000000000..11bacd9414 --- /dev/null +++ b/airflow/providers/amazon/aws/hooks/bedrock.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + + +class BedrockRuntimeHook(AwsBaseHook): + """ + Interact with the Amazon Bedrock Runtime. + + Provide thin wrapper around :external+boto3:py:class:`boto3.client("bedrock-runtime") <BedrockRuntime.Client>`. + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + client_type = "bedrock-runtime" + + def __init__(self, *args, **kwargs) -> None: + kwargs["client_type"] = self.client_type + super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/operators/bedrock.py b/airflow/providers/amazon/aws/operators/bedrock.py new file mode 100644 index 0000000000..d8eaf9e5d3 --- /dev/null +++ b/airflow/providers/amazon/aws/operators/bedrock.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any, Sequence + +from airflow.providers.amazon.aws.hooks.bedrock import BedrockRuntimeHook +from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields +from airflow.utils.helpers import prune_dict + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class BedrockInvokeModelOperator(AwsBaseOperator[BedrockRuntimeHook]): + """ + Invoke the specified Bedrock model to run inference using the input provided. + + Use InvokeModel to run inference for text models, image models, and embedding models. + To see the format and content of the input_data field for different models, refer to + `Inference parameters docs <https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html>`_. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BedrockInvokeModelOperator` + + :param model_id: The ID of the Bedrock model. (templated) + :param input_data: Input data in the format specified in the content-type request header. (templated) + :param content_type: The MIME type of the input data in the request. (templated) Default: application/json + :param accept: The desired MIME type of the inference body in the response. + (templated) Default: application/json + + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + """ + + aws_hook_class = BedrockRuntimeHook + template_fields: Sequence[str] = aws_template_fields( + "model_id", "input_data", "content_type", "accept_type" + ) + + def __init__( + self, + model_id: str, + input_data: dict[str, Any], + content_type: str | None = None, + accept_type: str | None = None, + **kwargs, + ): + super().__init__(**kwargs) + self.model_id = model_id + self.input_data = input_data + self.content_type = content_type + self.accept_type = accept_type + + def execute(self, context: Context) -> dict[str, str | int]: + # These are optional values which the API defaults to "application/json" if not provided here. + invoke_kwargs = prune_dict({"contentType": self.content_type, "accept": self.accept_type}) + + response = self.hook.conn.invoke_model( + body=json.dumps(self.input_data), + modelId=self.model_id, + **invoke_kwargs, + ) + + response_body = json.loads(response["body"].read()) + self.log.info("Bedrock %s prompt: %s", self.model_id, self.input_data) + self.log.info("Bedrock model response: %s", response_body) + return response_body diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index e2b0df930e..4c4f7cf597 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -142,6 +142,12 @@ integrations: - /docs/apache-airflow-providers-amazon/operators/athena/athena_boto.rst - /docs/apache-airflow-providers-amazon/operators/athena/athena_sql.rst tags: [aws] + - integration-name: Amazon Bedrock + external-doc-url: https://aws.amazon.com/bedrock/ + logo: /integration-logos/aws/amazon-bedrock_light...@4x.png + how-to-guide: + - /docs/apache-airflow-providers-amazon/operators/bedrock.rst + tags: [aws] - integration-name: Amazon Chime external-doc-url: https://aws.amazon.com/chime/ logo: /integration-logos/aws/Amazon-Chime-light-bg.png @@ -363,6 +369,9 @@ operators: - integration-name: AWS Batch python-modules: - airflow.providers.amazon.aws.operators.batch + - integration-name: Amazon Bedrock + python-modules: + - airflow.providers.amazon.aws.operators.bedrock - integration-name: Amazon CloudFormation python-modules: - airflow.providers.amazon.aws.operators.cloud_formation @@ -514,6 +523,9 @@ hooks: python-modules: - airflow.providers.amazon.aws.hooks.athena - airflow.providers.amazon.aws.hooks.athena_sql + - integration-name: Amazon Bedrock + python-modules: + - airflow.providers.amazon.aws.hooks.bedrock - integration-name: Amazon Chime python-modules: - airflow.providers.amazon.aws.hooks.chime diff --git a/docs/apache-airflow-providers-amazon/operators/bedrock.rst b/docs/apache-airflow-providers-amazon/operators/bedrock.rst new file mode 100644 index 0000000000..3e84cbc445 --- /dev/null +++ b/docs/apache-airflow-providers-amazon/operators/bedrock.rst @@ -0,0 +1,72 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +============== +Amazon Bedrock +============== + +`Amazon Bedrock <https://aws.amazon.com/bedrock/>`__ is a fully managed service that +offers a choice of high-performing foundation models (FMs) from leading AI companies +like AI21 Labs, Anthropic, Cohere, Meta, Mistral AI, Stability AI, and Amazon via a +single API, along with a broad set of capabilities you need to build generative AI +applications with security, privacy, and responsible AI. + +Prerequisite Tasks +------------------ + +.. include:: ../_partials/prerequisite_tasks.rst + +Generic Parameters +------------------ + +.. include:: ../_partials/generic_parameters.rst + +Operators +--------- + +.. _howto/operator:BedrockInvokeModelOperator: + +Invoke an existing Amazon Bedrock Model +======================================= + +To invoke an existing Amazon Bedrock model, you can use +:class:`~airflow.providers.amazon.aws.operators.bedrock.BedrockInvokeModelOperator`. + +Note that every model family has different input and output formats. +For example, to invoke a Meta Llama model you would use: + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_invoke_llama_model] + :end-before: [END howto_operator_invoke_llama_model] + +To invoke an Amazon Titan model you would use: + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_invoke_titan_model] + :end-before: [END howto_operator_invoke_titan_model] + +For details on the different formats, see `Inference parameters for foundation models <https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html>`__ + + +Reference +--------- + +* `AWS boto3 library documentation for Amazon Bedrock <https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock.html>`__ diff --git a/docs/integration-logos/aws/amazon-bedrock_light...@4x.png b/docs/integration-logos/aws/amazon-bedrock_light...@4x.png new file mode 100644 index 0000000000..e6af4b7276 Binary files /dev/null and b/docs/integration-logos/aws/amazon-bedrock_light...@4x.png differ diff --git a/tests/providers/amazon/aws/hooks/test_bedrock.py b/tests/providers/amazon/aws/hooks/test_bedrock.py new file mode 100644 index 0000000000..73612aacbc --- /dev/null +++ b/tests/providers/amazon/aws/hooks/test_bedrock.py @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.amazon.aws.hooks.bedrock import BedrockRuntimeHook + + +class TestBedrockRuntimeHook: + def test_conn_returns_a_boto3_connection(self): + hook = BedrockRuntimeHook() + + assert hook.conn is not None + assert hook.conn.meta.service_model.service_name == "bedrock-runtime" diff --git a/tests/providers/amazon/aws/operators/test_bedrock.py b/tests/providers/amazon/aws/operators/test_bedrock.py new file mode 100644 index 0000000000..f6274de48f --- /dev/null +++ b/tests/providers/amazon/aws/operators/test_bedrock.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import json +from typing import Generator +from unittest import mock + +import pytest +from moto import mock_aws + +from airflow.providers.amazon.aws.hooks.bedrock import BedrockRuntimeHook +from airflow.providers.amazon.aws.operators.bedrock import BedrockInvokeModelOperator + +MODEL_ID = "meta.llama2-13b-chat-v1" +PROMPT = "A very important question." +GENERATED_RESPONSE = "An important answer." +MOCK_RESPONSE = json.dumps( + { + "generation": GENERATED_RESPONSE, + "prompt_token_count": len(PROMPT), + "generation_token_count": len(GENERATED_RESPONSE), + "stop_reason": "stop", + } +) + + +@pytest.fixture +def runtime_hook() -> Generator[BedrockRuntimeHook, None, None]: + with mock_aws(): + yield BedrockRuntimeHook(aws_conn_id="aws_default") + + +class TestBedrockInvokeModelOperator: + @mock.patch.object(BedrockRuntimeHook, "conn") + def test_invoke_model_prompt_good_combinations(self, mock_conn): + mock_conn.invoke_model.return_value["body"].read.return_value = MOCK_RESPONSE + operator = BedrockInvokeModelOperator( + task_id="test_task", model_id=MODEL_ID, input_data={"input_data": {"prompt": PROMPT}} + ) + + response = operator.execute({}) + + assert response["generation"] == GENERATED_RESPONSE diff --git a/tests/system/providers/amazon/aws/example_bedrock.py b/tests/system/providers/amazon/aws/example_bedrock.py new file mode 100644 index 0000000000..e86e5a2e92 --- /dev/null +++ b/tests/system/providers/amazon/aws/example_bedrock.py @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime + +from airflow.models.baseoperator import chain +from airflow.models.dag import DAG +from airflow.providers.amazon.aws.operators.bedrock import BedrockInvokeModelOperator +from tests.system.providers.amazon.aws.utils import SystemTestContextBuilder + +sys_test_context_task = SystemTestContextBuilder().build() + +DAG_ID = "example_bedrock" +PROMPT = "What color is an orange?" + +with DAG( + dag_id=DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + tags=["example"], + catchup=False, +) as dag: + test_context = sys_test_context_task() + env_id = test_context["ENV_ID"] + + # [START howto_operator_invoke_llama_model] + invoke_llama_model = BedrockInvokeModelOperator( + task_id="invoke_llama", + model_id="meta.llama2-13b-chat-v1", + input_data={"prompt": PROMPT}, + ) + # [END howto_operator_invoke_llama_model] + + # [START howto_operator_invoke_titan_model] + invoke_titan_model = BedrockInvokeModelOperator( + task_id="invoke_titan", + model_id="amazon.titan-text-express-v1", + input_data={"inputText": PROMPT}, + ) + # [END howto_operator_invoke_titan_model] + + chain( + # TEST SETUP + test_context, + # TEST BODY + invoke_llama_model, + invoke_titan_model, + # TEST TEARDOWN + ) + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)