This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0723a8f01d Introduce Amazon Bedrock service (#38602)
0723a8f01d is described below

commit 0723a8f01d1bc9eb62324a222ba34b82a8d8252c
Author: D. Ferruzzi <ferru...@amazon.com>
AuthorDate: Sat Mar 30 01:54:42 2024 -0700

    Introduce Amazon Bedrock service (#38602)
    
    * Introduce Amazon Bedrock service
---
 airflow/providers/amazon/aws/hooks/bedrock.py      |  39 +++++++++
 airflow/providers/amazon/aws/operators/bedrock.py  |  93 +++++++++++++++++++++
 airflow/providers/amazon/provider.yaml             |  12 +++
 .../operators/bedrock.rst                          |  72 ++++++++++++++++
 .../aws/amazon-bedrock_light...@4x.png             | Bin 0 -> 12621 bytes
 tests/providers/amazon/aws/hooks/test_bedrock.py   |  27 ++++++
 .../providers/amazon/aws/operators/test_bedrock.py |  59 +++++++++++++
 .../system/providers/amazon/aws/example_bedrock.py |  76 +++++++++++++++++
 8 files changed, 378 insertions(+)

diff --git a/airflow/providers/amazon/aws/hooks/bedrock.py 
b/airflow/providers/amazon/aws/hooks/bedrock.py
new file mode 100644
index 0000000000..11bacd9414
--- /dev/null
+++ b/airflow/providers/amazon/aws/hooks/bedrock.py
@@ -0,0 +1,39 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+
+
+class BedrockRuntimeHook(AwsBaseHook):
+    """
+    Interact with the Amazon Bedrock Runtime.
+
+    Provide thin wrapper around 
:external+boto3:py:class:`boto3.client("bedrock-runtime") 
<BedrockRuntime.Client>`.
+
+    Additional arguments (such as ``aws_conn_id``) may be specified and
+    are passed down to the underlying AwsBaseHook.
+
+    .. seealso::
+        - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
+    """
+
+    client_type = "bedrock-runtime"
+
+    def __init__(self, *args, **kwargs) -> None:
+        kwargs["client_type"] = self.client_type
+        super().__init__(*args, **kwargs)
diff --git a/airflow/providers/amazon/aws/operators/bedrock.py 
b/airflow/providers/amazon/aws/operators/bedrock.py
new file mode 100644
index 0000000000..d8eaf9e5d3
--- /dev/null
+++ b/airflow/providers/amazon/aws/operators/bedrock.py
@@ -0,0 +1,93 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import json
+from typing import TYPE_CHECKING, Any, Sequence
+
+from airflow.providers.amazon.aws.hooks.bedrock import BedrockRuntimeHook
+from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
+from airflow.utils.helpers import prune_dict
+
+if TYPE_CHECKING:
+    from airflow.utils.context import Context
+
+
+class BedrockInvokeModelOperator(AwsBaseOperator[BedrockRuntimeHook]):
+    """
+    Invoke the specified Bedrock model to run inference using the input 
provided.
+
+    Use InvokeModel to run inference for text models, image models, and 
embedding models.
+    To see the format and content of the input_data field for different 
models, refer to
+    `Inference parameters docs 
<https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html>`_.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:BedrockInvokeModelOperator`
+
+    :param model_id: The ID of the Bedrock model. (templated)
+    :param input_data: Input data in the format specified in the content-type 
request header. (templated)
+    :param content_type: The MIME type of the input data in the request. 
(templated) Default: application/json
+    :param accept: The desired MIME type of the inference body in the response.
+        (templated) Default: application/json
+
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+        If this is ``None`` or empty then the default boto3 behaviour is used. 
If
+        running Airflow in a distributed manner and aws_conn_id is None or
+        empty, then default boto3 configuration would be used (and must be
+        maintained on each worker node).
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
+    """
+
+    aws_hook_class = BedrockRuntimeHook
+    template_fields: Sequence[str] = aws_template_fields(
+        "model_id", "input_data", "content_type", "accept_type"
+    )
+
+    def __init__(
+        self,
+        model_id: str,
+        input_data: dict[str, Any],
+        content_type: str | None = None,
+        accept_type: str | None = None,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        self.model_id = model_id
+        self.input_data = input_data
+        self.content_type = content_type
+        self.accept_type = accept_type
+
+    def execute(self, context: Context) -> dict[str, str | int]:
+        # These are optional values which the API defaults to 
"application/json" if not provided here.
+        invoke_kwargs = prune_dict({"contentType": self.content_type, 
"accept": self.accept_type})
+
+        response = self.hook.conn.invoke_model(
+            body=json.dumps(self.input_data),
+            modelId=self.model_id,
+            **invoke_kwargs,
+        )
+
+        response_body = json.loads(response["body"].read())
+        self.log.info("Bedrock %s prompt: %s", self.model_id, self.input_data)
+        self.log.info("Bedrock model response: %s", response_body)
+        return response_body
diff --git a/airflow/providers/amazon/provider.yaml 
b/airflow/providers/amazon/provider.yaml
index e2b0df930e..4c4f7cf597 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -142,6 +142,12 @@ integrations:
       - /docs/apache-airflow-providers-amazon/operators/athena/athena_boto.rst
       - /docs/apache-airflow-providers-amazon/operators/athena/athena_sql.rst
     tags: [aws]
+  - integration-name: Amazon Bedrock
+    external-doc-url: https://aws.amazon.com/bedrock/
+    logo: /integration-logos/aws/amazon-bedrock_light...@4x.png
+    how-to-guide:
+      - /docs/apache-airflow-providers-amazon/operators/bedrock.rst
+    tags: [aws]
   - integration-name: Amazon Chime
     external-doc-url: https://aws.amazon.com/chime/
     logo: /integration-logos/aws/Amazon-Chime-light-bg.png
@@ -363,6 +369,9 @@ operators:
   - integration-name: AWS Batch
     python-modules:
       - airflow.providers.amazon.aws.operators.batch
+  - integration-name: Amazon Bedrock
+    python-modules:
+      - airflow.providers.amazon.aws.operators.bedrock
   - integration-name: Amazon CloudFormation
     python-modules:
       - airflow.providers.amazon.aws.operators.cloud_formation
@@ -514,6 +523,9 @@ hooks:
     python-modules:
       - airflow.providers.amazon.aws.hooks.athena
       - airflow.providers.amazon.aws.hooks.athena_sql
+  - integration-name: Amazon Bedrock
+    python-modules:
+      - airflow.providers.amazon.aws.hooks.bedrock
   - integration-name: Amazon Chime
     python-modules:
       - airflow.providers.amazon.aws.hooks.chime
diff --git a/docs/apache-airflow-providers-amazon/operators/bedrock.rst 
b/docs/apache-airflow-providers-amazon/operators/bedrock.rst
new file mode 100644
index 0000000000..3e84cbc445
--- /dev/null
+++ b/docs/apache-airflow-providers-amazon/operators/bedrock.rst
@@ -0,0 +1,72 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+ ..   http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+==============
+Amazon Bedrock
+==============
+
+`Amazon Bedrock <https://aws.amazon.com/bedrock/>`__ is a fully managed 
service that
+offers a choice of high-performing foundation models (FMs) from leading AI 
companies
+like AI21 Labs, Anthropic, Cohere, Meta, Mistral AI, Stability AI, and Amazon 
via a
+single API, along with a broad set of capabilities you need to build 
generative AI
+applications with security, privacy, and responsible AI.
+
+Prerequisite Tasks
+------------------
+
+.. include:: ../_partials/prerequisite_tasks.rst
+
+Generic Parameters
+------------------
+
+.. include:: ../_partials/generic_parameters.rst
+
+Operators
+---------
+
+.. _howto/operator:BedrockInvokeModelOperator:
+
+Invoke an existing Amazon Bedrock Model
+=======================================
+
+To invoke an existing Amazon Bedrock model, you can use
+:class:`~airflow.providers.amazon.aws.operators.bedrock.BedrockInvokeModelOperator`.
+
+Note that every model family has different input and output formats.
+For example, to invoke a Meta Llama model you would use:
+
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_invoke_llama_model]
+    :end-before: [END howto_operator_invoke_llama_model]
+
+To invoke an Amazon Titan model you would use:
+
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_invoke_titan_model]
+    :end-before: [END howto_operator_invoke_titan_model]
+
+For details on the different formats, see `Inference parameters for foundation 
models 
<https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html>`__
+
+
+Reference
+---------
+
+* `AWS boto3 library documentation for Amazon Bedrock 
<https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock.html>`__
diff --git a/docs/integration-logos/aws/amazon-bedrock_light...@4x.png 
b/docs/integration-logos/aws/amazon-bedrock_light...@4x.png
new file mode 100644
index 0000000000..e6af4b7276
Binary files /dev/null and 
b/docs/integration-logos/aws/amazon-bedrock_light...@4x.png differ
diff --git a/tests/providers/amazon/aws/hooks/test_bedrock.py 
b/tests/providers/amazon/aws/hooks/test_bedrock.py
new file mode 100644
index 0000000000..73612aacbc
--- /dev/null
+++ b/tests/providers/amazon/aws/hooks/test_bedrock.py
@@ -0,0 +1,27 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow.providers.amazon.aws.hooks.bedrock import BedrockRuntimeHook
+
+
+class TestBedrockRuntimeHook:
+    def test_conn_returns_a_boto3_connection(self):
+        hook = BedrockRuntimeHook()
+
+        assert hook.conn is not None
+        assert hook.conn.meta.service_model.service_name == "bedrock-runtime"
diff --git a/tests/providers/amazon/aws/operators/test_bedrock.py 
b/tests/providers/amazon/aws/operators/test_bedrock.py
new file mode 100644
index 0000000000..f6274de48f
--- /dev/null
+++ b/tests/providers/amazon/aws/operators/test_bedrock.py
@@ -0,0 +1,59 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import json
+from typing import Generator
+from unittest import mock
+
+import pytest
+from moto import mock_aws
+
+from airflow.providers.amazon.aws.hooks.bedrock import BedrockRuntimeHook
+from airflow.providers.amazon.aws.operators.bedrock import 
BedrockInvokeModelOperator
+
+MODEL_ID = "meta.llama2-13b-chat-v1"
+PROMPT = "A very important question."
+GENERATED_RESPONSE = "An important answer."
+MOCK_RESPONSE = json.dumps(
+    {
+        "generation": GENERATED_RESPONSE,
+        "prompt_token_count": len(PROMPT),
+        "generation_token_count": len(GENERATED_RESPONSE),
+        "stop_reason": "stop",
+    }
+)
+
+
+@pytest.fixture
+def runtime_hook() -> Generator[BedrockRuntimeHook, None, None]:
+    with mock_aws():
+        yield BedrockRuntimeHook(aws_conn_id="aws_default")
+
+
+class TestBedrockInvokeModelOperator:
+    @mock.patch.object(BedrockRuntimeHook, "conn")
+    def test_invoke_model_prompt_good_combinations(self, mock_conn):
+        mock_conn.invoke_model.return_value["body"].read.return_value = 
MOCK_RESPONSE
+        operator = BedrockInvokeModelOperator(
+            task_id="test_task", model_id=MODEL_ID, input_data={"input_data": 
{"prompt": PROMPT}}
+        )
+
+        response = operator.execute({})
+
+        assert response["generation"] == GENERATED_RESPONSE
diff --git a/tests/system/providers/amazon/aws/example_bedrock.py 
b/tests/system/providers/amazon/aws/example_bedrock.py
new file mode 100644
index 0000000000..e86e5a2e92
--- /dev/null
+++ b/tests/system/providers/amazon/aws/example_bedrock.py
@@ -0,0 +1,76 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from datetime import datetime
+
+from airflow.models.baseoperator import chain
+from airflow.models.dag import DAG
+from airflow.providers.amazon.aws.operators.bedrock import 
BedrockInvokeModelOperator
+from tests.system.providers.amazon.aws.utils import SystemTestContextBuilder
+
+sys_test_context_task = SystemTestContextBuilder().build()
+
+DAG_ID = "example_bedrock"
+PROMPT = "What color is an orange?"
+
+with DAG(
+    dag_id=DAG_ID,
+    schedule="@once",
+    start_date=datetime(2021, 1, 1),
+    tags=["example"],
+    catchup=False,
+) as dag:
+    test_context = sys_test_context_task()
+    env_id = test_context["ENV_ID"]
+
+    # [START howto_operator_invoke_llama_model]
+    invoke_llama_model = BedrockInvokeModelOperator(
+        task_id="invoke_llama",
+        model_id="meta.llama2-13b-chat-v1",
+        input_data={"prompt": PROMPT},
+    )
+    # [END howto_operator_invoke_llama_model]
+
+    # [START howto_operator_invoke_titan_model]
+    invoke_titan_model = BedrockInvokeModelOperator(
+        task_id="invoke_titan",
+        model_id="amazon.titan-text-express-v1",
+        input_data={"inputText": PROMPT},
+    )
+    # [END howto_operator_invoke_titan_model]
+
+    chain(
+        # TEST SETUP
+        test_context,
+        # TEST BODY
+        invoke_llama_model,
+        invoke_titan_model,
+        # TEST TEARDOWN
+    )
+
+    from tests.system.utils.watcher import watcher
+
+    # This test needs watcher in order to properly mark success/failure
+    # when "tearDown" task with trigger rule is part of the DAG
+    list(dag.tasks) >> watcher()
+
+
+from tests.system.utils import get_test_run  # noqa: E402
+
+# Needed to run the example DAG with pytest (see: 
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)

Reply via email to