This is an automated email from the ASF dual-hosted git repository.

skrawcz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/hamilton.git


The following commit(s) were added to refs/heads/main by this push:
     new a47fc422 While loop DFS implementation (#1483)
a47fc422 is described below

commit a47fc4227d86528a5bd75bc6bc02cbeb01530d16
Author: Stefan Krawczyk <[email protected]>
AuthorDate: Sat Mar 7 09:59:56 2026 -0800

    While loop DFS implementation (#1483)
    
    * While loop DFS implementation
    
    This should help someone create a really large and deep graph.
    
    * Changes
    deep core change to DFS implementation
    
    * How I tested this
     - unit test
     - ran a small benchmark -- no performance regressions.
    
    * Notes
    Even though this is a pretty core change I think this is isolated enough 
that we are confident in this change.
    Alternatively we could enable an environment variable to switch, but I 
don't see much value in that given that we're swapping the function stack for a 
regular stack.
    
    * Fixes up iteration to behave like before
    
    So this ensures we don't visit things twice. It also
    ensures that we have the same order of traversal as before.
    I don't want to change the order of behavior if we don't
    have to.
---
 hamilton/graph.py   |  22 ++++++----
 tests/test_graph.py | 114 ++++++++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 128 insertions(+), 8 deletions(-)

diff --git a/hamilton/graph.py b/hamilton/graph.py
index 78595c4f..f3fac755 100644
--- a/hamilton/graph.py
+++ b/hamilton/graph.py
@@ -1091,13 +1091,19 @@ class FunctionGraph:
         nodes = set()
         user_nodes = set()
 
-        def dfs_traverse(node: node.Node):
-            nodes.add(node)
-            for n in next_nodes_fn(node):
-                if n not in nodes:
-                    dfs_traverse(n)
-            if node.user_defined:
-                user_nodes.add(node)
+        def dfs_traverse_iterative(start_node: node.Node):
+            """Iterative DFS to avoid recursion depth limits with large 
DAGs."""
+            stack = [start_node]
+            nodes.add(start_node)
+            while stack:
+                n = stack.pop()
+                if n.user_defined:
+                    user_nodes.add(n)
+                # reversed() preserves the same traversal order as the 
recursive version
+                for next_n in reversed(next_nodes_fn(n)):
+                    if next_n not in nodes:
+                        nodes.add(next_n)
+                        stack.append(next_n)
 
         missing_vars = []
         for var in starting_nodes:
@@ -1108,7 +1114,7 @@ class FunctionGraph:
                     # if it's not in the runtime inputs, it's a properly 
missing variable
                     missing_vars.append(var)
                 continue  # collect all missing final variables
-            dfs_traverse(self.nodes[var])
+            dfs_traverse_iterative(self.nodes[var])
         if missing_vars:
             missing_vars_str = ",\n".join(missing_vars)
             raise ValueError(f"Unknown nodes [{missing_vars_str}] requested. 
Check for typos?")
diff --git a/tests/test_graph.py b/tests/test_graph.py
index 6c6faf3b..1d802b02 100644
--- a/tests/test_graph.py
+++ b/tests/test_graph.py
@@ -17,6 +17,7 @@
 
 import inspect
 import pathlib
+import sys
 import uuid
 from itertools import permutations
 
@@ -26,6 +27,7 @@ import pytest
 import hamilton.graph_utils
 import hamilton.htypes
 from hamilton import ad_hoc_utils, base, graph, node
+from hamilton import function_modifiers as fm
 from hamilton.execution import graph_functions
 from hamilton.function_modifiers import schema
 from hamilton.lifecycle import base as lifecycle_base
@@ -538,6 +540,118 @@ def test_get_downstream_nodes():
     assert actual_nodes == expected_nodes
 
 
+def test_get_upstream_nodes_large_chain_no_recursion_error():
+    """Regression test: get_upstream_nodes with only final_node on a large 
chain DAG.
+
+    A recursive DFS would exceed Python's recursion limit (~1000) when 
traversing
+    a long dependency chain from a single final node. This test verifies that
+    the iterative DFS in directional_dfs_traverse handles large DAGs correctly.
+
+    Chain size is chosen to exceed recursion limit: 1200 nodes > 1000.
+    """
+
+    def step(prev: float) -> float:
+        """Single step in a linear chain."""
+        return prev + 1.0
+
+    # Build a linear chain: node_0 -> node_1 -> ... -> node_N
+    chain_size = sys.getrecursionlimit() + 200  # Exceeds recursion limit
+    config = {}
+    for i in range(chain_size):
+        prev = f"node_{i - 1}" if i > 0 else 0.0
+        config[f"node_{i}"] = {
+            "prev": fm.source(prev) if i > 0 else fm.value(0.0),
+        }
+    decorated = fm.parameterize(**config)(step)
+    module = ad_hoc_utils.create_temporary_module(decorated, 
module_name="large_chain")
+
+    fg = graph.FunctionGraph.from_modules(module, config={})
+    final_node = f"node_{chain_size - 1}"
+
+    # This would raise RecursionError with recursive DFS
+    nodes, user_nodes = fg.get_upstream_nodes([final_node])
+
+    assert len(nodes) == chain_size
+    assert len(user_nodes) == 0
+    assert all(fg.nodes[f"node_{i}"] in nodes for i in range(chain_size))
+
+
+def test_get_upstream_nodes_diamond_dag():
+    """Tests that diamond-shaped DAGs don't produce duplicate visits.
+
+    DAG shape:
+        x, y (inputs)
+          |
+        left   right    (both depend on x and y)
+           \\   /
+           bottom        (depends on left and right)
+
+    The shared inputs x and y are reachable via both left and right.
+    With a naive iterative DFS (mark-on-pop), x and y could be pushed
+    onto the stack multiple times. This verifies they appear exactly once.
+    """
+
+    def left(x: int, y: int) -> int:
+        return x + y
+
+    def right(x: int, y: int) -> int:
+        return x * y
+
+    def bottom(left: int, right: int) -> int:
+        return left + right
+
+    module = ad_hoc_utils.create_temporary_module(left, right, bottom)
+    fg = graph.FunctionGraph.from_modules(module, config={})
+    nodes, user_nodes = fg.get_upstream_nodes(["bottom"])
+
+    assert len(nodes) == 5  # x, y, left, right, bottom
+    assert {n.name for n in nodes} == {"x", "y", "left", "right", "bottom"}
+    # x and y are external inputs
+    assert {n.name for n in user_nodes} == {"x", "y"}
+
+
+def test_get_upstream_nodes_single_node():
+    """Tests traversal of a single node with no dependencies."""
+
+    def solo() -> int:
+        return 42
+
+    module = ad_hoc_utils.create_temporary_module(solo)
+    fg = graph.FunctionGraph.from_modules(module, config={})
+    nodes, user_nodes = fg.get_upstream_nodes(["solo"])
+
+    assert len(nodes) == 1
+    assert {n.name for n in nodes} == {"solo"}
+    assert len(user_nodes) == 0
+
+
+def test_get_upstream_nodes_overlapping_starting_nodes():
+    """Tests that overlapping subgraphs from multiple starting nodes are 
handled correctly.
+
+    DAG shape:
+        shared (input)
+          / \\
+        a     b    (both depend on shared)
+
+    Requesting both a and b as starting nodes means 'shared' is reachable
+    from both traversals. It should still appear exactly once in the result.
+    """
+
+    def a(shared: int) -> int:
+        return shared + 1
+
+    def b(shared: int) -> int:
+        return shared + 2
+
+    module = ad_hoc_utils.create_temporary_module(a, b)
+    fg = graph.FunctionGraph.from_modules(module, config={})
+    nodes, user_nodes = fg.get_upstream_nodes(["a", "b"])
+
+    assert len(nodes) == 3  # shared, a, b
+    assert {n.name for n in nodes} == {"shared", "a", "b"}
+    assert {n.name for n in user_nodes} == {"shared"}
+
+
 def test_function_graph_from_multiple_sources():
     fg = graph.FunctionGraph.from_modules(
         tests.resources.dummy_functions, tests.resources.parametrized_nodes, 
config={}

Reply via email to