diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 7c964a334..31d5c227a 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -11,6 +11,7 @@ from mcp.client.experimental import ExperimentalClientFeatures from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers from mcp.shared._context import RequestContext +from mcp.shared.dispatcher import Dispatcher from mcp.shared.message import SessionMessage from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS @@ -109,8 +110,8 @@ class ClientSession( ): def __init__( self, - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None, + write_stream: MemoryObjectSendStream[SessionMessage] | None = None, read_timeout_seconds: float | None = None, sampling_callback: SamplingFnT | None = None, elicitation_callback: ElicitationFnT | None = None, @@ -121,8 +122,9 @@ def __init__( *, sampling_capabilities: types.SamplingCapability | None = None, experimental_task_handlers: ExperimentalTaskHandlers | None = None, + dispatcher: Dispatcher | None = None, ) -> None: - super().__init__(read_stream, write_stream, read_timeout_seconds=read_timeout_seconds) + super().__init__(read_stream, write_stream, read_timeout_seconds=read_timeout_seconds, dispatcher=dispatcher) self._client_info = client_info or DEFAULT_CLIENT_INFO self._sampling_callback = sampling_callback or _default_sampling_callback self._sampling_capabilities = sampling_capabilities diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index ce467e6c9..16411d517 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -40,6 +40,7 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult: from mcp.server.experimental.session_features import ExperimentalServerSessionFeatures from mcp.server.models import InitializationOptions from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages +from mcp.shared.dispatcher import JSONRPCDispatcher from mcp.shared.exceptions import StatelessModeNotSupported from mcp.shared.experimental.tasks.capabilities import check_tasks_capability from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY @@ -157,9 +158,9 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool: return True - async def _receive_loop(self) -> None: + async def _run(self) -> None: async with self._incoming_message_stream_writer: - await super()._receive_loop() + await super()._run() async def _received_request(self, responder: RequestResponder[types.ClientRequest, types.ServerResult]): match responder.request: @@ -676,12 +677,14 @@ async def send_message(self, message: SessionMessage) -> None: WARNING: This is a low-level experimental method that may change without notice. Prefer using higher-level methods like send_notification() or - send_request() for normal operations. + send_request() for normal operations. Only works with the default + JSON-RPC dispatcher. Args: message: The session message to send """ - await self._write_stream.send(message) + assert isinstance(self._dispatcher, JSONRPCDispatcher), "send_message requires the default JSON-RPC dispatcher" + await self._dispatcher._write_stream.send(message) # type: ignore[reportPrivateUsage] async def _handle_incoming(self, req: ServerRequestResponder) -> None: await self._incoming_message_stream_writer.send(req) diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py new file mode 100644 index 000000000..60decfc0e --- /dev/null +++ b/src/mcp/shared/dispatcher.py @@ -0,0 +1,278 @@ +"""Dispatcher abstraction: the wire-protocol layer beneath a session. + +A ``Dispatcher`` is responsible for encoding MCP messages for the wire, +correlating request/response pairs, and delivering incoming messages to +session-provided handlers. The session itself deals only in MCP-level +dicts (``{"method": ..., "params": ...}`` for requests/notifications, result +dicts for responses) and never sees the wire encoding. + +The default ``JSONRPCDispatcher`` wraps messages in JSON-RPC 2.0 envelopes +and exchanges them over anyio memory streams — this is what every built-in +transport (stdio, Streamable HTTP, WebSocket) feeds into. Custom dispatchers +may use other encodings and request/response models as long as MCP's +request/notification/response semantics are preserved. + +!!! warning + The ``Dispatcher`` Protocol is experimental. Custom transports that + carry JSON-RPC should implement the ``Transport`` Protocol from + ``mcp.client._transport`` instead — that path is stable. +""" + +from __future__ import annotations + +import logging +from collections.abc import Awaitable, Callable +from typing import Any, Protocol + +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream + +from mcp.shared.exceptions import MCPError +from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage +from mcp.shared.response_router import ResponseRouter +from mcp.types import ( + CONNECTION_CLOSED, + ErrorData, + JSONRPCError, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + RequestId, +) + +OnRequestFn = Callable[[RequestId, dict[str, Any], MessageMetadata], Awaitable[None]] +"""Called when the peer sends us a request. Receives ``(request_id, {"method", "params"}, metadata)``.""" + +OnNotificationFn = Callable[[dict[str, Any]], Awaitable[None]] +"""Called when the peer sends us a notification. Receives ``{"method", "params"}``.""" + +OnErrorFn = Callable[[Exception], Awaitable[None]] +"""Called for transport-level errors and orphaned error responses.""" + + +class Dispatcher(Protocol): + """Wire-protocol layer beneath ``BaseSession``. + + Session generates request IDs (they double as progress tokens); the dispatcher + uses them for correlation if its protocol needs that. ``send_request`` blocks + until the correlated response arrives and returns the raw result dict, which + the session then validates into an MCP result type. + + Implementations must be cancellation-safe: if ``send_request`` is cancelled + (e.g. by the session's timeout), any correlation state for that request must + be cleaned up. + """ + + def set_handlers( + self, + on_request: OnRequestFn, + on_notification: OnNotificationFn, + on_error: OnErrorFn, + ) -> None: + """Wire incoming-message callbacks. Called once, before ``run``.""" + ... + + async def run(self) -> None: + """Run the receive loop. Returns when the connection closes. + + Started in the session's task group; cancelled on session exit. + """ + ... + + async def send_request( + self, + request_id: RequestId, + request: dict[str, Any], + metadata: MessageMetadata = None, + timeout: float | None = None, + ) -> dict[str, Any]: + """Send a request and wait for its response. + + ``request`` is ``{"method": str, "params": dict | None}``. Returns the raw + result dict. Raises ``MCPError`` if the peer returns an error response. + Raises ``TimeoutError`` if no response arrives within ``timeout``. + + The send itself must not be subject to the timeout — only the wait for + the response — so that requests are reliably delivered even when the + caller sets an aggressive deadline. + """ + ... + + async def send_notification( + self, + notification: dict[str, Any], + related_request_id: RequestId | None = None, + ) -> None: + """Send a fire-and-forget notification. ``notification`` is ``{"method", "params"}``.""" + ... + + async def send_response( + self, + request_id: RequestId, + response: dict[str, Any] | ErrorData, + ) -> None: + """Send a response to a request we previously received via ``on_request``.""" + ... + + +class JSONRPCDispatcher: + """Default dispatcher: JSON-RPC 2.0 over anyio memory streams. + + This is the behaviour ``BaseSession`` had before the dispatcher extraction — + every built-in transport produces a pair of streams that feed into here. + """ + + def __init__( + self, + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], + write_stream: MemoryObjectSendStream[SessionMessage], + response_routers: list[ResponseRouter], + ) -> None: + self._read_stream = read_stream + self._write_stream = write_stream + self._response_routers = response_routers + self._response_streams: dict[RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]] = {} + self._on_request: OnRequestFn | None = None + self._on_notification: OnNotificationFn | None = None + self._on_error: OnErrorFn | None = None + + def set_handlers( + self, + on_request: OnRequestFn, + on_notification: OnNotificationFn, + on_error: OnErrorFn, + ) -> None: + self._on_request = on_request + self._on_notification = on_notification + self._on_error = on_error + + async def send_request( + self, + request_id: RequestId, + request: dict[str, Any], + metadata: MessageMetadata = None, + timeout: float | None = None, + ) -> dict[str, Any]: + response_stream, response_stream_reader = anyio.create_memory_object_stream[JSONRPCResponse | JSONRPCError](1) + self._response_streams[request_id] = response_stream + try: + jsonrpc_request = JSONRPCRequest(jsonrpc="2.0", id=request_id, **request) + await self._write_stream.send(SessionMessage(message=jsonrpc_request, metadata=metadata)) + with anyio.fail_after(timeout): + response_or_error = await response_stream_reader.receive() + if isinstance(response_or_error, JSONRPCError): + raise MCPError.from_jsonrpc_error(response_or_error) + return response_or_error.result + finally: + self._response_streams.pop(request_id, None) + await response_stream.aclose() + await response_stream_reader.aclose() + + async def send_notification( + self, + notification: dict[str, Any], + related_request_id: RequestId | None = None, + ) -> None: + jsonrpc_notification = JSONRPCNotification(jsonrpc="2.0", **notification) + session_message = SessionMessage( + message=jsonrpc_notification, + metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None, + ) + await self._write_stream.send(session_message) + + async def send_response( + self, + request_id: RequestId, + response: dict[str, Any] | ErrorData, + ) -> None: + if isinstance(response, ErrorData): + message: JSONRPCResponse | JSONRPCError = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) + else: + message = JSONRPCResponse(jsonrpc="2.0", id=request_id, result=response) + await self._write_stream.send(SessionMessage(message=message)) + + async def run(self) -> None: + assert self._on_request is not None + assert self._on_notification is not None + assert self._on_error is not None + + async with self._read_stream, self._write_stream: + try: + async for message in self._read_stream: + if isinstance(message, Exception): + await self._on_error(message) + elif isinstance(message.message, JSONRPCRequest): + await self._on_request( + message.message.id, + message.message.model_dump(by_alias=True, mode="json", exclude_none=True), + message.metadata, + ) + elif isinstance(message.message, JSONRPCNotification): + await self._on_notification( + message.message.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + else: + await self._route_response(message) + except anyio.ClosedResourceError: + # Expected when the peer disconnects abruptly. + logging.debug("Read stream closed by client") + except Exception as e: + logging.exception(f"Unhandled exception in receive loop: {e}") # pragma: no cover + finally: + # Deliver CONNECTION_CLOSED to every request still awaiting a response. + for id, stream in self._response_streams.items(): + error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed") + try: + await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) + await stream.aclose() + except Exception: # pragma: no cover + pass + self._response_streams.clear() + # Handlers are bound methods of the session; the session holds us. Break + # the cycle so refcount GC can free both promptly. + self._on_request = None + self._on_notification = None + self._on_error = None + + async def _route_response(self, message: SessionMessage) -> None: + # Runtime-true (run() only calls us in the response/error branch) but the + # type checker can't see that, hence the guard. + if not isinstance(message.message, JSONRPCResponse | JSONRPCError): + return # pragma: no cover + + assert self._on_error is not None + + if message.message.id is None: + error = message.message.error + logging.warning(f"Received error with null ID: {error.message}") + await self._on_error(MCPError(error.code, error.message, error.data)) + return + + response_id = self._normalize_request_id(message.message.id) + + # Response routers (experimental task support) get first look. + if isinstance(message.message, JSONRPCError): + for router in self._response_routers: + if router.route_error(response_id, message.message.error): + return + else: + response_data: dict[str, Any] = message.message.result or {} + for router in self._response_routers: + if router.route_response(response_id, response_data): + return + + stream = self._response_streams.pop(response_id, None) + if stream: + await stream.send(message.message) + else: + await self._on_error(RuntimeError(f"Received response with an unknown request ID: {message}")) + + @staticmethod + def _normalize_request_id(response_id: RequestId) -> RequestId: + # We send integer IDs; some peers echo them back as strings. + if isinstance(response_id, str): + try: + return int(response_id) + except ValueError: + logging.warning(f"Response ID {response_id!r} cannot be normalized to match pending requests") + return response_id diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 9364abb73..8e1edc818 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -11,11 +11,11 @@ from pydantic import BaseModel, TypeAdapter from typing_extensions import Self +from mcp.shared.dispatcher import Dispatcher, JSONRPCDispatcher from mcp.shared.exceptions import MCPError -from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage +from mcp.shared.message import MessageMetadata, SessionMessage from mcp.shared.response_router import ResponseRouter from mcp.types import ( - CONNECTION_CLOSED, INVALID_PARAMS, REQUEST_TIMEOUT, CancelledNotification, @@ -23,10 +23,6 @@ ClientRequest, ClientResult, ErrorData, - JSONRPCError, - JSONRPCNotification, - JSONRPCRequest, - JSONRPCResponse, ProgressNotification, ProgressToken, RequestParamsMeta, @@ -166,14 +162,18 @@ class BaseSession( ReceiveNotificationT, ], ): - """Implements an MCP "session" on top of read/write streams, including features - like request/response linking, notifications, and progress. + """Implements an MCP "session" on top of a wire-protocol dispatcher, including + features like request/response linking, notifications, and progress. This class is an async context manager that automatically starts processing messages when entered. + + By default the session constructs a ``JSONRPCDispatcher`` from the supplied + read/write streams — this is the path every built-in transport uses. A custom + dispatcher can be passed via the ``dispatcher`` keyword argument to use a + different wire protocol; see ``mcp.shared.dispatcher``. """ - _response_streams: dict[RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]] _request_id: int _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] _progress_callbacks: dict[RequestId, ProgressFnT] @@ -181,14 +181,13 @@ class BaseSession( def __init__( self, - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None, + write_stream: MemoryObjectSendStream[SessionMessage] | None = None, # If none, reading will never time out read_timeout_seconds: float | None = None, + *, + dispatcher: Dispatcher | None = None, ) -> None: - self._read_stream = read_stream - self._write_stream = write_stream - self._response_streams = {} self._request_id = 0 self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} @@ -196,6 +195,12 @@ def __init__( self._response_routers = [] self._exit_stack = AsyncExitStack() + if dispatcher is None: + if read_stream is None or write_stream is None: + raise TypeError("either dispatcher or both read_stream and write_stream must be provided") + dispatcher = JSONRPCDispatcher(read_stream, write_stream, self._response_routers) + self._dispatcher = dispatcher + def add_response_router(self, router: ResponseRouter) -> None: """Register a response router to handle responses for non-standard requests. @@ -212,11 +217,21 @@ def add_response_router(self, router: ResponseRouter) -> None: self._response_routers.append(router) async def __aenter__(self) -> Self: + self._dispatcher.set_handlers( + on_request=self._on_incoming_request, + on_notification=self._on_incoming_notification, + on_error=self._handle_incoming, + ) self._task_group = anyio.create_task_group() await self._task_group.__aenter__() - self._task_group.start_soon(self._receive_loop) + self._task_group.start_soon(self._run) return self + async def _run(self) -> None: + """Run the dispatcher's receive loop. Hook for subclasses that need to + wrap the loop's lifetime (e.g. to close resources when it exits).""" + await self._dispatcher.run() + async def __aexit__( self, exc_type: type[BaseException] | None, @@ -248,9 +263,6 @@ async def send_request( request_id = self._request_id self._request_id = request_id + 1 - response_stream, response_stream_reader = anyio.create_memory_object_stream[JSONRPCResponse | JSONRPCError](1) - self._response_streams[request_id] = response_stream - # Set up progress token if progress callback is provided request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True) if progress_callback is not None: @@ -263,31 +275,19 @@ async def send_request( # Store the callback for this request self._progress_callbacks[request_id] = progress_callback - try: - jsonrpc_request = JSONRPCRequest(jsonrpc="2.0", id=request_id, **request_data) - await self._write_stream.send(SessionMessage(message=jsonrpc_request, metadata=metadata)) - - # request read timeout takes precedence over session read timeout - timeout = request_read_timeout_seconds or self._session_read_timeout_seconds + # request read timeout takes precedence over session read timeout + timeout = request_read_timeout_seconds or self._session_read_timeout_seconds + try: try: - with anyio.fail_after(timeout): - response_or_error = await response_stream_reader.receive() + result = await self._dispatcher.send_request(request_id, request_data, metadata, timeout) except TimeoutError: class_name = request.__class__.__name__ message = f"Timed out while waiting for response to {class_name}. Waited {timeout} seconds." raise MCPError(code=REQUEST_TIMEOUT, message=message) - - if isinstance(response_or_error, JSONRPCError): - raise MCPError.from_jsonrpc_error(response_or_error) - else: - return result_type.model_validate(response_or_error.result, by_name=False) - + return result_type.model_validate(result, by_name=False) finally: - self._response_streams.pop(request_id, None) self._progress_callbacks.pop(request_id, None) - await response_stream.aclose() - await response_stream_reader.aclose() async def send_notification( self, @@ -297,29 +297,19 @@ async def send_notification( """Emits a notification, which is a one-way message that does not expect a response.""" # Some transport implementations may need to set the related_request_id # to attribute to the notifications to the request that triggered them. - jsonrpc_notification = JSONRPCNotification( - jsonrpc="2.0", - **notification.model_dump(by_alias=True, mode="json", exclude_none=True), - ) - session_message = SessionMessage( - message=jsonrpc_notification, - metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None, + await self._dispatcher.send_notification( + notification.model_dump(by_alias=True, mode="json", exclude_none=True), + related_request_id, ) - await self._write_stream.send(session_message) async def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None: if isinstance(response, ErrorData): - jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) - session_message = SessionMessage(message=jsonrpc_error) - await self._write_stream.send(session_message) + await self._dispatcher.send_response(request_id, response) else: - jsonrpc_response = JSONRPCResponse( - jsonrpc="2.0", - id=request_id, - result=response.model_dump(by_alias=True, mode="json", exclude_none=True), + await self._dispatcher.send_response( + request_id, + response.model_dump(by_alias=True, mode="json", exclude_none=True), ) - session_message = SessionMessage(message=jsonrpc_response) - await self._write_stream.send(session_message) @property def _receive_request_adapter(self) -> TypeAdapter[ReceiveRequestT]: @@ -330,165 +320,65 @@ def _receive_request_adapter(self) -> TypeAdapter[ReceiveRequestT]: def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]: raise NotImplementedError - async def _receive_loop(self) -> None: - async with self._read_stream, self._write_stream: - try: - async for message in self._read_stream: - if isinstance(message, Exception): - await self._handle_incoming(message) - elif isinstance(message.message, JSONRPCRequest): - try: - validated_request = self._receive_request_adapter.validate_python( - message.message.model_dump(by_alias=True, mode="json", exclude_none=True), - by_name=False, - ) - responder = RequestResponder( - request_id=message.message.id, - request_meta=validated_request.params.meta if validated_request.params else None, - request=validated_request, - session=self, - on_complete=lambda r: self._in_flight.pop(r.request_id, None), - message_metadata=message.metadata, - ) - self._in_flight[responder.request_id] = responder - await self._received_request(responder) - - if not responder._completed: # type: ignore[reportPrivateUsage] - await self._handle_incoming(responder) - except Exception: - # For request validation errors, send a proper JSON-RPC error - # response instead of crashing the server - logging.warning("Failed to validate request", exc_info=True) - logging.debug(f"Message that failed validation: {message.message}") - error_response = JSONRPCError( - jsonrpc="2.0", - id=message.message.id, - error=ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data=""), - ) - session_message = SessionMessage(message=error_response) - await self._write_stream.send(session_message) + async def _on_incoming_request( + self, request_id: RequestId, payload: dict[str, Any], metadata: MessageMetadata + ) -> None: + """Dispatcher callback: a request arrived from the peer.""" + try: + validated_request = self._receive_request_adapter.validate_python(payload, by_name=False) + responder = RequestResponder( + request_id=request_id, + request_meta=validated_request.params.meta if validated_request.params else None, + request=validated_request, + session=self, + on_complete=lambda r: self._in_flight.pop(r.request_id, None), + message_metadata=metadata, + ) + self._in_flight[responder.request_id] = responder + await self._received_request(responder) + + if not responder._completed: # type: ignore[reportPrivateUsage] + await self._handle_incoming(responder) + except Exception: + logging.warning("Failed to validate request", exc_info=True) + logging.debug(f"Message that failed validation: {payload}") + await self._dispatcher.send_response( + request_id, + ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data=""), + ) - elif isinstance(message.message, JSONRPCNotification): + async def _on_incoming_notification(self, payload: dict[str, Any]) -> None: + """Dispatcher callback: a notification arrived from the peer.""" + try: + notification = self._receive_notification_adapter.validate_python(payload, by_name=False) + # Handle cancellation notifications + if isinstance(notification, CancelledNotification): + cancelled_id = notification.params.request_id + if cancelled_id in self._in_flight: # pragma: no branch + await self._in_flight[cancelled_id].cancel() + else: + # Handle progress notifications callback + if isinstance(notification, ProgressNotification): + progress_token = notification.params.progress_token + # If there is a progress callback for this token, + # call it with the progress information + if progress_token in self._progress_callbacks: + callback = self._progress_callbacks[progress_token] try: - notification = self._receive_notification_adapter.validate_python( - message.message.model_dump(by_alias=True, mode="json", exclude_none=True), - by_name=False, + await callback( + notification.params.progress, + notification.params.total, + notification.params.message, ) - # Handle cancellation notifications - if isinstance(notification, CancelledNotification): - cancelled_id = notification.params.request_id - if cancelled_id in self._in_flight: # pragma: no branch - await self._in_flight[cancelled_id].cancel() - else: - # Handle progress notifications callback - if isinstance(notification, ProgressNotification): - progress_token = notification.params.progress_token - # If there is a progress callback for this token, - # call it with the progress information - if progress_token in self._progress_callbacks: - callback = self._progress_callbacks[progress_token] - try: - await callback( - notification.params.progress, - notification.params.total, - notification.params.message, - ) - except Exception: - logging.exception("Progress callback raised an exception") - await self._received_notification(notification) - await self._handle_incoming(notification) except Exception: - # For other validation errors, log and continue - logging.warning( # pragma: no cover - f"Failed to validate notification:. Message was: {message.message}", - exc_info=True, - ) - else: # Response or error - await self._handle_response(message) - - except anyio.ClosedResourceError: - # This is expected when the client disconnects abruptly. - # Without this handler, the exception would propagate up and - # crash the server's task group. - logging.debug("Read stream closed by client") - except Exception as e: - # Other exceptions are not expected and should be logged. We purposefully - # catch all exceptions here to avoid crashing the server. - logging.exception(f"Unhandled exception in receive loop: {e}") # pragma: no cover - finally: - # after the read stream is closed, we need to send errors - # to any pending requests - for id, stream in self._response_streams.items(): - error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed") - try: - await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) - await stream.aclose() - except Exception: # pragma: no cover - # Stream might already be closed - pass - self._response_streams.clear() - - def _normalize_request_id(self, response_id: RequestId) -> RequestId: - """Normalize a response ID to match how request IDs are stored. - - Since the client always sends integer IDs, we normalize string IDs - to integers when possible. This matches the TypeScript SDK approach: - https://github.com/modelcontextprotocol/typescript-sdk/blob/a606fb17909ea454e83aab14c73f14ea45c04448/src/shared/protocol.ts#L861 - - Args: - response_id: The response ID from the incoming message. - - Returns: - The normalized ID (int if possible, otherwise original value). - """ - if isinstance(response_id, str): - try: - return int(response_id) - except ValueError: - logging.warning(f"Response ID {response_id!r} cannot be normalized to match pending requests") - return response_id - - async def _handle_response(self, message: SessionMessage) -> None: - """Handle an incoming response or error message. - - Checks response routers first (e.g., for task-related responses), - then falls back to the normal response stream mechanism. - """ - # This check is always true at runtime: the caller (_receive_loop) only invokes - # this method in the else branch after checking for JSONRPCRequest and - # JSONRPCNotification. However, the type checker can't infer this from the - # method signature, so we need this guard for type narrowing. - if not isinstance(message.message, JSONRPCResponse | JSONRPCError): - return # pragma: no cover - - if message.message.id is None: - # Narrows to JSONRPCError since JSONRPCResponse.id is always RequestId - error = message.message.error - logging.warning(f"Received error with null ID: {error.message}") - await self._handle_incoming(MCPError(error.code, error.message, error.data)) - return - # Normalize response ID to handle type mismatches (e.g., "0" vs 0) - response_id = self._normalize_request_id(message.message.id) - - # First, check response routers (e.g., TaskResultHandler) - if isinstance(message.message, JSONRPCError): - # Route error to routers - for router in self._response_routers: - if router.route_error(response_id, message.message.error): - return # Handled - else: - # Route success response to routers - response_data: dict[str, Any] = message.message.result or {} - for router in self._response_routers: - if router.route_response(response_id, response_data): - return # Handled - - # Fall back to normal response streams - stream = self._response_streams.pop(response_id, None) - if stream: - await stream.send(message.message) - else: - await self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}")) + logging.exception("Progress callback raised an exception") + await self._received_notification(notification) + await self._handle_incoming(notification) + except Exception: + logging.warning( # pragma: no cover + f"Failed to validate notification. Message was: {payload}", + exc_info=True, + ) async def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None: """Can be overridden by subclasses to handle a request without needing to diff --git a/tests/client/test_resource_cleanup.py b/tests/client/test_resource_cleanup.py index c7bf8fafa..6a8016b95 100644 --- a/tests/client/test_resource_cleanup.py +++ b/tests/client/test_resource_cleanup.py @@ -5,6 +5,7 @@ import pytest from pydantic import TypeAdapter +from mcp.shared.dispatcher import JSONRPCDispatcher from mcp.shared.message import SessionMessage from mcp.shared.session import BaseSession, RequestId, SendResultT from mcp.types import ClientNotification, ClientRequest, ClientResult, EmptyResult, ErrorData, PingRequest @@ -46,17 +47,21 @@ def _receive_notification_adapter(self) -> TypeAdapter[Any]: async def mock_send(*args: Any, **kwargs: Any): raise RuntimeError("Simulated network error") + # JSON-RPC correlation state lives on the dispatcher now; reach through to it. + dispatcher = session._dispatcher + assert isinstance(dispatcher, JSONRPCDispatcher) + # Record the response streams before the test - initial_stream_count = len(session._response_streams) + initial_stream_count = len(dispatcher._response_streams) # Run the test with the patched method - with patch.object(session._write_stream, "send", mock_send): + with patch.object(dispatcher._write_stream, "send", mock_send): with pytest.raises(RuntimeError): await session.send_request(request, EmptyResult) # Verify that no response streams were leaked - assert len(session._response_streams) == initial_stream_count, ( - f"Expected {initial_stream_count} response streams after request, but found {len(session._response_streams)}" + assert len(dispatcher._response_streams) == initial_stream_count, ( + f"Expected {initial_stream_count} response streams after request, but found {len(dispatcher._response_streams)}" ) # Clean up diff --git a/tests/server/test_lowlevel_exception_handling.py b/tests/server/test_lowlevel_exception_handling.py index 46925916d..8ae16cbce 100644 --- a/tests/server/test_lowlevel_exception_handling.py +++ b/tests/server/test_lowlevel_exception_handling.py @@ -71,8 +71,8 @@ async def test_server_run_exits_cleanly_when_transport_yields_exception_then_clo 1. Transport yields an Exception into the read stream (streamable_http.py does this in its broad POST-handler except). 2. Transport closes the read stream (terminate() in stateless mode). - 3. _receive_loop exits its `async with read_stream, write_stream:` block, - closing the write stream. + 3. The dispatcher's receive loop exits its `async with read_stream, write_stream:` + block, closing the write stream. 4. Meanwhile _handle_message(exc) was spawned via tg.start_soon and runs after the write stream is closed. @@ -84,8 +84,8 @@ async def test_server_run_exits_cleanly_when_transport_yields_exception_then_clo read_send, read_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) # Zero-buffer on the write stream forces send() to block until received. - # With no receiver, a send() sits blocked until _receive_loop exits its - # `async with self._read_stream, self._write_stream:` block and closes the + # With no receiver, a send() sits blocked until the dispatcher's receive loop + # exits its `async with read_stream, write_stream:` block and closes the # stream, at which point the blocked send raises ClosedResourceError. # This deterministically reproduces the race without sleeps. write_send, write_recv = anyio.create_memory_object_stream[SessionMessage](0) @@ -99,7 +99,7 @@ async def test_server_run_exits_cleanly_when_transport_yields_exception_then_clo # Before this fix, this raised ExceptionGroup(ClosedResourceError). await server.run(read_recv, write_send, server.create_initialization_options(), stateless=True) - # write_send was closed inside _receive_loop's `async with`; receive_nowait + # write_send was closed inside the dispatcher's `async with`; receive_nowait # raises EndOfStream iff the buffer is empty (i.e., server wrote nothing). with pytest.raises(anyio.EndOfStream): write_recv.receive_nowait() diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py new file mode 100644 index 000000000..364a77232 --- /dev/null +++ b/tests/shared/test_dispatcher.py @@ -0,0 +1,122 @@ +"""Tests for the Dispatcher abstraction beneath BaseSession.""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from mcp.client._memory import InMemoryTransport +from mcp.client.session import ClientSession +from mcp.server.mcpserver import Context, MCPServer +from mcp.shared._context import RequestContext +from mcp.shared.dispatcher import ( + JSONRPCDispatcher, + OnErrorFn, + OnNotificationFn, + OnRequestFn, +) +from mcp.shared.message import MessageMetadata +from mcp.types import ( + CreateMessageRequestParams, + CreateMessageResult, + ErrorData, + RequestId, + SamplingMessage, + TextContent, +) + +pytestmark = pytest.mark.anyio + + +class SpyDispatcher: + """A custom Dispatcher that wraps JSONRPCDispatcher and records traffic. + + This is the shape a real non-JSON-RPC dispatcher (gRPC, CBOR, etc.) would + take: satisfy the Dispatcher Protocol structurally, deal in MCP-level dicts. + Wrapping JSONRPCDispatcher here lets us assert the session never bypasses us + while still talking to a real server on the other end. + """ + + def __init__(self, inner: JSONRPCDispatcher) -> None: + self._inner = inner + self.sent_requests: list[dict[str, Any]] = [] + self.sent_notifications: list[dict[str, Any]] = [] + self.sent_responses: list[dict[str, Any] | ErrorData] = [] + + def set_handlers(self, on_request: OnRequestFn, on_notification: OnNotificationFn, on_error: OnErrorFn) -> None: + self._inner.set_handlers(on_request, on_notification, on_error) + + async def run(self) -> None: + await self._inner.run() + + async def send_request( + self, + request_id: RequestId, + request: dict[str, Any], + metadata: MessageMetadata = None, + timeout: float | None = None, + ) -> dict[str, Any]: + self.sent_requests.append(request) + return await self._inner.send_request(request_id, request, metadata, timeout) + + async def send_notification( + self, notification: dict[str, Any], related_request_id: RequestId | None = None + ) -> None: + self.sent_notifications.append(notification) + await self._inner.send_notification(notification, related_request_id) + + async def send_response(self, request_id: RequestId, response: dict[str, Any] | ErrorData) -> None: + self.sent_responses.append(response) + await self._inner.send_response(request_id, response) + + +async def test_client_session_accepts_custom_dispatcher(): + """ClientSession round-trips through a custom dispatcher end-to-end, including + a server-initiated request (sampling) so all five dispatcher methods fire.""" + app = MCPServer("test") + + @app.tool() + async def ask(question: str, ctx: Context) -> str: + answer = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(type="text", text=question))], + max_tokens=10, + ) + assert isinstance(answer.content, TextContent) + return answer.content.text + + async def sampling_callback( + context: RequestContext[ClientSession], params: CreateMessageRequestParams + ) -> CreateMessageResult: + return CreateMessageResult( + role="assistant", + content=TextContent(type="text", text="42"), + model="test", + stop_reason="endTurn", + ) + + # InMemoryTransport runs the server for us and yields client-side streams — + # we intercept those streams and feed them through a custom dispatcher. + async with InMemoryTransport(app) as (client_read, client_write): + inner = JSONRPCDispatcher(client_read, client_write, response_routers=[]) + spy = SpyDispatcher(inner) + + async with ClientSession(dispatcher=spy, sampling_callback=sampling_callback) as session: + await session.initialize() + result = await session.call_tool("ask", {"question": "meaning of life?"}) + content = result.content[0] + assert isinstance(content, TextContent) + assert content.text == "42" + + # initialize, tools/call (triggers sampling on the server), tools/list (schema refresh) + assert [r["method"] for r in spy.sent_requests] == ["initialize", "tools/call", "tools/list"] + assert [n["method"] for n in spy.sent_notifications] == ["notifications/initialized"] + # The server's sampling/createMessage request hit us; our response went back through the spy. + assert len(spy.sent_responses) == 1 + response = spy.sent_responses[0] + assert isinstance(response, dict) and response["model"] == "test" + + +async def test_base_session_requires_streams_or_dispatcher(): + with pytest.raises(TypeError, match="either dispatcher or both"): + ClientSession()