Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions src/mcp/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,15 @@ async def main():
elicitation_callback: ElicitationFnT | None = None
"""Callback for handling elicitation requests."""

streamable_http_session_id: str | None = None
"""Optional pre-existing MCP session ID used when server is a StreamableHTTP URL."""

streamable_http_initialize_result: InitializeResult | None = None
"""Previously negotiated InitializeResult used to resume a StreamableHTTP session."""

streamable_http_terminate_on_close: bool = True
"""Whether a URL-based StreamableHTTP client should terminate the session on close."""

_session: ClientSession | None = field(init=False, default=None)
_exit_stack: AsyncExitStack | None = field(init=False, default=None)
_transport: Transport = field(init=False)
Expand All @@ -103,10 +112,31 @@ def __post_init__(self) -> None:
if isinstance(self.server, Server | MCPServer):
self._transport = InMemoryTransport(self.server, raise_exceptions=self.raise_exceptions)
elif isinstance(self.server, str):
self._transport = streamable_http_client(self.server)
self._transport = streamable_http_client(
self.server,
session_id=self.streamable_http_session_id,
terminate_on_close=self.streamable_http_terminate_on_close,
)
else:
self._transport = self.server

@classmethod
def resume_session(
cls,
server: str,
*,
session_id: str,
initialize_result: InitializeResult,
**kwargs: Any,
) -> Client:
"""Create a URL-based client configured to resume an existing StreamableHTTP session."""
return cls(
server=server,
streamable_http_session_id=session_id,
streamable_http_initialize_result=initialize_result,
**kwargs,
)

async def __aenter__(self) -> Client:
"""Enter the async context manager."""
if self._session is not None:
Expand All @@ -129,7 +159,15 @@ async def __aenter__(self) -> Client:
)
)

await self._session.initialize()
if self.streamable_http_session_id is None and self.streamable_http_initialize_result is not None:
raise RuntimeError(
"streamable_http_initialize_result requires streamable_http_session_id for session resumption"
)

if self.streamable_http_session_id is not None and self.streamable_http_initialize_result is not None:
self._session.resume(self.streamable_http_initialize_result)
else:
await self._session.initialize()

# Transfer ownership to self for __aexit__ to handle
self._exit_stack = exit_stack.pop_all()
Expand Down
8 changes: 8 additions & 0 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,14 @@ async def initialize(self) -> types.InitializeResult:

return result

def resume(self, initialize_result: types.InitializeResult) -> None:
"""Mark this session as resumed using previously negotiated initialization data.

This bypasses the initialize/initialized handshake and seeds the session with
server capabilities and metadata from an earlier connection.
"""
self._initialize_result = initialize_result

@property
def initialize_result(self) -> types.InitializeResult | None:
"""The server's InitializeResult. None until initialize() has been called.
Expand Down
4 changes: 4 additions & 0 deletions src/mcp/client/session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ class StreamableHttpParameters(BaseModel):
# Close the client session when the transport closes.
terminate_on_close: bool = True

# Optional pre-existing MCP session ID for explicit session resumption.
session_id: str | None = None


ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters

Expand Down Expand Up @@ -296,6 +299,7 @@ async def _establish_session(
url=server_params.url,
http_client=httpx_client,
terminate_on_close=server_params.terminate_on_close,
session_id=server_params.session_id,
)
read, write = await session_stack.enter_async_context(client)

Expand Down
19 changes: 16 additions & 3 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,15 @@ class RequestContext:
class StreamableHTTPTransport:
"""StreamableHTTP client transport implementation."""

def __init__(self, url: str) -> None:
def __init__(self, url: str, session_id: str | None = None) -> None:
"""Initialize the StreamableHTTP transport.

Args:
url: The endpoint URL.
session_id: Optional pre-existing MCP session ID to resume.
"""
self.url = url
self.session_id: str | None = None
self.session_id: str | None = session_id
self.protocol_version: str | None = None

def _prepare_headers(self) -> dict[str, str]:
Expand Down Expand Up @@ -512,6 +513,7 @@ async def streamable_http_client(
*,
http_client: httpx.AsyncClient | None = None,
terminate_on_close: bool = True,
session_id: str | None = None,
) -> AsyncGenerator[TransportStreams, None]:
"""Client transport for StreamableHTTP.

