This is an automated email from the ASF dual-hosted git repository.

jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new eeb5c925041 Fix Weaviate tenant-aware ingestion (#67298)
eeb5c925041 is described below

commit eeb5c9250417e761c401df8cfd421db576a0dbd1
Author: iwannagotobed <[email protected]>
AuthorDate: Thu Jun 4 20:11:07 2026 +0900

    Fix Weaviate tenant-aware ingestion (#67298)
    
    * Fix Weaviate tenant-aware ingestion
    
    * docs: Document Weaviate ingest operator parameters
---
 .../airflow/providers/weaviate/hooks/weaviate.py   |  43 ++++++-
 .../providers/weaviate/operators/weaviate.py       |   8 +-
 .../tests/unit/weaviate/hooks/test_weaviate.py     | 128 +++++++++++++++++++++
 .../tests/unit/weaviate/operators/test_weaviate.py |  46 ++++++++
 4 files changed, 219 insertions(+), 6 deletions(-)

diff --git 
a/providers/weaviate/src/airflow/providers/weaviate/hooks/weaviate.py 
b/providers/weaviate/src/airflow/providers/weaviate/hooks/weaviate.py
index 7928b940195..10dce9bed09 100644
--- a/providers/weaviate/src/airflow/providers/weaviate/hooks/weaviate.py
+++ b/providers/weaviate/src/airflow/providers/weaviate/hooks/weaviate.py
@@ -411,6 +411,7 @@ class WeaviateHook(BaseHook):
         uuid_col: str = "id",
         retry_attempts_per_object: int = 5,
         references: ReferenceInputs | None = None,
+        tenant: str | None = None,
     ) -> None:
         """
         Add multiple objects or object references at once into weaviate.
@@ -421,10 +422,13 @@ class WeaviateHook(BaseHook):
         :param uuid_col: Name of the column containing the UUID.
         :param retry_attempts_per_object: number of time to try in case of 
failure before giving up.
         :param references: The references of the object to be added as a 
dictionary. Use `wvc.Reference.to` to create the correct values in the dict.
+        :param tenant: The tenant to which the objects will be added.
         """
         converted_data = self._convert_dataframe_to_list(data)
 
         collection = self.get_collection(collection_name)
+        if tenant:
+            collection = collection.with_tenant(tenant)
         with collection.batch.dynamic() as batch:
             # Batch import all data
             for data_obj in converted_data:
@@ -585,14 +589,17 @@ class WeaviateHook(BaseHook):
             )
         return all_objects
 
-    def delete_object(self, collection_name: str, uuid: UUID | str) -> bool:
+    def delete_object(self, collection_name: str, uuid: UUID | str, tenant: 
str | None = None) -> bool:
         """
         Delete an object from weaviate.
 
         :param collection_name: Collection name associated with the object 
given.
         :param uuid: uuid of the object to be deleted
+        :param tenant: The tenant from which the object will be deleted.
         """
         collection = self.get_collection(collection_name)
+        if tenant:
+            collection = collection.with_tenant(tenant)
         return collection.data.delete_by_id(uuid=uuid)
 
     def update_object(
@@ -640,7 +647,11 @@ class WeaviateHook(BaseHook):
         return collection.data.exists(uuid=uuid)
 
     def _delete_objects(
-        self, uuids: list[UUID], collection_name: str, 
retry_attempts_per_object: int = 5
+        self,
+        uuids: list[UUID],
+        collection_name: str,
+        retry_attempts_per_object: int = 5,
+        tenant: str | None = None,
     ) -> None:
         """
         Delete multiple objects.
@@ -650,6 +661,7 @@ class WeaviateHook(BaseHook):
         :param uuids: Collection of uuids.
         :param collection_name: Name of the collection in Weaviate schema 
where data is to be ingested.
         :param retry_attempts_per_object: number of times to try in case of 
failure before giving up.
+        :param tenant: The tenant from which the objects will be deleted.
         """
         for uuid in uuids:
             for attempt in Retrying(
@@ -661,7 +673,7 @@ class WeaviateHook(BaseHook):
             ):
                 with attempt:
                     try:
-                        self.delete_object(uuid=uuid, 
collection_name=collection_name)
+                        self.delete_object(uuid=uuid, 
collection_name=collection_name, tenant=tenant)
                         self.log.debug("Deleted object with uuid %s", uuid)
                     except weaviate.exceptions.UnexpectedStatusCodeException 
as e:
                         if e.status_code == 404:
@@ -728,6 +740,7 @@ class WeaviateHook(BaseHook):
         document_column: str,
         uuid_column: str,
         collection_name: str,
+        tenant: str | None = None,
         offset: int = 0,
         limit: int = 2000,
     ) -> dict[str, set]:
