From 36991d8bc1e5cea4980eb9dd99fdfc5e3286f34c Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 13 Mar 2026 11:14:12 +0000 Subject: [PATCH 1/2] fix: eliminate port allocation race in test_streamable_http fixtures MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous pattern picked a free port via socket.bind(0), released it, then started a uvicorn subprocess hoping to rebind — a TOCTOU race that caused intermittent CI failures when pytest-xdist workers stole the port between release and rebind (ConnectError, 404 against wrong server). Added run_uvicorn_in_thread() which pre-binds the listening socket with port=0 and passes it to uvicorn via server.run(sockets=[sock]). The port is held atomically from bind until shutdown and is known before the server thread even starts — no polling, no race. The kernel's listen queue buffers any connections that arrive during uvicorn startup. Migrated the four test_streamable_http.py fixtures (basic_server, event_server, json_response_server, context_aware_server) that share create_app(). These include the SSE auto-reconnect tests that genuinely need real TCP to exercise connection lifecycle. Running the server in-process means coverage now tracks transport code that was previously subprocess-invisible; adjusted pragmas accordingly (targeted no-cover on unreached error paths, lax no-cover on timing-dependent branches). wait_for_server() is kept for files not touched by this PR. --- src/mcp/server/session.py | 2 +- src/mcp/server/streamable_http.py | 70 +++--- src/mcp/server/transport_security.py | 22 +- tests/shared/test_streamable_http.py | 321 ++++++++------------------- tests/test_helpers.py | 10 +- 5 files changed, 146 insertions(+), 279 deletions(-) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index ce467e6c9..947c3b20d 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -222,7 +222,7 @@ async def send_log_message( related_request_id, ) - async def send_resource_updated(self, uri: str | AnyUrl) -> None: # pragma: no cover + async def send_resource_updated(self, uri: str | AnyUrl) -> None: """Send a resource updated notification.""" await self.send_notification( types.ResourceUpdatedNotification( diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index aa99e7c88..011deb1f0 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -177,7 +177,7 @@ def is_terminated(self) -> bool: """Check if this transport has been explicitly terminated.""" return self._terminated - def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover + def close_sse_stream(self, request_id: RequestId) -> None: """Close SSE connection for a specific request without terminating the stream. This method closes the HTTP connection for the specified request, triggering @@ -200,12 +200,12 @@ def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover writer.close() # Also close and remove request streams - if request_id in self._request_streams: + if request_id in self._request_streams: # pragma: no branch send_stream, receive_stream = self._request_streams.pop(request_id) send_stream.close() receive_stream.close() - def close_standalone_sse_stream(self) -> None: # pragma: no cover + def close_standalone_sse_stream(self) -> None: """Close the standalone GET SSE stream, triggering client reconnection. This method closes the HTTP connection for the standalone GET stream used @@ -240,10 +240,10 @@ def _create_session_message( # Only provide close callbacks when client supports resumability if self._event_store and protocol_version >= "2025-11-25": - async def close_stream_callback() -> None: # pragma: no cover + async def close_stream_callback() -> None: self.close_sse_stream(request_id) - async def close_standalone_stream_callback() -> None: # pragma: no cover + async def close_standalone_stream_callback() -> None: self.close_standalone_sse_stream() metadata = ServerMessageMetadata( @@ -291,7 +291,7 @@ def _create_error_response( ) -> Response: """Create an error response with a simple string message.""" response_headers = {"Content-Type": CONTENT_TYPE_JSON} - if headers: # pragma: no cover + if headers: response_headers.update(headers) if self.mcp_session_id: @@ -342,7 +342,7 @@ def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: } # If an event ID was provided, include it - if event_message.event_id: # pragma: no cover + if event_message.event_id: event_data["id"] = event_message.event_id return event_data @@ -372,7 +372,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No await error_response(scope, receive, send) return - if self._terminated: # pragma: no cover + if self._terminated: # If the session has been terminated, return 404 Not Found response = self._create_error_response( "Not Found: Session has been terminated", @@ -387,7 +387,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No await self._handle_get_request(request, send) elif request.method == "DELETE": await self._handle_delete_request(request, send) - else: # pragma: no cover + else: await self._handle_unsupported_request(request, send) def _check_accept_headers(self, request: Request) -> tuple[bool, bool]: @@ -467,7 +467,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re try: message = jsonrpc_message_adapter.validate_python(raw_message, by_name=False) - except ValidationError as e: # pragma: no cover + except ValidationError as e: response = self._create_error_response( f"Validation error: {str(e)}", HTTPStatus.BAD_REQUEST, @@ -493,7 +493,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re ) await response(scope, receive, send) return - elif not await self._validate_request_headers(request, send): # pragma: no cover + elif not await self._validate_request_headers(request, send): return # For notifications and responses only, return 202 Accepted @@ -633,7 +633,7 @@ async def sse_writer(): # pragma: lax no cover finally: await sse_stream_reader.aclose() - except Exception as err: # pragma: no cover + except Exception as err: # pragma: lax no cover logger.exception("Error handling POST request") response = self._create_error_response( f"Error handling POST request: {err}", @@ -659,7 +659,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: # Validate Accept header - must include text/event-stream _, has_sse = self._check_accept_headers(request) - if not has_sse: # pragma: no cover + if not has_sse: response = self._create_error_response( "Not Acceptable: Client must accept text/event-stream", HTTPStatus.NOT_ACCEPTABLE, @@ -667,11 +667,11 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: await response(request.scope, request.receive, send) return - if not await self._validate_request_headers(request, send): # pragma: no cover - return + if not await self._validate_request_headers(request, send): + return # pragma: no cover # Handle resumability: check for Last-Event-ID header - if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): # pragma: no cover + if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): await self._replay_events(last_event_id, request, send) return @@ -681,11 +681,11 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: "Content-Type": CONTENT_TYPE_SSE, } - if self.mcp_session_id: + if self.mcp_session_id: # pragma: no branch headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id # Check if we already have an active GET stream - if GET_STREAM_KEY in self._request_streams: # pragma: no cover + if GET_STREAM_KEY in self._request_streams: response = self._create_error_response( "Conflict: Only one SSE stream is allowed per session", HTTPStatus.CONFLICT, @@ -714,7 +714,7 @@ async def standalone_sse_writer(): # Send the message via SSE event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) - except Exception: # pragma: no cover + except Exception: # pragma: lax no cover logger.exception("Error in standalone SSE writer") finally: logger.debug("Closing standalone SSE writer") @@ -791,13 +791,13 @@ async def terminate(self) -> None: # During cleanup, we catch all exceptions since streams might be in various states logger.debug(f"Error closing streams: {e}") - async def _handle_unsupported_request(self, request: Request, send: Send) -> None: # pragma: no cover + async def _handle_unsupported_request(self, request: Request, send: Send) -> None: """Handle unsupported HTTP methods.""" headers = { "Content-Type": CONTENT_TYPE_JSON, "Allow": "GET, POST, DELETE", } - if self.mcp_session_id: + if self.mcp_session_id: # pragma: no branch headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id response = self._create_error_response( @@ -824,7 +824,7 @@ async def _validate_session(self, request: Request, send: Send) -> bool: request_session_id = self._get_session_id(request) # If no session ID provided but required, return error - if not request_session_id: # pragma: no cover + if not request_session_id: response = self._create_error_response( "Bad Request: Missing session ID", HTTPStatus.BAD_REQUEST, @@ -849,11 +849,11 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER) # If no protocol version provided, assume default version - if protocol_version is None: # pragma: no cover + if protocol_version is None: protocol_version = DEFAULT_NEGOTIATED_VERSION # Check if the protocol version is supported - if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: # pragma: no cover + if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: supported_versions = ", ".join(SUPPORTED_PROTOCOL_VERSIONS) response = self._create_error_response( f"Bad Request: Unsupported protocol version: {protocol_version}. " @@ -865,13 +865,13 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool return True - async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: # pragma: no cover + async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: """Replays events that would have been sent after the specified event ID. Only used when resumability is enabled. """ event_store = self._event_store - if not event_store: + if not event_store: # pragma: no cover return try: @@ -881,7 +881,7 @@ async def _replay_events(self, last_event_id: str, request: Request, send: Send) "Content-Type": CONTENT_TYPE_SSE, } - if self.mcp_session_id: + if self.mcp_session_id: # pragma: no branch headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id # Get protocol version from header (already validated in _validate_protocol_version) @@ -902,7 +902,7 @@ async def send_event(event_message: EventMessage) -> None: stream_id = await event_store.replay_events_after(last_event_id, send_event) # If stream ID not in mapping, create it - if stream_id and stream_id not in self._request_streams: + if stream_id and stream_id not in self._request_streams: # pragma: no branch # Register SSE writer so close_sse_stream() can close it self._sse_stream_writers[stream_id] = sse_stream_writer @@ -919,10 +919,10 @@ async def send_event(event_message: EventMessage) -> None: event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) - except anyio.ClosedResourceError: + except anyio.ClosedResourceError: # pragma: lax no cover # Expected when close_sse_stream() is called logger.debug("Replay SSE stream closed by close_sse_stream()") - except Exception: + except Exception: # pragma: lax no cover logger.exception("Error in replay sender") # Create and start EventSourceResponse @@ -934,13 +934,13 @@ async def send_event(event_message: EventMessage) -> None: try: await response(request.scope, request.receive, send) - except Exception: + except Exception: # pragma: no cover logger.exception("Error in replay response") finally: await sse_stream_writer.aclose() await sse_stream_reader.aclose() - except Exception: + except Exception: # pragma: no cover logger.exception("Error replaying events") response = self._create_error_response( "Error replaying events", @@ -991,7 +991,7 @@ async def message_router(): if isinstance(message, JSONRPCResponse | JSONRPCError) and message.id is not None: target_request_id = str(message.id) # Extract related_request_id from meta if it exists - elif ( # pragma: no cover + elif ( session_message.metadata is not None and isinstance( session_message.metadata, @@ -1015,10 +1015,10 @@ async def message_router(): try: # Send both the message and the event ID await self._request_streams[request_stream_id][0].send(EventMessage(message, event_id)) - except (anyio.BrokenResourceError, anyio.ClosedResourceError): # pragma: no cover + except (anyio.BrokenResourceError, anyio.ClosedResourceError): # pragma: lax no cover # Stream might be closed, remove from registry self._request_streams.pop(request_stream_id, None) - else: # pragma: no cover + else: logger.debug( f"""Request stream {request_stream_id} not found for message. Still processing message as the client diff --git a/src/mcp/server/transport_security.py b/src/mcp/server/transport_security.py index 1ed9842c0..8a0a6c15d 100644 --- a/src/mcp/server/transport_security.py +++ b/src/mcp/server/transport_security.py @@ -40,7 +40,7 @@ def __init__(self, settings: TransportSecuritySettings | None = None): # If not specified, disable DNS rebinding protection by default for backwards compatibility self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False) - def _validate_host(self, host: str | None) -> bool: # pragma: no cover + def _validate_host(self, host: str | None) -> bool: # pragma: lax no cover """Validate the Host header against allowed values.""" if not host: logger.warning("Missing Host header in request") @@ -62,7 +62,7 @@ def _validate_host(self, host: str | None) -> bool: # pragma: no cover logger.warning(f"Invalid Host header: {host}") return False - def _validate_origin(self, origin: str | None) -> bool: # pragma: no cover + def _validate_origin(self, origin: str | None) -> bool: # pragma: lax no cover """Validate the Origin header against allowed values.""" # Origin can be absent for same-origin requests if not origin: @@ -103,14 +103,14 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res if not self.settings.enable_dns_rebinding_protection: return None - # Validate Host header # pragma: no cover - host = request.headers.get("host") # pragma: no cover - if not self._validate_host(host): # pragma: no cover - return Response("Invalid Host header", status_code=421) # pragma: no cover + # Validate Host header + host = request.headers.get("host") # pragma: lax no cover + if not self._validate_host(host): # pragma: lax no cover + return Response("Invalid Host header", status_code=421) - # Validate Origin header # pragma: no cover - origin = request.headers.get("origin") # pragma: no cover - if not self._validate_origin(origin): # pragma: no cover - return Response("Invalid Origin header", status_code=403) # pragma: no cover + # Validate Origin header + origin = request.headers.get("origin") # pragma: lax no cover + if not self._validate_origin(origin): # pragma: lax no cover + return Response("Invalid Origin header", status_code=403) - return None # pragma: no cover + return None # pragma: lax no cover diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index f8ca30441..9119b9d14 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -6,10 +6,7 @@ from __future__ import annotations as _annotations import json -import multiprocessing -import socket import time -import traceback from collections.abc import AsyncIterator, Generator from contextlib import asynccontextmanager from dataclasses import dataclass, field @@ -21,7 +18,6 @@ import httpx import pytest import requests -import uvicorn from httpx_sse import ServerSentEvent from starlette.applications import Starlette from starlette.requests import Request @@ -65,7 +61,7 @@ TextResourceContents, Tool, ) -from tests.test_helpers import wait_for_server +from tests.test_helpers import run_uvicorn_in_thread # Test constants SERVER_NAME = "test_streamable_http_server" @@ -108,7 +104,7 @@ async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage | self._events.append((stream_id, event_id, message)) return event_id - async def replay_events_after( # pragma: no cover + async def replay_events_after( self, last_event_id: EventId, send_callback: EventCallback, @@ -117,11 +113,11 @@ async def replay_events_after( # pragma: no cover # Find the stream ID of the last event target_stream_id = None for stream_id, event_id, _ in self._events: - if event_id == last_event_id: + if event_id == last_event_id: # pragma: no branch target_stream_id = stream_id break - if target_stream_id is None: + if target_stream_id is None: # pragma: no cover # If event ID not found, return None return None @@ -132,7 +128,7 @@ async def replay_events_after( # pragma: no cover for stream_id, event_id, message in self._events: if stream_id == target_stream_id and int(event_id) > last_event_id_int: # Skip priming events (None message) - if message is not None: + if message is not None: # pragma: no branch await send_callback(EventMessage(message, event_id)) return target_stream_id @@ -144,18 +140,18 @@ class ServerState: @asynccontextmanager -async def _server_lifespan(_server: Server[ServerState]) -> AsyncIterator[ServerState]: # pragma: no cover +async def _server_lifespan(_server: Server[ServerState]) -> AsyncIterator[ServerState]: yield ServerState() -async def _handle_read_resource( # pragma: no cover +async def _handle_read_resource( ctx: ServerRequestContext[ServerState], params: ReadResourceRequestParams ) -> ReadResourceResult: uri = str(params.uri) parsed = urlparse(uri) if parsed.scheme == "foobar": text = f"Read {parsed.netloc}" - elif parsed.scheme == "slow": + elif parsed.scheme == "slow": # pragma: no cover await anyio.sleep(2.0) text = f"Slow response from {parsed.netloc}" else: @@ -163,7 +159,7 @@ async def _handle_read_resource( # pragma: no cover return ReadResourceResult(contents=[TextResourceContents(uri=uri, text=text, mime_type="text/plain")]) -async def _handle_list_tools( # pragma: no cover +async def _handle_list_tools( ctx: ServerRequestContext[ServerState], params: PaginatedRequestParams | None ) -> ListToolsResult: return ListToolsResult( @@ -228,9 +224,7 @@ async def _handle_list_tools( # pragma: no cover ) -async def _handle_call_tool( # pragma: no cover - ctx: ServerRequestContext[ServerState], params: CallToolRequestParams -) -> CallToolResult: +async def _handle_call_tool(ctx: ServerRequestContext[ServerState], params: CallToolRequestParams) -> CallToolResult: name = params.name args = params.arguments or {} @@ -239,7 +233,7 @@ async def _handle_call_tool( # pragma: no cover await ctx.session.send_resource_updated(uri="http://test_resource") return CallToolResult(content=[TextContent(type="text", text=f"Called {name}")]) - elif name == "long_running_with_checkpoints": + elif name == "long_running_with_checkpoints": # pragma: no cover await ctx.session.send_log_message( level="info", data="Tool started", @@ -272,7 +266,7 @@ async def _handle_call_tool( # pragma: no cover if sampling_result.content.type == "text": response = sampling_result.content.text - else: + else: # pragma: no cover response = str(sampling_result.content) return CallToolResult( content=[ @@ -360,7 +354,7 @@ async def _handle_call_tool( # pragma: no cover related_request_id=ctx.request_id, ) - if ctx.close_sse_stream: + if ctx.close_sse_stream: # pragma: no branch await ctx.close_sse_stream() await anyio.sleep(sleep_time) @@ -371,7 +365,7 @@ async def _handle_call_tool( # pragma: no cover await ctx.session.send_resource_updated(uri="http://notification_1") await anyio.sleep(0.1) - if ctx.close_standalone_sse_stream: + if ctx.close_standalone_sse_stream: # pragma: no branch await ctx.close_standalone_sse_stream() await anyio.sleep(1.5) @@ -382,7 +376,7 @@ async def _handle_call_tool( # pragma: no cover return CallToolResult(content=[TextContent(type="text", text=f"Called {name}")]) -def _create_server() -> Server[ServerState]: # pragma: no cover +def _create_server() -> Server[ServerState]: return Server( SERVER_NAME, lifespan=_server_lifespan, @@ -396,7 +390,7 @@ def create_app( is_json_response_enabled: bool = False, event_store: EventStore | None = None, retry_interval: int | None = None, -) -> Starlette: # pragma: no cover +) -> Starlette: """Create a Starlette application for testing using the session manager. Args: @@ -431,74 +425,11 @@ def create_app( return app -def run_server( - port: int, - is_json_response_enabled: bool = False, - event_store: EventStore | None = None, - retry_interval: int | None = None, -) -> None: # pragma: no cover - """Run the test server. - - Args: - port: Port to listen on. - is_json_response_enabled: If True, use JSON responses instead of SSE streams. - event_store: Optional event store for testing resumability. - retry_interval: Retry interval in milliseconds for SSE polling. - """ - - app = create_app(is_json_response_enabled, event_store, retry_interval) - # Configure server - config = uvicorn.Config( - app=app, - host="127.0.0.1", - port=port, - log_level="info", - limit_concurrency=10, - timeout_keep_alive=5, - access_log=False, - ) - - # Start the server - server = uvicorn.Server(config=config) - - # This is important to catch exceptions and prevent test hangs - try: - server.run() - except Exception: - traceback.print_exc() - - -# Test fixtures - using same approach as SSE tests -@pytest.fixture -def basic_server_port() -> int: - """Find an available port for the basic server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - @pytest.fixture -def json_server_port() -> int: - """Find an available port for the JSON response server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def basic_server(basic_server_port: int) -> Generator[None, None, None]: - """Start a basic server.""" - proc = multiprocessing.Process(target=run_server, kwargs={"port": basic_server_port}, daemon=True) - proc.start() - - # Wait for server to be running - wait_for_server(basic_server_port) - - yield - - # Clean up - proc.kill() - proc.join(timeout=2) +def basic_server_url() -> Generator[str, None, None]: + """Start a basic server in a background thread. Yields its base URL.""" + with run_uvicorn_in_thread(create_app(), limit_concurrency=10, timeout_keep_alive=5, access_log=False) as url: + yield url @pytest.fixture @@ -508,69 +439,28 @@ def event_store() -> SimpleEventStore: @pytest.fixture -def event_server_port() -> int: - """Find an available port for the event store server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def event_server( - event_server_port: int, event_store: SimpleEventStore -) -> Generator[tuple[SimpleEventStore, str], None, None]: - """Start a server with event store and retry_interval enabled.""" - proc = multiprocessing.Process( - target=run_server, - kwargs={"port": event_server_port, "event_store": event_store, "retry_interval": 500}, - daemon=True, - ) - proc.start() - - # Wait for server to be running - wait_for_server(event_server_port) - - yield event_store, f"http://127.0.0.1:{event_server_port}" - - # Clean up - proc.kill() - proc.join(timeout=2) - - -@pytest.fixture -def json_response_server(json_server_port: int) -> Generator[None, None, None]: - """Start a server with JSON response enabled.""" - proc = multiprocessing.Process( - target=run_server, - kwargs={"port": json_server_port, "is_json_response_enabled": True}, - daemon=True, - ) - proc.start() - - # Wait for server to be running - wait_for_server(json_server_port) - - yield - - # Clean up - proc.kill() - proc.join(timeout=2) - - -@pytest.fixture -def basic_server_url(basic_server_port: int) -> str: - """Get the URL for the basic test server.""" - return f"http://127.0.0.1:{basic_server_port}" +def event_server(event_store: SimpleEventStore) -> Generator[tuple[SimpleEventStore, str], None, None]: + """Start a server with event store and retry_interval enabled in a background thread.""" + with run_uvicorn_in_thread( + create_app(event_store=event_store, retry_interval=500), + limit_concurrency=10, + timeout_keep_alive=5, + access_log=False, + ) as url: + yield event_store, url @pytest.fixture -def json_server_url(json_server_port: int) -> str: - """Get the URL for the JSON response test server.""" - return f"http://127.0.0.1:{json_server_port}" +def json_server_url() -> Generator[str, None, None]: + """Start a server with JSON response enabled in a background thread. Yields its base URL.""" + with run_uvicorn_in_thread( + create_app(is_json_response_enabled=True), limit_concurrency=10, timeout_keep_alive=5, access_log=False + ) as url: + yield url # Basic request validation tests -def test_accept_header_validation(basic_server: None, basic_server_url: str): +def test_accept_header_validation(basic_server_url: str): """Test that Accept header is properly validated.""" # Test without Accept header (suppress requests library default Accept: */*) session = requests.Session() @@ -595,7 +485,7 @@ def test_accept_header_validation(basic_server: None, basic_server_url: str): "application/*;q=0.9, text/*;q=0.8", ], ) -def test_accept_header_wildcard(basic_server: None, basic_server_url: str, accept_header: str): +def test_accept_header_wildcard(basic_server_url: str, accept_header: str): """Test that wildcard Accept headers are accepted per RFC 7231.""" response = requests.post( f"{basic_server_url}/mcp", @@ -616,7 +506,7 @@ def test_accept_header_wildcard(basic_server: None, basic_server_url: str, accep "text/*", ], ) -def test_accept_header_incompatible(basic_server: None, basic_server_url: str, accept_header: str): +def test_accept_header_incompatible(basic_server_url: str, accept_header: str): """Test that incompatible Accept headers are rejected for SSE mode.""" response = requests.post( f"{basic_server_url}/mcp", @@ -630,7 +520,7 @@ def test_accept_header_incompatible(basic_server: None, basic_server_url: str, a assert "Not Acceptable" in response.text -def test_content_type_validation(basic_server: None, basic_server_url: str): +def test_content_type_validation(basic_server_url: str): """Test that Content-Type header is properly validated.""" # Test with incorrect Content-Type response = requests.post( @@ -646,7 +536,7 @@ def test_content_type_validation(basic_server: None, basic_server_url: str): assert "Invalid Content-Type" in response.text -def test_json_validation(basic_server: None, basic_server_url: str): +def test_json_validation(basic_server_url: str): """Test that JSON content is properly validated.""" # Test with invalid JSON response = requests.post( @@ -661,7 +551,7 @@ def test_json_validation(basic_server: None, basic_server_url: str): assert "Parse error" in response.text -def test_json_parsing(basic_server: None, basic_server_url: str): +def test_json_parsing(basic_server_url: str): """Test that JSON content is properly parse.""" # Test with valid JSON but invalid JSON-RPC response = requests.post( @@ -676,7 +566,7 @@ def test_json_parsing(basic_server: None, basic_server_url: str): assert "Validation error" in response.text -def test_method_not_allowed(basic_server: None, basic_server_url: str): +def test_method_not_allowed(basic_server_url: str): """Test that unsupported HTTP methods are rejected.""" # Test with unsupported method (PUT) response = requests.put( @@ -691,7 +581,7 @@ def test_method_not_allowed(basic_server: None, basic_server_url: str): assert "Method Not Allowed" in response.text -def test_session_validation(basic_server: None, basic_server_url: str): +def test_session_validation(basic_server_url: str): """Test session ID validation.""" # session_id not used directly in this test @@ -766,7 +656,7 @@ def test_streamable_http_transport_init_validation(): StreamableHTTPServerTransport(mcp_session_id="test\n") -def test_session_termination(basic_server: None, basic_server_url: str): +def test_session_termination(basic_server_url: str): """Test session termination via DELETE and subsequent request handling.""" response = requests.post( f"{basic_server_url}/mcp", @@ -806,7 +696,7 @@ def test_session_termination(basic_server: None, basic_server_url: str): assert "Session has been terminated" in response.text -def test_response(basic_server: None, basic_server_url: str): +def test_response(basic_server_url: str): """Test response handling for a valid request.""" mcp_url = f"{basic_server_url}/mcp" response = requests.post( @@ -841,7 +731,7 @@ def test_response(basic_server: None, basic_server_url: str): assert tools_response.headers.get("Content-Type") == "text/event-stream" -def test_json_response(json_response_server: None, json_server_url: str): +def test_json_response(json_server_url: str): """Test response handling when is_json_response_enabled is True.""" mcp_url = f"{json_server_url}/mcp" response = requests.post( @@ -856,7 +746,7 @@ def test_json_response(json_response_server: None, json_server_url: str): assert response.headers.get("Content-Type") == "application/json" -def test_json_response_accept_json_only(json_response_server: None, json_server_url: str): +def test_json_response_accept_json_only(json_server_url: str): """Test that json_response servers only require application/json in Accept header.""" mcp_url = f"{json_server_url}/mcp" response = requests.post( @@ -871,7 +761,7 @@ def test_json_response_accept_json_only(json_response_server: None, json_server_ assert response.headers.get("Content-Type") == "application/json" -def test_json_response_missing_accept_header(json_response_server: None, json_server_url: str): +def test_json_response_missing_accept_header(json_server_url: str): """Test that json_response servers reject requests without Accept header.""" mcp_url = f"{json_server_url}/mcp" # Suppress requests library default Accept: */* header @@ -888,7 +778,7 @@ def test_json_response_missing_accept_header(json_response_server: None, json_se assert "Not Acceptable" in response.text -def test_json_response_incorrect_accept_header(json_response_server: None, json_server_url: str): +def test_json_response_incorrect_accept_header(json_server_url: str): """Test that json_response servers reject requests with incorrect Accept header.""" mcp_url = f"{json_server_url}/mcp" # Test with only text/event-stream (wrong for JSON server) @@ -912,7 +802,7 @@ def test_json_response_incorrect_accept_header(json_response_server: None, json_ "application/*;q=0.9", ], ) -def test_json_response_wildcard_accept_header(json_response_server: None, json_server_url: str, accept_header: str): +def test_json_response_wildcard_accept_header(json_server_url: str, accept_header: str): """Test that json_response servers accept wildcard Accept headers per RFC 7231.""" mcp_url = f"{json_server_url}/mcp" response = requests.post( @@ -927,7 +817,7 @@ def test_json_response_wildcard_accept_header(json_response_server: None, json_s assert response.headers.get("Content-Type") == "application/json" -def test_get_sse_stream(basic_server: None, basic_server_url: str): +def test_get_sse_stream(basic_server_url: str): """Test establishing an SSE stream via GET request.""" # First, we need to initialize a session mcp_url = f"{basic_server_url}/mcp" @@ -987,7 +877,7 @@ def test_get_sse_stream(basic_server: None, basic_server_url: str): assert second_get.status_code == 409 -def test_get_validation(basic_server: None, basic_server_url: str): +def test_get_validation(basic_server_url: str): """Test validation for GET requests.""" # First, we need to initialize a session mcp_url = f"{basic_server_url}/mcp" @@ -1044,14 +934,14 @@ def test_get_validation(basic_server: None, basic_server_url: str): # Client-specific fixtures @pytest.fixture -async def http_client(basic_server: None, basic_server_url: str): # pragma: no cover +async def http_client(basic_server_url: str): # pragma: no cover """Create test client matching the SSE test pattern.""" async with httpx.AsyncClient(base_url=basic_server_url) as client: yield client @pytest.fixture -async def initialized_client_session(basic_server: None, basic_server_url: str): +async def initialized_client_session(basic_server_url: str): """Create initialized StreamableHTTP client session.""" async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: @@ -1060,7 +950,7 @@ async def initialized_client_session(basic_server: None, basic_server_url: str): @pytest.mark.anyio -async def test_streamable_http_client_basic_connection(basic_server: None, basic_server_url: str): +async def test_streamable_http_client_basic_connection(basic_server_url: str): """Test basic client connection with initialization.""" async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: @@ -1105,7 +995,7 @@ async def test_streamable_http_client_error_handling(initialized_client_session: @pytest.mark.anyio -async def test_streamable_http_client_session_persistence(basic_server: None, basic_server_url: str): +async def test_streamable_http_client_session_persistence(basic_server_url: str): """Test that session ID persists across requests.""" async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: @@ -1126,7 +1016,7 @@ async def test_streamable_http_client_session_persistence(basic_server: None, ba @pytest.mark.anyio -async def test_streamable_http_client_json_response(json_response_server: None, json_server_url: str): +async def test_streamable_http_client_json_response(json_server_url: str): """Test client with JSON response mode.""" async with streamable_http_client(f"{json_server_url}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: @@ -1147,7 +1037,7 @@ async def test_streamable_http_client_json_response(json_response_server: None, @pytest.mark.anyio -async def test_streamable_http_client_get_stream(basic_server: None, basic_server_url: str): +async def test_streamable_http_client_get_stream(basic_server_url: str): """Test GET stream functionality for server-initiated messages.""" notifications_received: list[types.ServerNotification] = [] @@ -1198,7 +1088,7 @@ async def capture_session_id(response: httpx.Response) -> None: @pytest.mark.anyio -async def test_streamable_http_client_session_termination(basic_server: None, basic_server_url: str): +async def test_streamable_http_client_session_termination(basic_server_url: str): """Test client session termination functionality.""" # Use httpx client with event hooks to capture session ID httpx_client, captured_ids = create_session_id_capturing_client() @@ -1233,9 +1123,7 @@ async def test_streamable_http_client_session_termination(basic_server: None, ba @pytest.mark.anyio -async def test_streamable_http_client_session_termination_204( - basic_server: None, basic_server_url: str, monkeypatch: pytest.MonkeyPatch -): +async def test_streamable_http_client_session_termination_204(basic_server_url: str, monkeypatch: pytest.MonkeyPatch): """Test client session termination functionality with a 204 response. This test patches the httpx client to return a 204 response for DELETEs. @@ -1412,7 +1300,7 @@ async def run_tool(): @pytest.mark.anyio -async def test_streamablehttp_server_sampling(basic_server: None, basic_server_url: str): +async def test_streamablehttp_server_sampling(basic_server_url: str): """Test server-initiated sampling request through streamable HTTP transport.""" # Variable to track if sampling callback was invoked sampling_callback_invoked = False @@ -1462,7 +1350,7 @@ async def sampling_callback( # Context-aware server implementation for testing request context propagation -async def _handle_context_list_tools( # pragma: no cover +async def _handle_context_list_tools( ctx: ServerRequestContext, params: PaginatedRequestParams | None ) -> ListToolsResult: return ListToolsResult( @@ -1487,15 +1375,13 @@ async def _handle_context_list_tools( # pragma: no cover ) -async def _handle_context_call_tool( # pragma: no cover - ctx: ServerRequestContext, params: CallToolRequestParams -) -> CallToolResult: +async def _handle_context_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: name = params.name args = params.arguments or {} if name == "echo_headers": headers_info: dict[str, Any] = {} - if ctx.request and isinstance(ctx.request, Request): + if ctx.request and isinstance(ctx.request, Request): # pragma: no branch headers_info = dict(ctx.request.headers) return CallToolResult(content=[TextContent(type="text", text=json.dumps(headers_info))]) @@ -1506,19 +1392,18 @@ async def _handle_context_call_tool( # pragma: no cover "method": None, "path": None, } - if ctx.request and isinstance(ctx.request, Request): + if ctx.request and isinstance(ctx.request, Request): # pragma: no branch request = ctx.request context_data["headers"] = dict(request.headers) context_data["method"] = request.method context_data["path"] = request.url.path return CallToolResult(content=[TextContent(type="text", text=json.dumps(context_data))]) - return CallToolResult(content=[TextContent(type="text", text=f"Unknown tool: {name}")]) + return CallToolResult(content=[TextContent(type="text", text=f"Unknown tool: {name}")]) # pragma: no cover -# Server runner for context-aware testing -def run_context_aware_server(port: int): # pragma: no cover - """Run the context-aware test server.""" +def create_context_aware_app() -> Starlette: + """Build the context-aware test app (echoes request headers via tools).""" server = Server( "ContextAwareServer", on_list_tools=_handle_context_list_tools, @@ -1531,7 +1416,7 @@ def run_context_aware_server(port: int): # pragma: no cover json_response=False, ) - app = Starlette( + return Starlette( debug=True, routes=[ Mount("/mcp", app=session_manager.handle_request), @@ -1539,36 +1424,16 @@ def run_context_aware_server(port: int): # pragma: no cover lifespan=lambda app: session_manager.run(), ) - server_instance = uvicorn.Server( - config=uvicorn.Config( - app=app, - host="127.0.0.1", - port=port, - log_level="error", - ) - ) - server_instance.run() - @pytest.fixture -def context_aware_server(basic_server_port: int) -> Generator[None, None, None]: - """Start the context-aware server in a separate process.""" - proc = multiprocessing.Process(target=run_context_aware_server, args=(basic_server_port,), daemon=True) - proc.start() - - # Wait for server to be running - wait_for_server(basic_server_port) - - yield - - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("Context-aware server process failed to terminate") +def context_server_url() -> Generator[str, None, None]: + """Start the context-aware server in a background thread. Yields its base URL.""" + with run_uvicorn_in_thread(create_context_aware_app()) as url: + yield url @pytest.mark.anyio -async def test_streamablehttp_request_context_propagation(context_aware_server: None, basic_server_url: str) -> None: +async def test_streamablehttp_request_context_propagation(context_server_url: str) -> None: """Test that request context is properly propagated through StreamableHTTP.""" custom_headers = { "Authorization": "Bearer test-token", @@ -1577,7 +1442,7 @@ async def test_streamablehttp_request_context_propagation(context_aware_server: } async with create_mcp_http_client(headers=custom_headers) as httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + async with streamable_http_client(f"{context_server_url}/mcp", http_client=httpx_client) as ( read_stream, write_stream, ): @@ -1601,7 +1466,7 @@ async def test_streamablehttp_request_context_propagation(context_aware_server: @pytest.mark.anyio -async def test_streamablehttp_request_context_isolation(context_aware_server: None, basic_server_url: str) -> None: +async def test_streamablehttp_request_context_isolation(context_server_url: str) -> None: """Test that request contexts are isolated between StreamableHTTP clients.""" contexts: list[dict[str, Any]] = [] @@ -1614,7 +1479,7 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No } async with create_mcp_http_client(headers=headers) as httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + async with streamable_http_client(f"{context_server_url}/mcp", http_client=httpx_client) as ( read_stream, write_stream, ): @@ -1639,9 +1504,9 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No @pytest.mark.anyio -async def test_client_includes_protocol_version_header_after_init(context_aware_server: None, basic_server_url: str): +async def test_client_includes_protocol_version_header_after_init(context_server_url: str): """Test that client includes mcp-protocol-version header after initialization.""" - async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): + async with streamable_http_client(f"{context_server_url}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: # Initialize and get the negotiated version init_result = await session.initialize() @@ -1659,7 +1524,7 @@ async def test_client_includes_protocol_version_header_after_init(context_aware_ assert headers_data[MCP_PROTOCOL_VERSION_HEADER] == negotiated_version -def test_server_validates_protocol_version_header(basic_server: None, basic_server_url: str): +def test_server_validates_protocol_version_header(basic_server_url: str): """Test that server returns 400 Bad Request version if header unsupported or invalid.""" # First initialize a session to get a valid session ID init_response = requests.post( @@ -1717,7 +1582,7 @@ def test_server_validates_protocol_version_header(basic_server: None, basic_serv assert response.status_code == 200 -def test_server_backwards_compatibility_no_protocol_version(basic_server: None, basic_server_url: str): +def test_server_backwards_compatibility_no_protocol_version(basic_server_url: str): """Test server accepts requests without protocol version header.""" # First initialize a session to get a valid session ID init_response = requests.post( @@ -1747,7 +1612,7 @@ def test_server_backwards_compatibility_no_protocol_version(basic_server: None, @pytest.mark.anyio -async def test_client_crash_handled(basic_server: None, basic_server_url: str): +async def test_client_crash_handled(basic_server_url: str): """Test that cases where the client crashes are handled gracefully.""" # Simulate bad client that crashes after init @@ -2219,9 +2084,7 @@ async def message_handler( @pytest.mark.anyio -async def test_streamable_http_client_does_not_mutate_provided_client( - basic_server: None, basic_server_url: str -) -> None: +async def test_streamable_http_client_does_not_mutate_provided_client(basic_server_url: str) -> None: """Test that streamable_http_client does not mutate the provided httpx client's headers.""" # Create a client with custom headers original_headers = { @@ -2252,9 +2115,7 @@ async def test_streamable_http_client_does_not_mutate_provided_client( @pytest.mark.anyio -async def test_streamable_http_client_mcp_headers_override_defaults( - context_aware_server: None, basic_server_url: str -) -> None: +async def test_streamable_http_client_mcp_headers_override_defaults(context_server_url: str) -> None: """Test that MCP protocol headers override httpx.AsyncClient default headers.""" # httpx.AsyncClient has default "accept: */*" header # We need to verify that our MCP accept header overrides it in actual requests @@ -2263,7 +2124,10 @@ async def test_streamable_http_client_mcp_headers_override_defaults( # Verify client has default accept header assert client.headers.get("accept") == "*/*" - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=client) as (read_stream, write_stream): + async with streamable_http_client(f"{context_server_url}/mcp", http_client=client) as ( + read_stream, + write_stream, + ): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch await session.initialize() @@ -2283,9 +2147,7 @@ async def test_streamable_http_client_mcp_headers_override_defaults( @pytest.mark.anyio -async def test_streamable_http_client_preserves_custom_with_mcp_headers( - context_aware_server: None, basic_server_url: str -) -> None: +async def test_streamable_http_client_preserves_custom_with_mcp_headers(context_server_url: str) -> None: """Test that both custom headers and MCP protocol headers are sent in requests.""" custom_headers = { "X-Custom-Header": "custom-value", @@ -2294,7 +2156,10 @@ async def test_streamable_http_client_preserves_custom_with_mcp_headers( } async with httpx.AsyncClient(headers=custom_headers, follow_redirects=True) as client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=client) as (read_stream, write_stream): + async with streamable_http_client(f"{context_server_url}/mcp", http_client=client) as ( + read_stream, + write_stream, + ): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch await session.initialize() diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 810c72820..27a3dc4c5 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -62,11 +62,15 @@ def wait_for_server(port: int, timeout: float = 20.0) -> None: """Wait for server to be ready to accept connections. Polls the server port until it accepts connections or timeout is reached. - This eliminates race conditions without arbitrary sleeps. + + .. deprecated:: + This has a race: the port may be bound by a different server (another + pytest-xdist worker). Prefer :func:`run_uvicorn_in_thread` which holds + the port atomically from bind until shutdown. Args: port: The port number to check - timeout: Maximum time to wait in seconds (default 5.0) + timeout: Maximum time to wait in seconds Raises: TimeoutError: If server doesn't start within the timeout period @@ -77,9 +81,7 @@ def wait_for_server(port: int, timeout: float = 20.0) -> None: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.settimeout(0.1) s.connect(("127.0.0.1", port)) - # Server is ready return except (ConnectionRefusedError, OSError): - # Server not ready yet, retry quickly time.sleep(0.01) raise TimeoutError(f"Server on port {port} did not start within {timeout} seconds") # pragma: no cover From 7e091ba7bef7f82993c97e67feac20dfa837f134 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 18 Mar 2026 13:41:10 +0000 Subject: [PATCH 2/2] fix: bound uvicorn graceful shutdown so thread.join doesn't abandon it mid-drain MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Without timeout_graceful_shutdown, uvicorn's shutdown() waits indefinitely for open connections to close. The SSE reconnection tests in test_streamable_http.py can leave streams open at fixture teardown, so the 5s thread.join times out and abandons the thread mid-shutdown. On Windows Proactor, the abandoned connection transports still have pending Overlapped Recv operations when the event loop is torn down. GC later finds them during an unrelated test, surfacing as PytestUnraisableExceptionWarning. Previously hidden by subprocess isolation. timeout_graceful_shutdown=1 gives uvicorn a bounded window to drain connections, then cancels remaining tasks via asyncio — transports unwind through CancelledError instead of being abandoned. --- tests/test_helpers.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 27a3dc4c5..c7a9d449e 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -47,6 +47,12 @@ def run_uvicorn_in_thread(app: Any, **config_kwargs: Any) -> Generator[str, None # which Python 3.14 deprecates. Under filterwarnings=error this crashes # the server thread silently. Starlette is asgi3; skip the autodetect. config_kwargs.setdefault("interface", "asgi3") + # shutdown() waits indefinitely for open connections to drain. SSE tests + # may leave streams open at teardown, so without a bound the join below + # times out and abandons the thread mid-shutdown — on Windows the + # Proactor's Overlapped Recv ops get GC'd pending. This bounds the wait, + # then cancels remaining tasks via asyncio so transports unwind cleanly. + config_kwargs.setdefault("timeout_graceful_shutdown", 1) server = uvicorn.Server(config=uvicorn.Config(app=app, **config_kwargs)) thread = threading.Thread(target=server.run, kwargs={"sockets": [sock]}, daemon=True)