diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 59de0ace4..40a5498c2 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -354,7 +354,7 @@ def session_manager(self) -> StreamableHTTPSessionManager: "Session manager can only be accessed after calling streamable_http_app(). " "The session manager is created lazily to avoid unnecessary initialization." ) - return self._session_manager # pragma: no cover + return self._session_manager async def run( self, @@ -567,6 +567,7 @@ def streamable_http_app( stateless_http: bool = False, event_store: EventStore | None = None, retry_interval: int | None = None, + session_idle_timeout: float | None = None, transport_security: TransportSecuritySettings | None = None, host: str = "127.0.0.1", auth: AuthSettings | None = None, @@ -588,6 +589,7 @@ def streamable_http_app( app=self, event_store=event_store, retry_interval=retry_interval, + session_idle_timeout=session_idle_timeout, json_response=json_response, stateless=stateless_http, security_settings=transport_security, diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index be77705da..3c21e7aed 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -1050,6 +1050,7 @@ def streamable_http_app( stateless_http: bool = False, event_store: EventStore | None = None, retry_interval: int | None = None, + session_idle_timeout: float | None = None, transport_security: TransportSecuritySettings | None = None, host: str = "127.0.0.1", ) -> Starlette: @@ -1060,6 +1061,7 @@ def streamable_http_app( stateless_http=stateless_http, event_store=event_store, retry_interval=retry_interval, + session_idle_timeout=session_idle_timeout, transport_security=transport_security, host=host, auth=self.settings.auth, diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index f14201857..d15373fe3 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -7,6 +7,7 @@ """ import logging +import math import re from abc import ABC, abstractmethod from collections.abc import AsyncGenerator, Awaitable, Callable @@ -171,9 +172,27 @@ def __init__( ] = {} self._sse_stream_writers: dict[RequestId, MemoryObjectSendStream[dict[str, str]]] = {} self._terminated = False + self._active_request_count = 0 # Idle timeout cancel scope; managed by the session manager. self.idle_scope: anyio.CancelScope | None = None + def mark_request_started(self) -> None: + """Suspend idle reaping while at least one HTTP request is in flight.""" + self._active_request_count += 1 + if self.idle_scope is not None: + self.idle_scope.deadline = math.inf + + def mark_request_finished(self, idle_timeout_seconds: float | None) -> None: + """Resume idle reaping once the last in-flight request completes.""" + self._active_request_count = max(0, self._active_request_count - 1) + if ( + idle_timeout_seconds is not None + and self.idle_scope is not None + and self._active_request_count == 0 + and not self._terminated + ): + self.idle_scope.deadline = anyio.current_time() + idle_timeout_seconds + @property def is_terminated(self) -> bool: """Check if this transport has been explicitly terminated.""" diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index c25314eab..e0fc2d2d9 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -196,10 +196,11 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: transport = self._server_instances[request_mcp_session_id] logger.debug("Session already exists, handling request directly") - # Push back idle deadline on activity - if transport.idle_scope is not None and self.session_idle_timeout is not None: - transport.idle_scope.deadline = anyio.current_time() + self.session_idle_timeout # pragma: no cover - await transport.handle_request(scope, receive, send) + transport.mark_request_started() + try: + await transport.handle_request(scope, receive, send) + finally: + transport.mark_request_finished(self.session_idle_timeout) return if request_mcp_session_id is None: @@ -223,7 +224,6 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED) -> None: async with http_transport.connect() as streams: read_stream, write_stream = streams - task_status.started() try: # Use a cancel scope for idle timeout — when the # deadline passes the scope cancels app.run() and @@ -234,6 +234,8 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE idle_scope.deadline = anyio.current_time() + self.session_idle_timeout http_transport.idle_scope = idle_scope + task_status.started() + with idle_scope: await self.app.run( read_stream, @@ -267,7 +269,11 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE await self._task_group.start(run_server) # Handle the HTTP request and return the response - await http_transport.handle_request(scope, receive, send) + http_transport.mark_request_started() + try: + await http_transport.handle_request(scope, receive, send) + finally: + http_transport.mark_request_finished(self.session_idle_timeout) else: # Unknown or expired session ID - return 404 per MCP spec # TODO: Align error code once spec clarifies diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 47cfbf14a..318c2a591 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -2,20 +2,31 @@ import json import logging +from contextlib import asynccontextmanager from typing import Any from unittest.mock import AsyncMock, patch import anyio import httpx import pytest -from starlette.types import Message +from starlette.applications import Starlette +from starlette.routing import Mount +from starlette.types import Message, Receive, Scope, Send from mcp import Client from mcp.client.streamable_http import streamable_http_client from mcp.server import Server, ServerRequestContext, streamable_http_manager from mcp.server.streamable_http import MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport from mcp.server.streamable_http_manager import StreamableHTTPSessionManager -from mcp.types import INVALID_REQUEST, ListToolsResult, PaginatedRequestParams +from mcp.types import ( + INVALID_REQUEST, + CallToolRequestParams, + CallToolResult, + ListToolsResult, + PaginatedRequestParams, + TextContent, + Tool, +) @pytest.mark.anyio @@ -374,8 +385,10 @@ async def mock_receive(): # pragma: no cover assert session_id is not None, "Session ID not found in response headers" - # Wait for the 50ms idle timeout to fire and cleanup to complete - await anyio.sleep(0.1) + # Wait deterministically for the idle timeout to fire and cleanup to complete. + with anyio.fail_after(1): + while session_id in manager._server_instances: + await anyio.sleep(0) # Verify via public API: old session ID now returns 404 response_messages: list[Message] = [] @@ -388,6 +401,7 @@ async def capture_send(message: Message): "method": "POST", "path": "/mcp", "headers": [ + (b"accept", b"application/json, text/event-stream"), (b"content-type", b"application/json"), (b"mcp-session-id", session_id.encode()), ], @@ -413,3 +427,55 @@ def test_session_idle_timeout_rejects_non_positive(): def test_session_idle_timeout_rejects_stateless(): with pytest.raises(RuntimeError, match="not supported in stateless"): StreamableHTTPSessionManager(app=Server("test"), session_idle_timeout=30, stateless=True) + + +def test_streamable_http_app_exposes_session_idle_timeout(): + app = Server("test-streamable-http-app") + + starlette_app = app.streamable_http_app(session_idle_timeout=30) + + assert starlette_app is not None + assert app.session_manager.session_idle_timeout == 30 + + +@pytest.mark.anyio +async def test_session_idle_timeout_does_not_cancel_in_flight_request(): + async def on_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="slow", + description="Slow tool", + input_schema={"type": "object", "properties": {}}, + ) + ] + ) + + async def on_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + await anyio.sleep(2.0) + return CallToolResult(content=[TextContent(type="text", text="ok")]) + + server = Server("idle-timeout-active-request", on_list_tools=on_list_tools, on_call_tool=on_call_tool) + manager = StreamableHTTPSessionManager(app=server, session_idle_timeout=1.0) + + async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None: + await manager.handle_request(scope, receive, send) + + @asynccontextmanager + async def lifespan(app: Starlette): + async with manager.run(): + yield + + starlette_app = Starlette(routes=[Mount("/", app=handle_streamable_http)], lifespan=lifespan) + + async with ( + starlette_app.router.lifespan_context(starlette_app), + httpx.ASGITransport(starlette_app) as transport, + httpx.AsyncClient(transport=transport, base_url="http://testserver") as http_client, + Client(streamable_http_client("http://testserver/", http_client=http_client)) as client, + ): + with anyio.fail_after(5): + result = await client.call_tool("slow", {}) + + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "ok"