sunank200 commented on code in PR #37101:
URL: https://github.com/apache/airflow/pull/37101#discussion_r1500339341


##########
tests/datasets/test_dataset.py:
##########
@@ -269,3 +269,70 @@ def test_dag_with_complex_dataset_triggers(session, 
dag_maker):
     assert isinstance(
         serialized_dag_dict["dataset_triggers"], dict
     ), "Serialized 'dataset_triggers' should be a dict"
+
+
+def datasets_equal(d1, d2):
+    if type(d1) != type(d2):
+        return False
+
+    if isinstance(d1, Dataset):
+        return d1.uri == d2.uri
+
+    elif isinstance(d1, (DatasetAny, DatasetAll)):
+        if len(d1.objects) != len(d2.objects):
+            return False
+
+        # Compare each pair of objects
+        for obj1, obj2 in zip(d1.objects, d2.objects):
+            # If obj1 or obj2 is a Dataset, DatasetAny, or DatasetAll instance,
+            # recursively call datasets_equal
+            if not datasets_equal(obj1, obj2):
+                return False
+        return True
+
+    return False
+
+
+dataset1 = Dataset(uri="s3://bucket1/data1")
+dataset2 = Dataset(uri="s3://bucket2/data2")
+dataset3 = Dataset(uri="s3://bucket3/data3")
+dataset4 = Dataset(uri="s3://bucket4/data4")
+dataset5 = Dataset(uri="s3://bucket5/data5")
+
+test_cases = [
+    (lambda: dataset1, dataset1),
+    (lambda: dataset1 & dataset2, DatasetAll(dataset1, dataset2)),
+    (lambda: dataset1 | dataset2, DatasetAny(dataset1, dataset2)),
+    (lambda: dataset1 | (dataset2 & dataset3), DatasetAny(dataset1, 
DatasetAll(dataset2, dataset3))),
+    (lambda: dataset1 | dataset2 & dataset3, DatasetAny(dataset1, 
DatasetAll(dataset2, dataset3))),
+    (
+        lambda: ((dataset1 & dataset2) | dataset3) & (dataset4 | dataset5),
+        DatasetAll(DatasetAny(DatasetAll(dataset1, dataset2), dataset3), 
DatasetAny(dataset4, dataset5)),
+    ),
+    (lambda: dataset1 & dataset2 | dataset3, DatasetAny(DatasetAll(dataset1, 
dataset2), dataset3)),
+    (
+        lambda: (dataset1 | dataset2) & (dataset3 | dataset4),
+        DatasetAll(DatasetAny(dataset1, dataset2), DatasetAny(dataset3, 
dataset4)),
+    ),
+    (
+        lambda: (dataset1 & dataset2) | (dataset3 & (dataset4 | dataset5)),
+        DatasetAny(DatasetAll(dataset1, dataset2), DatasetAll(dataset3, 
DatasetAny(dataset4, dataset5))),
+    ),
+    (
+        lambda: (dataset1 & dataset2) & (dataset3 & dataset4),
+        DatasetAll(dataset1, dataset2, DatasetAll(dataset3, dataset4)),
+    ),
+    (lambda: dataset1 | dataset2 | dataset3, DatasetAny(dataset1, dataset2, 
dataset3)),
+    (lambda: dataset1 & dataset2 & dataset3, DatasetAll(dataset1, dataset2, 
dataset3)),
+    (
+        lambda: ((dataset1 & dataset2) | dataset3) & (dataset4 | dataset5),
+        DatasetAll(DatasetAny(DatasetAll(dataset1, dataset2), dataset3), 
DatasetAny(dataset4, dataset5)),
+    ),
+]
+
+
+@pytest.mark.parametrize("expression, expected", test_cases)
+def test_extract_datasets(expression, expected):
+    expr = expression()
+    result = extract_datasets(expr)

Review Comment:
   Lambda functions are used to delay the evaluation of an expression until it 
is called. This means that the dataset objects and their combinations are not 
created until the test function is actually executed.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@airflow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to