@@ -737,6 +750,7 @@ class WeaviateHook(BaseHook):
         :param data: A single pandas DataFrame.
         :param document_column: The name of the property to query.
         :param collection_name: The name of the collection to query.
+        :param tenant: The tenant to query.
         :param uuid_column: The name of the column containing the UUID.
         :param offset: pagination parameter to indicate the which object to 
start fetching data.
         :param limit: pagination param to indicate the number of records to 
fetch from start object.
@@ -745,6 +759,8 @@ class WeaviateHook(BaseHook):
         document_keys = set(data[document_column])
         while True:
             collection = self.get_collection(collection_name)
+            if tenant:
+                collection = collection.with_tenant(tenant)
             data_objects = collection.query.fetch_objects(
                 filters=Filter.any_of(
                     [Filter.by_property(document_column).equal(key) for key in 
document_keys]
@@ -791,7 +807,12 @@ class WeaviateHook(BaseHook):
         return grouped_key_to_set
 
     def _get_segregated_documents(
-        self, data: pd.DataFrame, document_column: str, collection_name: str, 
uuid_column: str
+        self,
+        data: pd.DataFrame,
+        document_column: str,
+        collection_name: str,
+        uuid_column: str,
+        tenant: str | None = None,
     ) -> tuple[dict[str, set], set, set, set]:
         """
         Segregate documents into changed, unchanged and new document, when 
compared to Weaviate db.
@@ -800,6 +821,7 @@ class WeaviateHook(BaseHook):
         :param document_column: The name of the property to query.
         :param collection_name: The name of the collection to query.
         :param uuid_column: The name of the column containing the UUID.
+        :param tenant: The tenant to query.
         """
         changed_documents = set()
         unchanged_docs = set()
@@ -809,6 +831,7 @@ class WeaviateHook(BaseHook):
             uuid_column=uuid_column,
             document_column=document_column,
             collection_name=collection_name,
+            tenant=tenant,
         )
 
         input_documents_to_uuid = self._prepare_document_to_uuid_map(
@@ -836,6 +859,7 @@ class WeaviateHook(BaseHook):
         total_objects_count: int = 1,
         batch_delete_error: Sequence | None = None,
         verbose: bool = False,
+        tenant: str | None = None,
     ) -> Sequence[dict[str, UUID | str]]:
         """
         Delete all object that belong to list of documents.
@@ -847,6 +871,7 @@ class WeaviateHook(BaseHook):
             query is 10,000, if we have more objects to delete we need to run 
query multiple times.
         :param batch_delete_error: list to hold errors while inserting.
         :param verbose: Flag to enable verbose output during the ingestion 
process.
+        :param tenant: The tenant from which document objects will be deleted.
         """
         batch_delete_error = batch_delete_error or []
 
@@ -854,6 +879,8 @@ class WeaviateHook(BaseHook):
         MAX_LIMIT_ON_TOTAL_DELETABLE_OBJECTS = 10000
 
         collection = self.get_collection(collection_name)
+        if tenant:
+            collection = collection.with_tenant(tenant)
         delete_many_return = collection.data.delete_many(
             
where=Filter.any_of([Filter.by_property(document_column).equal(key) for key in 
document_keys]),
             verbose=verbose,
@@ -881,6 +908,7 @@ class WeaviateHook(BaseHook):
         uuid_column: str | None = None,
         vector_column: str = "Vector",
         verbose: bool = False,
+        tenant: str | None = None,
     ) -> Sequence[dict[str, UUID | str] | None]:
         """
         Create or replace objects belonging to documents.
@@ -909,6 +937,7 @@ class WeaviateHook(BaseHook):
         :param uuid_column: Column with pre-generated UUIDs. If not provided, 
UUIDs will be generated.
         :param vector_column: Column with embedding vectors for pre-embedded 
data.
         :param verbose: Flag to enable verbose output during the ingestion 
process.
+        :param tenant: The tenant to which objects will be added.
         :return: list of UUID which failed to create
         """
         if existing not in ["skip", "replace", "error"]:
@@ -960,6 +989,7 @@ class WeaviateHook(BaseHook):
             document_column=document_column,
             uuid_column=uuid_column,
             collection_name=collection_name,
+            tenant=tenant,
         )
         if verbose:
             self.log.info(
@@ -1001,6 +1031,7 @@ class WeaviateHook(BaseHook):
                     total_objects_count=total_objects_count,
                     batch_delete_error=batch_delete_error,
                     verbose=verbose,
+                    tenant=tenant,
                 )
             data = 
