Lee-W commented on code in PR #37101:
URL: https://github.com/apache/airflow/pull/37101#discussion_r1500235447


##########
airflow/datasets/__init__.py:
##########
@@ -106,8 +112,91 @@ class DatasetAny(_DatasetBooleanCondition):
 
     agg_func = any
 
+    def __init__(self, *objects: Dataset | DatasetAny | DatasetAll) -> None:
+        """Initialize with one or more Dataset, DatasetAny, or DatasetAll 
instances."""
+        super().__init__(*objects)
+
+    def __or__(self, other):
+        if isinstance(other, (Dataset, DatasetAny, DatasetAll)):
+            return DatasetAny(*self.objects, other)
+        return NotImplemented
+
+    def __and__(self, other):
+        if isinstance(other, (Dataset, DatasetAny, DatasetAll)):
+            return DatasetAll(self, other)
+        return NotImplemented
+
+    def __repr__(self) -> str:
+        return f"DatasetAny({', '.join(map(str, self.objects))})"
+
 
 class DatasetAll(_DatasetBooleanCondition):
     """Use to combine datasets schedule references in an "or" relationship."""
 
     agg_func = all
+
+    def __init__(self, *objects: Dataset | DatasetAny | DatasetAll):
+        """Initialize with one or more Dataset, DatasetAny, or DatasetAll 
instances."""
+        super().__init__(*objects)
+
+    def __or__(self, other):
+        if isinstance(other, (Dataset, DatasetAny, DatasetAll)):
+            return DatasetAny(self, other)
+        return NotImplemented
+
+    def __and__(self, other):
+        if isinstance(other, (Dataset, DatasetAny, DatasetAll)):
+            return DatasetAll(*self.objects, other)
+        return NotImplemented
+
+    def __repr__(self) -> str:
+        return f"DatasetAll({', '.join(map(str, self.objects))})"
+
+
+class DatasetsExpression:
+    """
+    Represents a node in an expression tree for dataset conditions.
+
+    :param value: The value of the node, which can be a 'Dataset', '&', or '|'.
+    :param left: The left child node.
+    :param right: The right child node.
+    """
+
+    def __init__(self, value, left=None, right=None) -> None:
+        self.value = value  # value can be 'Dataset', '&', or '|'
+        self.left = left
+        self.right = right
+
+    def __or__(self, other: Dataset | DatasetsExpression) -> 
DatasetsExpression:
+        return DatasetsExpression("|", self, other)
+
+    def __and__(self, other: Dataset | DatasetsExpression) -> 
DatasetsExpression:
+        return DatasetsExpression("&", self, other)
+
+    def __repr__(self) -> str:
+        if isinstance(self.value, Dataset):
+            return f"Dataset(uri='{self.value.uri}')"
+        elif self.value == "&":
+            return repr(DatasetAll(self.left, self.right))
+        elif self.value == "|":
+            return repr(DatasetAny(self.left, self.right))
+        else:
+            return f"Invalid DatasetsExpression(value={self.value})"

Review Comment:
   ```suggestion
           return f"Invalid DatasetsExpression(value={self.value})"
   ```
   
   Not sure whether we should raise an exception in this case



##########
tests/datasets/test_dataset.py:
##########
@@ -269,3 +269,70 @@ def test_dag_with_complex_dataset_triggers(session, 
dag_maker):
     assert isinstance(
         serialized_dag_dict["dataset_triggers"], dict
     ), "Serialized 'dataset_triggers' should be a dict"
+
+
+def datasets_equal(d1, d2):

Review Comment:
   ```suggestion
   def datasets_equal(d1: Dataset | DatasetAny | DatasetAll, d2: Dataset | 
DatasetAny | DatasetAll) -> bool:
   ```



##########
tests/datasets/test_dataset.py:
##########
@@ -269,3 +269,70 @@ def test_dag_with_complex_dataset_triggers(session, 
dag_maker):
     assert isinstance(
         serialized_dag_dict["dataset_triggers"], dict
     ), "Serialized 'dataset_triggers' should be a dict"
+
+
+def datasets_equal(d1, d2):
+    if type(d1) != type(d2):
+        return False
+
+    if isinstance(d1, Dataset):
+        return d1.uri == d2.uri
+
+    elif isinstance(d1, (DatasetAny, DatasetAll)):
+        if len(d1.objects) != len(d2.objects):
+            return False
+
+        # Compare each pair of objects
+        for obj1, obj2 in zip(d1.objects, d2.objects):
+            # If obj1 or obj2 is a Dataset, DatasetAny, or DatasetAll instance,
+            # recursively call datasets_equal
+            if not datasets_equal(obj1, obj2):
+                return False
+        return True
+
+    return False
+
+
+dataset1 = Dataset(uri="s3://bucket1/data1")
+dataset2 = Dataset(uri="s3://bucket2/data2")
+dataset3 = Dataset(uri="s3://bucket3/data3")
+dataset4 = Dataset(uri="s3://bucket4/data4")
+dataset5 = Dataset(uri="s3://bucket5/data5")
+
+test_cases = [
+    (lambda: dataset1, dataset1),
+    (lambda: dataset1 & dataset2, DatasetAll(dataset1, dataset2)),
+    (lambda: dataset1 | dataset2, DatasetAny(dataset1, dataset2)),
+    (lambda: dataset1 | (dataset2 & dataset3), DatasetAny(dataset1, 
DatasetAll(dataset2, dataset3))),
+    (lambda: dataset1 | dataset2 & dataset3, DatasetAny(dataset1, 
DatasetAll(dataset2, dataset3))),
+    (
+        lambda: ((dataset1 & dataset2) | dataset3) & (dataset4 | dataset5),
+        DatasetAll(DatasetAny(DatasetAll(dataset1, dataset2), dataset3), 
DatasetAny(dataset4, dataset5)),
+    ),
+    (lambda: dataset1 & dataset2 | dataset3, DatasetAny(DatasetAll(dataset1, 
dataset2), dataset3)),
+    (
+        lambda: (dataset1 | dataset2) & (dataset3 | dataset4),
+        DatasetAll(DatasetAny(dataset1, dataset2), DatasetAny(dataset3, 
dataset4)),
+    ),
+    (
+        lambda: (dataset1 & dataset2) | (dataset3 & (dataset4 | dataset5)),
+        DatasetAny(DatasetAll(dataset1, dataset2), DatasetAll(dataset3, 
DatasetAny(dataset4, dataset5))),
+    ),
+    (
+        lambda: (dataset1 & dataset2) & (dataset3 & dataset4),
+        DatasetAll(dataset1, dataset2, DatasetAll(dataset3, dataset4)),
+    ),
+    (lambda: dataset1 | dataset2 | dataset3, DatasetAny(dataset1, dataset2, 
dataset3)),
+    (lambda: dataset1 & dataset2 & dataset3, DatasetAll(dataset1, dataset2, 
dataset3)),
+    (
+        lambda: ((dataset1 & dataset2) | dataset3) & (dataset4 | dataset5),
+        DatasetAll(DatasetAny(DatasetAll(dataset1, dataset2), dataset3), 
DatasetAny(dataset4, dataset5)),
+    ),
+]
+
+
+@pytest.mark.parametrize("expression, expected", test_cases)
+def test_extract_datasets(expression, expected):
+    expr = expression()
+    result = extract_datasets(expr)

Review Comment:
   May I know why do we want to use `lambda` and function call here instead of 
using expression directly?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@airflow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to