This is an automated email from the ASF dual-hosted git repository.
aglinxinyuan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/texera.git
The following commit(s) were added to refs/heads/main by this push:
new ced34484ca test(pyamber): add unit tests for replace_print and
AtomicInteger (#4795)
ced34484ca is described below
commit ced34484ca8d90043fadf4780cd0d33ec8891b32
Author: Yicong Huang <[email protected]>
AuthorDate: Sun May 3 09:18:52 2026 -0700
test(pyamber): add unit tests for replace_print and AtomicInteger (#4795)
### What changes were proposed in this PR?
Adds pytest coverage for two thin core utilities:
- `amber/src/main/python/core/util/console_message/replace_print.py` —
the `replace_print` context manager that intercepts `builtins.print` and
enqueues `ConsoleMessage` payloads on a buffer.
- `amber/src/main/python/core/util/thread/atomic.py` — the
`AtomicInteger` thread-safe counter.
### Any related issues, documentation, discussions?
Closes #4793.
Bug pinned in the spec with comment + xfail-strict for the intended
contract (filed separately as a Bug issue): `AtomicInteger.get_and_set`
deadlocks because it holds the non-reentrant lock while invoking the
`value` property, which tries to acquire the same lock again. The pinned
test surfaces the deadlock via a thread + timeout (so the suite does not
hang), and the xfail-strict companion asserts the intended `get_and_set`
contract — that test will flip to XPASS the moment the deadlock is fixed
and force the spec to be updated alongside.
### How was this PR tested?
```
cd amber/src/main/python
ruff check core/util/thread/test_atomic.py
core/util/console_message/test_replace_print.py
ruff format --check core/util/thread/test_atomic.py
core/util/console_message/test_replace_print.py
python -m pytest core/util/thread/test_atomic.py
core/util/console_message/test_replace_print.py
```
### Was this PR authored or co-authored using generative AI tooling?
Generated-by: Claude Code (claude-opus-4-7)
---------
Co-authored-by: Claude Opus 4.7 (1M context) <[email protected]>
---
.../util/console_message/test_replace_print.py | 135 ++++++++++++++++++++
.../main/python/core/util/thread/test_atomic.py | 139 +++++++++++++++++++++
2 files changed, 274 insertions(+)
diff --git
a/amber/src/main/python/core/util/console_message/test_replace_print.py
b/amber/src/main/python/core/util/console_message/test_replace_print.py
new file mode 100644
index 0000000000..2fc6c8e322
--- /dev/null
+++ b/amber/src/main/python/core/util/console_message/test_replace_print.py
@@ -0,0 +1,135 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import builtins
+import io
+from typing import List
+
+import pytest
+
+from core.util.console_message.replace_print import replace_print
+from proto.org.apache.texera.amber.engine.architecture.rpc import (
+ ConsoleMessage,
+ ConsoleMessageType,
+)
+
+
+class CapturingBuffer:
+ """Minimal IBuffer stand-in that just records put calls."""
+
+ def __init__(self):
+ self.messages: List[ConsoleMessage] = []
+
+ def put(self, msg):
+ self.messages.append(msg)
+
+
+class TestReplacePrintLifecycle:
+ def test_print_is_replaced_inside_the_context_and_restored_on_exit(self):
+ original = builtins.print
+ buf = CapturingBuffer()
+ with replace_print("w", buf):
+ assert builtins.print is not original
+ assert builtins.print is original
+
+ def test_print_is_restored_even_when_the_block_raises(self):
+ original = builtins.print
+ buf = CapturingBuffer()
+ with pytest.raises(RuntimeError):
+ with replace_print("w", buf):
+ raise RuntimeError("boom")
+ assert builtins.print is original
+
+ def
test_exit_returns_true_for_clean_block_and_false_for_raising_block(self):
+ # Pin: __exit__ returns True when no exception, False otherwise. The
+ # contextlib protocol then suppresses or surfaces the exception
+ # accordingly. The class returns False on exception, so the exception
+ # propagates out — matching the docstring claim.
+ ctx = replace_print("w", CapturingBuffer())
+ ctx.__enter__()
+ assert ctx.__exit__(None, None, None) is True
+ ctx2 = replace_print("w", CapturingBuffer())
+ ctx2.__enter__()
+ try:
+ assert ctx2.__exit__(RuntimeError, RuntimeError("x"), None) is
False
+ finally:
+ # The class only restores `print` if __exit__ runs to completion;
+ # call it explicitly to clean up either way.
+ builtins.print = ctx2.builtins_print
+
+
+class TestReplacePrintBufferPayload:
+ def test_print_inside_context_enqueues_a_console_message(self):
+ buf = CapturingBuffer()
+ with replace_print("worker-A", buf):
+ print("hello")
+ assert len(buf.messages) == 1
+ msg = buf.messages[0]
+ assert msg.worker_id == "worker-A"
+ assert msg.msg_type == ConsoleMessageType.PRINT
+ # Default print appends a newline; the title carries the full line.
+ assert msg.title == "hello\n"
+ assert msg.message == ""
+
+ def test_joins_args_via_the_real_print_so_sep_and_end_kwargs_apply(self):
+ buf = CapturingBuffer()
+ with replace_print("w", buf):
+ print("a", "b", "c", sep="-", end="!")
+ assert buf.messages[0].title == "a-b-c!"
+
+ def test_each_print_call_produces_one_buffer_entry(self):
+ # Pin: the wrapped print writes to the buffer once per print call,
+ # not once per argument (contextlib.redirect_stdout-style would do the
+ # latter). The docstring calls this out.
+ buf = CapturingBuffer()
+ with replace_print("w", buf):
+ print("first")
+ print("second", "third")
+ assert [m.title for m in buf.messages] == ["first\n", "second third\n"]
+
+ def test_print_with_file_kwarg_bypasses_the_buffer(self):
+ # When the caller provides a `file=...` argument, the wrap delegates
+ # straight to the original builtins.print and does not enqueue a
+ # ConsoleMessage. This is what lets explicit logging redirects keep
+ # working inside the context.
+ buf = CapturingBuffer()
+ sink = io.StringIO()
+ with replace_print("w", buf):
+ print("ignored-by-buffer", file=sink)
+ assert buf.messages == []
+ assert sink.getvalue() == "ignored-by-buffer\n"
+
+ def test_source_field_records_caller_module_function_and_line(self):
+ # The wrap walks one frame up to identify where the print() came from,
+ # so the source string carries `<module>:<func>:<lineno>`. We verify
+ # only the structural parts — the exact line number and module name
+ # depend on this test's location, so use loose checks.
+ buf = CapturingBuffer()
+
+ def caller_under_test():
+ print("from-caller")
+
+ with replace_print("w", buf):
+ caller_under_test()
+
+ source = buf.messages[0].source
+ parts = source.split(":")
+ assert len(parts) == 3
+ # The reported function name is the function that called print().
+ assert parts[1] == "caller_under_test"
+ # And the line number is a positive integer.
+ assert parts[2].isdigit() and int(parts[2]) > 0
diff --git a/amber/src/main/python/core/util/thread/test_atomic.py
b/amber/src/main/python/core/util/thread/test_atomic.py
new file mode 100644
index 0000000000..fa6238e0eb
--- /dev/null
+++ b/amber/src/main/python/core/util/thread/test_atomic.py
@@ -0,0 +1,139 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import threading
+
+import pytest
+
+from core.util.thread.atomic import AtomicInteger
+
+
+class TestAtomicIntegerSingleThreaded:
+ def test_default_starts_at_zero(self):
+ assert AtomicInteger().value == 0
+
+ def test_initial_value_is_coerced_to_int(self):
+ # The constructor wraps the input through int(), which lets callers
+ # pass a numeric string or float and still get a clean integer state.
+ assert AtomicInteger("7").value == 7
+ assert AtomicInteger(3.9).value == 3 # int() truncates toward zero
+
+ def test_inc_returns_new_value_after_adding_default_one(self):
+ a = AtomicInteger(10)
+ assert a.inc() == 11
+ assert a.value == 11
+
+ def test_inc_with_custom_delta_uses_int_coercion(self):
+ a = AtomicInteger(10)
+ assert a.inc(5) == 15
+ # int("3") -> 3, the underlying state increments by 3.
+ assert a.inc("3") == 18
+
+ def test_dec_is_inc_with_negated_delta(self):
+ a = AtomicInteger(10)
+ assert a.dec() == 9
+ assert a.dec(4) == 5
+
+ def test_get_and_inc_returns_pre_increment_value(self):
+ a = AtomicInteger(10)
+ assert a.get_and_inc() == 10
+ assert a.value == 11
+
+ def test_get_and_dec_returns_pre_decrement_value(self):
+ a = AtomicInteger(10)
+ assert a.get_and_dec(2) == 10
+ assert a.value == 8
+
+ def test_value_setter_replaces_state_with_int_coercion(self):
+ a = AtomicInteger(10)
+ a.value = 42
+ assert a.value == 42
+ a.value = "100"
+ assert a.value == 100
+
+ def test_get_and_set_currently_deadlocks_on_non_reentrant_lock(self):
+ # Bug pin: get_and_set acquires self._lock and then reads self.value,
+ # which is a property that ALSO tries to acquire self._lock. The lock
+ # is a non-reentrant threading.Lock, so the call deadlocks the moment
+ # it is invoked. Document via thread + timeout so the test surfaces
+ # the deadlock without hanging the whole suite, and pair it with an
+ # xfail-strict test below that asserts the intended contract.
+ a = AtomicInteger(10)
+ started = threading.Event()
+ completed = threading.Event()
+ errors: list[BaseException] = []
+
+ def attempt():
+ started.set()
+ try:
+ a.get_and_set(99)
+ completed.set()
+ except BaseException as exc:
+ errors.append(exc)
+
+ worker = threading.Thread(target=attempt, daemon=True)
+ worker.start()
+ # Make sure the worker actually entered `attempt` — otherwise a
+ # scheduling delay alone could let the assertions below pass even on
+ # a fixed implementation.
+ assert started.wait(timeout=2.0), "worker thread never started"
+ # Give get_and_set a moment to either deadlock or return.
+ completed.wait(timeout=0.5)
+ assert not errors, (
+ f"get_and_set raised before reaching the deadlock spin:
{errors[0]!r}"
+ )
+ assert worker.is_alive(), (
+ "worker thread exited unexpectedly — get_and_set neither
deadlocked "
+ "nor completed; the test no longer pins the documented bug."
+ )
+ assert not completed.is_set(), (
+ "get_and_set unexpectedly returned — the deadlock bug appears
fixed; "
+ "delete this pinned test along with the xfail below."
+ )
+
+ @pytest.mark.xfail(
+ strict=True,
+ reason=(
+ "Known bug: AtomicInteger.get_and_set deadlocks because it holds "
+ "the non-reentrant lock while accessing the value property. "
+ "This xfail flips to XPASS when the bug is fixed."
+ ),
+ )
+ @pytest.mark.timeout(2)
+ def test_get_and_set_should_return_old_value_and_replace_state(self):
+ a = AtomicInteger(10)
+ assert a.get_and_set(99) == 10
+ assert a.value == 99
+
+
+class TestAtomicIntegerThreadSafety:
+ def test_inc_under_concurrent_threads_is_lossless(self):
+ a = AtomicInteger(0)
+ threads_count = 8
+ per_thread = 1000
+
+ def worker():
+ for _ in range(per_thread):
+ a.inc()
+
+ threads = [threading.Thread(target=worker) for _ in
range(threads_count)]
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+
+ assert a.value == threads_count * per_thread