Expand All @@ -521,6 +523,8 @@ async def streamable_http_client(
client with recommended MCP timeouts will be created. To configure headers,
authentication, or other HTTP settings, create an httpx.AsyncClient and pass it here.
terminate_on_close: If True, send a DELETE request to terminate the session when the context exits.
session_id: Optional pre-existing MCP session ID to include in requests,
enabling explicit session resumption.

Yields:
Tuple containing:
Expand All @@ -538,7 +542,7 @@ async def streamable_http_client(
# Create default client with recommended MCP timeouts
client = create_mcp_http_client()

transport = StreamableHTTPTransport(url)
transport = StreamableHTTPTransport(url, session_id=session_id)

logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")

Expand All @@ -557,10 +561,19 @@ async def streamable_http_client(
write_stream_reader,
anyio.create_task_group() as tg,
):
get_stream_started = False

def start_get_stream() -> None:
nonlocal get_stream_started
if get_stream_started:
return
get_stream_started = True
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)

# If we're resuming an existing session, start the GET stream immediately.
if session_id:
start_get_stream()

tg.start_soon(
transport.post_writer,
client,
Expand Down
100 changes: 98 additions & 2 deletions tests/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from unittest.mock import patch
from unittest.mock import AsyncMock, MagicMock, patch

import anyio
import pytest
Expand Down Expand Up @@ -307,7 +307,103 @@ async def test_complete_with_prompt_reference(simple_server: Server):
def test_client_with_url_initializes_streamable_http_transport():
with patch("mcp.client.client.streamable_http_client") as mock:
_ = Client("http://localhost:8000/mcp")
mock.assert_called_once_with("http://localhost:8000/mcp")
mock.assert_called_once_with("http://localhost:8000/mcp", session_id=None, terminate_on_close=True)


def test_client_with_url_and_session_id_initializes_streamable_http_transport():
with patch("mcp.client.client.streamable_http_client") as mock:
_ = Client("http://localhost:8000/mcp", streamable_http_session_id="resume-session-id")
mock.assert_called_once_with(
"http://localhost:8000/mcp",
session_id="resume-session-id",
terminate_on_close=True,
)


def test_client_with_url_and_terminate_on_close_false_initializes_streamable_http_transport():
with patch("mcp.client.client.streamable_http_client") as mock:
_ = Client("http://localhost:8000/mcp", streamable_http_terminate_on_close=False)
mock.assert_called_once_with("http://localhost:8000/mcp", session_id=None, terminate_on_close=False)


def test_client_resume_session_builder_initializes_streamable_http_transport():
initialize_result = types.InitializeResult(
protocol_version="2025-03-26",
capabilities=types.ServerCapabilities(),
server_info=types.Implementation(name="server", version="1.0"),
)
with patch("mcp.client.client.streamable_http_client") as mock:
_ = Client.resume_session(
"http://localhost:8000/mcp",
session_id="resume-session-id",
initialize_result=initialize_result,
)
mock.assert_called_once_with(
"http://localhost:8000/mcp",
session_id="resume-session-id",
terminate_on_close=True,
)


async def test_client_resume_session_skips_initialize():
initialize_result = types.InitializeResult(
protocol_version="2025-03-26",
capabilities=types.ServerCapabilities(),
server_info=types.Implementation(name="server", version="1.0"),
)

transport_cm = AsyncMock()
transport_cm.__aenter__.return_value = (MagicMock(), MagicMock())
transport_cm.__aexit__.return_value = None

session = MagicMock()
session.initialize = AsyncMock()
session.resume = MagicMock()
session.initialize_result = initialize_result
session_cm = AsyncMock()
session_cm.__aenter__.return_value = session
session_cm.__aexit__.return_value = None

with (
patch("mcp.client.client.streamable_http_client", return_value=transport_cm),
patch("mcp.client.client.ClientSession", return_value=session_cm),
):
async with Client.resume_session(
"http://localhost:8000/mcp",
session_id="resume-session-id",
initialize_result=initialize_result,
) as client:
assert client.initialize_result == initialize_result

session.initialize.assert_not_awaited()
session.resume.assert_called_once_with(initialize_result)


async def test_client_streamable_initialize_result_requires_session_id():
initialize_result = types.InitializeResult(
protocol_version="2025-03-26",
capabilities=types.ServerCapabilities(),
server_info=types.Implementation(name="server", version="1.0"),
)

transport_cm = AsyncMock()
transport_cm.__aenter__.return_value = (MagicMock(), MagicMock())
transport_cm.__aexit__.return_value = None

session_cm = AsyncMock()
session_cm.__aenter__.return_value = AsyncMock()
session_cm.__aexit__.return_value = None

with (
patch("mcp.client.client.streamable_http_client", return_value=transport_cm),
patch("mcp.client.client.ClientSession", return_value=session_cm),
):
client = Client(
"http://localhost:8000/mcp",
streamable_http_initialize_result=initialize_result,
)
with pytest.raises(RuntimeError, match="requires streamable_http_session_id"):
await client.__aenter__()


async def test_client_uses_transport_directly(app: MCPServer):
Expand Down
22 changes: 22 additions & 0 deletions tests/client/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,28 @@ async def mock_server():
assert result.protocol_version == LATEST_PROTOCOL_VERSION


@pytest.mark.anyio
async def test_client_session_resume_sets_initialize_result():
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)

session = ClientSession(server_to_client_receive, client_to_server_send)
assert session.initialize_result is None

resumed_result = InitializeResult(
protocol_version=LATEST_PROTOCOL_VERSION,
capabilities=ServerCapabilities(),
server_info=Implementation(name="mock-server", version="0.1.0"),
)
session.resume(resumed_result)
assert session.initialize_result == resumed_result

await client_to_server_send.aclose()
await client_to_server_receive.aclose()
await server_to_client_send.aclose()
await server_to_client_receive.aclose()


@pytest.mark.anyio
@pytest.mark.parametrize(argnames="meta", argvalues=[None, {"toolMeta": "value"}])
async def test_client_tool_call_with_meta(meta: RequestParamsMeta | None):
Expand Down
7 changes: 6 additions & 1 deletion tests/client/test_session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,11 @@ async def test_client_session_group_disconnect_non_existent_server():
"mcp.client.session_group.sse_client",
), # url, headers, timeout, sse_read_timeout
(
StreamableHttpParameters(url="http://test.com/stream", terminate_on_close=False),
StreamableHttpParameters(
url="http://test.com/stream",
terminate_on_close=False,
session_id="resumed-session-id",
),
"streamablehttp",
"mcp.client.session_group.streamable_http_client",
), # url, headers, timeout, sse_read_timeout, terminate_on_close
Expand Down Expand Up @@ -363,6 +367,7 @@ async def test_client_session_group_establish_session_parameterized(
call_args = mock_specific_client_func.call_args
assert call_args.kwargs["url"] == server_params_instance.url
assert call_args.kwargs["terminate_on_close"] == server_params_instance.terminate_on_close
assert call_args.kwargs["session_id"] == server_params_instance.session_id
assert isinstance(call_args.kwargs["http_client"], httpx.AsyncClient)

mock_client_cm_instance.__aenter__.assert_awaited_once()
Expand Down
51 changes: 51 additions & 0 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -1801,6 +1801,57 @@ async def test_handle_sse_event_skips_empty_data():
await read_stream.aclose()


def test_streamable_http_transport_includes_seeded_session_id_header():
transport = StreamableHTTPTransport(url="http://localhost:8000/mcp", session_id="resume-session-id")

headers = transport._prepare_headers()

assert headers["mcp-session-id"] == "resume-session-id"


def test_streamable_http_client_resumption_starts_get_stream_once(monkeypatch: pytest.MonkeyPatch):
start_count = 0

async def fake_handle_get_stream(
self: StreamableHTTPTransport, # noqa: ARG001
client: httpx.AsyncClient, # noqa: ARG001
read_stream_writer: Any, # noqa: ARG001
) -> None:
nonlocal start_count
start_count += 1
await anyio.sleep(0)

async def fake_post_writer(
self: StreamableHTTPTransport, # noqa: ARG001
client: httpx.AsyncClient, # noqa: ARG001
write_stream_reader: Any, # noqa: ARG001
read_stream_writer: Any, # noqa: ARG001
write_stream: Any, # noqa: ARG001
start_get_stream: Any, # noqa: ARG001
tg: Any, # noqa: ARG001
) -> None:
# Call twice; the second call should hit the early return guard.
start_get_stream()
start_get_stream()
await anyio.sleep(0)

monkeypatch.setattr(StreamableHTTPTransport, "handle_get_stream", fake_handle_get_stream)
monkeypatch.setattr(StreamableHTTPTransport, "post_writer", fake_post_writer)

async def exercise_client() -> None:
async with streamable_http_client(
"http://localhost:8000/mcp",
session_id="resume-session-id",
terminate_on_close=False,
):
await anyio.sleep(0)

anyio.run(exercise_client)

if start_count != 1:
raise AssertionError(f"Expected exactly one GET stream start, got {start_count}")


@pytest.mark.anyio
async def test_priming_event_not_sent_for_old_protocol_version():
"""Test that _maybe_send_priming_event skips for old protocol versions (backwards compat)."""
Expand Down
Loading