data[data[document_column].isin(new_documents.union(changed_documents))]
             self.log.info("Batch inserting %s objects for non-existing and 
changed documents.", data.shape[0])
@@ -1011,6 +1042,7 @@ class WeaviateHook(BaseHook):
                 data=data,
                 vector_col=vector_column,
                 uuid_col=uuid_column,
+                tenant=tenant,
             )
             if batch_delete_error:
                 if batch_delete_error:
@@ -1019,10 +1051,13 @@ class WeaviateHook(BaseHook):
                 self._delete_objects(
                     [item["uuid"] for item in batch_delete_error],
                     collection_name=collection_name,
+                    tenant=tenant,
                 )
 
         if verbose:
             collection = self.get_collection(collection_name)
+            if tenant:
+                collection = collection.with_tenant(tenant)
             self.log.info(
                 "Total objects in collection %s : %s ",
                 collection_name,
diff --git 
a/providers/weaviate/src/airflow/providers/weaviate/operators/weaviate.py 
b/providers/weaviate/src/airflow/providers/weaviate/operators/weaviate.py
index 6ad0c02bba5..de080c72320 100644
--- a/providers/weaviate/src/airflow/providers/weaviate/operators/weaviate.py
+++ b/providers/weaviate/src/airflow/providers/weaviate/operators/weaviate.py
@@ -43,10 +43,12 @@ class WeaviateIngestOperator(BaseOperator):
     custom vectors and store them in the Weaviate class.
 
     :param conn_id: The Weaviate connection.
-    :param collection: The Weaviate collection to be used for storing the data 
objects into.
+    :param collection_name: The Weaviate collection to be used for storing the 
data objects into.
     :param input_data: The list of dicts or pandas dataframe representing 
Weaviate data objects to generate
         embeddings on (or provides custom vectors) and store them in the 
Weaviate class.
     :param vector_col: key/column name in which the vectors are stored.
+    :param uuid_column: Column with pre-generated UUIDs.
+    :param tenant: The tenant to which objects will be added.
     :param hook_params: Optional config params to be passed to the underlying 
hook.
         Should match the desired hook constructor params.
     """
@@ -88,6 +90,7 @@ class WeaviateIngestOperator(BaseOperator):
             data=self.input_data,
             vector_col=self.vector_col,
             uuid_col=self.uuid_column,
+            tenant=self.tenant,
         )
 
 
@@ -118,7 +121,7 @@ class WeaviateDocumentIngestOperator(BaseOperator):
     :param document_column: Column in DataFrame that identifying source 
document.
     :param uuid_column: Column with pre-generated UUIDs. If not provided, 
UUIDs will be generated.
     :param vector_column: Column with embedding vectors for pre-embedded data.
-    :param tenant: The tenant to which the object will be added.
+    :param tenant: The tenant to which objects will be added.
     :param verbose: Flag to enable verbose output during the ingestion process.
     :param hook_params: Optional config params to be passed to the underlying 
hook.
         Should match the desired hook constructor params.
@@ -172,5 +175,6 @@ class WeaviateDocumentIngestOperator(BaseOperator):
             uuid_column=self.uuid_column,
             vector_column=self.vector_col,
             verbose=self.verbose,
+            tenant=self.tenant,
         )
         return batch_delete_error
diff --git a/providers/weaviate/tests/unit/weaviate/hooks/test_weaviate.py 
b/providers/weaviate/tests/unit/weaviate/hooks/test_weaviate.py
index 036a8143a0f..e8682688639 100644
--- a/providers/weaviate/tests/unit/weaviate/hooks/test_weaviate.py
+++ b/providers/weaviate/tests/unit/weaviate/hooks/test_weaviate.py
@@ -601,6 +601,24 @@ def test_batch_data(data, expected_length, weaviate_hook):
     assert mock_batch_context.add_object.call_count == expected_length
 
 
+def test_batch_data_uses_tenant_collection(weaviate_hook):
+    mock_collection = MagicMock()
+    mock_tenant_collection = MagicMock()
+    mock_collection.with_tenant.return_value = mock_tenant_collection
+    weaviate_hook.get_collection = MagicMock(return_value=mock_collection)
+
+    weaviate_hook.batch_data("TestCollection", [{"name": "John"}], 
tenant="tenant-a")
+
+    mock_collection.with_tenant.assert_called_once_with("tenant-a")
+    mock_collection.batch.dynamic.assert_not_called()
+    
mock_tenant_collection.batch.dynamic.return_value.__enter__.return_value.add_object.assert_called_once_with(
+        properties={"name": "John"},
+        references=None,
+        uuid=None,
+        vector=None,
+    )
+
+
 def test_batch_data_retry(weaviate_hook):
     """Test to ensure retrying working as expected"""
     # Mock the Weaviate Collection
@@ -819,6 +837,32 @@ def test__delete_objects(delete_object, weaviate_hook):
     assert delete_object.call_count == 5
 
 
+def test_delete_object_uses_tenant_collection(weaviate_hook):
+    mock_collection = MagicMock()
+    mock_tenant_collection = MagicMock()
+    mock_collection.with_tenant.return_value = mock_tenant_collection
+    weaviate_hook.get_collection = MagicMock(return_value=mock_collection)
+
+    weaviate_hook.delete_object(collection_name="test", uuid="1", 
tenant="tenant-a")
+
+    mock_collection.with_tenant.assert_called_once_with("tenant-a")
+    mock_tenant_collection.data.delete_by_id.assert_called_once_with(uuid="1")
+    mock_collection.data.delete_by_id.assert_not_called()
+
+
+def test__delete_objects_passes_tenant_to_delete_object(weaviate_hook):
+    weaviate_hook.delete_object = MagicMock()
+
+    weaviate_hook._delete_objects(uuids=["1", "2"], collection_name="test", 
tenant="tenant-a")
+
+    weaviate_hook.delete_object.assert_has_calls(
+        [
+            mock.call(uuid="1", collection_name="test", tenant="tenant-a"),
+            mock.call(uuid="2", collection_name="test", tenant="tenant-a"),
+        ]
+    )
+
+
 def test__prepare_document_to_uuid_map(weaviate_hook):
     input_data = [
         {"id": "1", "name": "ross", "age": "12", "gender": "m"},
@@ -860,6 +904,46 @@ def 
test___get_segregated_documents(_get_documents_to_uuid_map, _prepare_documen
     assert new_documents == {"hjk.doc"}
 
 
+def test__get_documents_to_uuid_map_uses_tenant_collection(weaviate_hook):
+    mock_collection = MagicMock()
+    mock_tenant_collection = MagicMock()
+    mock_collection.with_tenant.return_value = mock_tenant_collection
+    mock_tenant_collection.query.fetch_objects.return_value.objects = []
+    weaviate_hook.get_collection = MagicMock(return_value=mock_collection)
+
+    weaviate_hook._get_documents_to_uuid_map(
+        data=pd.DataFrame.from_dict({"doc": ["abc.xml"]}),
+        document_column="doc",
+        uuid_column="id",
+        collection_name="test",
+        tenant="tenant-a",
+    )
+
+    mock_collection.with_tenant.assert_called_once_with("tenant-a")
+    mock_tenant_collection.query.fetch_objects.assert_called_once()
+    mock_collection.query.fetch_objects.assert_not_called()
+
+
+def test__delete_all_documents_objects_uses_tenant_collection(weaviate_hook):
+    mock_collection = MagicMock()
+    mock_tenant_collection = MagicMock()
+    mock_collection.with_tenant.return_value = mock_tenant_collection
+    mock_tenant_collection.data.delete_many.return_value.matches = 1
+    mock_tenant_collection.data.delete_many.return_value.failed = 0
+    weaviate_hook.get_collection = MagicMock(return_value=mock_collection)
+
+    weaviate_hook._delete_all_documents_objects(
+        document_keys=["abc.xml"],
+        document_column="doc",
+        collection_name="test",
+        tenant="tenant-a",
+    )
+
+    mock_collection.with_tenant.assert_called_once_with("tenant-a")
+    mock_tenant_collection.data.delete_many.assert_called_once()
+    mock_collection.data.delete_many.assert_not_called()
+
+
 
@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook._get_segregated_documents")
 
@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook._generate_uuids")
 def test_error_option_of_create_or_replace_document_objects(
@@ -922,6 +1006,7 @@ def test_skip_option_of_create_or_replace_document_objects(
     pd.testing.assert_frame_equal(
         batch_data.call_args_list[0].kwargs["data"], 
df[df["doc"].isin(new_documents)]
     )
+    assert batch_data.call_args_list[0].kwargs["tenant"] is None
 
 
 
@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook._delete_all_documents_objects")
@@ -966,11 +1051,54 @@ def 
test_replace_option_of_create_or_replace_document_objects(
         collection_name="test",
         batch_delete_error=[],
         verbose=False,
+        tenant=None,
     )
     pd.testing.assert_frame_equal(
         batch_data.call_args_list[0].kwargs["data"],
         df[df["doc"].isin(changed_documents.union(new_documents))],
     )
+    assert batch_data.call_args_list[0].kwargs["tenant"] is None
+
+
[email protected]("airflow.providers.weaviate.hooks.weaviate.WeaviateHook._delete_all_documents_objects")
[email protected]("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.batch_data")
[email protected]("airflow.providers.weaviate.hooks.weaviate.WeaviateHook._get_segregated_documents")
[email protected]("airflow.providers.weaviate.hooks.weaviate.WeaviateHook._generate_uuids")
+def test_create_or_replace_document_objects_passes_tenant_to_helpers(
+    _generate_uuids, _get_segregated_documents, batch_data, 
_delete_all_documents_objects, weaviate_hook
+):
+    df = pd.DataFrame.from_dict(
+        {
+            "id": ["1", "2", "3"],
+            "name": ["ross", "bob", "joy"],
+            "doc": ["abc.xml", "zyx.html", "zyx.html"],
+        }
+    )
+    documents_to_uuid_map, changed_documents, unchanged_documents, 
new_documents = (
+        {"abc.xml": {"uuid"}},
+        {"abc.xml"},
+        {},
+        {"zyx.html"},
+    )
+    _generate_uuids.return_value = (df, "id")
+    _get_segregated_documents.return_value = (
+        documents_to_uuid_map,
+        changed_documents,
+        unchanged_documents,
+        new_documents,
+    )
+
+    weaviate_hook.create_or_replace_document_objects(
+        data=df,
+        collection_name="test",
+        existing="replace",
+        document_column="doc",
+        tenant="tenant-a",
+    )
+
+    assert _get_segregated_documents.call_args_list[0].kwargs["tenant"] == 
"tenant-a"
+    assert _delete_all_documents_objects.call_args_list[0].kwargs["tenant"] == 
"tenant-a"
+    assert batch_data.call_args_list[0].kwargs["tenant"] == "tenant-a"
 
 
 
@mock.patch("airflow.providers.weaviate.hooks.weaviate.weaviate.connect_to_custom")
diff --git a/providers/weaviate/tests/unit/weaviate/operators/test_weaviate.py 
b/providers/weaviate/tests/unit/weaviate/operators/test_weaviate.py
index 8f1f538b8c9..0f09fb35d52 100644
--- a/providers/weaviate/tests/unit/weaviate/operators/test_weaviate.py
+++ b/providers/weaviate/tests/unit/weaviate/operators/test_weaviate.py
@@ -57,9 +57,30 @@ class TestWeaviateIngestOperator:
             data=[{"data": "sample_data"}],
             vector_col="Vector",
             uuid_col="id",
+            tenant=None,
         )
         mock_log.debug.assert_called_once_with("Input data: %s", [{"data": 
"sample_data"}])
 
+    def test_execute_passes_tenant_to_hook(self):
+        operator = WeaviateIngestOperator(
+            task_id="weaviate_task",
+            conn_id="weaviate_conn",
+            collection_name="my_collection",
+            input_data=[{"data": "sample_data"}],
+            tenant="tenant-a",
+        )
+        operator.hook.batch_data = MagicMock()
+
+        operator.execute(context=None)
+
+        operator.hook.batch_data.assert_called_once_with(
+            collection_name="my_collection",
+            data=[{"data": "sample_data"}],
+            vector_col="Vector",
+            uuid_col="id",
+            tenant="tenant-a",
+        )
+
     @pytest.mark.db_test
     def test_templates(self, create_task_instance_of_operator):
         dag_id = "TestWeaviateIngestOperator"
@@ -129,9 +150,34 @@ class TestWeaviateDocumentIngestOperator:
             uuid_column="id",
             vector_column="vector",
             verbose=False,
+            tenant=None,
         )
         mock_log.debug.assert_called_once_with("Total input objects : %s", 
len([{"data": "sample_data"}]))
 
+    def test_execute_passes_tenant_to_hook(self):
+        operator = WeaviateDocumentIngestOperator(
+            task_id="weaviate_task",
+            conn_id="weaviate_conn",
+            input_data=[{"data": "sample_data"}],
+            collection_name="my_collection",
+            document_column="docLink",
+            tenant="tenant-a",
+        )
+        operator.hook.create_or_replace_document_objects = MagicMock()
+
+        operator.execute(context=None)
+
+        
operator.hook.create_or_replace_document_objects.assert_called_once_with(
+            data=[{"data": "sample_data"}],
+            collection_name="my_collection",
+            document_column="docLink",
+            existing="skip",
+            uuid_column="id",
+            vector_column="Vector",
+            verbose=False,
+            tenant="tenant-a",
+        )
+
     @pytest.mark.db_test
     def test_partial_hook_params(self, dag_maker, session):
         with dag_maker(dag_id="test_partial_hook_params", session=session):

Reply via email to