Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 108 additions & 17 deletions src/google/adk/agents/parallel_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
from __future__ import annotations

import asyncio
import logging
import sys
from typing import AsyncGenerator
from typing import ClassVar
from typing import Optional

from typing_extensions import override

Expand All @@ -31,6 +33,8 @@
from .invocation_context import InvocationContext
from .parallel_agent_config import ParallelAgentConfig

logger = logging.getLogger('google_adk.' + __name__)


def _create_branch_ctx_for_sub_agent(
agent: BaseAgent,
Expand All @@ -48,29 +52,75 @@ def _create_branch_ctx_for_sub_agent(
return invocation_context


async def _iter_with_idle_timeout(
gen: AsyncGenerator[Event, None],
timeout_secs: float,
branch_name: str,
) -> AsyncGenerator[Event, None]:
"""Wrap *gen*, raising TimeoutError if no event arrives within *timeout_secs*.

Uses asyncio.wait_for on each __anext__ call so that a branch whose upstream
model stream silently stalls (connection open, no chunks) is detected and
cancelled rather than hanging the parent ParallelAgent indefinitely.
"""
while True:
try:
event = await asyncio.wait_for(
gen.__anext__(),
timeout=timeout_secs,
)
except StopAsyncIteration:
return
except asyncio.TimeoutError as exc:
logger.warning(
'ParallelAgent branch %r has not produced an event for %.1fs. '
'The upstream model stream may have stalled. Cancelling the branch '
'so the parent agent is not blocked indefinitely.',
branch_name,
timeout_secs,
)
raise asyncio.TimeoutError(
f'Branch {branch_name!r} idle for >{timeout_secs}s without an event'
) from exc
yield event


async def _merge_agent_run(
agent_runs: list[AsyncGenerator[Event, None]],
*,
branch_names: list[str] | None = None,
branch_idle_timeout_secs: float | None = None,
) -> AsyncGenerator[Event, None]:
"""Merges agent runs using asyncio.TaskGroup on Python 3.11+."""
sentinel = object()
queue = asyncio.Queue()
queue: asyncio.Queue = asyncio.Queue()
names = branch_names or [f'branch-{i}' for i in range(len(agent_runs))]

# Agents are processed in parallel.
# Events for each agent are put on queue sequentially.
async def process_an_agent(events_for_one_agent):
async def process_an_agent(events_for_one_agent, branch_name: str):
try:
async for event in events_for_one_agent:
gen = (
_iter_with_idle_timeout(
events_for_one_agent, branch_idle_timeout_secs, branch_name
)
if branch_idle_timeout_secs is not None
else events_for_one_agent
)
async for event in gen:
resume_signal = asyncio.Event()
await queue.put((event, resume_signal))
# put_nowait: the queue is unbounded so this never blocks, and it is
# safe to call from a finally block that may run during cancellation.
queue.put_nowait((event, resume_signal))
# Wait for upstream to consume event before generating new events.
await resume_signal.wait()
finally:
# Mark agent as finished.
await queue.put((sentinel, None))
# Mark agent as finished. put_nowait is cancellation-safe (see above).
queue.put_nowait((sentinel, None))

async with asyncio.TaskGroup() as tg:
for events_for_one_agent in agent_runs:
tg.create_task(process_an_agent(events_for_one_agent))
for events_for_one_agent, name in zip(agent_runs, names):
tg.create_task(process_an_agent(events_for_one_agent, name))

sentinel_count = 0
# Run until all agents finished processing.
Expand All @@ -88,6 +138,9 @@ async def process_an_agent(events_for_one_agent):
# TODO - remove once Python <3.11 is no longer supported.
async def _merge_agent_run_pre_3_11(
agent_runs: list[AsyncGenerator[Event, None]],
*,
branch_names: list[str] | None = None,
branch_idle_timeout_secs: float | None = None,
) -> AsyncGenerator[Event, None]:
"""Merges agent runs for Python 3.10 without asyncio.TaskGroup.

Expand All @@ -96,12 +149,16 @@ async def _merge_agent_run_pre_3_11(

Args:
agent_runs: Async generators that yield events from each agent.
branch_names: Optional names for each branch, used in log messages.
branch_idle_timeout_secs: If set, cancel a branch that produces no event
for this many seconds (guards against silently stalled model streams).

Yields:
Event: The next event from the merged generator.
"""
sentinel = object()
queue = asyncio.Queue()
queue: asyncio.Queue = asyncio.Queue()
names = branch_names or [f'branch-{i}' for i in range(len(agent_runs))]

def propagate_exceptions(tasks):
# Propagate exceptions and errors from tasks.
Expand All @@ -113,21 +170,31 @@ def propagate_exceptions(tasks):

# Agents are processed in parallel.
# Events for each agent are put on queue sequentially.
async def process_an_agent(events_for_one_agent):
async def process_an_agent(events_for_one_agent, branch_name: str):
try:
async for event in events_for_one_agent:
gen = (
_iter_with_idle_timeout(
events_for_one_agent, branch_idle_timeout_secs, branch_name
)
if branch_idle_timeout_secs is not None
else events_for_one_agent
)
async for event in gen:
resume_signal = asyncio.Event()
await queue.put((event, resume_signal))
queue.put_nowait((event, resume_signal))
# Wait for upstream to consume event before generating new events.
await resume_signal.wait()
finally:
# Mark agent as finished.
await queue.put((sentinel, None))
# put_nowait is cancellation-safe: the queue is unbounded so it never
# blocks, and it will not raise even if the task is being cancelled.
queue.put_nowait((sentinel, None))

tasks = []
try:
for events_for_one_agent in agent_runs:
tasks.append(asyncio.create_task(process_an_agent(events_for_one_agent)))
for events_for_one_agent, name in zip(agent_runs, names):
tasks.append(
asyncio.create_task(process_an_agent(events_for_one_agent, name))
)

sentinel_count = 0
# Run until all agents finished processing.
Expand All @@ -142,6 +209,10 @@ async def process_an_agent(events_for_one_agent):
# Signal to agent that event has been processed by runner and it can
# continue now.
resume_signal.set()
# A task may have put its sentinel AND raised an exception; if the sentinel
# was consumed before propagate_exceptions ran in the loop body, the error
# would be silently lost. Check once more after the loop to surface it.
propagate_exceptions(tasks)
finally:
for task in tasks:
task.cancel()
Expand All @@ -155,11 +226,23 @@ class ParallelAgent(BaseAgent):

- Running different algorithms simultaneously.
- Generating multiple responses for review by a subsequent evaluation agent.

Attributes:
branch_idle_timeout_secs: Optional per-branch idle timeout in seconds.
When set, any branch that produces no event for this many seconds is
cancelled and raises ``asyncio.TimeoutError``, which unblocks the
parent agent instead of hanging indefinitely. This guards against
upstream model streams that stall silently (connection open, no
chunks arriving). ``None`` (the default) disables the timeout and
preserves the original unbounded-wait behaviour.
"""

config_type: ClassVar[type[BaseAgentConfig]] = ParallelAgentConfig
"""The config type for this agent."""

branch_idle_timeout_secs: Optional[float] = None
"""Per-branch idle timeout in seconds; None disables the guard."""

@override
async def _run_async_impl(
self, ctx: InvocationContext
Expand All @@ -173,13 +256,15 @@ async def _run_async_impl(
yield self._create_agent_state_event(ctx)

agent_runs = []
branch_names = []
# Prepare and collect async generators for each sub-agent.
for sub_agent in self.sub_agents:
sub_agent_ctx = _create_branch_ctx_for_sub_agent(self, sub_agent, ctx)

# Only include sub-agents that haven't finished in a previous run.
if not sub_agent_ctx.end_of_agents.get(sub_agent.name):
agent_runs.append(sub_agent.run_async(sub_agent_ctx))
branch_names.append(sub_agent.name)

pause_invocation = False
try:
Expand All @@ -188,7 +273,13 @@ async def _run_async_impl(
if sys.version_info >= (3, 11)
else _merge_agent_run_pre_3_11
)
async with Aclosing(merge_func(agent_runs)) as agen:
async with Aclosing(
merge_func(
agent_runs,
branch_names=branch_names,
branch_idle_timeout_secs=self.branch_idle_timeout_secs,
)
) as agen:
async for event in agen:
yield event
if ctx.should_pause_invocation(event):
Expand Down
Loading