This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 13ee1f1cfb0 Add `S3VectorsQueryVectorsOperator` (#66724)
13ee1f1cfb0 is described below
commit 13ee1f1cfb0f21e57ed44a3c73395dfeed0d61b3
Author: John Jackson <[email protected]>
AuthorDate: Tue May 12 07:08:47 2026 -0700
Add `S3VectorsQueryVectorsOperator` (#66724)
---
providers/amazon/docs/operators/s3_vectors.rst | 15 +++++
.../providers/amazon/aws/operators/s3_vectors.py | 69 ++++++++++++++++++++++
.../tests/system/amazon/aws/example_s3_vectors.py | 12 ++++
.../unit/amazon/aws/operators/test_s3_vectors.py | 47 +++++++++++++++
4 files changed, 143 insertions(+)
diff --git a/providers/amazon/docs/operators/s3_vectors.rst
b/providers/amazon/docs/operators/s3_vectors.rst
index 6e1d364377a..ac8bf6dedac 100644
--- a/providers/amazon/docs/operators/s3_vectors.rst
+++ b/providers/amazon/docs/operators/s3_vectors.rst
@@ -92,6 +92,21 @@ To insert vectors into an Amazon S3 Vectors index, use
:start-after: [START howto_operator_s3vectors_put_vectors]
:end-before: [END howto_operator_s3vectors_put_vectors]
+
+.. _howto/operator:S3VectorsQueryVectorsOperator:
+
+Query Vectors
+-------------
+
+To query vectors by similarity in an Amazon S3 Vectors index, use
+:class:`~airflow.providers.amazon.aws.operators.s3_vectors.S3VectorsQueryVectorsOperator`.
+
+.. exampleinclude:: /../../amazon/tests/system/amazon/aws/example_s3_vectors.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_s3vectors_query_vectors]
+ :end-before: [END howto_operator_s3vectors_query_vectors]
+
Reference
---------
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/operators/s3_vectors.py
b/providers/amazon/src/airflow/providers/amazon/aws/operators/s3_vectors.py
index 1ca23e05bf2..cac1d2409f9 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/operators/s3_vectors.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/s3_vectors.py
@@ -298,3 +298,72 @@ class
S3VectorsPutVectorsOperator(AwsBaseOperator[AwsBaseHook]):
vectors=self.vectors,
)
self.log.info("Put %d vectors successfully", len(self.vectors))
+
+
+class S3VectorsQueryVectorsOperator(AwsBaseOperator[AwsBaseHook]):
+ """
+ Query vectors by similarity in an Amazon S3 Vectors index.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:S3VectorsQueryVectorsOperator`
+
+ :param vector_bucket_name: The name of the vector bucket. (templated)
+ :param index_name: The name of the index. (templated)
+ :param top_k: The number of results to return.
+ :param query_vector: The query vector dict (e.g. ``{"float32": [0.1, 0.2,
...]}``)
+ :param filter: Optional filter expression dict.
+ :param return_metadata: Whether to return metadata with results.
+ :param return_distance: Whether to return distance scores.
+ """
+
+ aws_hook_class = AwsBaseHook
+ template_fields: tuple[str, ...] = (
+ *AwsBaseOperator.template_fields,
+ "vector_bucket_name",
+ "index_name",
+ "top_k",
+ )
+
+ def __init__(
+ self,
+ *,
+ vector_bucket_name: str,
+ index_name: str,
+ top_k: int,
+ query_vector: dict[str, Any],
+ filter: dict[str, Any] | None = None,
+ return_metadata: bool = True,
+ return_distance: bool = True,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.vector_bucket_name = vector_bucket_name
+ self.index_name = index_name
+ self.top_k = top_k
+ self.query_vector = query_vector
+ self.filter = filter
+ self.return_metadata = return_metadata
+ self.return_distance = return_distance
+
+ @property
+ def _hook_parameters(self) -> dict[str, Any]:
+ return {**super()._hook_parameters, "client_type": "s3vectors"}
+
+ def execute(self, context: Context) -> list[dict[str, Any]]:
+ self.log.info("Querying top %d vectors from index %s", self.top_k,
self.index_name)
+ kwargs: dict[str, Any] = prune_dict(
+ {
+ "vectorBucketName": self.vector_bucket_name,
+ "indexName": self.index_name,
+ "topK": self.top_k,
+ "queryVector": self.query_vector,
+ "filter": self.filter,
+ "returnMetadata": self.return_metadata,
+ "returnDistance": self.return_distance,
+ }
+ )
+ response = self.hook.conn.query_vectors(**kwargs)
+ vectors = response.get("vectors", [])
+ self.log.info("Query returned %d results", len(vectors))
+ return vectors
diff --git a/providers/amazon/tests/system/amazon/aws/example_s3_vectors.py
b/providers/amazon/tests/system/amazon/aws/example_s3_vectors.py
index 9faedd9549e..dda071944c8 100644
--- a/providers/amazon/tests/system/amazon/aws/example_s3_vectors.py
+++ b/providers/amazon/tests/system/amazon/aws/example_s3_vectors.py
@@ -24,6 +24,7 @@ from airflow.providers.amazon.aws.operators.s3_vectors import
(
S3VectorsDeleteIndexOperator,
S3VectorsDeleteVectorBucketOperator,
S3VectorsPutVectorsOperator,
+ S3VectorsQueryVectorsOperator,
)
from airflow.providers.common.compat.sdk import DAG, chain
@@ -86,6 +87,16 @@ with DAG(
)
# [END howto_operator_s3vectors_delete_vector_bucket]
+ # [START howto_operator_s3vectors_query_vectors]
+ query_vectors = S3VectorsQueryVectorsOperator(
+ task_id="query_vectors",
+ vector_bucket_name=bucket_name,
+ index_name=index_name,
+ top_k=3,
+ query_vector={"float32": [0.1, 0.2, 0.3, 0.4]},
+ )
+ # [END howto_operator_s3vectors_query_vectors]
+
# [START howto_operator_s3vectors_delete_index]
delete_index = S3VectorsDeleteIndexOperator(
task_id="delete_index",
@@ -100,6 +111,7 @@ with DAG(
create_vector_bucket,
create_index,
put_vectors,
+ query_vectors,
delete_index,
delete_vector_bucket,
)
diff --git
a/providers/amazon/tests/unit/amazon/aws/operators/test_s3_vectors.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_s3_vectors.py
index e483d0396db..df564d498b0 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_s3_vectors.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_s3_vectors.py
@@ -27,6 +27,7 @@ from airflow.providers.amazon.aws.operators.s3_vectors import
(
S3VectorsDeleteIndexOperator,
S3VectorsDeleteVectorBucketOperator,
S3VectorsPutVectorsOperator,
+ S3VectorsQueryVectorsOperator,
)
from unit.amazon.aws.utils.test_template_fields import validate_template_fields
@@ -250,3 +251,49 @@ class TestS3VectorsPutVectorsOperator:
def test_template_fields(self):
validate_template_fields(self.operator)
+
+
+QUERY_VECTOR = {"float32": [0.1, 0.2, 0.3, 0.4]}
+QUERY_RESULTS = [{"key": "vec1", "distance": 0.95, "metadata": {"label":
"test"}}]
+
+
+class TestS3VectorsQueryVectorsOperator:
+ def setup_method(self):
+ self.operator = S3VectorsQueryVectorsOperator(
+ task_id="query_vectors",
+ vector_bucket_name=BUCKET_NAME,
+ index_name=INDEX_NAME,
+ top_k=5,
+ query_vector=QUERY_VECTOR,
+ )
+
+ def test_execute(self):
+ mock_conn = MagicMock()
+ mock_conn.query_vectors.return_value = {"vectors": QUERY_RESULTS,
"distanceMetric": "cosine"}
+ self.operator.hook.conn = mock_conn
+
+ result = self.operator.execute({})
+
+ mock_conn.query_vectors.assert_called_once_with(
+ vectorBucketName=BUCKET_NAME,
+ indexName=INDEX_NAME,
+ topK=5,
+ queryVector=QUERY_VECTOR,
+ returnMetadata=True,
+ returnDistance=True,
+ )
+ assert result == QUERY_RESULTS
+
+ def test_execute_with_filter(self):
+ mock_conn = MagicMock()
+ mock_conn.query_vectors.return_value = {"vectors": []}
+ self.operator.hook.conn = mock_conn
+ self.operator.filter = {"equals": {"key": "label", "value": "test"}}
+
+ self.operator.execute({})
+
+ call_kwargs = mock_conn.query_vectors.call_args[1]
+ assert call_kwargs["filter"] == {"equals": {"key": "label", "value":
"test"}}
+
+ def test_template_fields(self):
+ validate_template_fields(self.operator)