This is an automated email from the ASF dual-hosted git repository.
skrawcz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/hamilton.git
The following commit(s) were added to refs/heads/main by this push:
new a47fc422 While loop DFS implementation (#1483)
a47fc422 is described below
commit a47fc4227d86528a5bd75bc6bc02cbeb01530d16
Author: Stefan Krawczyk <[email protected]>
AuthorDate: Sat Mar 7 09:59:56 2026 -0800
While loop DFS implementation (#1483)
* While loop DFS implementation
This should help someone create a really large and deep graph.
* Changes
deep core change to DFS implementation
* How I tested this
- unit test
- ran a small benchmark -- no performance regressions.
* Notes
Even though this is a pretty core change I think this is isolated enough
that we are confident in this change.
Alternatively we could enable an environment variable to switch, but I
don't see much value in that given that we're swapping the function stack for a
regular stack.
* Fixes up iteration to behave like before
So this ensures we don't visit things twice. It also
ensures that we have the same order of traversal as before.
I don't want to change the order of behavior if we don't
have to.
---
hamilton/graph.py | 22 ++++++----
tests/test_graph.py | 114 ++++++++++++++++++++++++++++++++++++++++++++++++++++
2 files changed, 128 insertions(+), 8 deletions(-)
diff --git a/hamilton/graph.py b/hamilton/graph.py
index 78595c4f..f3fac755 100644
--- a/hamilton/graph.py
+++ b/hamilton/graph.py
@@ -1091,13 +1091,19 @@ class FunctionGraph:
nodes = set()
user_nodes = set()
- def dfs_traverse(node: node.Node):
- nodes.add(node)
- for n in next_nodes_fn(node):
- if n not in nodes:
- dfs_traverse(n)
- if node.user_defined:
- user_nodes.add(node)
+ def dfs_traverse_iterative(start_node: node.Node):
+ """Iterative DFS to avoid recursion depth limits with large
DAGs."""
+ stack = [start_node]
+ nodes.add(start_node)
+ while stack:
+ n = stack.pop()
+ if n.user_defined:
+ user_nodes.add(n)
+ # reversed() preserves the same traversal order as the
recursive version
+ for next_n in reversed(next_nodes_fn(n)):
+ if next_n not in nodes:
+ nodes.add(next_n)
+ stack.append(next_n)
missing_vars = []
for var in starting_nodes:
@@ -1108,7 +1114,7 @@ class FunctionGraph:
# if it's not in the runtime inputs, it's a properly
missing variable
missing_vars.append(var)
continue # collect all missing final variables
- dfs_traverse(self.nodes[var])
+ dfs_traverse_iterative(self.nodes[var])
if missing_vars:
missing_vars_str = ",\n".join(missing_vars)
raise ValueError(f"Unknown nodes [{missing_vars_str}] requested.
Check for typos?")
diff --git a/tests/test_graph.py b/tests/test_graph.py
index 6c6faf3b..1d802b02 100644
--- a/tests/test_graph.py
+++ b/tests/test_graph.py
@@ -17,6 +17,7 @@
import inspect
import pathlib
+import sys
import uuid
from itertools import permutations
@@ -26,6 +27,7 @@ import pytest
import hamilton.graph_utils
import hamilton.htypes
from hamilton import ad_hoc_utils, base, graph, node
+from hamilton import function_modifiers as fm
from hamilton.execution import graph_functions
from hamilton.function_modifiers import schema
from hamilton.lifecycle import base as lifecycle_base
@@ -538,6 +540,118 @@ def test_get_downstream_nodes():
assert actual_nodes == expected_nodes
+def test_get_upstream_nodes_large_chain_no_recursion_error():
+ """Regression test: get_upstream_nodes with only final_node on a large
chain DAG.
+
+ A recursive DFS would exceed Python's recursion limit (~1000) when
traversing
+ a long dependency chain from a single final node. This test verifies that
+ the iterative DFS in directional_dfs_traverse handles large DAGs correctly.
+
+ Chain size is chosen to exceed recursion limit: 1200 nodes > 1000.
+ """
+
+ def step(prev: float) -> float:
+ """Single step in a linear chain."""
+ return prev + 1.0
+
+ # Build a linear chain: node_0 -> node_1 -> ... -> node_N
+ chain_size = sys.getrecursionlimit() + 200 # Exceeds recursion limit
+ config = {}
+ for i in range(chain_size):
+ prev = f"node_{i - 1}" if i > 0 else 0.0
+ config[f"node_{i}"] = {
+ "prev": fm.source(prev) if i > 0 else fm.value(0.0),
+ }
+ decorated = fm.parameterize(**config)(step)
+ module = ad_hoc_utils.create_temporary_module(decorated,
module_name="large_chain")
+
+ fg = graph.FunctionGraph.from_modules(module, config={})
+ final_node = f"node_{chain_size - 1}"
+
+ # This would raise RecursionError with recursive DFS
+ nodes, user_nodes = fg.get_upstream_nodes([final_node])
+
+ assert len(nodes) == chain_size
+ assert len(user_nodes) == 0
+ assert all(fg.nodes[f"node_{i}"] in nodes for i in range(chain_size))
+
+
+def test_get_upstream_nodes_diamond_dag():
+ """Tests that diamond-shaped DAGs don't produce duplicate visits.
+
+ DAG shape:
+ x, y (inputs)
+ |
+ left right (both depend on x and y)
+ \\ /
+ bottom (depends on left and right)
+
+ The shared inputs x and y are reachable via both left and right.
+ With a naive iterative DFS (mark-on-pop), x and y could be pushed
+ onto the stack multiple times. This verifies they appear exactly once.
+ """
+
+ def left(x: int, y: int) -> int:
+ return x + y
+
+ def right(x: int, y: int) -> int:
+ return x * y
+
+ def bottom(left: int, right: int) -> int:
+ return left + right
+
+ module = ad_hoc_utils.create_temporary_module(left, right, bottom)
+ fg = graph.FunctionGraph.from_modules(module, config={})
+ nodes, user_nodes = fg.get_upstream_nodes(["bottom"])
+
+ assert len(nodes) == 5 # x, y, left, right, bottom
+ assert {n.name for n in nodes} == {"x", "y", "left", "right", "bottom"}
+ # x and y are external inputs
+ assert {n.name for n in user_nodes} == {"x", "y"}
+
+
+def test_get_upstream_nodes_single_node():
+ """Tests traversal of a single node with no dependencies."""
+
+ def solo() -> int:
+ return 42
+
+ module = ad_hoc_utils.create_temporary_module(solo)
+ fg = graph.FunctionGraph.from_modules(module, config={})
+ nodes, user_nodes = fg.get_upstream_nodes(["solo"])
+
+ assert len(nodes) == 1
+ assert {n.name for n in nodes} == {"solo"}
+ assert len(user_nodes) == 0
+
+
+def test_get_upstream_nodes_overlapping_starting_nodes():
+ """Tests that overlapping subgraphs from multiple starting nodes are
handled correctly.
+
+ DAG shape:
+ shared (input)
+ / \\
+ a b (both depend on shared)
+
+ Requesting both a and b as starting nodes means 'shared' is reachable
+ from both traversals. It should still appear exactly once in the result.
+ """
+
+ def a(shared: int) -> int:
+ return shared + 1
+
+ def b(shared: int) -> int:
+ return shared + 2
+
+ module = ad_hoc_utils.create_temporary_module(a, b)
+ fg = graph.FunctionGraph.from_modules(module, config={})
+ nodes, user_nodes = fg.get_upstream_nodes(["a", "b"])
+
+ assert len(nodes) == 3 # shared, a, b
+ assert {n.name for n in nodes} == {"shared", "a", "b"}
+ assert {n.name for n in user_nodes} == {"shared"}
+
+
def test_function_graph_from_multiple_sources():
fg = graph.FunctionGraph.from_modules(
tests.resources.dummy_functions, tests.resources.parametrized_nodes,
config={}