This is an automated email from the ASF dual-hosted git repository. skrawcz pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/burr.git
commit af1a6e416392a5de0737c0f04e77b796dfec98b2 Author: Stefan Krawczyk <[email protected]> AuthorDate: Sun Mar 15 22:02:59 2026 -0700 Add comparison operators to when() conditions Extend Condition.when() to support Django-style lookup operators via double-underscore suffixes (e.g. when(age__gte=18), when(status__in=[...])). Supported operators: eq, ne, gt, gte, lt, lte, in, notin, contains. Plain key=value usage remains fully backward compatible. Includes 43 new parametrized tests and updated transition docs. --- burr/core/action.py | 90 ++++++++++++++++++++++--- docs/concepts/transitions.rst | 30 +++++++++ tests/core/test_action.py | 149 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 260 insertions(+), 9 deletions(-) diff --git a/burr/core/action.py b/burr/core/action.py index b2e7c16d..cd98558c 100644 --- a/burr/core/action.py +++ b/burr/core/action.py @@ -423,26 +423,98 @@ class Condition(Function): def reads(self) -> list[str]: return self._keys + _OPERATORS = { + "eq": ("==", lambda a, b: a == b), + "ne": ("!=", lambda a, b: a != b), + "lt": ("<", lambda a, b: a < b), + "lte": ("<=", lambda a, b: a <= b), + "gt": (">", lambda a, b: a > b), + "gte": (">=", lambda a, b: a >= b), + "in": ("in", lambda a, b: a in b), + "notin": ("not in", lambda a, b: a not in b), + "contains": ("contains", lambda a, b: b in a), + } + + @classmethod + def _parse_kwarg(cls, kwarg_key: str, value): + """Parse a kwarg key into (state_key, operator_symbol, comparison_func, explicit). + + Supports Django-style lookups: ``key__gte=10`` parses as key >= 10. + Plain ``key=value`` defaults to equality (implicit). + + Returns a tuple of (state_key, symbol, func, explicit) where explicit + indicates whether an operator suffix was present. + """ + for suffix, (symbol, func) in cls._OPERATORS.items(): + dunder = f"__{suffix}" + if kwarg_key.endswith(dunder): + state_key = kwarg_key[: -len(dunder)] + if not state_key: + raise ValueError( + f"Invalid when() key: '{kwarg_key}' — " f"no state key before '__{suffix}'" + ) + return state_key, symbol, func, True + return kwarg_key, "=", lambda a, b: a == b, False + @classmethod def when(cls, **kwargs): - """Returns a condition that checks if the given keys are in the - state and equal to the given values. + """Returns a condition that checks state values using optional operators. You can also refer to this as ``from burr.core import when`` in the API. - :param kwargs: Keyword arguments of keys and values to check -- will be an AND condition - :return: A condition that checks if the given keys are in the state and equal to the given values + Basic equality (unchanged from original):: + + when(foo="bar") # state["foo"] == "bar" + when(foo="bar", baz="qux") # state["foo"] == "bar" AND state["baz"] == "qux" + + Comparison operators via ``__`` suffix:: + + when(age__gt=18) # state["age"] > 18 + when(age__gte=18) # state["age"] >= 18 + when(age__lt=18) # state["age"] < 18 + when(age__lte=18) # state["age"] <= 18 + when(age__ne=0) # state["age"] != 0 + when(age__eq=18) # state["age"] == 18 (explicit) + + Membership operators:: + + when(status__in=["a", "b"]) # state["status"] in ["a", "b"] + when(status__notin=["x", "y"]) # state["status"] not in ["x", "y"] + when(tags__contains="python") # "python" in state["tags"] + + Multiple conditions are ANDed together:: + + when(age__gte=18, status="active") # age >= 18 AND status == "active" + + :param kwargs: Keyword arguments with optional ``__operator`` suffixes + :return: A condition that checks all specified constraints (AND) """ - keys = list(kwargs.keys()) + parsed = [] + for kwarg_key, value in kwargs.items(): + state_key, symbol, func, explicit = cls._parse_kwarg(kwarg_key, value) + parsed.append((state_key, symbol, func, value, explicit)) + + state_keys = list(dict.fromkeys(p[0] for p in parsed)) def condition_func(state: State) -> bool: - for key, value in kwargs.items(): - if state.get(key) != value: + for state_key, _symbol, func, value, _explicit in parsed: + if not func(state.get(state_key), value): return False return True - name = f"{', '.join(f'{key}={value}' for key, value in sorted(kwargs.items()))}" - return Condition(keys, condition_func, name=name) + name_parts = [] + for state_key, symbol, _func, value, explicit in sorted(parsed, key=lambda p: p[0]): + if not explicit: + # Backward-compatible format: key=value (no repr, no spaces) + name_parts.append(f"{state_key}={value}") + elif symbol.isalnum() or " " in symbol: + # Word operators like "in", "not in", "contains" + name_parts.append(f"{state_key} {symbol} {value!r}") + else: + # Symbol operators like >=, !=, etc. + name_parts.append(f"{state_key}{symbol}{value!r}") + name = ", ".join(name_parts) + return Condition(state_keys, condition_func, name=name) def __repr__(self): return f"condition: {self._name}" diff --git a/docs/concepts/transitions.rst b/docs/concepts/transitions.rst index a3338395..3ade8ef0 100644 --- a/docs/concepts/transitions.rst +++ b/docs/concepts/transitions.rst @@ -54,6 +54,36 @@ Conditions have a few APIs, but the most common are the three convenience functi ) +``when()`` also supports comparison operators via Django-style ``__`` suffixes: + +.. code-block:: python + + from burr.core import when + with_transitions( + ("check", "adult", when(age__gte=18)), # age >= 18 + ("check", "child", when(age__lt=18)), # age < 18 + ("check", "valid", when(score__gt=0, score__lte=100)), # 0 < score <= 100 + ("check", "active", when(status__in=["active", "pending"])), # membership + ("check", "tagged", when(tags__contains="python")), # collection contains value + ("check", "clean", when(status__notin=["banned", "suspended"])), # not in + ("check", "changed", when(status__ne="initial")), # not equal + ) + +Available operators: + +- ``key=value`` — exact equality (default, unchanged) +- ``key__eq=value`` — explicit equality +- ``key__ne=value`` — not equal +- ``key__gt=value`` — greater than +- ``key__gte=value`` — greater than or equal +- ``key__lt=value`` — less than +- ``key__lte=value`` — less than or equal +- ``key__in=[values]`` — value is in the given collection +- ``key__notin=[values]`` — value is not in the given collection +- ``key__contains=value`` — collection/string in state contains the value + +Multiple keyword arguments are ANDed together. For more complex expressions, use ``expr()``. + Conditions are evaluated in the order they are specified, and the first one that evaluates to True will be the transition that is selected when determining which action to run next. If no condition evaluates to ``True``, the application execution will stop early. diff --git a/tests/core/test_action.py b/tests/core/test_action.py index 83fecf3b..381bb6ff 100644 --- a/tests/core/test_action.py +++ b/tests/core/test_action.py @@ -126,6 +126,155 @@ def test_condition_when_complex(): assert cond.run(State({"foo": "baz", "baz": "corge"})) == {Condition.KEY: False} +# --- when() operator tests --- + + [email protected]( + "kwargs,state_dict,expected", + [ + # __eq (explicit equality) + ({"age__eq": 18}, {"age": 18}, True), + ({"age__eq": 18}, {"age": 19}, False), + # __ne (not equal) + ({"age__ne": 0}, {"age": 5}, True), + ({"age__ne": 0}, {"age": 0}, False), + # __gt (greater than) + ({"age__gt": 18}, {"age": 19}, True), + ({"age__gt": 18}, {"age": 18}, False), + ({"age__gt": 18}, {"age": 17}, False), + # __gte (greater than or equal) + ({"age__gte": 18}, {"age": 19}, True), + ({"age__gte": 18}, {"age": 18}, True), + ({"age__gte": 18}, {"age": 17}, False), + # __lt (less than) + ({"age__lt": 18}, {"age": 17}, True), + ({"age__lt": 18}, {"age": 18}, False), + ({"age__lt": 18}, {"age": 19}, False), + # __lte (less than or equal) + ({"age__lte": 18}, {"age": 17}, True), + ({"age__lte": 18}, {"age": 18}, True), + ({"age__lte": 18}, {"age": 19}, False), + # __in (membership) + ({"status__in": ["active", "pending"]}, {"status": "active"}, True), + ({"status__in": ["active", "pending"]}, {"status": "pending"}, True), + ({"status__in": ["active", "pending"]}, {"status": "banned"}, False), + # __notin (not in) + ({"status__notin": ["banned", "suspended"]}, {"status": "active"}, True), + ({"status__notin": ["banned", "suspended"]}, {"status": "banned"}, False), + # __contains (collection contains value) + ({"tags__contains": "python"}, {"tags": ["python", "java"]}, True), + ({"tags__contains": "go"}, {"tags": ["python", "java"]}, False), + ({"text__contains": "hello"}, {"text": "say hello world"}, True), + ({"text__contains": "goodbye"}, {"text": "say hello world"}, False), + ], + ids=[ + "eq-match", + "eq-no-match", + "ne-different", + "ne-same", + "gt-above", + "gt-equal", + "gt-below", + "gte-above", + "gte-equal", + "gte-below", + "lt-below", + "lt-equal", + "lt-above", + "lte-below", + "lte-equal", + "lte-above", + "in-first", + "in-second", + "in-missing", + "notin-absent", + "notin-present", + "contains-list-match", + "contains-list-no-match", + "contains-str-match", + "contains-str-no-match", + ], +) +def test_condition_when_operators(kwargs, state_dict, expected): + cond = Condition.when(**kwargs) + assert cond.run(State(state_dict)) == {Condition.KEY: expected} + + [email protected]( + "kwargs,expected_reads", + [ + ({"age__gte": 18}, ["age"]), + ({"status__in": ["a"]}, ["status"]), + ({"tags__contains": "x"}, ["tags"]), + ({"age__gte": 18, "status": "active"}, ["age", "status"]), + # same key with different operators + ({"age__gte": 10, "age__lt": 20}, ["age"]), + ], + ids=["gte", "in", "contains", "mixed", "same-key-two-ops"], +) +def test_condition_when_operators_reads(kwargs, expected_reads): + cond = Condition.when(**kwargs) + assert sorted(cond.reads) == sorted(expected_reads) + + [email protected]( + "kwargs,expected_name", + [ + ({"age__gte": 18}, "age>=18"), + ({"age__lt": 5}, "age<5"), + ({"age__ne": 0}, "age!=0"), + ({"status__in": ["a", "b"]}, "status in ['a', 'b']"), + ({"status__notin": ["x"]}, "status not in ['x']"), + ({"tags__contains": "py"}, "tags contains 'py'"), + # plain equality still uses old format + ({"foo": "bar"}, "foo=bar"), + ({"foo": "bar", "baz": "qux"}, "baz=qux, foo=bar"), + ], + ids=["gte", "lt", "ne", "in", "notin", "contains", "plain-eq", "plain-multi"], +) +def test_condition_when_operators_name(kwargs, expected_name): + cond = Condition.when(**kwargs) + assert cond.name == expected_name + + +def test_condition_when_operators_combined(): + """Test multiple operators ANDed together.""" + cond = Condition.when(age__gte=18, status="active", score__lt=100) + assert cond.run(State({"age": 20, "status": "active", "score": 50})) == {Condition.KEY: True} + assert cond.run(State({"age": 17, "status": "active", "score": 50})) == {Condition.KEY: False} + assert cond.run(State({"age": 20, "status": "inactive", "score": 50})) == {Condition.KEY: False} + assert cond.run(State({"age": 20, "status": "active", "score": 100})) == {Condition.KEY: False} + + +def test_condition_when_operators_with_invert(): + """Ensure operator-based conditions work with ~ (invert).""" + cond = ~Condition.when(age__gte=18) + assert cond.run(State({"age": 17})) == {Condition.KEY: True} + assert cond.run(State({"age": 18})) == {Condition.KEY: False} + + +def test_condition_when_operators_with_or(): + """Ensure operator-based conditions work with | (or).""" + cond = Condition.when(age__lt=13) | Condition.when(age__gte=65) + assert cond.run(State({"age": 10})) == {Condition.KEY: True} + assert cond.run(State({"age": 70})) == {Condition.KEY: True} + assert cond.run(State({"age": 30})) == {Condition.KEY: False} + + +def test_condition_when_operators_with_and(): + """Ensure operator-based conditions work with & (and).""" + cond = Condition.when(age__gte=18) & Condition.when(age__lt=65) + assert cond.run(State({"age": 30})) == {Condition.KEY: True} + assert cond.run(State({"age": 17})) == {Condition.KEY: False} + assert cond.run(State({"age": 65})) == {Condition.KEY: False} + + +def test_condition_when_invalid_key(): + """Empty state key before operator suffix should raise.""" + with pytest.raises(ValueError, match="no state key"): + Condition.when(__gte=18) + + def test_condition_default(): cond = default assert cond.name == "default"
