diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 34d6a360f..906e147a2 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -95,6 +95,15 @@ async def main(): elicitation_callback: ElicitationFnT | None = None """Callback for handling elicitation requests.""" + streamable_http_session_id: str | None = None + """Optional pre-existing MCP session ID used when server is a StreamableHTTP URL.""" + + streamable_http_initialize_result: InitializeResult | None = None + """Previously negotiated InitializeResult used to resume a StreamableHTTP session.""" + + streamable_http_terminate_on_close: bool = True + """Whether a URL-based StreamableHTTP client should terminate the session on close.""" + _session: ClientSession | None = field(init=False, default=None) _exit_stack: AsyncExitStack | None = field(init=False, default=None) _transport: Transport = field(init=False) @@ -103,10 +112,31 @@ def __post_init__(self) -> None: if isinstance(self.server, Server | MCPServer): self._transport = InMemoryTransport(self.server, raise_exceptions=self.raise_exceptions) elif isinstance(self.server, str): - self._transport = streamable_http_client(self.server) + self._transport = streamable_http_client( + self.server, + session_id=self.streamable_http_session_id, + terminate_on_close=self.streamable_http_terminate_on_close, + ) else: self._transport = self.server + @classmethod + def resume_session( + cls, + server: str, + *, + session_id: str, + initialize_result: InitializeResult, + **kwargs: Any, + ) -> Client: + """Create a URL-based client configured to resume an existing StreamableHTTP session.""" + return cls( + server=server, + streamable_http_session_id=session_id, + streamable_http_initialize_result=initialize_result, + **kwargs, + ) + async def __aenter__(self) -> Client: """Enter the async context manager.""" if self._session is not None: @@ -129,7 +159,15 @@ async def __aenter__(self) -> Client: ) ) - await self._session.initialize() + if self.streamable_http_session_id is None and self.streamable_http_initialize_result is not None: + raise RuntimeError( + "streamable_http_initialize_result requires streamable_http_session_id for session resumption" + ) + + if self.streamable_http_session_id is not None and self.streamable_http_initialize_result is not None: + self._session.resume(self.streamable_http_initialize_result) + else: + await self._session.initialize() # Transfer ownership to self for __aexit__ to handle self._exit_stack = exit_stack.pop_all() diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 7c964a334..1e1e29078 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -191,6 +191,14 @@ async def initialize(self) -> types.InitializeResult: return result + def resume(self, initialize_result: types.InitializeResult) -> None: + """Mark this session as resumed using previously negotiated initialization data. + + This bypasses the initialize/initialized handshake and seeds the session with + server capabilities and metadata from an earlier connection. + """ + self._initialize_result = initialize_result + @property def initialize_result(self) -> types.InitializeResult | None: """The server's InitializeResult. None until initialize() has been called. diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 961021264..81b9a94be 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -63,6 +63,9 @@ class StreamableHttpParameters(BaseModel): # Close the client session when the transport closes. terminate_on_close: bool = True + # Optional pre-existing MCP session ID for explicit session resumption. + session_id: str | None = None + ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters @@ -296,6 +299,7 @@ async def _establish_session( url=server_params.url, http_client=httpx_client, terminate_on_close=server_params.terminate_on_close, + session_id=server_params.session_id, ) read, write = await session_stack.enter_async_context(client) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 3afb94b03..836616d86 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -72,14 +72,15 @@ class RequestContext: class StreamableHTTPTransport: """StreamableHTTP client transport implementation.""" - def __init__(self, url: str) -> None: + def __init__(self, url: str, session_id: str | None = None) -> None: """Initialize the StreamableHTTP transport. Args: url: The endpoint URL. + session_id: Optional pre-existing MCP session ID to resume. """ self.url = url - self.session_id: str | None = None + self.session_id: str | None = session_id self.protocol_version: str | None = None def _prepare_headers(self) -> dict[str, str]: @@ -512,6 +513,7 @@ async def streamable_http_client( *, http_client: httpx.AsyncClient | None = None, terminate_on_close: bool = True, + session_id: str | None = None, ) -> AsyncGenerator[TransportStreams, None]: """Client transport for StreamableHTTP. @@ -521,6 +523,8 @@ async def streamable_http_client( client with recommended MCP timeouts will be created. To configure headers, authentication, or other HTTP settings, create an httpx.AsyncClient and pass it here. terminate_on_close: If True, send a DELETE request to terminate the session when the context exits. + session_id: Optional pre-existing MCP session ID to include in requests, + enabling explicit session resumption. Yields: Tuple containing: @@ -538,7 +542,7 @@ async def streamable_http_client( # Create default client with recommended MCP timeouts client = create_mcp_http_client() - transport = StreamableHTTPTransport(url) + transport = StreamableHTTPTransport(url, session_id=session_id) logger.debug(f"Connecting to StreamableHTTP endpoint: {url}") @@ -557,10 +561,19 @@ async def streamable_http_client( write_stream_reader, anyio.create_task_group() as tg, ): + get_stream_started = False def start_get_stream() -> None: + nonlocal get_stream_started + if get_stream_started: + return + get_stream_started = True tg.start_soon(transport.handle_get_stream, client, read_stream_writer) + # If we're resuming an existing session, start the GET stream immediately. + if session_id: + start_get_stream() + tg.start_soon( transport.post_writer, client, diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 18368e6bb..709cb0a2c 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -2,7 +2,7 @@ from __future__ import annotations -from unittest.mock import patch +from unittest.mock import AsyncMock, MagicMock, patch import anyio import pytest @@ -307,7 +307,103 @@ async def test_complete_with_prompt_reference(simple_server: Server): def test_client_with_url_initializes_streamable_http_transport(): with patch("mcp.client.client.streamable_http_client") as mock: _ = Client("http://localhost:8000/mcp") - mock.assert_called_once_with("http://localhost:8000/mcp") + mock.assert_called_once_with("http://localhost:8000/mcp", session_id=None, terminate_on_close=True) + + +def test_client_with_url_and_session_id_initializes_streamable_http_transport(): + with patch("mcp.client.client.streamable_http_client") as mock: + _ = Client("http://localhost:8000/mcp", streamable_http_session_id="resume-session-id") + mock.assert_called_once_with( + "http://localhost:8000/mcp", + session_id="resume-session-id", + terminate_on_close=True, + ) + + +def test_client_with_url_and_terminate_on_close_false_initializes_streamable_http_transport(): + with patch("mcp.client.client.streamable_http_client") as mock: + _ = Client("http://localhost:8000/mcp", streamable_http_terminate_on_close=False) + mock.assert_called_once_with("http://localhost:8000/mcp", session_id=None, terminate_on_close=False) + + +def test_client_resume_session_builder_initializes_streamable_http_transport(): + initialize_result = types.InitializeResult( + protocol_version="2025-03-26", + capabilities=types.ServerCapabilities(), + server_info=types.Implementation(name="server", version="1.0"), + ) + with patch("mcp.client.client.streamable_http_client") as mock: + _ = Client.resume_session( + "http://localhost:8000/mcp", + session_id="resume-session-id", + initialize_result=initialize_result, + ) + mock.assert_called_once_with( + "http://localhost:8000/mcp", + session_id="resume-session-id", + terminate_on_close=True, + ) + + +async def test_client_resume_session_skips_initialize(): + initialize_result = types.InitializeResult( + protocol_version="2025-03-26", + capabilities=types.ServerCapabilities(), + server_info=types.Implementation(name="server", version="1.0"), + ) + + transport_cm = AsyncMock() + transport_cm.__aenter__.return_value = (MagicMock(), MagicMock()) + transport_cm.__aexit__.return_value = None + + session = MagicMock() + session.initialize = AsyncMock() + session.resume = MagicMock() + session.initialize_result = initialize_result + session_cm = AsyncMock() + session_cm.__aenter__.return_value = session + session_cm.__aexit__.return_value = None + + with ( + patch("mcp.client.client.streamable_http_client", return_value=transport_cm), + patch("mcp.client.client.ClientSession", return_value=session_cm), + ): + async with Client.resume_session( + "http://localhost:8000/mcp", + session_id="resume-session-id", + initialize_result=initialize_result, + ) as client: + assert client.initialize_result == initialize_result + + session.initialize.assert_not_awaited() + session.resume.assert_called_once_with(initialize_result) + + +async def test_client_streamable_initialize_result_requires_session_id(): + initialize_result = types.InitializeResult( + protocol_version="2025-03-26", + capabilities=types.ServerCapabilities(), + server_info=types.Implementation(name="server", version="1.0"), + ) + + transport_cm = AsyncMock() + transport_cm.__aenter__.return_value = (MagicMock(), MagicMock()) + transport_cm.__aexit__.return_value = None + + session_cm = AsyncMock() + session_cm.__aenter__.return_value = AsyncMock() + session_cm.__aexit__.return_value = None + + with ( + patch("mcp.client.client.streamable_http_client", return_value=transport_cm), + patch("mcp.client.client.ClientSession", return_value=session_cm), + ): + client = Client( + "http://localhost:8000/mcp", + streamable_http_initialize_result=initialize_result, + ) + with pytest.raises(RuntimeError, match="requires streamable_http_session_id"): + await client.__aenter__() async def test_client_uses_transport_directly(app: MCPServer): diff --git a/tests/client/test_session.py b/tests/client/test_session.py index f25c964f0..ad2611b58 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -606,6 +606,28 @@ async def mock_server(): assert result.protocol_version == LATEST_PROTOCOL_VERSION +@pytest.mark.anyio +async def test_client_session_resume_sets_initialize_result(): + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + session = ClientSession(server_to_client_receive, client_to_server_send) + assert session.initialize_result is None + + resumed_result = InitializeResult( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), + ) + session.resume(resumed_result) + assert session.initialize_result == resumed_result + + await client_to_server_send.aclose() + await client_to_server_receive.aclose() + await server_to_client_send.aclose() + await server_to_client_receive.aclose() + + @pytest.mark.anyio @pytest.mark.parametrize(argnames="meta", argvalues=[None, {"toolMeta": "value"}]) async def test_client_tool_call_with_meta(meta: RequestParamsMeta | None): diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 6a58b39f3..932700bec 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -294,7 +294,11 @@ async def test_client_session_group_disconnect_non_existent_server(): "mcp.client.session_group.sse_client", ), # url, headers, timeout, sse_read_timeout ( - StreamableHttpParameters(url="http://test.com/stream", terminate_on_close=False), + StreamableHttpParameters( + url="http://test.com/stream", + terminate_on_close=False, + session_id="resumed-session-id", + ), "streamablehttp", "mcp.client.session_group.streamable_http_client", ), # url, headers, timeout, sse_read_timeout, terminate_on_close @@ -363,6 +367,7 @@ async def test_client_session_group_establish_session_parameterized( call_args = mock_specific_client_func.call_args assert call_args.kwargs["url"] == server_params_instance.url assert call_args.kwargs["terminate_on_close"] == server_params_instance.terminate_on_close + assert call_args.kwargs["session_id"] == server_params_instance.session_id assert isinstance(call_args.kwargs["http_client"], httpx.AsyncClient) mock_client_cm_instance.__aenter__.assert_awaited_once() diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index f8ca30441..040924c65 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -10,6 +10,7 @@ import socket import time import traceback +import unittest from collections.abc import AsyncIterator, Generator from contextlib import asynccontextmanager from dataclasses import dataclass, field @@ -1801,6 +1802,55 @@ async def test_handle_sse_event_skips_empty_data(): await read_stream.aclose() +def test_streamable_http_transport_includes_seeded_session_id_header(): + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp", session_id="resume-session-id") + + headers = transport._prepare_headers() + + assert headers["mcp-session-id"] == "resume-session-id" + + +def test_streamable_http_client_resumption_starts_get_stream_once(monkeypatch: pytest.MonkeyPatch): + start_count = 0 + + async def fake_handle_get_stream( + self: StreamableHTTPTransport, # noqa: ARG001 + client: httpx.AsyncClient, # noqa: ARG001 + read_stream_writer: Any, # noqa: ARG001 + ) -> None: + nonlocal start_count + start_count += 1 + await anyio.sleep(0) + + async def fake_post_writer( + self: StreamableHTTPTransport, # noqa: ARG001 + client: httpx.AsyncClient, # noqa: ARG001 + write_stream_reader: Any, # noqa: ARG001 + read_stream_writer: Any, # noqa: ARG001 + write_stream: Any, # noqa: ARG001 + start_get_stream: Any, # noqa: ARG001 + tg: Any, # noqa: ARG001 + ) -> None: + # Call twice; the second call should hit the early return guard. + start_get_stream() + start_get_stream() + await anyio.sleep(0) + + monkeypatch.setattr(StreamableHTTPTransport, "handle_get_stream", fake_handle_get_stream) + monkeypatch.setattr(StreamableHTTPTransport, "post_writer", fake_post_writer) + + async def exercise_client() -> None: + async with streamable_http_client( + "http://localhost:8000/mcp", + session_id="resume-session-id", + terminate_on_close=False, + ): + await anyio.sleep(0) + unittest.TestCase().assertEqual(start_count, 1, f"Expected exactly one GET stream start, got {start_count}") + + anyio.run(exercise_client) + + @pytest.mark.anyio async def test_priming_event_not_sent_for_old_protocol_version(): """Test that _maybe_send_priming_event skips for old protocol versions (backwards compat)."""