diff --git a/examples/servers/mrtr-options/README.md b/examples/servers/mrtr-options/README.md new file mode 100644 index 000000000..0acb27d80 --- /dev/null +++ b/examples/servers/mrtr-options/README.md @@ -0,0 +1,147 @@ +# MRTR handler-shape options (SEP-2322) + +Python-SDK counterpart to [typescript-sdk#1701]. Seven ways to write the same +weather-lookup tool, so the diff between files is the argument. + +Unlike the TS demos, the lowlevel plumbing here is **real** — each option is +an actual `mcp.server.Server` that round-trips `IncompleteResult` through the +wire protocol. The invariant test at the bottom asserts they all produce +identical client-observed behaviour. + +[typescript-sdk#1701]: https://github.com/modelcontextprotocol/typescript-sdk/pull/1701 + +## Start here + +If you just want to see what an MRTR lowlevel handler looks like without +the comparison framing, read these first: + +- [`basic.py`](mrtr_options/basic.py) — the simple-tool equivalent. One + `IncompleteResult`, one retry, done. ~130 lines, half of which are + comments explaining the two moves every MRTR handler makes. +- [`basic_multiround.py`](mrtr_options/basic_multiround.py) — the + ADO-rules SEP example. Two rounds, with `request_state` carrying + accumulated context across the retry so any server instance can + handle any round. + +Both are runnable end-to-end against the in-memory client: + +```sh +uv run python -m mrtr_options.basic +uv run python -m mrtr_options.basic_multiround +``` + +## The quadrant + +| Server infra | Pre-MRTR client | MRTR client | +| ------------------------------- | --------------------------------- | ----------- | +| Can hold SSE | E by default; A/C/D if you opt in | MRTR | +| MRTR-only (horizontally scaled) | E by necessity | MRTR | + +Both rows *work* for old clients — version negotiation succeeds, +`tools/list` is complete, tools that don't elicit are unaffected. Only +elicitation inside a tool is unavailable. Bottom-left isn't "unresolvable"; +it's "E is the only option." Top-left is "E, unless you choose to carry SSE +infra." The rows collapse for E, which is why it's the SDK default. + +## Options + +| | Author writes | SDK does | Hidden re-entry | Server state | Old client gets | +| ------------------------------ | ------------------------------- | -------------------------------- | --------------- | -------------------- | --------------------------------- | +| [E](mrtr_options/option_e_degrade.py) | MRTR-native only | Nothing | No | None | Result w/ default, or error | +| [A](mrtr_options/option_a_sse_shim.py) | MRTR-native only | Retry-loop over SSE | Yes, safe | SSE connection | Full elicitation | +| [B](mrtr_options/option_b_await_shim.py) | `await elicit()` | Exception → `IncompleteResult` | **Yes, unsafe** | None | Full elicitation | +| [C](mrtr_options/option_c_version_branch.py) | One handler, `if version` branch | Version accessor | No | SSE (old-client arm) | Full elicitation | +| [D](mrtr_options/option_d_dual_handler.py) | Two handlers | Picks by version | No | SSE (old-client arm) | Full elicitation | +| [F](mrtr_options/option_f_ctx_once.py) | MRTR-native + `ctx.once` wraps | `once()` guard in request_state | No | None | (same as E) | +| [G](mrtr_options/option_g_tool_builder.py) | Step functions + `.build()` | Step-tracking in request_state | No | None | (same as E) | +| [H](mrtr_options/option_h_linear.py) | `await ctx.elicit()` (linear) | Holds coroutine frame in memory | No | Coroutine frame | (same as E) | + +"Hidden re-entry" = the handler function is invoked more than once for a +single logical tool call, and the author can't tell from the source text. + +**A is safe** because MRTR-native code has the re-entry guard (`if not +prefs: return IncompleteResult(...)`) visible in source even though the +*loop* is hidden. + +**B is unsafe** because `await elicit()` looks like a suspension point but +is actually a re-entry point on MRTR sessions — see the `audit_log` +landmine in that file. + +## Footgun prevention (F, G) + +A–E are about the dual-path axis (old client vs new). F and G address a +different axis: even in a pure-MRTR world, the naive handler shape has a +footgun. Code above the `if not prefs` guard runs on every retry. If that +code is a DB write or HTTP POST, it executes N times for N-round +elicitation. Nothing *enforces* putting side-effects below the guard — +safety depends on the developer knowing the convention. The analogy from +SDK-WG review: the naive MRTR handler is de-facto GOTO. + +**F (`MrtrCtx.once`)** keeps the monolithic handler but wraps side-effects +in an idempotency guard. `ctx.once("audit", lambda: audit_log(...))` checks +`request_state` — if the key is marked executed, skip. Opt-in: an unwrapped +mutation still fires twice. The footgun is made *visually distinct*, which +is reviewable. + +**G (`ToolBuilder`)** decomposes the handler into named step functions. +`incomplete_step` may return `IncompleteResult` or data; `end_step` receives +everything and runs exactly once. There is no "above the guard" zone because +there is no guard — the SDK's step-tracking is the guard. Side-effects go in +`end_step`, structurally unreachable until all elicitations complete. + +Both depend on `request_state` integrity. The demos use plain base64-JSON; +a real SDK MUST HMAC-sign the blob, or the client can forge step-done +markers and skip the guards. Per-session key derived from `initialize` keeps +it stateless. Without signing, the safety story is advisory. + +## Trade-offs + +**E is the SDK default.** A horizontally-scaled server gets E for free — +it's the only thing that works on that infra. A server that can hold SSE +also gets E by default, and opts into A/C/D only if serving old-client +elicitation is worth the extra infra dependency. + +**A vs E** is the core tension. Same author-facing code (MRTR-native), the +only difference is whether old clients get elicitation. A requires shipping +`sse_retry_shim`; E requires nothing. A also carries a deployment-time +hazard E doesn't: the shim calls real SSE under the hood, so on MRTR-only +infra it fails at runtime when an old client connects — a constraint that +lives nowhere near the tool code. + +**B** is zero-migration but breaks silently for anything non-idempotent +above the await. Not a ship target. + +**C vs D** is factoring: one function with a branch vs two functions with a +dispatcher. Both put the dual-path burden on the tool author. + +**F vs G** is the footgun-prevention trade. F is minimal — one line per +side-effect, composes with any handler shape. G is structural — +double-execution impossible for `end_step`, but costs two function defs +per tool. Likely SDK answer: ship F as a primitive on the context, ship G +as an opt-in builder, recommend G for multi-round tools and F for +single-question tools. + +**H (linear continuation)** is the Option B footgun, *fixed*. Handler code +reads exactly like the SSE era — `await ctx.elicit()` is a genuine +suspension point, side-effects above it fire once — because the coroutine +frame is held in memory across rounds. The trade: server is stateful +*within* a single tool call (frame keyed by `request_state`), so +horizontally-scaled deployments need sticky routing on the token. Same +operational shape as A's SSE hold but without the long-lived connection. +Use for migrating existing SSE-era tools without rewriting, or when the +linear style is genuinely clearer than guard-first. Don't use if you need +true statelessness — E/F/G encode everything in `request_state` itself. + +## The invariant test + +`tests/server/experimental/test_mrtr_options.py` parametrises all seven +servers against the same `Client` + `elicitation_callback`, asserting +identical output. The footgun test measures `audit_count` to prove F and G +hold the side-effect to one. + +## Not in scope + +- Persistent/Tasks workflow — `ServerTaskContext` already does + `input_required`; MRTR integration is a separate PR +- `mrtrOnly` client flag — trivial to add, not demoed +- requestState HMAC signing — called out in code comments diff --git a/examples/servers/mrtr-options/mrtr_options/__init__.py b/examples/servers/mrtr-options/mrtr_options/__init__.py new file mode 100644 index 000000000..7c027d29a --- /dev/null +++ b/examples/servers/mrtr-options/mrtr_options/__init__.py @@ -0,0 +1,7 @@ +"""MRTR handler-shape comparison — seven options on the same weather tool. + +See README.md for the trade-off matrix. Every option here is a real lowlevel +``mcp.server.Server`` that produces identical wire behaviour to each client +version — the server's internal choice doesn't leak. That's the argument +against per-feature ``-mrtr`` capability flags. +""" diff --git a/examples/servers/mrtr-options/mrtr_options/_shared.py b/examples/servers/mrtr-options/mrtr_options/_shared.py new file mode 100644 index 000000000..b16e42a39 --- /dev/null +++ b/examples/servers/mrtr-options/mrtr_options/_shared.py @@ -0,0 +1,58 @@ +"""Domain logic shared across all options — *not* SDK machinery. + +The weather tool: given a location, asks which units, returns a temperature +string. Same tool throughout so the diff between option files is the +argument. + +``audit_log`` is the side-effect that makes the MRTR footgun concrete: under +naive re-entry it fires once per round. Options F and G tame it. +""" + +from __future__ import annotations + +from mcp import types +from mcp.server import Server, ServerRequestContext + +UNITS_SCHEMA: types.ElicitRequestedSchema = { + "type": "object", + "properties": {"units": {"type": "string", "enum": ["metric", "imperial"], "title": "Units"}}, + "required": ["units"], +} + +UNITS_REQUEST = types.ElicitRequest( + params=types.ElicitRequestFormParams(message="Which units?", requested_schema=UNITS_SCHEMA) +) + + +def lookup_weather(location: str, units: str) -> str: + temp = "22°C" if units == "metric" else "72°F" + return f"Weather in {location}: {temp}, partly cloudy." + + +_audit_count = 0 + + +def audit_log(location: str) -> None: + """The footgun. Under naive re-entry this fires N times for N-round MRTR.""" + global _audit_count + _audit_count += 1 + print(f"[audit] lookup requested for {location} (count={_audit_count})") + + +def audit_count() -> int: + return _audit_count + + +def reset_audit() -> None: + global _audit_count + _audit_count = 0 + + +async def no_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> types.ListToolsResult: + """Minimal tools/list handler so Client validation has something to call.""" + return types.ListToolsResult(tools=[]) + + +def build_server(name: str, on_call_tool: object, **kwargs: object) -> Server: + """Consistent Server construction across option files.""" + return Server(name, on_call_tool=on_call_tool, on_list_tools=no_tools, **kwargs) # type: ignore[arg-type] diff --git a/examples/servers/mrtr-options/mrtr_options/basic.py b/examples/servers/mrtr-options/mrtr_options/basic.py new file mode 100644 index 000000000..de9716fd1 --- /dev/null +++ b/examples/servers/mrtr-options/mrtr_options/basic.py @@ -0,0 +1,128 @@ +"""The minimal MRTR lowlevel server — the simple-tool equivalent. + +No version checks, no comparison framing. Just the two moves every MRTR +handler makes: + + 1. Check ``params.input_responses`` for the answer to a prior ask. + 2. If it's not there, return ``IncompleteResult`` with the ask embedded. + +The client SDK (``mcp.client.Client.call_tool``) drives the retry loop — +this handler is invoked once per round with whatever the client collected. + +Run against the in-memory client: + + uv run python -m mrtr_options.basic +""" + +from __future__ import annotations + +import anyio + +from mcp import types +from mcp.client import Client +from mcp.client.context import ClientRequestContext +from mcp.server import Server, ServerRequestContext + + +async def on_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool( + name="get_weather", + description="Look up weather for a location. Asks which units you want.", + input_schema={ + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + }, + ) + ] + ) + + +async def on_call_tool( + ctx: ServerRequestContext, params: types.CallToolRequestParams +) -> types.CallToolResult | types.IncompleteResult: + """The MRTR tool handler. Called once per round.""" + location = (params.arguments or {}).get("location", "?") + + # ─────────────────────────────────────────────────────────────────────── + # Step 1: check if the client has already answered our question. + # + # ``input_responses`` is a dict keyed by the same keys we used in + # ``input_requests`` on the prior round. Each value is the raw result + # the client produced (ElicitResult, CreateMessageResult, ListRootsResult + # — serialized to dict form over the wire). + # + # On the first round, ``input_responses`` is None. On subsequent rounds, + # it contains ONLY the answers to the most recent round's asks — not + # accumulated across rounds. If you need to accumulate, encode it in + # ``request_state`` (see option_f_ctx_once.py / option_g_tool_builder.py). + # ─────────────────────────────────────────────────────────────────────── + responses = params.input_responses or {} + prefs = responses.get("unit_prefs") + + if prefs is None or prefs.get("action") != "accept": + # ─────────────────────────────────────────────────────────────────── + # Step 2: ask. Return IncompleteResult with the embedded request. + # + # The client SDK receives this, dispatches the embedded ElicitRequest + # to its elicitation_callback, and re-invokes this handler with the + # answer in input_responses["unit_prefs"]. + # + # Keys are server-assigned and opaque to the client. Pick whatever + # makes the code readable — they just need to be consistent between + # the ask and the check above. + # ─────────────────────────────────────────────────────────────────── + return types.IncompleteResult( + input_requests={ + "unit_prefs": types.ElicitRequest( + params=types.ElicitRequestFormParams( + message="Which units for the temperature?", + requested_schema={ + "type": "object", + "properties": {"units": {"type": "string", "enum": ["metric", "imperial"]}}, + "required": ["units"], + }, + ) + ) + }, + # request_state is optional. Use it for anything that must + # survive across rounds without server-side storage — e.g. + # partially-computed results, progress markers, or (in F/G) + # idempotency guards. The client echoes it verbatim. + request_state=None, + ) + + # ─────────────────────────────────────────────────────────────────────── + # Step 3: we have the answer. Compute and return a normal result. + # ─────────────────────────────────────────────────────────────────────── + units = prefs["content"]["units"] + temp = "22°C" if units == "metric" else "72°F" + return types.CallToolResult(content=[types.TextContent(text=f"Weather in {location}: {temp}, partly cloudy.")]) + + +server = Server("mrtr-basic", on_list_tools=on_list_tools, on_call_tool=on_call_tool) + + +# ─── Demo driver ───────────────────────────────────────────────────────────── + + +async def elicitation_callback(context: ClientRequestContext, params: types.ElicitRequestParams) -> types.ElicitResult: + """What the app developer writes. Same signature as SSE-era callbacks.""" + assert isinstance(params, types.ElicitRequestFormParams) + print(f"[client] server asks: {params.message}") + # A real client presents params.requested_schema as a form. We hard-code. + return types.ElicitResult(action="accept", content={"units": "metric"}) + + +async def main() -> None: + async with Client(server, elicitation_callback=elicitation_callback) as client: + result = await client.call_tool("get_weather", {"location": "Tokyo"}) + print(f"[client] result: {result.content[0].text}") # type: ignore[union-attr] + + +if __name__ == "__main__": + anyio.run(main) diff --git a/examples/servers/mrtr-options/mrtr_options/basic_multiround.py b/examples/servers/mrtr-options/mrtr_options/basic_multiround.py new file mode 100644 index 000000000..798f9d9ef --- /dev/null +++ b/examples/servers/mrtr-options/mrtr_options/basic_multiround.py @@ -0,0 +1,167 @@ +"""Multi-round MRTR with request_state accumulation. + +This is the ADO-custom-rules example from the SEP, translated. Resolving +a work item triggers cascading required fields: + + Rule 1: State → "Resolved" requires a Resolution field + Rule 2: Resolution = "Duplicate" requires a "Duplicate Of" link + +The server learns Rule 2 is needed only after the user answers Rule 1. +Two rounds of elicitation. The Rule 1 answer must survive across rounds +*without server-side storage* — that's what ``request_state`` is for. + +Key point: ``input_responses`` carries only the *latest* round's answers. +Round 2's retry has ``{"duplicate_of": ...}`` but NOT ``{"resolution": ...}``. +Anything the server needs to keep must be encoded in ``request_state``, +which the client echoes verbatim. + +Run against the in-memory client: + + uv run python -m mrtr_options.basic_multiround +""" + +from __future__ import annotations + +import base64 +import json +from typing import Any + +import anyio + +from mcp import types +from mcp.client import Client +from mcp.client.context import ClientRequestContext +from mcp.server import Server, ServerRequestContext + + +def encode_state(state: dict[str, Any]) -> str: + """Serialize state for the round trip through the client. + + Plain base64-JSON here. A production server handling sensitive data + MUST sign this — the client is an untrusted intermediary and could + forge or replay state otherwise. See SEP-2322 §Security Implications. + """ + return base64.b64encode(json.dumps(state).encode()).decode() + + +def decode_state(blob: str | None) -> dict[str, Any]: + if not blob: + return {} + return json.loads(base64.b64decode(blob)) + + +def ask(message: str, field: str) -> types.ElicitRequest: + """Build a form-mode elicitation for a single string field.""" + return types.ElicitRequest( + params=types.ElicitRequestFormParams( + message=message, + requested_schema={ + "type": "object", + "properties": {field: {"type": "string"}}, + "required": [field], + }, + ) + ) + + +async def on_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool( + name="resolve_work_item", + description="Resolve a work item. May need cascading follow-up fields.", + input_schema={ + "type": "object", + "properties": {"work_item_id": {"type": "integer"}}, + "required": ["work_item_id"], + }, + ) + ] + ) + + +async def on_call_tool( + ctx: ServerRequestContext, params: types.CallToolRequestParams +) -> types.CallToolResult | types.IncompleteResult: + args = params.arguments or {} + work_item_id = args.get("work_item_id", 0) + responses = params.input_responses or {} + state = decode_state(params.request_state) + + # ─────────────────────────────────────────────────────────────────────── + # Round 1: State → Resolved triggers Rule 1 (require Resolution). + # + # If we don't yet have the resolution — neither in this round's + # input_responses nor in accumulated state — ask for it. + # ─────────────────────────────────────────────────────────────────────── + resolution = state.get("resolution") + if not resolution: + resp = responses.get("resolution") + if not resp or resp.get("action") != "accept": + return types.IncompleteResult( + input_requests={ + "resolution": ask( + f"Resolving #{work_item_id} requires a resolution. Fixed, Won't Fix, Duplicate, or By Design?", + "resolution", + ) + }, + # No state yet — the original tool arguments are re-sent on + # retry, so we don't need to encode anything for round 1. + ) + resolution = resp["content"]["resolution"] + + # ─────────────────────────────────────────────────────────────────────── + # Round 2: Resolution = "Duplicate" triggers Rule 2 (require link). + # + # If the resolution is Duplicate and we don't yet have the link, ask + # for it — but encode the already-gathered resolution in request_state + # so it survives the round trip regardless of which server instance + # handles the next retry. + # ─────────────────────────────────────────────────────────────────────── + if resolution == "Duplicate": + resp = responses.get("duplicate_of") + if not resp or resp.get("action") != "accept": + return types.IncompleteResult( + input_requests={"duplicate_of": ask("Which work item is the original?", "duplicate_of")}, + request_state=encode_state({"resolution": resolution}), + ) + dup = resp["content"]["duplicate_of"] + text = f"#{work_item_id} resolved as Duplicate of #{dup}." + else: + text = f"#{work_item_id} resolved as {resolution}." + + return types.CallToolResult(content=[types.TextContent(text=text)]) + + +server = Server("mrtr-multiround", on_list_tools=on_list_tools, on_call_tool=on_call_tool) + + +# ─── Demo driver ───────────────────────────────────────────────────────────── + + +ANSWERS = { + "resolution": "Duplicate", + "duplicate_of": "4301", +} + + +async def elicitation_callback(context: ClientRequestContext, params: types.ElicitRequestParams) -> types.ElicitResult: + assert isinstance(params, types.ElicitRequestFormParams) + print(f"[client] server asks: {params.message}") + # Pick the field name from the schema and answer from our table. + field = next(iter(params.requested_schema["properties"])) + answer = ANSWERS[field] + print(f"[client] answering {field}={answer}") + return types.ElicitResult(action="accept", content={field: answer}) + + +async def main() -> None: + async with Client(server, elicitation_callback=elicitation_callback) as client: + result = await client.call_tool("resolve_work_item", {"work_item_id": 4522}) + print(f"[client] final: {result.content[0].text}") # type: ignore[union-attr] + + +if __name__ == "__main__": + anyio.run(main) diff --git a/examples/servers/mrtr-options/mrtr_options/option_a_sse_shim.py b/examples/servers/mrtr-options/mrtr_options/option_a_sse_shim.py new file mode 100644 index 000000000..4e7c14736 --- /dev/null +++ b/examples/servers/mrtr-options/mrtr_options/option_a_sse_shim.py @@ -0,0 +1,60 @@ +"""Option A: SDK shim emulates the MRTR retry loop over SSE. Hidden loop. + +Tool author writes MRTR-native code only. The SDK wrapper detects the +negotiated version: + - new client → pass ``IncompleteResult`` through, client drives retry + - old client → SDK runs the retry loop *locally*, fulfilling each + ``InputRequest`` via real SSE (``ctx.session.elicit_form()``), + re-invoking the handler until it returns a complete result + +Author experience: one code path. Re-entry is explicit in source (the +``if not prefs`` guard), so the handler is safe to re-invoke by +construction. But the *fact* that it's re-invoked for old clients is +invisible — the shim is doing work the author can't see. + +What makes this "clunky but possible": the SDK runs a loop on the +author's behalf. If the handler does something expensive before the +guard, the author won't find out until an old client connects in prod. +Works, but it's magic. + +Deployment hazard: ``sse_retry_shim`` calls real SSE under the hood. +On MRTR-only infra it fails at runtime when an old client connects — +a constraint that lives nowhere near the tool code. If that's the +deployment, use Option E. +""" + +from __future__ import annotations + +from mcp import types +from mcp.server import ServerRequestContext +from mcp.server.experimental.mrtr import input_response, sse_retry_shim + +from ._shared import UNITS_REQUEST, build_server, lookup_weather + +# ─────────────────────────────────────────────────────────────────────────── +# This is what the tool author writes. One function, MRTR-native. No +# version check, no SSE awareness. The ``if not prefs`` guard IS the +# re-entry contract; the author sees it, but doesn't see the shim +# calling this in a loop for old-client sessions. +# ─────────────────────────────────────────────────────────────────────────── + + +async def weather( + ctx: ServerRequestContext, params: types.CallToolRequestParams +) -> types.CallToolResult | types.IncompleteResult: + location = (params.arguments or {}).get("location", "?") + + prefs = input_response(params, "units") + if prefs is None: + return types.IncompleteResult(input_requests={"units": UNITS_REQUEST}) + + return types.CallToolResult(content=[types.TextContent(text=lookup_weather(location, prefs["units"]))]) + + +# ─────────────────────────────────────────────────────────────────────────── +# Registration applies the shim. In a real SDK this could be a flag on +# ``add_tool`` or inferred from the handler signature — the author opts in +# once at registration, not per-call. +# ─────────────────────────────────────────────────────────────────────────── + +server = build_server("mrtr-option-a", on_call_tool=sse_retry_shim(weather)) diff --git a/examples/servers/mrtr-options/mrtr_options/option_b_await_shim.py b/examples/servers/mrtr-options/mrtr_options/option_b_await_shim.py new file mode 100644 index 000000000..92ba70875 --- /dev/null +++ b/examples/servers/mrtr-options/mrtr_options/option_b_await_shim.py @@ -0,0 +1,93 @@ +"""Option B: exception-based shim, ``await elicit()`` canonical. The footgun. + +Tool author writes today's ``await ctx.elicit(...)`` style. The shim routes: + - old client → native SSE, blocks inline (today's behaviour exactly) + - new client → ``elicit()`` raises ``NeedsInputSignal``, shim catches, + emits ``IncompleteResult``. On retry the handler runs *from the top* + and this time ``elicit()`` finds the answer in ``input_responses``. + +Author experience: zero migration. Handlers that work today keep working. +The ``await`` reads linearly. + +The problem: the ``await`` is a lie on MRTR sessions. Everything above it +re-executes on retry. Uncomment the ``audit_log()`` call below — an MRTR +client triggers *two* audit entries for one tool call. A pre-MRTR client +triggers one. Same source, different observable behaviour, nothing warns. + +Only safe if you can enforce "no side-effects before await" as a lint +rule, which is hard in practice. + +**This is not a ship target — it's a cautionary comparison.** +""" + +from __future__ import annotations + +from mcp import types +from mcp.server import ServerRequestContext +from mcp.server.experimental.mrtr import input_response + +from ._shared import UNITS_REQUEST, UNITS_SCHEMA, build_server, lookup_weather + + +class NeedsInputSignal(Exception): + """Control-flow-by-exception. Unwound by the shim, packaged as IncompleteResult.""" + + def __init__(self, input_requests: types.InputRequests) -> None: + self.input_requests = input_requests + super().__init__("NeedsInputSignal (control flow, not an error)") + + +async def elicit_or_signal( + ctx: ServerRequestContext, params: types.CallToolRequestParams, key: str +) -> dict[str, str] | None: + """The ``await``-able elicit that looks linear but isn't on MRTR.""" + version = ctx.session.client_params.protocol_version if ctx.session.client_params else None + + # Old client: native SSE, no trickery. + if version is None or str(version) < "2026-06-01": + result = await ctx.session.elicit_form(message="Which units?", requested_schema=UNITS_SCHEMA) + if result.action != "accept" or not result.content: + return None + return {k: str(v) for k, v in result.content.items()} + + # New client: check input_responses first. + prefs = input_response(params, key) + if prefs is not None: + return {k: str(v) for k, v in prefs.items()} + + # Not pre-supplied → signal the shim. Everything on the stack unwinds. + # On retry the handler re-executes from line one. + raise NeedsInputSignal({key: UNITS_REQUEST}) + + +# ─────────────────────────────────────────────────────────────────────────── +# This is what the tool author writes. Looks linear. Isn't, on MRTR. +# ─────────────────────────────────────────────────────────────────────────── + + +async def _weather_inner(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + location = (params.arguments or {}).get("location", "?") + + # audit_log(location) + # ^^^^^^^^^^^^^^^^^^ + # On pre-MRTR: runs once. On MRTR: runs once on the initial call, + # once more on the retry. The await below isn't a suspension point + # on MRTR — it's a re-entry point. Nothing in this syntax says so. + + prefs = await elicit_or_signal(ctx, params, "units") + if not prefs: + return types.CallToolResult(content=[types.TextContent(text="Cancelled.")]) + + return types.CallToolResult(content=[types.TextContent(text=lookup_weather(location, prefs["units"]))]) + + +async def weather( + ctx: ServerRequestContext, params: types.CallToolRequestParams +) -> types.CallToolResult | types.IncompleteResult: + try: + return await _weather_inner(ctx, params) + except NeedsInputSignal as signal: + return types.IncompleteResult(input_requests=signal.input_requests) + + +server = build_server("mrtr-option-b", on_call_tool=weather) diff --git a/examples/servers/mrtr-options/mrtr_options/option_c_version_branch.py b/examples/servers/mrtr-options/mrtr_options/option_c_version_branch.py new file mode 100644 index 000000000..8d7042e86 --- /dev/null +++ b/examples/servers/mrtr-options/mrtr_options/option_c_version_branch.py @@ -0,0 +1,52 @@ +"""Option C: explicit version branch in the handler body. + +No shim. Tool author checks the negotiated version themselves and writes +both code paths inline. The SDK provides nothing except the version +accessor and the raw primitives for each path. + +Author experience: everything is visible. Both protocol behaviours are +right there in source, separated by an ``if``. No hidden re-entry, no +magic wrappers. A reader traces exactly what happens for each client +version. + +The cost is also visible: the elicitation schema is duplicated, the +cancel-handling is duplicated, and there's a conditional at the top of +every handler that uses elicitation. For one tool, fine. For twenty, +it's twenty copies of the same branch. +""" + +from __future__ import annotations + +from mcp import types +from mcp.server import ServerRequestContext +from mcp.server.experimental.mrtr import input_response + +from ._shared import UNITS_REQUEST, UNITS_SCHEMA, build_server, lookup_weather + + +async def weather( + ctx: ServerRequestContext, params: types.CallToolRequestParams +) -> types.CallToolResult | types.IncompleteResult: + location = (params.arguments or {}).get("location", "?") + version = ctx.session.client_params.protocol_version if ctx.session.client_params else None + + # ─────────────────────────────────────────────────────────────────────── + # The branch is the whole story. + # ─────────────────────────────────────────────────────────────────────── + + if version is not None and str(version) >= "2026-06-01": + # MRTR path: check input_responses, return IncompleteResult if missing. + prefs = input_response(params, "units") + if prefs is None: + return types.IncompleteResult(input_requests={"units": UNITS_REQUEST}) + return types.CallToolResult(content=[types.TextContent(text=lookup_weather(location, prefs["units"]))]) + + # SSE path: inline await, blocks on the response stream. + result = await ctx.session.elicit_form(message="Which units?", requested_schema=UNITS_SCHEMA) + if result.action != "accept" or not result.content: + return types.CallToolResult(content=[types.TextContent(text="Cancelled.")]) + units = str(result.content.get("units", "metric")) + return types.CallToolResult(content=[types.TextContent(text=lookup_weather(location, units))]) + + +server = build_server("mrtr-option-c", on_call_tool=weather) diff --git a/examples/servers/mrtr-options/mrtr_options/option_d_dual_handler.py b/examples/servers/mrtr-options/mrtr_options/option_d_dual_handler.py new file mode 100644 index 000000000..fc4d00fe0 --- /dev/null +++ b/examples/servers/mrtr-options/mrtr_options/option_d_dual_handler.py @@ -0,0 +1,57 @@ +"""Option D: dual registration. Two handlers, SDK picks by version. + +Tool author writes two separate functions — one MRTR-native, one +SSE-native — and hands both to the SDK. Dispatch by negotiated version. +No shim converts between them; each path is exactly what the author +wrote for that protocol era. + +Author experience: no hidden control flow. Unlike Option C, the two +paths are structurally separated rather than tangled in one body. +Shared logic factors out naturally. Each handler readable in isolation. + +The cost: two functions per elicitation-using tool, both live until SSE +is deprecated. There's no mechanical link between them — if the MRTR +handler changes the schema and the SSE one doesn't, nothing catches it. +Also: the registration API grows a shape that only exists for the +transition period. +""" + +from __future__ import annotations + +from mcp import types +from mcp.server import ServerRequestContext +from mcp.server.experimental.mrtr import dispatch_by_version, input_response + +from ._shared import UNITS_REQUEST, UNITS_SCHEMA, build_server, lookup_weather + +# ─────────────────────────────────────────────────────────────────────────── +# Two functions. Each clean in isolation. +# ─────────────────────────────────────────────────────────────────────────── + + +async def weather_mrtr( + ctx: ServerRequestContext, params: types.CallToolRequestParams +) -> types.CallToolResult | types.IncompleteResult: + location = (params.arguments or {}).get("location", "?") + prefs = input_response(params, "units") + if prefs is None: + return types.IncompleteResult(input_requests={"units": UNITS_REQUEST}) + return types.CallToolResult(content=[types.TextContent(text=lookup_weather(location, prefs["units"]))]) + + +async def weather_sse(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + location = (params.arguments or {}).get("location", "?") + result = await ctx.session.elicit_form(message="Which units?", requested_schema=UNITS_SCHEMA) + if result.action != "accept" or not result.content: + return types.CallToolResult(content=[types.TextContent(text="Cancelled.")]) + units = str(result.content.get("units", "metric")) + return types.CallToolResult(content=[types.TextContent(text=lookup_weather(location, units))]) + + +# ─────────────────────────────────────────────────────────────────────────── +# Registration takes both. Real SDK shape might be an overload or a +# ``{mrtr:, sse:}`` dict — point is both handlers are visible at the +# registration site and the SDK owns the switch. +# ─────────────────────────────────────────────────────────────────────────── + +server = build_server("mrtr-option-d", on_call_tool=dispatch_by_version(mrtr=weather_mrtr, sse=weather_sse)) diff --git a/examples/servers/mrtr-options/mrtr_options/option_e_degrade.py b/examples/servers/mrtr-options/mrtr_options/option_e_degrade.py new file mode 100644 index 000000000..cc94d50a2 --- /dev/null +++ b/examples/servers/mrtr-options/mrtr_options/option_e_degrade.py @@ -0,0 +1,68 @@ +"""Option E: graceful degradation. The SDK default. + +Tool author writes MRTR-native code only. Pre-MRTR clients get a result +with a default (or an error — author's choice) for *this tool*; everything +else on the server is unaffected. Version negotiation succeeds, tools/list +is complete, tools that don't elicit work normally. + +This is the only option that works on horizontally-scaled MRTR-only infra, +and it's also correct on SSE-capable infra — both quadrant rows collapse +here. That's why it's the default: a server adopting the new SDK gets this +behaviour without asking. A/C/D are opt-in for servers that choose to carry +SSE through the transition. + +Author experience: one code path, trivially understood. The version check +is one line at the top; everything below is plain MRTR. +""" + +from __future__ import annotations + +from mcp import types +from mcp.server import ServerRequestContext +from mcp.server.experimental.mrtr import input_response + +from ._shared import UNITS_REQUEST, build_server, lookup_weather + +MRTR_MIN_VERSION = "2026-06-01" + + +async def weather( + ctx: ServerRequestContext, params: types.CallToolRequestParams +) -> types.CallToolResult | types.IncompleteResult: + location = (params.arguments or {}).get("location", "?") + + # ─────────────────────────────────────────────────────────────────────── + # Pre-MRTR session: elicitation unavailable. Tool author decides what + # that means — not the SDK, not the spec. + # + # For weather, unit preference is nice-to-have. Defaulting to metric + # and returning the answer is a better old-client experience than + # "upgrade your client to check the weather." + # + # For a tool where the elicitation is essential — confirming a + # destructive action, collecting required auth — error instead: + # + # return types.CallToolResult( + # content=[types.TextContent( + # text=f"This tool requires protocol version {MRTR_MIN_VERSION}+." + # )], + # is_error=True, + # ) + # + # Either way: no SSE code path. The server is still a valid 2025-11 + # server — it just doesn't use the client's declared elicitation + # capability. Servers are already allowed to do that. No new flags, + # no special negotiation. + # ─────────────────────────────────────────────────────────────────────── + version = ctx.session.client_params.protocol_version if ctx.session.client_params else None + if version is None or str(version) < MRTR_MIN_VERSION: + return types.CallToolResult(content=[types.TextContent(text=lookup_weather(location, "metric"))]) + + prefs = input_response(params, "units") + if prefs is None: + return types.IncompleteResult(input_requests={"units": UNITS_REQUEST}) + + return types.CallToolResult(content=[types.TextContent(text=lookup_weather(location, prefs["units"]))]) + + +server = build_server("mrtr-option-e", on_call_tool=weather) diff --git a/examples/servers/mrtr-options/mrtr_options/option_f_ctx_once.py b/examples/servers/mrtr-options/mrtr_options/option_f_ctx_once.py new file mode 100644 index 000000000..3c61d0c6c --- /dev/null +++ b/examples/servers/mrtr-options/mrtr_options/option_f_ctx_once.py @@ -0,0 +1,50 @@ +"""Option F: ``ctx.once`` idempotency guard inside the monolithic handler. + +Same MRTR-native shape as E, but side-effects get wrapped in +``ctx.once(key, fn)``. The guard lives in ``request_state`` — on retry, +keys marked executed skip their fn. Makes the hazard *visible* at the +call site without restructuring the handler. + +Opt-in: an unwrapped mutation still fires twice. The footgun isn't +eliminated — it's made reviewable. ``ctx.once("x", ...)`` reads +differently from a bare call; a reviewer can grep for effects that +aren't wrapped. + +When to reach for this over G (ToolBuilder): single elicitation, one +or two side-effects, handler fits in ten lines. When the step count +hits 3+, ToolBuilder's boilerplate pays for itself. +""" + +from __future__ import annotations + +from mcp import types +from mcp.server import ServerRequestContext +from mcp.server.experimental.mrtr import MrtrCtx, input_response + +from ._shared import UNITS_REQUEST, audit_log, build_server, lookup_weather + + +async def weather( + ctx: ServerRequestContext, params: types.CallToolRequestParams +) -> types.CallToolResult | types.IncompleteResult: + location = (params.arguments or {}).get("location", "?") + mrtr = MrtrCtx(params) + + # ─────────────────────────────────────────────────────────────────────── + # This is the hazard line. In E it would run on every retry. + # Here it runs once — ``once`` checks request_state, skips on retry. + # A reviewer sees ``mrtr.once`` and knows the author considered + # re-entry. A bare ``audit_log(location)`` would be the red flag. + # ─────────────────────────────────────────────────────────────────────── + mrtr.once("audit", lambda: audit_log(location)) + + prefs = input_response(params, "units") + if prefs is None: + # ``mrtr.incomplete()`` encodes the executed-keys set into + # request_state so the guard holds across retry. + return mrtr.incomplete({"units": UNITS_REQUEST}) + + return types.CallToolResult(content=[types.TextContent(text=lookup_weather(location, prefs["units"]))]) + + +server = build_server("mrtr-option-f", on_call_tool=weather) diff --git a/examples/servers/mrtr-options/mrtr_options/option_g_tool_builder.py b/examples/servers/mrtr-options/mrtr_options/option_g_tool_builder.py new file mode 100644 index 000000000..efd97229c --- /dev/null +++ b/examples/servers/mrtr-options/mrtr_options/option_g_tool_builder.py @@ -0,0 +1,68 @@ +"""Option G: ``ToolBuilder`` — explicit step decomposition. + +The monolithic handler becomes a sequence of named step functions. +``incomplete_step`` may return ``IncompleteResult`` (needs more input) +or a dict (satisfied, pass to next step). ``end_step`` receives +everything and runs exactly once — structurally unreachable until +every prior step has returned data. + +The footgun is eliminated by code shape, not discipline. There is no +"above the guard" zone because there is no guard — the SDK's step +tracking (via ``request_state``) *is* the guard. Side-effects go in +``end_step``; anything in an ``incomplete_step`` is documented as +must-be-idempotent, and the return-type split makes that distinction +visible at the function signature level. + +Boilerplate: two function defs + ``.build()`` to replace E's 3-line +guard. Worth it at 3+ rounds or when the side-effect story matters. +Overkill for a single-question tool where F is lighter. +""" + +from __future__ import annotations + +from typing import Any + +from mcp import types +from mcp.server.experimental.mrtr import ToolBuilder + +from ._shared import UNITS_REQUEST, audit_log, build_server, lookup_weather + +# ─────────────────────────────────────────────────────────────────────────── +# Step 1: ask for units. Returns IncompleteResult if not yet provided, +# or ``{"units": ...}`` to pass forward. MUST be idempotent — it can +# re-run if request_state is tampered with (unsigned in this draft) or +# on a partial replay. No side-effects here. +# ─────────────────────────────────────────────────────────────────────────── + + +def ask_units(args: dict[str, Any], inputs: dict[str, Any]) -> types.IncompleteResult | dict[str, Any]: + resp = inputs.get("units") + if not resp or resp.get("action") != "accept": + return types.IncompleteResult(input_requests={"units": UNITS_REQUEST}) + return {"units": resp["content"]["units"]} + + +# ─────────────────────────────────────────────────────────────────────────── +# End step: has everything, does the work. Runs exactly once. This is +# where side-effects live — the SDK guarantees this function is not +# reached until ``ask_units`` (and any other incomplete steps) have all +# returned data. ``audit_log`` here fires once regardless of how many +# MRTR rounds it took to collect the inputs. +# ─────────────────────────────────────────────────────────────────────────── + + +def fetch_weather(args: dict[str, Any], collected: dict[str, Any]) -> types.CallToolResult: + location = (args or {}).get("location", "?") + audit_log(location) + return types.CallToolResult(content=[types.TextContent(text=lookup_weather(location, collected["units"]))]) + + +# ─────────────────────────────────────────────────────────────────────────── +# Assembly. Steps are named so reordering during development doesn't +# silently remap data. The builder output is directly a lowlevel +# ``on_call_tool`` handler — no extra wrapping. +# ─────────────────────────────────────────────────────────────────────────── + +weather = ToolBuilder[dict[str, Any]]().incomplete_step("ask_units", ask_units).end_step(fetch_weather).build() + +server = build_server("mrtr-option-g", on_call_tool=weather) diff --git a/examples/servers/mrtr-options/mrtr_options/option_h_linear.py b/examples/servers/mrtr-options/mrtr_options/option_h_linear.py new file mode 100644 index 000000000..fe47a0ad3 --- /dev/null +++ b/examples/servers/mrtr-options/mrtr_options/option_h_linear.py @@ -0,0 +1,63 @@ +"""Option H: continuation-based linear MRTR. ``await ctx.elicit()`` is genuine. + +The Option B footgun was: ``await elicit()`` *looks* like a suspension point +but is actually a re-entry point, so everything above it runs twice. This +fixes that by making it a *real* suspension point — the coroutine frame is +held in a ``ContinuationStore`` across MRTR rounds, keyed by +``request_state``. + +Handler code stays exactly as it was in the SSE era. Side-effects above +the await fire once because the function never restarts — it resumes. + +Trade-off: the server holds the frame in memory between rounds. Client +still sees pure MRTR (no SSE), but the server is stateful *within* a +single tool call. Horizontally-scaled deployments need sticky routing on +the ``request_state`` token. Same operational shape as Option A's SSE +hold, without the long-lived connection. + +When to use: migrating existing SSE-era tools to MRTR wire protocol +without rewriting the handler, or when the linear style is genuinely +clearer than guard-first (complex branching, many rounds). + +When not to: if you need true statelessness across server instances. +Use E/F/G — they encode everything the server needs in ``request_state`` +itself. +""" + +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel + +from mcp.server.experimental.mrtr import ContinuationStore, LinearCtx, linear_mrtr + +from ._shared import audit_log, build_server, lookup_weather + + +class UnitsPref(BaseModel): + units: str + + +# ─────────────────────────────────────────────────────────────────────────── +# This is what the tool author writes. Linear, front-to-back, no re-entry +# contract to reason about. The ``audit_log`` above the await fires +# exactly once — the await is a real suspension point. +# ─────────────────────────────────────────────────────────────────────────── + + +async def weather(ctx: LinearCtx, args: dict[str, Any]) -> str: + location = args["location"] + audit_log(location) # runs once — unlike Option B + prefs = await ctx.elicit("Which units?", UnitsPref) + return lookup_weather(location, prefs.units) + + +# ─────────────────────────────────────────────────────────────────────────── +# Registration. The store must be entered as an async context manager +# around the server's run loop — it owns the task group that keeps the +# suspended coroutines alive. +# ─────────────────────────────────────────────────────────────────────────── + +store = ContinuationStore() +server = build_server("mrtr-option-h", on_call_tool=linear_mrtr(weather, store=store)) diff --git a/examples/servers/mrtr-options/pyproject.toml b/examples/servers/mrtr-options/pyproject.toml new file mode 100644 index 000000000..be396c68b --- /dev/null +++ b/examples/servers/mrtr-options/pyproject.toml @@ -0,0 +1,19 @@ +[project] +name = "mcp-mrtr-options" +version = "0.1.0" +description = "MRTR handler-shape comparison — seven options on the same weather tool (SEP-2322)" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Model Context Protocol a Series of LF Projects, LLC." }] +license = { text = "MIT" } +dependencies = ["mcp"] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mrtr_options"] + +[tool.uv.sources] +mcp = { workspace = true } diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 34d6a360f..8f2c69a74 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -12,17 +12,24 @@ from mcp.client.streamable_http import streamable_http_client from mcp.server import Server from mcp.server.mcpserver import MCPServer +from mcp.shared._context import RequestContext from mcp.shared.session import ProgressFnT from mcp.types import ( CallToolResult, CompleteResult, + CreateMessageRequest, + ElicitRequest, EmptyResult, + ErrorData, GetPromptResult, Implementation, + IncompleteResult, InitializeResult, + InputResponses, ListPromptsResult, ListResourcesResult, ListResourceTemplatesResult, + ListRootsRequest, ListToolsResult, LoggingLevel, PaginatedRequestParams, @@ -32,6 +39,9 @@ ResourceTemplateReference, ) +MRTR_MAX_ROUNDS = 8 +"""Bound on MRTR retry rounds. A well-formed handler converges; an unbounded loop is a bug.""" + @dataclass class Client: @@ -95,6 +105,9 @@ async def main(): elicitation_callback: ElicitationFnT | None = None """Callback for handling elicitation requests.""" + max_mrtr_rounds: int = MRTR_MAX_ROUNDS + """Maximum MRTR retry rounds before raising (SEP-2322). A server that never converges is a bug.""" + _session: ClientSession | None = field(init=False, default=None) _exit_stack: AsyncExitStack | None = field(init=False, default=None) _transport: Transport = field(init=False) @@ -238,6 +251,12 @@ async def call_tool( ) -> CallToolResult: """Call a tool on the server. + If the server returns an ``IncompleteResult`` (SEP-2322 MRTR), this + method drives the retry loop internally: each embedded input request + is dispatched to the matching callback (``elicitation_callback``, + ``sampling_callback``, or ``list_roots_callback``) and the tool is + re-called with the collected responses plus echoed ``request_state``. + Args: name: The name of the tool to call arguments: Arguments to pass to the tool @@ -248,14 +267,65 @@ async def call_tool( Returns: The tool result. """ - return await self.session.call_tool( - name=name, - arguments=arguments, - read_timeout_seconds=read_timeout_seconds, - progress_callback=progress_callback, - meta=meta, + input_responses: InputResponses | None = None + request_state: str | None = None + + for _round in range(self.max_mrtr_rounds): + result = await self.session.call_tool_mrtr( + name=name, + arguments=arguments, + read_timeout_seconds=read_timeout_seconds, + progress_callback=progress_callback, + meta=meta, + input_responses=input_responses, + request_state=request_state, + ) + + if isinstance(result, CallToolResult): + return result + + input_responses = await self._fulfil_input_requests(result) + request_state = result.request_state + + raise RuntimeError( + f"MRTR retry loop for tool {name!r} exceeded {self.max_mrtr_rounds} rounds without converging" ) + async def _fulfil_input_requests(self, incomplete: IncompleteResult) -> InputResponses | None: + """Dispatch each embedded input request to the matching callback.""" + if not incomplete.input_requests: + return None + + ctx = RequestContext[ClientSession](session=self.session) + responses: InputResponses = {} + + for key, req in incomplete.input_requests.items(): + match req: + case ElicitRequest(params=params): + if self.elicitation_callback is None: + raise RuntimeError( + f"Server sent elicitation input request {key!r} but no elicitation_callback is configured" + ) + result = await self.elicitation_callback(ctx, params) + case CreateMessageRequest(params=params): + if self.sampling_callback is None: # pragma: no cover + raise RuntimeError( + f"Server sent sampling input request {key!r} but no sampling_callback is configured" + ) + result = await self.sampling_callback(ctx, params) + case ListRootsRequest(): + if self.list_roots_callback is None: # pragma: no cover + raise RuntimeError( + f"Server sent roots input request {key!r} but no list_roots_callback is configured" + ) + result = await self.list_roots_callback(ctx) + + if isinstance(result, ErrorData): + raise RuntimeError(f"Input request {key!r} failed: {result.message}") + responses[key] = result + + return responses + async def list_prompts( self, *, diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 7c964a334..27c3dca1f 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -97,6 +97,10 @@ async def _default_logging_callback( ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData) +_call_tool_result_adapter: TypeAdapter[types.IncompleteResult | types.CallToolResult] = TypeAdapter( + types.IncompleteResult | types.CallToolResult +) + class ClientSession( BaseSession[ @@ -305,18 +309,60 @@ async def call_tool( *, meta: RequestParamsMeta | None = None, ) -> types.CallToolResult: - """Send a tools/call request with optional progress callback support.""" + """Send a tools/call request with optional progress callback support. - result = await self.send_request( + Raises: + RuntimeError: If the server returns an IncompleteResult. Use + ``Client.call_tool`` or ``call_tool_mrtr`` to handle MRTR flows. + """ + result = await self.call_tool_mrtr( + name, + arguments, + read_timeout_seconds, + progress_callback, + meta=meta, + ) + if isinstance(result, types.IncompleteResult): + raise RuntimeError( + f"Server returned IncompleteResult for tool {name!r}. " + "Use Client.call_tool or ClientSession.call_tool_mrtr to handle MRTR flows." + ) + return result + + async def call_tool_mrtr( + self, + name: str, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: float | None = None, + progress_callback: ProgressFnT | None = None, + *, + meta: RequestParamsMeta | None = None, + input_responses: types.InputResponses | None = None, + request_state: str | None = None, + ) -> types.CallToolResult | types.IncompleteResult: + """Send a single tools/call request; returns IncompleteResult if server needs input. + + This is the MRTR-aware variant (SEP-2322). One request → one response. + Higher-level ``mcp.client.Client.call_tool`` drives the retry loop; this + method just surfaces whatever the server sent. + """ + + result: types.CallToolResult | types.IncompleteResult = await self.send_request( types.CallToolRequest( - params=types.CallToolRequestParams(name=name, arguments=arguments, _meta=meta), + params=types.CallToolRequestParams( + name=name, + arguments=arguments, + input_responses=input_responses, + request_state=request_state, + _meta=meta, + ), ), - types.CallToolResult, + _call_tool_result_adapter, request_read_timeout_seconds=read_timeout_seconds, progress_callback=progress_callback, ) - if not result.is_error: + if isinstance(result, types.CallToolResult) and not result.is_error: await self._validate_tool_result(name, result) return result diff --git a/src/mcp/server/experimental/mrtr/__init__.py b/src/mcp/server/experimental/mrtr/__init__.py new file mode 100644 index 000000000..25ed342b9 --- /dev/null +++ b/src/mcp/server/experimental/mrtr/__init__.py @@ -0,0 +1,50 @@ +"""MRTR (SEP-2322) server-side primitives — footgun-prevention layers. + +!!! warning + These APIs are experimental and may change or be removed without notice. + +The naive MRTR handler is de-facto GOTO: re-entry jumps to the top, state +progression is implicit in ``input_responses`` checks, and side-effects +above the guard execute on every retry. Two primitives here make safe code +the easy path: + +- :class:`MrtrCtx` — ``ctx.once(key, fn)`` idempotency guard. Opt-in per + call site; unwrapped mutations still fire twice. Makes the hazard + *visually distinct* from safe code, which is reviewable. Lightweight — + use for single-question tools with one side-effect. + +- :class:`ToolBuilder` — structural decomposition into named steps. + ``end_step`` runs exactly once, structurally unreachable until every + ``incomplete_step`` has returned data. No "above the guard" zone to get + wrong. Boilerplate pays for itself at 3+ rounds. + +Both track progress in ``request_state`` (base64-JSON here; a production +SDK MUST HMAC-sign the blob — see :mod:`._state`). + +The :mod:`.compat` module holds the dual-path shims (Options A and D from +the comparison deck). They're comparison artifacts, not ship targets — +Option E (degrade-only) is the SDK default and requires neither. +""" + +from mcp.server.experimental.mrtr._state import decode_state, encode_state, input_response +from mcp.server.experimental.mrtr.builder import EndStep, IncompleteStep, ToolBuilder +from mcp.server.experimental.mrtr.compat import MrtrHandler, dispatch_by_version, sse_retry_shim +from mcp.server.experimental.mrtr.context import MrtrCtx +from mcp.server.experimental.mrtr.linear import ContinuationStore, ElicitDeclined, LinearCtx, linear_mrtr + +__all__ = [ + "MrtrCtx", + "ToolBuilder", + "IncompleteStep", + "EndStep", + "MrtrHandler", + "LinearCtx", + "ContinuationStore", + "ElicitDeclined", + "linear_mrtr", + "input_response", + "encode_state", + "decode_state", + "sse_retry_shim", + "dispatch_by_version", +] diff --git a/src/mcp/server/experimental/mrtr/_state.py b/src/mcp/server/experimental/mrtr/_state.py new file mode 100644 index 000000000..982608f1a --- /dev/null +++ b/src/mcp/server/experimental/mrtr/_state.py @@ -0,0 +1,53 @@ +"""request_state encode/decode and input_responses sugar. + +!!! warning "Unsigned — advisory only" + Plain base64-JSON. A production SDK MUST HMAC-sign the blob because the + client can otherwise forge step-done / once-executed markers and skip the + guards entirely. A per-session key derived from the initialize handshake + keeps it stateless. Without signing, the safety story of MrtrCtx / + ToolBuilder is advisory, not enforced. +""" + +from __future__ import annotations + +import base64 +import json +from typing import Any, cast + +from mcp.types import CallToolRequestParams + + +def encode_state(state: Any) -> str: + """Encode a JSON-serializable value into an opaque request_state string.""" + return base64.b64encode(json.dumps(state).encode()).decode() + + +def decode_state(blob: str | None) -> dict[str, Any]: + """Decode a request_state string. Returns {} on None/empty/malformed.""" + if not blob: + return {} + try: + result = json.loads(base64.b64decode(blob)) + return cast(dict[str, Any], result) if isinstance(result, dict) else {} + except (ValueError, json.JSONDecodeError): # pragma: no cover + return {} + + +def input_response(params: CallToolRequestParams, key: str) -> dict[str, Any] | None: + """Pull an accepted elicitation's content out of ``params.input_responses``. + + Returns ``None`` if the key is absent, declined, or cancelled. Sugar for + the common guard-first pattern:: + + units = input_response(params, "units") + if units is None: + return IncompleteResult(input_requests={"units": ...}) + """ + if not params.input_responses: + return None + entry = params.input_responses.get(key) + if not entry: + return None + if entry.get("action") != "accept": + return None + return entry.get("content") diff --git a/src/mcp/server/experimental/mrtr/builder.py b/src/mcp/server/experimental/mrtr/builder.py new file mode 100644 index 000000000..11612ad05 --- /dev/null +++ b/src/mcp/server/experimental/mrtr/builder.py @@ -0,0 +1,114 @@ +"""ToolBuilder — structural step decomposition (Option G).""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from typing import Any, Generic, TypeVar + +from mcp.server.experimental.mrtr._state import decode_state, encode_state +from mcp.types import CallToolRequestParams, CallToolResult, IncompleteResult + +ArgsT = TypeVar("ArgsT") + + +IncompleteStep = Callable[[ArgsT, dict[str, Any]], IncompleteResult | dict[str, Any]] +"""An incomplete-step function. Receives args + all input responses collected +so far. Returns either an :class:`IncompleteResult` (needs more input) or a +dict to merge into the collected data passed to the next step. + +MUST be idempotent — it can re-run if the client tampers with ``request_state`` +(unsigned in this draft) or if a step before it wasn't the most-recently +completed. Side-effects belong in the end step. +""" + +EndStep = Callable[[ArgsT, dict[str, Any]], CallToolResult] +"""The end-step function. Receives args + the merged data from all prior +steps. Runs exactly once, when every incomplete step has returned data. +This is the safe zone — put side-effects here. +""" + + +class ToolBuilder(Generic[ArgsT]): + """Explicit step decomposition for MRTR handlers. + + The monolithic handler becomes a sequence of named step functions. + ``end_step`` is structurally unreachable until every ``incomplete_step`` + has returned data — the SDK's step-tracking (via ``request_state``) is + the guard, not developer discipline:: + + def ask_units(args, inputs): + u = inputs.get("units") + if not u or u.get("action") != "accept": + return IncompleteResult(input_requests={"units": ElicitRequest(...)}) + return {"units": u["content"]["u"]} + + def fetch_weather(args, collected): + audit_log(args) # runs exactly once + return CallToolResult(...) # uses collected["units"] + + handler = ( + ToolBuilder[dict[str, str]]() + .incomplete_step("ask_units", ask_units) + .end_step(fetch_weather) + .build() + ) + + server = Server("demo", on_call_tool=handler) + + Steps are named (not ordinal) so reordering during development doesn't + silently remap data. Each name must be unique; ``build()`` raises on + duplicates. + + Boilerplate vs a raw guard-first handler: two function defs + ``.build()`` + to replace a 3-line ``if not x: return IncompleteResult(...)``. Worth it at + 3+ rounds or when the side-effect story matters. Overkill for a single + question — use :class:`mcp.server.experimental.mrtr.MrtrCtx` instead. + """ + + def __init__(self) -> None: + self._steps: list[tuple[str, IncompleteStep[ArgsT]]] = [] + self._end: EndStep[ArgsT] | None = None + + def incomplete_step(self, name: str, fn: IncompleteStep[ArgsT]) -> ToolBuilder[ArgsT]: + """Append a step that may return IncompleteResult or data to collect.""" + self._steps.append((name, fn)) + return self + + def end_step(self, fn: EndStep[ArgsT]) -> ToolBuilder[ArgsT]: + """Set the final step that runs exactly once with all collected data.""" + self._end = fn + return self + + def build(self) -> Callable[[Any, CallToolRequestParams], Awaitable[CallToolResult | IncompleteResult]]: + """Produce a lowlevel ``on_call_tool`` handler.""" + if self._end is None: + raise ValueError("ToolBuilder: end_step is required") + names = [n for n, _ in self._steps] + if len(names) != len(set(names)): + raise ValueError(f"ToolBuilder: duplicate step names in {names}") + + steps = list(self._steps) + end = self._end + + async def handler(ctx: Any, params: CallToolRequestParams) -> CallToolResult | IncompleteResult: + args: ArgsT = params.arguments # type: ignore[assignment] + prior = decode_state(params.request_state) + done: set[str] = set(prior.get("done", [])) + inputs = params.input_responses or {} + collected: dict[str, Any] = dict(prior.get("collected", {})) + + for name, step in steps: + if name in done: + continue + result = step(args, inputs) + if isinstance(result, IncompleteResult): + return IncompleteResult( + input_requests=result.input_requests, + request_state=encode_state({"done": sorted(done), "collected": collected}), + ) + collected.update(result) + done.add(name) + + return end(args, collected) + + return handler diff --git a/src/mcp/server/experimental/mrtr/compat.py b/src/mcp/server/experimental/mrtr/compat.py new file mode 100644 index 000000000..d7e3a55b9 --- /dev/null +++ b/src/mcp/server/experimental/mrtr/compat.py @@ -0,0 +1,110 @@ +"""Dual-path compat shims for pre-MRTR clients (Options A and D). + +!!! warning "Comparison artifacts, not ship targets" + These exist so the option-comparison deck has concrete SDK machinery to + reference. Whether either ships depends on where SEP-2322 discussion + converges. Option E (degrade-only) is the SDK default and requires + neither. +""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from typing import Any + +from mcp.types import CallToolRequestParams, CallToolResult, ElicitRequest, ElicitRequestFormParams, IncompleteResult + +MrtrHandler = Callable[[Any, CallToolRequestParams], Awaitable[CallToolResult | IncompleteResult]] +"""Signature of an MRTR-native lowlevel handler.""" + +MRTR_MIN_VERSION = "2026-06-01" +"""Placeholder for the first protocol version where IncompleteResult is legal.""" + + +def sse_retry_shim(mrtr_handler: MrtrHandler, *, max_rounds: int = 8) -> MrtrHandler: # pragma: no cover + """Wrap an MRTR-native handler so pre-MRTR clients also get elicitation (Option A). + + When the negotiated version is pre-MRTR and the handler returns + ``IncompleteResult``, this shim drives the retry loop *locally* — it + fulfils each ``InputRequest`` via real SSE (``ctx.session.elicit_form()``), + collects the answers, and re-invokes the handler with ``input_responses`` + populated. Repeat until complete. + + This only works on infra that can actually hold SSE — the elicit call is + a real SSE round-trip. On a horizontally-scaled MRTR-only deployment (the + whole reason to adopt MRTR), this fails at runtime when an old client + connects. That constraint lives nowhere near the tool code. If that's the + deployment, use Option E (degrade) instead — it's the SDK default. + + Hidden cost: the handler is silently re-invoked. The MRTR shape makes + re-entry safe by construction (the guard is visible in source), but the + *loop* is invisible. If the handler does something expensive before the + guard, you won't find out until an old client connects in prod. + + Not tested against a pre-MRTR client in this draft because the SDK's + LATEST_PROTOCOL_VERSION is still 2025-11-25. Covered by E2E tests once + the version bumps. + """ + + async def wrapped(ctx: Any, params: CallToolRequestParams) -> CallToolResult | IncompleteResult: + version = ctx.session.client_params.protocol_version if ctx.session.client_params else None + if version is None or str(version) >= MRTR_MIN_VERSION: + return await mrtr_handler(ctx, params) + + responses: dict[str, Any] = dict(params.input_responses or {}) + state = params.request_state + + for _round in range(max_rounds): + retry = params.model_copy(update={"input_responses": responses or None, "request_state": state}) + result = await mrtr_handler(ctx, retry) + + if isinstance(result, CallToolResult): + return result + + state = result.request_state + + if not result.input_requests: + return CallToolResult( + content=[{"type": "text", "text": "IncompleteResult with no inputRequests on pre-MRTR session."}], # type: ignore[list-item] + is_error=True, + ) + + for key, req in result.input_requests.items(): + if not isinstance(req, ElicitRequest) or not isinstance(req.params, ElicitRequestFormParams): + continue + elicit_result = await ctx.session.elicit_form( + message=req.params.message, + requested_schema=req.params.requested_schema, + related_request_id=ctx.request_id, + ) + responses[key] = elicit_result.model_dump(by_alias=True, exclude_none=True) + + return CallToolResult( + content=[{"type": "text", "text": "SSE retry shim exceeded round limit."}], # type: ignore[list-item] + is_error=True, + ) + + return wrapped + + +def dispatch_by_version( + *, + mrtr: MrtrHandler, + sse: Callable[[Any, CallToolRequestParams], Awaitable[CallToolResult]], + min_mrtr_version: str = MRTR_MIN_VERSION, +) -> MrtrHandler: + """Two handlers, one per protocol era. SDK picks by negotiated version (Option D). + + No shim, no magic — the author wrote both paths, the SDK just routes. + Two functions per tool, both live until SSE is deprecated, and nothing + mechanically links them: if the MRTR handler changes the elicitation + schema and the SSE handler doesn't, nothing catches it. + """ + + async def wrapped(ctx: Any, params: CallToolRequestParams) -> CallToolResult | IncompleteResult: + version = ctx.session.client_params.protocol_version if ctx.session.client_params else None + if version is not None and str(version) >= min_mrtr_version: + return await mrtr(ctx, params) + return await sse(ctx, params) + + return wrapped diff --git a/src/mcp/server/experimental/mrtr/context.py b/src/mcp/server/experimental/mrtr/context.py new file mode 100644 index 000000000..a7f007da2 --- /dev/null +++ b/src/mcp/server/experimental/mrtr/context.py @@ -0,0 +1,73 @@ +"""MrtrCtx — idempotency guard for side-effects (Option F).""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +from mcp.server.experimental.mrtr._state import decode_state, encode_state +from mcp.types import CallToolRequestParams, IncompleteResult, InputRequests + + +class MrtrCtx: + """MRTR context with a ``once`` guard tracked in ``request_state``. + + Handler stays monolithic (guard-first, like a raw MRTR handler), but + side-effects can be wrapped for at-most-once execution across retries:: + + ctx = MrtrCtx(params) + ctx.once("audit", lambda: audit_log(params.arguments["x"])) + + units = input_response(params, "units") + if units is None: + return ctx.incomplete({"units": ElicitRequest(...)}) + + return CallToolResult(...) + + Opt-in: an unwrapped mutation still fires twice. The footgun isn't + eliminated — it's made visually distinct from safe code, which is + reviewable. A bare ``db.write()`` above the guard is the red flag; + ``ctx.once("write", lambda: db.write())`` reads as "I considered + re-entry." + + Crash window: if the server dies between ``fn()`` completing and + ``request_state`` reaching the client, the next invocation re-executes. + At-most-once under normal operation, not crash-safe. For financial + operations use external idempotency (request ID as DB unique key). + """ + + def __init__(self, params: CallToolRequestParams) -> None: + self._params = params + prior = decode_state(params.request_state) + self._executed: set[str] = set(prior.get("executed", [])) + + @property + def input_responses(self) -> dict[str, Any] | None: # pragma: no cover + return self._params.input_responses + + def once(self, key: str, fn: Callable[[], Any]) -> None: + """Run ``fn`` at most once across all MRTR rounds for this tool call. + + On subsequent rounds where ``key`` is marked executed in + ``request_state``, ``fn`` is skipped entirely. + """ + if key in self._executed: + return + fn() + self._executed.add(key) + + def has_run(self, key: str) -> bool: + """Check if ``once(key, ...)`` has fired on a prior round.""" + return key in self._executed + + def incomplete(self, input_requests: InputRequests) -> IncompleteResult: + """Build an IncompleteResult that carries the executed-keys set. + + Call this instead of constructing ``IncompleteResult`` directly so + the ``once`` guard holds across retry. Without this, ``once`` is a + no-op on the next round. + """ + return IncompleteResult( + input_requests=input_requests, + request_state=encode_state({"executed": sorted(self._executed)}), + ) diff --git a/src/mcp/server/experimental/mrtr/linear.py b/src/mcp/server/experimental/mrtr/linear.py new file mode 100644 index 000000000..bb6383bc7 --- /dev/null +++ b/src/mcp/server/experimental/mrtr/linear.py @@ -0,0 +1,276 @@ +"""Linear MRTR — keep ``await ctx.elicit()`` genuinely linear (Option H). + +The Option B footgun was: ``await elicit()`` *looks* like a suspension point +but is actually a re-entry point, so everything above it runs twice. This +module fixes that by making it a *real* suspension point — the coroutine +frame is held in memory across MRTR rounds, keyed by ``request_state``. + +Handler code stays exactly as it was in the SSE era:: + + async def my_tool(ctx: LinearCtx, location: str) -> str: + audit_log(location) # runs exactly once + units = await ctx.elicit("Which units?", UnitsSchema) + audit_log("got units") # runs exactly once + return f"{location}: 22°{units.u}" + +The wrapper ``linear_mrtr(my_tool)`` translates this into a standard MRTR +``on_call_tool`` handler. Round 1 starts the coroutine; ``elicit()`` sends +an ``IncompleteResult`` back through the wrapper and parks on a stream. +Round 2's retry wakes it with the answer. The coroutine continues from +where it stopped — no re-entry, no double-execution. + +**Trade-off**: the server holds the frame in memory between rounds. The +client still sees pure MRTR (no SSE, independent HTTP requests), but the +server is stateful *within* a single tool call. Horizontally-scaled +deployments need sticky routing on the ``request_state`` token, or a +distributed continuation store. Same operational shape as Option A's SSE +hold, just without the long-lived connection. + +**When to use this**: migrating existing SSE-era tools to MRTR wire +protocol without rewriting the handler. Or when the linear style is +genuinely clearer than guard-first (complex branching, many rounds). + +**When not to**: if you need true statelessness across server instances. +Use Option E/F/G instead — they encode everything the server needs in +``request_state`` itself. +""" + +from __future__ import annotations + +import uuid +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from types import TracebackType +from typing import Any, TypeVar + +import anyio +import anyio.abc +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from pydantic import BaseModel + +from mcp.types import ( + CallToolRequestParams, + CallToolResult, + ElicitRequest, + ElicitRequestFormParams, + IncompleteResult, + InputRequests, + TextContent, +) + +__all__ = ["LinearCtx", "linear_mrtr", "ContinuationStore"] + +T = TypeVar("T", bound=BaseModel) + + +# ─── Continuation plumbing ─────────────────────────────────────────────────── + + +@dataclass +class _Continuation: + """In-memory state for one suspended linear handler.""" + + ask_send: MemoryObjectSendStream[IncompleteResult | CallToolResult] + ask_recv: MemoryObjectReceiveStream[IncompleteResult | CallToolResult] + answer_send: MemoryObjectSendStream[dict[str, Any]] + answer_recv: MemoryObjectReceiveStream[dict[str, Any]] + + @classmethod + def new(cls) -> _Continuation: + ask_s, ask_r = anyio.create_memory_object_stream[IncompleteResult | CallToolResult](1) + ans_s, ans_r = anyio.create_memory_object_stream[dict[str, Any]](1) + return cls(ask_send=ask_s, ask_recv=ask_r, answer_send=ans_s, answer_recv=ans_r) + + def close(self) -> None: + self.ask_send.close() + self.ask_recv.close() + self.answer_send.close() + self.answer_recv.close() + + +class ContinuationStore: + """Owns the background task group and the token → continuation map. + + One per server (or per-process). Must be entered as an async context + manager so the task group is live before any handler runs:: + + store = ContinuationStore() + handler = linear_mrtr(my_tool, store=store) + server = Server("demo", on_call_tool=handler) + + async with store: + await server.run(...) + + Continuations expire after ``ttl_seconds`` of inactivity — if the client + never retries, the frame is reclaimed. Default 5 minutes. + """ + + def __init__(self, *, ttl_seconds: float = 300.0) -> None: + self._frames: dict[str, _Continuation] = {} + self._ttl = ttl_seconds + self._tg: anyio.abc.TaskGroup | None = None + + async def __aenter__(self) -> ContinuationStore: + self._tg = anyio.create_task_group() + await self._tg.__aenter__() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + if self._tg is not None: # pragma: no branch + self._tg.cancel_scope.cancel() + await self._tg.__aexit__(exc_type, exc_val, exc_tb) + self._tg = None + self._frames.clear() + + def _check_entered(self) -> None: + if self._tg is None: + raise RuntimeError("ContinuationStore not entered — use `async with store:` around server.run()") + + def _start(self, token: str, cont: _Continuation, runner: Callable[[], Awaitable[None]]) -> None: + assert self._tg is not None + self._frames[token] = cont + + async def _run_and_cleanup() -> None: + try: + with anyio.move_on_after(self._ttl): + await runner() + finally: + cont.close() + self._frames.pop(token, None) + + self._tg.start_soon(_run_and_cleanup) + + def get(self, token: str) -> _Continuation | None: + return self._frames.get(token) + + +# ─── The linear context ────────────────────────────────────────────────────── + + +class LinearCtx: + """The ``ctx`` handed to a linear handler. ``await ctx.elicit()`` genuinely suspends.""" + + def __init__(self, continuation: _Continuation) -> None: + self._cont = continuation + self._counter = 0 + + async def elicit(self, message: str, schema: type[T]) -> T: + """Ask the client a question. Suspends until the answer arrives on a later round. + + The schema is a Pydantic model; the elicitation requestedSchema is + derived from it, and the answer is validated back into an instance. + + Raises: + ElicitDeclined: if the user declined or cancelled. + """ + key = f"q{self._counter}" + self._counter += 1 + responses = await self.ask( + { + key: ElicitRequest( + params=ElicitRequestFormParams(message=message, requested_schema=schema.model_json_schema()) + ) + } + ) + answer = responses.get(key, {}) + if answer.get("action") != "accept": + raise ElicitDeclined(answer.get("action", "cancel")) + return schema.model_validate(answer.get("content", {})) + + async def ask(self, input_requests: InputRequests) -> dict[str, Any]: + """Send one or more input requests in a single round; returns the full responses dict. + + Lower-level than :meth:`elicit` — hand-rolled schemas, no validation, + multiple asks batched into one round. + """ + await self._cont.ask_send.send(IncompleteResult(input_requests=input_requests)) + return await self._cont.answer_recv.receive() + + +class ElicitDeclined(Exception): + """Raised inside a linear handler when the user declines or cancels an elicitation.""" + + def __init__(self, action: str) -> None: + self.action = action + super().__init__(f"Elicitation {action}") + + +# ─── The wrapper ───────────────────────────────────────────────────────────── + + +LinearHandler = Callable[[LinearCtx, dict[str, Any]], Awaitable[CallToolResult | str]] +"""Signature of a linear handler: ``(ctx, arguments) -> CallToolResult | str``.""" + + +class _LinearMrtrWrapper: + def __init__(self, handler: LinearHandler, store: ContinuationStore) -> None: + self._handler = handler + self._store = store + + async def __call__(self, ctx: Any, params: CallToolRequestParams) -> CallToolResult | IncompleteResult: + token = params.request_state + + if token is None: + return await self._start(params) + return await self._resume(token, params) + + async def _start(self, params: CallToolRequestParams) -> CallToolResult | IncompleteResult: + self._store._check_entered() # pyright: ignore[reportPrivateUsage] + token = uuid.uuid4().hex + cont = _Continuation.new() + linear_ctx = LinearCtx(cont) + args = dict(params.arguments or {}) + + async def runner() -> None: + try: + result = await self._handler(linear_ctx, args) + if isinstance(result, str): + result = CallToolResult(content=[TextContent(text=result)]) + await cont.ask_send.send(result) + except ElicitDeclined as exc: + await cont.ask_send.send( + CallToolResult(content=[TextContent(text=f"Cancelled ({exc.action}).")], is_error=False) + ) + except Exception as exc: # noqa: BLE001 + await cont.ask_send.send(CallToolResult(content=[TextContent(text=str(exc))], is_error=True)) + + self._store._start(token, cont, runner) # pyright: ignore[reportPrivateUsage] + return await self._next(token, cont) + + async def _resume(self, token: str, params: CallToolRequestParams) -> CallToolResult | IncompleteResult: + cont = self._store.get(token) + if cont is None: + return CallToolResult( + content=[TextContent(text="Continuation expired or unknown. Retry the tool call from scratch.")], + is_error=True, + ) + await cont.answer_send.send(params.input_responses or {}) + return await self._next(token, cont) + + async def _next(self, token: str, cont: _Continuation) -> CallToolResult | IncompleteResult: + msg = await cont.ask_recv.receive() + if isinstance(msg, IncompleteResult): + return IncompleteResult(input_requests=msg.input_requests, request_state=token) + return msg + + +def linear_mrtr(handler: LinearHandler, *, store: ContinuationStore) -> _LinearMrtrWrapper: + """Wrap a linear ``await ctx.elicit()``-style handler into an MRTR ``on_call_tool``. + + The handler runs exactly once, front to back. ``ctx.elicit()`` is a real + suspension point — the coroutine frame is held in ``store`` between MRTR + rounds, keyed by ``request_state``. + + Args: + handler: ``async (ctx: LinearCtx, arguments: dict) -> CallToolResult | str``. + Returning a ``str`` is shorthand for a single TextContent. + store: The :class:`ContinuationStore` that owns the background task + group. Must be entered as an async context manager around the + server's run loop. + """ + return _LinearMrtrWrapper(handler, store) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index c28842272..ba31f5ff2 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -118,7 +118,7 @@ def __init__( | None = None, on_call_tool: Callable[ [ServerRequestContext[LifespanResultT], types.CallToolRequestParams], - Awaitable[types.CallToolResult | types.CreateTaskResult], + Awaitable[types.CallToolResult | types.IncompleteResult | types.CreateTaskResult], ] | None = None, on_list_resources: Callable[ diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 6fc59923f..67ce58de4 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -4,7 +4,7 @@ from collections.abc import Callable from contextlib import AsyncExitStack from types import TracebackType -from typing import Any, Generic, Protocol, TypeVar +from typing import Any, Generic, Protocol, TypeVar, overload import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream @@ -230,6 +230,7 @@ async def __aexit__( self._task_group.cancel_scope.cancel() return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + @overload async def send_request( self, request: SendRequestT, @@ -237,7 +238,26 @@ async def send_request( request_read_timeout_seconds: float | None = None, metadata: MessageMetadata = None, progress_callback: ProgressFnT | None = None, - ) -> ReceiveResultT: + ) -> ReceiveResultT: ... + + @overload + async def send_request( + self, + request: SendRequestT, + result_type: TypeAdapter[Any], + request_read_timeout_seconds: float | None = None, + metadata: MessageMetadata = None, + progress_callback: ProgressFnT | None = None, + ) -> Any: ... + + async def send_request( + self, + request: SendRequestT, + result_type: type[ReceiveResultT] | TypeAdapter[Any], + request_read_timeout_seconds: float | None = None, + metadata: MessageMetadata = None, + progress_callback: ProgressFnT | None = None, + ) -> ReceiveResultT | Any: """Sends a request and waits for a response. Raises an MCPError if the response contains an error. If a request read timeout is provided, it will take @@ -280,6 +300,8 @@ async def send_request( if isinstance(response_or_error, JSONRPCError): raise MCPError.from_jsonrpc_error(response_or_error) + elif isinstance(result_type, TypeAdapter): + return result_type.validate_python(response_or_error.result) else: return result_type.model_validate(response_or_error.result, by_name=False) diff --git a/src/mcp/types/__init__.py b/src/mcp/types/__init__.py index b44230393..cc892de9e 100644 --- a/src/mcp/types/__init__.py +++ b/src/mcp/types/__init__.py @@ -74,10 +74,15 @@ ImageContent, Implementation, IncludeContext, + IncompleteResult, InitializedNotification, InitializeRequest, InitializeRequestParams, InitializeResult, + InputRequest, + InputRequests, + InputResponse, + InputResponses, ListPromptsRequest, ListPromptsResult, ListResourcesRequest, @@ -179,6 +184,7 @@ client_notification_adapter, client_request_adapter, client_result_adapter, + input_request_adapter, server_notification_adapter, server_request_adapter, server_result_adapter, @@ -342,6 +348,13 @@ "SubscribeRequestParams", "UnsubscribeRequest", "UnsubscribeRequestParams", + # MRTR (SEP-2322) + "IncompleteResult", + "InputRequest", + "InputRequests", + "InputResponse", + "InputResponses", + "input_request_adapter", # Results "CallToolResult", "CancelTaskResult", diff --git a/src/mcp/types/_types.py b/src/mcp/types/_types.py index 9005d253a..af1f7d356 100644 --- a/src/mcp/types/_types.py +++ b/src/mcp/types/_types.py @@ -76,6 +76,19 @@ class RequestParams(MCPModel): for task augmentation of specific request types in their capabilities. """ + input_responses: dict[str, Any] | None = None + """Responses to input requests from a prior IncompleteResult (SEP-2322 MRTR). + + Keys mirror the server's inputRequests keys; values are the corresponding + ElicitResult, CreateMessageResult, or ListRootsResult payloads. + """ + + request_state: str | None = None + """Opaque state echoed from a prior IncompleteResult (SEP-2322 MRTR). + + Clients MUST NOT inspect or modify this value. + """ + meta: RequestParamsMeta | None = Field(alias="_meta", default=None) @@ -1716,6 +1729,44 @@ class ElicitationRequiredErrorData(MCPModel): """List of URL mode elicitations that must be completed.""" +# ─── Multi Round-Trip Requests (SEP-2322) ─────────────────────────────────── + +InputRequest: TypeAlias = CreateMessageRequest | ElicitRequest | ListRootsRequest +"""A server-initiated request embedded in an IncompleteResult.""" + +InputResponse: TypeAlias = CreateMessageResult | CreateMessageResultWithTools | ElicitResult | ListRootsResult +"""A client's response to an InputRequest, sent on the retry.""" + +InputRequests: TypeAlias = dict[str, InputRequest] +"""Keyed map of input requests. Keys are server-assigned and opaque to the client.""" + +InputResponses: TypeAlias = dict[str, InputResponse] +"""Keyed map of input responses. Keys mirror the server's InputRequests keys.""" + + +class IncompleteResult(Result): + """A result indicating the server needs more input before completing (SEP-2322). + + The client MUST retry the original request with ``input_responses`` populated + for each key in ``input_requests``, and ``request_state`` echoed verbatim. + + At least one of ``input_requests`` or ``request_state`` must be present. + """ + + result_type: Literal["incomplete"] = "incomplete" + """Discriminator marking this as an incomplete result.""" + + input_requests: InputRequests | None = None + """Server-initiated requests the client must fulfil before retrying.""" + + request_state: str | None = None + """Opaque state the client must echo back on retry. Not inspected by the client.""" + + +input_request_adapter: TypeAdapter[InputRequest] = TypeAdapter(InputRequest) +"""Type adapter for validating embedded InputRequest payloads.""" + + ClientResult = ( EmptyResult | CreateMessageResult @@ -1774,5 +1825,6 @@ class ElicitationRequiredErrorData(MCPModel): | ListTasksResult | CancelTaskResult | CreateTaskResult + | IncompleteResult ) server_result_adapter = TypeAdapter[ServerResult](ServerResult) diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 18368e6bb..80942648f 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -7,10 +7,12 @@ import anyio import pytest from inline_snapshot import snapshot +from pydantic import FileUrl from mcp import MCPError, types from mcp.client._memory import InMemoryTransport from mcp.client.client import Client +from mcp.client.context import ClientRequestContext from mcp.server import Server, ServerRequestContext from mcp.server.mcpserver import MCPServer from mcp.types import ( @@ -320,3 +322,209 @@ async def test_client_uses_transport_directly(app: MCPServer): structured_content={"result": "Hello, Transport!"}, ) ) + + +# ─── MRTR (SEP-2322) ──────────────────────────────────────────────────────── + + +async def _mrtr_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + return types.ListToolsResult(tools=[]) + + +async def test_mrtr_single_round_elicitation(): + """Server returns IncompleteResult with one elicitation; Client drives retry transparently.""" + + async def on_call_tool( + ctx: ServerRequestContext, params: types.CallToolRequestParams + ) -> types.CallToolResult | types.IncompleteResult: + units = params.input_responses.get("units") if params.input_responses else None + if units is None: + return types.IncompleteResult( + input_requests={ + "units": types.ElicitRequest( + params=types.ElicitRequestFormParams( + message="Which units?", + requested_schema={"type": "object", "properties": {"u": {"type": "string"}}}, + ) + ) + }, + ) + u = units["content"]["u"] + location = params.arguments["location"] if params.arguments else "?" + return types.CallToolResult(content=[TextContent(text=f"Weather in {location}: 22°{u}")]) + + server = Server("mrtr-test", on_call_tool=on_call_tool, on_list_tools=_mrtr_list_tools) + + async def elicitation_cb(context: ClientRequestContext, params: types.ElicitRequestParams) -> types.ElicitResult: + return types.ElicitResult(action="accept", content={"u": "C"}) + + async with Client(server, elicitation_callback=elicitation_cb) as client: + result = await client.call_tool("weather", {"location": "Tokyo"}) + assert result == snapshot(CallToolResult(content=[TextContent(text="Weather in Tokyo: 22°C")])) + + +async def test_mrtr_multi_round_with_request_state(): + """Two-round elicitation accumulating state in request_state (the ADO-rules SEP example).""" + + async def on_call_tool( + ctx: ServerRequestContext, params: types.CallToolRequestParams + ) -> types.CallToolResult | types.IncompleteResult: + responses = params.input_responses or {} + state = params.request_state + + if "resolution" not in responses and state is None: + return types.IncompleteResult( + input_requests={ + "resolution": types.ElicitRequest( + params=types.ElicitRequestFormParams(message="Resolution?", requested_schema={}) + ) + }, + ) + + if state is None: + resolution = responses["resolution"]["content"]["r"] + return types.IncompleteResult( + input_requests={ + "dup": types.ElicitRequest( + params=types.ElicitRequestFormParams(message="Duplicate of?", requested_schema={}) + ) + }, + request_state=f"resolution={resolution}", + ) + + dup = responses["dup"]["content"]["id"] + return types.CallToolResult(content=[TextContent(text=f"{state} dup={dup}")]) + + server = Server("mrtr-test", on_call_tool=on_call_tool, on_list_tools=_mrtr_list_tools) + + answers = {"Resolution?": {"r": "Duplicate"}, "Duplicate of?": {"id": "4301"}} + + async def elicitation_cb(context: ClientRequestContext, params: types.ElicitRequestParams) -> types.ElicitResult: + assert isinstance(params, types.ElicitRequestFormParams) + return types.ElicitResult(action="accept", content=dict(answers[params.message])) + + async with Client(server, elicitation_callback=elicitation_cb) as client: + result = await client.call_tool("update_item", {}) + assert result == snapshot(CallToolResult(content=[TextContent(text="resolution=Duplicate dup=4301")])) + + +async def test_mrtr_round_limit_exceeded(): + """Server never converges → Client raises after max_mrtr_rounds.""" + + async def on_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.IncompleteResult: + return types.IncompleteResult(request_state="spin") + + server = Server("mrtr-test", on_call_tool=on_call_tool, on_list_tools=_mrtr_list_tools) + + async with Client(server, max_mrtr_rounds=3) as client: + with pytest.raises(RuntimeError, match="exceeded 3 rounds"): + await client.call_tool("stuck", {}) + + +async def test_mrtr_elicitation_without_callback_raises(): + """IncompleteResult with elicitation but no callback → clear error.""" + + async def on_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.IncompleteResult: + return types.IncompleteResult( + input_requests={ + "ask": types.ElicitRequest(params=types.ElicitRequestFormParams(message="?", requested_schema={})) + }, + ) + + server = Server("mrtr-test", on_call_tool=on_call_tool, on_list_tools=_mrtr_list_tools) + + async with Client(server) as client: + with pytest.raises(RuntimeError, match="no elicitation_callback"): + await client.call_tool("ask", {}) + + +async def test_mrtr_sampling_input_request(): + """IncompleteResult with a sampling input request is dispatched to sampling_callback.""" + + async def on_call_tool( + ctx: ServerRequestContext, params: types.CallToolRequestParams + ) -> types.CallToolResult | types.IncompleteResult: + if params.input_responses and "q" in params.input_responses: + answer = params.input_responses["q"]["content"]["text"] + return types.CallToolResult(content=[TextContent(text=answer)]) + return types.IncompleteResult( + input_requests={ + "q": types.CreateMessageRequest( + params=types.CreateMessageRequestParams( + messages=[ + types.SamplingMessage(role="user", content=types.TextContent(text="Capital of France?")) + ], + max_tokens=50, + ) + ) + }, + ) + + server = Server("mrtr-test", on_call_tool=on_call_tool, on_list_tools=_mrtr_list_tools) + + async def sampling_cb( + context: ClientRequestContext, params: types.CreateMessageRequestParams + ) -> types.CreateMessageResult: + return types.CreateMessageResult( + role="assistant", content=types.TextContent(text="Paris"), model="test", stop_reason="endTurn" + ) + + async with Client(server, sampling_callback=sampling_cb) as client: + result = await client.call_tool("ask", {}) + assert result == snapshot(CallToolResult(content=[TextContent(text="Paris")])) + + +async def test_mrtr_list_roots_input_request(): + """IncompleteResult with a roots/list input request is dispatched to list_roots_callback.""" + + async def on_call_tool( + ctx: ServerRequestContext, params: types.CallToolRequestParams + ) -> types.CallToolResult | types.IncompleteResult: + if params.input_responses and "roots" in params.input_responses: + n = len(params.input_responses["roots"]["roots"]) + return types.CallToolResult(content=[TextContent(text=f"saw {n} roots")]) + return types.IncompleteResult(input_requests={"roots": types.ListRootsRequest()}) + + server = Server("mrtr-test", on_call_tool=on_call_tool, on_list_tools=_mrtr_list_tools) + + async def list_roots_cb(context: ClientRequestContext) -> types.ListRootsResult: + return types.ListRootsResult(roots=[types.Root(uri=FileUrl("file:///a")), types.Root(uri=FileUrl("file:///b"))]) + + async with Client(server, list_roots_callback=list_roots_cb) as client: + result = await client.call_tool("scan", {}) + assert result == snapshot(CallToolResult(content=[TextContent(text="saw 2 roots")])) + + +async def test_mrtr_callback_returns_error_data(): + """Callback returning ErrorData surfaces as RuntimeError.""" + + async def on_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.IncompleteResult: + return types.IncompleteResult( + input_requests={ + "ask": types.ElicitRequest(params=types.ElicitRequestFormParams(message="?", requested_schema={})) + }, + ) + + server = Server("mrtr-test", on_call_tool=on_call_tool, on_list_tools=_mrtr_list_tools) + + async def elicitation_cb(context: ClientRequestContext, params: types.ElicitRequestParams) -> types.ErrorData: + return types.ErrorData(code=-1, message="user closed dialog") + + async with Client(server, elicitation_callback=elicitation_cb) as client: + with pytest.raises(RuntimeError, match="user closed dialog"): + await client.call_tool("ask", {}) + + +async def test_session_call_tool_raises_on_incomplete(): + """ClientSession.call_tool (non-MRTR) raises if server returns IncompleteResult.""" + + async def on_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.IncompleteResult: + return types.IncompleteResult(request_state="x") + + server = Server("mrtr-test", on_call_tool=on_call_tool, on_list_tools=_mrtr_list_tools) + + async with Client(server) as client: + with pytest.raises(RuntimeError, match="Use Client.call_tool"): + await client.session.call_tool("stuck", {}) diff --git a/tests/experimental/test_mrtr.py b/tests/experimental/test_mrtr.py new file mode 100644 index 000000000..2f2a7293c --- /dev/null +++ b/tests/experimental/test_mrtr.py @@ -0,0 +1,459 @@ +"""E2E tests for MRTR server-side primitives (SEP-2322). + +Tests the ``mcp.server.experimental.mrtr`` module: ``MrtrCtx``, +``ToolBuilder``, ``input_response``, ``dispatch_by_version``. + +The footgun test measures side-effect counts to prove F and G actually +hold the guard. The invariant test parametrises all handler shapes against +the same Client to prove the server's internal choice doesn't leak. +""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from typing import Any + +import pytest +from inline_snapshot import snapshot +from pydantic import BaseModel + +from mcp import types +from mcp.client.client import Client +from mcp.client.context import ClientRequestContext +from mcp.server import Server, ServerRequestContext +from mcp.server.experimental.mrtr import ( + ContinuationStore, + LinearCtx, + MrtrCtx, + ToolBuilder, + dispatch_by_version, + input_response, + linear_mrtr, +) + +pytestmark = pytest.mark.anyio + + +# ─── Shared domain bits (mirror of examples/servers/mrtr-options) ──────────── + + +UNITS_REQUEST = types.ElicitRequest( + params=types.ElicitRequestFormParams( + message="Which units?", + requested_schema={ + "type": "object", + "properties": {"units": {"type": "string", "enum": ["metric", "imperial"]}}, + "required": ["units"], + }, + ) +) + + +def lookup_weather(location: str, units: str) -> str: + temp = "22°C" if units == "metric" else "72°F" + return f"Weather in {location}: {temp}" + + +async def no_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> types.ListToolsResult: + return types.ListToolsResult(tools=[]) + + +async def pick_metric(context: ClientRequestContext, params: types.ElicitRequestParams) -> types.ElicitResult: + return types.ElicitResult(action="accept", content={"units": "metric"}) + + +_audit: list[str] = [] + + +def audit_log(where: str) -> None: + _audit.append(where) + + +@pytest.fixture(autouse=True) +def reset_audit(): + _audit.clear() + yield + + +MrtrHandler = Callable[ + [ServerRequestContext, types.CallToolRequestParams], Awaitable[types.CallToolResult | types.IncompleteResult] +] + + +def make_server(handler: MrtrHandler) -> Server: + return Server("mrtr-test", on_call_tool=handler, on_list_tools=no_tools) + + +# ─── Handler shapes ────────────────────────────────────────────────────────── + + +async def option_e_degrade( + ctx: ServerRequestContext, params: types.CallToolRequestParams +) -> types.CallToolResult | types.IncompleteResult: + """Option E — SDK default. MRTR-native; pre-MRTR gets default.""" + location = (params.arguments or {}).get("location", "?") + prefs = input_response(params, "units") + if prefs is None: + return types.IncompleteResult(input_requests={"units": UNITS_REQUEST}) + return types.CallToolResult(content=[types.TextContent(text=lookup_weather(location, prefs["units"]))]) + + +async def option_f_ctx_once( + ctx: ServerRequestContext, params: types.CallToolRequestParams +) -> types.CallToolResult | types.IncompleteResult: + """Option F — ctx.once idempotency guard.""" + location = (params.arguments or {}).get("location", "?") + mrtr = MrtrCtx(params) + mrtr.once("audit", lambda: audit_log(f"F:{location}")) + prefs = input_response(params, "units") + if prefs is None: + return mrtr.incomplete({"units": UNITS_REQUEST}) + return types.CallToolResult(content=[types.TextContent(text=lookup_weather(location, prefs["units"]))]) + + +def ask_units(args: dict[str, Any], inputs: dict[str, Any]) -> types.IncompleteResult | dict[str, Any]: + resp = inputs.get("units") + if not resp or resp.get("action") != "accept": + return types.IncompleteResult(input_requests={"units": UNITS_REQUEST}) + return {"units": resp["content"]["units"]} + + +def fetch_weather(args: dict[str, Any], collected: dict[str, Any]) -> types.CallToolResult: + location = (args or {}).get("location", "?") + audit_log(f"G:{location}") + return types.CallToolResult(content=[types.TextContent(text=lookup_weather(location, collected["units"]))]) + + +option_g_tool_builder = ( + ToolBuilder[dict[str, Any]]().incomplete_step("ask_units", ask_units).end_step(fetch_weather).build() +) + + +async def option_e_with_naive_audit( + ctx: ServerRequestContext, params: types.CallToolRequestParams +) -> types.CallToolResult | types.IncompleteResult: + """Option E with a naive side-effect above the guard — the footgun.""" + location = (params.arguments or {}).get("location", "?") + audit_log(f"naive:{location}") # runs on EVERY round + prefs = input_response(params, "units") + if prefs is None: + return types.IncompleteResult(input_requests={"units": UNITS_REQUEST}) + return types.CallToolResult(content=[types.TextContent(text=lookup_weather(location, prefs["units"]))]) + + +# ─── The invariant: client can't tell ──────────────────────────────────────── + + +@pytest.mark.parametrize( + "handler", + [option_e_degrade, option_f_ctx_once, option_g_tool_builder], + ids=["E-degrade", "F-ctx_once", "G-tool_builder"], +) +async def test_mrtr_wire_invariant(handler: MrtrHandler): + """All MRTR handler shapes produce identical wire behaviour. + + The server's internal choice (guard-first, ctx.once, ToolBuilder) doesn't + leak to the client. Same Client, same callback, same result. This is the + argument against per-feature ``-mrtr`` capability flags. + """ + async with Client(make_server(handler), elicitation_callback=pick_metric) as client: + result = await client.call_tool("weather", {"location": "Tokyo"}) + assert isinstance(result, types.CallToolResult) + assert result.content[0] == types.TextContent(text="Weather in Tokyo: 22°C") + + +# ─── The footgun: side-effect counts ───────────────────────────────────────── + + +async def test_mrtr_naive_handler_double_executes_side_effect(): + """The footgun, measured. Naive MRTR handler fires audit_log twice.""" + async with Client(make_server(option_e_with_naive_audit), elicitation_callback=pick_metric) as client: + await client.call_tool("weather", {"location": "Tokyo"}) + assert _audit == snapshot(["naive:Tokyo", "naive:Tokyo"]) + + +async def test_mrtr_ctx_once_holds_side_effect(): + """Option F: ctx.once guard holds the side-effect to one across retry.""" + async with Client(make_server(option_f_ctx_once), elicitation_callback=pick_metric) as client: + await client.call_tool("weather", {"location": "Tokyo"}) + assert _audit == snapshot(["F:Tokyo"]) + + +async def test_mrtr_tool_builder_end_step_runs_once(): + """Option G: end_step runs exactly once regardless of round count.""" + async with Client(make_server(option_g_tool_builder), elicitation_callback=pick_metric) as client: + await client.call_tool("weather", {"location": "Tokyo"}) + assert _audit == snapshot(["G:Tokyo"]) + + +# ─── ToolBuilder edge cases ────────────────────────────────────────────────── + + +def test_tool_builder_requires_end_step(): + with pytest.raises(ValueError, match="end_step is required"): + ToolBuilder[dict[str, Any]]().incomplete_step("x", ask_units).build() + + +def test_tool_builder_rejects_duplicate_step_names(): + with pytest.raises(ValueError, match="duplicate step names"): + ToolBuilder[dict[str, Any]]().incomplete_step("x", ask_units).incomplete_step("x", ask_units).end_step( + fetch_weather + ).build() + + +async def test_tool_builder_multi_step_accumulates(): + """Two incomplete_steps before end_step — collected dict merges.""" + + def ask_lang(args: dict[str, Any], inputs: dict[str, Any]) -> types.IncompleteResult | dict[str, Any]: + resp = inputs.get("lang") + if not resp or resp.get("action") != "accept": + return types.IncompleteResult( + input_requests={ + "lang": types.ElicitRequest( + params=types.ElicitRequestFormParams(message="Lang?", requested_schema={}) + ) + } + ) + return {"lang": resp["content"]["lang"]} + + def finish(args: dict[str, Any], collected: dict[str, Any]) -> types.CallToolResult: + return types.CallToolResult(content=[types.TextContent(text=f"{collected['units']}/{collected['lang']}")]) + + handler = ( + ToolBuilder[dict[str, Any]]() + .incomplete_step("ask_units", ask_units) + .incomplete_step("ask_lang", ask_lang) + .end_step(finish) + .build() + ) + + answers = {"Which units?": {"units": "metric"}, "Lang?": {"lang": "en"}} + + async def elicitation_cb(context: ClientRequestContext, params: types.ElicitRequestParams) -> types.ElicitResult: + assert isinstance(params, types.ElicitRequestFormParams) + return types.ElicitResult(action="accept", content=dict(answers[params.message])) + + async with Client(make_server(handler), elicitation_callback=elicitation_cb) as client: + result = await client.call_tool("multi", {}) + assert result == snapshot(types.CallToolResult(content=[types.TextContent(text="metric/en")])) + + +# ─── MrtrCtx edge cases ────────────────────────────────────────────────────── + + +async def test_mrtr_ctx_once_persists_across_multiple_rounds(): + """once() guard survives 3+ rounds — executed-keys round-trip through request_state.""" + + async def handler( + ctx: ServerRequestContext, params: types.CallToolRequestParams + ) -> types.CallToolResult | types.IncompleteResult: + mrtr = MrtrCtx(params) + mrtr.once("init", lambda: audit_log("init")) + + # Step progression tracked via executed keys, not raw input_responses + # (which only carries the latest round's answers per SEP). + if not mrtr.has_run("got_a"): + if not input_response(params, "a"): + return mrtr.incomplete({"a": UNITS_REQUEST}) + mrtr.once("got_a", lambda: audit_log("after_a")) + + if not input_response(params, "b"): + return mrtr.incomplete({"b": UNITS_REQUEST}) + mrtr.once("got_b", lambda: audit_log("after_b")) + return types.CallToolResult(content=[types.TextContent(text="done")]) + + async def elicitation_cb(context: ClientRequestContext, params: types.ElicitRequestParams) -> types.ElicitResult: + return types.ElicitResult(action="accept", content={"units": "metric"}) + + async with Client(make_server(handler), elicitation_callback=elicitation_cb) as client: + await client.call_tool("multi", {}) + + assert _audit == snapshot(["init", "after_a", "after_b"]) + + +# ─── input_response helper ─────────────────────────────────────────────────── + + +def test_input_response_returns_none_on_missing(): + params = types.CallToolRequestParams(name="x") + assert input_response(params, "key") is None + + +def test_input_response_returns_none_on_decline(): + params = types.CallToolRequestParams(name="x", input_responses={"key": {"action": "decline"}}) + assert input_response(params, "key") is None + + +def test_input_response_returns_content_on_accept(): + params = types.CallToolRequestParams(name="x", input_responses={"key": {"action": "accept", "content": {"v": 1}}}) + assert input_response(params, "key") == {"v": 1} + + +# ─── dispatch_by_version ───────────────────────────────────────────────────── + + +async def _mrtr_path(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + return types.CallToolResult(content=[types.TextContent(text="mrtr")]) + + +async def _sse_path(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + return types.CallToolResult(content=[types.TextContent(text="sse")]) + + +async def test_dispatch_by_version_routes_to_mrtr_when_at_or_above(): + """Negotiated version >= min → MRTR handler.""" + handler = dispatch_by_version(mrtr=_mrtr_path, sse=_sse_path, min_mrtr_version=types.LATEST_PROTOCOL_VERSION) + async with Client(make_server(handler)) as client: + result = await client.call_tool("x", {}) + assert result == snapshot(types.CallToolResult(content=[types.TextContent(text="mrtr")])) + + +async def test_dispatch_by_version_routes_to_sse_when_below(): + """Negotiated version < min → SSE handler.""" + handler = dispatch_by_version(mrtr=_mrtr_path, sse=_sse_path, min_mrtr_version="9999-01-01") + async with Client(make_server(handler)) as client: + result = await client.call_tool("x", {}) + assert result == snapshot(types.CallToolResult(content=[types.TextContent(text="sse")])) + + +# ─── Option H: linear_mrtr — continuation-based, genuine suspension ────────── + + +class Units(BaseModel): + units: str + + +async def test_linear_mrtr_side_effects_run_exactly_once(): + """The Option B footgun, fixed: ``await ctx.elicit()`` is a real suspension point. + + Side-effects above and below the await fire exactly once — the coroutine + frame is held in the ContinuationStore across MRTR rounds, so there is + no re-entry. + """ + + async def weather(ctx: LinearCtx, args: dict[str, Any]) -> str: + location = args["location"] + audit_log(f"before:{location}") # would fire twice under Option B + prefs = await ctx.elicit("Which units?", Units) + audit_log(f"after:{prefs.units}") + return lookup_weather(location, prefs.units) + + store = ContinuationStore() + server = make_server(linear_mrtr(weather, store=store)) + + async with store: + async with Client(server, elicitation_callback=pick_metric) as client: + result = await client.call_tool("weather", {"location": "Tokyo"}) + assert result == snapshot(types.CallToolResult(content=[types.TextContent(text="Weather in Tokyo: 22°C")])) + + assert _audit == snapshot(["before:Tokyo", "after:metric"]) + + +async def test_linear_mrtr_multiple_elicits(): + """Two sequential ``await ctx.elicit()`` calls — three MRTR rounds.""" + + class Lang(BaseModel): + lang: str + + async def handler(ctx: LinearCtx, args: dict[str, Any]) -> str: + audit_log("start") + u = await ctx.elicit("Which units?", Units) + audit_log(f"got units={u.units}") + lang = await ctx.elicit("Which language?", Lang) + audit_log(f"got lang={lang.lang}") + return f"{u.units}/{lang.lang}" + + store = ContinuationStore() + server = make_server(linear_mrtr(handler, store=store)) + + answers = {"Which units?": {"units": "metric"}, "Which language?": {"lang": "en"}} + + async def elicitation_cb(context: ClientRequestContext, params: types.ElicitRequestParams) -> types.ElicitResult: + assert isinstance(params, types.ElicitRequestFormParams) + return types.ElicitResult(action="accept", content=dict(answers[params.message])) + + async with store: + async with Client(server, elicitation_callback=elicitation_cb) as client: + result = await client.call_tool("multi", {}) + assert result == snapshot(types.CallToolResult(content=[types.TextContent(text="metric/en")])) + + assert _audit == snapshot(["start", "got units=metric", "got lang=en"]) + + +async def test_linear_mrtr_elicit_declined_propagates(): + """User declines → handler sees ElicitDeclined, wrapper returns a cancelled result.""" + + async def handler(ctx: LinearCtx, args: dict[str, Any]) -> str: + await ctx.elicit("Confirm?", Units) + return "never reached" # pragma: no cover + + store = ContinuationStore() + server = make_server(linear_mrtr(handler, store=store)) + + async def decline_cb(context: ClientRequestContext, params: types.ElicitRequestParams) -> types.ElicitResult: + return types.ElicitResult(action="decline") + + async with store: + async with Client(server, elicitation_callback=decline_cb) as client: + result = await client.call_tool("confirm", {}) + assert result == snapshot(types.CallToolResult(content=[types.TextContent(text="Cancelled (decline).")])) + + +async def test_linear_mrtr_handler_exception_surfaces(): + """Exception in handler → surfaced as is_error result.""" + + async def handler(ctx: LinearCtx, args: dict[str, Any]) -> str: + raise ValueError("boom") + + store = ContinuationStore() + server = make_server(linear_mrtr(handler, store=store)) + + async with store: + async with Client(server) as client: + result = await client.call_tool("fail", {}) + assert result == snapshot(types.CallToolResult(content=[types.TextContent(text="boom")], is_error=True)) + + +async def test_linear_mrtr_unknown_token_errors(): + """Retry with a request_state that isn't in the store → clear error.""" + + async def handler(ctx: LinearCtx, args: dict[str, Any]) -> str: # pragma: no cover + return "x" + + store = ContinuationStore() + wrapped = linear_mrtr(handler, store=store) + + async with store: + params = types.CallToolRequestParams(name="x", request_state="bogus") + result = await wrapped(None, params) + assert isinstance(result, types.CallToolResult) + assert result.is_error + assert "expired or unknown" in result.content[0].text # type: ignore[union-attr] + + +async def test_linear_mrtr_handler_can_return_call_tool_result(): + """Handler returning CallToolResult directly (not str shorthand).""" + + async def handler(ctx: LinearCtx, args: dict[str, Any]) -> types.CallToolResult: + return types.CallToolResult(content=[types.TextContent(text="direct")]) + + store = ContinuationStore() + server = make_server(linear_mrtr(handler, store=store)) + + async with store: + async with Client(server) as client: + result = await client.call_tool("direct", {}) + assert result == snapshot(types.CallToolResult(content=[types.TextContent(text="direct")])) + + +async def test_linear_mrtr_store_not_entered_raises(): + """Calling without entering the store → clear RuntimeError.""" + + async def handler(ctx: LinearCtx, args: dict[str, Any]) -> str: # pragma: no cover + return "x" + + store = ContinuationStore() + wrapped = linear_mrtr(handler, store=store) + + with pytest.raises(RuntimeError, match="ContinuationStore not entered"): + await wrapped(None, types.CallToolRequestParams(name="x")) diff --git a/uv.lock b/uv.lock index 4af3532ea..8a2c6402f 100644 --- a/uv.lock +++ b/uv.lock @@ -10,6 +10,7 @@ resolution-markers = [ members = [ "mcp", "mcp-everything-server", + "mcp-mrtr-options", "mcp-simple-auth", "mcp-simple-auth-client", "mcp-simple-chatbot", @@ -934,6 +935,17 @@ dev = [ { name = "ruff", specifier = ">=0.6.9" }, ] +[[package]] +name = "mcp-mrtr-options" +version = "0.1.0" +source = { editable = "examples/servers/mrtr-options" } +dependencies = [ + { name = "mcp" }, +] + +[package.metadata] +requires-dist = [{ name = "mcp", editable = "." }] + [[package]] name = "mcp-simple-auth" version = "0.1.0"