diff --git a/src/google/adk/agents/parallel_agent.py b/src/google/adk/agents/parallel_agent.py index cb8b09f655..949a1f2413 100644 --- a/src/google/adk/agents/parallel_agent.py +++ b/src/google/adk/agents/parallel_agent.py @@ -17,9 +17,11 @@ from __future__ import annotations import asyncio +import logging import sys from typing import AsyncGenerator from typing import ClassVar +from typing import Optional from typing_extensions import override @@ -31,6 +33,8 @@ from .invocation_context import InvocationContext from .parallel_agent_config import ParallelAgentConfig +logger = logging.getLogger('google_adk.' + __name__) + def _create_branch_ctx_for_sub_agent( agent: BaseAgent, @@ -48,29 +52,75 @@ def _create_branch_ctx_for_sub_agent( return invocation_context +async def _iter_with_idle_timeout( + gen: AsyncGenerator[Event, None], + timeout_secs: float, + branch_name: str, +) -> AsyncGenerator[Event, None]: + """Wrap *gen*, raising TimeoutError if no event arrives within *timeout_secs*. + + Uses asyncio.wait_for on each __anext__ call so that a branch whose upstream + model stream silently stalls (connection open, no chunks) is detected and + cancelled rather than hanging the parent ParallelAgent indefinitely. + """ + while True: + try: + event = await asyncio.wait_for( + gen.__anext__(), + timeout=timeout_secs, + ) + except StopAsyncIteration: + return + except asyncio.TimeoutError as exc: + logger.warning( + 'ParallelAgent branch %r has not produced an event for %.1fs. ' + 'The upstream model stream may have stalled. Cancelling the branch ' + 'so the parent agent is not blocked indefinitely.', + branch_name, + timeout_secs, + ) + raise asyncio.TimeoutError( + f'Branch {branch_name!r} idle for >{timeout_secs}s without an event' + ) from exc + yield event + + async def _merge_agent_run( agent_runs: list[AsyncGenerator[Event, None]], + *, + branch_names: list[str] | None = None, + branch_idle_timeout_secs: float | None = None, ) -> AsyncGenerator[Event, None]: """Merges agent runs using asyncio.TaskGroup on Python 3.11+.""" sentinel = object() - queue = asyncio.Queue() + queue: asyncio.Queue = asyncio.Queue() + names = branch_names or [f'branch-{i}' for i in range(len(agent_runs))] # Agents are processed in parallel. # Events for each agent are put on queue sequentially. - async def process_an_agent(events_for_one_agent): + async def process_an_agent(events_for_one_agent, branch_name: str): try: - async for event in events_for_one_agent: + gen = ( + _iter_with_idle_timeout( + events_for_one_agent, branch_idle_timeout_secs, branch_name + ) + if branch_idle_timeout_secs is not None + else events_for_one_agent + ) + async for event in gen: resume_signal = asyncio.Event() - await queue.put((event, resume_signal)) + # put_nowait: the queue is unbounded so this never blocks, and it is + # safe to call from a finally block that may run during cancellation. + queue.put_nowait((event, resume_signal)) # Wait for upstream to consume event before generating new events. await resume_signal.wait() finally: - # Mark agent as finished. - await queue.put((sentinel, None)) + # Mark agent as finished. put_nowait is cancellation-safe (see above). + queue.put_nowait((sentinel, None)) async with asyncio.TaskGroup() as tg: - for events_for_one_agent in agent_runs: - tg.create_task(process_an_agent(events_for_one_agent)) + for events_for_one_agent, name in zip(agent_runs, names): + tg.create_task(process_an_agent(events_for_one_agent, name)) sentinel_count = 0 # Run until all agents finished processing. @@ -88,6 +138,9 @@ async def process_an_agent(events_for_one_agent): # TODO - remove once Python <3.11 is no longer supported. async def _merge_agent_run_pre_3_11( agent_runs: list[AsyncGenerator[Event, None]], + *, + branch_names: list[str] | None = None, + branch_idle_timeout_secs: float | None = None, ) -> AsyncGenerator[Event, None]: """Merges agent runs for Python 3.10 without asyncio.TaskGroup. @@ -96,12 +149,16 @@ async def _merge_agent_run_pre_3_11( Args: agent_runs: Async generators that yield events from each agent. + branch_names: Optional names for each branch, used in log messages. + branch_idle_timeout_secs: If set, cancel a branch that produces no event + for this many seconds (guards against silently stalled model streams). Yields: Event: The next event from the merged generator. """ sentinel = object() - queue = asyncio.Queue() + queue: asyncio.Queue = asyncio.Queue() + names = branch_names or [f'branch-{i}' for i in range(len(agent_runs))] def propagate_exceptions(tasks): # Propagate exceptions and errors from tasks. @@ -113,21 +170,31 @@ def propagate_exceptions(tasks): # Agents are processed in parallel. # Events for each agent are put on queue sequentially. - async def process_an_agent(events_for_one_agent): + async def process_an_agent(events_for_one_agent, branch_name: str): try: - async for event in events_for_one_agent: + gen = ( + _iter_with_idle_timeout( + events_for_one_agent, branch_idle_timeout_secs, branch_name + ) + if branch_idle_timeout_secs is not None + else events_for_one_agent + ) + async for event in gen: resume_signal = asyncio.Event() - await queue.put((event, resume_signal)) + queue.put_nowait((event, resume_signal)) # Wait for upstream to consume event before generating new events. await resume_signal.wait() finally: - # Mark agent as finished. - await queue.put((sentinel, None)) + # put_nowait is cancellation-safe: the queue is unbounded so it never + # blocks, and it will not raise even if the task is being cancelled. + queue.put_nowait((sentinel, None)) tasks = [] try: - for events_for_one_agent in agent_runs: - tasks.append(asyncio.create_task(process_an_agent(events_for_one_agent))) + for events_for_one_agent, name in zip(agent_runs, names): + tasks.append( + asyncio.create_task(process_an_agent(events_for_one_agent, name)) + ) sentinel_count = 0 # Run until all agents finished processing. @@ -142,6 +209,10 @@ async def process_an_agent(events_for_one_agent): # Signal to agent that event has been processed by runner and it can # continue now. resume_signal.set() + # A task may have put its sentinel AND raised an exception; if the sentinel + # was consumed before propagate_exceptions ran in the loop body, the error + # would be silently lost. Check once more after the loop to surface it. + propagate_exceptions(tasks) finally: for task in tasks: task.cancel() @@ -155,11 +226,23 @@ class ParallelAgent(BaseAgent): - Running different algorithms simultaneously. - Generating multiple responses for review by a subsequent evaluation agent. + + Attributes: + branch_idle_timeout_secs: Optional per-branch idle timeout in seconds. + When set, any branch that produces no event for this many seconds is + cancelled and raises ``asyncio.TimeoutError``, which unblocks the + parent agent instead of hanging indefinitely. This guards against + upstream model streams that stall silently (connection open, no + chunks arriving). ``None`` (the default) disables the timeout and + preserves the original unbounded-wait behaviour. """ config_type: ClassVar[type[BaseAgentConfig]] = ParallelAgentConfig """The config type for this agent.""" + branch_idle_timeout_secs: Optional[float] = None + """Per-branch idle timeout in seconds; None disables the guard.""" + @override async def _run_async_impl( self, ctx: InvocationContext @@ -173,6 +256,7 @@ async def _run_async_impl( yield self._create_agent_state_event(ctx) agent_runs = [] + branch_names = [] # Prepare and collect async generators for each sub-agent. for sub_agent in self.sub_agents: sub_agent_ctx = _create_branch_ctx_for_sub_agent(self, sub_agent, ctx) @@ -180,6 +264,7 @@ async def _run_async_impl( # Only include sub-agents that haven't finished in a previous run. if not sub_agent_ctx.end_of_agents.get(sub_agent.name): agent_runs.append(sub_agent.run_async(sub_agent_ctx)) + branch_names.append(sub_agent.name) pause_invocation = False try: @@ -188,7 +273,13 @@ async def _run_async_impl( if sys.version_info >= (3, 11) else _merge_agent_run_pre_3_11 ) - async with Aclosing(merge_func(agent_runs)) as agen: + async with Aclosing( + merge_func( + agent_runs, + branch_names=branch_names, + branch_idle_timeout_secs=self.branch_idle_timeout_secs, + ) + ) as agen: async for event in agen: yield event if ctx.should_pause_invocation(event): diff --git a/tests/unittests/agents/test_parallel_agent_idle_timeout.py b/tests/unittests/agents/test_parallel_agent_idle_timeout.py new file mode 100644 index 0000000000..197ec9b56c --- /dev/null +++ b/tests/unittests/agents/test_parallel_agent_idle_timeout.py @@ -0,0 +1,432 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ParallelAgent.branch_idle_timeout_secs (issue #5455). + +The bug: when an upstream model stream silently stalls (connection stays open, +no chunks arrive, no exception), process_an_agent in _merge_agent_run blocks +forever on events_for_one_agent.__anext__(), queue.get() in the merge loop +waits forever, and the SSE stream never closes. + +The fix: _iter_with_idle_timeout wraps each __anext__ call with +asyncio.wait_for so that a stalled branch raises asyncio.TimeoutError instead +of hanging indefinitely. +""" + +from __future__ import annotations + +import asyncio +import sys +from typing import AsyncGenerator + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.parallel_agent import _iter_with_idle_timeout +from google.adk.agents.parallel_agent import _merge_agent_run +from google.adk.agents.parallel_agent import _merge_agent_run_pre_3_11 +from google.adk.agents.parallel_agent import ParallelAgent +from google.adk.apps.app import ResumabilityConfig +from google.adk.events.event import Event +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.genai import types +import pytest +from typing_extensions import override + +# --------------------------------------------------------------------------- +# Shared test helpers +# --------------------------------------------------------------------------- + + +def _make_event(author: str, text: str = 'hello') -> Event: + return Event( + author=author, + invocation_id='test-inv', + content=types.Content(parts=[types.Part(text=text)]), + ) + + +async def _make_ctx(agent: BaseAgent) -> InvocationContext: + svc = InMemorySessionService() + session = await svc.create_session(app_name='app', user_id='user') + return InvocationContext( + invocation_id='test-inv', + agent=agent, + session=session, + session_service=svc, + resumability_config=ResumabilityConfig(is_resumable=False), + ) + + +class _NormalAgent(BaseAgent): + """Yields one event then exits.""" + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + yield _make_event(self.name) + + +class _StallingAgent(BaseAgent): + """Yields one event then stalls forever (simulates silent model freeze). + + This is the core reproduction of issue #5455: the branch produces some + events (tool call responses arrive) then the model stream stops emitting + chunks without closing the connection. + """ + + stall_after: int = 1 + """Number of events to yield before stalling.""" + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + for i in range(self.stall_after): + yield _make_event(self.name, f'event-{i}') + # Simulate a connection that is open but silent. + await asyncio.sleep(3600) # stall for 1 hour + yield _make_event(self.name, 'never-reached') # pragma: no cover + + +class _SlowAgent(BaseAgent): + """Yields one event after a short but noticeable delay.""" + + delay: float = 0.05 + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + await asyncio.sleep(self.delay) + yield _make_event(self.name) + + +# --------------------------------------------------------------------------- +# Unit tests for _iter_with_idle_timeout +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_iter_with_idle_timeout_passes_events_through(): + """Normal generator: all events pass through unmodified.""" + + async def _gen(): + for i in range(3): + yield _make_event('agent', f'msg-{i}') + + events = [ + e + async for e in _iter_with_idle_timeout( + _gen(), timeout_secs=1.0, branch_name='test' + ) + ] + assert len(events) == 3 + assert [e.content.parts[0].text for e in events] == [ + 'msg-0', + 'msg-1', + 'msg-2', + ] + + +@pytest.mark.asyncio +async def test_iter_with_idle_timeout_raises_on_stall(): + """Generator that stalls: TimeoutError is raised with a descriptive message.""" + + async def _stalling_gen(): + yield _make_event('agent', 'first') + await asyncio.sleep(3600) # simulate silent stall + yield _make_event('agent', 'never') # pragma: no cover + + gen = _stalling_gen() + it = _iter_with_idle_timeout(gen, timeout_secs=0.05, branch_name='my-branch') + + first = await it.__anext__() + assert first.content.parts[0].text == 'first' + + with pytest.raises(asyncio.TimeoutError) as exc_info: + await it.__anext__() + + assert 'my-branch' in str(exc_info.value) + assert ( + '0.05' in str(exc_info.value) + or '0.1' in str(exc_info.value) + or 'idle' in str(exc_info.value).lower() + ) + await gen.aclose() + + +@pytest.mark.asyncio +async def test_iter_with_idle_timeout_stops_on_empty_generator(): + """Empty generator terminates cleanly.""" + + async def _empty(): + return + yield # make it an async generator # noqa: unreachable + + events = [ + e + async for e in _iter_with_idle_timeout( + _empty(), timeout_secs=1.0, branch_name='empty' + ) + ] + assert events == [] + + +# --------------------------------------------------------------------------- +# Integration tests for _merge_agent_run (3.11+) and _merge_agent_run_pre_3_11 +# --------------------------------------------------------------------------- + + +async def _gen_from_events(events: list[Event]) -> AsyncGenerator[Event, None]: + for e in events: + yield e + + +async def _stalling_gen( + yield_first: bool = True, +) -> AsyncGenerator[Event, None]: + if yield_first: + yield _make_event('stalling') + await asyncio.sleep(3600) + yield _make_event('never') # pragma: no cover + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'merge_fn', + [ + _merge_agent_run, + _merge_agent_run_pre_3_11, + ], +) +async def test_merge_no_timeout_normal(merge_fn): + """Without timeout, two normal generators merge correctly.""" + runs = [ + _gen_from_events([_make_event('a', '1'), _make_event('a', '2')]), + _gen_from_events([_make_event('b', '3')]), + ] + events = [e async for e in merge_fn(runs)] + authors = {e.author for e in events} + assert authors == {'a', 'b'} + assert len(events) == 3 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'merge_fn', + [ + _merge_agent_run, + _merge_agent_run_pre_3_11, + ], +) +async def test_merge_timeout_stalled_branch_raises(merge_fn): + """A stalled branch raises TimeoutError, unblocking the merge loop.""" + runs = [ + _gen_from_events([_make_event('normal', 'ok')]), + _stalling_gen(yield_first=True), + ] + with pytest.raises( + (asyncio.TimeoutError, ExceptionGroup, BaseExceptionGroup) + ): + events = [] + async for e in merge_fn( + runs, + branch_names=['normal', 'stalling'], + branch_idle_timeout_secs=0.05, + ): + events.append(e) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'merge_fn', + [ + _merge_agent_run, + _merge_agent_run_pre_3_11, + ], +) +async def test_merge_timeout_does_not_fire_for_fast_branches(merge_fn): + """Timeout is generous enough that normal (fast) branches complete cleanly.""" + runs = [ + _gen_from_events([_make_event('a', '1'), _make_event('a', '2')]), + _gen_from_events([_make_event('b', '3')]), + ] + events = [ + e + async for e in merge_fn( + runs, + branch_idle_timeout_secs=5.0, # generous + ) + ] + assert len(events) == 3 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'merge_fn', + [ + _merge_agent_run, + _merge_agent_run_pre_3_11, + ], +) +async def test_merge_timeout_none_disables_guard(merge_fn): + """branch_idle_timeout_secs=None preserves the original unbounded behaviour.""" + runs = [ + _gen_from_events([_make_event('a')]), + _gen_from_events([_make_event('b')]), + ] + # Should complete without any timeout machinery + events = [e async for e in merge_fn(runs, branch_idle_timeout_secs=None)] + assert len(events) == 2 + + +# --------------------------------------------------------------------------- +# End-to-end ParallelAgent tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_parallel_agent_no_timeout_field_by_default(): + """branch_idle_timeout_secs defaults to None (no timeout).""" + agent = ParallelAgent(name='p', sub_agents=[]) + assert agent.branch_idle_timeout_secs is None + + +@pytest.mark.asyncio +async def test_parallel_agent_completes_normally_with_timeout_set(): + """With timeout set, normal branches complete and all events arrive.""" + a1 = _NormalAgent(name='a1') + a2 = _NormalAgent(name='a2') + parent = ParallelAgent( + name='parent', + sub_agents=[a1, a2], + branch_idle_timeout_secs=5.0, + ) + ctx = await _make_ctx(parent) + events = [e async for e in parent.run_async(ctx)] + authors = {e.author for e in events} + assert 'a1' in authors + assert 'a2' in authors + + +@pytest.mark.asyncio +async def test_parallel_agent_stalled_branch_raises_timeout(): + """A branch that stalls causes the ParallelAgent to raise TimeoutError. + + This is the direct reproduction of issue #5455: one sub-agent produces some + events (like tool responses) then the model stream silently stops. Without + branch_idle_timeout_secs the SSE stream hangs forever. With it, a + TimeoutError is raised promptly so the caller can close the stream. + """ + normal = _NormalAgent(name='normal') + stalling = _StallingAgent(name='stalling', stall_after=1) + parent = ParallelAgent( + name='parent', + sub_agents=[normal, stalling], + branch_idle_timeout_secs=0.1, + ) + ctx = await _make_ctx(parent) + + with pytest.raises( + (asyncio.TimeoutError, ExceptionGroup, BaseExceptionGroup) + ): + async for _ in parent.run_async(ctx): + pass + + +@pytest.mark.asyncio +async def test_parallel_agent_stalled_from_start_raises_timeout(): + """Branch that stalls immediately (no events at all) also detected.""" + normal = _NormalAgent(name='normal') + stalling = _StallingAgent(name='stalling', stall_after=0) + parent = ParallelAgent( + name='parent', + sub_agents=[normal, stalling], + branch_idle_timeout_secs=0.1, + ) + ctx = await _make_ctx(parent) + + with pytest.raises( + (asyncio.TimeoutError, ExceptionGroup, BaseExceptionGroup) + ): + async for _ in parent.run_async(ctx): + pass + + +@pytest.mark.asyncio +async def test_parallel_agent_slow_but_not_stalled_completes(): + """A slow branch (delay < timeout) is not incorrectly cancelled.""" + slow = _SlowAgent(name='slow', delay=0.05) + normal = _NormalAgent(name='normal') + parent = ParallelAgent( + name='parent', + sub_agents=[normal, slow], + branch_idle_timeout_secs=5.0, # generous — should not fire + ) + ctx = await _make_ctx(parent) + events = [e async for e in parent.run_async(ctx)] + authors = {e.author for e in events} + assert 'normal' in authors + assert 'slow' in authors + + +@pytest.mark.asyncio +async def test_parallel_agent_timeout_logged_as_warning(caplog): + """Warning is logged with branch name and timeout duration when stall fires.""" + import logging + + stalling = _StallingAgent(name='stalling_branch', stall_after=0) + parent = ParallelAgent( + name='parent', + sub_agents=[stalling], + branch_idle_timeout_secs=0.05, + ) + ctx = await _make_ctx(parent) + + with pytest.raises( + (asyncio.TimeoutError, ExceptionGroup, BaseExceptionGroup) + ): + with caplog.at_level(logging.WARNING): + async for _ in parent.run_async(ctx): + pass + + assert any('stalling_branch' in r.message for r in caplog.records) + assert any( + 'idle' in r.message.lower() or 'stall' in r.message.lower() + for r in caplog.records + ) + + +@pytest.mark.asyncio +@pytest.mark.skipif( + sys.version_info < (3, 11), + reason='ExceptionGroup only available on Python 3.11+', +) +async def test_parallel_agent_timeout_error_is_timeout_in_exception_group(): + """On 3.11+ the raised ExceptionGroup wraps a TimeoutError.""" + stalling = _StallingAgent(name='stalling', stall_after=0) + parent = ParallelAgent( + name='parent', + sub_agents=[stalling], + branch_idle_timeout_secs=0.05, + ) + ctx = await _make_ctx(parent) + + with pytest.raises(BaseExceptionGroup) as exc_info: + async for _ in parent.run_async(ctx): + pass + + flat = exc_info.value.exceptions + assert any(isinstance(e, asyncio.TimeoutError) for e in flat) diff --git a/tests/unittests/sessions/test_v0_storage_event.py b/tests/unittests/sessions/test_v0_storage_event.py index 3f542af8b4..9fbfbb09fb 100644 --- a/tests/unittests/sessions/test_v0_storage_event.py +++ b/tests/unittests/sessions/test_v0_storage_event.py @@ -125,3 +125,24 @@ def test_from_event_preserves_short_error_message(): storage_event = StorageEvent.from_event(session, event) assert storage_event.error_message == short_error + + +def test_from_event_with_none_error_message(): + session = Session( + app_name="app", + user_id="user", + id="session_id", + state={}, + events=[], + last_update_time=0.0, + ) + event = Event( + id="event_id", + invocation_id="inv_id", + author="agent", + timestamp=1.0, + ) + + storage_event = StorageEvent.from_event(session, event) + + assert storage_event.error_message is None