This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new eeb5c925041 Fix Weaviate tenant-aware ingestion (#67298)
eeb5c925041 is described below
commit eeb5c9250417e761c401df8cfd421db576a0dbd1
Author: iwannagotobed <[email protected]>
AuthorDate: Thu Jun 4 20:11:07 2026 +0900
Fix Weaviate tenant-aware ingestion (#67298)
* Fix Weaviate tenant-aware ingestion
* docs: Document Weaviate ingest operator parameters
---
.../airflow/providers/weaviate/hooks/weaviate.py | 43 ++++++-
.../providers/weaviate/operators/weaviate.py | 8 +-
.../tests/unit/weaviate/hooks/test_weaviate.py | 128 +++++++++++++++++++++
.../tests/unit/weaviate/operators/test_weaviate.py | 46 ++++++++
4 files changed, 219 insertions(+), 6 deletions(-)
diff --git
a/providers/weaviate/src/airflow/providers/weaviate/hooks/weaviate.py
b/providers/weaviate/src/airflow/providers/weaviate/hooks/weaviate.py
index 7928b940195..10dce9bed09 100644
--- a/providers/weaviate/src/airflow/providers/weaviate/hooks/weaviate.py
+++ b/providers/weaviate/src/airflow/providers/weaviate/hooks/weaviate.py
@@ -411,6 +411,7 @@ class WeaviateHook(BaseHook):
uuid_col: str = "id",
retry_attempts_per_object: int = 5,
references: ReferenceInputs | None = None,
+ tenant: str | None = None,
) -> None:
"""
Add multiple objects or object references at once into weaviate.
@@ -421,10 +422,13 @@ class WeaviateHook(BaseHook):
:param uuid_col: Name of the column containing the UUID.
:param retry_attempts_per_object: number of time to try in case of
failure before giving up.
:param references: The references of the object to be added as a
dictionary. Use `wvc.Reference.to` to create the correct values in the dict.
+ :param tenant: The tenant to which the objects will be added.
"""
converted_data = self._convert_dataframe_to_list(data)
collection = self.get_collection(collection_name)
+ if tenant:
+ collection = collection.with_tenant(tenant)
with collection.batch.dynamic() as batch:
# Batch import all data
for data_obj in converted_data:
@@ -585,14 +589,17 @@ class WeaviateHook(BaseHook):
)
return all_objects
- def delete_object(self, collection_name: str, uuid: UUID | str) -> bool:
+ def delete_object(self, collection_name: str, uuid: UUID | str, tenant:
str | None = None) -> bool:
"""
Delete an object from weaviate.
:param collection_name: Collection name associated with the object
given.
:param uuid: uuid of the object to be deleted
+ :param tenant: The tenant from which the object will be deleted.
"""
collection = self.get_collection(collection_name)
+ if tenant:
+ collection = collection.with_tenant(tenant)
return collection.data.delete_by_id(uuid=uuid)
def update_object(
@@ -640,7 +647,11 @@ class WeaviateHook(BaseHook):
return collection.data.exists(uuid=uuid)
def _delete_objects(
- self, uuids: list[UUID], collection_name: str,
retry_attempts_per_object: int = 5
+ self,
+ uuids: list[UUID],
+ collection_name: str,
+ retry_attempts_per_object: int = 5,
+ tenant: str | None = None,
) -> None:
"""
Delete multiple objects.
@@ -650,6 +661,7 @@ class WeaviateHook(BaseHook):
:param uuids: Collection of uuids.
:param collection_name: Name of the collection in Weaviate schema
where data is to be ingested.
:param retry_attempts_per_object: number of times to try in case of
failure before giving up.
+ :param tenant: The tenant from which the objects will be deleted.
"""
for uuid in uuids:
for attempt in Retrying(
@@ -661,7 +673,7 @@ class WeaviateHook(BaseHook):
):
with attempt:
try:
- self.delete_object(uuid=uuid,
collection_name=collection_name)
+ self.delete_object(uuid=uuid,
collection_name=collection_name, tenant=tenant)
self.log.debug("Deleted object with uuid %s", uuid)
except weaviate.exceptions.UnexpectedStatusCodeException
as e:
if e.status_code == 404:
@@ -728,6 +740,7 @@ class WeaviateHook(BaseHook):
document_column: str,
uuid_column: str,
collection_name: str,
+ tenant: str | None = None,
offset: int = 0,
limit: int = 2000,
) -> dict[str, set]:
@@ -737,6 +750,7 @@ class WeaviateHook(BaseHook):
:param data: A single pandas DataFrame.
:param document_column: The name of the property to query.
:param collection_name: The name of the collection to query.
+ :param tenant: The tenant to query.
:param uuid_column: The name of the column containing the UUID.
:param offset: pagination parameter to indicate the which object to
start fetching data.
:param limit: pagination param to indicate the number of records to
fetch from start object.
@@ -745,6 +759,8 @@ class WeaviateHook(BaseHook):
document_keys = set(data[document_column])
while True:
collection = self.get_collection(collection_name)
+ if tenant:
+ collection = collection.with_tenant(tenant)
data_objects = collection.query.fetch_objects(
filters=Filter.any_of(
[Filter.by_property(document_column).equal(key) for key in
document_keys]
@@ -791,7 +807,12 @@ class WeaviateHook(BaseHook):
return grouped_key_to_set
def _get_segregated_documents(
- self, data: pd.DataFrame, document_column: str, collection_name: str,
uuid_column: str
+ self,
+ data: pd.DataFrame,
+ document_column: str,
+ collection_name: str,
+ uuid_column: str,
+ tenant: str | None = None,
) -> tuple[dict[str, set], set, set, set]:
"""
Segregate documents into changed, unchanged and new document, when
compared to Weaviate db.
@@ -800,6 +821,7 @@ class WeaviateHook(BaseHook):
:param document_column: The name of the property to query.
:param collection_name: The name of the collection to query.
:param uuid_column: The name of the column containing the UUID.
+ :param tenant: The tenant to query.
"""
changed_documents = set()
unchanged_docs = set()
@@ -809,6 +831,7 @@ class WeaviateHook(BaseHook):
uuid_column=uuid_column,
document_column=document_column,
collection_name=collection_name,
+ tenant=tenant,
)
input_documents_to_uuid = self._prepare_document_to_uuid_map(
@@ -836,6 +859,7 @@ class WeaviateHook(BaseHook):
total_objects_count: int = 1,
batch_delete_error: Sequence | None = None,
verbose: bool = False,
+ tenant: str | None = None,
) -> Sequence[dict[str, UUID | str]]:
"""
Delete all object that belong to list of documents.
@@ -847,6 +871,7 @@ class WeaviateHook(BaseHook):
query is 10,000, if we have more objects to delete we need to run
query multiple times.
:param batch_delete_error: list to hold errors while inserting.
:param verbose: Flag to enable verbose output during the ingestion
process.
+ :param tenant: The tenant from which document objects will be deleted.
"""
batch_delete_error = batch_delete_error or []
@@ -854,6 +879,8 @@ class WeaviateHook(BaseHook):
MAX_LIMIT_ON_TOTAL_DELETABLE_OBJECTS = 10000
collection = self.get_collection(collection_name)
+ if tenant:
+ collection = collection.with_tenant(tenant)
delete_many_return = collection.data.delete_many(
where=Filter.any_of([Filter.by_property(document_column).equal(key) for key in
document_keys]),
verbose=verbose,
@@ -881,6 +908,7 @@ class WeaviateHook(BaseHook):
uuid_column: str | None = None,
vector_column: str = "Vector",
verbose: bool = False,
+ tenant: str | None = None,
) -> Sequence[dict[str, UUID | str] | None]:
"""
Create or replace objects belonging to documents.
@@ -909,6 +937,7 @@ class WeaviateHook(BaseHook):
:param uuid_column: Column with pre-generated UUIDs. If not provided,
UUIDs will be generated.
:param vector_column: Column with embedding vectors for pre-embedded
data.
:param verbose: Flag to enable verbose output during the ingestion
process.
+ :param tenant: The tenant to which objects will be added.
:return: list of UUID which failed to create
"""
if existing not in ["skip", "replace", "error"]:
@@ -960,6 +989,7 @@ class WeaviateHook(BaseHook):
document_column=document_column,
uuid_column=uuid_column,
collection_name=collection_name,
+ tenant=tenant,
)
if verbose:
self.log.info(
@@ -1001,6 +1031,7 @@ class WeaviateHook(BaseHook):
total_objects_count=total_objects_count,
batch_delete_error=batch_delete_error,
verbose=verbose,
+ tenant=tenant,
)
data =
data[data[document_column].isin(new_documents.union(changed_documents))]
self.log.info("Batch inserting %s objects for non-existing and
changed documents.", data.shape[0])
@@ -1011,6 +1042,7 @@ class WeaviateHook(BaseHook):
data=data,
vector_col=vector_column,
uuid_col=uuid_column,
+ tenant=tenant,
)
if batch_delete_error:
if batch_delete_error:
@@ -1019,10 +1051,13 @@ class WeaviateHook(BaseHook):
self._delete_objects(
[item["uuid"] for item in batch_delete_error],
collection_name=collection_name,
+ tenant=tenant,
)
if verbose:
collection = self.get_collection(collection_name)
+ if tenant:
+ collection = collection.with_tenant(tenant)
self.log.info(
"Total objects in collection %s : %s ",
collection_name,
diff --git
a/providers/weaviate/src/airflow/providers/weaviate/operators/weaviate.py
b/providers/weaviate/src/airflow/providers/weaviate/operators/weaviate.py
index 6ad0c02bba5..de080c72320 100644
--- a/providers/weaviate/src/airflow/providers/weaviate/operators/weaviate.py
+++ b/providers/weaviate/src/airflow/providers/weaviate/operators/weaviate.py
@@ -43,10 +43,12 @@ class WeaviateIngestOperator(BaseOperator):
custom vectors and store them in the Weaviate class.
:param conn_id: The Weaviate connection.
- :param collection: The Weaviate collection to be used for storing the data
objects into.
+ :param collection_name: The Weaviate collection to be used for storing the
data objects into.
:param input_data: The list of dicts or pandas dataframe representing
Weaviate data objects to generate
embeddings on (or provides custom vectors) and store them in the
Weaviate class.
:param vector_col: key/column name in which the vectors are stored.
+ :param uuid_column: Column with pre-generated UUIDs.
+ :param tenant: The tenant to which objects will be added.
:param hook_params: Optional config params to be passed to the underlying
hook.
Should match the desired hook constructor params.
"""
@@ -88,6 +90,7 @@ class WeaviateIngestOperator(BaseOperator):
data=self.input_data,
vector_col=self.vector_col,
uuid_col=self.uuid_column,
+ tenant=self.tenant,
)
@@ -118,7 +121,7 @@ class WeaviateDocumentIngestOperator(BaseOperator):
:param document_column: Column in DataFrame that identifying source
document.
:param uuid_column: Column with pre-generated UUIDs. If not provided,
UUIDs will be generated.
:param vector_column: Column with embedding vectors for pre-embedded data.
- :param tenant: The tenant to which the object will be added.
+ :param tenant: The tenant to which objects will be added.
:param verbose: Flag to enable verbose output during the ingestion process.
:param hook_params: Optional config params to be passed to the underlying
hook.
Should match the desired hook constructor params.
@@ -172,5 +175,6 @@ class WeaviateDocumentIngestOperator(BaseOperator):
uuid_column=self.uuid_column,
vector_column=self.vector_col,
verbose=self.verbose,
+ tenant=self.tenant,
)
return batch_delete_error
diff --git a/providers/weaviate/tests/unit/weaviate/hooks/test_weaviate.py
b/providers/weaviate/tests/unit/weaviate/hooks/test_weaviate.py
index 036a8143a0f..e8682688639 100644
--- a/providers/weaviate/tests/unit/weaviate/hooks/test_weaviate.py
+++ b/providers/weaviate/tests/unit/weaviate/hooks/test_weaviate.py
@@ -601,6 +601,24 @@ def test_batch_data(data, expected_length, weaviate_hook):
assert mock_batch_context.add_object.call_count == expected_length
+def test_batch_data_uses_tenant_collection(weaviate_hook):
+ mock_collection = MagicMock()
+ mock_tenant_collection = MagicMock()
+ mock_collection.with_tenant.return_value = mock_tenant_collection
+ weaviate_hook.get_collection = MagicMock(return_value=mock_collection)
+
+ weaviate_hook.batch_data("TestCollection", [{"name": "John"}],
tenant="tenant-a")
+
+ mock_collection.with_tenant.assert_called_once_with("tenant-a")
+ mock_collection.batch.dynamic.assert_not_called()
+
mock_tenant_collection.batch.dynamic.return_value.__enter__.return_value.add_object.assert_called_once_with(
+ properties={"name": "John"},
+ references=None,
+ uuid=None,
+ vector=None,
+ )
+
+
def test_batch_data_retry(weaviate_hook):
"""Test to ensure retrying working as expected"""
# Mock the Weaviate Collection
@@ -819,6 +837,32 @@ def test__delete_objects(delete_object, weaviate_hook):
assert delete_object.call_count == 5
+def test_delete_object_uses_tenant_collection(weaviate_hook):
+ mock_collection = MagicMock()
+ mock_tenant_collection = MagicMock()
+ mock_collection.with_tenant.return_value = mock_tenant_collection
+ weaviate_hook.get_collection = MagicMock(return_value=mock_collection)
+
+ weaviate_hook.delete_object(collection_name="test", uuid="1",
tenant="tenant-a")
+
+ mock_collection.with_tenant.assert_called_once_with("tenant-a")
+ mock_tenant_collection.data.delete_by_id.assert_called_once_with(uuid="1")
+ mock_collection.data.delete_by_id.assert_not_called()
+
+
+def test__delete_objects_passes_tenant_to_delete_object(weaviate_hook):
+ weaviate_hook.delete_object = MagicMock()
+
+ weaviate_hook._delete_objects(uuids=["1", "2"], collection_name="test",
tenant="tenant-a")
+
+ weaviate_hook.delete_object.assert_has_calls(
+ [
+ mock.call(uuid="1", collection_name="test", tenant="tenant-a"),
+ mock.call(uuid="2", collection_name="test", tenant="tenant-a"),
+ ]
+ )
+
+
def test__prepare_document_to_uuid_map(weaviate_hook):
input_data = [
{"id": "1", "name": "ross", "age": "12", "gender": "m"},
@@ -860,6 +904,46 @@ def
test___get_segregated_documents(_get_documents_to_uuid_map, _prepare_documen
assert new_documents == {"hjk.doc"}
+def test__get_documents_to_uuid_map_uses_tenant_collection(weaviate_hook):
+ mock_collection = MagicMock()
+ mock_tenant_collection = MagicMock()
+ mock_collection.with_tenant.return_value = mock_tenant_collection
+ mock_tenant_collection.query.fetch_objects.return_value.objects = []
+ weaviate_hook.get_collection = MagicMock(return_value=mock_collection)
+
+ weaviate_hook._get_documents_to_uuid_map(
+ data=pd.DataFrame.from_dict({"doc": ["abc.xml"]}),
+ document_column="doc",
+ uuid_column="id",
+ collection_name="test",
+ tenant="tenant-a",
+ )
+
+ mock_collection.with_tenant.assert_called_once_with("tenant-a")
+ mock_tenant_collection.query.fetch_objects.assert_called_once()
+ mock_collection.query.fetch_objects.assert_not_called()
+
+
+def test__delete_all_documents_objects_uses_tenant_collection(weaviate_hook):
+ mock_collection = MagicMock()
+ mock_tenant_collection = MagicMock()
+ mock_collection.with_tenant.return_value = mock_tenant_collection
+ mock_tenant_collection.data.delete_many.return_value.matches = 1
+ mock_tenant_collection.data.delete_many.return_value.failed = 0
+ weaviate_hook.get_collection = MagicMock(return_value=mock_collection)
+
+ weaviate_hook._delete_all_documents_objects(
+ document_keys=["abc.xml"],
+ document_column="doc",
+ collection_name="test",
+ tenant="tenant-a",
+ )
+
+ mock_collection.with_tenant.assert_called_once_with("tenant-a")
+ mock_tenant_collection.data.delete_many.assert_called_once()
+ mock_collection.data.delete_many.assert_not_called()
+
+
@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook._get_segregated_documents")
@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook._generate_uuids")
def test_error_option_of_create_or_replace_document_objects(
@@ -922,6 +1006,7 @@ def test_skip_option_of_create_or_replace_document_objects(
pd.testing.assert_frame_equal(
batch_data.call_args_list[0].kwargs["data"],
df[df["doc"].isin(new_documents)]
)
+ assert batch_data.call_args_list[0].kwargs["tenant"] is None
@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook._delete_all_documents_objects")
@@ -966,11 +1051,54 @@ def
test_replace_option_of_create_or_replace_document_objects(
collection_name="test",
batch_delete_error=[],
verbose=False,
+ tenant=None,
)
pd.testing.assert_frame_equal(
batch_data.call_args_list[0].kwargs["data"],
df[df["doc"].isin(changed_documents.union(new_documents))],
)
+ assert batch_data.call_args_list[0].kwargs["tenant"] is None
+
+
[email protected]("airflow.providers.weaviate.hooks.weaviate.WeaviateHook._delete_all_documents_objects")
[email protected]("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.batch_data")
[email protected]("airflow.providers.weaviate.hooks.weaviate.WeaviateHook._get_segregated_documents")
[email protected]("airflow.providers.weaviate.hooks.weaviate.WeaviateHook._generate_uuids")
+def test_create_or_replace_document_objects_passes_tenant_to_helpers(
+ _generate_uuids, _get_segregated_documents, batch_data,
_delete_all_documents_objects, weaviate_hook
+):
+ df = pd.DataFrame.from_dict(
+ {
+ "id": ["1", "2", "3"],
+ "name": ["ross", "bob", "joy"],
+ "doc": ["abc.xml", "zyx.html", "zyx.html"],
+ }
+ )
+ documents_to_uuid_map, changed_documents, unchanged_documents,
new_documents = (
+ {"abc.xml": {"uuid"}},
+ {"abc.xml"},
+ {},
+ {"zyx.html"},
+ )
+ _generate_uuids.return_value = (df, "id")
+ _get_segregated_documents.return_value = (
+ documents_to_uuid_map,
+ changed_documents,
+ unchanged_documents,
+ new_documents,
+ )
+
+ weaviate_hook.create_or_replace_document_objects(
+ data=df,
+ collection_name="test",
+ existing="replace",
+ document_column="doc",
+ tenant="tenant-a",
+ )
+
+ assert _get_segregated_documents.call_args_list[0].kwargs["tenant"] ==
"tenant-a"
+ assert _delete_all_documents_objects.call_args_list[0].kwargs["tenant"] ==
"tenant-a"
+ assert batch_data.call_args_list[0].kwargs["tenant"] == "tenant-a"
@mock.patch("airflow.providers.weaviate.hooks.weaviate.weaviate.connect_to_custom")
diff --git a/providers/weaviate/tests/unit/weaviate/operators/test_weaviate.py
b/providers/weaviate/tests/unit/weaviate/operators/test_weaviate.py
index 8f1f538b8c9..0f09fb35d52 100644
--- a/providers/weaviate/tests/unit/weaviate/operators/test_weaviate.py
+++ b/providers/weaviate/tests/unit/weaviate/operators/test_weaviate.py
@@ -57,9 +57,30 @@ class TestWeaviateIngestOperator:
data=[{"data": "sample_data"}],
vector_col="Vector",
uuid_col="id",
+ tenant=None,
)
mock_log.debug.assert_called_once_with("Input data: %s", [{"data":
"sample_data"}])
+ def test_execute_passes_tenant_to_hook(self):
+ operator = WeaviateIngestOperator(
+ task_id="weaviate_task",
+ conn_id="weaviate_conn",
+ collection_name="my_collection",
+ input_data=[{"data": "sample_data"}],
+ tenant="tenant-a",
+ )
+ operator.hook.batch_data = MagicMock()
+
+ operator.execute(context=None)
+
+ operator.hook.batch_data.assert_called_once_with(
+ collection_name="my_collection",
+ data=[{"data": "sample_data"}],
+ vector_col="Vector",
+ uuid_col="id",
+ tenant="tenant-a",
+ )
+
@pytest.mark.db_test
def test_templates(self, create_task_instance_of_operator):
dag_id = "TestWeaviateIngestOperator"
@@ -129,9 +150,34 @@ class TestWeaviateDocumentIngestOperator:
uuid_column="id",
vector_column="vector",
verbose=False,
+ tenant=None,
)
mock_log.debug.assert_called_once_with("Total input objects : %s",
len([{"data": "sample_data"}]))
+ def test_execute_passes_tenant_to_hook(self):
+ operator = WeaviateDocumentIngestOperator(
+ task_id="weaviate_task",
+ conn_id="weaviate_conn",
+ input_data=[{"data": "sample_data"}],
+ collection_name="my_collection",
+ document_column="docLink",
+ tenant="tenant-a",
+ )
+ operator.hook.create_or_replace_document_objects = MagicMock()
+
+ operator.execute(context=None)
+
+
operator.hook.create_or_replace_document_objects.assert_called_once_with(
+ data=[{"data": "sample_data"}],
+ collection_name="my_collection",
+ document_column="docLink",
+ existing="skip",
+ uuid_column="id",
+ vector_column="Vector",
+ verbose=False,
+ tenant="tenant-a",
+ )
+
@pytest.mark.db_test
def test_partial_hook_params(self, dag_maker, session):
with dag_maker(dag_id="test_partial_hook_params", session=session):