From 0a22a9dc33ee5396877f2cdae2165e5a05108357 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 12 Feb 2026 15:55:54 +0000 Subject: [PATCH 01/84] refactor: replace lowlevel Server decorators with on_* constructor kwargs (#1985) --- README.v2.md | 314 ++++---- docs/experimental/index.md | 7 +- docs/migration.md | 302 +++++++- .../mcp_everything_server/server.py | 32 +- .../mcp_simple_pagination/server.py | 251 +++---- .../simple-prompt/mcp_simple_prompt/server.py | 57 +- .../mcp_simple_resource/server.py | 69 +- .../server.py | 119 +-- .../mcp_simple_streamablehttp/server.py | 141 ++-- .../mcp_simple_task_interactive/server.py | 89 ++- .../simple-task/mcp_simple_task/server.py | 78 +- .../simple-tool/mcp_simple_tool/server.py | 58 +- .../mcp_sse_polling_demo/server.py | 178 ++--- .../__main__.py | 96 ++- examples/snippets/servers/lowlevel/basic.py | 52 +- .../lowlevel/direct_call_tool_result.py | 64 +- .../snippets/servers/lowlevel/lifespan.py | 82 +- .../servers/lowlevel/structured_output.py | 92 ++- .../snippets/servers/pagination_example.py | 17 +- src/mcp/server/__init__.py | 3 +- src/mcp/server/context.py | 2 +- .../server/experimental/request_context.py | 5 +- .../experimental/task_result_handler.py | 19 +- src/mcp/server/lowlevel/__init__.py | 2 +- src/mcp/server/lowlevel/experimental.py | 248 +++---- src/mcp/server/lowlevel/func_inspection.py | 53 -- src/mcp/server/lowlevel/server.py | 699 ++++++------------ src/mcp/server/mcpserver/server.py | 159 +++- src/mcp/server/session.py | 24 +- src/mcp/server/streamable_http_manager.py | 2 +- src/mcp/shared/_context.py | 10 +- src/mcp/shared/experimental/tasks/helpers.py | 5 +- src/mcp/shared/message.py | 7 +- tests/client/test_client.py | 62 +- tests/client/test_http_unicode.py | 103 +-- tests/client/test_list_methods_cursor.py | 14 +- tests/client/test_output_schema_validation.py | 204 ++--- tests/client/transports/test_memory.py | 34 +- tests/experimental/tasks/client/test_tasks.py | 541 +++++--------- .../tasks/server/test_integration.py | 337 +++------ .../tasks/server/test_run_task_flow.py | 441 ++++------- .../experimental/tasks/server/test_server.py | 448 ++++------- .../tasks/test_elicitation_scenarios.py | 149 ++-- .../tasks/test_spec_compliance.py | 78 +- tests/issues/test_129_resource_templates.py | 37 +- tests/issues/test_152_resource_mime_type.py | 41 +- .../test_1574_resource_uri_validation.py | 75 +- tests/issues/test_342_base64_encoding.py | 89 +-- tests/issues/test_88_random_error.py | 54 +- tests/server/lowlevel/test_func_inspection.py | 292 -------- tests/server/lowlevel/test_server_listing.py | 143 ++-- .../server/lowlevel/test_server_pagination.py | 126 ++-- .../mcpserver/auth/test_auth_integration.py | 3 +- tests/server/mcpserver/prompts/test_base.py | 4 +- .../server/mcpserver/prompts/test_manager.py | 3 +- tests/server/mcpserver/test_server.py | 21 + tests/server/test_cancel_handling.py | 45 +- tests/server/test_completion_with_context.py | 129 ++-- tests/server/test_lifespan.py | 24 +- .../server/test_lowlevel_input_validation.py | 311 -------- .../server/test_lowlevel_output_validation.py | 476 ------------ .../server/test_lowlevel_tool_annotations.py | 122 +-- tests/server/test_read_resource.py | 134 ++-- tests/server/test_session.py | 58 +- tests/server/test_streamable_http_manager.py | 16 +- tests/shared/test_memory.py | 30 - tests/shared/test_progress_notifications.py | 181 +++-- tests/shared/test_session.py | 41 +- tests/shared/test_sse.py | 166 +++-- tests/shared/test_streamable_http.py | 586 +++++++-------- tests/shared/test_ws.py | 90 ++- 71 files changed, 3563 insertions(+), 5481 deletions(-) delete mode 100644 src/mcp/server/lowlevel/func_inspection.py delete mode 100644 tests/server/lowlevel/test_func_inspection.py delete mode 100644 tests/server/test_lowlevel_input_validation.py delete mode 100644 tests/server/test_lowlevel_output_validation.py delete mode 100644 tests/shared/test_memory.py diff --git a/README.v2.md b/README.v2.md index 67f181811f..bd6927bf92 100644 --- a/README.v2.md +++ b/README.v2.md @@ -1642,12 +1642,11 @@ uv run examples/snippets/servers/lowlevel/lifespan.py from collections.abc import AsyncIterator from contextlib import asynccontextmanager -from typing import Any +from typing import TypedDict import mcp.server.stdio from mcp import types -from mcp.server.lowlevel import NotificationOptions, Server -from mcp.server.models import InitializationOptions +from mcp.server import Server, ServerRequestContext # Mock database class for example @@ -1670,52 +1669,58 @@ class Database: return [{"id": "1", "name": "Example", "query": query_str}] +class AppContext(TypedDict): + db: Database + + @asynccontextmanager -async def server_lifespan(_server: Server) -> AsyncIterator[dict[str, Any]]: +async def server_lifespan(_server: Server[AppContext]) -> AsyncIterator[AppContext]: """Manage server startup and shutdown lifecycle.""" - # Initialize resources on startup db = await Database.connect() try: yield {"db": db} finally: - # Clean up on shutdown await db.disconnect() -# Pass lifespan to server -server = Server("example-server", lifespan=server_lifespan) - - -@server.list_tools() -async def handle_list_tools() -> list[types.Tool]: +async def handle_list_tools( + ctx: ServerRequestContext[AppContext], params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: """List available tools.""" - return [ - types.Tool( - name="query_db", - description="Query the database", - input_schema={ - "type": "object", - "properties": {"query": {"type": "string", "description": "SQL query to execute"}}, - "required": ["query"], - }, - ) - ] + return types.ListToolsResult( + tools=[ + types.Tool( + name="query_db", + description="Query the database", + input_schema={ + "type": "object", + "properties": {"query": {"type": "string", "description": "SQL query to execute"}}, + "required": ["query"], + }, + ) + ] + ) -@server.call_tool() -async def query_db(name: str, arguments: dict[str, Any]) -> list[types.TextContent]: +async def handle_call_tool( + ctx: ServerRequestContext[AppContext], params: types.CallToolRequestParams +) -> types.CallToolResult: """Handle database query tool call.""" - if name != "query_db": - raise ValueError(f"Unknown tool: {name}") + if params.name != "query_db": + raise ValueError(f"Unknown tool: {params.name}") - # Access lifespan context - ctx = server.request_context db = ctx.lifespan_context["db"] + results = await db.query((params.arguments or {})["query"]) - # Execute query - results = await db.query(arguments["query"]) + return types.CallToolResult(content=[types.TextContent(type="text", text=f"Query results: {results}")]) - return [types.TextContent(type="text", text=f"Query results: {results}")] + +server = Server( + "example-server", + lifespan=server_lifespan, + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) async def run(): @@ -1724,14 +1729,7 @@ async def run(): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="example-server", - server_version="0.1.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) @@ -1760,32 +1758,30 @@ import asyncio import mcp.server.stdio from mcp import types -from mcp.server.lowlevel import NotificationOptions, Server -from mcp.server.models import InitializationOptions - -# Create a server instance -server = Server("example-server") +from mcp.server import Server, ServerRequestContext -@server.list_prompts() -async def handle_list_prompts() -> list[types.Prompt]: +async def handle_list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListPromptsResult: """List available prompts.""" - return [ - types.Prompt( - name="example-prompt", - description="An example prompt template", - arguments=[types.PromptArgument(name="arg1", description="Example argument", required=True)], - ) - ] + return types.ListPromptsResult( + prompts=[ + types.Prompt( + name="example-prompt", + description="An example prompt template", + arguments=[types.PromptArgument(name="arg1", description="Example argument", required=True)], + ) + ] + ) -@server.get_prompt() -async def handle_get_prompt(name: str, arguments: dict[str, str] | None) -> types.GetPromptResult: +async def handle_get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> types.GetPromptResult: """Get a specific prompt by name.""" - if name != "example-prompt": - raise ValueError(f"Unknown prompt: {name}") + if params.name != "example-prompt": + raise ValueError(f"Unknown prompt: {params.name}") - arg1_value = (arguments or {}).get("arg1", "default") + arg1_value = (params.arguments or {}).get("arg1", "default") return types.GetPromptResult( description="Example prompt", @@ -1798,20 +1794,20 @@ async def handle_get_prompt(name: str, arguments: dict[str, str] | None) -> type ) +server = Server( + "example-server", + on_list_prompts=handle_list_prompts, + on_get_prompt=handle_get_prompt, +) + + async def run(): """Run the basic low-level server.""" async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="example", - server_version="0.1.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) @@ -1835,62 +1831,67 @@ uv run examples/snippets/servers/lowlevel/structured_output.py """ import asyncio -from typing import Any +import json import mcp.server.stdio from mcp import types -from mcp.server.lowlevel import NotificationOptions, Server -from mcp.server.models import InitializationOptions +from mcp.server import Server, ServerRequestContext -server = Server("example-server") - -@server.list_tools() -async def list_tools() -> list[types.Tool]: +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: """List available tools with structured output schemas.""" - return [ - types.Tool( - name="get_weather", - description="Get current weather for a city", - input_schema={ - "type": "object", - "properties": {"city": {"type": "string", "description": "City name"}}, - "required": ["city"], - }, - output_schema={ - "type": "object", - "properties": { - "temperature": {"type": "number", "description": "Temperature in Celsius"}, - "condition": {"type": "string", "description": "Weather condition"}, - "humidity": {"type": "number", "description": "Humidity percentage"}, - "city": {"type": "string", "description": "City name"}, + return types.ListToolsResult( + tools=[ + types.Tool( + name="get_weather", + description="Get current weather for a city", + input_schema={ + "type": "object", + "properties": {"city": {"type": "string", "description": "City name"}}, + "required": ["city"], }, - "required": ["temperature", "condition", "humidity", "city"], - }, - ) - ] + output_schema={ + "type": "object", + "properties": { + "temperature": {"type": "number", "description": "Temperature in Celsius"}, + "condition": {"type": "string", "description": "Weather condition"}, + "humidity": {"type": "number", "description": "Humidity percentage"}, + "city": {"type": "string", "description": "City name"}, + }, + "required": ["temperature", "condition", "humidity", "city"], + }, + ) + ] + ) -@server.call_tool() -async def call_tool(name: str, arguments: dict[str, Any]) -> dict[str, Any]: +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: """Handle tool calls with structured output.""" - if name == "get_weather": - city = arguments["city"] + if params.name == "get_weather": + city = (params.arguments or {})["city"] - # Simulated weather data - in production, call a weather API weather_data = { "temperature": 22.5, "condition": "partly cloudy", "humidity": 65, - "city": city, # Include the requested city + "city": city, } - # low-level server will validate structured output against the tool's - # output schema, and additionally serialize it into a TextContent block - # for backwards compatibility with pre-2025-06-18 clients. - return weather_data - else: - raise ValueError(f"Unknown tool: {name}") + return types.CallToolResult( + content=[types.TextContent(type="text", text=json.dumps(weather_data, indent=2))], + structured_content=weather_data, + ) + + raise ValueError(f"Unknown tool: {params.name}") + + +server = Server( + "example-server", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) async def run(): @@ -1899,14 +1900,7 @@ async def run(): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="structured-output-example", - server_version="0.1.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) @@ -1917,18 +1911,11 @@ if __name__ == "__main__": _Full example: [examples/snippets/servers/lowlevel/structured_output.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/lowlevel/structured_output.py)_ -Tools can return data in four ways: +With the low-level server, handlers always return `CallToolResult` directly. You construct both the human-readable `content` and the machine-readable `structured_content` yourself, giving you full control over the response. -1. **Content only**: Return a list of content blocks (default behavior before spec revision 2025-06-18) -2. **Structured data only**: Return a dictionary that will be serialized to JSON (Introduced in spec revision 2025-06-18) -3. **Both**: Return a tuple of (content, structured_data) preferred option to use for backwards compatibility -4. **Direct CallToolResult**: Return `CallToolResult` directly for full control (including `_meta` field) +##### Returning CallToolResult with `_meta` -When an `outputSchema` is defined, the server automatically validates the structured output against the schema. This ensures type safety and helps catch errors early. - -##### Returning CallToolResult Directly - -For full control over the response including the `_meta` field (for passing data to client applications without exposing it to the model), return `CallToolResult` directly: +For passing data to client applications without exposing it to the model, use the `_meta` field on `CallToolResult`: ```python @@ -1937,44 +1924,49 @@ uv run examples/snippets/servers/lowlevel/direct_call_tool_result.py """ import asyncio -from typing import Any import mcp.server.stdio from mcp import types -from mcp.server.lowlevel import NotificationOptions, Server -from mcp.server.models import InitializationOptions - -server = Server("example-server") +from mcp.server import Server, ServerRequestContext -@server.list_tools() -async def list_tools() -> list[types.Tool]: +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: """List available tools.""" - return [ - types.Tool( - name="advanced_tool", - description="Tool with full control including _meta field", - input_schema={ - "type": "object", - "properties": {"message": {"type": "string"}}, - "required": ["message"], - }, - ) - ] + return types.ListToolsResult( + tools=[ + types.Tool( + name="advanced_tool", + description="Tool with full control including _meta field", + input_schema={ + "type": "object", + "properties": {"message": {"type": "string"}}, + "required": ["message"], + }, + ) + ] + ) -@server.call_tool() -async def handle_call_tool(name: str, arguments: dict[str, Any]) -> types.CallToolResult: +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: """Handle tool calls by returning CallToolResult directly.""" - if name == "advanced_tool": - message = str(arguments.get("message", "")) + if params.name == "advanced_tool": + message = (params.arguments or {}).get("message", "") return types.CallToolResult( content=[types.TextContent(type="text", text=f"Processed: {message}")], structured_content={"result": "success", "message": message}, _meta={"hidden": "data for client applications only"}, ) - raise ValueError(f"Unknown tool: {name}") + raise ValueError(f"Unknown tool: {params.name}") + + +server = Server( + "example-server", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) async def run(): @@ -1983,14 +1975,7 @@ async def run(): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="example", - server_version="0.1.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) @@ -2001,8 +1986,6 @@ if __name__ == "__main__": _Full example: [examples/snippets/servers/lowlevel/direct_call_tool_result.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/lowlevel/direct_call_tool_result.py)_ -**Note:** When returning `CallToolResult`, you bypass the automatic content/structured conversion. You must construct the complete response yourself. - ### Pagination (Advanced) For servers that need to handle large datasets, the low-level server provides paginated versions of list operations. This is an optional optimization - most servers won't need pagination unless they're dealing with hundreds or thousands of items. @@ -2011,25 +1994,23 @@ For servers that need to handle large datasets, the low-level server provides pa ```python -"""Example of implementing pagination with MCP server decorators.""" +"""Example of implementing pagination with the low-level MCP server.""" from mcp import types -from mcp.server.lowlevel import Server - -# Initialize the server -server = Server("paginated-server") +from mcp.server import Server, ServerRequestContext # Sample data to paginate ITEMS = [f"Item {i}" for i in range(1, 101)] # 100 items -@server.list_resources() -async def list_resources_paginated(request: types.ListResourcesRequest) -> types.ListResourcesResult: +async def handle_list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListResourcesResult: """List resources with pagination support.""" page_size = 10 # Extract cursor from request params - cursor = request.params.cursor if request.params is not None else None + cursor = params.cursor if params is not None else None # Parse cursor to get offset start = 0 if cursor is None else int(cursor) @@ -2045,6 +2026,9 @@ async def list_resources_paginated(request: types.ListResourcesRequest) -> types next_cursor = str(end) if end < len(ITEMS) else None return types.ListResourcesResult(resources=page_items, next_cursor=next_cursor) + + +server = Server("paginated-server", on_list_resources=handle_list_resources) ``` _Full example: [examples/snippets/servers/pagination_example.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/pagination_example.py)_ diff --git a/docs/experimental/index.md b/docs/experimental/index.md index 1d496b3f10..c97fe2a3d6 100644 --- a/docs/experimental/index.md +++ b/docs/experimental/index.md @@ -27,10 +27,9 @@ Tasks are useful for: Experimental features are accessed via the `.experimental` property: ```python -# Server-side -@server.experimental.get_task() -async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: - ... +# Server-side: enable task support (auto-registers default handlers) +server = Server(name="my-server") +server.experimental.enable_tasks() # Client-side result = await session.experimental.call_tool_as_task("tool_name", {"arg": "value"}) diff --git a/docs/migration.md b/docs/migration.md index 7d30f0ac92..17fd92bd06 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -351,7 +351,6 @@ The nested `RequestParams.Meta` Pydantic model class has been replaced with a to - `RequestParams.Meta` (Pydantic model) → `RequestParamsMeta` (TypedDict) - Attribute access (`meta.progress_token`) → Dictionary access (`meta.get("progress_token")`) - `progress_token` field changed from `ProgressToken | None = None` to `NotRequired[ProgressToken]` -` **In request context handlers:** @@ -364,11 +363,12 @@ async def handle_tool(name: str, arguments: dict) -> list[TextContent]: await ctx.session.send_progress_notification(ctx.meta.progress_token, 0.5, 100) # After (v2) -@server.call_tool() -async def handle_tool(name: str, arguments: dict) -> list[TextContent]: - ctx = server.request_context +async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: if ctx.meta and "progress_token" in ctx.meta: await ctx.session.send_progress_notification(ctx.meta["progress_token"], 0.5, 100) + ... + +server = Server("my-server", on_call_tool=handle_call_tool) ``` ### `RequestContext` and `ProgressContext` type parameters simplified @@ -471,12 +471,292 @@ await client.read_resource("test://resource") await client.read_resource(str(my_any_url)) ``` +### Lowlevel `Server`: constructor parameters are now keyword-only + +All parameters after `name` are now keyword-only. If you were passing `version` or other parameters positionally, use keyword arguments instead: + +```python +# Before (v1) +server = Server("my-server", "1.0") + +# After (v2) +server = Server("my-server", version="1.0") +``` + +### Lowlevel `Server`: type parameter reduced from 2 to 1 + +The `Server` class previously had two type parameters: `Server[LifespanResultT, RequestT]`. The `RequestT` parameter has been removed — handlers now receive typed params directly rather than a generic request type. + +```python +# Before (v1) +from typing import Any + +from mcp.server.lowlevel.server import Server + +server: Server[dict[str, Any], Any] = Server(...) + +# After (v2) +from typing import Any + +from mcp.server import Server + +server: Server[dict[str, Any]] = Server(...) +``` + +### Lowlevel `Server`: `request_handlers` and `notification_handlers` attributes removed + +The public `server.request_handlers` and `server.notification_handlers` dictionaries have been removed. Handler registration is now done exclusively through constructor `on_*` keyword arguments. There is no public API to register handlers after construction. + +```python +# Before (v1) — direct dict access +from mcp.types import ListToolsRequest + +if ListToolsRequest in server.request_handlers: + ... + +# After (v2) — no public access to handler dicts +# Use the on_* constructor params to register handlers +server = Server("my-server", on_list_tools=handle_list_tools) +``` + +### Lowlevel `Server`: decorator-based handlers replaced with constructor `on_*` params + +The lowlevel `Server` class no longer uses decorator methods for handler registration. Instead, handlers are passed as `on_*` keyword arguments to the constructor. + +**Before (v1):** + +```python +from mcp.server.lowlevel.server import Server + +server = Server("my-server") + +@server.list_tools() +async def handle_list_tools(): + return [types.Tool(name="my_tool", description="A tool", inputSchema={})] + +@server.call_tool() +async def handle_call_tool(name: str, arguments: dict): + return [types.TextContent(type="text", text=f"Called {name}")] +``` + +**After (v2):** + +```python +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + CallToolRequestParams, + CallToolResult, + ListToolsResult, + PaginatedRequestParams, + TextContent, + Tool, +) + +async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="my_tool", description="A tool", input_schema={})]) + + +async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + return CallToolResult( + content=[TextContent(type="text", text=f"Called {params.name}")], + is_error=False, + ) + +server = Server("my-server", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) +``` + +**Key differences:** + +- Handlers receive `(ctx, params)` instead of the full request object or unpacked arguments. `ctx` is a `ServerRequestContext` with `session`, `lifespan_context`, and `experimental` fields (plus `request_id`, `meta`, etc. for request handlers). `params` is the typed request params object. +- Handlers return the full result type (e.g. `ListToolsResult`) rather than unwrapped values (e.g. `list[Tool]`). +- The automatic `jsonschema` input/output validation that the old `call_tool()` decorator performed has been removed. There is no built-in replacement — if you relied on schema validation in the lowlevel server, you will need to validate inputs yourself in your handler. + +**Notification handlers:** + +```python +from mcp.server import Server, ServerRequestContext +from mcp.types import ProgressNotificationParams + + +async def handle_progress(ctx: ServerRequestContext, params: ProgressNotificationParams) -> None: + print(f"Progress: {params.progress}/{params.total}") + +server = Server("my-server", on_progress=handle_progress) +``` + +### Lowlevel `Server`: automatic return value wrapping removed + +The old decorator-based handlers performed significant automatic wrapping of return values. This magic has been removed — handlers now return fully constructed result types. If you want these conveniences, use `MCPServer` (previously `FastMCP`) instead of the lowlevel `Server`. + +**`call_tool()` — structured output wrapping removed:** + +The old decorator accepted several return types and auto-wrapped them into `CallToolResult`: + +```python +# Before (v1) — returning a dict auto-wrapped into structured_content + JSON TextContent +@server.call_tool() +async def handle(name: str, arguments: dict) -> dict: + return {"temperature": 22.5, "city": "London"} + +# Before (v1) — returning a list auto-wrapped into CallToolResult.content +@server.call_tool() +async def handle(name: str, arguments: dict) -> list[TextContent]: + return [TextContent(type="text", text="Done")] +``` + +```python +# After (v2) — construct the full result yourself +import json + +async def handle(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + data = {"temperature": 22.5, "city": "London"} + return CallToolResult( + content=[TextContent(type="text", text=json.dumps(data, indent=2))], + structured_content=data, + ) +``` + +Note: `params.arguments` can be `None` (the old decorator defaulted it to `{}`). Use `params.arguments or {}` to preserve the old behavior. + +**`read_resource()` — content type wrapping removed:** + +The old decorator auto-wrapped `str` into `TextResourceContents` and `bytes` into `BlobResourceContents` (with base64 encoding), and applied a default mime type of `text/plain`: + +```python +# Before (v1) — str/bytes auto-wrapped with mime type defaulting +@server.read_resource() +async def handle(uri: str) -> str: + return "file contents" + +@server.read_resource() +async def handle(uri: str) -> bytes: + return b"\x89PNG..." +``` + +```python +# After (v2) — construct TextResourceContents or BlobResourceContents yourself +import base64 + +async def handle_read(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: + # Text content + return ReadResourceResult( + contents=[TextResourceContents(uri=str(params.uri), text="file contents", mime_type="text/plain")] + ) + +async def handle_read(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: + # Binary content — you must base64-encode it yourself + return ReadResourceResult( + contents=[BlobResourceContents( + uri=str(params.uri), + blob=base64.b64encode(b"\x89PNG...").decode("utf-8"), + mime_type="image/png", + )] + ) +``` + +**`list_tools()`, `list_resources()`, `list_prompts()` — list wrapping removed:** + +The old decorators accepted bare lists and wrapped them into the result type: + +```python +# Before (v1) +@server.list_tools() +async def handle() -> list[Tool]: + return [Tool(name="my_tool", ...)] + +# After (v2) +async def handle(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="my_tool", ...)]) +``` + +**Using `MCPServer` instead:** + +If you prefer the convenience of automatic wrapping, use `MCPServer` which still provides these features through its `@mcp.tool()`, `@mcp.resource()`, and `@mcp.prompt()` decorators. The lowlevel `Server` is intentionally minimal — it provides no magic and gives you full control over the MCP protocol types. + +### Lowlevel `Server`: `request_context` property removed + +The `server.request_context` property has been removed. Request context is now passed directly to handlers as the first argument (`ctx`). The `request_ctx` module-level contextvar is now an internal implementation detail and should not be relied upon. + +**Before (v1):** + +```python +from mcp.server.lowlevel.server import request_ctx + +@server.call_tool() +async def handle_call_tool(name: str, arguments: dict): + ctx = server.request_context # or request_ctx.get() + await ctx.session.send_log_message(level="info", data="Processing...") + return [types.TextContent(type="text", text="Done")] +``` + +**After (v2):** + +```python +from mcp.server import ServerRequestContext +from mcp.types import CallToolRequestParams, CallToolResult, TextContent + + +async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + await ctx.session.send_log_message(level="info", data="Processing...") + return CallToolResult( + content=[TextContent(type="text", text="Done")], + is_error=False, + ) +``` + +### `RequestContext`: request-specific fields are now optional + +The `RequestContext` class now uses optional fields for request-specific data (`request_id`, `meta`, etc.) so it can be used for both request and notification handlers. In notification handlers, these fields are `None`. + +```python +from mcp.server import ServerRequestContext + +# request_id, meta, etc. are available in request handlers +# but None in notification handlers +``` + +### Experimental: task handler decorators removed + +The experimental decorator methods on `ExperimentalHandlers` (`@server.experimental.list_tasks()`, `@server.experimental.get_task()`, etc.) have been removed. + +Default task handlers are still registered automatically via `server.experimental.enable_tasks()`. Custom handlers can be passed as `on_*` kwargs to override specific defaults. + +**Before (v1):** + +```python +server = Server("my-server") +server.experimental.enable_tasks() + +@server.experimental.get_task() +async def custom_get_task(request: GetTaskRequest) -> GetTaskResult: + ... +``` + +**After (v2):** + +```python +from mcp.server import Server, ServerRequestContext +from mcp.types import GetTaskRequestParams, GetTaskResult + + +async def custom_get_task(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: + ... + + +server = Server("my-server") +server.experimental.enable_tasks(on_get_task=custom_get_task) +``` + ## Deprecations ## Bug Fixes +### Lowlevel `Server`: `subscribe` capability now correctly reported + +Previously, the lowlevel `Server` hardcoded `subscribe=False` in resource capabilities even when a `subscribe_resource()` handler was registered. The `subscribe` capability is now dynamically set to `True` when an `on_subscribe_resource` handler is provided. Clients that previously didn't see `subscribe: true` in capabilities will now see it when a handler is registered, which may change client behavior. + ### Extra fields no longer allowed on top-level MCP types MCP protocol types no longer accept arbitrary extra fields at the top level. This matches the MCP specification which only allows extra fields within `_meta` objects, not on the types themselves. @@ -506,16 +786,16 @@ params = CallToolRequestParams( The `streamable_http_app()` method is now available directly on the lowlevel `Server` class, not just `MCPServer`. This allows using the streamable HTTP transport without the MCPServer wrapper. ```python -from mcp.server.lowlevel.server import Server +from mcp.server import Server, ServerRequestContext +from mcp.types import ListToolsResult, PaginatedRequestParams -server = Server("my-server") -# Register handlers... -@server.list_tools() -async def list_tools(): - return [...] +async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[...]) + + +server = Server("my-server", on_list_tools=handle_list_tools) -# Create a Starlette app for streamable HTTP app = server.streamable_http_app( streamable_http_path="/mcp", json_response=False, diff --git a/examples/servers/everything-server/mcp_everything_server/server.py b/examples/servers/everything-server/mcp_everything_server/server.py index 4fb7d9a1d3..2101cff28f 100644 --- a/examples/servers/everything-server/mcp_everything_server/server.py +++ b/examples/servers/everything-server/mcp_everything_server/server.py @@ -10,6 +10,7 @@ import logging import click +from mcp.server import ServerRequestContext from mcp.server.mcpserver import Context, MCPServer from mcp.server.mcpserver.prompts.base import UserMessage from mcp.server.session import ServerSession @@ -20,13 +21,17 @@ CompletionArgument, CompletionContext, EmbeddedResource, + EmptyResult, ImageContent, JSONRPCMessage, PromptReference, ResourceTemplateReference, SamplingMessage, + SetLevelRequestParams, + SubscribeRequestParams, TextContent, TextResourceContents, + UnsubscribeRequestParams, ) from pydantic import BaseModel, Field @@ -393,28 +398,29 @@ def test_prompt_with_image() -> list[UserMessage]: # Custom request handlers # TODO(felix): Add public APIs to MCPServer for subscribe_resource, unsubscribe_resource, # and set_logging_level to avoid accessing protected _lowlevel_server attribute. -@mcp._lowlevel_server.set_logging_level() # pyright: ignore[reportPrivateUsage] -async def handle_set_logging_level(level: str) -> None: +async def handle_set_logging_level(ctx: ServerRequestContext, params: SetLevelRequestParams) -> EmptyResult: """Handle logging level changes""" - logger.info(f"Log level set to: {level}") - # In a real implementation, you would adjust the logging level here - # For conformance testing, we just acknowledge the request + logger.info(f"Log level set to: {params.level}") + return EmptyResult() -async def handle_subscribe(uri: str) -> None: +async def handle_subscribe(ctx: ServerRequestContext, params: SubscribeRequestParams) -> EmptyResult: """Handle resource subscription""" - resource_subscriptions.add(str(uri)) - logger.info(f"Subscribed to resource: {uri}") + resource_subscriptions.add(str(params.uri)) + logger.info(f"Subscribed to resource: {params.uri}") + return EmptyResult() -async def handle_unsubscribe(uri: str) -> None: +async def handle_unsubscribe(ctx: ServerRequestContext, params: UnsubscribeRequestParams) -> EmptyResult: """Handle resource unsubscription""" - resource_subscriptions.discard(str(uri)) - logger.info(f"Unsubscribed from resource: {uri}") + resource_subscriptions.discard(str(params.uri)) + logger.info(f"Unsubscribed from resource: {params.uri}") + return EmptyResult() -mcp._lowlevel_server.subscribe_resource()(handle_subscribe) # pyright: ignore[reportPrivateUsage] -mcp._lowlevel_server.unsubscribe_resource()(handle_unsubscribe) # pyright: ignore[reportPrivateUsage] +mcp._lowlevel_server._add_request_handler("logging/setLevel", handle_set_logging_level) # pyright: ignore[reportPrivateUsage] +mcp._lowlevel_server._add_request_handler("resources/subscribe", handle_subscribe) # pyright: ignore[reportPrivateUsage] +mcp._lowlevel_server._add_request_handler("resources/unsubscribe", handle_unsubscribe) # pyright: ignore[reportPrivateUsage] @mcp.completion() diff --git a/examples/servers/simple-pagination/mcp_simple_pagination/server.py b/examples/servers/simple-pagination/mcp_simple_pagination/server.py index ff45ae2245..bac27a0f1f 100644 --- a/examples/servers/simple-pagination/mcp_simple_pagination/server.py +++ b/examples/servers/simple-pagination/mcp_simple_pagination/server.py @@ -1,17 +1,19 @@ """Simple MCP server demonstrating pagination for tools, resources, and prompts. -This example shows how to use the paginated decorators to handle large lists -of items that need to be split across multiple pages. +This example shows how to implement pagination with the low-level server API +to handle large lists of items that need to be split across multiple pages. """ -from typing import Any +from typing import TypeVar import anyio import click from mcp import types -from mcp.server.lowlevel import Server +from mcp.server import Server, ServerRequestContext from starlette.requests import Request +T = TypeVar("T") + # Sample data - in real scenarios, this might come from a database SAMPLE_TOOLS = [ types.Tool( @@ -44,6 +46,102 @@ ] +def _paginate(cursor: str | None, items: list[T], page_size: int) -> tuple[list[T], str | None]: + """Helper to paginate a list of items given a cursor.""" + if cursor is not None: + try: + start_idx = int(cursor) + except (ValueError, TypeError): + return [], None + else: + start_idx = 0 + + page = items[start_idx : start_idx + page_size] + next_cursor = str(start_idx + page_size) if start_idx + page_size < len(items) else None + return page, next_cursor + + +# Paginated list_tools - returns 5 tools per page +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + cursor = params.cursor if params is not None else None + page, next_cursor = _paginate(cursor, SAMPLE_TOOLS, page_size=5) + return types.ListToolsResult(tools=page, next_cursor=next_cursor) + + +# Paginated list_resources - returns 10 resources per page +async def handle_list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListResourcesResult: + cursor = params.cursor if params is not None else None + page, next_cursor = _paginate(cursor, SAMPLE_RESOURCES, page_size=10) + return types.ListResourcesResult(resources=page, next_cursor=next_cursor) + + +# Paginated list_prompts - returns 7 prompts per page +async def handle_list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListPromptsResult: + cursor = params.cursor if params is not None else None + page, next_cursor = _paginate(cursor, SAMPLE_PROMPTS, page_size=7) + return types.ListPromptsResult(prompts=page, next_cursor=next_cursor) + + +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + # Find the tool in our sample data + tool = next((t for t in SAMPLE_TOOLS if t.name == params.name), None) + if not tool: + raise ValueError(f"Unknown tool: {params.name}") + + return types.CallToolResult( + content=[ + types.TextContent( + type="text", + text=f"Called tool '{params.name}' with arguments: {params.arguments}", + ) + ] + ) + + +async def handle_read_resource( + ctx: ServerRequestContext, params: types.ReadResourceRequestParams +) -> types.ReadResourceResult: + resource = next((r for r in SAMPLE_RESOURCES if r.uri == str(params.uri)), None) + if not resource: + raise ValueError(f"Unknown resource: {params.uri}") + + return types.ReadResourceResult( + contents=[ + types.TextResourceContents( + uri=str(params.uri), + text=f"Content of {resource.name}: This is sample content for the resource.", + mime_type="text/plain", + ) + ] + ) + + +async def handle_get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> types.GetPromptResult: + prompt = next((p for p in SAMPLE_PROMPTS if p.name == params.name), None) + if not prompt: + raise ValueError(f"Unknown prompt: {params.name}") + + message_text = f"This is the prompt '{params.name}'" + if params.arguments: + message_text += f" with arguments: {params.arguments}" + + return types.GetPromptResult( + description=prompt.description, + messages=[ + types.PromptMessage( + role="user", + content=types.TextContent(type="text", text=message_text), + ) + ], + ) + + @click.command() @click.option("--port", default=8000, help="Port to listen on for SSE") @click.option( @@ -53,142 +151,15 @@ help="Transport type", ) def main(port: int, transport: str) -> int: - app = Server("mcp-simple-pagination") - - # Paginated list_tools - returns 5 tools per page - @app.list_tools() - async def list_tools_paginated(request: types.ListToolsRequest) -> types.ListToolsResult: - page_size = 5 - - cursor = request.params.cursor if request.params is not None else None - if cursor is None: - # First page - start_idx = 0 - else: - # Parse cursor to get the start index - try: - start_idx = int(cursor) - except (ValueError, TypeError): - # Invalid cursor, return empty - return types.ListToolsResult(tools=[], next_cursor=None) - - # Get the page of tools - page_tools = SAMPLE_TOOLS[start_idx : start_idx + page_size] - - # Determine if there are more pages - next_cursor = None - if start_idx + page_size < len(SAMPLE_TOOLS): - next_cursor = str(start_idx + page_size) - - return types.ListToolsResult(tools=page_tools, next_cursor=next_cursor) - - # Paginated list_resources - returns 10 resources per page - @app.list_resources() - async def list_resources_paginated( - request: types.ListResourcesRequest, - ) -> types.ListResourcesResult: - page_size = 10 - - cursor = request.params.cursor if request.params is not None else None - if cursor is None: - # First page - start_idx = 0 - else: - # Parse cursor to get the start index - try: - start_idx = int(cursor) - except (ValueError, TypeError): - # Invalid cursor, return empty - return types.ListResourcesResult(resources=[], next_cursor=None) - - # Get the page of resources - page_resources = SAMPLE_RESOURCES[start_idx : start_idx + page_size] - - # Determine if there are more pages - next_cursor = None - if start_idx + page_size < len(SAMPLE_RESOURCES): - next_cursor = str(start_idx + page_size) - - return types.ListResourcesResult(resources=page_resources, next_cursor=next_cursor) - - # Paginated list_prompts - returns 7 prompts per page - @app.list_prompts() - async def list_prompts_paginated( - request: types.ListPromptsRequest, - ) -> types.ListPromptsResult: - page_size = 7 - - cursor = request.params.cursor if request.params is not None else None - if cursor is None: - # First page - start_idx = 0 - else: - # Parse cursor to get the start index - try: - start_idx = int(cursor) - except (ValueError, TypeError): - # Invalid cursor, return empty - return types.ListPromptsResult(prompts=[], next_cursor=None) - - # Get the page of prompts - page_prompts = SAMPLE_PROMPTS[start_idx : start_idx + page_size] - - # Determine if there are more pages - next_cursor = None - if start_idx + page_size < len(SAMPLE_PROMPTS): - next_cursor = str(start_idx + page_size) - - return types.ListPromptsResult(prompts=page_prompts, next_cursor=next_cursor) - - # Implement call_tool handler - @app.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: - # Find the tool in our sample data - tool = next((t for t in SAMPLE_TOOLS if t.name == name), None) - if not tool: - raise ValueError(f"Unknown tool: {name}") - - # Simple mock response - return [ - types.TextContent( - type="text", - text=f"Called tool '{name}' with arguments: {arguments}", - ) - ] - - # Implement read_resource handler - @app.read_resource() - async def read_resource(uri: str) -> str: - # Find the resource in our sample data - resource = next((r for r in SAMPLE_RESOURCES if r.uri == uri), None) - if not resource: - raise ValueError(f"Unknown resource: {uri}") - - # Return a simple string - the decorator will convert it to TextResourceContents - return f"Content of {resource.name}: This is sample content for the resource." - - # Implement get_prompt handler - @app.get_prompt() - async def get_prompt(name: str, arguments: dict[str, str] | None) -> types.GetPromptResult: - # Find the prompt in our sample data - prompt = next((p for p in SAMPLE_PROMPTS if p.name == name), None) - if not prompt: - raise ValueError(f"Unknown prompt: {name}") - - # Simple mock response - message_text = f"This is the prompt '{name}'" - if arguments: - message_text += f" with arguments: {arguments}" - - return types.GetPromptResult( - description=prompt.description, - messages=[ - types.PromptMessage( - role="user", - content=types.TextContent(type="text", text=message_text), - ) - ], - ) + app = Server( + "mcp-simple-pagination", + on_list_tools=handle_list_tools, + on_list_resources=handle_list_resources, + on_list_prompts=handle_list_prompts, + on_call_tool=handle_call_tool, + on_read_resource=handle_read_resource, + on_get_prompt=handle_get_prompt, + ) if transport == "sse": from mcp.server.sse import SseServerTransport diff --git a/examples/servers/simple-prompt/mcp_simple_prompt/server.py b/examples/servers/simple-prompt/mcp_simple_prompt/server.py index cbc5a9d68f..6cf99d4b69 100644 --- a/examples/servers/simple-prompt/mcp_simple_prompt/server.py +++ b/examples/servers/simple-prompt/mcp_simple_prompt/server.py @@ -1,7 +1,7 @@ import anyio import click from mcp import types -from mcp.server.lowlevel import Server +from mcp.server import Server, ServerRequestContext from starlette.requests import Request @@ -30,20 +30,11 @@ def create_messages(context: str | None = None, topic: str | None = None) -> lis return messages -@click.command() -@click.option("--port", default=8000, help="Port to listen on for SSE") -@click.option( - "--transport", - type=click.Choice(["stdio", "sse"]), - default="stdio", - help="Transport type", -) -def main(port: int, transport: str) -> int: - app = Server("mcp-simple-prompt") - - @app.list_prompts() - async def list_prompts() -> list[types.Prompt]: - return [ +async def handle_list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListPromptsResult: + return types.ListPromptsResult( + prompts=[ types.Prompt( name="simple", title="Simple Assistant Prompt", @@ -62,19 +53,35 @@ async def list_prompts() -> list[types.Prompt]: ], ) ] + ) - @app.get_prompt() - async def get_prompt(name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult: - if name != "simple": - raise ValueError(f"Unknown prompt: {name}") - if arguments is None: - arguments = {} +async def handle_get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> types.GetPromptResult: + if params.name != "simple": + raise ValueError(f"Unknown prompt: {params.name}") - return types.GetPromptResult( - messages=create_messages(context=arguments.get("context"), topic=arguments.get("topic")), - description="A simple prompt with optional context and topic arguments", - ) + arguments = params.arguments or {} + + return types.GetPromptResult( + messages=create_messages(context=arguments.get("context"), topic=arguments.get("topic")), + description="A simple prompt with optional context and topic arguments", + ) + + +@click.command() +@click.option("--port", default=8000, help="Port to listen on for SSE") +@click.option( + "--transport", + type=click.Choice(["stdio", "sse"]), + default="stdio", + help="Transport type", +) +def main(port: int, transport: str) -> int: + app = Server( + "mcp-simple-prompt", + on_list_prompts=handle_list_prompts, + on_get_prompt=handle_get_prompt, + ) if transport == "sse": from mcp.server.sse import SseServerTransport diff --git a/examples/servers/simple-resource/mcp_simple_resource/server.py b/examples/servers/simple-resource/mcp_simple_resource/server.py index 588d1044a8..b9b6a1d960 100644 --- a/examples/servers/simple-resource/mcp_simple_resource/server.py +++ b/examples/servers/simple-resource/mcp_simple_resource/server.py @@ -1,8 +1,9 @@ +from urllib.parse import urlparse + import anyio import click from mcp import types -from mcp.server.lowlevel import Server -from mcp.server.lowlevel.helper_types import ReadResourceContents +from mcp.server import Server, ServerRequestContext from starlette.requests import Request SAMPLE_RESOURCES = { @@ -21,20 +22,11 @@ } -@click.command() -@click.option("--port", default=8000, help="Port to listen on for SSE") -@click.option( - "--transport", - type=click.Choice(["stdio", "sse"]), - default="stdio", - help="Transport type", -) -def main(port: int, transport: str) -> int: - app = Server("mcp-simple-resource") - - @app.list_resources() - async def list_resources() -> list[types.Resource]: - return [ +async def handle_list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListResourcesResult: + return types.ListResourcesResult( + resources=[ types.Resource( uri=f"file:///{name}.txt", name=name, @@ -44,20 +36,45 @@ async def list_resources() -> list[types.Resource]: ) for name in SAMPLE_RESOURCES.keys() ] + ) - @app.read_resource() - async def read_resource(uri: str): - from urllib.parse import urlparse - parsed = urlparse(uri) - if not parsed.path: - raise ValueError(f"Invalid resource path: {uri}") - name = parsed.path.replace(".txt", "").lstrip("/") +async def handle_read_resource( + ctx: ServerRequestContext, params: types.ReadResourceRequestParams +) -> types.ReadResourceResult: + parsed = urlparse(str(params.uri)) + if not parsed.path: + raise ValueError(f"Invalid resource path: {params.uri}") + name = parsed.path.replace(".txt", "").lstrip("/") - if name not in SAMPLE_RESOURCES: - raise ValueError(f"Unknown resource: {uri}") + if name not in SAMPLE_RESOURCES: + raise ValueError(f"Unknown resource: {params.uri}") - return [ReadResourceContents(content=SAMPLE_RESOURCES[name]["content"], mime_type="text/plain")] + return types.ReadResourceResult( + contents=[ + types.TextResourceContents( + uri=str(params.uri), + text=SAMPLE_RESOURCES[name]["content"], + mime_type="text/plain", + ) + ] + ) + + +@click.command() +@click.option("--port", default=8000, help="Port to listen on for SSE") +@click.option( + "--transport", + type=click.Choice(["stdio", "sse"]), + default="stdio", + help="Transport type", +) +def main(port: int, transport: str) -> int: + app = Server( + "mcp-simple-resource", + on_list_resources=handle_list_resources, + on_read_resource=handle_read_resource, + ) if transport == "sse": from mcp.server.sse import SseServerTransport diff --git a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py index 9fed2f0aa6..cb4a6503ce 100644 --- a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py +++ b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py @@ -1,13 +1,12 @@ import contextlib import logging from collections.abc import AsyncIterator -from typing import Any import anyio import click import uvicorn from mcp import types -from mcp.server.lowlevel import Server +from mcp.server import Server, ServerRequestContext from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from starlette.applications import Starlette from starlette.middleware.cors import CORSMiddleware @@ -17,6 +16,64 @@ logger = logging.getLogger(__name__) +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool( + name="start-notification-stream", + description=("Sends a stream of notifications with configurable count and interval"), + input_schema={ + "type": "object", + "required": ["interval", "count", "caller"], + "properties": { + "interval": { + "type": "number", + "description": "Interval between notifications in seconds", + }, + "count": { + "type": "number", + "description": "Number of notifications to send", + }, + "caller": { + "type": "string", + "description": ("Identifier of the caller to include in notifications"), + }, + }, + }, + ) + ] + ) + + +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + arguments = params.arguments or {} + interval = arguments.get("interval", 1.0) + count = arguments.get("count", 5) + caller = arguments.get("caller", "unknown") + + # Send the specified number of notifications with the given interval + for i in range(count): + await ctx.session.send_log_message( + level="info", + data=f"Notification {i + 1}/{count} from caller: {caller}", + logger="notification_stream", + related_request_id=ctx.request_id, + ) + if i < count - 1: # Don't wait after the last notification + await anyio.sleep(interval) + + return types.CallToolResult( + content=[ + types.TextContent( + type="text", + text=(f"Sent {count} notifications with {interval}s interval for caller: {caller}"), + ) + ] + ) + + @click.command() @click.option("--port", default=3000, help="Port to listen on for HTTP") @click.option( @@ -41,59 +98,11 @@ def main( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) - app = Server("mcp-streamable-http-stateless-demo") - - @app.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: - ctx = app.request_context - interval = arguments.get("interval", 1.0) - count = arguments.get("count", 5) - caller = arguments.get("caller", "unknown") - - # Send the specified number of notifications with the given interval - for i in range(count): - await ctx.session.send_log_message( - level="info", - data=f"Notification {i + 1}/{count} from caller: {caller}", - logger="notification_stream", - related_request_id=ctx.request_id, - ) - if i < count - 1: # Don't wait after the last notification - await anyio.sleep(interval) - - return [ - types.TextContent( - type="text", - text=(f"Sent {count} notifications with {interval}s interval for caller: {caller}"), - ) - ] - - @app.list_tools() - async def list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="start-notification-stream", - description=("Sends a stream of notifications with configurable count and interval"), - input_schema={ - "type": "object", - "required": ["interval", "count", "caller"], - "properties": { - "interval": { - "type": "number", - "description": "Interval between notifications in seconds", - }, - "count": { - "type": "number", - "description": "Number of notifications to send", - }, - "caller": { - "type": "string", - "description": ("Identifier of the caller to include in notifications"), - }, - }, - }, - ) - ] + app = Server( + "mcp-streamable-http-stateless-demo", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) # Create the session manager with true stateless mode session_manager = StreamableHTTPSessionManager( diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index ef03d9b08f..2f2a53b1b1 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -1,12 +1,11 @@ import contextlib import logging from collections.abc import AsyncIterator -from typing import Any import anyio import click from mcp import types -from mcp.server.lowlevel import Server +from mcp.server import Server, ServerRequestContext from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from starlette.applications import Starlette from starlette.middleware.cors import CORSMiddleware @@ -19,6 +18,75 @@ logger = logging.getLogger(__name__) +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool( + name="start-notification-stream", + description="Sends a stream of notifications with configurable count and interval", + input_schema={ + "type": "object", + "required": ["interval", "count", "caller"], + "properties": { + "interval": { + "type": "number", + "description": "Interval between notifications in seconds", + }, + "count": { + "type": "number", + "description": "Number of notifications to send", + }, + "caller": { + "type": "string", + "description": "Identifier of the caller to include in notifications", + }, + }, + }, + ) + ] + ) + + +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + arguments = params.arguments or {} + interval = arguments.get("interval", 1.0) + count = arguments.get("count", 5) + caller = arguments.get("caller", "unknown") + + # Send the specified number of notifications with the given interval + for i in range(count): + # Include more detailed message for resumability demonstration + notification_msg = f"[{i + 1}/{count}] Event from '{caller}' - Use Last-Event-ID to resume if disconnected" + await ctx.session.send_log_message( + level="info", + data=notification_msg, + logger="notification_stream", + # Associates this notification with the original request + # Ensures notifications are sent to the correct response stream + # Without this, notifications will either go to: + # - a standalone SSE stream (if GET request is supported) + # - nowhere (if GET request isn't supported) + related_request_id=ctx.request_id, + ) + logger.debug(f"Sent notification {i + 1}/{count} for caller: {caller}") + if i < count - 1: # Don't wait after the last notification + await anyio.sleep(interval) + + # This will send a resource notification through standalone SSE + # established by GET request + await ctx.session.send_resource_updated(uri="http:///test_resource") + return types.CallToolResult( + content=[ + types.TextContent( + type="text", + text=(f"Sent {count} notifications with {interval}s interval for caller: {caller}"), + ) + ] + ) + + @click.command() @click.option("--port", default=3000, help="Port to listen on for HTTP") @click.option( @@ -43,70 +111,11 @@ def main( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) - app = Server("mcp-streamable-http-demo") - - @app.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: - ctx = app.request_context - interval = arguments.get("interval", 1.0) - count = arguments.get("count", 5) - caller = arguments.get("caller", "unknown") - - # Send the specified number of notifications with the given interval - for i in range(count): - # Include more detailed message for resumability demonstration - notification_msg = f"[{i + 1}/{count}] Event from '{caller}' - Use Last-Event-ID to resume if disconnected" - await ctx.session.send_log_message( - level="info", - data=notification_msg, - logger="notification_stream", - # Associates this notification with the original request - # Ensures notifications are sent to the correct response stream - # Without this, notifications will either go to: - # - a standalone SSE stream (if GET request is supported) - # - nowhere (if GET request isn't supported) - related_request_id=ctx.request_id, - ) - logger.debug(f"Sent notification {i + 1}/{count} for caller: {caller}") - if i < count - 1: # Don't wait after the last notification - await anyio.sleep(interval) - - # This will send a resource notificaiton though standalone SSE - # established by GET request - await ctx.session.send_resource_updated(uri="http:///test_resource") - return [ - types.TextContent( - type="text", - text=(f"Sent {count} notifications with {interval}s interval for caller: {caller}"), - ) - ] - - @app.list_tools() - async def list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="start-notification-stream", - description=("Sends a stream of notifications with configurable count and interval"), - input_schema={ - "type": "object", - "required": ["interval", "count", "caller"], - "properties": { - "interval": { - "type": "number", - "description": "Interval between notifications in seconds", - }, - "count": { - "type": "number", - "description": "Number of notifications to send", - }, - "caller": { - "type": "string", - "description": ("Identifier of the caller to include in notifications"), - }, - }, - }, - ) - ] + app = Server( + "mcp-streamable-http-demo", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) # Create event store for resumability # The InMemoryEventStore enables resumability support for StreamableHTTP transport. diff --git a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py b/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py index dc689ed942..6938b6552a 100644 --- a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py +++ b/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py @@ -13,42 +13,39 @@ import click import uvicorn from mcp import types +from mcp.server import Server, ServerRequestContext from mcp.server.experimental.task_context import ServerTaskContext -from mcp.server.lowlevel import Server from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from starlette.applications import Starlette from starlette.routing import Mount -server = Server("simple-task-interactive") -# Enable task support - this auto-registers all handlers -server.experimental.enable_tasks() +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool( + name="confirm_delete", + description="Asks for confirmation before deleting (demonstrates elicitation)", + input_schema={ + "type": "object", + "properties": {"filename": {"type": "string"}}, + }, + execution=types.ToolExecution(task_support=types.TASK_REQUIRED), + ), + types.Tool( + name="write_haiku", + description="Asks LLM to write a haiku (demonstrates sampling)", + input_schema={"type": "object", "properties": {"topic": {"type": "string"}}}, + execution=types.ToolExecution(task_support=types.TASK_REQUIRED), + ), + ] + ) -@server.list_tools() -async def list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="confirm_delete", - description="Asks for confirmation before deleting (demonstrates elicitation)", - input_schema={ - "type": "object", - "properties": {"filename": {"type": "string"}}, - }, - execution=types.ToolExecution(task_support=types.TASK_REQUIRED), - ), - types.Tool( - name="write_haiku", - description="Asks LLM to write a haiku (demonstrates sampling)", - input_schema={"type": "object", "properties": {"topic": {"type": "string"}}}, - execution=types.ToolExecution(task_support=types.TASK_REQUIRED), - ), - ] - - -async def handle_confirm_delete(arguments: dict[str, Any]) -> types.CreateTaskResult: +async def handle_confirm_delete(ctx: ServerRequestContext, arguments: dict[str, Any]) -> types.CreateTaskResult: """Handle the confirm_delete tool - demonstrates elicitation.""" - ctx = server.request_context ctx.experimental.validate_task_mode(types.TASK_REQUIRED) filename = arguments.get("filename", "unknown.txt") @@ -80,9 +77,8 @@ async def work(task: ServerTaskContext) -> types.CallToolResult: return await ctx.experimental.run_task(work) -async def handle_write_haiku(arguments: dict[str, Any]) -> types.CreateTaskResult: +async def handle_write_haiku(ctx: ServerRequestContext, arguments: dict[str, Any]) -> types.CreateTaskResult: """Handle the write_haiku tool - demonstrates sampling.""" - ctx = server.request_context ctx.experimental.validate_task_mode(types.TASK_REQUIRED) topic = arguments.get("topic", "nature") @@ -111,18 +107,31 @@ async def work(task: ServerTaskContext) -> types.CallToolResult: return await ctx.experimental.run_task(work) -@server.call_tool() -async def handle_call_tool(name: str, arguments: dict[str, Any]) -> types.CallToolResult | types.CreateTaskResult: +async def handle_call_tool( + ctx: ServerRequestContext, params: types.CallToolRequestParams +) -> types.CallToolResult | types.CreateTaskResult: """Dispatch tool calls to their handlers.""" - if name == "confirm_delete": - return await handle_confirm_delete(arguments) - elif name == "write_haiku": - return await handle_write_haiku(arguments) - else: - return types.CallToolResult( - content=[types.TextContent(type="text", text=f"Unknown tool: {name}")], - is_error=True, - ) + arguments = params.arguments or {} + + if params.name == "confirm_delete": + return await handle_confirm_delete(ctx, arguments) + elif params.name == "write_haiku": + return await handle_write_haiku(ctx, arguments) + + return types.CallToolResult( + content=[types.TextContent(type="text", text=f"Unknown tool: {params.name}")], + is_error=True, + ) + + +server = Server( + "simple-task-interactive", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) + +# Enable task support - this auto-registers all handlers +server.experimental.enable_tasks() def create_app(session_manager: StreamableHTTPSessionManager) -> Starlette: diff --git a/examples/servers/simple-task/mcp_simple_task/server.py b/examples/servers/simple-task/mcp_simple_task/server.py index ec16b15ae8..50ae3ca9af 100644 --- a/examples/servers/simple-task/mcp_simple_task/server.py +++ b/examples/servers/simple-task/mcp_simple_task/server.py @@ -2,66 +2,68 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager -from typing import Any import anyio import click import uvicorn from mcp import types +from mcp.server import Server, ServerRequestContext from mcp.server.experimental.task_context import ServerTaskContext -from mcp.server.lowlevel import Server from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from starlette.applications import Starlette from starlette.routing import Mount -server = Server("simple-task-server") -# One-line setup: auto-registers get_task, get_task_result, list_tasks, cancel_task -server.experimental.enable_tasks() +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool( + name="long_running_task", + description="A task that takes a few seconds to complete with status updates", + input_schema={"type": "object", "properties": {}}, + execution=types.ToolExecution(task_support=types.TASK_REQUIRED), + ) + ] + ) -@server.list_tools() -async def list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="long_running_task", - description="A task that takes a few seconds to complete with status updates", - input_schema={"type": "object", "properties": {}}, - execution=types.ToolExecution(task_support=types.TASK_REQUIRED), - ) - ] +async def handle_call_tool( + ctx: ServerRequestContext, params: types.CallToolRequestParams +) -> types.CallToolResult | types.CreateTaskResult: + """Dispatch tool calls to their handlers.""" + if params.name == "long_running_task": + ctx.experimental.validate_task_mode(types.TASK_REQUIRED) + async def work(task: ServerTaskContext) -> types.CallToolResult: + await task.update_status("Starting work...") + await anyio.sleep(1) -async def handle_long_running_task(arguments: dict[str, Any]) -> types.CreateTaskResult: - """Handle the long_running_task tool - demonstrates status updates.""" - ctx = server.request_context - ctx.experimental.validate_task_mode(types.TASK_REQUIRED) + await task.update_status("Processing step 1...") + await anyio.sleep(1) - async def work(task: ServerTaskContext) -> types.CallToolResult: - await task.update_status("Starting work...") - await anyio.sleep(1) + await task.update_status("Processing step 2...") + await anyio.sleep(1) - await task.update_status("Processing step 1...") - await anyio.sleep(1) + return types.CallToolResult(content=[types.TextContent(type="text", text="Task completed!")]) - await task.update_status("Processing step 2...") - await anyio.sleep(1) + return await ctx.experimental.run_task(work) - return types.CallToolResult(content=[types.TextContent(type="text", text="Task completed!")]) + return types.CallToolResult( + content=[types.TextContent(type="text", text=f"Unknown tool: {params.name}")], + is_error=True, + ) - return await ctx.experimental.run_task(work) +server = Server( + "simple-task-server", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) -@server.call_tool() -async def handle_call_tool(name: str, arguments: dict[str, Any]) -> types.CallToolResult | types.CreateTaskResult: - """Dispatch tool calls to their handlers.""" - if name == "long_running_task": - return await handle_long_running_task(arguments) - else: - return types.CallToolResult( - content=[types.TextContent(type="text", text=f"Unknown tool: {name}")], - is_error=True, - ) +# One-line setup: auto-registers get_task, get_task_result, list_tasks, cancel_task +server.experimental.enable_tasks() @click.command() diff --git a/examples/servers/simple-tool/mcp_simple_tool/server.py b/examples/servers/simple-tool/mcp_simple_tool/server.py index 1c253a22ec..9fe71e5b7a 100644 --- a/examples/servers/simple-tool/mcp_simple_tool/server.py +++ b/examples/servers/simple-tool/mcp_simple_tool/server.py @@ -1,9 +1,7 @@ -from typing import Any - import anyio import click from mcp import types -from mcp.server.lowlevel import Server +from mcp.server import Server, ServerRequestContext from mcp.shared._httpx_utils import create_mcp_http_client from starlette.requests import Request @@ -18,28 +16,11 @@ async def fetch_website( return [types.TextContent(type="text", text=response.text)] -@click.command() -@click.option("--port", default=8000, help="Port to listen on for SSE") -@click.option( - "--transport", - type=click.Choice(["stdio", "sse"]), - default="stdio", - help="Transport type", -) -def main(port: int, transport: str) -> int: - app = Server("mcp-website-fetcher") - - @app.call_tool() - async def fetch_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: - if name != "fetch": - raise ValueError(f"Unknown tool: {name}") - if "url" not in arguments: - raise ValueError("Missing required argument 'url'") - return await fetch_website(arguments["url"]) - - @app.list_tools() - async def list_tools() -> list[types.Tool]: - return [ +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ types.Tool( name="fetch", title="Website Fetcher", @@ -56,6 +37,33 @@ async def list_tools() -> list[types.Tool]: }, ) ] + ) + + +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + if params.name != "fetch": + raise ValueError(f"Unknown tool: {params.name}") + arguments = params.arguments or {} + if "url" not in arguments: + raise ValueError("Missing required argument 'url'") + content = await fetch_website(arguments["url"]) + return types.CallToolResult(content=content) + + +@click.command() +@click.option("--port", default=8000, help="Port to listen on for SSE") +@click.option( + "--transport", + type=click.Choice(["stdio", "sse"]), + default="stdio", + help="Transport type", +) +def main(port: int, transport: str) -> int: + app = Server( + "mcp-website-fetcher", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) if transport == "sse": from mcp.server.sse import SseServerTransport diff --git a/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py index 9d7071ca70..c8178c35a4 100644 --- a/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py +++ b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py @@ -15,12 +15,11 @@ import contextlib import logging from collections.abc import AsyncIterator -from typing import Any import anyio import click from mcp import types -from mcp.server.lowlevel import Server +from mcp.server import Server, ServerRequestContext from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from starlette.applications import Starlette from starlette.routing import Mount @@ -31,111 +30,124 @@ logger = logging.getLogger(__name__) -@click.command() -@click.option("--port", default=3000, help="Port to listen on") -@click.option( - "--log-level", - default="INFO", - help="Logging level (DEBUG, INFO, WARNING, ERROR)", -) -@click.option( - "--retry-interval", - default=100, - help="SSE retry interval in milliseconds (sent to client)", -) -def main(port: int, log_level: str, retry_interval: int) -> int: - """Run the SSE Polling Demo server.""" - logging.basicConfig( - level=getattr(logging, log_level.upper()), - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + """List available tools.""" + return types.ListToolsResult( + tools=[ + types.Tool( + name="process_batch", + description=( + "Process a batch of items with periodic checkpoints. " + "Demonstrates SSE polling where server closes stream periodically." + ), + input_schema={ + "type": "object", + "properties": { + "items": { + "type": "integer", + "description": "Number of items to process (1-100)", + "default": 10, + }, + "checkpoint_every": { + "type": "integer", + "description": "Close stream after this many items (1-20)", + "default": 3, + }, + }, + }, + ) + ] ) - # Create the lowlevel server - app = Server("sse-polling-demo") - @app.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: - """Handle tool calls.""" - ctx = app.request_context +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + """Handle tool calls.""" + arguments = params.arguments or {} - if name == "process_batch": - items = arguments.get("items", 10) - checkpoint_every = arguments.get("checkpoint_every", 3) + if params.name == "process_batch": + items = arguments.get("items", 10) + checkpoint_every = arguments.get("checkpoint_every", 3) - if items < 1 or items > 100: - return [types.TextContent(type="text", text="Error: items must be between 1 and 100")] - if checkpoint_every < 1 or checkpoint_every > 20: - return [types.TextContent(type="text", text="Error: checkpoint_every must be between 1 and 20")] + if items < 1 or items > 100: + return types.CallToolResult( + content=[types.TextContent(type="text", text="Error: items must be between 1 and 100")] + ) + if checkpoint_every < 1 or checkpoint_every > 20: + return types.CallToolResult( + content=[types.TextContent(type="text", text="Error: checkpoint_every must be between 1 and 20")] + ) + + await ctx.session.send_log_message( + level="info", + data=f"Starting batch processing of {items} items...", + logger="process_batch", + related_request_id=ctx.request_id, + ) + for i in range(1, items + 1): + # Simulate work + await anyio.sleep(0.5) + + # Report progress await ctx.session.send_log_message( level="info", - data=f"Starting batch processing of {items} items...", + data=f"[{i}/{items}] Processing item {i}", logger="process_batch", related_request_id=ctx.request_id, ) - for i in range(1, items + 1): - # Simulate work - await anyio.sleep(0.5) - - # Report progress + # Checkpoint: close stream to trigger client reconnect + if i % checkpoint_every == 0 and i < items: await ctx.session.send_log_message( level="info", - data=f"[{i}/{items}] Processing item {i}", + data=f"Checkpoint at item {i} - closing SSE stream for polling", logger="process_batch", related_request_id=ctx.request_id, ) - - # Checkpoint: close stream to trigger client reconnect - if i % checkpoint_every == 0 and i < items: - await ctx.session.send_log_message( - level="info", - data=f"Checkpoint at item {i} - closing SSE stream for polling", - logger="process_batch", - related_request_id=ctx.request_id, - ) - if ctx.close_sse_stream: - logger.info(f"Closing SSE stream at checkpoint {i}") - await ctx.close_sse_stream() - # Wait for client to reconnect (must be > retry_interval of 100ms) - await anyio.sleep(0.2) - - return [ + if ctx.close_sse_stream: + logger.info(f"Closing SSE stream at checkpoint {i}") + await ctx.close_sse_stream() + # Wait for client to reconnect (must be > retry_interval of 100ms) + await anyio.sleep(0.2) + + return types.CallToolResult( + content=[ types.TextContent( type="text", text=f"Successfully processed {items} items with checkpoints every {checkpoint_every} items", ) ] + ) - return [types.TextContent(type="text", text=f"Unknown tool: {name}")] + return types.CallToolResult(content=[types.TextContent(type="text", text=f"Unknown tool: {params.name}")]) - @app.list_tools() - async def list_tools() -> list[types.Tool]: - """List available tools.""" - return [ - types.Tool( - name="process_batch", - description=( - "Process a batch of items with periodic checkpoints. " - "Demonstrates SSE polling where server closes stream periodically." - ), - input_schema={ - "type": "object", - "properties": { - "items": { - "type": "integer", - "description": "Number of items to process (1-100)", - "default": 10, - }, - "checkpoint_every": { - "type": "integer", - "description": "Close stream after this many items (1-20)", - "default": 3, - }, - }, - }, - ) - ] + +@click.command() +@click.option("--port", default=3000, help="Port to listen on") +@click.option( + "--log-level", + default="INFO", + help="Logging level (DEBUG, INFO, WARNING, ERROR)", +) +@click.option( + "--retry-interval", + default=100, + help="SSE retry interval in milliseconds (sent to client)", +) +def main(port: int, log_level: str, retry_interval: int) -> int: + """Run the SSE Polling Demo server.""" + logging.basicConfig( + level=getattr(logging, log_level.upper()), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + app = Server( + "sse-polling-demo", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) # Create event store for resumability event_store = InMemoryEventStore() diff --git a/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__main__.py b/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__main__.py index fd73a54cdb..95fb908540 100644 --- a/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__main__.py +++ b/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__main__.py @@ -2,60 +2,54 @@ """Example low-level MCP server demonstrating structured output support. This example shows how to use the low-level server API to return -structured data from tools, with automatic validation against output -schemas. +structured data from tools. """ import asyncio +import json +import random from datetime import datetime -from typing import Any import mcp.server.stdio from mcp import types -from mcp.server.lowlevel import NotificationOptions, Server -from mcp.server.models import InitializationOptions +from mcp.server import Server, ServerRequestContext -# Create low-level server instance -server = Server("structured-output-lowlevel-example") - -@server.list_tools() -async def list_tools() -> list[types.Tool]: +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: """List available tools with their schemas.""" - return [ - types.Tool( - name="get_weather", - description="Get weather information (simulated)", - input_schema={ - "type": "object", - "properties": {"city": {"type": "string", "description": "City name"}}, - "required": ["city"], - }, - output_schema={ - "type": "object", - "properties": { - "temperature": {"type": "number"}, - "conditions": {"type": "string"}, - "humidity": {"type": "integer", "minimum": 0, "maximum": 100}, - "wind_speed": {"type": "number"}, - "timestamp": {"type": "string", "format": "date-time"}, + return types.ListToolsResult( + tools=[ + types.Tool( + name="get_weather", + description="Get weather information (simulated)", + input_schema={ + "type": "object", + "properties": {"city": {"type": "string", "description": "City name"}}, + "required": ["city"], + }, + output_schema={ + "type": "object", + "properties": { + "temperature": {"type": "number"}, + "conditions": {"type": "string"}, + "humidity": {"type": "integer", "minimum": 0, "maximum": 100}, + "wind_speed": {"type": "number"}, + "timestamp": {"type": "string", "format": "date-time"}, + }, + "required": ["temperature", "conditions", "humidity", "wind_speed", "timestamp"], }, - "required": ["temperature", "conditions", "humidity", "wind_speed", "timestamp"], - }, - ), - ] + ), + ] + ) -@server.call_tool() -async def call_tool(name: str, arguments: dict[str, Any]) -> Any: +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: """Handle tool call with structured output.""" - if name == "get_weather": - # city = arguments["city"] # Would be used with real weather API - + if params.name == "get_weather": # Simulate weather data (in production, call a real weather API) - import random - weather_conditions = ["sunny", "cloudy", "rainy", "partly cloudy", "foggy"] weather_data = { @@ -66,12 +60,19 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> Any: "timestamp": datetime.now().isoformat(), } - # Return structured data only - # The low-level server will serialize this to JSON content automatically - return weather_data + return types.CallToolResult( + content=[types.TextContent(type="text", text=json.dumps(weather_data, indent=2))], + structured_content=weather_data, + ) + + raise ValueError(f"Unknown tool: {params.name}") - else: - raise ValueError(f"Unknown tool: {name}") + +server = Server( + "structured-output-lowlevel-example", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) async def run(): @@ -80,14 +81,7 @@ async def run(): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="structured-output-lowlevel-example", - server_version="0.1.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) diff --git a/examples/snippets/servers/lowlevel/basic.py b/examples/snippets/servers/lowlevel/basic.py index 0d44325048..81f40e9945 100644 --- a/examples/snippets/servers/lowlevel/basic.py +++ b/examples/snippets/servers/lowlevel/basic.py @@ -6,32 +6,30 @@ import mcp.server.stdio from mcp import types -from mcp.server.lowlevel import NotificationOptions, Server -from mcp.server.models import InitializationOptions +from mcp.server import Server, ServerRequestContext -# Create a server instance -server = Server("example-server") - -@server.list_prompts() -async def handle_list_prompts() -> list[types.Prompt]: +async def handle_list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListPromptsResult: """List available prompts.""" - return [ - types.Prompt( - name="example-prompt", - description="An example prompt template", - arguments=[types.PromptArgument(name="arg1", description="Example argument", required=True)], - ) - ] + return types.ListPromptsResult( + prompts=[ + types.Prompt( + name="example-prompt", + description="An example prompt template", + arguments=[types.PromptArgument(name="arg1", description="Example argument", required=True)], + ) + ] + ) -@server.get_prompt() -async def handle_get_prompt(name: str, arguments: dict[str, str] | None) -> types.GetPromptResult: +async def handle_get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> types.GetPromptResult: """Get a specific prompt by name.""" - if name != "example-prompt": - raise ValueError(f"Unknown prompt: {name}") + if params.name != "example-prompt": + raise ValueError(f"Unknown prompt: {params.name}") - arg1_value = (arguments or {}).get("arg1", "default") + arg1_value = (params.arguments or {}).get("arg1", "default") return types.GetPromptResult( description="Example prompt", @@ -44,20 +42,20 @@ async def handle_get_prompt(name: str, arguments: dict[str, str] | None) -> type ) +server = Server( + "example-server", + on_list_prompts=handle_list_prompts, + on_get_prompt=handle_get_prompt, +) + + async def run(): """Run the basic low-level server.""" async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="example", - server_version="0.1.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) diff --git a/examples/snippets/servers/lowlevel/direct_call_tool_result.py b/examples/snippets/servers/lowlevel/direct_call_tool_result.py index 725f5711af..7e8fc4dcb3 100644 --- a/examples/snippets/servers/lowlevel/direct_call_tool_result.py +++ b/examples/snippets/servers/lowlevel/direct_call_tool_result.py @@ -3,44 +3,49 @@ """ import asyncio -from typing import Any import mcp.server.stdio from mcp import types -from mcp.server.lowlevel import NotificationOptions, Server -from mcp.server.models import InitializationOptions +from mcp.server import Server, ServerRequestContext -server = Server("example-server") - -@server.list_tools() -async def list_tools() -> list[types.Tool]: +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: """List available tools.""" - return [ - types.Tool( - name="advanced_tool", - description="Tool with full control including _meta field", - input_schema={ - "type": "object", - "properties": {"message": {"type": "string"}}, - "required": ["message"], - }, - ) - ] - - -@server.call_tool() -async def handle_call_tool(name: str, arguments: dict[str, Any]) -> types.CallToolResult: + return types.ListToolsResult( + tools=[ + types.Tool( + name="advanced_tool", + description="Tool with full control including _meta field", + input_schema={ + "type": "object", + "properties": {"message": {"type": "string"}}, + "required": ["message"], + }, + ) + ] + ) + + +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: """Handle tool calls by returning CallToolResult directly.""" - if name == "advanced_tool": - message = str(arguments.get("message", "")) + if params.name == "advanced_tool": + message = (params.arguments or {}).get("message", "") return types.CallToolResult( content=[types.TextContent(type="text", text=f"Processed: {message}")], structured_content={"result": "success", "message": message}, _meta={"hidden": "data for client applications only"}, ) - raise ValueError(f"Unknown tool: {name}") + raise ValueError(f"Unknown tool: {params.name}") + + +server = Server( + "example-server", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) async def run(): @@ -49,14 +54,7 @@ async def run(): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="example", - server_version="0.1.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) diff --git a/examples/snippets/servers/lowlevel/lifespan.py b/examples/snippets/servers/lowlevel/lifespan.py index da8ff7bdfd..bcd96c8935 100644 --- a/examples/snippets/servers/lowlevel/lifespan.py +++ b/examples/snippets/servers/lowlevel/lifespan.py @@ -4,12 +4,11 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager -from typing import Any +from typing import TypedDict import mcp.server.stdio from mcp import types -from mcp.server.lowlevel import NotificationOptions, Server -from mcp.server.models import InitializationOptions +from mcp.server import Server, ServerRequestContext # Mock database class for example @@ -32,52 +31,58 @@ async def query(self, query_str: str) -> list[dict[str, str]]: return [{"id": "1", "name": "Example", "query": query_str}] +class AppContext(TypedDict): + db: Database + + @asynccontextmanager -async def server_lifespan(_server: Server) -> AsyncIterator[dict[str, Any]]: +async def server_lifespan(_server: Server[AppContext]) -> AsyncIterator[AppContext]: """Manage server startup and shutdown lifecycle.""" - # Initialize resources on startup db = await Database.connect() try: yield {"db": db} finally: - # Clean up on shutdown await db.disconnect() -# Pass lifespan to server -server = Server("example-server", lifespan=server_lifespan) - - -@server.list_tools() -async def handle_list_tools() -> list[types.Tool]: +async def handle_list_tools( + ctx: ServerRequestContext[AppContext], params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: """List available tools.""" - return [ - types.Tool( - name="query_db", - description="Query the database", - input_schema={ - "type": "object", - "properties": {"query": {"type": "string", "description": "SQL query to execute"}}, - "required": ["query"], - }, - ) - ] - - -@server.call_tool() -async def query_db(name: str, arguments: dict[str, Any]) -> list[types.TextContent]: + return types.ListToolsResult( + tools=[ + types.Tool( + name="query_db", + description="Query the database", + input_schema={ + "type": "object", + "properties": {"query": {"type": "string", "description": "SQL query to execute"}}, + "required": ["query"], + }, + ) + ] + ) + + +async def handle_call_tool( + ctx: ServerRequestContext[AppContext], params: types.CallToolRequestParams +) -> types.CallToolResult: """Handle database query tool call.""" - if name != "query_db": - raise ValueError(f"Unknown tool: {name}") + if params.name != "query_db": + raise ValueError(f"Unknown tool: {params.name}") - # Access lifespan context - ctx = server.request_context db = ctx.lifespan_context["db"] + results = await db.query((params.arguments or {})["query"]) + + return types.CallToolResult(content=[types.TextContent(type="text", text=f"Query results: {results}")]) - # Execute query - results = await db.query(arguments["query"]) - return [types.TextContent(type="text", text=f"Query results: {results}")] +server = Server( + "example-server", + lifespan=server_lifespan, + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) async def run(): @@ -86,14 +91,7 @@ async def run(): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="example-server", - server_version="0.1.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) diff --git a/examples/snippets/servers/lowlevel/structured_output.py b/examples/snippets/servers/lowlevel/structured_output.py index cad8f67da6..f93c8875fd 100644 --- a/examples/snippets/servers/lowlevel/structured_output.py +++ b/examples/snippets/servers/lowlevel/structured_output.py @@ -3,62 +3,67 @@ """ import asyncio -from typing import Any +import json import mcp.server.stdio from mcp import types -from mcp.server.lowlevel import NotificationOptions, Server -from mcp.server.models import InitializationOptions +from mcp.server import Server, ServerRequestContext -server = Server("example-server") - -@server.list_tools() -async def list_tools() -> list[types.Tool]: +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: """List available tools with structured output schemas.""" - return [ - types.Tool( - name="get_weather", - description="Get current weather for a city", - input_schema={ - "type": "object", - "properties": {"city": {"type": "string", "description": "City name"}}, - "required": ["city"], - }, - output_schema={ - "type": "object", - "properties": { - "temperature": {"type": "number", "description": "Temperature in Celsius"}, - "condition": {"type": "string", "description": "Weather condition"}, - "humidity": {"type": "number", "description": "Humidity percentage"}, - "city": {"type": "string", "description": "City name"}, + return types.ListToolsResult( + tools=[ + types.Tool( + name="get_weather", + description="Get current weather for a city", + input_schema={ + "type": "object", + "properties": {"city": {"type": "string", "description": "City name"}}, + "required": ["city"], }, - "required": ["temperature", "condition", "humidity", "city"], - }, - ) - ] + output_schema={ + "type": "object", + "properties": { + "temperature": {"type": "number", "description": "Temperature in Celsius"}, + "condition": {"type": "string", "description": "Weather condition"}, + "humidity": {"type": "number", "description": "Humidity percentage"}, + "city": {"type": "string", "description": "City name"}, + }, + "required": ["temperature", "condition", "humidity", "city"], + }, + ) + ] + ) -@server.call_tool() -async def call_tool(name: str, arguments: dict[str, Any]) -> dict[str, Any]: +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: """Handle tool calls with structured output.""" - if name == "get_weather": - city = arguments["city"] + if params.name == "get_weather": + city = (params.arguments or {})["city"] - # Simulated weather data - in production, call a weather API weather_data = { "temperature": 22.5, "condition": "partly cloudy", "humidity": 65, - "city": city, # Include the requested city + "city": city, } - # low-level server will validate structured output against the tool's - # output schema, and additionally serialize it into a TextContent block - # for backwards compatibility with pre-2025-06-18 clients. - return weather_data - else: - raise ValueError(f"Unknown tool: {name}") + return types.CallToolResult( + content=[types.TextContent(type="text", text=json.dumps(weather_data, indent=2))], + structured_content=weather_data, + ) + + raise ValueError(f"Unknown tool: {params.name}") + + +server = Server( + "example-server", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) async def run(): @@ -67,14 +72,7 @@ async def run(): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="structured-output-example", - server_version="0.1.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) diff --git a/examples/snippets/servers/pagination_example.py b/examples/snippets/servers/pagination_example.py index bb406653e5..bcd0ffb106 100644 --- a/examples/snippets/servers/pagination_example.py +++ b/examples/snippets/servers/pagination_example.py @@ -1,22 +1,20 @@ -"""Example of implementing pagination with MCP server decorators.""" +"""Example of implementing pagination with the low-level MCP server.""" from mcp import types -from mcp.server.lowlevel import Server - -# Initialize the server -server = Server("paginated-server") +from mcp.server import Server, ServerRequestContext # Sample data to paginate ITEMS = [f"Item {i}" for i in range(1, 101)] # 100 items -@server.list_resources() -async def list_resources_paginated(request: types.ListResourcesRequest) -> types.ListResourcesResult: +async def handle_list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListResourcesResult: """List resources with pagination support.""" page_size = 10 # Extract cursor from request params - cursor = request.params.cursor if request.params is not None else None + cursor = params.cursor if params is not None else None # Parse cursor to get offset start = 0 if cursor is None else int(cursor) @@ -32,3 +30,6 @@ async def list_resources_paginated(request: types.ListResourcesRequest) -> types next_cursor = str(end) if end < len(ITEMS) else None return types.ListResourcesResult(resources=page_items, next_cursor=next_cursor) + + +server = Server("paginated-server", on_list_resources=handle_list_resources) diff --git a/src/mcp/server/__init__.py b/src/mcp/server/__init__.py index a2dada3af7..aab5c33f7d 100644 --- a/src/mcp/server/__init__.py +++ b/src/mcp/server/__init__.py @@ -1,5 +1,6 @@ +from .context import ServerRequestContext from .lowlevel import NotificationOptions, Server from .mcpserver import MCPServer from .models import InitializationOptions -__all__ = ["Server", "MCPServer", "NotificationOptions", "InitializationOptions"] +__all__ = ["Server", "ServerRequestContext", "MCPServer", "NotificationOptions", "InitializationOptions"] diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index 43b9d3800c..d8e11d78b2 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -10,7 +10,7 @@ from mcp.shared._context import RequestContext from mcp.shared.message import CloseSSEStreamCallback -LifespanContextT = TypeVar("LifespanContextT") +LifespanContextT = TypeVar("LifespanContextT", default=dict[str, Any]) RequestT = TypeVar("RequestT", default=Any) diff --git a/src/mcp/server/experimental/request_context.py b/src/mcp/server/experimental/request_context.py index 80ae5912b0..91aa9a6450 100644 --- a/src/mcp/server/experimental/request_context.py +++ b/src/mcp/server/experimental/request_context.py @@ -160,10 +160,7 @@ async def run_task( RuntimeError: If task support is not enabled or task_metadata is missing Example: - @server.call_tool() - async def handle_tool(name: str, args: dict): - ctx = server.request_context - + async def handle_tool(ctx: RequestContext, params: CallToolRequestParams) -> CallToolResult: async def work(task: ServerTaskContext) -> CallToolResult: result = await task.elicit( message="Are you sure?", diff --git a/src/mcp/server/experimental/task_result_handler.py b/src/mcp/server/experimental/task_result_handler.py index 991221bd0b..b2268bc1c8 100644 --- a/src/mcp/server/experimental/task_result_handler.py +++ b/src/mcp/server/experimental/task_result_handler.py @@ -44,17 +44,14 @@ class TaskResultHandler: 5. Returns the final result Usage: - # Create handler with store and queue - handler = TaskResultHandler(task_store, message_queue) - - # Register it with the server - @server.experimental.get_task_result() - async def handle_task_result(req: GetTaskPayloadRequest) -> GetTaskPayloadResult: - ctx = server.request_context - return await handler.handle(req, ctx.session, ctx.request_id) - - # Or use the convenience method - handler.register(server) + async def handle_task_result( + ctx: ServerRequestContext, params: GetTaskPayloadRequestParams + ) -> GetTaskPayloadResult: + ... + + server.experimental.enable_tasks( + on_task_result=handle_task_result, + ) """ def __init__( diff --git a/src/mcp/server/lowlevel/__init__.py b/src/mcp/server/lowlevel/__init__.py index 66df389916..37191ba1a0 100644 --- a/src/mcp/server/lowlevel/__init__.py +++ b/src/mcp/server/lowlevel/__init__.py @@ -1,3 +1,3 @@ from .server import NotificationOptions, Server -__all__ = ["Server", "NotificationOptions"] +__all__ = ["NotificationOptions", "Server"] diff --git a/src/mcp/server/lowlevel/experimental.py b/src/mcp/server/lowlevel/experimental.py index 9b472c0232..8ac2687280 100644 --- a/src/mcp/server/lowlevel/experimental.py +++ b/src/mcp/server/lowlevel/experimental.py @@ -7,10 +7,12 @@ import logging from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING +from typing import Any, Generic +from typing_extensions import TypeVar + +from mcp.server.context import ServerRequestContext from mcp.server.experimental.task_support import TaskSupport -from mcp.server.lowlevel.func_inspection import create_call_wrapper from mcp.shared.exceptions import MCPError from mcp.shared.experimental.tasks.helpers import cancel_task from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore @@ -18,16 +20,16 @@ from mcp.shared.experimental.tasks.store import TaskStore from mcp.types import ( INVALID_PARAMS, - CancelTaskRequest, + CancelTaskRequestParams, CancelTaskResult, GetTaskPayloadRequest, + GetTaskPayloadRequestParams, GetTaskPayloadResult, - GetTaskRequest, + GetTaskRequestParams, GetTaskResult, - ListTasksRequest, ListTasksResult, + PaginatedRequestParams, ServerCapabilities, - ServerResult, ServerTasksCapability, ServerTasksRequestsCapability, TasksCallCapability, @@ -36,13 +38,12 @@ TasksToolsCapability, ) -if TYPE_CHECKING: - from mcp.server.lowlevel.server import Server - logger = logging.getLogger(__name__) +LifespanResultT = TypeVar("LifespanResultT", default=Any) + -class ExperimentalHandlers: +class ExperimentalHandlers(Generic[LifespanResultT]): """Experimental request/notification handlers. WARNING: These APIs are experimental and may change without notice. @@ -50,13 +51,13 @@ class ExperimentalHandlers: def __init__( self, - server: Server, - request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]], - notification_handlers: dict[type, Callable[..., Awaitable[None]]], - ): - self._server = server - self._request_handlers = request_handlers - self._notification_handlers = notification_handlers + add_request_handler: Callable[ + [str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]]], None + ], + has_handler: Callable[[str], bool], + ) -> None: + self._add_request_handler = add_request_handler + self._has_handler = has_handler self._task_support: TaskSupport | None = None @property @@ -66,16 +67,13 @@ def task_support(self) -> TaskSupport | None: def update_capabilities(self, capabilities: ServerCapabilities) -> None: # Only add tasks capability if handlers are registered - if not any( - req_type in self._request_handlers - for req_type in [GetTaskRequest, ListTasksRequest, CancelTaskRequest, GetTaskPayloadRequest] - ): + if not any(self._has_handler(method) for method in ["tasks/get", "tasks/list", "tasks/cancel", "tasks/result"]): return capabilities.tasks = ServerTasksCapability() - if ListTasksRequest in self._request_handlers: + if self._has_handler("tasks/list"): capabilities.tasks.list = TasksListCapability() - if CancelTaskRequest in self._request_handlers: + if self._has_handler("tasks/cancel"): capabilities.tasks.cancel = TasksCancelCapability() capabilities.tasks.requests = ServerTasksRequestsCapability( @@ -86,15 +84,35 @@ def enable_tasks( self, store: TaskStore | None = None, queue: TaskMessageQueue | None = None, + *, + on_get_task: Callable[[ServerRequestContext[LifespanResultT], GetTaskRequestParams], Awaitable[GetTaskResult]] + | None = None, + on_task_result: Callable[ + [ServerRequestContext[LifespanResultT], GetTaskPayloadRequestParams], Awaitable[GetTaskPayloadResult] + ] + | None = None, + on_list_tasks: Callable[ + [ServerRequestContext[LifespanResultT], PaginatedRequestParams | None], Awaitable[ListTasksResult] + ] + | None = None, + on_cancel_task: Callable[ + [ServerRequestContext[LifespanResultT], CancelTaskRequestParams], Awaitable[CancelTaskResult] + ] + | None = None, ) -> TaskSupport: """Enable experimental task support. - This sets up the task infrastructure and auto-registers default handlers - for tasks/get, tasks/result, tasks/list, and tasks/cancel. + This sets up the task infrastructure and registers handlers for + tasks/get, tasks/result, tasks/list, and tasks/cancel. Custom handlers + can be provided via the on_* kwargs; any not provided will use defaults. Args: store: Custom TaskStore implementation (defaults to InMemoryTaskStore) queue: Custom TaskMessageQueue implementation (defaults to InMemoryTaskMessageQueue) + on_get_task: Custom handler for tasks/get + on_task_result: Custom handler for tasks/result + on_list_tasks: Custom handler for tasks/list + on_cancel_task: Custom handler for tasks/cancel Returns: The TaskSupport configuration object @@ -117,24 +135,27 @@ def enable_tasks( queue = InMemoryTaskMessageQueue() self._task_support = TaskSupport(store=store, queue=queue) - - # Auto-register default handlers - self._register_default_task_handlers() - - return self._task_support - - def _register_default_task_handlers(self) -> None: - """Register default handlers for task operations.""" - assert self._task_support is not None - support = self._task_support - - # Register get_task handler if not already registered - if GetTaskRequest not in self._request_handlers: - - async def _default_get_task(req: GetTaskRequest) -> ServerResult: - task = await support.store.get_task(req.params.task_id) + task_support = self._task_support + + # Register user-provided handlers + if on_get_task is not None: + self._add_request_handler("tasks/get", on_get_task) + if on_task_result is not None: + self._add_request_handler("tasks/result", on_task_result) + if on_list_tasks is not None: + self._add_request_handler("tasks/list", on_list_tasks) + if on_cancel_task is not None: + self._add_request_handler("tasks/cancel", on_cancel_task) + + # Fill in defaults for any not provided + if not self._has_handler("tasks/get"): + + async def _default_get_task( + ctx: ServerRequestContext[LifespanResultT], params: GetTaskRequestParams + ) -> GetTaskResult: + task = await task_support.store.get_task(params.task_id) if task is None: - raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {req.params.task_id}") + raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {params.task_id}") return GetTaskResult( task_id=task.task_id, status=task.status, @@ -145,136 +166,39 @@ async def _default_get_task(req: GetTaskRequest) -> ServerResult: poll_interval=task.poll_interval, ) - self._request_handlers[GetTaskRequest] = _default_get_task + self._add_request_handler("tasks/get", _default_get_task) - # Register get_task_result handler if not already registered - if GetTaskPayloadRequest not in self._request_handlers: + if not self._has_handler("tasks/result"): - async def _default_get_task_result(req: GetTaskPayloadRequest) -> GetTaskPayloadResult: - ctx = self._server.request_context - result = await support.handler.handle(req, ctx.session, ctx.request_id) + async def _default_get_task_result( + ctx: ServerRequestContext[LifespanResultT], params: GetTaskPayloadRequestParams + ) -> GetTaskPayloadResult: + assert ctx.request_id is not None + req = GetTaskPayloadRequest(params=params) + result = await task_support.handler.handle(req, ctx.session, ctx.request_id) return result - self._request_handlers[GetTaskPayloadRequest] = _default_get_task_result + self._add_request_handler("tasks/result", _default_get_task_result) - # Register list_tasks handler if not already registered - if ListTasksRequest not in self._request_handlers: + if not self._has_handler("tasks/list"): - async def _default_list_tasks(req: ListTasksRequest) -> ListTasksResult: - cursor = req.params.cursor if req.params else None - tasks, next_cursor = await support.store.list_tasks(cursor) + async def _default_list_tasks( + ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None + ) -> ListTasksResult: + cursor = params.cursor if params else None + tasks, next_cursor = await task_support.store.list_tasks(cursor) return ListTasksResult(tasks=tasks, next_cursor=next_cursor) - self._request_handlers[ListTasksRequest] = _default_list_tasks - - # Register cancel_task handler if not already registered - if CancelTaskRequest not in self._request_handlers: - - async def _default_cancel_task(req: CancelTaskRequest) -> CancelTaskResult: - result = await cancel_task(support.store, req.params.task_id) - return result - - self._request_handlers[CancelTaskRequest] = _default_cancel_task - - def list_tasks( - self, - ) -> Callable[ - [Callable[[ListTasksRequest], Awaitable[ListTasksResult]]], - Callable[[ListTasksRequest], Awaitable[ListTasksResult]], - ]: - """Register a handler for listing tasks. - - WARNING: This API is experimental and may change without notice. - """ - - def decorator( - func: Callable[[ListTasksRequest], Awaitable[ListTasksResult]], - ) -> Callable[[ListTasksRequest], Awaitable[ListTasksResult]]: - logger.debug("Registering handler for ListTasksRequest") - wrapper = create_call_wrapper(func, ListTasksRequest) - - async def handler(req: ListTasksRequest) -> ListTasksResult: - result = await wrapper(req) - return result - - self._request_handlers[ListTasksRequest] = handler - return func - - return decorator - - def get_task( - self, - ) -> Callable[ - [Callable[[GetTaskRequest], Awaitable[GetTaskResult]]], Callable[[GetTaskRequest], Awaitable[GetTaskResult]] - ]: - """Register a handler for getting task status. - - WARNING: This API is experimental and may change without notice. - """ - - def decorator( - func: Callable[[GetTaskRequest], Awaitable[GetTaskResult]], - ) -> Callable[[GetTaskRequest], Awaitable[GetTaskResult]]: - logger.debug("Registering handler for GetTaskRequest") - wrapper = create_call_wrapper(func, GetTaskRequest) - - async def handler(req: GetTaskRequest) -> GetTaskResult: - result = await wrapper(req) - return result - - self._request_handlers[GetTaskRequest] = handler - return func - - return decorator - - def get_task_result( - self, - ) -> Callable[ - [Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]]], - Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]], - ]: - """Register a handler for getting task results/payload. - - WARNING: This API is experimental and may change without notice. - """ - - def decorator( - func: Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]], - ) -> Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]]: - logger.debug("Registering handler for GetTaskPayloadRequest") - wrapper = create_call_wrapper(func, GetTaskPayloadRequest) - - async def handler(req: GetTaskPayloadRequest) -> GetTaskPayloadResult: - result = await wrapper(req) - return result - - self._request_handlers[GetTaskPayloadRequest] = handler - return func - - return decorator - - def cancel_task( - self, - ) -> Callable[ - [Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]]], - Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]], - ]: - """Register a handler for cancelling tasks. - - WARNING: This API is experimental and may change without notice. - """ + self._add_request_handler("tasks/list", _default_list_tasks) - def decorator( - func: Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]], - ) -> Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]]: - logger.debug("Registering handler for CancelTaskRequest") - wrapper = create_call_wrapper(func, CancelTaskRequest) + if not self._has_handler("tasks/cancel"): - async def handler(req: CancelTaskRequest) -> CancelTaskResult: - result = await wrapper(req) + async def _default_cancel_task( + ctx: ServerRequestContext[LifespanResultT], params: CancelTaskRequestParams + ) -> CancelTaskResult: + result = await cancel_task(task_support.store, params.task_id) return result - self._request_handlers[CancelTaskRequest] = handler - return func + self._add_request_handler("tasks/cancel", _default_cancel_task) - return decorator + return task_support diff --git a/src/mcp/server/lowlevel/func_inspection.py b/src/mcp/server/lowlevel/func_inspection.py deleted file mode 100644 index d176970902..0000000000 --- a/src/mcp/server/lowlevel/func_inspection.py +++ /dev/null @@ -1,53 +0,0 @@ -import inspect -from collections.abc import Callable -from typing import Any, TypeVar, get_type_hints - -T = TypeVar("T") -R = TypeVar("R") - - -def create_call_wrapper(func: Callable[..., R], request_type: type[T]) -> Callable[[T], R]: - """Create a wrapper function that knows how to call func with the request object. - - Returns a wrapper function that takes the request and calls func appropriately. - - The wrapper handles three calling patterns: - 1. Positional-only parameter typed as request_type (no default): func(req) - 2. Positional/keyword parameter typed as request_type (no default): func(**{param_name: req}) - 3. No request parameter or parameter with default: func() - """ - try: - sig = inspect.signature(func) - type_hints = get_type_hints(func) - except (ValueError, TypeError, NameError): # pragma: no cover - return lambda _: func() - - # Check for positional-only parameter typed as request_type - for param_name, param in sig.parameters.items(): - if param.kind == inspect.Parameter.POSITIONAL_ONLY: - param_type = type_hints.get(param_name) - if param_type == request_type: # pragma: no branch - # Check if it has a default - if so, treat as old style - if param.default is not inspect.Parameter.empty: # pragma: no cover - return lambda _: func() - # Found positional-only parameter with correct type and no default - return lambda req: func(req) - - # Check for any positional/keyword parameter typed as request_type - for param_name, param in sig.parameters.items(): - if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY): # pragma: no branch - param_type = type_hints.get(param_name) - if param_type == request_type: - # Check if it has a default - if so, treat as old style - if param.default is not inspect.Parameter.empty: # pragma: no cover - return lambda _: func() - - # Found keyword parameter with correct type and no default - # Need to capture param_name in closure properly - def make_keyword_wrapper(name: str) -> Callable[[Any], Any]: - return lambda req: func(**{name: req}) - - return make_keyword_wrapper(param_name) - - # No request parameter found - use old style - return lambda _: func() diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 7bd79bb37c..04404a3fca 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -2,82 +2,49 @@ This module provides a framework for creating an MCP (Model Context Protocol) server. It allows you to easily define and handle various types of requests and notifications -in an asynchronous manner. +using constructor-based handler registration. Usage: -1. Create a Server instance: - server = Server("your_server_name") - -2. Define request handlers using decorators: - @server.list_prompts() - async def handle_list_prompts(request: types.ListPromptsRequest) -> types.ListPromptsResult: - # Implementation - - @server.get_prompt() - async def handle_get_prompt( - name: str, arguments: dict[str, str] | None - ) -> types.GetPromptResult: - # Implementation - - @server.list_tools() - async def handle_list_tools(request: types.ListToolsRequest) -> types.ListToolsResult: - # Implementation - - @server.call_tool() - async def handle_call_tool( - name: str, arguments: dict | None - ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: - # Implementation - - @server.list_resource_templates() - async def handle_list_resource_templates() -> list[types.ResourceTemplate]: - # Implementation - -3. Define notification handlers if needed: - @server.progress_notification() - async def handle_progress( - progress_token: str | int, progress: float, total: float | None, - message: str | None - ) -> None: - # Implementation - -4. Run the server: +1. Define handler functions: + async def my_list_tools(ctx, params): + return types.ListToolsResult(tools=[...]) + + async def my_call_tool(ctx, params): + return types.CallToolResult(content=[...]) + +2. Create a Server instance with on_* handlers: + server = Server( + "your_server_name", + on_list_tools=my_list_tools, + on_call_tool=my_call_tool, + ) + +3. Run the server: async def main(): async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="your_server_name", - server_version="your_version", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) asyncio.run(main()) -The Server class provides methods to register handlers for various MCP requests and -notifications. It automatically manages the request context and handles incoming -messages from the client. +The Server class dispatches incoming requests and notifications to registered +handler callables by method string. """ from __future__ import annotations -import base64 import contextvars -import json import logging import warnings -from collections.abc import AsyncIterator, Awaitable, Callable, Iterable +from collections.abc import AsyncIterator, Awaitable, Callable from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager from importlib.metadata import version as importlib_version -from typing import Any, Generic, TypeAlias, cast +from typing import Any, Generic import anyio -import jsonschema from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from starlette.applications import Starlette from starlette.middleware import Middleware @@ -94,30 +61,20 @@ async def main(): from mcp.server.context import ServerRequestContext from mcp.server.experimental.request_context import Experimental from mcp.server.lowlevel.experimental import ExperimentalHandlers -from mcp.server.lowlevel.func_inspection import create_call_wrapper -from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings -from mcp.shared.exceptions import MCPError, UrlElicitationRequiredError +from mcp.shared.exceptions import MCPError from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder -from mcp.shared.tool_name_validation import validate_and_warn_tool_name logger = logging.getLogger(__name__) LifespanResultT = TypeVar("LifespanResultT", default=Any) -RequestT = TypeVar("RequestT", default=Any) - -# type aliases for tool call results -StructuredContent: TypeAlias = dict[str, Any] -UnstructuredContent: TypeAlias = Iterable[types.ContentBlock] -CombinationContent: TypeAlias = tuple[UnstructuredContent, StructuredContent] -# This will be properly typed in each Server instance's context -request_ctx: contextvars.ContextVar[ServerRequestContext[Any, Any]] = contextvars.ContextVar("request_ctx") +request_ctx: contextvars.ContextVar[ServerRequestContext[Any]] = contextvars.ContextVar("request_ctx") class NotificationOptions: @@ -128,7 +85,7 @@ def __init__(self, prompts_changed: bool = False, resources_changed: bool = Fals @asynccontextmanager -async def lifespan(_: Server[LifespanResultT, RequestT]) -> AsyncIterator[dict[str, Any]]: +async def lifespan(_: Server[LifespanResultT]) -> AsyncIterator[dict[str, Any]]: """Default lifespan context manager that does nothing. Args: @@ -140,10 +97,15 @@ async def lifespan(_: Server[LifespanResultT, RequestT]) -> AsyncIterator[dict[s yield {} -class Server(Generic[LifespanResultT, RequestT]): +async def _ping_handler(ctx: ServerRequestContext[Any], params: types.RequestParams | None) -> types.EmptyResult: + return types.EmptyResult() + + +class Server(Generic[LifespanResultT]): def __init__( self, name: str, + *, version: str | None = None, title: str | None = None, description: str | None = None, @@ -151,9 +113,80 @@ def __init__( website_url: str | None = None, icons: list[types.Icon] | None = None, lifespan: Callable[ - [Server[LifespanResultT, RequestT]], + [Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT], ] = lifespan, + # Request handlers + on_list_tools: Callable[ + [ServerRequestContext[LifespanResultT], types.PaginatedRequestParams | None], + Awaitable[types.ListToolsResult], + ] + | None = None, + on_call_tool: Callable[ + [ServerRequestContext[LifespanResultT], types.CallToolRequestParams], + Awaitable[types.CallToolResult | types.CreateTaskResult], + ] + | None = None, + on_list_resources: Callable[ + [ServerRequestContext[LifespanResultT], types.PaginatedRequestParams | None], + Awaitable[types.ListResourcesResult], + ] + | None = None, + on_list_resource_templates: Callable[ + [ServerRequestContext[LifespanResultT], types.PaginatedRequestParams | None], + Awaitable[types.ListResourceTemplatesResult], + ] + | None = None, + on_read_resource: Callable[ + [ServerRequestContext[LifespanResultT], types.ReadResourceRequestParams], + Awaitable[types.ReadResourceResult], + ] + | None = None, + on_subscribe_resource: Callable[ + [ServerRequestContext[LifespanResultT], types.SubscribeRequestParams], + Awaitable[types.EmptyResult], + ] + | None = None, + on_unsubscribe_resource: Callable[ + [ServerRequestContext[LifespanResultT], types.UnsubscribeRequestParams], + Awaitable[types.EmptyResult], + ] + | None = None, + on_list_prompts: Callable[ + [ServerRequestContext[LifespanResultT], types.PaginatedRequestParams | None], + Awaitable[types.ListPromptsResult], + ] + | None = None, + on_get_prompt: Callable[ + [ServerRequestContext[LifespanResultT], types.GetPromptRequestParams], + Awaitable[types.GetPromptResult], + ] + | None = None, + on_completion: Callable[ + [ServerRequestContext[LifespanResultT], types.CompleteRequestParams], + Awaitable[types.CompleteResult], + ] + | None = None, + on_set_logging_level: Callable[ + [ServerRequestContext[LifespanResultT], types.SetLevelRequestParams], + Awaitable[types.EmptyResult], + ] + | None = None, + on_ping: Callable[ + [ServerRequestContext[LifespanResultT], types.RequestParams | None], + Awaitable[types.EmptyResult], + ] = _ping_handler, + # Notification handlers + on_roots_list_changed: Callable[ + [ServerRequestContext[LifespanResultT], types.NotificationParams | None], + Awaitable[None], + ] + | None = None, + on_progress: Callable[ + [ServerRequestContext[LifespanResultT], types.ProgressNotificationParams], + Awaitable[None], + ] + | None = None, ): self.name = name self.version = version @@ -163,15 +196,64 @@ def __init__( self.website_url = website_url self.icons = icons self.lifespan = lifespan - self.request_handlers: dict[type, Callable[..., Awaitable[types.ServerResult]]] = { - types.PingRequest: _ping_handler, - } - self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} - self._tool_cache: dict[str, types.Tool] = {} - self._experimental_handlers: ExperimentalHandlers | None = None + self._request_handlers: dict[str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]]] = {} + self._notification_handlers: dict[ + str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[None]] + ] = {} + self._experimental_handlers: ExperimentalHandlers[LifespanResultT] | None = None self._session_manager: StreamableHTTPSessionManager | None = None logger.debug("Initializing server %r", name) + # Populate internal handler dicts from on_* kwargs + self._request_handlers.update( + { + method: handler + for method, handler in { + "ping": on_ping, + "prompts/list": on_list_prompts, + "prompts/get": on_get_prompt, + "resources/list": on_list_resources, + "resources/templates/list": on_list_resource_templates, + "resources/read": on_read_resource, + "resources/subscribe": on_subscribe_resource, + "resources/unsubscribe": on_unsubscribe_resource, + "tools/list": on_list_tools, + "tools/call": on_call_tool, + "logging/setLevel": on_set_logging_level, + "completion/complete": on_completion, + }.items() + if handler is not None + } + ) + + self._notification_handlers.update( + { + method: handler + for method, handler in { + "notifications/roots/list_changed": on_roots_list_changed, + "notifications/progress": on_progress, + }.items() + if handler is not None + } + ) + + def _add_request_handler( + self, + method: str, + handler: Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]], + ) -> None: + """Add a request handler, silently replacing any existing handler for the same method.""" + self._request_handlers[method] = handler + + def _has_handler(self, method: str) -> bool: + """Check if a handler is registered for the given method.""" + return method in self._request_handlers or method in self._notification_handlers + + # TODO: Rethink capabilities API. Currently capabilities are derived from registered + # handlers but require NotificationOptions to be passed externally for list_changed + # flags, and experimental_capabilities as a separate dict. Consider deriving capabilities + # entirely from server state (e.g. constructor params for list_changed) instead of + # requiring callers to assemble them at create_initialization_options() time. def create_initialization_options( self, notification_options: NotificationOptions | None = None, @@ -214,25 +296,26 @@ def get_capabilities( completions_capability = None # Set prompt capabilities if handler exists - if types.ListPromptsRequest in self.request_handlers: + if "prompts/list" in self._request_handlers: prompts_capability = types.PromptsCapability(list_changed=notification_options.prompts_changed) # Set resource capabilities if handler exists - if types.ListResourcesRequest in self.request_handlers: + if "resources/list" in self._request_handlers: resources_capability = types.ResourcesCapability( - subscribe=False, list_changed=notification_options.resources_changed + subscribe="resources/subscribe" in self._request_handlers, + list_changed=notification_options.resources_changed, ) # Set tool capabilities if handler exists - if types.ListToolsRequest in self.request_handlers: + if "tools/list" in self._request_handlers: tools_capability = types.ToolsCapability(list_changed=notification_options.tools_changed) # Set logging capabilities if handler exists - if types.SetLevelRequest in self.request_handlers: + if "logging/setLevel" in self._request_handlers: logging_capability = types.LoggingCapability() # Set completions capabilities if handler exists - if types.CompleteRequest in self.request_handlers: + if "completion/complete" in self._request_handlers: completions_capability = types.CompletionsCapability() capabilities = types.ServerCapabilities( @@ -248,12 +331,7 @@ def get_capabilities( return capabilities @property - def request_context(self) -> ServerRequestContext[LifespanResultT, RequestT]: - """If called outside of a request context, this will raise a LookupError.""" - return request_ctx.get() - - @property - def experimental(self) -> ExperimentalHandlers: + def experimental(self) -> ExperimentalHandlers[LifespanResultT]: """Experimental APIs for tasks and other features. WARNING: These APIs are experimental and may change without notice. @@ -261,7 +339,10 @@ def experimental(self) -> ExperimentalHandlers: # We create this inline so we only add these capabilities _if_ they're actually used if self._experimental_handlers is None: - self._experimental_handlers = ExperimentalHandlers(self, self.request_handlers, self.notification_handlers) + self._experimental_handlers = ExperimentalHandlers( + add_request_handler=self._add_request_handler, + has_handler=self._has_handler, + ) return self._experimental_handlers @property @@ -278,374 +359,6 @@ def session_manager(self) -> StreamableHTTPSessionManager: ) return self._session_manager # pragma: no cover - def list_prompts(self): - def decorator( - func: Callable[[], Awaitable[list[types.Prompt]]] - | Callable[[types.ListPromptsRequest], Awaitable[types.ListPromptsResult]], - ): - logger.debug("Registering handler for PromptListRequest") - - wrapper = create_call_wrapper(func, types.ListPromptsRequest) - - async def handler(req: types.ListPromptsRequest): - result = await wrapper(req) - # Handle both old style (list[Prompt]) and new style (ListPromptsResult) - if isinstance(result, types.ListPromptsResult): - return result - else: - # Old style returns list[Prompt] - return types.ListPromptsResult(prompts=result) - - self.request_handlers[types.ListPromptsRequest] = handler - return func - - return decorator - - def get_prompt(self): - def decorator( - func: Callable[[str, dict[str, str] | None], Awaitable[types.GetPromptResult]], - ): - logger.debug("Registering handler for GetPromptRequest") - - async def handler(req: types.GetPromptRequest): - prompt_get = await func(req.params.name, req.params.arguments) - return prompt_get - - self.request_handlers[types.GetPromptRequest] = handler - return func - - return decorator - - def list_resources(self): - def decorator( - func: Callable[[], Awaitable[list[types.Resource]]] - | Callable[[types.ListResourcesRequest], Awaitable[types.ListResourcesResult]], - ): - logger.debug("Registering handler for ListResourcesRequest") - - wrapper = create_call_wrapper(func, types.ListResourcesRequest) - - async def handler(req: types.ListResourcesRequest): - result = await wrapper(req) - # Handle both old style (list[Resource]) and new style (ListResourcesResult) - if isinstance(result, types.ListResourcesResult): - return result - else: - # Old style returns list[Resource] - return types.ListResourcesResult(resources=result) - - self.request_handlers[types.ListResourcesRequest] = handler - return func - - return decorator - - def list_resource_templates(self): - def decorator(func: Callable[[], Awaitable[list[types.ResourceTemplate]]]): - logger.debug("Registering handler for ListResourceTemplatesRequest") - - async def handler(_: Any): - templates = await func() - return types.ListResourceTemplatesResult(resource_templates=templates) - - self.request_handlers[types.ListResourceTemplatesRequest] = handler - return func - - return decorator - - def read_resource(self): - def decorator( - func: Callable[[str], Awaitable[str | bytes | Iterable[ReadResourceContents]]], - ): - logger.debug("Registering handler for ReadResourceRequest") - - async def handler(req: types.ReadResourceRequest): - result = await func(req.params.uri) - - def create_content(data: str | bytes, mime_type: str | None, meta: dict[str, Any] | None = None): - # Note: ResourceContents uses Field(alias="_meta"), so we must use the alias key - meta_kwargs: dict[str, Any] = {"_meta": meta} if meta is not None else {} - match data: - case str() as data: - return types.TextResourceContents( - uri=req.params.uri, - text=data, - mime_type=mime_type or "text/plain", - **meta_kwargs, - ) - case bytes() as data: # pragma: no branch - return types.BlobResourceContents( - uri=req.params.uri, - blob=base64.b64encode(data).decode(), - mime_type=mime_type or "application/octet-stream", - **meta_kwargs, - ) - - match result: - case str() | bytes() as data: # pragma: lax no cover - warnings.warn( - "Returning str or bytes from read_resource is deprecated. " - "Use Iterable[ReadResourceContents] instead.", - DeprecationWarning, - stacklevel=2, - ) - content = create_content(data, None) - case Iterable() as contents: - contents_list = [ - create_content( - content_item.content, content_item.mime_type, getattr(content_item, "meta", None) - ) - for content_item in contents - ] - return types.ReadResourceResult(contents=contents_list) - case _: # pragma: no cover - raise ValueError(f"Unexpected return type from read_resource: {type(result)}") - - return types.ReadResourceResult(contents=[content]) # pragma: no cover - - self.request_handlers[types.ReadResourceRequest] = handler - return func - - return decorator - - def set_logging_level(self): - def decorator(func: Callable[[types.LoggingLevel], Awaitable[None]]): - logger.debug("Registering handler for SetLevelRequest") - - async def handler(req: types.SetLevelRequest): - await func(req.params.level) - return types.EmptyResult() - - self.request_handlers[types.SetLevelRequest] = handler - return func - - return decorator - - def subscribe_resource(self): - def decorator(func: Callable[[str], Awaitable[None]]): - logger.debug("Registering handler for SubscribeRequest") - - async def handler(req: types.SubscribeRequest): - await func(req.params.uri) - return types.EmptyResult() - - self.request_handlers[types.SubscribeRequest] = handler - return func - - return decorator - - def unsubscribe_resource(self): - def decorator(func: Callable[[str], Awaitable[None]]): - logger.debug("Registering handler for UnsubscribeRequest") - - async def handler(req: types.UnsubscribeRequest): - await func(req.params.uri) - return types.EmptyResult() - - self.request_handlers[types.UnsubscribeRequest] = handler - return func - - return decorator - - def list_tools(self): - def decorator( - func: Callable[[], Awaitable[list[types.Tool]]] - | Callable[[types.ListToolsRequest], Awaitable[types.ListToolsResult]], - ): - logger.debug("Registering handler for ListToolsRequest") - - wrapper = create_call_wrapper(func, types.ListToolsRequest) - - async def handler(req: types.ListToolsRequest): - result = await wrapper(req) - - # Handle both old style (list[Tool]) and new style (ListToolsResult) - if isinstance(result, types.ListToolsResult): - # Refresh the tool cache with returned tools - for tool in result.tools: - validate_and_warn_tool_name(tool.name) - self._tool_cache[tool.name] = tool - return result - else: - # Old style returns list[Tool] - # Clear and refresh the entire tool cache - self._tool_cache.clear() - for tool in result: - validate_and_warn_tool_name(tool.name) - self._tool_cache[tool.name] = tool - return types.ListToolsResult(tools=result) - - self.request_handlers[types.ListToolsRequest] = handler - return func - - return decorator - - def _make_error_result(self, error_message: str) -> types.CallToolResult: - """Create a CallToolResult with an error.""" - return types.CallToolResult( - content=[types.TextContent(type="text", text=error_message)], - is_error=True, - ) - - async def _get_cached_tool_definition(self, tool_name: str) -> types.Tool | None: - """Get tool definition from cache, refreshing if necessary. - - Returns the Tool object if found, None otherwise. - """ - if tool_name not in self._tool_cache: - if types.ListToolsRequest in self.request_handlers: - logger.debug("Tool cache miss for %s, refreshing cache", tool_name) - await self.request_handlers[types.ListToolsRequest](None) - - tool = self._tool_cache.get(tool_name) - if tool is None: - logger.warning("Tool '%s' not listed, no validation will be performed", tool_name) - - return tool - - def call_tool(self, *, validate_input: bool = True): - """Register a tool call handler. - - Args: - validate_input: If True, validates input against inputSchema. Default is True. - - The handler validates input against inputSchema (if validate_input=True), calls the tool function, - and builds a CallToolResult with the results: - - Unstructured content (iterable of ContentBlock): returned in content - - Structured content (dict): returned in structuredContent, serialized JSON text returned in content - - Both: returned in content and structuredContent - - If outputSchema is defined, validates structuredContent or errors if missing. - """ - - def decorator( - func: Callable[ - [str, dict[str, Any]], - Awaitable[ - UnstructuredContent - | StructuredContent - | CombinationContent - | types.CallToolResult - | types.CreateTaskResult - ], - ], - ): - logger.debug("Registering handler for CallToolRequest") - - async def handler(req: types.CallToolRequest): - try: - tool_name = req.params.name - arguments = req.params.arguments or {} - tool = await self._get_cached_tool_definition(tool_name) - - # input validation - if validate_input and tool: - try: - jsonschema.validate(instance=arguments, schema=tool.input_schema) - except jsonschema.ValidationError as e: - return self._make_error_result(f"Input validation error: {e.message}") - - # tool call - results = await func(tool_name, arguments) - - # output normalization - unstructured_content: UnstructuredContent - maybe_structured_content: StructuredContent | None - if isinstance(results, types.CallToolResult): - return results - elif isinstance(results, types.CreateTaskResult): - # Task-augmented execution returns task info instead of result - return results - elif isinstance(results, tuple) and len(results) == 2: - # tool returned both structured and unstructured content - unstructured_content, maybe_structured_content = cast(CombinationContent, results) - elif isinstance(results, dict): - # tool returned structured content only - maybe_structured_content = cast(StructuredContent, results) - unstructured_content = [types.TextContent(type="text", text=json.dumps(results, indent=2))] - elif hasattr(results, "__iter__"): - # tool returned unstructured content only - unstructured_content = cast(UnstructuredContent, results) - maybe_structured_content = None - else: # pragma: no cover - return self._make_error_result(f"Unexpected return type from tool: {type(results).__name__}") - - # output validation - if tool and tool.output_schema is not None: - if maybe_structured_content is None: - return self._make_error_result( - "Output validation error: outputSchema defined but no structured output returned" - ) - else: - try: - jsonschema.validate(instance=maybe_structured_content, schema=tool.output_schema) - except jsonschema.ValidationError as e: - return self._make_error_result(f"Output validation error: {e.message}") - - # result - return types.CallToolResult( - content=list(unstructured_content), - structured_content=maybe_structured_content, - is_error=False, - ) - except UrlElicitationRequiredError: - # Re-raise UrlElicitationRequiredError so it can be properly handled - # by _handle_request, which converts it to an error response with code -32042 - raise - except Exception as e: - return self._make_error_result(str(e)) - - self.request_handlers[types.CallToolRequest] = handler - return func - - return decorator - - def progress_notification(self): - def decorator( - func: Callable[[str | int, float, float | None, str | None], Awaitable[None]], - ): - logger.debug("Registering handler for ProgressNotification") - - async def handler(req: types.ProgressNotification): - await func( - req.params.progress_token, - req.params.progress, - req.params.total, - req.params.message, - ) - - self.notification_handlers[types.ProgressNotification] = handler - return func - - return decorator - - def completion(self): - """Provides completions for prompts and resource templates""" - - def decorator( - func: Callable[ - [ - types.PromptReference | types.ResourceTemplateReference, - types.CompletionArgument, - types.CompletionContext | None, - ], - Awaitable[types.Completion | None], - ], - ): - logger.debug("Registering handler for CompleteRequest") - - async def handler(req: types.CompleteRequest): - completion = await func(req.params.ref, req.params.argument, req.params.context) - return types.CompleteResult( - completion=completion - if completion is not None - else types.Completion(values=[], total=None, has_more=None), - ) - - self.request_handlers[types.CompleteRequest] = handler - return func - - return decorator - async def run( self, read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], @@ -715,7 +428,7 @@ async def _handle_message( if raise_exceptions: raise message case _: - await self._handle_notification(message) + await self._handle_notification(message, session, lifespan_context) for warning in w: # pragma: lax no cover logger.info("Warning: %s: %s", warning.category.__name__, warning.message) @@ -730,10 +443,9 @@ async def _handle_request( ): logger.info("Processing request of type %s", type(req).__name__) - if handler := self.request_handlers.get(type(req)): + if handler := self._request_handlers.get(req.method): logger.debug("Dispatching request of type %s", type(req).__name__) - token = None try: # Extract request context and close_sse_stream from message metadata request_data = None @@ -744,32 +456,32 @@ async def _handle_request( close_sse_stream_cb = message.message_metadata.close_sse_stream close_standalone_sse_stream_cb = message.message_metadata.close_standalone_sse_stream - # Set our global state that can be retrieved via - # app.get_request_context() client_capabilities = session.client_params.capabilities if session.client_params else None task_support = self._experimental_handlers.task_support if self._experimental_handlers else None # Get task metadata from request params if present task_metadata = None if hasattr(req, "params") and req.params is not None: task_metadata = getattr(req.params, "task", None) - token = request_ctx.set( - ServerRequestContext( - request_id=message.request_id, - meta=message.request_meta, - session=session, - lifespan_context=lifespan_context, - experimental=Experimental( - task_metadata=task_metadata, - _client_capabilities=client_capabilities, - _session=session, - _task_support=task_support, - ), - request=request_data, - close_sse_stream=close_sse_stream_cb, - close_standalone_sse_stream=close_standalone_sse_stream_cb, - ) + ctx = ServerRequestContext( + request_id=message.request_id, + meta=message.request_meta, + session=session, + lifespan_context=lifespan_context, + experimental=Experimental( + task_metadata=task_metadata, + _client_capabilities=client_capabilities, + _session=session, + _task_support=task_support, + ), + request=request_data, + close_sse_stream=close_sse_stream_cb, + close_standalone_sse_stream=close_standalone_sse_stream_cb, ) - response = await handler(req) + token = request_ctx.set(ctx) + try: + response = await handler(ctx, req.params) + finally: + request_ctx.reset(token) except MCPError as err: response = err.error except anyio.get_cancelled_exc_class(): @@ -779,10 +491,6 @@ async def _handle_request( if raise_exceptions: # pragma: no cover raise err response = types.ErrorData(code=0, message=str(err), data=None) - finally: - # Reset the global state after we are done - if token is not None: # pragma: no branch - request_ctx.reset(token) await message.respond(response) else: # pragma: no cover @@ -790,12 +498,29 @@ async def _handle_request( logger.debug("Response sent") - async def _handle_notification(self, notify: Any): - if handler := self.notification_handlers.get(type(notify)): # type: ignore + async def _handle_notification( + self, + notify: types.ClientNotification, + session: ServerSession, + lifespan_context: LifespanResultT, + ) -> None: + if handler := self._notification_handlers.get(notify.method): logger.debug("Dispatching notification of type %s", type(notify).__name__) try: - await handler(notify) + client_capabilities = session.client_params.capabilities if session.client_params else None + task_support = self._experimental_handlers.task_support if self._experimental_handlers else None + ctx = ServerRequestContext( + session=session, + lifespan_context=lifespan_context, + experimental=Experimental( + task_metadata=None, + _client_capabilities=client_capabilities, + _session=session, + _task_support=task_support, + ), + ) + await handler(ctx, notify.params) except Exception: # pragma: no cover logger.exception("Uncaught exception in notification handler") @@ -910,7 +635,3 @@ def streamable_http_app( middleware=middleware, lifespan=lambda app: session_manager.run(), ) - - -async def _ping_handler(request: types.PingRequest) -> types.ServerResult: - return types.EmptyResult() diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index 8c1fc342be..f26944a2d8 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -2,7 +2,9 @@ from __future__ import annotations +import base64 import inspect +import json import re from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence from contextlib import AbstractAsyncContextManager, asynccontextmanager @@ -29,7 +31,7 @@ from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, UrlElicitationResult, elicit_with_validation from mcp.server.elicitation import elicit_url as _elicit_url from mcp.server.lowlevel.helper_types import ReadResourceContents -from mcp.server.lowlevel.server import LifespanResultT, Server +from mcp.server.lowlevel.server import LifespanResultT, Server, request_ctx from mcp.server.lowlevel.server import lifespan as default_lifespan from mcp.server.mcpserver.exceptions import ResourceError from mcp.server.mcpserver.prompts import Prompt, PromptManager @@ -42,7 +44,30 @@ from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings -from mcp.types import Annotations, ContentBlock, GetPromptResult, Icon, ToolAnnotations +from mcp.shared.exceptions import MCPError +from mcp.types import ( + Annotations, + BlobResourceContents, + CallToolRequestParams, + CallToolResult, + CompleteRequestParams, + CompleteResult, + Completion, + ContentBlock, + GetPromptRequestParams, + GetPromptResult, + Icon, + ListPromptsResult, + ListResourcesResult, + ListResourceTemplatesResult, + ListToolsResult, + PaginatedRequestParams, + ReadResourceRequestParams, + ReadResourceResult, + TextContent, + TextResourceContents, + ToolAnnotations, +) from mcp.types import Prompt as MCPPrompt from mcp.types import PromptArgument as MCPPromptArgument from mcp.types import Resource as MCPResource @@ -91,9 +116,9 @@ class Settings(BaseSettings, Generic[LifespanResultT]): def lifespan_wrapper( app: MCPServer[LifespanResultT], lifespan: Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]], -) -> Callable[[Server[LifespanResultT, Request]], AbstractAsyncContextManager[LifespanResultT]]: +) -> Callable[[Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]]: @asynccontextmanager - async def wrap(_: Server[LifespanResultT, Request]) -> AsyncIterator[LifespanResultT]: + async def wrap(_: Server[LifespanResultT]) -> AsyncIterator[LifespanResultT]: async with lifespan(app) as context: yield context @@ -132,6 +157,9 @@ def __init__( auth=auth, ) + self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools) + self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources) + self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts) self._lowlevel_server = Server( name=name or "mcp-server", title=title, @@ -140,13 +168,17 @@ def __init__( website_url=website_url, icons=icons, version=version, + on_list_tools=self._handle_list_tools, + on_call_tool=self._handle_call_tool, + on_list_resources=self._handle_list_resources, + on_read_resource=self._handle_read_resource, + on_list_resource_templates=self._handle_list_resource_templates, + on_list_prompts=self._handle_list_prompts, + on_get_prompt=self._handle_get_prompt, # TODO(Marcelo): It seems there's a type mismatch between the lifespan type from an MCPServer and Server. # We need to create a Lifespan type that is a generic on the server type, like Starlette does. lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan), # type: ignore ) - self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools) - self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources) - self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts) # Validate auth configuration if self.settings.auth is not None: if auth_server_provider and token_verifier: # pragma: no cover @@ -164,9 +196,6 @@ def __init__( self._token_verifier = ProviderTokenVerifier(auth_server_provider) self._custom_starlette_routes: list[Route] = [] - # Set up MCP protocol handlers - self._setup_handlers() - # Configure logging configure_logging(self.settings.log_level) @@ -263,18 +292,83 @@ def run( case "streamable-http": # pragma: no cover anyio.run(lambda: self.run_streamable_http_async(**kwargs)) - def _setup_handlers(self) -> None: - """Set up core MCP protocol handlers.""" - self._lowlevel_server.list_tools()(self.list_tools) - # Note: we disable the lowlevel server's input validation. - # MCPServer does ad hoc conversion of incoming data before validating - - # for now we preserve this for backwards compatibility. - self._lowlevel_server.call_tool(validate_input=False)(self.call_tool) - self._lowlevel_server.list_resources()(self.list_resources) - self._lowlevel_server.read_resource()(self.read_resource) - self._lowlevel_server.list_prompts()(self.list_prompts) - self._lowlevel_server.get_prompt()(self.get_prompt) - self._lowlevel_server.list_resource_templates()(self.list_resource_templates) + async def _handle_list_tools( + self, ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None + ) -> ListToolsResult: + return ListToolsResult(tools=await self.list_tools()) + + async def _handle_call_tool( + self, ctx: ServerRequestContext[LifespanResultT], params: CallToolRequestParams + ) -> CallToolResult: + try: + result = await self.call_tool(params.name, params.arguments or {}) + except MCPError: + raise + except Exception as e: + return CallToolResult(content=[TextContent(type="text", text=str(e))], is_error=True) + if isinstance(result, CallToolResult): + return result + if isinstance(result, tuple) and len(result) == 2: + unstructured_content, structured_content = result + return CallToolResult( + content=list(unstructured_content), # type: ignore[arg-type] + structured_content=structured_content, # type: ignore[arg-type] + ) + if isinstance(result, dict): # pragma: no cover + # TODO: this code path is unreachable — convert_result never returns a raw dict. + # The call_tool return type (Sequence[ContentBlock] | dict[str, Any]) is wrong + # and needs to be cleaned up. + return CallToolResult( + content=[TextContent(type="text", text=json.dumps(result, indent=2))], + structured_content=result, + ) + return CallToolResult(content=list(result)) + + async def _handle_list_resources( + self, ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult(resources=await self.list_resources()) + + async def _handle_read_resource( + self, ctx: ServerRequestContext[LifespanResultT], params: ReadResourceRequestParams + ) -> ReadResourceResult: + results = await self.read_resource(params.uri) + contents: list[TextResourceContents | BlobResourceContents] = [] + for item in results: + if isinstance(item.content, bytes): + contents.append( + BlobResourceContents( + uri=params.uri, + blob=base64.b64encode(item.content).decode(), + mime_type=item.mime_type or "application/octet-stream", + _meta=item.meta, + ) + ) + else: + contents.append( + TextResourceContents( + uri=params.uri, + text=item.content, + mime_type=item.mime_type or "text/plain", + _meta=item.meta, + ) + ) + return ReadResourceResult(contents=contents) + + async def _handle_list_resource_templates( + self, ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None + ) -> ListResourceTemplatesResult: + return ListResourceTemplatesResult(resource_templates=await self.list_resource_templates()) + + async def _handle_list_prompts( + self, ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None + ) -> ListPromptsResult: + return ListPromptsResult(prompts=await self.list_prompts()) + + async def _handle_get_prompt( + self, ctx: ServerRequestContext[LifespanResultT], params: GetPromptRequestParams + ) -> GetPromptResult: + return await self.get_prompt(params.name, params.arguments) async def list_tools(self) -> list[MCPTool]: """List all available tools.""" @@ -298,7 +392,7 @@ def get_context(self) -> Context[LifespanResultT, Request]: during a request; outside a request, most methods will error. """ try: - request_context = self._lowlevel_server.request_context + request_context = request_ctx.get() except LookupError: request_context = None return Context(request_context=request_context, mcp_server=self) @@ -486,7 +580,24 @@ async def handle_completion(ref, argument, context): return Completion(values=["option1", "option2"]) return None """ - return self._lowlevel_server.completion() + + def decorator(func: _CallableT) -> _CallableT: + async def handler( + ctx: ServerRequestContext[LifespanResultT], params: CompleteRequestParams + ) -> CompleteResult: + result = await func(params.ref, params.argument, params.context) + return CompleteResult( + completion=result if result is not None else Completion(values=[], total=None, has_more=None), + ) + + # TODO(maxisbey): remove private access — completion needs post-construction + # handler registration, find a better pattern for this + self._lowlevel_server._add_request_handler( # pyright: ignore[reportPrivateUsage] + "completion/complete", handler + ) + return func + + return decorator def add_resource(self, resource: Resource) -> None: """Add a resource to the server. diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index f496121a37..6925aa556b 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -6,30 +6,22 @@ Common usage pattern: ``` - server = Server(name) - - @server.call_tool() - async def handle_tool_call(ctx: RequestContext, arguments: dict[str, Any]) -> Any: + async def handle_call_tool(ctx: RequestContext, params: CallToolRequestParams) -> CallToolResult: # Check client capabilities before proceeding if ctx.session.check_client_capability( types.ClientCapabilities(experimental={"advanced_tools": dict()}) ): - # Perform advanced tool operations - result = await perform_advanced_tool_operation(arguments) + result = await perform_advanced_tool_operation(params.arguments) else: - # Fall back to basic tool operations - result = await perform_basic_tool_operation(arguments) - + result = await perform_basic_tool_operation(params.arguments) return result - @server.list_prompts() - async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: - # Access session for any necessary checks or operations + async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult: if ctx.session.client_params: - # Customize prompts based on client initialization parameters - return generate_custom_prompts(ctx.session.client_params) - else: - return default_prompts + return ListPromptsResult(prompts=generate_custom_prompts(ctx.session.client_params)) + return ListPromptsResult(prompts=default_prompts) + + server = Server(name, on_call_tool=handle_call_tool, on_list_prompts=handle_list_prompts) ``` The ServerSession class is typically used internally by the Server class and should not diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index ddc6e5014f..8eb29c4d48 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -60,7 +60,7 @@ class StreamableHTTPSessionManager: def __init__( self, - app: Server[Any, Any], + app: Server[Any], event_store: EventStore | None = None, json_response: bool = False, stateless: bool = False, diff --git a/src/mcp/shared/_context.py b/src/mcp/shared/_context.py index 2facc2a493..bbcee2d02c 100644 --- a/src/mcp/shared/_context.py +++ b/src/mcp/shared/_context.py @@ -13,8 +13,12 @@ @dataclass(kw_only=True) class RequestContext(Generic[SessionT]): - """Common context for handling incoming requests.""" + """Common context for handling incoming requests. + + For request handlers, request_id is always populated. + For notification handlers, request_id is None. + """ - request_id: RequestId - meta: RequestParamsMeta | None session: SessionT + request_id: RequestId | None = None + meta: RequestParamsMeta | None = None diff --git a/src/mcp/shared/experimental/tasks/helpers.py b/src/mcp/shared/experimental/tasks/helpers.py index 38ca802daf..bd1781cb57 100644 --- a/src/mcp/shared/experimental/tasks/helpers.py +++ b/src/mcp/shared/experimental/tasks/helpers.py @@ -72,9 +72,8 @@ async def cancel_task( - Task is already in a terminal state (completed, failed, cancelled) Example: - @server.experimental.cancel_task() - async def handle_cancel(request: CancelTaskRequest) -> CancelTaskResult: - return await cancel_task(store, request.params.taskId) + async def handle_cancel(ctx, params: CancelTaskRequestParams) -> CancelTaskResult: + return await cancel_task(store, params.task_id) """ task = await store.get_task(task_id) if task is None: diff --git a/src/mcp/shared/message.py b/src/mcp/shared/message.py index 9dedd2e5d3..1858eeac31 100644 --- a/src/mcp/shared/message.py +++ b/src/mcp/shared/message.py @@ -6,6 +6,7 @@ from collections.abc import Awaitable, Callable from dataclasses import dataclass +from typing import Any from mcp.types import JSONRPCMessage, RequestId @@ -30,8 +31,10 @@ class ServerMessageMetadata: """Metadata specific to server messages.""" related_request_id: RequestId | None = None - # Request-specific context (e.g., headers, auth info) - request_context: object | None = None + # Transport-specific request context (e.g. starlette Request for HTTP + # transports, None for stdio). Typed as Any because the server layer is + # transport-agnostic. + request_context: Any = None # Callback to close SSE stream for the current request without terminating close_sse_stream: CloseSSEStreamCallback | None = None # Callback to close the standalone GET SSE stream (for unsolicited notifications) diff --git a/tests/client/test_client.py b/tests/client/test_client.py index d483ae54b6..45300063a2 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -11,7 +11,7 @@ from mcp import types from mcp.client._memory import InMemoryTransport from mcp.client.client import Client -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.mcpserver import MCPServer from mcp.types import ( CallToolResult, @@ -41,33 +41,36 @@ @pytest.fixture def simple_server() -> Server: """Create a simple MCP server for testing.""" - server = Server(name="test_server") - @server.list_resources() - async def handle_list_resources(): - return [Resource(uri="memory://test", name="Test Resource", description="A test resource")] + async def handle_list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult( + resources=[Resource(uri="memory://test", name="Test Resource", description="A test resource")] + ) - @server.subscribe_resource() - async def handle_subscribe_resource(uri: str): - pass + async def handle_subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeRequestParams) -> EmptyResult: + return EmptyResult() - @server.unsubscribe_resource() - async def handle_unsubscribe_resource(uri: str): - pass + async def handle_unsubscribe_resource( + ctx: ServerRequestContext, params: types.UnsubscribeRequestParams + ) -> EmptyResult: + return EmptyResult() - @server.set_logging_level() - async def handle_set_logging_level(level: str): - pass + async def handle_set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + return EmptyResult() - @server.completion() - async def handle_completion( - ref: types.PromptReference | types.ResourceTemplateReference, - argument: types.CompletionArgument, - context: types.CompletionContext | None, - ) -> types.Completion | None: - return types.Completion(values=[]) + async def handle_completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> types.CompleteResult: + return types.CompleteResult(completion=types.Completion(values=[])) - return server + return Server( + name="test_server", + on_list_resources=handle_list_resources, + on_subscribe_resource=handle_subscribe_resource, + on_unsubscribe_resource=handle_unsubscribe_resource, + on_set_logging_level=handle_set_logging_level, + on_completion=handle_completion, + ) @pytest.fixture @@ -202,19 +205,14 @@ async def test_client_send_progress_notification(): """Test sending progress notification.""" received_from_client = None event = anyio.Event() - server = Server(name="test_server") - - @server.progress_notification() - async def handle_progress_notification( - progress_token: str | int, - progress: float = 0.0, - total: float | None = None, - message: str | None = None, - ) -> None: + + async def handle_progress(ctx: ServerRequestContext, params: types.ProgressNotificationParams) -> None: nonlocal received_from_client - received_from_client = {"progress_token": progress_token, "progress": progress} + received_from_client = {"progress_token": params.progress_token, "progress": params.progress} event.set() + server = Server(name="test_server", on_progress=handle_progress) + async with Client(server) as client: await client.send_progress_notification(progress_token="token123", progress=50.0) await event.wait() diff --git a/tests/client/test_http_unicode.py b/tests/client/test_http_unicode.py index 5cca8c1943..cc2e14e469 100644 --- a/tests/client/test_http_unicode.py +++ b/tests/client/test_http_unicode.py @@ -8,7 +8,6 @@ import socket from collections.abc import AsyncGenerator, Generator from contextlib import asynccontextmanager -from typing import Any import pytest from starlette.applications import Starlette @@ -17,7 +16,7 @@ from mcp import types from mcp.client.session import ClientSession from mcp.client.streamable_http import streamable_http_client -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.types import TextContent, Tool from tests.test_helpers import wait_for_server @@ -47,54 +46,56 @@ def run_unicode_server(port: int) -> None: # pragma: no cover import uvicorn # Need to recreate the server setup in this process - server = Server(name="unicode_test_server") - - @server.list_tools() - async def list_tools() -> list[Tool]: - """List tools with Unicode descriptions.""" - return [ - Tool( - name="echo_unicode", - description="🔤 Echo Unicode text - Hello 👋 World 🌍 - Testing 🧪 Unicode ✨", - input_schema={ - "type": "object", - "properties": { - "text": {"type": "string", "description": "Text to echo back"}, + async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + Tool( + name="echo_unicode", + description="🔤 Echo Unicode text - Hello 👋 World 🌍 - Testing 🧪 Unicode ✨", + input_schema={ + "type": "object", + "properties": { + "text": {"type": "string", "description": "Text to echo back"}, + }, + "required": ["text"], }, - "required": ["text"], - }, - ), - ] - - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any] | None) -> list[TextContent]: - """Handle tool calls with Unicode content.""" - if name == "echo_unicode": - text = arguments.get("text", "") if arguments else "" - return [ - TextContent( - type="text", - text=f"Echo: {text}", - ) + ), ] - else: - raise ValueError(f"Unknown tool: {name}") - - @server.list_prompts() - async def list_prompts() -> list[types.Prompt]: - """List prompts with Unicode names and descriptions.""" - return [ - types.Prompt( - name="unicode_prompt", - description="Unicode prompt - Слой хранилища, где располагаются", - arguments=[], + ) + + async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + if params.name == "echo_unicode": + text = params.arguments.get("text", "") if params.arguments else "" + return types.CallToolResult( + content=[ + TextContent( + type="text", + text=f"Echo: {text}", + ) + ] ) - ] + else: + raise ValueError(f"Unknown tool: {params.name}") + + async def handle_list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListPromptsResult: + return types.ListPromptsResult( + prompts=[ + types.Prompt( + name="unicode_prompt", + description="Unicode prompt - Слой хранилища, где располагаются", + arguments=[], + ) + ] + ) - @server.get_prompt() - async def get_prompt(name: str, arguments: dict[str, Any] | None) -> types.GetPromptResult: - """Get a prompt with Unicode content.""" - if name == "unicode_prompt": + async def handle_get_prompt( + ctx: ServerRequestContext, params: types.GetPromptRequestParams + ) -> types.GetPromptResult: + if params.name == "unicode_prompt": return types.GetPromptResult( messages=[ types.PromptMessage( @@ -106,7 +107,15 @@ async def get_prompt(name: str, arguments: dict[str, Any] | None) -> types.GetPr ) ] ) - raise ValueError(f"Unknown prompt: {name}") + raise ValueError(f"Unknown prompt: {params.name}") + + server = Server( + name="unicode_test_server", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + on_list_prompts=handle_list_prompts, + on_get_prompt=handle_get_prompt, + ) # Create the session manager session_manager = StreamableHTTPSessionManager( diff --git a/tests/client/test_list_methods_cursor.py b/tests/client/test_list_methods_cursor.py index 4d7c53db25..f70fb9277d 100644 --- a/tests/client/test_list_methods_cursor.py +++ b/tests/client/test_list_methods_cursor.py @@ -3,9 +3,9 @@ import pytest from mcp import Client, types -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.mcpserver import MCPServer -from mcp.types import ListToolsRequest, ListToolsResult +from mcp.types import ListToolsResult from .conftest import StreamSpyCollection @@ -105,14 +105,16 @@ async def test_list_tools_with_strict_server_validation( async def test_list_tools_with_lowlevel_server(): """Test that list_tools works with a lowlevel Server using params.""" - server = Server("test-lowlevel") - @server.list_tools() - async def handle_list_tools(request: ListToolsRequest) -> ListToolsResult: + async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListToolsResult: # Echo back what cursor we received in the tool description - cursor = request.params.cursor if request.params else None + cursor = params.cursor if params else None return ListToolsResult(tools=[types.Tool(name="test_tool", description=f"cursor={cursor}", input_schema={})]) + server = Server("test-lowlevel", on_list_tools=handle_list_tools) + async with Client(server) as client: result = await client.list_tools() assert result.tools[0].description == "cursor=None" diff --git a/tests/client/test_output_schema_validation.py b/tests/client/test_output_schema_validation.py index cc93d303b1..d78197b5c3 100644 --- a/tests/client/test_output_schema_validation.py +++ b/tests/client/test_output_schema_validation.py @@ -1,50 +1,41 @@ -import inspect import logging -from contextlib import contextmanager from typing import Any -from unittest.mock import patch -import jsonschema import pytest from mcp import Client -from mcp.server.lowlevel import Server -from mcp.types import Tool - - -@contextmanager -def bypass_server_output_validation(): - """Context manager that bypasses server-side output validation. - This simulates a malicious or non-compliant server that doesn't validate - its outputs, allowing us to test client-side validation. - """ - # Save the original validate function - original_validate = jsonschema.validate - - # Create a mock that tracks which module is calling it - def selective_mock(instance: Any = None, schema: Any = None, *args: Any, **kwargs: Any) -> None: - # Check the call stack to see where this is being called from - for frame_info in inspect.stack(): - # If called from the server module, skip validation - # TODO: fix this as it's a rather gross workaround and will eventually break - # Normalize path separators for cross-platform compatibility - normalized_path = frame_info.filename.replace("\\", "/") - if "mcp/server/lowlevel/server.py" in normalized_path: - return None - # Otherwise, use the real validation (for client-side) - return original_validate(instance=instance, schema=schema, *args, **kwargs) - - with patch("jsonschema.validate", selective_mock): - yield +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + CallToolRequestParams, + CallToolResult, + ListToolsResult, + PaginatedRequestParams, + TextContent, + Tool, +) + + +def _make_server( + tools: list[Tool], + structured_content: dict[str, Any], +) -> Server: + """Create a low-level server that returns the given structured_content for any tool call.""" + + async def on_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=tools) + + async def on_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + return CallToolResult( + content=[TextContent(type="text", text="result")], + structured_content=structured_content, + ) + + return Server("test-server", on_list_tools=on_list_tools, on_call_tool=on_call_tool) @pytest.mark.anyio async def test_tool_structured_output_client_side_validation_basemodel(): """Test that client validates structured content against schema for BaseModel outputs""" - # Create a malicious low-level server that returns invalid structured content - server = Server("test-server") - - # Define the expected schema for our tool output_schema = { "type": "object", "properties": {"name": {"type": "string", "title": "Name"}, "age": {"type": "integer", "title": "Age"}}, @@ -52,39 +43,27 @@ async def test_tool_structured_output_client_side_validation_basemodel(): "title": "UserOutput", } - @server.list_tools() - async def list_tools(): - return [ + server = _make_server( + tools=[ Tool( name="get_user", description="Get user data", input_schema={"type": "object"}, output_schema=output_schema, ) - ] - - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]): - # Return invalid structured content - age is string instead of integer - # The low-level server will wrap this in CallToolResult - return {"name": "John", "age": "invalid"} # Invalid: age should be int + ], + structured_content={"name": "John", "age": "invalid"}, # Invalid: age should be int + ) - # Test that client validates the structured content - with bypass_server_output_validation(): - async with Client(server) as client: - # The client validates structured content and should raise an error - with pytest.raises(RuntimeError) as exc_info: - await client.call_tool("get_user", {}) - # Verify it's a validation error - assert "Invalid structured content returned by tool get_user" in str(exc_info.value) + async with Client(server) as client: + with pytest.raises(RuntimeError) as exc_info: + await client.call_tool("get_user", {}) + assert "Invalid structured content returned by tool get_user" in str(exc_info.value) @pytest.mark.anyio async def test_tool_structured_output_client_side_validation_primitive(): """Test that client validates structured content for primitive outputs""" - server = Server("test-server") - - # Primitive types are wrapped in {"result": value} output_schema = { "type": "object", "properties": {"result": {"type": "integer", "title": "Result"}}, @@ -92,122 +71,95 @@ async def test_tool_structured_output_client_side_validation_primitive(): "title": "calculate_Output", } - @server.list_tools() - async def list_tools(): - return [ + server = _make_server( + tools=[ Tool( name="calculate", description="Calculate something", input_schema={"type": "object"}, output_schema=output_schema, ) - ] - - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]): - # Return invalid structured content - result is string instead of integer - return {"result": "not_a_number"} # Invalid: should be int + ], + structured_content={"result": "not_a_number"}, # Invalid: should be int + ) - with bypass_server_output_validation(): - async with Client(server) as client: - # The client validates structured content and should raise an error - with pytest.raises(RuntimeError) as exc_info: - await client.call_tool("calculate", {}) - assert "Invalid structured content returned by tool calculate" in str(exc_info.value) + async with Client(server) as client: + with pytest.raises(RuntimeError) as exc_info: + await client.call_tool("calculate", {}) + assert "Invalid structured content returned by tool calculate" in str(exc_info.value) @pytest.mark.anyio async def test_tool_structured_output_client_side_validation_dict_typed(): """Test that client validates dict[str, T] structured content""" - server = Server("test-server") - - # dict[str, int] schema output_schema = {"type": "object", "additionalProperties": {"type": "integer"}, "title": "get_scores_Output"} - @server.list_tools() - async def list_tools(): - return [ + server = _make_server( + tools=[ Tool( name="get_scores", description="Get scores", input_schema={"type": "object"}, output_schema=output_schema, ) - ] - - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]): - # Return invalid structured content - values should be integers - return {"alice": "100", "bob": "85"} # Invalid: values should be int + ], + structured_content={"alice": "100", "bob": "85"}, # Invalid: values should be int + ) - with bypass_server_output_validation(): - async with Client(server) as client: - # The client validates structured content and should raise an error - with pytest.raises(RuntimeError) as exc_info: - await client.call_tool("get_scores", {}) - assert "Invalid structured content returned by tool get_scores" in str(exc_info.value) + async with Client(server) as client: + with pytest.raises(RuntimeError) as exc_info: + await client.call_tool("get_scores", {}) + assert "Invalid structured content returned by tool get_scores" in str(exc_info.value) @pytest.mark.anyio async def test_tool_structured_output_client_side_validation_missing_required(): """Test that client validates missing required fields""" - server = Server("test-server") - output_schema = { "type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "integer"}, "email": {"type": "string"}}, - "required": ["name", "age", "email"], # All fields required + "required": ["name", "age", "email"], "title": "PersonOutput", } - @server.list_tools() - async def list_tools(): - return [ + server = _make_server( + tools=[ Tool( name="get_person", description="Get person data", input_schema={"type": "object"}, output_schema=output_schema, ) - ] - - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]): - # Return structured content missing required field 'email' - return {"name": "John", "age": 30} # Missing required 'email' + ], + structured_content={"name": "John", "age": 30}, # Missing required 'email' + ) - with bypass_server_output_validation(): - async with Client(server) as client: - # The client validates structured content and should raise an error - with pytest.raises(RuntimeError) as exc_info: - await client.call_tool("get_person", {}) - assert "Invalid structured content returned by tool get_person" in str(exc_info.value) + async with Client(server) as client: + with pytest.raises(RuntimeError) as exc_info: + await client.call_tool("get_person", {}) + assert "Invalid structured content returned by tool get_person" in str(exc_info.value) @pytest.mark.anyio async def test_tool_not_listed_warning(caplog: pytest.LogCaptureFixture): """Test that client logs warning when tool is not in list_tools but has output_schema""" - server = Server("test-server") - @server.list_tools() - async def list_tools() -> list[Tool]: - # Return empty list - tool is not listed - return [] + async def on_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[]) + + async def on_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + return CallToolResult( + content=[TextContent(type="text", text="result")], + structured_content={"result": 42}, + ) - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]) -> dict[str, Any]: - # Server still responds to the tool call with structured content - return {"result": 42} + server = Server("test-server", on_list_tools=on_list_tools, on_call_tool=on_call_tool) - # Set logging level to capture warnings caplog.set_level(logging.WARNING) - with bypass_server_output_validation(): - async with Client(server) as client: - # Call a tool that wasn't listed - result = await client.call_tool("mystery_tool", {}) - assert result.structured_content == {"result": 42} - assert result.is_error is False + async with Client(server) as client: + result = await client.call_tool("mystery_tool", {}) + assert result.structured_content == {"result": 42} + assert result.is_error is False - # Check that warning was logged - assert "Tool mystery_tool not listed" in caplog.text + assert "Tool mystery_tool not listed" in caplog.text diff --git a/tests/client/transports/test_memory.py b/tests/client/transports/test_memory.py index 30ecb0ac33..47be3e2089 100644 --- a/tests/client/transports/test_memory.py +++ b/tests/client/transports/test_memory.py @@ -2,31 +2,31 @@ import pytest -from mcp import Client +from mcp import Client, types from mcp.client._memory import InMemoryTransport -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.mcpserver import MCPServer -from mcp.types import Resource +from mcp.types import ListResourcesResult, Resource @pytest.fixture def simple_server() -> Server: """Create a simple MCP server for testing.""" - server = Server(name="test_server") - - # pragma: no cover - handler exists only to register a resource capability. - # Transport tests verify stream creation, not handler invocation. - @server.list_resources() - async def handle_list_resources(): # pragma: no cover - return [ - Resource( - uri="memory://test", - name="Test Resource", - description="A test resource", - ) - ] - return server + async def handle_list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: # pragma: no cover + return ListResourcesResult( + resources=[ + Resource( + uri="memory://test", + name="Test Resource", + description="A test resource", + ) + ] + ) + + return Server(name="test_server", on_list_resources=handle_list_resources) @pytest.fixture diff --git a/tests/experimental/tasks/client/test_tasks.py b/tests/experimental/tasks/client/test_tasks.py index f21abf4d0f..613c794ebf 100644 --- a/tests/experimental/tasks/client/test_tasks.py +++ b/tests/experimental/tasks/client/test_tasks.py @@ -1,43 +1,38 @@ """Tests for the experimental client task methods (session.experimental).""" +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager from dataclasses import dataclass, field -from typing import Any import anyio import pytest from anyio import Event from anyio.abc import TaskGroup -from mcp.client.session import ClientSession -from mcp.server import Server -from mcp.server.lowlevel import NotificationOptions -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession +from mcp import Client +from mcp.server import Server, ServerRequestContext from mcp.shared.experimental.tasks.helpers import task_execution from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder from mcp.types import ( CallToolRequest, CallToolRequestParams, CallToolResult, - CancelTaskRequest, + CancelTaskRequestParams, CancelTaskResult, - ClientResult, CreateTaskResult, - GetTaskPayloadRequest, + GetTaskPayloadRequestParams, GetTaskPayloadResult, - GetTaskRequest, + GetTaskRequestParams, GetTaskResult, - ListTasksRequest, ListTasksResult, - ServerNotification, - ServerRequest, + ListToolsResult, + PaginatedRequestParams, TaskMetadata, TextContent, - Tool, ) +pytestmark = pytest.mark.anyio + @dataclass class AppContext: @@ -48,44 +43,53 @@ class AppContext: task_done_events: dict[str, Event] = field(default_factory=lambda: {}) -@pytest.mark.anyio -async def test_session_experimental_get_task() -> None: - """Test session.experimental.get_task() method.""" - # Note: We bypass the normal lifespan mechanism - server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] - store = InMemoryTaskStore() +async def _handle_list_tools( + ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None +) -> ListToolsResult: + raise NotImplementedError - @server.list_tools() - async def list_tools(): - return [Tool(name="test_tool", description="Test", input_schema={"type": "object"})] - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: - ctx = server.request_context - app = ctx.lifespan_context - if ctx.experimental.is_task: - task_metadata = ctx.experimental.task_metadata - assert task_metadata is not None - task = await app.store.create_task(task_metadata) +async def _handle_call_tool_with_done_event( + ctx: ServerRequestContext[AppContext], params: CallToolRequestParams, *, result_text: str = "Done" +) -> CallToolResult | CreateTaskResult: + app = ctx.lifespan_context + if ctx.experimental.is_task: + task_metadata = ctx.experimental.task_metadata + assert task_metadata is not None + task = await app.store.create_task(task_metadata) - done_event = Event() - app.task_done_events[task.task_id] = done_event + done_event = Event() + app.task_done_events[task.task_id] = done_event - async def do_work(): - async with task_execution(task.task_id, app.store) as task_ctx: - await task_ctx.complete(CallToolResult(content=[TextContent(type="text", text="Done")])) - done_event.set() + async def do_work() -> None: + async with task_execution(task.task_id, app.store) as task_ctx: + await task_ctx.complete(CallToolResult(content=[TextContent(type="text", text=result_text)])) + done_event.set() - app.task_group.start_soon(do_work) - return CreateTaskResult(task=task) + app.task_group.start_soon(do_work) + return CreateTaskResult(task=task) - raise NotImplementedError + raise NotImplementedError + + +def _make_lifespan(store: InMemoryTaskStore, task_done_events: dict[str, Event]): + @asynccontextmanager + async def app_lifespan(server: Server[AppContext]) -> AsyncIterator[AppContext]: + async with anyio.create_task_group() as tg: + yield AppContext(task_group=tg, store=store, task_done_events=task_done_events) - @server.experimental.get_task() - async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: - app = server.request_context.lifespan_context - task = await app.store.get_task(request.params.task_id) - assert task is not None, f"Test setup error: task {request.params.task_id} should exist" + return app_lifespan + + +async def test_session_experimental_get_task() -> None: + """Test session.experimental.get_task() method.""" + store = InMemoryTaskStore() + task_done_events: dict[str, Event] = {} + + async def handle_get_task(ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams) -> GetTaskResult: + app = ctx.lifespan_context + task = await app.store.get_task(params.task_id) + assert task is not None, f"Test setup error: task {params.task_id} should exist" return GetTaskResult( task_id=task.task_id, status=task.status, @@ -96,280 +100,141 @@ async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: poll_interval=task.poll_interval, ) - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: ... # pragma: no branch - - async def run_server(app_context: AppContext): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), + server: Server[AppContext] = Server( + "test-server", + lifespan=_make_lifespan(store, task_done_events), + on_list_tools=_handle_list_tools, + on_call_tool=_handle_call_tool_with_done_event, + ) + server.experimental.enable_tasks(on_get_task=handle_get_task) + + async with Client(server) as client: + # Create a task + create_result = await client.session.send_request( + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), + ) ), - ) as server_session: - async for message in server_session.incoming_messages: - await server._handle_message(message, server_session, app_context, raise_exceptions=False) - - async with anyio.create_task_group() as tg: - app_context = AppContext(task_group=tg, store=store) - tg.start_soon(run_server, app_context) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - await client_session.initialize() - - # Create a task - create_result = await client_session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) - ), - CreateTaskResult, - ) - task_id = create_result.task.task_id - - # Wait for task to complete - await app_context.task_done_events[task_id].wait() + CreateTaskResult, + ) + task_id = create_result.task.task_id - # Use session.experimental to get task status - task_status = await client_session.experimental.get_task(task_id) + # Wait for task to complete + await task_done_events[task_id].wait() - assert task_status.task_id == task_id - assert task_status.status == "completed" + # Use session.experimental to get task status + task_status = await client.session.experimental.get_task(task_id) - tg.cancel_scope.cancel() + assert task_status.task_id == task_id + assert task_status.status == "completed" -@pytest.mark.anyio async def test_session_experimental_get_task_result() -> None: """Test session.experimental.get_task_result() method.""" - server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] store = InMemoryTaskStore() + task_done_events: dict[str, Event] = {} - @server.list_tools() - async def list_tools(): - return [Tool(name="test_tool", description="Test", input_schema={"type": "object"})] + async def handle_call_tool( + ctx: ServerRequestContext[AppContext], params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: + return await _handle_call_tool_with_done_event(ctx, params, result_text="Task result content") - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: - ctx = server.request_context - app = ctx.lifespan_context - if ctx.experimental.is_task: - task_metadata = ctx.experimental.task_metadata - assert task_metadata is not None - task = await app.store.create_task(task_metadata) - - done_event = Event() - app.task_done_events[task.task_id] = done_event - - async def do_work(): - async with task_execution(task.task_id, app.store) as task_ctx: - await task_ctx.complete( - CallToolResult(content=[TextContent(type="text", text="Task result content")]) - ) - done_event.set() - - app.task_group.start_soon(do_work) - return CreateTaskResult(task=task) - - raise NotImplementedError - - @server.experimental.get_task_result() async def handle_get_task_result( - request: GetTaskPayloadRequest, + ctx: ServerRequestContext[AppContext], params: GetTaskPayloadRequestParams ) -> GetTaskPayloadResult: - app = server.request_context.lifespan_context - result = await app.store.get_result(request.params.task_id) - assert result is not None, f"Test setup error: result for {request.params.task_id} should exist" + app = ctx.lifespan_context + result = await app.store.get_result(params.task_id) + assert result is not None, f"Test setup error: result for {params.task_id} should exist" assert isinstance(result, CallToolResult) return GetTaskPayloadResult(**result.model_dump()) - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: ... # pragma: no branch - - async def run_server(app_context: AppContext): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), + server: Server[AppContext] = Server( + "test-server", + lifespan=_make_lifespan(store, task_done_events), + on_list_tools=_handle_list_tools, + on_call_tool=handle_call_tool, + ) + server.experimental.enable_tasks(on_task_result=handle_get_task_result) + + async with Client(server) as client: + # Create a task + create_result = await client.session.send_request( + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), + ) ), - ) as server_session: - async for message in server_session.incoming_messages: - await server._handle_message(message, server_session, app_context, raise_exceptions=False) - - async with anyio.create_task_group() as tg: - app_context = AppContext(task_group=tg, store=store) - tg.start_soon(run_server, app_context) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - await client_session.initialize() - - # Create a task - create_result = await client_session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) - ), - CreateTaskResult, - ) - task_id = create_result.task.task_id - - # Wait for task to complete - await app_context.task_done_events[task_id].wait() + CreateTaskResult, + ) + task_id = create_result.task.task_id - # Use TaskClient to get task result - task_result = await client_session.experimental.get_task_result(task_id, CallToolResult) + # Wait for task to complete + await task_done_events[task_id].wait() - assert len(task_result.content) == 1 - content = task_result.content[0] - assert isinstance(content, TextContent) - assert content.text == "Task result content" + # Use TaskClient to get task result + task_result = await client.session.experimental.get_task_result(task_id, CallToolResult) - tg.cancel_scope.cancel() + assert len(task_result.content) == 1 + content = task_result.content[0] + assert isinstance(content, TextContent) + assert content.text == "Task result content" -@pytest.mark.anyio async def test_session_experimental_list_tasks() -> None: """Test TaskClient.list_tasks() method.""" - server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] store = InMemoryTaskStore() + task_done_events: dict[str, Event] = {} - @server.list_tools() - async def list_tools(): - return [Tool(name="test_tool", description="Test", input_schema={"type": "object"})] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: - ctx = server.request_context + async def handle_list_tasks( + ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None + ) -> ListTasksResult: app = ctx.lifespan_context - if ctx.experimental.is_task: - task_metadata = ctx.experimental.task_metadata - assert task_metadata is not None - task = await app.store.create_task(task_metadata) - - done_event = Event() - app.task_done_events[task.task_id] = done_event - - async def do_work(): - async with task_execution(task.task_id, app.store) as task_ctx: - await task_ctx.complete(CallToolResult(content=[TextContent(type="text", text="Done")])) - done_event.set() - - app.task_group.start_soon(do_work) - return CreateTaskResult(task=task) - - raise NotImplementedError - - @server.experimental.list_tasks() - async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: - app = server.request_context.lifespan_context - tasks_list, next_cursor = await app.store.list_tasks(cursor=request.params.cursor if request.params else None) + cursor = params.cursor if params else None + tasks_list, next_cursor = await app.store.list_tasks(cursor=cursor) return ListTasksResult(tasks=tasks_list, next_cursor=next_cursor) - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: ... # pragma: no branch - - async def run_server(app_context: AppContext): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, + server: Server[AppContext] = Server( + "test-server", + lifespan=_make_lifespan(store, task_done_events), + on_list_tools=_handle_list_tools, + on_call_tool=_handle_call_tool_with_done_event, + ) + server.experimental.enable_tasks(on_list_tasks=handle_list_tasks) + + async with Client(server) as client: + # Create two tasks + for _ in range(2): + create_result = await client.session.send_request( + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), + ) ), - ), - ) as server_session: - async for message in server_session.incoming_messages: - await server._handle_message(message, server_session, app_context, raise_exceptions=False) - - async with anyio.create_task_group() as tg: - app_context = AppContext(task_group=tg, store=store) - tg.start_soon(run_server, app_context) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - await client_session.initialize() - - # Create two tasks - for _ in range(2): - create_result = await client_session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) - ), - CreateTaskResult, - ) - await app_context.task_done_events[create_result.task.task_id].wait() - - # Use TaskClient to list tasks - list_result = await client_session.experimental.list_tasks() + CreateTaskResult, + ) + await task_done_events[create_result.task.task_id].wait() - assert len(list_result.tasks) == 2 + # Use TaskClient to list tasks + list_result = await client.session.experimental.list_tasks() - tg.cancel_scope.cancel() + assert len(list_result.tasks) == 2 -@pytest.mark.anyio async def test_session_experimental_cancel_task() -> None: """Test TaskClient.cancel_task() method.""" - server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] store = InMemoryTaskStore() + task_done_events: dict[str, Event] = {} - @server.list_tools() - async def list_tools(): - return [Tool(name="test_tool", description="Test", input_schema={"type": "object"})] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: - ctx = server.request_context + async def handle_call_tool_no_work( + ctx: ServerRequestContext[AppContext], params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: app = ctx.lifespan_context if ctx.experimental.is_task: task_metadata = ctx.experimental.task_metadata @@ -377,14 +242,12 @@ async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextCon task = await app.store.create_task(task_metadata) # Don't start any work - task stays in "working" status return CreateTaskResult(task=task) - raise NotImplementedError - @server.experimental.get_task() - async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: - app = server.request_context.lifespan_context - task = await app.store.get_task(request.params.task_id) - assert task is not None, f"Test setup error: task {request.params.task_id} should exist" + async def handle_get_task(ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams) -> GetTaskResult: + app = ctx.lifespan_context + task = await app.store.get_task(params.task_id) + assert task is not None, f"Test setup error: task {params.task_id} should exist" return GetTaskResult( task_id=task.task_id, status=task.status, @@ -395,14 +258,14 @@ async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: poll_interval=task.poll_interval, ) - @server.experimental.cancel_task() - async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: - app = server.request_context.lifespan_context - task = await app.store.get_task(request.params.task_id) - assert task is not None, f"Test setup error: task {request.params.task_id} should exist" - await app.store.update_task(request.params.task_id, status="cancelled") - # CancelTaskResult extends Task, so we need to return the updated task info - updated_task = await app.store.get_task(request.params.task_id) + async def handle_cancel_task( + ctx: ServerRequestContext[AppContext], params: CancelTaskRequestParams + ) -> CancelTaskResult: + app = ctx.lifespan_context + task = await app.store.get_task(params.task_id) + assert task is not None, f"Test setup error: task {params.task_id} should exist" + await app.store.update_task(params.task_id, status="cancelled") + updated_task = await app.store.get_task(params.task_id) assert updated_task is not None return CancelTaskResult( task_id=updated_task.task_id, @@ -412,63 +275,35 @@ async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: ttl=updated_task.ttl, ) - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: ... # pragma: no branch - - async def run_server(app_context: AppContext): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), + server: Server[AppContext] = Server( + "test-server", + lifespan=_make_lifespan(store, task_done_events), + on_list_tools=_handle_list_tools, + on_call_tool=handle_call_tool_no_work, + ) + server.experimental.enable_tasks(on_get_task=handle_get_task, on_cancel_task=handle_cancel_task) + + async with Client(server) as client: + # Create a task (but don't complete it) + create_result = await client.session.send_request( + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), + ) ), - ) as server_session: - async for message in server_session.incoming_messages: - await server._handle_message(message, server_session, app_context, raise_exceptions=False) - - async with anyio.create_task_group() as tg: - app_context = AppContext(task_group=tg, store=store) - tg.start_soon(run_server, app_context) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - await client_session.initialize() - - # Create a task (but don't complete it) - create_result = await client_session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) - ), - CreateTaskResult, - ) - task_id = create_result.task.task_id - - # Verify task is working - status_before = await client_session.experimental.get_task(task_id) - assert status_before.status == "working" + CreateTaskResult, + ) + task_id = create_result.task.task_id - # Cancel the task - await client_session.experimental.cancel_task(task_id) + # Verify task is working + status_before = await client.session.experimental.get_task(task_id) + assert status_before.status == "working" - # Verify task is cancelled - status_after = await client_session.experimental.get_task(task_id) - assert status_after.status == "cancelled" + # Cancel the task + await client.session.experimental.cancel_task(task_id) - tg.cancel_scope.cancel() + # Verify task is cancelled + status_after = await client.session.experimental.get_task(task_id) + assert status_after.status == "cancelled" diff --git a/tests/experimental/tasks/server/test_integration.py b/tests/experimental/tasks/server/test_integration.py index 41cecc1295..b5b79033d0 100644 --- a/tests/experimental/tasks/server/test_integration.py +++ b/tests/experimental/tasks/server/test_integration.py @@ -8,46 +8,37 @@ 5. Client retrieves result with tasks/result """ +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager from dataclasses import dataclass, field -from typing import Any import anyio import pytest from anyio import Event from anyio.abc import TaskGroup -from mcp.client.session import ClientSession -from mcp.server import Server -from mcp.server.lowlevel import NotificationOptions -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession +from mcp import Client +from mcp.server import Server, ServerRequestContext from mcp.shared.experimental.tasks.helpers import task_execution from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder from mcp.types import ( - TASK_REQUIRED, CallToolRequest, CallToolRequestParams, CallToolResult, - ClientResult, CreateTaskResult, - GetTaskPayloadRequest, GetTaskPayloadRequestParams, GetTaskPayloadResult, - GetTaskRequest, GetTaskRequestParams, GetTaskResult, - ListTasksRequest, ListTasksResult, - ServerNotification, - ServerRequest, + ListToolsResult, + PaginatedRequestParams, TaskMetadata, TextContent, - Tool, - ToolExecution, ) +pytestmark = pytest.mark.anyio + @dataclass class AppContext: @@ -55,77 +46,57 @@ class AppContext: task_group: TaskGroup store: InMemoryTaskStore - # Events to signal when tasks complete (for testing without sleeps) task_done_events: dict[str, Event] = field(default_factory=lambda: {}) -@pytest.mark.anyio +def _make_lifespan(store: InMemoryTaskStore, task_done_events: dict[str, Event]): + @asynccontextmanager + async def app_lifespan(server: Server[AppContext]) -> AsyncIterator[AppContext]: + async with anyio.create_task_group() as tg: + yield AppContext(task_group=tg, store=store, task_done_events=task_done_events) + + return app_lifespan + + async def test_task_lifecycle_with_task_execution() -> None: - """Test the complete task lifecycle using the task_execution pattern. - - This demonstrates the recommended way to implement task-augmented tools: - 1. Create task in store - 2. Spawn work using task_execution() context manager - 3. Return CreateTaskResult immediately - 4. Work executes in background, auto-fails on exception - """ - # Note: We bypass the normal lifespan mechanism and pass context directly to _handle_message - server: Server[AppContext, Any] = Server("test-tasks") # type: ignore[assignment] + """Test the complete task lifecycle using the task_execution pattern.""" store = InMemoryTaskStore() + task_done_events: dict[str, Event] = {} + + async def handle_list_tools( + ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None + ) -> ListToolsResult: + raise NotImplementedError - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="process_data", - description="Process data asynchronously", - input_schema={ - "type": "object", - "properties": {"input": {"type": "string"}}, - }, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: - ctx = server.request_context + async def handle_call_tool( + ctx: ServerRequestContext[AppContext], params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: app = ctx.lifespan_context - if name == "process_data" and ctx.experimental.is_task: - # 1. Create task in store + if params.name == "process_data" and ctx.experimental.is_task: task_metadata = ctx.experimental.task_metadata assert task_metadata is not None task = await app.store.create_task(task_metadata) - # 2. Create event to signal completion (for testing) done_event = Event() app.task_done_events[task.task_id] = done_event - # 3. Define work function using task_execution for safety - async def do_work(): + async def do_work() -> None: async with task_execution(task.task_id, app.store) as task_ctx: await task_ctx.update_status("Processing input...") - # Simulate work - input_value = arguments.get("input", "") + input_value = (params.arguments or {}).get("input", "") result_text = f"Processed: {input_value.upper()}" await task_ctx.complete(CallToolResult(content=[TextContent(type="text", text=result_text)])) - # Signal completion done_event.set() - # 4. Spawn work in task group (from lifespan_context) app.task_group.start_soon(do_work) - - # 5. Return CreateTaskResult immediately return CreateTaskResult(task=task) raise NotImplementedError - # Register task query handlers (delegate to store) - @server.experimental.get_task() - async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: - app = server.request_context.lifespan_context - task = await app.store.get_task(request.params.task_id) - assert task is not None, f"Test setup error: task {request.params.task_id} should exist" + async def handle_get_task(ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams) -> GetTaskResult: + app = ctx.lifespan_context + task = await app.store.get_task(params.task_id) + assert task is not None, f"Test setup error: task {params.task_id} should exist" return GetTaskResult( task_id=task.task_id, status=task.status, @@ -136,134 +107,91 @@ async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: poll_interval=task.poll_interval, ) - @server.experimental.get_task_result() async def handle_get_task_result( - request: GetTaskPayloadRequest, + ctx: ServerRequestContext[AppContext], params: GetTaskPayloadRequestParams ) -> GetTaskPayloadResult: - app = server.request_context.lifespan_context - result = await app.store.get_result(request.params.task_id) - assert result is not None, f"Test setup error: result for {request.params.task_id} should exist" + app = ctx.lifespan_context + result = await app.store.get_result(params.task_id) + assert result is not None, f"Test setup error: result for {params.task_id} should exist" assert isinstance(result, CallToolResult) - # Return as GetTaskPayloadResult (which accepts extra fields) return GetTaskPayloadResult(**result.model_dump()) - @server.experimental.list_tasks() - async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + async def handle_list_tasks( + ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None + ) -> ListTasksResult: raise NotImplementedError - # Set up client-server communication - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: ... # pragma: no cover - - async def run_server(app_context: AppContext): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, + server: Server[AppContext] = Server( + "test-tasks", + lifespan=_make_lifespan(store, task_done_events), + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) + server.experimental.enable_tasks( + on_get_task=handle_get_task, + on_task_result=handle_get_task_result, + on_list_tasks=handle_list_tasks, + ) + + async with Client(server) as client: + # Step 1: Send task-augmented tool call + create_result = await client.session.send_request( + CallToolRequest( + params=CallToolRequestParams( + name="process_data", + arguments={"input": "hello world"}, + task=TaskMetadata(ttl=60000), ), ), - ) as server_session: - async for message in server_session.incoming_messages: - await server._handle_message(message, server_session, app_context, raise_exceptions=False) - - async with anyio.create_task_group() as tg: - # Create app context with task group and store - app_context = AppContext(task_group=tg, store=store) - tg.start_soon(run_server, app_context) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - await client_session.initialize() - - # === Step 1: Send task-augmented tool call === - create_result = await client_session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="process_data", - arguments={"input": "hello world"}, - task=TaskMetadata(ttl=60000), - ), - ), - CreateTaskResult, - ) - - assert isinstance(create_result, CreateTaskResult) - assert create_result.task.status == "working" - task_id = create_result.task.task_id - - # === Step 2: Wait for task to complete === - await app_context.task_done_events[task_id].wait() + CreateTaskResult, + ) - task_status = await client_session.send_request( - GetTaskRequest(params=GetTaskRequestParams(task_id=task_id)), - GetTaskResult, - ) + assert isinstance(create_result, CreateTaskResult) + assert create_result.task.status == "working" + task_id = create_result.task.task_id - assert task_status.task_id == task_id - assert task_status.status == "completed" + # Step 2: Wait for task to complete + await task_done_events[task_id].wait() - # === Step 3: Retrieve the actual result === - task_result = await client_session.send_request( - GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=task_id)), - CallToolResult, - ) + task_status = await client.session.experimental.get_task(task_id) + assert task_status.task_id == task_id + assert task_status.status == "completed" - assert len(task_result.content) == 1 - content = task_result.content[0] - assert isinstance(content, TextContent) - assert content.text == "Processed: HELLO WORLD" + # Step 3: Retrieve the actual result + task_result = await client.session.experimental.get_task_result(task_id, CallToolResult) - tg.cancel_scope.cancel() + assert len(task_result.content) == 1 + content = task_result.content[0] + assert isinstance(content, TextContent) + assert content.text == "Processed: HELLO WORLD" -@pytest.mark.anyio async def test_task_auto_fails_on_exception() -> None: """Test that task_execution automatically fails the task on unhandled exception.""" - # Note: We bypass the normal lifespan mechanism and pass context directly to _handle_message - server: Server[AppContext, Any] = Server("test-tasks-failure") # type: ignore[assignment] store = InMemoryTaskStore() + task_done_events: dict[str, Event] = {} + + async def handle_list_tools( + ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None + ) -> ListToolsResult: + raise NotImplementedError - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="failing_task", - description="A task that fails", - input_schema={"type": "object", "properties": {}}, - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: - ctx = server.request_context + async def handle_call_tool( + ctx: ServerRequestContext[AppContext], params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: app = ctx.lifespan_context - if name == "failing_task" and ctx.experimental.is_task: + if params.name == "failing_task" and ctx.experimental.is_task: task_metadata = ctx.experimental.task_metadata assert task_metadata is not None task = await app.store.create_task(task_metadata) - # Create event to signal completion (for testing) done_event = Event() app.task_done_events[task.task_id] = done_event - async def do_failing_work(): + async def do_failing_work() -> None: async with task_execution(task.task_id, app.store) as task_ctx: await task_ctx.update_status("About to fail...") raise RuntimeError("Something went wrong!") - # Note: complete() is never called, but task_execution - # will automatically call fail() due to the exception # This line is reached because task_execution suppresses the exception done_event.set() @@ -272,11 +200,10 @@ async def do_failing_work(): raise NotImplementedError - @server.experimental.get_task() - async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: - app = server.request_context.lifespan_context - task = await app.store.get_task(request.params.task_id) - assert task is not None, f"Test setup error: task {request.params.task_id} should exist" + async def handle_get_task(ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams) -> GetTaskResult: + app = ctx.lifespan_context + task = await app.store.get_task(params.task_id) + assert task is not None, f"Test setup error: task {params.task_id} should exist" return GetTaskResult( task_id=task.task_id, status=task.status, @@ -287,64 +214,34 @@ async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: poll_interval=task.poll_interval, ) - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: ... # pragma: no cover - - async def run_server(app_context: AppContext): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, + server: Server[AppContext] = Server( + "test-tasks-failure", + lifespan=_make_lifespan(store, task_done_events), + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) + server.experimental.enable_tasks(on_get_task=handle_get_task) + + async with Client(server) as client: + # Send task request + create_result = await client.session.send_request( + CallToolRequest( + params=CallToolRequestParams( + name="failing_task", + arguments={}, + task=TaskMetadata(ttl=60000), ), ), - ) as server_session: - async for message in server_session.incoming_messages: - await server._handle_message(message, server_session, app_context, raise_exceptions=False) - - async with anyio.create_task_group() as tg: - app_context = AppContext(task_group=tg, store=store) - tg.start_soon(run_server, app_context) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - await client_session.initialize() - - # Send task request - create_result = await client_session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="failing_task", - arguments={}, - task=TaskMetadata(ttl=60000), - ), - ), - CreateTaskResult, - ) - - task_id = create_result.task.task_id + CreateTaskResult, + ) - # Wait for task to complete (even though it fails) - await app_context.task_done_events[task_id].wait() + task_id = create_result.task.task_id - # Check that task was auto-failed - task_status = await client_session.send_request( - GetTaskRequest(params=GetTaskRequestParams(task_id=task_id)), GetTaskResult - ) + # Wait for task to complete (even though it fails) + await task_done_events[task_id].wait() - assert task_status.status == "failed" - assert task_status.status_message == "Something went wrong!" + # Check that task was auto-failed + task_status = await client.session.experimental.get_task(task_id) - tg.cancel_scope.cancel() + assert task_status.status == "failed" + assert task_status.status_message == "Something went wrong!" diff --git a/tests/experimental/tasks/server/test_run_task_flow.py b/tests/experimental/tasks/server/test_run_task_flow.py index 0d5d1df77a..027382e69e 100644 --- a/tests/experimental/tasks/server/test_run_task_flow.py +++ b/tests/experimental/tasks/server/test_run_task_flow.py @@ -8,159 +8,102 @@ These are integration tests that verify the complete flow works end-to-end. """ -from typing import Any from unittest.mock import Mock import anyio import pytest from anyio import Event -from mcp.client.session import ClientSession -from mcp.server import Server +from mcp import Client +from mcp.server import Server, ServerRequestContext from mcp.server.experimental.request_context import Experimental from mcp.server.experimental.task_context import ServerTaskContext from mcp.server.experimental.task_support import TaskSupport from mcp.server.lowlevel import NotificationOptions from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue -from mcp.shared.message import SessionMessage from mcp.types import ( TASK_REQUIRED, + CallToolRequestParams, CallToolResult, - CancelTaskRequest, - CancelTaskResult, CreateTaskResult, - GetTaskPayloadRequest, - GetTaskPayloadResult, - GetTaskRequest, + GetTaskRequestParams, GetTaskResult, - ListTasksRequest, - ListTasksResult, + ListToolsResult, + PaginatedRequestParams, TextContent, - Tool, - ToolExecution, ) +pytestmark = pytest.mark.anyio -@pytest.mark.anyio -async def test_run_task_basic_flow() -> None: - """Test the basic run_task flow without elicitation. - 1. enable_tasks() sets up handlers - 2. Client calls tool with task field - 3. run_task() spawns work, returns CreateTaskResult - 4. Work completes in background - 5. Client polls and sees completed status - """ - server = Server("test-run-task") +async def _handle_list_tools_simple_task( + ctx: ServerRequestContext, params: PaginatedRequestParams | None +) -> ListToolsResult: + raise NotImplementedError - # One-line setup - server.experimental.enable_tasks() - # Track when work completes and capture received meta +async def test_run_task_basic_flow() -> None: + """Test the basic run_task flow without elicitation.""" work_completed = Event() received_meta: list[str | None] = [None] - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="simple_task", - description="A simple task", - input_schema={"type": "object", "properties": {"input": {"type": "string"}}}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult | CreateTaskResult: - ctx = server.request_context + async def handle_call_tool( + ctx: ServerRequestContext, params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) - # Capture the meta from the request (if present) if ctx.meta is not None: # pragma: no branch received_meta[0] = ctx.meta.get("custom_field") async def work(task: ServerTaskContext) -> CallToolResult: await task.update_status("Working...") - input_val = arguments.get("input", "default") + input_val = (params.arguments or {}).get("input", "default") result = CallToolResult(content=[TextContent(type="text", text=f"Processed: {input_val}")]) work_completed.set() return result return await ctx.experimental.run_task(work) - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), + server = Server( + "test-run-task", + on_list_tools=_handle_list_tools_simple_task, + on_call_tool=handle_call_tool, + ) + server.experimental.enable_tasks() + + async with Client(server) as client: + result = await client.session.experimental.call_tool_as_task( + "simple_task", + {"input": "hello"}, + meta={"custom_field": "test_value"}, ) - async def run_client() -> None: - async with ClientSession(server_to_client_receive, client_to_server_send) as client_session: - # Initialize - await client_session.initialize() - - # Call tool as task (with meta to test that code path) - result = await client_session.experimental.call_tool_as_task( - "simple_task", - {"input": "hello"}, - meta={"custom_field": "test_value"}, - ) - - # Should get CreateTaskResult - task_id = result.task.task_id - assert result.task.status == "working" - - # Wait for work to complete - with anyio.fail_after(5): - await work_completed.wait() - - # Poll until task status is completed - with anyio.fail_after(5): - while True: - task_status = await client_session.experimental.get_task(task_id) - if task_status.status == "completed": # pragma: no branch - break - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) - - # Verify the meta was passed through correctly + task_id = result.task.task_id + assert result.task.status == "working" + + with anyio.fail_after(5): + await work_completed.wait() + + with anyio.fail_after(5): + while True: + task_status = await client.session.experimental.get_task(task_id) + if task_status.status == "completed": # pragma: no branch + break + assert received_meta[0] == "test_value" -@pytest.mark.anyio async def test_run_task_auto_fails_on_exception() -> None: """Test that run_task automatically fails the task when work raises.""" - server = Server("test-run-task-fail") - server.experimental.enable_tasks() - work_failed = Event() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="failing_task", - description="A task that fails", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult | CreateTaskResult: - ctx = server.request_context + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError + + async def handle_call_tool( + ctx: ServerRequestContext, params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) async def work(task: ServerTaskContext) -> CallToolResult: @@ -169,42 +112,29 @@ async def work(task: ServerTaskContext) -> CallToolResult: return await ctx.experimental.run_task(work) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options(), - ) - - async def run_client() -> None: - async with ClientSession(server_to_client_receive, client_to_server_send) as client_session: - await client_session.initialize() - - result = await client_session.experimental.call_tool_as_task("failing_task", {}) - task_id = result.task.task_id + server = Server( + "test-run-task-fail", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) + server.experimental.enable_tasks() - # Wait for work to fail - with anyio.fail_after(5): - await work_failed.wait() + async with Client(server) as client: + result = await client.session.experimental.call_tool_as_task("failing_task", {}) + task_id = result.task.task_id - # Poll until task status is failed - with anyio.fail_after(5): - while True: - task_status = await client_session.experimental.get_task(task_id) - if task_status.status == "failed": # pragma: no branch - break + with anyio.fail_after(5): + await work_failed.wait() - assert "Something went wrong" in (task_status.status_message or "") + with anyio.fail_after(5): + while True: + task_status = await client.session.experimental.get_task(task_id) + if task_status.status == "failed": # pragma: no branch + break - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) + assert "Something went wrong" in (task_status.status_message or "") -@pytest.mark.anyio async def test_enable_tasks_auto_registers_handlers() -> None: """Test that enable_tasks() auto-registers get_task, list_tasks, cancel_task handlers.""" server = Server("test-enable-tasks") @@ -221,63 +151,41 @@ async def test_enable_tasks_auto_registers_handlers() -> None: assert caps_after.tasks is not None assert caps_after.tasks.list is not None assert caps_after.tasks.cancel is not None - # Verify nested call capability is present assert caps_after.tasks.requests is not None assert caps_after.tasks.requests.tools is not None assert caps_after.tasks.requests.tools.call is not None -@pytest.mark.anyio async def test_enable_tasks_with_custom_store_and_queue() -> None: """Test that enable_tasks() uses provided store and queue instead of defaults.""" server = Server("test-custom-store-queue") - # Create custom store and queue custom_store = InMemoryTaskStore() custom_queue = InMemoryTaskMessageQueue() - # Enable tasks with custom implementations task_support = server.experimental.enable_tasks(store=custom_store, queue=custom_queue) - # Verify our custom implementations are used assert task_support.store is custom_store assert task_support.queue is custom_queue -@pytest.mark.anyio async def test_enable_tasks_skips_default_handlers_when_custom_registered() -> None: """Test that enable_tasks() doesn't override already-registered handlers.""" server = Server("test-custom-handlers") - # Register custom handlers BEFORE enable_tasks (never called, just for registration) - @server.experimental.get_task() - async def custom_get_task(req: GetTaskRequest) -> GetTaskResult: - raise NotImplementedError - - @server.experimental.get_task_result() - async def custom_get_task_result(req: GetTaskPayloadRequest) -> GetTaskPayloadResult: + # Register custom handlers via enable_tasks kwargs + async def custom_get_task(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: raise NotImplementedError - @server.experimental.list_tasks() - async def custom_list_tasks(req: ListTasksRequest) -> ListTasksResult: - raise NotImplementedError - - @server.experimental.cancel_task() - async def custom_cancel_task(req: CancelTaskRequest) -> CancelTaskResult: - raise NotImplementedError + server.experimental.enable_tasks(on_get_task=custom_get_task) - # Now enable tasks - should NOT override our custom handlers - server.experimental.enable_tasks() + # Verify handler is registered + assert server._has_handler("tasks/get") + assert server._has_handler("tasks/list") + assert server._has_handler("tasks/cancel") + assert server._has_handler("tasks/result") - # Verify our custom handlers are still registered (not replaced by defaults) - # The handlers dict should contain our custom handlers - assert GetTaskRequest in server.request_handlers - assert GetTaskPayloadRequest in server.request_handlers - assert ListTasksRequest in server.request_handlers - assert CancelTaskRequest in server.request_handlers - -@pytest.mark.anyio async def test_run_task_without_enable_tasks_raises() -> None: """Test that run_task raises when enable_tasks() wasn't called.""" experimental = Experimental( @@ -294,7 +202,6 @@ async def work(task: ServerTaskContext) -> CallToolResult: await experimental.run_task(work) -@pytest.mark.anyio async def test_task_support_task_group_before_run_raises() -> None: """Test that accessing task_group before run() raises RuntimeError.""" task_support = TaskSupport.in_memory() @@ -303,7 +210,6 @@ async def test_task_support_task_group_before_run_raises() -> None: _ = task_support.task_group -@pytest.mark.anyio async def test_run_task_without_session_raises() -> None: """Test that run_task raises when session is not available.""" task_support = TaskSupport.in_memory() @@ -322,7 +228,6 @@ async def work(task: ServerTaskContext) -> CallToolResult: await experimental.run_task(work) -@pytest.mark.anyio async def test_run_task_without_task_metadata_raises() -> None: """Test that run_task raises when request is not task-augmented.""" task_support = TaskSupport.in_memory() @@ -342,29 +247,17 @@ async def work(task: ServerTaskContext) -> CallToolResult: await experimental.run_task(work) -@pytest.mark.anyio async def test_run_task_with_model_immediate_response() -> None: """Test that run_task includes model_immediate_response in CreateTaskResult._meta.""" - server = Server("test-run-task-immediate") - server.experimental.enable_tasks() - work_completed = Event() immediate_response_text = "Processing your request..." - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="task_with_immediate", - description="A task with immediate response", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult | CreateTaskResult: - ctx = server.request_context + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError + + async def handle_call_tool( + ctx: ServerRequestContext, params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) async def work(task: ServerTaskContext) -> CallToolResult: @@ -373,164 +266,102 @@ async def work(task: ServerTaskContext) -> CallToolResult: return await ctx.experimental.run_task(work, model_immediate_response=immediate_response_text) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options(), - ) - - async def run_client() -> None: - async with ClientSession(server_to_client_receive, client_to_server_send) as client_session: - await client_session.initialize() - - result = await client_session.experimental.call_tool_as_task("task_with_immediate", {}) + server = Server( + "test-run-task-immediate", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) + server.experimental.enable_tasks() - # Verify the immediate response is in _meta - assert result.meta is not None - assert "io.modelcontextprotocol/model-immediate-response" in result.meta - assert result.meta["io.modelcontextprotocol/model-immediate-response"] == immediate_response_text + async with Client(server) as client: + result = await client.session.experimental.call_tool_as_task("task_with_immediate", {}) - with anyio.fail_after(5): - await work_completed.wait() + assert result.meta is not None + assert "io.modelcontextprotocol/model-immediate-response" in result.meta + assert result.meta["io.modelcontextprotocol/model-immediate-response"] == immediate_response_text - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) + with anyio.fail_after(5): + await work_completed.wait() -@pytest.mark.anyio async def test_run_task_doesnt_complete_if_already_terminal() -> None: """Test that run_task doesn't auto-complete if work manually completed the task.""" - server = Server("test-already-complete") - server.experimental.enable_tasks() - work_completed = Event() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="manual_complete_task", - description="A task that manually completes", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult | CreateTaskResult: - ctx = server.request_context + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError + + async def handle_call_tool( + ctx: ServerRequestContext, params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) async def work(task: ServerTaskContext) -> CallToolResult: - # Manually complete the task before returning manual_result = CallToolResult(content=[TextContent(type="text", text="Manually completed")]) await task.complete(manual_result, notify=False) work_completed.set() - # Return a different result - but it should be ignored since task is already terminal return CallToolResult(content=[TextContent(type="text", text="This should be ignored")]) return await ctx.experimental.run_task(work) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options(), - ) - - async def run_client() -> None: - async with ClientSession(server_to_client_receive, client_to_server_send) as client_session: - await client_session.initialize() - - result = await client_session.experimental.call_tool_as_task("manual_complete_task", {}) - task_id = result.task.task_id + server = Server( + "test-already-complete", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) + server.experimental.enable_tasks() - with anyio.fail_after(5): - await work_completed.wait() + async with Client(server) as client: + result = await client.session.experimental.call_tool_as_task("manual_complete_task", {}) + task_id = result.task.task_id - # Poll until task status is completed - with anyio.fail_after(5): - while True: - status = await client_session.experimental.get_task(task_id) - if status.status == "completed": # pragma: no branch - break + with anyio.fail_after(5): + await work_completed.wait() - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) + with anyio.fail_after(5): + while True: + status = await client.session.experimental.get_task(task_id) + if status.status == "completed": # pragma: no branch + break -@pytest.mark.anyio async def test_run_task_doesnt_fail_if_already_terminal() -> None: """Test that run_task doesn't auto-fail if work manually failed/cancelled the task.""" - server = Server("test-already-failed") - server.experimental.enable_tasks() - work_completed = Event() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="manual_cancel_task", - description="A task that manually cancels then raises", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult | CreateTaskResult: - ctx = server.request_context + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError + + async def handle_call_tool( + ctx: ServerRequestContext, params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) async def work(task: ServerTaskContext) -> CallToolResult: - # Manually fail the task first await task.fail("Manually failed", notify=False) work_completed.set() - # Then raise - but the auto-fail should be skipped since task is already terminal raise RuntimeError("This error should not change status") return await ctx.experimental.run_task(work) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options(), - ) - - async def run_client() -> None: - async with ClientSession(server_to_client_receive, client_to_server_send) as client_session: - await client_session.initialize() - - result = await client_session.experimental.call_tool_as_task("manual_cancel_task", {}) - task_id = result.task.task_id + server = Server( + "test-already-failed", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) + server.experimental.enable_tasks() - with anyio.fail_after(5): - await work_completed.wait() + async with Client(server) as client: + result = await client.session.experimental.call_tool_as_task("manual_cancel_task", {}) + task_id = result.task.task_id - # Poll until task status is failed - with anyio.fail_after(5): - while True: - status = await client_session.experimental.get_task(task_id) - if status.status == "failed": # pragma: no branch - break + with anyio.fail_after(5): + await work_completed.wait() - # Task should still be failed (from manual fail, not auto-fail from exception) - assert status.status_message == "Manually failed" # Not "This error should not change status" + with anyio.fail_after(5): + while True: + status = await client.session.experimental.get_task(task_id) + if status.status == "failed": # pragma: no branch + break - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) + assert status.status_message == "Manually failed" diff --git a/tests/experimental/tasks/server/test_server.py b/tests/experimental/tasks/server/test_server.py index 8005380d28..6a28b274ea 100644 --- a/tests/experimental/tasks/server/test_server.py +++ b/tests/experimental/tasks/server/test_server.py @@ -6,8 +6,9 @@ import anyio import pytest +from mcp import Client from mcp.client.session import ClientSession -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession @@ -23,7 +24,6 @@ CallToolRequest, CallToolRequestParams, CallToolResult, - CancelTaskRequest, CancelTaskRequestParams, CancelTaskResult, ClientResult, @@ -31,21 +31,18 @@ GetTaskPayloadRequest, GetTaskPayloadRequestParams, GetTaskPayloadResult, - GetTaskRequest, GetTaskRequestParams, GetTaskResult, JSONRPCError, JSONRPCNotification, JSONRPCResponse, - ListTasksRequest, ListTasksResult, - ListToolsRequest, ListToolsResult, + PaginatedRequestParams, SamplingMessage, ServerCapabilities, ServerNotification, ServerRequest, - ServerResult, Task, TaskMetadata, TextContent, @@ -53,57 +50,37 @@ ToolExecution, ) +pytestmark = pytest.mark.anyio -@pytest.mark.anyio -async def test_list_tasks_handler() -> None: - """Test that experimental list_tasks handler works.""" - server = Server("test") +async def test_list_tasks_handler() -> None: + """Test that experimental list_tasks handler works via Client.""" now = datetime.now(timezone.utc) test_tasks = [ - Task( - task_id="task-1", - status="working", - created_at=now, - last_updated_at=now, - ttl=60000, - poll_interval=1000, - ), - Task( - task_id="task-2", - status="completed", - created_at=now, - last_updated_at=now, - ttl=60000, - poll_interval=1000, - ), + Task(task_id="task-1", status="working", created_at=now, last_updated_at=now, ttl=60000, poll_interval=1000), + Task(task_id="task-2", status="completed", created_at=now, last_updated_at=now, ttl=60000, poll_interval=1000), ] - @server.experimental.list_tasks() - async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + async def handle_list_tasks(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListTasksResult: return ListTasksResult(tasks=test_tasks) - handler = server.request_handlers[ListTasksRequest] - request = ListTasksRequest(method="tasks/list") - result = await handler(request) + server = Server("test") + server.experimental.enable_tasks(on_list_tasks=handle_list_tasks) - assert isinstance(result, ServerResult) - assert isinstance(result, ListTasksResult) - assert len(result.tasks) == 2 - assert result.tasks[0].task_id == "task-1" - assert result.tasks[1].task_id == "task-2" + async with Client(server) as client: + result = await client.session.experimental.list_tasks() + assert len(result.tasks) == 2 + assert result.tasks[0].task_id == "task-1" + assert result.tasks[1].task_id == "task-2" -@pytest.mark.anyio async def test_get_task_handler() -> None: - """Test that experimental get_task handler works.""" - server = Server("test") + """Test that experimental get_task handler works via Client.""" - @server.experimental.get_task() - async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: + async def handle_get_task(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: now = datetime.now(timezone.utc) return GetTaskResult( - task_id=request.params.task_id, + task_id=params.task_id, status="working", created_at=now, last_updated_at=now, @@ -111,85 +88,69 @@ async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: poll_interval=1000, ) - handler = server.request_handlers[GetTaskRequest] - request = GetTaskRequest( - method="tasks/get", - params=GetTaskRequestParams(task_id="test-task-123"), - ) - result = await handler(request) + server = Server("test") + server.experimental.enable_tasks(on_get_task=handle_get_task) - assert isinstance(result, ServerResult) - assert isinstance(result, GetTaskResult) - assert result.task_id == "test-task-123" - assert result.status == "working" + async with Client(server) as client: + result = await client.session.experimental.get_task("test-task-123") + assert result.task_id == "test-task-123" + assert result.status == "working" -@pytest.mark.anyio async def test_get_task_result_handler() -> None: - """Test that experimental get_task_result handler works.""" - server = Server("test") + """Test that experimental get_task_result handler works via Client.""" - @server.experimental.get_task_result() - async def handle_get_task_result(request: GetTaskPayloadRequest) -> GetTaskPayloadResult: + async def handle_get_task_result( + ctx: ServerRequestContext, params: GetTaskPayloadRequestParams + ) -> GetTaskPayloadResult: return GetTaskPayloadResult() - handler = server.request_handlers[GetTaskPayloadRequest] - request = GetTaskPayloadRequest( - method="tasks/result", - params=GetTaskPayloadRequestParams(task_id="test-task-123"), - ) - result = await handler(request) + server = Server("test") + server.experimental.enable_tasks(on_task_result=handle_get_task_result) - assert isinstance(result, ServerResult) - assert isinstance(result, GetTaskPayloadResult) + async with Client(server) as client: + result = await client.session.send_request( + GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id="test-task-123")), + GetTaskPayloadResult, + ) + assert isinstance(result, GetTaskPayloadResult) -@pytest.mark.anyio async def test_cancel_task_handler() -> None: - """Test that experimental cancel_task handler works.""" - server = Server("test") + """Test that experimental cancel_task handler works via Client.""" - @server.experimental.cancel_task() - async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: + async def handle_cancel_task(ctx: ServerRequestContext, params: CancelTaskRequestParams) -> CancelTaskResult: now = datetime.now(timezone.utc) return CancelTaskResult( - task_id=request.params.task_id, + task_id=params.task_id, status="cancelled", created_at=now, last_updated_at=now, ttl=60000, ) - handler = server.request_handlers[CancelTaskRequest] - request = CancelTaskRequest( - method="tasks/cancel", - params=CancelTaskRequestParams(task_id="test-task-123"), - ) - result = await handler(request) + server = Server("test") + server.experimental.enable_tasks(on_cancel_task=handle_cancel_task) - assert isinstance(result, ServerResult) - assert isinstance(result, CancelTaskResult) - assert result.task_id == "test-task-123" - assert result.status == "cancelled" + async with Client(server) as client: + result = await client.session.experimental.cancel_task("test-task-123") + assert result.task_id == "test-task-123" + assert result.status == "cancelled" -@pytest.mark.anyio async def test_server_capabilities_include_tasks() -> None: """Test that server capabilities include tasks when handlers are registered.""" server = Server("test") - @server.experimental.list_tasks() - async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + async def noop_list_tasks(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListTasksResult: raise NotImplementedError - @server.experimental.cancel_task() - async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: + async def noop_cancel_task(ctx: ServerRequestContext, params: CancelTaskRequestParams) -> CancelTaskResult: raise NotImplementedError - capabilities = server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ) + server.experimental.enable_tasks(on_list_tasks=noop_list_tasks, on_cancel_task=noop_cancel_task) + + capabilities = server.get_capabilities(notification_options=NotificationOptions(), experimental_capabilities={}) assert capabilities.tasks is not None assert capabilities.tasks.list is not None @@ -198,259 +159,164 @@ async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: assert capabilities.tasks.requests.tools is not None -@pytest.mark.anyio -async def test_server_capabilities_partial_tasks() -> None: +@pytest.mark.skip( + reason="TODO(maxisbey): enable_tasks registers default handlers for all task methods, " + "so partial capabilities aren't possible yet. Low-level API should support " + "selectively enabling/disabling task capabilities." +) +async def test_server_capabilities_partial_tasks() -> None: # pragma: no cover """Test capabilities with only some task handlers registered.""" server = Server("test") - @server.experimental.list_tasks() - async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + async def noop_list_tasks(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListTasksResult: raise NotImplementedError # Only list_tasks registered, not cancel_task + server.experimental.enable_tasks(on_list_tasks=noop_list_tasks) - capabilities = server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ) + capabilities = server.get_capabilities(notification_options=NotificationOptions(), experimental_capabilities={}) assert capabilities.tasks is not None assert capabilities.tasks.list is not None assert capabilities.tasks.cancel is None # Not registered -@pytest.mark.anyio async def test_tool_with_task_execution_metadata() -> None: """Test that tools can declare task execution mode.""" - server = Server("test") - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="quick_tool", - description="Fast tool", - input_schema={"type": "object", "properties": {}}, - execution=ToolExecution(task_support=TASK_FORBIDDEN), - ), - Tool( - name="long_tool", - description="Long running tool", - input_schema={"type": "object", "properties": {}}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ), - Tool( - name="flexible_tool", - description="Can be either", - input_schema={"type": "object", "properties": {}}, - execution=ToolExecution(task_support=TASK_OPTIONAL), - ), - ] + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="quick_tool", + description="Fast tool", + input_schema={"type": "object", "properties": {}}, + execution=ToolExecution(task_support=TASK_FORBIDDEN), + ), + Tool( + name="long_tool", + description="Long running tool", + input_schema={"type": "object", "properties": {}}, + execution=ToolExecution(task_support=TASK_REQUIRED), + ), + Tool( + name="flexible_tool", + description="Can be either", + input_schema={"type": "object", "properties": {}}, + execution=ToolExecution(task_support=TASK_OPTIONAL), + ), + ] + ) - tools_handler = server.request_handlers[ListToolsRequest] - request = ListToolsRequest(method="tools/list") - result = await tools_handler(request) + server = Server("test", on_list_tools=handle_list_tools) - assert isinstance(result, ServerResult) - assert isinstance(result, ListToolsResult) - tools = result.tools + async with Client(server) as client: + result = await client.list_tools() + tools = result.tools - assert tools[0].execution is not None - assert tools[0].execution.task_support == TASK_FORBIDDEN - assert tools[1].execution is not None - assert tools[1].execution.task_support == TASK_REQUIRED - assert tools[2].execution is not None - assert tools[2].execution.task_support == TASK_OPTIONAL + assert tools[0].execution is not None + assert tools[0].execution.task_support == TASK_FORBIDDEN + assert tools[1].execution is not None + assert tools[1].execution.task_support == TASK_REQUIRED + assert tools[2].execution is not None + assert tools[2].execution.task_support == TASK_OPTIONAL -@pytest.mark.anyio async def test_task_metadata_in_call_tool_request() -> None: - """Test that task metadata is accessible via RequestContext when calling a tool.""" - server = Server("test") + """Test that task metadata is accessible via ctx when calling a tool.""" captured_task_metadata: TaskMetadata | None = None - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="long_task", - description="A long running task", - input_schema={"type": "object", "properties": {}}, - execution=ToolExecution(task_support="optional"), - ) - ] + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: nonlocal captured_task_metadata - ctx = server.request_context captured_task_metadata = ctx.experimental.task_metadata - return [TextContent(type="text", text="done")] - - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: ... # pragma: no branch - - async def run_server(): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, + return CallToolResult(content=[TextContent(type="text", text="done")]) + + server = Server("test", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) + + async with Client(server) as client: + # Call tool with task metadata + await client.session.send_request( + CallToolRequest( + params=CallToolRequestParams( + name="long_task", + arguments={}, + task=TaskMetadata(ttl=60000), ), ), - ) as server_session: - async with anyio.create_task_group() as tg: - - async def handle_messages(): - async for message in server_session.incoming_messages: # pragma: no branch - await server._handle_message(message, server_session, {}, False) - - tg.start_soon(handle_messages) - await anyio.sleep_forever() - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - await client_session.initialize() - - # Call tool with task metadata - await client_session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="long_task", - arguments={}, - task=TaskMetadata(ttl=60000), - ), - ), - CallToolResult, - ) - - tg.cancel_scope.cancel() + CallToolResult, + ) assert captured_task_metadata is not None assert captured_task_metadata.ttl == 60000 -@pytest.mark.anyio async def test_task_metadata_is_task_property() -> None: - """Test that RequestContext.experimental.is_task works correctly.""" - server = Server("test") + """Test that ctx.experimental.is_task works correctly.""" is_task_values: list[bool] = [] - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="test_tool", - description="Test tool", - input_schema={"type": "object", "properties": {}}, - ) - ] + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: - ctx = server.request_context + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: is_task_values.append(ctx.experimental.is_task) - return [TextContent(type="text", text="done")] + return CallToolResult(content=[TextContent(type="text", text="done")]) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + server = Server("test", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: ... # pragma: no branch + async with Client(server) as client: + # Call without task metadata + await client.session.send_request( + CallToolRequest(params=CallToolRequestParams(name="test_tool", arguments={})), + CallToolResult, + ) - async def run_server(): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), + # Call with task metadata + await client.session.send_request( + CallToolRequest( + params=CallToolRequestParams(name="test_tool", arguments={}, task=TaskMetadata(ttl=60000)), ), - ) as server_session: - async with anyio.create_task_group() as tg: - - async def handle_messages(): - async for message in server_session.incoming_messages: # pragma: no branch - await server._handle_message(message, server_session, {}, False) - - tg.start_soon(handle_messages) - await anyio.sleep_forever() - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - await client_session.initialize() - - # Call without task metadata - await client_session.send_request( - CallToolRequest(params=CallToolRequestParams(name="test_tool", arguments={})), - CallToolResult, - ) - - # Call with task metadata - await client_session.send_request( - CallToolRequest( - params=CallToolRequestParams(name="test_tool", arguments={}, task=TaskMetadata(ttl=60000)), - ), - CallToolResult, - ) - - tg.cancel_scope.cancel() + CallToolResult, + ) assert len(is_task_values) == 2 assert is_task_values[0] is False # First call without task assert is_task_values[1] is True # Second call with task -@pytest.mark.anyio async def test_update_capabilities_no_handlers() -> None: """Test that update_capabilities returns early when no task handlers are registered.""" server = Server("test-no-handlers") - # Access experimental to initialize it, but don't register any task handlers _ = server.experimental caps = server.get_capabilities(NotificationOptions(), {}) - - # Without any task handlers registered, tasks capability should be None assert caps.tasks is None -@pytest.mark.anyio +async def test_update_capabilities_partial_handlers() -> None: + """Test that update_capabilities skips list/cancel when only tasks/get is registered.""" + server = Server("test-partial") + # Access .experimental to create the ExperimentalHandlers instance + exp = server.experimental + # Second access returns the same cached instance + assert server.experimental is exp + + async def noop_get(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: + raise NotImplementedError + + server._add_request_handler("tasks/get", noop_get) + + caps = server.get_capabilities(NotificationOptions(), {}) + assert caps.tasks is not None + assert caps.tasks.list is None + assert caps.tasks.cancel is None + + async def test_default_task_handlers_via_enable_tasks() -> None: - """Test that enable_tasks() auto-registers working default handlers. - - This exercises the default handlers in lowlevel/experimental.py: - - _default_get_task (task not found) - - _default_get_task_result - - _default_list_tasks - - _default_cancel_task - """ + """Test that enable_tasks() auto-registers working default handlers.""" server = Server("test-default-handlers") - # Enable tasks with default handlers (no custom handlers registered) task_support = server.experimental.enable_tasks() store = task_support.store @@ -493,24 +359,18 @@ async def run_server() -> None: task = await store.create_task(TaskMetadata(ttl=60000)) # Test list_tasks (default handler) - list_result = await client_session.send_request(ListTasksRequest(), ListTasksResult) + list_result = await client_session.experimental.list_tasks() assert len(list_result.tasks) == 1 assert list_result.tasks[0].task_id == task.task_id # Test get_task (default handler - found) - get_result = await client_session.send_request( - GetTaskRequest(params=GetTaskRequestParams(task_id=task.task_id)), - GetTaskResult, - ) + get_result = await client_session.experimental.get_task(task.task_id) assert get_result.task_id == task.task_id assert get_result.status == "working" # Test get_task (default handler - not found path) with pytest.raises(MCPError, match="not found"): - await client_session.send_request( - GetTaskRequest(params=GetTaskRequestParams(task_id="nonexistent-task")), - GetTaskResult, - ) + await client_session.experimental.get_task("nonexistent-task") # Create a completed task to test get_task_result completed_task = await store.create_task(TaskMetadata(ttl=60000)) @@ -529,9 +389,7 @@ async def run_server() -> None: assert "io.modelcontextprotocol/related-task" in payload_result.meta # Test cancel_task (default handler) - cancel_result = await client_session.send_request( - CancelTaskRequest(params=CancelTaskRequestParams(task_id=task.task_id)), CancelTaskResult - ) + cancel_result = await client_session.experimental.cancel_task(task.task_id) assert cancel_result.task_id == task.task_id assert cancel_result.status == "cancelled" diff --git a/tests/experimental/tasks/test_elicitation_scenarios.py b/tests/experimental/tasks/test_elicitation_scenarios.py index 57122da7b9..2d0378a9ce 100644 --- a/tests/experimental/tasks/test_elicitation_scenarios.py +++ b/tests/experimental/tasks/test_elicitation_scenarios.py @@ -17,7 +17,7 @@ from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers from mcp.client.session import ClientSession -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.experimental.task_context import ServerTaskContext from mcp.server.lowlevel import NotificationOptions from mcp.shared._context import RequestContext @@ -26,6 +26,7 @@ from mcp.shared.message import SessionMessage from mcp.types import ( TASK_REQUIRED, + CallToolRequestParams, CallToolResult, CreateMessageRequestParams, CreateMessageResult, @@ -35,11 +36,12 @@ ErrorData, GetTaskPayloadResult, GetTaskResult, + ListToolsResult, + PaginatedRequestParams, SamplingMessage, TaskMetadata, TextContent, Tool, - ToolExecution, ) @@ -181,24 +183,21 @@ async def test_scenario1_normal_tool_normal_elicitation() -> None: Server calls session.elicit() directly, client responds immediately. """ - server = Server("test-scenario1") elicit_received = Event() tool_result: list[str] = [] - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="confirm_action", - description="Confirm an action", - input_schema={"type": "object"}, - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: - ctx = server.request_context + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="confirm_action", + description="Confirm an action", + input_schema={"type": "object"}, + ) + ] + ) + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: # Normal elicitation - expects immediate response result = await ctx.session.elicit( message="Please confirm the action", @@ -209,6 +208,8 @@ async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResu tool_result.append("confirmed" if confirmed else "cancelled") return CallToolResult(content=[TextContent(type="text", text="confirmed" if confirmed else "cancelled")]) + server = Server("test-scenario1", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) + # Elicitation callback for client async def elicitation_callback( context: RequestContext[ClientSession], @@ -262,27 +263,24 @@ async def test_scenario2_normal_tool_task_augmented_elicitation() -> None: Server calls session.experimental.elicit_as_task(), client creates a task for the elicitation and returns CreateTaskResult. Server polls client. """ - server = Server("test-scenario2") elicit_received = Event() tool_result: list[str] = [] # Client-side task store for handling task-augmented elicitation client_task_store = InMemoryTaskStore() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="confirm_action", - description="Confirm an action", - input_schema={"type": "object"}, - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: - ctx = server.request_context + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="confirm_action", + description="Confirm an action", + input_schema={"type": "object"}, + ) + ] + ) + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: # Task-augmented elicitation - server polls client result = await ctx.session.experimental.elicit_as_task( message="Please confirm the action", @@ -294,6 +292,7 @@ async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResu tool_result.append("confirmed" if confirmed else "cancelled") return CallToolResult(content=[TextContent(type="text", text="confirmed" if confirmed else "cancelled")]) + server = Server("test-scenario2", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) task_handlers = create_client_task_handlers(client_task_store, elicit_received) # Set up streams @@ -342,26 +341,13 @@ async def test_scenario3_task_augmented_tool_normal_elicitation() -> None: Client calls tool as task. Inside the task, server uses task.elicit() which queues the request and delivers via tasks/result. """ - server = Server("test-scenario3") - server.experimental.enable_tasks() - elicit_received = Event() work_completed = Event() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="confirm_action", - description="Confirm an action", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CreateTaskResult: - ctx = server.request_context + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) async def work(task: ServerTaskContext) -> CallToolResult: @@ -377,6 +363,9 @@ async def work(task: ServerTaskContext) -> CallToolResult: return await ctx.experimental.run_task(work) + server = Server("test-scenario3", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) + server.experimental.enable_tasks() + # Elicitation callback for client async def elicitation_callback( context: RequestContext[ClientSession], @@ -452,29 +441,16 @@ async def test_scenario4_task_augmented_tool_task_augmented_elicitation() -> Non 5. Server gets the ElicitResult and completes the tool task 6. Client's tasks/result returns with the CallToolResult """ - server = Server("test-scenario4") - server.experimental.enable_tasks() - elicit_received = Event() work_completed = Event() # Client-side task store for handling task-augmented elicitation client_task_store = InMemoryTaskStore() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="confirm_action", - description="Confirm an action", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CreateTaskResult: - ctx = server.request_context + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) async def work(task: ServerTaskContext) -> CallToolResult: @@ -491,6 +467,8 @@ async def work(task: ServerTaskContext) -> CallToolResult: return await ctx.experimental.run_task(work) + server = Server("test-scenario4", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) + server.experimental.enable_tasks() task_handlers = create_client_task_handlers(client_task_store, elicit_received) # Set up streams @@ -553,27 +531,24 @@ async def test_scenario2_sampling_normal_tool_task_augmented_sampling() -> None: Server calls session.experimental.create_message_as_task(), client creates a task for the sampling and returns CreateTaskResult. Server polls client. """ - server = Server("test-scenario2-sampling") sampling_received = Event() tool_result: list[str] = [] # Client-side task store for handling task-augmented sampling client_task_store = InMemoryTaskStore() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="generate_text", - description="Generate text using sampling", - input_schema={"type": "object"}, - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: - ctx = server.request_context + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="generate_text", + description="Generate text using sampling", + input_schema={"type": "object"}, + ) + ] + ) + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: # Task-augmented sampling - server polls client result = await ctx.session.experimental.create_message_as_task( messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], @@ -587,6 +562,7 @@ async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResu tool_result.append(response_text) return CallToolResult(content=[TextContent(type="text", text=response_text)]) + server = Server("test-scenario2-sampling", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) task_handlers = create_sampling_task_handlers(client_task_store, sampling_received) # Set up streams @@ -636,29 +612,16 @@ async def test_scenario4_sampling_task_augmented_tool_task_augmented_sampling() which sends task-augmented sampling. Client creates its own task for the sampling, and server polls the client. """ - server = Server("test-scenario4-sampling") - server.experimental.enable_tasks() - sampling_received = Event() work_completed = Event() # Client-side task store for handling task-augmented sampling client_task_store = InMemoryTaskStore() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="generate_text", - description="Generate text using sampling", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CreateTaskResult: - ctx = server.request_context + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) async def work(task: ServerTaskContext) -> CallToolResult: @@ -677,6 +640,8 @@ async def work(task: ServerTaskContext) -> CallToolResult: return await ctx.experimental.run_task(work) + server = Server("test-scenario4-sampling", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) + server.experimental.enable_tasks() task_handlers = create_sampling_task_handlers(client_task_store, sampling_received) # Set up streams diff --git a/tests/experimental/tasks/test_spec_compliance.py b/tests/experimental/tasks/test_spec_compliance.py index d00ce40a45..38d7d0a664 100644 --- a/tests/experimental/tasks/test_spec_compliance.py +++ b/tests/experimental/tasks/test_spec_compliance.py @@ -10,17 +10,17 @@ import pytest -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.lowlevel import NotificationOptions from mcp.shared.experimental.tasks.helpers import MODEL_IMMEDIATE_RESPONSE_KEY from mcp.types import ( - CancelTaskRequest, + CancelTaskRequestParams, CancelTaskResult, CreateTaskResult, - GetTaskRequest, + GetTaskRequestParams, GetTaskResult, - ListTasksRequest, ListTasksResult, + PaginatedRequestParams, ServerCapabilities, Task, ) @@ -44,13 +44,22 @@ def test_server_without_task_handlers_has_no_tasks_capability() -> None: assert caps.tasks is None +async def _noop_get_task(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: + raise NotImplementedError + + +async def _noop_list_tasks(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListTasksResult: + raise NotImplementedError + + +async def _noop_cancel_task(ctx: ServerRequestContext, params: CancelTaskRequestParams) -> CancelTaskResult: + raise NotImplementedError + + def test_server_with_list_tasks_handler_declares_list_capability() -> None: """Server with list_tasks handler declares tasks.list capability.""" server: Server = Server("test") - - @server.experimental.list_tasks() - async def handle_list(req: ListTasksRequest) -> ListTasksResult: - raise NotImplementedError + server.experimental.enable_tasks(on_list_tasks=_noop_list_tasks) caps = _get_capabilities(server) assert caps.tasks is not None @@ -60,10 +69,7 @@ async def handle_list(req: ListTasksRequest) -> ListTasksResult: def test_server_with_cancel_task_handler_declares_cancel_capability() -> None: """Server with cancel_task handler declares tasks.cancel capability.""" server: Server = Server("test") - - @server.experimental.cancel_task() - async def handle_cancel(req: CancelTaskRequest) -> CancelTaskResult: - raise NotImplementedError + server.experimental.enable_tasks(on_cancel_task=_noop_cancel_task) caps = _get_capabilities(server) assert caps.tasks is not None @@ -75,10 +81,7 @@ def test_server_with_get_task_handler_declares_requests_tools_call_capability() (get_task is required for task-augmented tools/call support) """ server: Server = Server("test") - - @server.experimental.get_task() - async def handle_get(req: GetTaskRequest) -> GetTaskResult: - raise NotImplementedError + server.experimental.enable_tasks(on_get_task=_noop_get_task) caps = _get_capabilities(server) assert caps.tasks is not None @@ -86,28 +89,30 @@ async def handle_get(req: GetTaskRequest) -> GetTaskResult: assert caps.tasks.requests.tools is not None -def test_server_without_list_handler_has_no_list_capability() -> None: +@pytest.mark.skip( + reason="TODO(maxisbey): enable_tasks registers default handlers for all task methods, " + "so partial capabilities aren't possible yet. Low-level API should support " + "selectively enabling/disabling task capabilities." +) +def test_server_without_list_handler_has_no_list_capability() -> None: # pragma: no cover """Server without list_tasks handler has no tasks.list capability.""" server: Server = Server("test") - - # Register only get_task (not list_tasks) - @server.experimental.get_task() - async def handle_get(req: GetTaskRequest) -> GetTaskResult: - raise NotImplementedError + server.experimental.enable_tasks(on_get_task=_noop_get_task) caps = _get_capabilities(server) assert caps.tasks is not None assert caps.tasks.list is None -def test_server_without_cancel_handler_has_no_cancel_capability() -> None: +@pytest.mark.skip( + reason="TODO(maxisbey): enable_tasks registers default handlers for all task methods, " + "so partial capabilities aren't possible yet. Low-level API should support " + "selectively enabling/disabling task capabilities." +) +def test_server_without_cancel_handler_has_no_cancel_capability() -> None: # pragma: no cover """Server without cancel_task handler has no tasks.cancel capability.""" server: Server = Server("test") - - # Register only get_task (not cancel_task) - @server.experimental.get_task() - async def handle_get(req: GetTaskRequest) -> GetTaskResult: - raise NotImplementedError + server.experimental.enable_tasks(on_get_task=_noop_get_task) caps = _get_capabilities(server) assert caps.tasks is not None @@ -117,18 +122,11 @@ async def handle_get(req: GetTaskRequest) -> GetTaskResult: def test_server_with_all_task_handlers_has_full_capability() -> None: """Server with all task handlers declares complete tasks capability.""" server: Server = Server("test") - - @server.experimental.list_tasks() - async def handle_list(req: ListTasksRequest) -> ListTasksResult: - raise NotImplementedError - - @server.experimental.cancel_task() - async def handle_cancel(req: CancelTaskRequest) -> CancelTaskResult: - raise NotImplementedError - - @server.experimental.get_task() - async def handle_get(req: GetTaskRequest) -> GetTaskResult: - raise NotImplementedError + server.experimental.enable_tasks( + on_list_tasks=_noop_list_tasks, + on_cancel_task=_noop_cancel_task, + on_get_task=_noop_get_task, + ) caps = _get_capabilities(server) assert caps.tasks is not None diff --git a/tests/issues/test_129_resource_templates.py b/tests/issues/test_129_resource_templates.py index 39e2c6f2ad..bb4735121f 100644 --- a/tests/issues/test_129_resource_templates.py +++ b/tests/issues/test_129_resource_templates.py @@ -1,15 +1,13 @@ import pytest -from mcp import types +from mcp import Client from mcp.server.mcpserver import MCPServer @pytest.mark.anyio async def test_resource_templates(): - # Create an MCP server mcp = MCPServer("Demo") - # Add a dynamic greeting resource @mcp.resource("greeting://{name}") def get_greeting(name: str) -> str: # pragma: no cover """Get a personalized greeting""" @@ -20,23 +18,16 @@ def get_user_profile(user_id: str) -> str: # pragma: no cover """Dynamic user data""" return f"Profile data for user {user_id}" - # Get the list of resource templates using the underlying server - # Note: list_resource_templates() returns a decorator that wraps the handler - # The handler returns a ServerResult with a ListResourceTemplatesResult inside - result = await mcp._lowlevel_server.request_handlers[types.ListResourceTemplatesRequest]( - types.ListResourceTemplatesRequest(params=None) - ) - assert isinstance(result, types.ListResourceTemplatesResult) - templates = result.resource_templates - - # Verify we get both templates back - assert len(templates) == 2 - - # Verify template details - greeting_template = next(t for t in templates if t.name == "get_greeting") - assert greeting_template.uri_template == "greeting://{name}" - assert greeting_template.description == "Get a personalized greeting" - - profile_template = next(t for t in templates if t.name == "get_user_profile") - assert profile_template.uri_template == "users://{user_id}/profile" - assert profile_template.description == "Dynamic user data" + async with Client(mcp) as client: + result = await client.list_resource_templates() + templates = result.resource_templates + + assert len(templates) == 2 + + greeting_template = next(t for t in templates if t.name == "get_greeting") + assert greeting_template.uri_template == "greeting://{name}" + assert greeting_template.description == "Get a personalized greeting" + + profile_template = next(t for t in templates if t.name == "get_user_profile") + assert profile_template.uri_template == "users://{user_id}/profile" + assert profile_template.description == "Dynamic user data" diff --git a/tests/issues/test_152_resource_mime_type.py b/tests/issues/test_152_resource_mime_type.py index e738017f85..851e89979f 100644 --- a/tests/issues/test_152_resource_mime_type.py +++ b/tests/issues/test_152_resource_mime_type.py @@ -3,9 +3,16 @@ import pytest from mcp import Client, types -from mcp.server.lowlevel import Server -from mcp.server.lowlevel.helper_types import ReadResourceContents +from mcp.server import Server, ServerRequestContext from mcp.server.mcpserver import MCPServer +from mcp.types import ( + BlobResourceContents, + ListResourcesResult, + PaginatedRequestParams, + ReadResourceRequestParams, + ReadResourceResult, + TextResourceContents, +) pytestmark = pytest.mark.anyio @@ -58,7 +65,6 @@ def get_image_as_bytes() -> bytes: async def test_lowlevel_resource_mime_type(): """Test that mime_type parameter is respected for resources.""" - server = Server("test") # Create a small test image as bytes image_bytes = b"fake_image_data" @@ -74,17 +80,24 @@ async def test_lowlevel_resource_mime_type(): ), ] - @server.list_resources() - async def handle_list_resources(): - return test_resources - - @server.read_resource() - async def handle_read_resource(uri: str): - if str(uri) == "test://image": - return [ReadResourceContents(content=base64_string, mime_type="image/png")] - elif str(uri) == "test://image_bytes": - return [ReadResourceContents(content=bytes(image_bytes), mime_type="image/png")] - raise Exception(f"Resource not found: {uri}") # pragma: no cover + async def handle_list_resources( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult(resources=test_resources) + + resource_contents: dict[str, list[TextResourceContents | BlobResourceContents]] = { + "test://image": [TextResourceContents(uri="test://image", text=base64_string, mime_type="image/png")], + "test://image_bytes": [ + BlobResourceContents( + uri="test://image_bytes", blob=base64.b64encode(image_bytes).decode("utf-8"), mime_type="image/png" + ) + ], + } + + async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: + return ReadResourceResult(contents=resource_contents[str(params.uri)]) + + server = Server("test", on_list_resources=handle_list_resources, on_read_resource=handle_read_resource) # Test that resources are listed with correct mime type async with Client(server) as client: diff --git a/tests/issues/test_1574_resource_uri_validation.py b/tests/issues/test_1574_resource_uri_validation.py index e6ff568774..c677081282 100644 --- a/tests/issues/test_1574_resource_uri_validation.py +++ b/tests/issues/test_1574_resource_uri_validation.py @@ -13,8 +13,14 @@ import pytest from mcp import Client, types -from mcp.server.lowlevel import Server -from mcp.server.lowlevel.helper_types import ReadResourceContents +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + ListResourcesResult, + PaginatedRequestParams, + ReadResourceRequestParams, + ReadResourceResult, + TextResourceContents, +) pytestmark = pytest.mark.anyio @@ -26,24 +32,24 @@ async def test_relative_uri_roundtrip(): the server would fail to serialize resources with relative URIs, or the URI would be transformed during the roundtrip. """ - server = Server("test") - - @server.list_resources() - async def list_resources(): - return [ - types.Resource(name="user", uri="users/me"), - types.Resource(name="config", uri="./config"), - types.Resource(name="parent", uri="../parent/resource"), - ] - - @server.read_resource() - async def read_resource(uri: str): - return [ - ReadResourceContents( - content=f"data for {uri}", - mime_type="text/plain", - ) - ] + + async def handle_list_resources( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult( + resources=[ + types.Resource(name="user", uri="users/me"), + types.Resource(name="config", uri="./config"), + types.Resource(name="parent", uri="../parent/resource"), + ] + ) + + async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: + return ReadResourceResult( + contents=[TextResourceContents(uri=str(params.uri), text=f"data for {params.uri}", mime_type="text/plain")] + ) + + server = Server("test", on_list_resources=handle_list_resources, on_read_resource=handle_read_resource) async with Client(server) as client: # List should return the exact URIs we specified @@ -67,18 +73,23 @@ async def test_custom_scheme_uri_roundtrip(): Some MCP servers use custom schemes like "custom://resource". These should work end-to-end. """ - server = Server("test") - - @server.list_resources() - async def list_resources(): - return [ - types.Resource(name="custom", uri="custom://my-resource"), - types.Resource(name="file", uri="file:///path/to/file"), - ] - - @server.read_resource() - async def read_resource(uri: str): - return [ReadResourceContents(content="data", mime_type="text/plain")] + + async def handle_list_resources( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult( + resources=[ + types.Resource(name="custom", uri="custom://my-resource"), + types.Resource(name="file", uri="file:///path/to/file"), + ] + ) + + async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: + return ReadResourceResult( + contents=[TextResourceContents(uri=str(params.uri), text="data", mime_type="text/plain")] + ) + + server = Server("test", on_list_resources=handle_list_resources, on_read_resource=handle_read_resource) async with Client(server) as client: resources = await client.list_resources() diff --git a/tests/issues/test_342_base64_encoding.py b/tests/issues/test_342_base64_encoding.py index 44b17d3372..2bccedf8d2 100644 --- a/tests/issues/test_342_base64_encoding.py +++ b/tests/issues/test_342_base64_encoding.py @@ -1,83 +1,52 @@ """Test for base64 encoding issue in MCP server. -This test demonstrates the issue in server.py where the server uses -urlsafe_b64encode but the BlobResourceContents validator expects standard -base64 encoding. - -The test should FAIL before fixing server.py to use b64encode instead of -urlsafe_b64encode. -After the fix, the test should PASS. +This test verifies that binary resource data is encoded with standard base64 +(not urlsafe_b64encode), so BlobResourceContents validation succeeds. """ import base64 -from typing import cast import pytest -from mcp.server.lowlevel.helper_types import ReadResourceContents -from mcp.server.lowlevel.server import Server -from mcp.types import ( - BlobResourceContents, - ReadResourceRequest, - ReadResourceRequestParams, - ReadResourceResult, - ServerResult, -) +from mcp import Client +from mcp.server.mcpserver import MCPServer +from mcp.types import BlobResourceContents +pytestmark = pytest.mark.anyio -@pytest.mark.anyio -async def test_server_base64_encoding_issue(): - """Tests that server response can be validated by BlobResourceContents. - This test will: - 1. Set up a server that returns binary data - 2. Extract the base64-encoded blob from the server's response - 3. Verify the encoded data can be properly validated by BlobResourceContents +async def test_server_base64_encoding(): + """Tests that binary resource data round-trips correctly through base64 encoding. - BEFORE FIX: The test will fail because server uses urlsafe_b64encode - AFTER FIX: The test will pass because server uses standard b64encode + The test uses binary data that produces different results with urlsafe vs standard + base64, ensuring the server uses standard encoding. """ - server = Server("test") + mcp = MCPServer("test") # Create binary data that will definitely result in + and / characters # when encoded with standard base64 binary_data = bytes(list(range(255)) * 4) - # Register a resource handler that returns our test data - @server.read_resource() - async def read_resource(uri: str) -> list[ReadResourceContents]: - return [ReadResourceContents(content=binary_data, mime_type="application/octet-stream")] - - # Get the handler directly from the server - handler = server.request_handlers[ReadResourceRequest] - - # Create a request - request = ReadResourceRequest( - params=ReadResourceRequestParams(uri="test://resource"), - ) - - # Call the handler to get the response - result: ServerResult = await handler(request) - - # After (fixed code): - read_result: ReadResourceResult = cast(ReadResourceResult, result) - blob_content = read_result.contents[0] - - # First verify our test data actually produces different encodings + # Sanity check: our test data produces different encodings urlsafe_b64 = base64.urlsafe_b64encode(binary_data).decode() standard_b64 = base64.b64encode(binary_data).decode() - assert urlsafe_b64 != standard_b64, "Test data doesn't demonstrate" - " encoding difference" + assert urlsafe_b64 != standard_b64, "Test data doesn't demonstrate encoding difference" + + @mcp.resource("test://binary", mime_type="application/octet-stream") + def get_binary() -> bytes: + """Return binary test data.""" + return binary_data + + async with Client(mcp) as client: + result = await client.read_resource("test://binary") + assert len(result.contents) == 1 - # Now validate the server's output with BlobResourceContents.model_validate - # Before the fix: This should fail with "Invalid base64" because server - # uses urlsafe_b64encode - # After the fix: This should pass because server will use standard b64encode - model_dict = blob_content.model_dump() + blob_content = result.contents[0] + assert isinstance(blob_content, BlobResourceContents) - # Direct validation - this will fail before fix, pass after fix - blob_model = BlobResourceContents.model_validate(model_dict) + # Verify standard base64 was used (not urlsafe) + assert blob_content.blob == standard_b64 - # Verify we can decode the data back correctly - decoded = base64.b64decode(blob_model.blob) - assert decoded == binary_data + # Verify we can decode the data back correctly + decoded = base64.b64decode(blob_content.blob) + assert decoded == binary_data diff --git a/tests/issues/test_88_random_error.py b/tests/issues/test_88_random_error.py index cd27698e66..6b593d2a54 100644 --- a/tests/issues/test_88_random_error.py +++ b/tests/issues/test_88_random_error.py @@ -1,8 +1,6 @@ """Test to reproduce issue #88: Random error thrown on response.""" -from collections.abc import Sequence from pathlib import Path -from typing import Any import anyio import pytest @@ -11,10 +9,10 @@ from mcp import types from mcp.client.session import ClientSession -from mcp.server.lowlevel import Server +from mcp.server import Server, ServerRequestContext from mcp.shared.exceptions import MCPError from mcp.shared.message import SessionMessage -from mcp.types import ContentBlock, TextContent +from mcp.types import CallToolRequestParams, CallToolResult, ListToolsResult, PaginatedRequestParams, TextContent @pytest.mark.anyio @@ -32,36 +30,38 @@ async def test_notification_validation_error(tmp_path: Path): - Slow operations use minimal timeout (10ms) for quick test execution """ - server = Server(name="test") request_count = 0 slow_request_lock = anyio.Event() - @server.list_tools() - async def list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="slow", - description="A slow tool", - input_schema={"type": "object"}, - ), - types.Tool( - name="fast", - description="A fast tool", - input_schema={"type": "object"}, - ), - ] - - @server.call_tool() - async def slow_tool(name: str, arguments: dict[str, Any]) -> Sequence[ContentBlock]: + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + types.Tool( + name="slow", + description="A slow tool", + input_schema={"type": "object"}, + ), + types.Tool( + name="fast", + description="A fast tool", + input_schema={"type": "object"}, + ), + ] + ) + + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: nonlocal request_count request_count += 1 + assert params.name in ("slow", "fast"), f"Unknown tool: {params.name}" - if name == "slow": + if params.name == "slow": await slow_request_lock.wait() # it should timeout here - return [TextContent(type="text", text=f"slow {request_count}")] - elif name == "fast": - return [TextContent(type="text", text=f"fast {request_count}")] - return [TextContent(type="text", text=f"unknown {request_count}")] # pragma: no cover + text = f"slow {request_count}" + else: + text = f"fast {request_count}" + return CallToolResult(content=[TextContent(type="text", text=text)]) + + server = Server(name="test", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) async def server_handler( read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], diff --git a/tests/server/lowlevel/test_func_inspection.py b/tests/server/lowlevel/test_func_inspection.py deleted file mode 100644 index 9cb2b561ac..0000000000 --- a/tests/server/lowlevel/test_func_inspection.py +++ /dev/null @@ -1,292 +0,0 @@ -"""Unit tests for func_inspection module. - -Tests the create_call_wrapper function which determines how to call handler functions -with different parameter signatures and type hints. -""" - -from typing import Any, Generic, TypeVar - -import pytest - -from mcp.server.lowlevel.func_inspection import create_call_wrapper -from mcp.types import ListPromptsRequest, ListResourcesRequest, ListToolsRequest, PaginatedRequestParams - -T = TypeVar("T") - - -@pytest.mark.anyio -async def test_no_params_returns_deprecated_wrapper() -> None: - """Test: def foo() - should call without request.""" - called_without_request = False - - async def handler() -> list[str]: - nonlocal called_without_request - called_without_request = True - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should call handler without passing request - request = ListPromptsRequest(method="prompts/list", params=None) - result = await wrapper(request) - assert called_without_request is True - assert result == ["test"] - - -@pytest.mark.anyio -async def test_param_with_default_returns_deprecated_wrapper() -> None: - """Test: def foo(thing: int = 1) - should call without request.""" - called_without_request = False - - async def handler(thing: int = 1) -> list[str]: - nonlocal called_without_request - called_without_request = True - return [f"test-{thing}"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should call handler without passing request (uses default value) - request = ListPromptsRequest(method="prompts/list", params=None) - result = await wrapper(request) - assert called_without_request is True - assert result == ["test-1"] - - -@pytest.mark.anyio -async def test_typed_request_param_passes_request() -> None: - """Test: def foo(req: ListPromptsRequest) - should pass request through.""" - received_request = None - - async def handler(req: ListPromptsRequest) -> list[str]: - nonlocal received_request - received_request = req - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should pass request to handler - request = ListPromptsRequest(method="prompts/list", params=PaginatedRequestParams(cursor="test-cursor")) - await wrapper(request) - - assert received_request is not None - assert received_request is request - params = getattr(received_request, "params", None) - assert params is not None - assert params.cursor == "test-cursor" - - -@pytest.mark.anyio -async def test_typed_request_with_default_param_passes_request() -> None: - """Test: def foo(req: ListPromptsRequest, thing: int = 1) - should pass request through.""" - received_request = None - received_thing = None - - async def handler(req: ListPromptsRequest, thing: int = 1) -> list[str]: - nonlocal received_request, received_thing - received_request = req - received_thing = thing - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should pass request to handler - request = ListPromptsRequest(method="prompts/list", params=None) - await wrapper(request) - - assert received_request is request - assert received_thing == 1 # default value - - -@pytest.mark.anyio -async def test_optional_typed_request_with_default_none_is_deprecated() -> None: - """Test: def foo(thing: int = 1, req: ListPromptsRequest | None = None) - old style.""" - called_without_request = False - - async def handler(thing: int = 1, req: ListPromptsRequest | None = None) -> list[str]: - nonlocal called_without_request - called_without_request = True - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should call handler without passing request - request = ListPromptsRequest(method="prompts/list", params=None) - result = await wrapper(request) - assert called_without_request is True - assert result == ["test"] - - -@pytest.mark.anyio -async def test_untyped_request_param_is_deprecated() -> None: - """Test: def foo(req) - should call without request.""" - called = False - - async def handler(req): # type: ignore[no-untyped-def] # pyright: ignore[reportMissingParameterType] # pragma: no cover - nonlocal called - called = True - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) # pyright: ignore[reportUnknownArgumentType] - - # Wrapper should call handler without passing request, which will fail because req is required - request = ListPromptsRequest(method="prompts/list", params=None) - # This will raise TypeError because handler expects 'req' but wrapper doesn't provide it - with pytest.raises(TypeError, match="missing 1 required positional argument"): - await wrapper(request) - - -@pytest.mark.anyio -async def test_any_typed_request_param_is_deprecated() -> None: - """Test: def foo(req: Any) - should call without request.""" - - async def handler(req: Any) -> list[str]: # pragma: no cover - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should call handler without passing request, which will fail because req is required - request = ListPromptsRequest(method="prompts/list", params=None) - # This will raise TypeError because handler expects 'req' but wrapper doesn't provide it - with pytest.raises(TypeError, match="missing 1 required positional argument"): - await wrapper(request) - - -@pytest.mark.anyio -async def test_generic_typed_request_param_is_deprecated() -> None: - """Test: def foo(req: Generic[T]) - should call without request.""" - - async def handler(req: Generic[T]) -> list[str]: # pyright: ignore[reportGeneralTypeIssues] # pragma: no cover - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should call handler without passing request, which will fail because req is required - request = ListPromptsRequest(method="prompts/list", params=None) - # This will raise TypeError because handler expects 'req' but wrapper doesn't provide it - with pytest.raises(TypeError, match="missing 1 required positional argument"): - await wrapper(request) - - -@pytest.mark.anyio -async def test_wrong_typed_request_param_is_deprecated() -> None: - """Test: def foo(req: str) - should call without request.""" - - async def handler(req: str) -> list[str]: # pragma: no cover - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should call handler without passing request, which will fail because req is required - request = ListPromptsRequest(method="prompts/list", params=None) - # This will raise TypeError because handler expects 'req' but wrapper doesn't provide it - with pytest.raises(TypeError, match="missing 1 required positional argument"): - await wrapper(request) - - -@pytest.mark.anyio -async def test_required_param_before_typed_request_attempts_to_pass() -> None: - """Test: def foo(thing: int, req: ListPromptsRequest) - attempts to pass request (will fail at runtime).""" - received_request = None - - async def handler(thing: int, req: ListPromptsRequest) -> list[str]: # pragma: no cover - nonlocal received_request - received_request = req - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper will attempt to pass request, but it will fail at runtime - # because 'thing' is required and has no default - request = ListPromptsRequest(method="prompts/list", params=None) - - # This will raise TypeError because 'thing' is missing - with pytest.raises(TypeError, match="missing 1 required positional argument: 'thing'"): - await wrapper(request) - - -@pytest.mark.anyio -async def test_positional_only_param_with_correct_type() -> None: - """Test: def foo(req: ListPromptsRequest, /) - should pass request through.""" - received_request = None - - async def handler(req: ListPromptsRequest, /) -> list[str]: - nonlocal received_request - received_request = req - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should pass request to handler - request = ListPromptsRequest(method="prompts/list", params=None) - await wrapper(request) - - assert received_request is request - - -@pytest.mark.anyio -async def test_keyword_only_param_with_correct_type() -> None: - """Test: def foo(*, req: ListPromptsRequest) - should pass request through.""" - received_request = None - - async def handler(*, req: ListPromptsRequest) -> list[str]: - nonlocal received_request - received_request = req - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should pass request to handler with keyword argument - request = ListPromptsRequest(method="prompts/list", params=None) - await wrapper(request) - - assert received_request is request - - -@pytest.mark.anyio -async def test_different_request_types() -> None: - """Test that wrapper works with different request types.""" - # Test with ListResourcesRequest - received_request = None - - async def handler(req: ListResourcesRequest) -> list[str]: - nonlocal received_request - received_request = req - return ["test"] - - wrapper = create_call_wrapper(handler, ListResourcesRequest) - - request = ListResourcesRequest(method="resources/list", params=None) - await wrapper(request) - - assert received_request is request - - # Test with ListToolsRequest - received_request = None - - async def handler2(req: ListToolsRequest) -> list[str]: - nonlocal received_request - received_request = req - return ["test"] - - wrapper2 = create_call_wrapper(handler2, ListToolsRequest) - - request2 = ListToolsRequest(method="tools/list", params=None) - await wrapper2(request2) - - assert received_request is request2 - - -@pytest.mark.anyio -async def test_mixed_params_with_typed_request() -> None: - """Test: def foo(a: str, req: ListPromptsRequest, b: int = 5) - attempts to pass request.""" - - async def handler(a: str, req: ListPromptsRequest, b: int = 5) -> list[str]: # pragma: no cover - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Will fail at runtime due to missing 'a' - request = ListPromptsRequest(method="prompts/list", params=None) - - with pytest.raises(TypeError, match="missing 1 required positional argument: 'a'"): - await wrapper(request) diff --git a/tests/server/lowlevel/test_server_listing.py b/tests/server/lowlevel/test_server_listing.py index 6bf4cddb39..2c3d303a92 100644 --- a/tests/server/lowlevel/test_server_listing.py +++ b/tests/server/lowlevel/test_server_listing.py @@ -1,20 +1,16 @@ -"""Basic tests for list_prompts, list_resources, and list_tools decorators without pagination.""" - -import warnings +"""Basic tests for list_prompts, list_resources, and list_tools handlers without pagination.""" import pytest -from mcp.server import Server +from mcp import Client +from mcp.server import Server, ServerRequestContext from mcp.types import ( - ListPromptsRequest, ListPromptsResult, - ListResourcesRequest, ListResourcesResult, - ListToolsRequest, ListToolsResult, + PaginatedRequestParams, Prompt, Resource, - ServerResult, Tool, ) @@ -22,60 +18,44 @@ @pytest.mark.anyio async def test_list_prompts_basic() -> None: """Test basic prompt listing without pagination.""" - server = Server("test") - test_prompts = [ Prompt(name="prompt1", description="First prompt"), Prompt(name="prompt2", description="Second prompt"), ] - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - - @server.list_prompts() - async def handle_list_prompts() -> list[Prompt]: - return test_prompts - - handler = server.request_handlers[ListPromptsRequest] - request = ListPromptsRequest(method="prompts/list", params=None) - result = await handler(request) + async def handle_list_prompts( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListPromptsResult: + return ListPromptsResult(prompts=test_prompts) - assert isinstance(result, ServerResult) - assert isinstance(result, ListPromptsResult) - assert result.prompts == test_prompts + server = Server("test", on_list_prompts=handle_list_prompts) + async with Client(server) as client: + result = await client.list_prompts() + assert result.prompts == test_prompts @pytest.mark.anyio async def test_list_resources_basic() -> None: """Test basic resource listing without pagination.""" - server = Server("test") - test_resources = [ Resource(uri="file:///test1.txt", name="Test 1"), Resource(uri="file:///test2.txt", name="Test 2"), ] - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - - @server.list_resources() - async def handle_list_resources() -> list[Resource]: - return test_resources + async def handle_list_resources( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult(resources=test_resources) - handler = server.request_handlers[ListResourcesRequest] - request = ListResourcesRequest(method="resources/list", params=None) - result = await handler(request) - - assert isinstance(result, ServerResult) - assert isinstance(result, ListResourcesResult) - assert result.resources == test_resources + server = Server("test", on_list_resources=handle_list_resources) + async with Client(server) as client: + result = await client.list_resources() + assert result.resources == test_resources @pytest.mark.anyio async def test_list_tools_basic() -> None: """Test basic tool listing without pagination.""" - server = Server("test") - test_tools = [ Tool( name="tool1", @@ -102,80 +82,53 @@ async def test_list_tools_basic() -> None: ), ] - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=test_tools) - @server.list_tools() - async def handle_list_tools() -> list[Tool]: - return test_tools - - handler = server.request_handlers[ListToolsRequest] - request = ListToolsRequest(method="tools/list", params=None) - result = await handler(request) - - assert isinstance(result, ServerResult) - assert isinstance(result, ListToolsResult) - assert result.tools == test_tools + server = Server("test", on_list_tools=handle_list_tools) + async with Client(server) as client: + result = await client.list_tools() + assert result.tools == test_tools @pytest.mark.anyio async def test_list_prompts_empty() -> None: """Test listing with empty results.""" - server = Server("test") - - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - - @server.list_prompts() - async def handle_list_prompts() -> list[Prompt]: - return [] - handler = server.request_handlers[ListPromptsRequest] - request = ListPromptsRequest(method="prompts/list", params=None) - result = await handler(request) + async def handle_list_prompts( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListPromptsResult: + return ListPromptsResult(prompts=[]) - assert isinstance(result, ServerResult) - assert isinstance(result, ListPromptsResult) - assert result.prompts == [] + server = Server("test", on_list_prompts=handle_list_prompts) + async with Client(server) as client: + result = await client.list_prompts() + assert result.prompts == [] @pytest.mark.anyio async def test_list_resources_empty() -> None: """Test listing with empty results.""" - server = Server("test") - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) + async def handle_list_resources( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult(resources=[]) - @server.list_resources() - async def handle_list_resources() -> list[Resource]: - return [] - - handler = server.request_handlers[ListResourcesRequest] - request = ListResourcesRequest(method="resources/list", params=None) - result = await handler(request) - - assert isinstance(result, ServerResult) - assert isinstance(result, ListResourcesResult) - assert result.resources == [] + server = Server("test", on_list_resources=handle_list_resources) + async with Client(server) as client: + result = await client.list_resources() + assert result.resources == [] @pytest.mark.anyio async def test_list_tools_empty() -> None: """Test listing with empty results.""" - server = Server("test") - - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - - @server.list_tools() - async def handle_list_tools() -> list[Tool]: - return [] - handler = server.request_handlers[ListToolsRequest] - request = ListToolsRequest(method="tools/list", params=None) - result = await handler(request) + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[]) - assert isinstance(result, ServerResult) - assert isinstance(result, ListToolsResult) - assert result.tools == [] + server = Server("test", on_list_tools=handle_list_tools) + async with Client(server) as client: + result = await client.list_tools() + assert result.tools == [] diff --git a/tests/server/lowlevel/test_server_pagination.py b/tests/server/lowlevel/test_server_pagination.py index 081fb262ab..a4627b316d 100644 --- a/tests/server/lowlevel/test_server_pagination.py +++ b/tests/server/lowlevel/test_server_pagination.py @@ -1,111 +1,83 @@ import pytest -from mcp.server import Server +from mcp import Client +from mcp.server import Server, ServerRequestContext from mcp.types import ( - ListPromptsRequest, ListPromptsResult, - ListResourcesRequest, ListResourcesResult, - ListToolsRequest, ListToolsResult, PaginatedRequestParams, - ServerResult, ) @pytest.mark.anyio async def test_list_prompts_pagination() -> None: - server = Server("test") test_cursor = "test-cursor-123" + received_params: PaginatedRequestParams | None = None - # Track what request was received - received_request: ListPromptsRequest | None = None - - @server.list_prompts() - async def handle_list_prompts(request: ListPromptsRequest) -> ListPromptsResult: - nonlocal received_request - received_request = request + async def handle_list_prompts( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListPromptsResult: + nonlocal received_params + received_params = params return ListPromptsResult(prompts=[], next_cursor="next") - handler = server.request_handlers[ListPromptsRequest] - - # Test: No cursor provided -> handler receives request with None params - request = ListPromptsRequest(method="prompts/list", params=None) - result = await handler(request) - assert received_request is not None - assert received_request.params is None - assert isinstance(result, ServerResult) + server = Server("test", on_list_prompts=handle_list_prompts) + async with Client(server) as client: + # No cursor provided + await client.list_prompts() + assert received_params is not None + assert received_params.cursor is None - # Test: Cursor provided -> handler receives request with cursor in params - request_with_cursor = ListPromptsRequest(method="prompts/list", params=PaginatedRequestParams(cursor=test_cursor)) - result2 = await handler(request_with_cursor) - assert received_request is not None - assert received_request.params is not None - assert received_request.params.cursor == test_cursor - assert isinstance(result2, ServerResult) + # Cursor provided + await client.list_prompts(cursor=test_cursor) + assert received_params is not None + assert received_params.cursor == test_cursor @pytest.mark.anyio async def test_list_resources_pagination() -> None: - server = Server("test") test_cursor = "resource-cursor-456" + received_params: PaginatedRequestParams | None = None - # Track what request was received - received_request: ListResourcesRequest | None = None - - @server.list_resources() - async def handle_list_resources(request: ListResourcesRequest) -> ListResourcesResult: - nonlocal received_request - received_request = request + async def handle_list_resources( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListResourcesResult: + nonlocal received_params + received_params = params return ListResourcesResult(resources=[], next_cursor="next") - handler = server.request_handlers[ListResourcesRequest] + server = Server("test", on_list_resources=handle_list_resources) + async with Client(server) as client: + # No cursor provided + await client.list_resources() + assert received_params is not None + assert received_params.cursor is None - # Test: No cursor provided -> handler receives request with None params - request = ListResourcesRequest(method="resources/list", params=None) - result = await handler(request) - assert received_request is not None - assert received_request.params is None - assert isinstance(result, ServerResult) - - # Test: Cursor provided -> handler receives request with cursor in params - request_with_cursor = ListResourcesRequest( - method="resources/list", params=PaginatedRequestParams(cursor=test_cursor) - ) - result2 = await handler(request_with_cursor) - assert received_request is not None - assert received_request.params is not None - assert received_request.params.cursor == test_cursor - assert isinstance(result2, ServerResult) + # Cursor provided + await client.list_resources(cursor=test_cursor) + assert received_params is not None + assert received_params.cursor == test_cursor @pytest.mark.anyio async def test_list_tools_pagination() -> None: - server = Server("test") test_cursor = "tools-cursor-789" + received_params: PaginatedRequestParams | None = None - # Track what request was received - received_request: ListToolsRequest | None = None - - @server.list_tools() - async def handle_list_tools(request: ListToolsRequest) -> ListToolsResult: - nonlocal received_request - received_request = request + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + nonlocal received_params + received_params = params return ListToolsResult(tools=[], next_cursor="next") - handler = server.request_handlers[ListToolsRequest] - - # Test: No cursor provided -> handler receives request with None params - request = ListToolsRequest(method="tools/list", params=None) - result = await handler(request) - assert received_request is not None - assert received_request.params is None - assert isinstance(result, ServerResult) - - # Test: Cursor provided -> handler receives request with cursor in params - request_with_cursor = ListToolsRequest(method="tools/list", params=PaginatedRequestParams(cursor=test_cursor)) - result2 = await handler(request_with_cursor) - assert received_request is not None - assert received_request.params is not None - assert received_request.params.cursor == test_cursor - assert isinstance(result2, ServerResult) + server = Server("test", on_list_tools=handle_list_tools) + async with Client(server) as client: + # No cursor provided + await client.list_tools() + assert received_params is not None + assert received_params.cursor is None + + # Cursor provided + await client.list_tools(cursor=test_cursor) + assert received_params is not None + assert received_params.cursor == test_cursor diff --git a/tests/server/mcpserver/auth/test_auth_integration.py b/tests/server/mcpserver/auth/test_auth_integration.py index a78a86cf0b..602f5cc752 100644 --- a/tests/server/mcpserver/auth/test_auth_integration.py +++ b/tests/server/mcpserver/auth/test_auth_integration.py @@ -21,7 +21,8 @@ RefreshToken, construct_redirect_uri, ) -from mcp.server.auth.routes import ClientRegistrationOptions, RevocationOptions, create_auth_routes +from mcp.server.auth.routes import create_auth_routes +from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions from mcp.shared.auth import OAuthClientInformationFull, OAuthToken diff --git a/tests/server/mcpserver/prompts/test_base.py b/tests/server/mcpserver/prompts/test_base.py index 035e1cc81d..553e47363f 100644 --- a/tests/server/mcpserver/prompts/test_base.py +++ b/tests/server/mcpserver/prompts/test_base.py @@ -2,8 +2,8 @@ import pytest -from mcp.server.mcpserver.prompts.base import AssistantMessage, Message, Prompt, TextContent, UserMessage -from mcp.types import EmbeddedResource, TextResourceContents +from mcp.server.mcpserver.prompts.base import AssistantMessage, Message, Prompt, UserMessage +from mcp.types import EmbeddedResource, TextContent, TextResourceContents class TestRenderPrompt: diff --git a/tests/server/mcpserver/prompts/test_manager.py b/tests/server/mcpserver/prompts/test_manager.py index 0e30b2e697..02f91c6802 100644 --- a/tests/server/mcpserver/prompts/test_manager.py +++ b/tests/server/mcpserver/prompts/test_manager.py @@ -1,7 +1,8 @@ import pytest -from mcp.server.mcpserver.prompts.base import Prompt, TextContent, UserMessage +from mcp.server.mcpserver.prompts.base import Prompt, UserMessage from mcp.server.mcpserver.prompts.manager import PromptManager +from mcp.types import TextContent class TestPromptManager: diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index 979dc580f8..3f253baa82 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -21,6 +21,9 @@ from mcp.types import ( AudioContent, BlobResourceContents, + Completion, + CompletionArgument, + CompletionContext, ContentBlock, EmbeddedResource, GetPromptResult, @@ -30,6 +33,7 @@ Prompt, PromptArgument, PromptMessage, + PromptReference, ReadResourceResult, Resource, ResourceTemplate, @@ -1401,6 +1405,23 @@ def prompt_fn(name: str) -> str: ... # pragma: no branch await client.get_prompt("prompt_fn") +async def test_completion_decorator() -> None: + """Test that the completion decorator registers a working handler.""" + mcp = MCPServer() + + @mcp.completion() + async def handle_completion( + ref: PromptReference, argument: CompletionArgument, context: CompletionContext | None + ) -> Completion: + assert argument.name == "style" + return Completion(values=["bold", "italic", "underline"]) + + async with Client(mcp) as client: + ref = PromptReference(type="ref/prompt", name="test") + result = await client.complete(ref=ref, argument={"name": "style", "value": "b"}) + assert result.completion.values == ["bold", "italic", "underline"] + + def test_streamable_http_no_redirect() -> None: """Test that streamable HTTP routes are correctly configured.""" mcp = MCPServer() diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index 6d1634f2e9..297f3d6a5c 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -1,12 +1,10 @@ """Test that cancelled requests don't cause double responses.""" -from typing import Any - import anyio import pytest -from mcp import Client, types -from mcp.server.lowlevel.server import Server +from mcp import Client +from mcp.server import Server, ServerRequestContext from mcp.shared.exceptions import MCPError from mcp.types import ( CallToolRequest, @@ -14,6 +12,9 @@ CallToolResult, CancelledNotification, CancelledNotificationParams, + ListToolsResult, + PaginatedRequestParams, + TextContent, Tool, ) @@ -22,34 +23,34 @@ async def test_server_remains_functional_after_cancel(): """Verify server can handle new requests after a cancellation.""" - server = Server("test-server") - # Track tool calls call_count = 0 ev_first_call = anyio.Event() first_request_id = None - @server.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="test_tool", - description="Tool for testing", - input_schema={}, - ) - ] + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="test_tool", + description="Tool for testing", + input_schema={}, + ) + ] + ) - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[types.TextContent]: + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: nonlocal call_count, first_request_id - if name == "test_tool": + if params.name == "test_tool": call_count += 1 if call_count == 1: - first_request_id = server.request_context.request_id + first_request_id = ctx.request_id ev_first_call.set() await anyio.sleep(5) # First call is slow - return [types.TextContent(type="text", text=f"Call number: {call_count}")] - raise ValueError(f"Unknown tool: {name}") # pragma: no cover + return CallToolResult(content=[TextContent(type="text", text=f"Call number: {call_count}")]) + raise ValueError(f"Unknown tool: {params.name}") # pragma: no cover + + server = Server("test-server", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) async with Client(server) as client: # First request (will be cancelled) @@ -86,6 +87,6 @@ async def first_request(): # Type narrowing for pyright content = result.content[0] assert content.type == "text" - assert isinstance(content, types.TextContent) + assert isinstance(content, TextContent) assert content.text == "Call number: 2" assert call_count == 2 diff --git a/tests/server/test_completion_with_context.py b/tests/server/test_completion_with_context.py index 5a8d67f09e..a01d0d4d72 100644 --- a/tests/server/test_completion_with_context.py +++ b/tests/server/test_completion_with_context.py @@ -1,15 +1,13 @@ """Tests for completion handler with context functionality.""" -from typing import Any - import pytest from mcp import Client -from mcp.server.lowlevel import Server +from mcp.server import Server, ServerRequestContext from mcp.types import ( + CompleteRequestParams, + CompleteResult, Completion, - CompletionArgument, - CompletionContext, PromptReference, ResourceTemplateReference, ) @@ -18,23 +16,15 @@ @pytest.mark.anyio async def test_completion_handler_receives_context(): """Test that the completion handler receives context correctly.""" - server = Server("test-server") - # Track what the handler receives - received_args: dict[str, Any] = {} + received_params: CompleteRequestParams | None = None - @server.completion() - async def handle_completion( - ref: PromptReference | ResourceTemplateReference, - argument: CompletionArgument, - context: CompletionContext | None, - ) -> Completion | None: - received_args["ref"] = ref - received_args["argument"] = argument - received_args["context"] = context + async def handle_completion(ctx: ServerRequestContext, params: CompleteRequestParams) -> CompleteResult: + nonlocal received_params + received_params = params + return CompleteResult(completion=Completion(values=["test-completion"], total=1, has_more=False)) - # Return test completion - return Completion(values=["test-completion"], total=1, has_more=False) + server = Server("test-server", on_completion=handle_completion) async with Client(server) as client: # Test with context @@ -45,28 +35,23 @@ async def handle_completion( ) # Verify handler received the context - assert received_args["context"] is not None - assert received_args["context"].arguments == {"previous": "value"} + assert received_params is not None + assert received_params.context is not None + assert received_params.context.arguments == {"previous": "value"} assert result.completion.values == ["test-completion"] @pytest.mark.anyio async def test_completion_backward_compatibility(): """Test that completion works without context (backward compatibility).""" - server = Server("test-server") - context_was_none = False - @server.completion() - async def handle_completion( - ref: PromptReference | ResourceTemplateReference, - argument: CompletionArgument, - context: CompletionContext | None, - ) -> Completion | None: + async def handle_completion(ctx: ServerRequestContext, params: CompleteRequestParams) -> CompleteResult: nonlocal context_was_none - context_was_none = context is None + context_was_none = params.context is None + return CompleteResult(completion=Completion(values=["no-context-completion"], total=1, has_more=False)) - return Completion(values=["no-context-completion"], total=1, has_more=False) + server = Server("test-server", on_completion=handle_completion) async with Client(server) as client: # Test without context @@ -82,30 +67,31 @@ async def handle_completion( @pytest.mark.anyio async def test_dependent_completion_scenario(): """Test a real-world scenario with dependent completions.""" - server = Server("test-server") - - @server.completion() - async def handle_completion( - ref: PromptReference | ResourceTemplateReference, - argument: CompletionArgument, - context: CompletionContext | None, - ) -> Completion | None: + + async def handle_completion(ctx: ServerRequestContext, params: CompleteRequestParams) -> CompleteResult: # Simulate database/table completion scenario - if isinstance(ref, ResourceTemplateReference): - if ref.uri == "db://{database}/{table}": - if argument.name == "database": - # Complete database names - return Completion(values=["users_db", "products_db", "analytics_db"], total=3, has_more=False) - elif argument.name == "table": - # Complete table names based on selected database - if context and context.arguments: - db = context.arguments.get("database") - if db == "users_db": - return Completion(values=["users", "sessions", "permissions"], total=3, has_more=False) - elif db == "products_db": - return Completion(values=["products", "categories", "inventory"], total=3, has_more=False) - - return Completion(values=[], total=0, has_more=False) # pragma: no cover + assert isinstance(params.ref, ResourceTemplateReference) + assert params.ref.uri == "db://{database}/{table}" + + if params.argument.name == "database": + return CompleteResult( + completion=Completion(values=["users_db", "products_db", "analytics_db"], total=3, has_more=False) + ) + + assert params.argument.name == "table" + assert params.context and params.context.arguments + db = params.context.arguments.get("database") + if db == "users_db": + return CompleteResult( + completion=Completion(values=["users", "sessions", "permissions"], total=3, has_more=False) + ) + else: + assert db == "products_db" + return CompleteResult( + completion=Completion(values=["products", "categories", "inventory"], total=3, has_more=False) + ) + + server = Server("test-server", on_completion=handle_completion) async with Client(server) as client: # First, complete database @@ -136,27 +122,20 @@ async def handle_completion( @pytest.mark.anyio async def test_completion_error_on_missing_context(): """Test that server can raise error when required context is missing.""" - server = Server("test-server") - - @server.completion() - async def handle_completion( - ref: PromptReference | ResourceTemplateReference, - argument: CompletionArgument, - context: CompletionContext | None, - ) -> Completion | None: - if isinstance(ref, ResourceTemplateReference): - if ref.uri == "db://{database}/{table}": - if argument.name == "table": - # Check if database context is provided - if not context or not context.arguments or "database" not in context.arguments: - # Raise an error instead of returning error as completion - raise ValueError("Please select a database first to see available tables") - # Normal completion if context is provided - db = context.arguments.get("database") - if db == "test_db": - return Completion(values=["users", "orders", "products"], total=3, has_more=False) - - return Completion(values=[], total=0, has_more=False) # pragma: no cover + + async def handle_completion(ctx: ServerRequestContext, params: CompleteRequestParams) -> CompleteResult: + assert isinstance(params.ref, ResourceTemplateReference) + assert params.ref.uri == "db://{database}/{table}" + assert params.argument.name == "table" + + if not params.context or not params.context.arguments or "database" not in params.context.arguments: + raise ValueError("Please select a database first to see available tables") + + db = params.context.arguments.get("database") + assert db == "test_db" + return CompleteResult(completion=Completion(values=["users", "orders", "products"], total=3, has_more=False)) + + server = Server("test-server", on_completion=handle_completion) async with Client(server) as client: # Try to complete table without database context - should raise error diff --git a/tests/server/test_lifespan.py b/tests/server/test_lifespan.py index a303664a54..0f8840d291 100644 --- a/tests/server/test_lifespan.py +++ b/tests/server/test_lifespan.py @@ -2,18 +2,20 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager -from typing import Any import anyio import pytest from pydantic import TypeAdapter +from mcp.server import ServerRequestContext from mcp.server.lowlevel.server import NotificationOptions, Server from mcp.server.mcpserver import Context, MCPServer from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.shared.message import SessionMessage from mcp.types import ( + CallToolRequestParams, + CallToolResult, ClientCapabilities, Implementation, InitializeRequestParams, @@ -39,20 +41,20 @@ async def test_lifespan(server: Server) -> AsyncIterator[dict[str, bool]]: finally: context["shutdown"] = True - server = Server[dict[str, bool]]("test", lifespan=test_lifespan) - - # Create memory streams for testing - send_stream1, receive_stream1 = anyio.create_memory_object_stream[SessionMessage](100) - send_stream2, receive_stream2 = anyio.create_memory_object_stream[SessionMessage](100) - # Create a tool that accesses lifespan context - @server.call_tool() - async def check_lifespan(name: str, arguments: dict[str, Any]) -> list[TextContent]: - ctx = server.request_context + async def check_lifespan( + ctx: ServerRequestContext[dict[str, bool]], params: CallToolRequestParams + ) -> CallToolResult: assert isinstance(ctx.lifespan_context, dict) assert ctx.lifespan_context["started"] assert not ctx.lifespan_context["shutdown"] - return [TextContent(type="text", text="true")] + return CallToolResult(content=[TextContent(type="text", text="true")]) + + server = Server[dict[str, bool]]("test", lifespan=test_lifespan, on_call_tool=check_lifespan) + + # Create memory streams for testing + send_stream1, receive_stream1 = anyio.create_memory_object_stream[SessionMessage](100) + send_stream2, receive_stream2 = anyio.create_memory_object_stream[SessionMessage](100) # Run server in background task async with anyio.create_task_group() as tg, send_stream1, receive_stream1, send_stream2, receive_stream2: diff --git a/tests/server/test_lowlevel_input_validation.py b/tests/server/test_lowlevel_input_validation.py deleted file mode 100644 index 3f977bcc1b..0000000000 --- a/tests/server/test_lowlevel_input_validation.py +++ /dev/null @@ -1,311 +0,0 @@ -"""Test input schema validation for lowlevel server.""" - -import logging -from collections.abc import Awaitable, Callable -from typing import Any - -import anyio -import pytest - -from mcp.client.session import ClientSession -from mcp.server import Server -from mcp.server.lowlevel import NotificationOptions -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession -from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder -from mcp.types import CallToolResult, ClientResult, ServerNotification, ServerRequest, TextContent, Tool - - -async def run_tool_test( - tools: list[Tool], - call_tool_handler: Callable[[str, dict[str, Any]], Awaitable[list[TextContent]]], - test_callback: Callable[[ClientSession], Awaitable[CallToolResult]], -) -> CallToolResult | None: - """Helper to run a tool test with minimal boilerplate. - - Args: - tools: List of tools to register - call_tool_handler: Handler function for tool calls - test_callback: Async function that performs the test using the client session - - Returns: - The result of the tool call - """ - server = Server("test") - result = None - - @server.list_tools() - async def list_tools(): - return tools - - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: - return await call_tool_handler(name, arguments) - - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - # Message handler for client - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: - if isinstance(message, Exception): # pragma: no cover - raise message - - # Server task - async def run_server(): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), - ) as server_session: - async with anyio.create_task_group() as tg: - - async def handle_messages(): - async for message in server_session.incoming_messages: # pragma: no branch - await server._handle_message(message, server_session, {}, False) - - tg.start_soon(handle_messages) - await anyio.sleep_forever() - - # Run the test - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - # Initialize the session - await client_session.initialize() - - # Run the test callback - result = await test_callback(client_session) - - # Cancel the server task - tg.cancel_scope.cancel() - - return result - - -def create_add_tool() -> Tool: - """Create a standard 'add' tool for testing.""" - return Tool( - name="add", - description="Add two numbers", - input_schema={ - "type": "object", - "properties": { - "a": {"type": "number"}, - "b": {"type": "number"}, - }, - "required": ["a", "b"], - "additionalProperties": False, - }, - ) - - -@pytest.mark.anyio -async def test_valid_tool_call(): - """Test that valid arguments pass validation.""" - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: - if name == "add": - result = arguments["a"] + arguments["b"] - return [TextContent(type="text", text=f"Result: {result}")] - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("add", {"a": 5, "b": 3}) - - result = await run_tool_test([create_add_tool()], call_tool_handler, test_callback) - - # Verify results - assert result is not None - assert not result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Result: 8" - - -@pytest.mark.anyio -async def test_invalid_tool_call_missing_required(): - """Test that missing required arguments fail validation.""" - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: # pragma: no cover - # This should not be reached due to validation - raise RuntimeError("Should not reach here") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("add", {"a": 5}) # missing 'b' - - result = await run_tool_test([create_add_tool()], call_tool_handler, test_callback) - - # Verify results - assert result is not None - assert result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert "Input validation error" in result.content[0].text - assert "'b' is a required property" in result.content[0].text - - -@pytest.mark.anyio -async def test_invalid_tool_call_wrong_type(): - """Test that wrong argument types fail validation.""" - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: # pragma: no cover - # This should not be reached due to validation - raise RuntimeError("Should not reach here") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("add", {"a": "five", "b": 3}) # 'a' should be number - - result = await run_tool_test([create_add_tool()], call_tool_handler, test_callback) - - # Verify results - assert result is not None - assert result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert "Input validation error" in result.content[0].text - assert "'five' is not of type 'number'" in result.content[0].text - - -@pytest.mark.anyio -async def test_cache_refresh_on_missing_tool(): - """Test that tool cache is refreshed when tool is not found.""" - tools = [ - Tool( - name="multiply", - description="Multiply two numbers", - input_schema={ - "type": "object", - "properties": { - "x": {"type": "number"}, - "y": {"type": "number"}, - }, - "required": ["x", "y"], - }, - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: - if name == "multiply": - result = arguments["x"] * arguments["y"] - return [TextContent(type="text", text=f"Result: {result}")] - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - # Call tool without first listing tools (cache should be empty) - # The cache should be refreshed automatically - return await client_session.call_tool("multiply", {"x": 10, "y": 20}) - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify results - should work because cache will be refreshed - assert result is not None - assert not result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Result: 200" - - -@pytest.mark.anyio -async def test_enum_constraint_validation(): - """Test that enum constraints are validated.""" - tools = [ - Tool( - name="greet", - description="Greet someone", - input_schema={ - "type": "object", - "properties": { - "name": {"type": "string"}, - "title": {"type": "string", "enum": ["Mr", "Ms", "Dr"]}, - }, - "required": ["name"], - }, - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: # pragma: no cover - # This should not be reached due to validation failure - raise RuntimeError("Should not reach here") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("greet", {"name": "Smith", "title": "Prof"}) # Invalid title - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify results - assert result is not None - assert result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert "Input validation error" in result.content[0].text - assert "'Prof' is not one of" in result.content[0].text - - -@pytest.mark.anyio -async def test_tool_not_in_list_logs_warning(caplog: pytest.LogCaptureFixture): - """Test that calling a tool not in list_tools logs a warning and skips validation.""" - tools = [ - Tool( - name="add", - description="Add two numbers", - input_schema={ - "type": "object", - "properties": { - "a": {"type": "number"}, - "b": {"type": "number"}, - }, - "required": ["a", "b"], - }, - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: - # This should be reached since validation is skipped for unknown tools - if name == "unknown_tool": - # Even with invalid arguments, this should execute since validation is skipped - return [TextContent(type="text", text="Unknown tool executed without validation")] - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - # Call a tool that's not in the list with invalid arguments - # This should trigger the warning about validation not being performed - return await client_session.call_tool("unknown_tool", {"invalid": "args"}) - - with caplog.at_level(logging.WARNING): - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify results - should succeed because validation is skipped for unknown tools - assert result is not None - assert not result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Unknown tool executed without validation" - - # Verify warning was logged - assert any( - "Tool 'unknown_tool' not listed, no validation will be performed" in record.message for record in caplog.records - ) diff --git a/tests/server/test_lowlevel_output_validation.py b/tests/server/test_lowlevel_output_validation.py deleted file mode 100644 index 92d9c047ca..0000000000 --- a/tests/server/test_lowlevel_output_validation.py +++ /dev/null @@ -1,476 +0,0 @@ -"""Test output schema validation for lowlevel server.""" - -import json -from collections.abc import Awaitable, Callable -from typing import Any - -import anyio -import pytest - -from mcp.client.session import ClientSession -from mcp.server import Server -from mcp.server.lowlevel import NotificationOptions -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession -from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder -from mcp.types import CallToolResult, ClientResult, ServerNotification, ServerRequest, TextContent, Tool - - -async def run_tool_test( - tools: list[Tool], - call_tool_handler: Callable[[str, dict[str, Any]], Awaitable[Any]], - test_callback: Callable[[ClientSession], Awaitable[CallToolResult]], -) -> CallToolResult | None: - """Helper to run a tool test with minimal boilerplate. - - Args: - tools: List of tools to register - call_tool_handler: Handler function for tool calls - test_callback: Async function that performs the test using the client session - - Returns: - The result of the tool call - """ - server = Server("test") - - result = None - - @server.list_tools() - async def list_tools(): - return tools - - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]): - return await call_tool_handler(name, arguments) - - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - # Message handler for client - async def message_handler( # pragma: no cover - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: - if isinstance(message, Exception): - raise message - - # Server task - async def run_server(): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), - ) as server_session: - async with anyio.create_task_group() as tg: - - async def handle_messages(): - async for message in server_session.incoming_messages: # pragma: no branch - await server._handle_message(message, server_session, {}, False) - - tg.start_soon(handle_messages) - await anyio.sleep_forever() - - # Run the test - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - # Initialize the session - await client_session.initialize() - - # Run the test callback - result = await test_callback(client_session) - - # Cancel the server task - tg.cancel_scope.cancel() - - return result - - -@pytest.mark.anyio -async def test_content_only_without_output_schema(): - """Test returning content only when no outputSchema is defined.""" - tools = [ - Tool( - name="echo", - description="Echo a message", - input_schema={ - "type": "object", - "properties": { - "message": {"type": "string"}, - }, - "required": ["message"], - }, - # No outputSchema defined - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: - if name == "echo": - return [TextContent(type="text", text=f"Echo: {arguments['message']}")] - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("echo", {"message": "Hello"}) - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify results - assert result is not None - assert not result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Echo: Hello" - assert result.structured_content is None - - -@pytest.mark.anyio -async def test_dict_only_without_output_schema(): - """Test returning dict only when no outputSchema is defined.""" - tools = [ - Tool( - name="get_info", - description="Get structured information", - input_schema={ - "type": "object", - "properties": {}, - }, - # No outputSchema defined - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> dict[str, Any]: - if name == "get_info": - return {"status": "ok", "data": {"value": 42}} - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("get_info", {}) - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify results - assert result is not None - assert not result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - # Check that the content is the JSON serialization - assert json.loads(result.content[0].text) == {"status": "ok", "data": {"value": 42}} - assert result.structured_content == {"status": "ok", "data": {"value": 42}} - - -@pytest.mark.anyio -async def test_both_content_and_dict_without_output_schema(): - """Test returning both content and dict when no outputSchema is defined.""" - tools = [ - Tool( - name="process", - description="Process data", - input_schema={ - "type": "object", - "properties": {}, - }, - # No outputSchema defined - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> tuple[list[TextContent], dict[str, Any]]: - if name == "process": - content = [TextContent(type="text", text="Processing complete")] - data = {"result": "success", "count": 10} - return (content, data) - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("process", {}) - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify results - assert result is not None - assert not result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Processing complete" - assert result.structured_content == {"result": "success", "count": 10} - - -@pytest.mark.anyio -async def test_content_only_with_output_schema_error(): - """Test error when outputSchema is defined but only content is returned.""" - tools = [ - Tool( - name="structured_tool", - description="Tool expecting structured output", - input_schema={ - "type": "object", - "properties": {}, - }, - output_schema={ - "type": "object", - "properties": { - "result": {"type": "string"}, - }, - "required": ["result"], - }, - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: - # This returns only content, but outputSchema expects structured data - return [TextContent(type="text", text="This is not structured")] - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("structured_tool", {}) - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify error - assert result is not None - assert result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert "Output validation error: outputSchema defined but no structured output returned" in result.content[0].text - - -@pytest.mark.anyio -async def test_valid_dict_with_output_schema(): - """Test valid dict output matching outputSchema.""" - tools = [ - Tool( - name="calc", - description="Calculate result", - input_schema={ - "type": "object", - "properties": { - "x": {"type": "number"}, - "y": {"type": "number"}, - }, - "required": ["x", "y"], - }, - output_schema={ - "type": "object", - "properties": { - "sum": {"type": "number"}, - "product": {"type": "number"}, - }, - "required": ["sum", "product"], - }, - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> dict[str, Any]: - if name == "calc": - x = arguments["x"] - y = arguments["y"] - return {"sum": x + y, "product": x * y} - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("calc", {"x": 3, "y": 4}) - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify results - assert result is not None - assert not result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - # Check JSON serialization - assert json.loads(result.content[0].text) == {"sum": 7, "product": 12} - assert result.structured_content == {"sum": 7, "product": 12} - - -@pytest.mark.anyio -async def test_invalid_dict_with_output_schema(): - """Test dict output that doesn't match outputSchema.""" - tools = [ - Tool( - name="user_info", - description="Get user information", - input_schema={ - "type": "object", - "properties": {}, - }, - output_schema={ - "type": "object", - "properties": { - "name": {"type": "string"}, - "age": {"type": "integer"}, - }, - "required": ["name", "age"], - }, - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> dict[str, Any]: - if name == "user_info": - # Missing required 'age' field - return {"name": "Alice"} - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("user_info", {}) - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify error - assert result is not None - assert result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert "Output validation error:" in result.content[0].text - assert "'age' is a required property" in result.content[0].text - - -@pytest.mark.anyio -async def test_both_content_and_valid_dict_with_output_schema(): - """Test returning both content and valid dict with outputSchema.""" - tools = [ - Tool( - name="analyze", - description="Analyze data", - input_schema={ - "type": "object", - "properties": { - "text": {"type": "string"}, - }, - "required": ["text"], - }, - output_schema={ - "type": "object", - "properties": { - "sentiment": {"type": "string", "enum": ["positive", "negative", "neutral"]}, - "confidence": {"type": "number", "minimum": 0, "maximum": 1}, - }, - "required": ["sentiment", "confidence"], - }, - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> tuple[list[TextContent], dict[str, Any]]: - if name == "analyze": - content = [TextContent(type="text", text=f"Analysis of: {arguments['text']}")] - data = {"sentiment": "positive", "confidence": 0.95} - return (content, data) - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("analyze", {"text": "Great job!"}) - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify results - assert result is not None - assert not result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert result.content[0].text == "Analysis of: Great job!" - assert result.structured_content == {"sentiment": "positive", "confidence": 0.95} - - -@pytest.mark.anyio -async def test_tool_call_result(): - """Test returning ToolCallResult when no outputSchema is defined.""" - tools = [ - Tool( - name="get_info", - description="Get structured information", - input_schema={ - "type": "object", - "properties": {}, - }, - # No outputSchema for direct return of tool call result - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> CallToolResult: - if name == "get_info": - return CallToolResult( - content=[TextContent(type="text", text="Results calculated")], - structured_content={"status": "ok", "data": {"value": 42}}, - _meta={"some": "metadata"}, - ) - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("get_info", {}) - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify results - assert result is not None - assert not result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert result.content[0].text == "Results calculated" - assert isinstance(result.content[0], TextContent) - assert result.structured_content == {"status": "ok", "data": {"value": 42}} - assert result.meta == {"some": "metadata"} - - -@pytest.mark.anyio -async def test_output_schema_type_validation(): - """Test outputSchema validates types correctly.""" - tools = [ - Tool( - name="stats", - description="Get statistics", - input_schema={ - "type": "object", - "properties": {}, - }, - output_schema={ - "type": "object", - "properties": { - "count": {"type": "integer"}, - "average": {"type": "number"}, - "items": {"type": "array", "items": {"type": "string"}}, - }, - "required": ["count", "average", "items"], - }, - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> dict[str, Any]: - if name == "stats": - # Wrong type for 'count' - should be integer - return {"count": "five", "average": 2.5, "items": ["a", "b"]} - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("stats", {}) - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify error - assert result is not None - assert result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert "Output validation error:" in result.content[0].text - assert "'five' is not of type 'integer'" in result.content[0].text diff --git a/tests/server/test_lowlevel_tool_annotations.py b/tests/server/test_lowlevel_tool_annotations.py index 68543136eb..705abdfe8c 100644 --- a/tests/server/test_lowlevel_tool_annotations.py +++ b/tests/server/test_lowlevel_tool_annotations.py @@ -1,100 +1,44 @@ """Tests for tool annotations in low-level server.""" -import anyio import pytest -from mcp.client.session import ClientSession -from mcp.server import Server -from mcp.server.lowlevel import NotificationOptions -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession -from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder -from mcp.types import ClientResult, ServerNotification, ServerRequest, Tool, ToolAnnotations +from mcp import Client +from mcp.server import Server, ServerRequestContext +from mcp.types import ListToolsResult, PaginatedRequestParams, Tool, ToolAnnotations @pytest.mark.anyio async def test_lowlevel_server_tool_annotations(): """Test that tool annotations work in low-level server.""" - server = Server("test") - # Create a tool with annotations - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="echo", - description="Echo a message back", - input_schema={ - "type": "object", - "properties": { - "message": {"type": "string"}, + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="echo", + description="Echo a message back", + input_schema={ + "type": "object", + "properties": { + "message": {"type": "string"}, + }, + "required": ["message"], }, - "required": ["message"], - }, - annotations=ToolAnnotations( - title="Echo Tool", - read_only_hint=True, - ), - ) - ] - - tools_result = None - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - # Message handler for client - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: - if isinstance(message, Exception): # pragma: no cover - raise message - - # Server task - async def run_server(): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), - ) as server_session: - async with anyio.create_task_group() as tg: - - async def handle_messages(): - async for message in server_session.incoming_messages: # pragma: no branch - await server._handle_message(message, server_session, {}, False) - - tg.start_soon(handle_messages) - await anyio.sleep_forever() - - # Run the test - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - # Initialize the session - await client_session.initialize() - - # List tools - tools_result = await client_session.list_tools() - - # Cancel the server task - tg.cancel_scope.cancel() - - # Verify results - assert tools_result is not None - assert len(tools_result.tools) == 1 - assert tools_result.tools[0].name == "echo" - assert tools_result.tools[0].annotations is not None - assert tools_result.tools[0].annotations.title == "Echo Tool" - assert tools_result.tools[0].annotations.read_only_hint is True + annotations=ToolAnnotations( + title="Echo Tool", + read_only_hint=True, + ), + ) + ] + ) + + server = Server("test", on_list_tools=handle_list_tools) + + async with Client(server) as client: + tools_result = await client.list_tools() + + assert len(tools_result.tools) == 1 + assert tools_result.tools[0].name == "echo" + assert tools_result.tools[0].annotations is not None + assert tools_result.tools[0].annotations.title == "Echo Tool" + assert tools_result.tools[0].annotations.read_only_hint is True diff --git a/tests/server/test_read_resource.py b/tests/server/test_read_resource.py index 88fd1e38ff..102a58d039 100644 --- a/tests/server/test_read_resource.py +++ b/tests/server/test_read_resource.py @@ -1,106 +1,58 @@ -from collections.abc import Iterable -from pathlib import Path -from tempfile import NamedTemporaryFile +import base64 import pytest -from mcp import types -from mcp.server.lowlevel.server import ReadResourceContents, Server +from mcp import Client +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + BlobResourceContents, + ReadResourceRequestParams, + ReadResourceResult, + TextResourceContents, +) +pytestmark = pytest.mark.anyio -@pytest.fixture -def temp_file(): - """Create a temporary file for testing.""" - with NamedTemporaryFile(mode="w", delete=False) as f: - f.write("test content") - path = Path(f.name).resolve() - yield path - try: - path.unlink() - except FileNotFoundError: # pragma: no cover - pass +async def test_read_resource_text(): + async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: + return ReadResourceResult( + contents=[TextResourceContents(uri=str(params.uri), text="Hello World", mime_type="text/plain")] + ) -@pytest.mark.anyio -async def test_read_resource_text(temp_file: Path): - server = Server("test") + server = Server("test", on_read_resource=handle_read_resource) - @server.read_resource() - async def read_resource(uri: str) -> Iterable[ReadResourceContents]: - return [ReadResourceContents(content="Hello World", mime_type="text/plain")] + async with Client(server) as client: + result = await client.read_resource("test://resource") + assert len(result.contents) == 1 - # Get the handler directly from the server - handler = server.request_handlers[types.ReadResourceRequest] + content = result.contents[0] + assert isinstance(content, TextResourceContents) + assert content.text == "Hello World" + assert content.mime_type == "text/plain" - # Create a request - request = types.ReadResourceRequest( - params=types.ReadResourceRequestParams(uri=temp_file.as_uri()), - ) - # Call the handler - result = await handler(request) - assert isinstance(result, types.ReadResourceResult) - assert len(result.contents) == 1 +async def test_read_resource_binary(): + binary_data = b"Hello World" - content = result.contents[0] - assert isinstance(content, types.TextResourceContents) - assert content.text == "Hello World" - assert content.mime_type == "text/plain" + async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: + return ReadResourceResult( + contents=[ + BlobResourceContents( + uri=str(params.uri), + blob=base64.b64encode(binary_data).decode("utf-8"), + mime_type="application/octet-stream", + ) + ] + ) + server = Server("test", on_read_resource=handle_read_resource) -@pytest.mark.anyio -async def test_read_resource_binary(temp_file: Path): - server = Server("test") + async with Client(server) as client: + result = await client.read_resource("test://resource") + assert len(result.contents) == 1 - @server.read_resource() - async def read_resource(uri: str) -> Iterable[ReadResourceContents]: - return [ReadResourceContents(content=b"Hello World", mime_type="application/octet-stream")] - - # Get the handler directly from the server - handler = server.request_handlers[types.ReadResourceRequest] - - # Create a request - request = types.ReadResourceRequest( - params=types.ReadResourceRequestParams(uri=temp_file.as_uri()), - ) - - # Call the handler - result = await handler(request) - assert isinstance(result, types.ReadResourceResult) - assert len(result.contents) == 1 - - content = result.contents[0] - assert isinstance(content, types.BlobResourceContents) - assert content.mime_type == "application/octet-stream" - - -@pytest.mark.anyio -async def test_read_resource_default_mime(temp_file: Path): - server = Server("test") - - @server.read_resource() - async def read_resource(uri: str) -> Iterable[ReadResourceContents]: - return [ - ReadResourceContents( - content="Hello World", - # No mime_type specified, should default to text/plain - ) - ] - - # Get the handler directly from the server - handler = server.request_handlers[types.ReadResourceRequest] - - # Create a request - request = types.ReadResourceRequest( - params=types.ReadResourceRequestParams(uri=temp_file.as_uri()), - ) - - # Call the handler - result = await handler(request) - assert isinstance(result, types.ReadResourceResult) - assert len(result.contents) == 1 - - content = result.contents[0] - assert isinstance(content, types.TextResourceContents) - assert content.text == "Hello World" - assert content.mime_type == "text/plain" + content = result.contents[0] + assert isinstance(content, BlobResourceContents) + assert content.mime_type == "application/octet-stream" + assert base64.b64decode(content.blob) == binary_data diff --git a/tests/server/test_session.py b/tests/server/test_session.py index d353e46e45..a2786d865d 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -5,7 +5,7 @@ from mcp import types from mcp.client.session import ClientSession -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession @@ -14,17 +14,10 @@ from mcp.shared.session import RequestResponder from mcp.types import ( ClientNotification, - Completion, - CompletionArgument, - CompletionContext, CompletionsCapability, InitializedNotification, - Prompt, - PromptReference, PromptsCapability, - Resource, ResourcesCapability, - ResourceTemplateReference, ServerCapabilities, ) @@ -85,47 +78,50 @@ async def run_server(): @pytest.mark.anyio async def test_server_capabilities(): - server = Server("test") notification_options = NotificationOptions() experimental_capabilities: dict[str, Any] = {} - # Initially no capabilities + async def noop_list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListPromptsResult: + raise NotImplementedError + + async def noop_list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListResourcesResult: + raise NotImplementedError + + async def noop_completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> types.CompleteResult: + raise NotImplementedError + + # No capabilities + server = Server("test") caps = server.get_capabilities(notification_options, experimental_capabilities) assert caps.prompts is None assert caps.resources is None assert caps.completions is None - # Add a prompts handler - @server.list_prompts() - async def list_prompts() -> list[Prompt]: # pragma: no cover - return [] - + # With prompts handler + server = Server("test", on_list_prompts=noop_list_prompts) caps = server.get_capabilities(notification_options, experimental_capabilities) assert caps.prompts == PromptsCapability(list_changed=False) assert caps.resources is None assert caps.completions is None - # Add a resources handler - @server.list_resources() - async def list_resources() -> list[Resource]: # pragma: no cover - return [] - + # With prompts + resources handlers + server = Server("test", on_list_prompts=noop_list_prompts, on_list_resources=noop_list_resources) caps = server.get_capabilities(notification_options, experimental_capabilities) assert caps.prompts == PromptsCapability(list_changed=False) assert caps.resources == ResourcesCapability(subscribe=False, list_changed=False) assert caps.completions is None - # Add a complete handler - @server.completion() - async def complete( # pragma: no cover - ref: PromptReference | ResourceTemplateReference, - argument: CompletionArgument, - context: CompletionContext | None, - ) -> Completion | None: - return Completion( - values=["completion1", "completion2"], - ) - + # With prompts + resources + completion handlers + server = Server( + "test", + on_list_prompts=noop_list_prompts, + on_list_resources=noop_list_resources, + on_completion=noop_completion, + ) caps = server.get_capabilities(notification_options, experimental_capabilities) assert caps.prompts == PromptsCapability(list_changed=False) assert caps.resources == ResourcesCapability(subscribe=False, list_changed=False) diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 51572baa9c..475eaa167f 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -9,13 +9,12 @@ import pytest from starlette.types import Message -from mcp import Client, types +from mcp import Client from mcp.client.streamable_http import streamable_http_client -from mcp.server import streamable_http_manager -from mcp.server.lowlevel import Server +from mcp.server import Server, ServerRequestContext, streamable_http_manager from mcp.server.streamable_http import MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport from mcp.server.streamable_http_manager import StreamableHTTPSessionManager -from mcp.types import INVALID_REQUEST +from mcp.types import INVALID_REQUEST, ListToolsResult, PaginatedRequestParams @pytest.mark.anyio @@ -218,7 +217,7 @@ async def test_stateless_requests_memory_cleanup(): # Patch StreamableHTTPServerTransport constructor to track instances - original_constructor = streamable_http_manager.StreamableHTTPServerTransport + original_constructor = StreamableHTTPServerTransport def track_transport(*args: Any, **kwargs: Any) -> StreamableHTTPServerTransport: transport = original_constructor(*args, **kwargs) @@ -321,12 +320,11 @@ async def mock_receive(): @pytest.mark.anyio async def test_e2e_streamable_http_server_cleanup(): host = "testserver" - app = Server("test-server") - @app.list_tools() - async def list_tools(req: types.ListToolsRequest) -> types.ListToolsResult: - return types.ListToolsResult(tools=[]) + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[]) + app = Server("test-server", on_list_tools=handle_list_tools) mcp_app = app.streamable_http_app(host=host) async with ( mcp_app.router.lifespan_context(mcp_app), diff --git a/tests/shared/test_memory.py b/tests/shared/test_memory.py deleted file mode 100644 index 31238b9ffd..0000000000 --- a/tests/shared/test_memory.py +++ /dev/null @@ -1,30 +0,0 @@ -import pytest - -from mcp import Client -from mcp.server import Server -from mcp.types import EmptyResult, Resource - - -@pytest.fixture -def mcp_server() -> Server: - server = Server(name="test_server") - - @server.list_resources() - async def handle_list_resources(): # pragma: no cover - return [ - Resource( - uri="memory://test", - name="Test Resource", - description="A test resource", - ) - ] - - return server - - -@pytest.mark.anyio -async def test_memory_server_and_client_connection(mcp_server: Server): - """Shows how a client and server can communicate over memory streams.""" - async with Client(mcp_server) as client: - response = await client.send_ping() - assert isinstance(response, EmptyResult) diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index ab117f1f02..6b87774c0c 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -6,7 +6,7 @@ from mcp import Client, types from mcp.client.session import ClientSession -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession @@ -35,9 +35,6 @@ async def run_server(): capabilities=server.get_capabilities(NotificationOptions(), {}), ), ) as server_session: - global serv_sesh - - serv_sesh = server_session async for message in server_session.incoming_messages: try: await server._handle_message(message, server_session, {}) @@ -52,79 +49,73 @@ async def run_server(): server_progress_token = "server_token_123" client_progress_token = "client_token_456" - # Create a server with progress capability - server = Server(name="ProgressTestServer") - # Register progress handler - @server.progress_notification() - async def handle_progress( - progress_token: str | int, - progress: float, - total: float | None, - message: str | None, - ): + async def handle_progress(ctx: ServerRequestContext, params: types.ProgressNotificationParams) -> None: server_progress_updates.append( { - "token": progress_token, - "progress": progress, - "total": total, - "message": message, + "token": params.progress_token, + "progress": params.progress, + "total": params.total, + "message": params.message, } ) # Register list tool handler - @server.list_tools() - async def handle_list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="test_tool", - description="A tool that sends progress notifications types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool( + name="test_tool", + description="A tool that sends progress notifications list[types.TextContent]: + async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: # Make sure we received a progress token - if name == "test_tool": - if arguments and "_meta" in arguments: - progressToken = arguments["_meta"]["progressToken"] - - if not progressToken: # pragma: no cover - raise ValueError("Empty progress token received") - - if progressToken != client_progress_token: # pragma: no cover - raise ValueError("Server sending back incorrect progressToken") - - # Send progress notifications - await serv_sesh.send_progress_notification( - progress_token=progressToken, - progress=0.25, - total=1.0, - message="Server progress 25%", - ) + if params.name == "test_tool": + assert params.meta is not None + progress_token = params.meta.get("progress_token") + assert progress_token is not None + assert progress_token == client_progress_token + + # Send progress notifications using ctx.session + await ctx.session.send_progress_notification( + progress_token=progress_token, + progress=0.25, + total=1.0, + message="Server progress 25%", + ) - await serv_sesh.send_progress_notification( - progress_token=progressToken, - progress=0.5, - total=1.0, - message="Server progress 50%", - ) + await ctx.session.send_progress_notification( + progress_token=progress_token, + progress=0.5, + total=1.0, + message="Server progress 50%", + ) - await serv_sesh.send_progress_notification( - progress_token=progressToken, - progress=1.0, - total=1.0, - message="Server progress 100%", - ) + await ctx.session.send_progress_notification( + progress_token=progress_token, + progress=1.0, + total=1.0, + message="Server progress 100%", + ) - else: # pragma: no cover - raise ValueError("Progress token not sent.") + return types.CallToolResult(content=[types.TextContent(type="text", text="Tool executed successfully")]) - return [types.TextContent(type="text", text="Tool executed successfully")] + raise ValueError(f"Unknown tool: {params.name}") # pragma: no cover - raise ValueError(f"Unknown tool: {name}") # pragma: no cover + # Create a server with progress capability + server = Server( + name="ProgressTestServer", + on_progress=handle_progress, + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) # Client message handler to store progress notifications async def handle_client_message( @@ -164,7 +155,7 @@ async def handle_client_message( await client_session.list_tools() # Call test_tool with progress token - await client_session.call_tool("test_tool", {"_meta": {"progressToken": client_progress_token}}) + await client_session.call_tool("test_tool", meta={"progress_token": client_progress_token}) # Send progress notifications from client to server await client_session.send_progress_notification( @@ -217,22 +208,21 @@ async def test_progress_context_manager(): # Track progress updates server_progress_updates: list[dict[str, Any]] = [] - server = Server(name="ProgressContextTestServer") - progress_token = None # Register progress handler - @server.progress_notification() - async def handle_progress( - progress_token: str | int, - progress: float, - total: float | None, - message: str | None, - ): + async def handle_progress(ctx: ServerRequestContext, params: types.ProgressNotificationParams) -> None: server_progress_updates.append( - {"token": progress_token, "progress": progress, "total": total, "message": message} + { + "token": params.progress_token, + "progress": params.progress, + "total": params.total, + "message": params.message, + } ) + server = Server(name="ProgressContextTestServer", on_progress=handle_progress) + # Run server session to receive progress updates async def run_server(): # Create a server session @@ -334,30 +324,37 @@ async def failing_progress_callback(progress: float, total: float | None, messag raise ValueError("Progress callback failed!") # Create a server with a tool that sends progress notifications - server = Server(name="TestProgressServer") - - @server.call_tool() - async def handle_call_tool(name: str, arguments: Any) -> list[types.TextContent]: - if name == "progress_tool": + async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + if params.name == "progress_tool": + assert ctx.request_id is not None # Send a progress notification - await server.request_context.session.send_progress_notification( - progress_token=server.request_context.request_id, + await ctx.session.send_progress_notification( + progress_token=ctx.request_id, progress=50.0, total=100.0, message="Halfway done", ) - return [types.TextContent(type="text", text="progress_result")] - raise ValueError(f"Unknown tool: {name}") # pragma: no cover - - @server.list_tools() - async def handle_list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="progress_tool", - description="A tool that sends progress notifications", - input_schema={}, - ) - ] + return types.CallToolResult(content=[types.TextContent(type="text", text="progress_result")]) + raise ValueError(f"Unknown tool: {params.name}") # pragma: no cover + + async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool( + name="progress_tool", + description="A tool that sends progress notifications", + input_schema={}, + ) + ] + ) + + server = Server( + name="TestProgressServer", + on_call_tool=handle_call_tool, + on_list_tools=handle_list_tools, + ) # Test with mocked logging with patch("mcp.shared.session.logging.exception", side_effect=mock_log_exception): diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 182b4671df..2c220f7379 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -1,11 +1,9 @@ -from typing import Any - import anyio import pytest from mcp import Client, types from mcp.client.session import ClientSession -from mcp.server.lowlevel.server import Server +from mcp.server import Server, ServerRequestContext from mcp.shared.exceptions import MCPError from mcp.shared.memory import create_client_server_memory_streams from mcp.shared.message import SessionMessage @@ -17,7 +15,6 @@ JSONRPCError, JSONRPCRequest, JSONRPCResponse, - TextContent, ) @@ -42,29 +39,25 @@ async def test_request_cancellation(): request_id = None # Create a server with a slow tool - server = Server(name="TestSessionServer") - - # Register the tool handler - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[TextContent]: + async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: nonlocal request_id, ev_tool_called - if name == "slow_tool": - request_id = server.request_context.request_id + if params.name == "slow_tool": + request_id = ctx.request_id ev_tool_called.set() await anyio.sleep(10) # Long enough to ensure we can cancel - return [] # pragma: no cover - raise ValueError(f"Unknown tool: {name}") # pragma: no cover - - # Register the tool so it shows up in list_tools - @server.list_tools() - async def handle_list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="slow_tool", - description="A slow tool that takes 10 seconds to complete", - input_schema={}, - ) - ] + return types.CallToolResult(content=[]) # pragma: no cover + raise ValueError(f"Unknown tool: {params.name}") # pragma: no cover + + async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + raise NotImplementedError + + server = Server( + name="TestSessionServer", + on_call_tool=handle_call_tool, + on_list_tools=handle_list_tools, + ) async def make_request(client: Client): nonlocal ev_cancelled diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index df2321ba16..207364cdcb 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -22,15 +22,20 @@ from mcp import types from mcp.client.session import ClientSession from mcp.client.sse import _extract_session_id_from_endpoint, sse_client -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.exceptions import MCPError from mcp.types import ( + CallToolRequestParams, + CallToolResult, EmptyResult, Implementation, InitializeResult, JSONRPCResponse, + ListToolsResult, + PaginatedRequestParams, + ReadResourceRequestParams, ReadResourceResult, ServerCapabilities, TextContent, @@ -54,36 +59,48 @@ def server_url(server_port: int) -> str: return f"http://127.0.0.1:{server_port}" -# Test server implementation -class ServerTest(Server): # pragma: no cover - def __init__(self): - super().__init__(SERVER_NAME) - - @self.read_resource() - async def handle_read_resource(uri: str) -> str | bytes: - parsed = urlparse(uri) - if parsed.scheme == "foobar": - return f"Read {parsed.netloc}" - if parsed.scheme == "slow": - # Simulate a slow resource - await anyio.sleep(2.0) - return f"Slow response from {parsed.netloc}" - - raise MCPError(code=404, message="OOPS! no resource with that URI was found") - - @self.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="test_tool", - description="A test tool", - input_schema={"type": "object", "properties": {}}, - ) - ] +async def _handle_read_resource( # pragma: no cover + ctx: ServerRequestContext, params: ReadResourceRequestParams +) -> ReadResourceResult: + uri = str(params.uri) + parsed = urlparse(uri) + if parsed.scheme == "foobar": + text = f"Read {parsed.netloc}" + elif parsed.scheme == "slow": + await anyio.sleep(2.0) + text = f"Slow response from {parsed.netloc}" + else: + raise MCPError(code=404, message="OOPS! no resource with that URI was found") + return ReadResourceResult(contents=[TextResourceContents(uri=uri, text=text, mime_type="text/plain")]) + + +async def _handle_list_tools( # pragma: no cover + ctx: ServerRequestContext, params: PaginatedRequestParams | None +) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="test_tool", + description="A test tool", + input_schema={"type": "object", "properties": {}}, + ) + ] + ) + - @self.call_tool() - async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: - return [TextContent(type="text", text=f"Called {name}")] +async def _handle_call_tool( # pragma: no cover + ctx: ServerRequestContext, params: CallToolRequestParams +) -> CallToolResult: + return CallToolResult(content=[TextContent(type="text", text=f"Called {params.name}")]) + + +def _create_server() -> Server: # pragma: no cover + return Server( + SERVER_NAME, + on_read_resource=_handle_read_resource, + on_list_tools=_handle_list_tools, + on_call_tool=_handle_call_tool, + ) # Test fixtures @@ -94,7 +111,7 @@ def make_server_app() -> Starlette: # pragma: no cover allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] ) sse = SseServerTransport("/messages/", security_settings=security_settings) - server = ServerTest() + server = _create_server() async def handle_sse(request: Request) -> Response: async with sse.connect_sse(request.scope, request.receive, request._send) as streams: @@ -336,47 +353,46 @@ async def test_sse_client_basic_connection_mounted_app(mounted_server: None, ser assert isinstance(ping_result, EmptyResult) -# Test server with request context that returns headers in the response -class RequestContextServer(Server[object, Request]): # pragma: no cover - def __init__(self): - super().__init__("request_context_server") - - @self.call_tool() - async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: - headers_info = {} - context = self.request_context - if context.request: - headers_info = dict(context.request.headers) - - if name == "echo_headers": - return [TextContent(type="text", text=json.dumps(headers_info))] - elif name == "echo_context": - context_data = { - "request_id": args.get("request_id"), - "headers": headers_info, - } - return [TextContent(type="text", text=json.dumps(context_data))] - - return [TextContent(type="text", text=f"Called {name}")] - - @self.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="echo_headers", - description="Echoes request headers", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="echo_context", - description="Echoes request context", - input_schema={ - "type": "object", - "properties": {"request_id": {"type": "string"}}, - "required": ["request_id"], - }, - ), - ] +async def _handle_context_call_tool( # pragma: no cover + ctx: ServerRequestContext, params: CallToolRequestParams +) -> CallToolResult: + headers_info: dict[str, Any] = {} + if ctx.request: + headers_info = dict(ctx.request.headers) + + if params.name == "echo_headers": + return CallToolResult(content=[TextContent(type="text", text=json.dumps(headers_info))]) + elif params.name == "echo_context": + context_data = { + "request_id": (params.arguments or {}).get("request_id"), + "headers": headers_info, + } + return CallToolResult(content=[TextContent(type="text", text=json.dumps(context_data))]) + + return CallToolResult(content=[TextContent(type="text", text=f"Called {params.name}")]) + + +async def _handle_context_list_tools( # pragma: no cover + ctx: ServerRequestContext, params: PaginatedRequestParams | None +) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="echo_headers", + description="Echoes request headers", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="echo_context", + description="Echoes request context", + input_schema={ + "type": "object", + "properties": {"request_id": {"type": "string"}}, + "required": ["request_id"], + }, + ), + ] + ) def run_context_server(server_port: int) -> None: # pragma: no cover @@ -386,7 +402,11 @@ def run_context_server(server_port: int) -> None: # pragma: no cover allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] ) sse = SseServerTransport("/messages/", security_settings=security_settings) - context_server = RequestContextServer() + context_server = Server( + "request_context_server", + on_call_tool=_handle_context_call_tool, + on_list_tools=_handle_context_list_tools, + ) async def handle_sse(request: Request) -> Response: async with sse.connect_sse(request.scope, request.receive, request._send) as streams: diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index b04b920262..42b1a3698a 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -10,7 +10,9 @@ import socket import time import traceback -from collections.abc import Generator +from collections.abc import AsyncIterator, Generator +from contextlib import asynccontextmanager +from dataclasses import dataclass, field from typing import Any from unittest.mock import MagicMock from urllib.parse import urlparse @@ -28,7 +30,7 @@ from mcp import MCPError, types from mcp.client.session import ClientSession from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.streamable_http import ( MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, @@ -50,7 +52,19 @@ ) from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder -from mcp.types import InitializeResult, JSONRPCRequest, TextContent, TextResourceContents, Tool +from mcp.types import ( + CallToolRequestParams, + CallToolResult, + InitializeResult, + JSONRPCRequest, + ListToolsResult, + PaginatedRequestParams, + ReadResourceRequestParams, + ReadResourceResult, + TextContent, + TextResourceContents, + Tool, +) from tests.test_helpers import wait_for_server # Test constants @@ -124,263 +138,258 @@ async def replay_events_after( # pragma: no cover return target_stream_id -# Test server implementation that follows MCP protocol -class ServerTest(Server): # pragma: no cover - def __init__(self): - super().__init__(SERVER_NAME) - self._lock = None # Will be initialized in async context - - @self.read_resource() - async def handle_read_resource(uri: str) -> str | bytes: - parsed = urlparse(uri) - if parsed.scheme == "foobar": - return f"Read {parsed.netloc}" - if parsed.scheme == "slow": - # Simulate a slow resource - await anyio.sleep(2.0) - return f"Slow response from {parsed.netloc}" - - raise ValueError(f"Unknown resource: {uri}") - - @self.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="test_tool", - description="A test tool", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="test_tool_with_standalone_notification", - description="A test tool that sends a notification", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="long_running_with_checkpoints", - description="A long-running tool that sends periodic notifications", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="test_sampling_tool", - description="A tool that triggers server-side sampling", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="wait_for_lock_with_notification", - description="A tool that sends a notification and waits for lock", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="release_lock", - description="A tool that releases the lock", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="tool_with_stream_close", - description="A tool that closes SSE stream mid-operation", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="tool_with_multiple_notifications_and_close", - description="Tool that sends notification1, closes stream, sends notification2, notification3", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="tool_with_multiple_stream_closes", - description="Tool that closes SSE stream multiple times during execution", - input_schema={ - "type": "object", - "properties": { - "checkpoints": {"type": "integer", "default": 3}, - "sleep_time": {"type": "number", "default": 0.2}, - }, +@dataclass +class ServerState: + lock: anyio.Event = field(default_factory=anyio.Event) + + +@asynccontextmanager +async def _server_lifespan(_server: Server[ServerState]) -> AsyncIterator[ServerState]: # pragma: no cover + yield ServerState() + + +async def _handle_read_resource( # pragma: no cover + ctx: ServerRequestContext[ServerState], params: ReadResourceRequestParams +) -> ReadResourceResult: + uri = str(params.uri) + parsed = urlparse(uri) + if parsed.scheme == "foobar": + text = f"Read {parsed.netloc}" + elif parsed.scheme == "slow": + await anyio.sleep(2.0) + text = f"Slow response from {parsed.netloc}" + else: + raise ValueError(f"Unknown resource: {uri}") + return ReadResourceResult(contents=[TextResourceContents(uri=uri, text=text, mime_type="text/plain")]) + + +async def _handle_list_tools( # pragma: no cover + ctx: ServerRequestContext[ServerState], params: PaginatedRequestParams | None +) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="test_tool", + description="A test tool", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="test_tool_with_standalone_notification", + description="A test tool that sends a notification", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="long_running_with_checkpoints", + description="A long-running tool that sends periodic notifications", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="test_sampling_tool", + description="A tool that triggers server-side sampling", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="wait_for_lock_with_notification", + description="A tool that sends a notification and waits for lock", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="release_lock", + description="A tool that releases the lock", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="tool_with_stream_close", + description="A tool that closes SSE stream mid-operation", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="tool_with_multiple_notifications_and_close", + description="Tool that sends notification1, closes stream, sends notification2, notification3", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="tool_with_multiple_stream_closes", + description="Tool that closes SSE stream multiple times during execution", + input_schema={ + "type": "object", + "properties": { + "checkpoints": {"type": "integer", "default": 3}, + "sleep_time": {"type": "number", "default": 0.2}, }, - ), - Tool( - name="tool_with_standalone_stream_close", - description="Tool that closes standalone GET stream mid-operation", - input_schema={"type": "object", "properties": {}}, - ), - ] - - @self.call_tool() - async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: - ctx = self.request_context + }, + ), + Tool( + name="tool_with_standalone_stream_close", + description="Tool that closes standalone GET stream mid-operation", + input_schema={"type": "object", "properties": {}}, + ), + ] + ) - # When the tool is called, send a notification to test GET stream - if name == "test_tool_with_standalone_notification": - await ctx.session.send_resource_updated(uri="http://test_resource") - return [TextContent(type="text", text=f"Called {name}")] - elif name == "long_running_with_checkpoints": - # Send notifications that are part of the response stream - # This simulates a long-running tool that sends logs +async def _handle_call_tool( # pragma: no cover + ctx: ServerRequestContext[ServerState], params: CallToolRequestParams +) -> CallToolResult: + name = params.name + args = params.arguments or {} - await ctx.session.send_log_message( - level="info", - data="Tool started", - logger="tool", - related_request_id=ctx.request_id, # need for stream association - ) + # When the tool is called, send a notification to test GET stream + if name == "test_tool_with_standalone_notification": + await ctx.session.send_resource_updated(uri="http://test_resource") + return CallToolResult(content=[TextContent(type="text", text=f"Called {name}")]) - await anyio.sleep(0.1) + elif name == "long_running_with_checkpoints": + await ctx.session.send_log_message( + level="info", + data="Tool started", + logger="tool", + related_request_id=ctx.request_id, + ) - await ctx.session.send_log_message( - level="info", - data="Tool is almost done", - logger="tool", - related_request_id=ctx.request_id, - ) + await anyio.sleep(0.1) - return [TextContent(type="text", text="Completed!")] + await ctx.session.send_log_message( + level="info", + data="Tool is almost done", + logger="tool", + related_request_id=ctx.request_id, + ) - elif name == "test_sampling_tool": - # Test sampling by requesting the client to sample a message - sampling_result = await ctx.session.create_message( - messages=[ - types.SamplingMessage( - role="user", - content=types.TextContent(type="text", text="Server needs client sampling"), - ) - ], - max_tokens=100, - related_request_id=ctx.request_id, - ) + return CallToolResult(content=[TextContent(type="text", text="Completed!")]) - # Return the sampling result in the tool response - # Since we're not passing tools param, result.content is single content - if sampling_result.content.type == "text": - response = sampling_result.content.text - else: - response = str(sampling_result.content) - return [ - TextContent( - type="text", - text=f"Response from sampling: {response}", - ) - ] - - elif name == "wait_for_lock_with_notification": - # Initialize lock if not already done - if self._lock is None: - self._lock = anyio.Event() - - # First send a notification - await ctx.session.send_log_message( - level="info", - data="First notification before lock", - logger="lock_tool", - related_request_id=ctx.request_id, + elif name == "test_sampling_tool": + sampling_result = await ctx.session.create_message( + messages=[ + types.SamplingMessage( + role="user", + content=types.TextContent(type="text", text="Server needs client sampling"), ) + ], + max_tokens=100, + related_request_id=ctx.request_id, + ) - # Now wait for the lock to be released - await self._lock.wait() - - # Send second notification after lock is released - await ctx.session.send_log_message( - level="info", - data="Second notification after lock", - logger="lock_tool", - related_request_id=ctx.request_id, + if sampling_result.content.type == "text": + response = sampling_result.content.text + else: + response = str(sampling_result.content) + return CallToolResult( + content=[ + TextContent( + type="text", + text=f"Response from sampling: {response}", ) + ] + ) - return [TextContent(type="text", text="Completed")] + elif name == "wait_for_lock_with_notification": + await ctx.session.send_log_message( + level="info", + data="First notification before lock", + logger="lock_tool", + related_request_id=ctx.request_id, + ) - elif name == "release_lock": - assert self._lock is not None, "Lock must be initialized before releasing" + await ctx.lifespan_context.lock.wait() - # Release the lock - self._lock.set() - return [TextContent(type="text", text="Lock released")] + await ctx.session.send_log_message( + level="info", + data="Second notification after lock", + logger="lock_tool", + related_request_id=ctx.request_id, + ) - elif name == "tool_with_stream_close": - # Send notification before closing - await ctx.session.send_log_message( - level="info", - data="Before close", - logger="stream_close_tool", - related_request_id=ctx.request_id, - ) - # Close SSE stream (triggers client reconnect) - assert ctx.close_sse_stream is not None - await ctx.close_sse_stream() - # Continue processing (events stored in event_store) - await anyio.sleep(0.1) - await ctx.session.send_log_message( - level="info", - data="After close", - logger="stream_close_tool", - related_request_id=ctx.request_id, - ) - return [TextContent(type="text", text="Done")] - - elif name == "tool_with_multiple_notifications_and_close": - # Send notification1 - await ctx.session.send_log_message( - level="info", - data="notification1", - logger="multi_notif_tool", - related_request_id=ctx.request_id, - ) - # Close SSE stream - assert ctx.close_sse_stream is not None - await ctx.close_sse_stream() - # Send notification2, notification3 (stored in event_store) - await anyio.sleep(0.1) - await ctx.session.send_log_message( - level="info", - data="notification2", - logger="multi_notif_tool", - related_request_id=ctx.request_id, - ) - await ctx.session.send_log_message( - level="info", - data="notification3", - logger="multi_notif_tool", - related_request_id=ctx.request_id, - ) - return [TextContent(type="text", text="All notifications sent")] + return CallToolResult(content=[TextContent(type="text", text="Completed")]) - elif name == "tool_with_multiple_stream_closes": - num_checkpoints = args.get("checkpoints", 3) - sleep_time = args.get("sleep_time", 0.2) + elif name == "release_lock": + ctx.lifespan_context.lock.set() + return CallToolResult(content=[TextContent(type="text", text="Lock released")]) - for i in range(num_checkpoints): - await ctx.session.send_log_message( - level="info", - data=f"checkpoint_{i}", - logger="multi_close_tool", - related_request_id=ctx.request_id, - ) + elif name == "tool_with_stream_close": + await ctx.session.send_log_message( + level="info", + data="Before close", + logger="stream_close_tool", + related_request_id=ctx.request_id, + ) + assert ctx.close_sse_stream is not None + await ctx.close_sse_stream() + await anyio.sleep(0.1) + await ctx.session.send_log_message( + level="info", + data="After close", + logger="stream_close_tool", + related_request_id=ctx.request_id, + ) + return CallToolResult(content=[TextContent(type="text", text="Done")]) + + elif name == "tool_with_multiple_notifications_and_close": + await ctx.session.send_log_message( + level="info", + data="notification1", + logger="multi_notif_tool", + related_request_id=ctx.request_id, + ) + assert ctx.close_sse_stream is not None + await ctx.close_sse_stream() + await anyio.sleep(0.1) + await ctx.session.send_log_message( + level="info", + data="notification2", + logger="multi_notif_tool", + related_request_id=ctx.request_id, + ) + await ctx.session.send_log_message( + level="info", + data="notification3", + logger="multi_notif_tool", + related_request_id=ctx.request_id, + ) + return CallToolResult(content=[TextContent(type="text", text="All notifications sent")]) + + elif name == "tool_with_multiple_stream_closes": + num_checkpoints = args.get("checkpoints", 3) + sleep_time = args.get("sleep_time", 0.2) + + for i in range(num_checkpoints): + await ctx.session.send_log_message( + level="info", + data=f"checkpoint_{i}", + logger="multi_close_tool", + related_request_id=ctx.request_id, + ) - if ctx.close_sse_stream: - await ctx.close_sse_stream() + if ctx.close_sse_stream: + await ctx.close_sse_stream() - await anyio.sleep(sleep_time) + await anyio.sleep(sleep_time) - return [TextContent(type="text", text=f"Completed {num_checkpoints} checkpoints")] + return CallToolResult(content=[TextContent(type="text", text=f"Completed {num_checkpoints} checkpoints")]) - elif name == "tool_with_standalone_stream_close": - # Test for GET stream reconnection - # 1. Send unsolicited notification via GET stream (no related_request_id) - await ctx.session.send_resource_updated(uri="http://notification_1") + elif name == "tool_with_standalone_stream_close": + await ctx.session.send_resource_updated(uri="http://notification_1") + await anyio.sleep(0.1) - # Small delay to ensure notification is flushed before closing - await anyio.sleep(0.1) + if ctx.close_standalone_sse_stream: + await ctx.close_standalone_sse_stream() - # 2. Close the standalone GET stream - if ctx.close_standalone_sse_stream: - await ctx.close_standalone_sse_stream() + await anyio.sleep(1.5) + await ctx.session.send_resource_updated(uri="http://notification_2") - # 3. Wait for client to reconnect (uses retry_interval from server, default 1000ms) - await anyio.sleep(1.5) + return CallToolResult(content=[TextContent(type="text", text="Standalone stream close test done")]) - # 4. Send another notification on the new GET stream connection - await ctx.session.send_resource_updated(uri="http://notification_2") + return CallToolResult(content=[TextContent(type="text", text=f"Called {name}")]) - return [TextContent(type="text", text="Standalone stream close test done")] - return [TextContent(type="text", text=f"Called {name}")] +def _create_server() -> Server[ServerState]: # pragma: no cover + return Server( + SERVER_NAME, + lifespan=_server_lifespan, + on_read_resource=_handle_read_resource, + on_list_tools=_handle_list_tools, + on_call_tool=_handle_call_tool, + ) def create_app( @@ -396,7 +405,7 @@ def create_app( retry_interval: Retry interval in milliseconds for SSE polling. """ # Create server instance - server = ServerTest() + server = _create_server() # Create the session manager security_settings = TransportSecuritySettings( @@ -1385,69 +1394,68 @@ async def sampling_callback( # Context-aware server implementation for testing request context propagation -class ContextAwareServerTest(Server): # pragma: no cover - def __init__(self): - super().__init__("ContextAwareServer") - - @self.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="echo_headers", - description="Echo request headers from context", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="echo_context", - description="Echo request context with custom data", - input_schema={ - "type": "object", - "properties": { - "request_id": {"type": "string"}, - }, - "required": ["request_id"], +async def _handle_context_list_tools( # pragma: no cover + ctx: ServerRequestContext, params: PaginatedRequestParams | None +) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="echo_headers", + description="Echo request headers from context", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="echo_context", + description="Echo request context with custom data", + input_schema={ + "type": "object", + "properties": { + "request_id": {"type": "string"}, }, - ), - ] + "required": ["request_id"], + }, + ), + ] + ) - @self.call_tool() - async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: - ctx = self.request_context - - if name == "echo_headers": - # Access the request object from context - headers_info = {} - if ctx.request and isinstance(ctx.request, Request): - headers_info = dict(ctx.request.headers) - return [TextContent(type="text", text=json.dumps(headers_info))] - - elif name == "echo_context": - # Return full context information - context_data: dict[str, Any] = { - "request_id": args.get("request_id"), - "headers": {}, - "method": None, - "path": None, - } - if ctx.request and isinstance(ctx.request, Request): - request = ctx.request - context_data["headers"] = dict(request.headers) - context_data["method"] = request.method - context_data["path"] = request.url.path - return [ - TextContent( - type="text", - text=json.dumps(context_data), - ) - ] - - return [TextContent(type="text", text=f"Unknown tool: {name}")] + +async def _handle_context_call_tool( # pragma: no cover + ctx: ServerRequestContext, params: CallToolRequestParams +) -> CallToolResult: + name = params.name + args = params.arguments or {} + + if name == "echo_headers": + headers_info: dict[str, Any] = {} + if ctx.request and isinstance(ctx.request, Request): + headers_info = dict(ctx.request.headers) + return CallToolResult(content=[TextContent(type="text", text=json.dumps(headers_info))]) + + elif name == "echo_context": + context_data: dict[str, Any] = { + "request_id": args.get("request_id"), + "headers": {}, + "method": None, + "path": None, + } + if ctx.request and isinstance(ctx.request, Request): + request = ctx.request + context_data["headers"] = dict(request.headers) + context_data["method"] = request.method + context_data["path"] = request.url.path + return CallToolResult(content=[TextContent(type="text", text=json.dumps(context_data))]) + + return CallToolResult(content=[TextContent(type="text", text=f"Unknown tool: {name}")]) # Server runner for context-aware testing def run_context_aware_server(port: int): # pragma: no cover """Run the context-aware test server.""" - server = ContextAwareServerTest() + server = Server( + "ContextAwareServer", + on_list_tools=_handle_context_list_tools, + on_call_tool=_handle_context_call_tool, + ) session_manager = StreamableHTTPSessionManager( app=server, diff --git a/tests/shared/test_ws.py b/tests/shared/test_ws.py index 07e19195d5..d828505295 100644 --- a/tests/shared/test_ws.py +++ b/tests/shared/test_ws.py @@ -2,7 +2,6 @@ import socket import time from collections.abc import AsyncGenerator, Generator -from typing import Any from urllib.parse import urlparse import anyio @@ -15,9 +14,21 @@ from mcp import MCPError from mcp.client.session import ClientSession from mcp.client.websocket import websocket_client -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.websocket import websocket_server -from mcp.types import EmptyResult, InitializeResult, ReadResourceResult, TextContent, TextResourceContents, Tool +from mcp.types import ( + CallToolRequestParams, + CallToolResult, + EmptyResult, + InitializeResult, + ListToolsResult, + PaginatedRequestParams, + ReadResourceRequestParams, + ReadResourceResult, + TextContent, + TextResourceContents, + Tool, +) from tests.test_helpers import wait_for_server SERVER_NAME = "test_server_for_WS" @@ -35,42 +46,59 @@ def server_url(server_port: int) -> str: return f"ws://127.0.0.1:{server_port}" -# Test server implementation -class ServerTest(Server): # pragma: no cover - def __init__(self): - super().__init__(SERVER_NAME) - - @self.read_resource() - async def handle_read_resource(uri: str) -> str | bytes: - parsed = urlparse(uri) - if parsed.scheme == "foobar": - return f"Read {parsed.netloc}" - elif parsed.scheme == "slow": - # Simulate a slow resource - await anyio.sleep(2.0) - return f"Slow response from {parsed.netloc}" - - raise MCPError(code=404, message="OOPS! no resource with that URI was found") - - @self.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="test_tool", - description="A test tool", - input_schema={"type": "object", "properties": {}}, +async def handle_read_resource( # pragma: no cover + ctx: ServerRequestContext, params: ReadResourceRequestParams +) -> ReadResourceResult: + parsed = urlparse(str(params.uri)) + if parsed.scheme == "foobar": + return ReadResourceResult( + contents=[TextResourceContents(uri=str(params.uri), text=f"Read {parsed.netloc}", mime_type="text/plain")] + ) + elif parsed.scheme == "slow": + await anyio.sleep(2.0) + return ReadResourceResult( + contents=[ + TextResourceContents( + uri=str(params.uri), text=f"Slow response from {parsed.netloc}", mime_type="text/plain" ) ] + ) + raise MCPError(code=404, message="OOPS! no resource with that URI was found") - @self.call_tool() - async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: - return [TextContent(type="text", text=f"Called {name}")] + +async def handle_list_tools( # pragma: no cover + ctx: ServerRequestContext, params: PaginatedRequestParams | None +) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="test_tool", + description="A test tool", + input_schema={"type": "object", "properties": {}}, + ) + ] + ) + + +async def handle_call_tool( # pragma: no cover + ctx: ServerRequestContext, params: CallToolRequestParams +) -> CallToolResult: + return CallToolResult(content=[TextContent(type="text", text=f"Called {params.name}")]) + + +def _create_server() -> Server: # pragma: no cover + return Server( + SERVER_NAME, + on_read_resource=handle_read_resource, + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) # Test fixtures def make_server_app() -> Starlette: # pragma: no cover """Create test Starlette app with WebSocket transport""" - server = ServerTest() + server = _create_server() async def handle_ws(websocket: WebSocket): async with websocket_server(websocket.scope, websocket.receive, websocket.send) as streams: From 1e0b5c04792cf59b6ab5557142cff6a72c618ea8 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 12 Feb 2026 16:22:07 +0000 Subject: [PATCH 02/84] fix: revert README.md to v1 documentation (#2045) Co-authored-by: Marcelo Trylesinski --- .pre-commit-config.yaml | 6 + README.md | 433 ++++++++++++++++++++++------------------ 2 files changed, 244 insertions(+), 195 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 03a8ae0389..42c12fdedd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -55,6 +55,12 @@ repos: language: system files: ^(pyproject\.toml|uv\.lock)$ pass_filenames: false + # TODO(Max): Drop this in v2. + - id: readme-v1-frozen + name: README.md is frozen (v1 docs) + entry: README.md is frozen at v1. Edit README.v2.md instead. + language: fail + files: ^README\.md$ - id: readme-snippets name: Check README snippets are up to date entry: uv run --frozen python scripts/update_readme_snippets.py --check diff --git a/README.md b/README.md index dc23d0d1d7..487d48bee4 100644 --- a/README.md +++ b/README.md @@ -13,12 +13,13 @@ -> [!IMPORTANT] -> **This is the `main` branch which contains v2 of the SDK (currently in development, pre-alpha).** -> -> We anticipate a stable v2 release in Q1 2026. Until then, **v1.x remains the recommended version** for production use. v1.x will continue to receive bug fixes and security updates for at least 6 months after v2 ships to give people time to upgrade. + + +> [!NOTE] +> **This README documents v1.x of the MCP Python SDK (the current stable release).** > -> For v1 documentation and code, see the [`v1.x` branch](https://github.com/modelcontextprotocol/python-sdk/tree/v1.x). +> For v1.x code and documentation, see the [`v1.x` branch](https://github.com/modelcontextprotocol/python-sdk/tree/v1.x). +> For the upcoming v2 documentation (pre-alpha, in development on `main`), see [`README.v2.md`](README.v2.md). ## Table of Contents @@ -45,7 +46,7 @@ - [Sampling](#sampling) - [Logging and Notifications](#logging-and-notifications) - [Authentication](#authentication) - - [MCPServer Properties](#mcpserver-properties) + - [FastMCP Properties](#fastmcp-properties) - [Session Properties and Methods](#session-properties-and-methods) - [Request Context Properties](#request-context-properties) - [Running Your Server](#running-your-server) @@ -134,18 +135,19 @@ uv run mcp Let's create a simple MCP server that exposes a calculator tool and some data: - + ```python -"""MCPServer quickstart example. +""" +FastMCP quickstart example. Run from the repository root: - uv run examples/snippets/servers/mcpserver_quickstart.py + uv run examples/snippets/servers/fastmcp_quickstart.py """ -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP # Create an MCP server -mcp = MCPServer("Demo") +mcp = FastMCP("Demo", json_response=True) # Add an addition tool @@ -177,16 +179,16 @@ def greet_user(name: str, style: str = "friendly") -> str: # Run with streamable HTTP transport if __name__ == "__main__": - mcp.run(transport="streamable-http", json_response=True) + mcp.run(transport="streamable-http") ``` -_Full example: [examples/snippets/servers/mcpserver_quickstart.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/mcpserver_quickstart.py)_ +_Full example: [examples/snippets/servers/fastmcp_quickstart.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/fastmcp_quickstart.py)_ You can install this server in [Claude Code](https://docs.claude.com/en/docs/claude-code/mcp) and interact with it right away. First, run the server: ```bash -uv run --with mcp examples/snippets/servers/mcpserver_quickstart.py +uv run --with mcp examples/snippets/servers/fastmcp_quickstart.py ``` Then add it to Claude Code: @@ -216,7 +218,7 @@ The [Model Context Protocol (MCP)](https://modelcontextprotocol.io) lets you bui ### Server -The MCPServer server is your core interface to the MCP protocol. It handles connection management, protocol compliance, and message routing: +The FastMCP server is your core interface to the MCP protocol. It handles connection management, protocol compliance, and message routing: ```python @@ -226,7 +228,7 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass -from mcp.server.mcpserver import Context, MCPServer +from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession @@ -256,7 +258,7 @@ class AppContext: @asynccontextmanager -async def app_lifespan(server: MCPServer) -> AsyncIterator[AppContext]: +async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]: """Manage application lifecycle with type-safe context.""" # Initialize on startup db = await Database.connect() @@ -268,7 +270,7 @@ async def app_lifespan(server: MCPServer) -> AsyncIterator[AppContext]: # Pass lifespan to server -mcp = MCPServer("My App", lifespan=app_lifespan) +mcp = FastMCP("My App", lifespan=app_lifespan) # Access type-safe lifespan context in tools @@ -288,9 +290,9 @@ Resources are how you expose data to LLMs. They're similar to GET endpoints in a ```python -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP -mcp = MCPServer(name="Resource Example") +mcp = FastMCP(name="Resource Example") @mcp.resource("file://documents/{name}") @@ -319,9 +321,9 @@ Tools let LLMs take actions through your server. Unlike resources, tools are exp ```python -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP -mcp = MCPServer(name="Tool Example") +mcp = FastMCP(name="Tool Example") @mcp.tool() @@ -340,14 +342,14 @@ def get_weather(city: str, unit: str = "celsius") -> str: _Full example: [examples/snippets/servers/basic_tool.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/basic_tool.py)_ -Tools can optionally receive a Context object by including a parameter with the `Context` type annotation. This context is automatically injected by the MCPServer framework and provides access to MCP capabilities: +Tools can optionally receive a Context object by including a parameter with the `Context` type annotation. This context is automatically injected by the FastMCP framework and provides access to MCP capabilities: ```python -from mcp.server.mcpserver import Context, MCPServer +from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession -mcp = MCPServer(name="Progress Example") +mcp = FastMCP(name="Progress Example") @mcp.tool() @@ -395,7 +397,7 @@ validated data that clients can easily process. **Note:** For backward compatibility, unstructured results are also returned. Unstructured results are provided for backward compatibility with previous versions of the MCP specification, and are quirks-compatible -with previous versions of MCPServer in the current version of the SDK. +with previous versions of FastMCP in the current version of the SDK. **Note:** In cases where a tool function's return type annotation causes the tool to be classified as structured _and this is undesirable_, @@ -414,10 +416,10 @@ from typing import Annotated from pydantic import BaseModel -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP from mcp.types import CallToolResult, TextContent -mcp = MCPServer("CallToolResult Example") +mcp = FastMCP("CallToolResult Example") class ValidationModel(BaseModel): @@ -441,7 +443,7 @@ def validated_tool() -> Annotated[CallToolResult, ValidationModel]: """Return CallToolResult with structured output validation.""" return CallToolResult( content=[TextContent(type="text", text="Validated response")], - structured_content={"status": "success", "data": {"result": 42}}, + structuredContent={"status": "success", "data": {"result": 42}}, _meta={"internal": "metadata"}, ) @@ -465,9 +467,9 @@ from typing import TypedDict from pydantic import BaseModel, Field -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP -mcp = MCPServer("Structured Output Example") +mcp = FastMCP("Structured Output Example") # Using Pydantic models for rich structured data @@ -567,10 +569,10 @@ Prompts are reusable templates that help LLMs interact with your server effectiv ```python -from mcp.server.mcpserver import MCPServer -from mcp.server.mcpserver.prompts import base +from mcp.server.fastmcp import FastMCP +from mcp.server.fastmcp.prompts import base -mcp = MCPServer(name="Prompt Example") +mcp = FastMCP(name="Prompt Example") @mcp.prompt(title="Code Review") @@ -595,7 +597,7 @@ _Full example: [examples/snippets/servers/basic_prompt.py](https://github.com/mo MCP servers can provide icons for UI display. Icons can be added to the server implementation, tools, resources, and prompts: ```python -from mcp.server.mcpserver import MCPServer, Icon +from mcp.server.fastmcp import FastMCP, Icon # Create an icon from a file path or URL icon = Icon( @@ -605,7 +607,7 @@ icon = Icon( ) # Add icons to server -mcp = MCPServer( +mcp = FastMCP( "My Server", website_url="https://example.com", icons=[icon] @@ -623,21 +625,21 @@ def my_resource(): return "content" ``` -_Full example: [examples/mcpserver/icons_demo.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/mcpserver/icons_demo.py)_ +_Full example: [examples/fastmcp/icons_demo.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/fastmcp/icons_demo.py)_ ### Images -MCPServer provides an `Image` class that automatically handles image data: +FastMCP provides an `Image` class that automatically handles image data: ```python -"""Example showing image handling with MCPServer.""" +"""Example showing image handling with FastMCP.""" from PIL import Image as PILImage -from mcp.server.mcpserver import Image, MCPServer +from mcp.server.fastmcp import FastMCP, Image -mcp = MCPServer("Image Example") +mcp = FastMCP("Image Example") @mcp.tool() @@ -660,9 +662,9 @@ The Context object is automatically injected into tool and resource functions th To use context in a tool or resource function, add a parameter with the `Context` type annotation: ```python -from mcp.server.mcpserver import Context, MCPServer +from mcp.server.fastmcp import Context, FastMCP -mcp = MCPServer(name="Context Example") +mcp = FastMCP(name="Context Example") @mcp.tool() @@ -678,11 +680,11 @@ The Context object provides the following capabilities: - `ctx.request_id` - Unique ID for the current request - `ctx.client_id` - Client ID if available -- `ctx.mcp_server` - Access to the MCPServer server instance (see [MCPServer Properties](#mcpserver-properties)) +- `ctx.fastmcp` - Access to the FastMCP server instance (see [FastMCP Properties](#fastmcp-properties)) - `ctx.session` - Access to the underlying session for advanced communication (see [Session Properties and Methods](#session-properties-and-methods)) - `ctx.request_context` - Access to request-specific data and lifespan resources (see [Request Context Properties](#request-context-properties)) - `await ctx.debug(message)` - Send debug log message -- `await ctx.info(message)` - Send info log message +- `await ctx.info(message)` - Send info log message - `await ctx.warning(message)` - Send warning log message - `await ctx.error(message)` - Send error log message - `await ctx.log(level, message, logger_name=None)` - Send log with custom level @@ -692,10 +694,10 @@ The Context object provides the following capabilities: ```python -from mcp.server.mcpserver import Context, MCPServer +from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession -mcp = MCPServer(name="Progress Example") +mcp = FastMCP(name="Progress Example") @mcp.tool() @@ -726,8 +728,9 @@ Client usage: ```python -"""cd to the `examples/snippets` directory and run: -uv run completion-client +""" +cd to the `examples/snippets` directory and run: + uv run completion-client """ import asyncio @@ -755,8 +758,8 @@ async def run(): # List available resource templates templates = await session.list_resource_templates() print("Available resource templates:") - for template in templates.resource_templates: - print(f" - {template.uri_template}") + for template in templates.resourceTemplates: + print(f" - {template.uriTemplate}") # List available prompts prompts = await session.list_prompts() @@ -765,20 +768,20 @@ async def run(): print(f" - {prompt.name}") # Complete resource template arguments - if templates.resource_templates: - template = templates.resource_templates[0] - print(f"\nCompleting arguments for resource template: {template.uri_template}") + if templates.resourceTemplates: + template = templates.resourceTemplates[0] + print(f"\nCompleting arguments for resource template: {template.uriTemplate}") # Complete without context result = await session.complete( - ref=ResourceTemplateReference(type="ref/resource", uri=template.uri_template), + ref=ResourceTemplateReference(type="ref/resource", uri=template.uriTemplate), argument={"name": "owner", "value": "model"}, ) print(f"Completions for 'owner' starting with 'model': {result.completion.values}") # Complete with context - repo suggestions based on owner result = await session.complete( - ref=ResourceTemplateReference(type="ref/resource", uri=template.uri_template), + ref=ResourceTemplateReference(type="ref/resource", uri=template.uriTemplate), argument={"name": "repo", "value": ""}, context_arguments={"owner": "modelcontextprotocol"}, ) @@ -824,12 +827,12 @@ import uuid from pydantic import BaseModel, Field -from mcp.server.mcpserver import Context, MCPServer +from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession from mcp.shared.exceptions import UrlElicitationRequiredError from mcp.types import ElicitRequestURLParams -mcp = MCPServer(name="Elicitation Example") +mcp = FastMCP(name="Elicitation Example") class BookingPreferences(BaseModel): @@ -908,7 +911,7 @@ async def connect_service(service_name: str, ctx: Context[ServerSession, None]) mode="url", message=f"Authorization required to connect to {service_name}", url=f"https://{service_name}.example.com/oauth/authorize?elicit={elicitation_id}", - elicitation_id=elicitation_id, + elicitationId=elicitation_id, ) ] ) @@ -931,11 +934,11 @@ Tools can interact with LLMs through sampling (generating text): ```python -from mcp.server.mcpserver import Context, MCPServer +from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession from mcp.types import SamplingMessage, TextContent -mcp = MCPServer(name="Sampling Example") +mcp = FastMCP(name="Sampling Example") @mcp.tool() @@ -968,10 +971,10 @@ Tools can send logs and notifications through the context: ```python -from mcp.server.mcpserver import Context, MCPServer +from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession -mcp = MCPServer(name="Notifications Example") +mcp = FastMCP(name="Notifications Example") @mcp.tool() @@ -1002,15 +1005,16 @@ MCP servers can use authentication by providing an implementation of the `TokenV ```python -"""Run from the repository root: -uv run examples/snippets/servers/oauth_server.py +""" +Run from the repository root: + uv run examples/snippets/servers/oauth_server.py """ from pydantic import AnyHttpUrl from mcp.server.auth.provider import AccessToken, TokenVerifier from mcp.server.auth.settings import AuthSettings -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP class SimpleTokenVerifier(TokenVerifier): @@ -1020,9 +1024,10 @@ class SimpleTokenVerifier(TokenVerifier): pass # This is where you would implement actual token validation -# Create MCPServer instance as a Resource Server -mcp = MCPServer( +# Create FastMCP instance as a Resource Server +mcp = FastMCP( "Weather Service", + json_response=True, # Token verifier for authentication token_verifier=SimpleTokenVerifier(), # Auth settings for RFC 9728 Protected Resource Metadata @@ -1046,7 +1051,7 @@ async def get_weather(city: str = "London") -> dict[str, str]: if __name__ == "__main__": - mcp.run(transport="streamable-http", json_response=True) + mcp.run(transport="streamable-http") ``` _Full example: [examples/snippets/servers/oauth_server.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/oauth_server.py)_ @@ -1062,19 +1067,19 @@ For a complete example with separate Authorization Server and Resource Server im See [TokenVerifier](src/mcp/server/auth/provider.py) for more details on implementing token validation. -### MCPServer Properties +### FastMCP Properties -The MCPServer server instance accessible via `ctx.mcp_server` provides access to server configuration and metadata: +The FastMCP server instance accessible via `ctx.fastmcp` provides access to server configuration and metadata: -- `ctx.mcp_server.name` - The server's name as defined during initialization -- `ctx.mcp_server.instructions` - Server instructions/description provided to clients -- `ctx.mcp_server.website_url` - Optional website URL for the server -- `ctx.mcp_server.icons` - Optional list of icons for UI display -- `ctx.mcp_server.settings` - Complete server configuration object containing: +- `ctx.fastmcp.name` - The server's name as defined during initialization +- `ctx.fastmcp.instructions` - Server instructions/description provided to clients +- `ctx.fastmcp.website_url` - Optional website URL for the server +- `ctx.fastmcp.icons` - Optional list of icons for UI display +- `ctx.fastmcp.settings` - Complete server configuration object containing: - `debug` - Debug mode flag - `log_level` - Current logging level - `host` and `port` - Server network configuration - - `sse_path`, `streamable_http_path` - Transport paths + - `mount_path`, `sse_path`, `streamable_http_path` - Transport paths - `stateless_http` - Whether the server operates in stateless mode - And other configuration options @@ -1083,12 +1088,12 @@ The MCPServer server instance accessible via `ctx.mcp_server` provides access to def server_info(ctx: Context) -> dict: """Get information about the current server.""" return { - "name": ctx.mcp_server.name, - "instructions": ctx.mcp_server.instructions, - "debug_mode": ctx.mcp_server.settings.debug, - "log_level": ctx.mcp_server.settings.log_level, - "host": ctx.mcp_server.settings.host, - "port": ctx.mcp_server.settings.port, + "name": ctx.fastmcp.name, + "instructions": ctx.fastmcp.instructions, + "debug_mode": ctx.fastmcp.settings.debug, + "log_level": ctx.fastmcp.settings.log_level, + "host": ctx.fastmcp.settings.host, + "port": ctx.fastmcp.settings.port, } ``` @@ -1110,13 +1115,13 @@ The session object accessible via `ctx.session` provides advanced control over c async def notify_data_update(resource_uri: str, ctx: Context) -> str: """Update data and notify clients of the change.""" # Perform data update logic here - + # Notify clients that this specific resource changed await ctx.session.send_resource_updated(AnyUrl(resource_uri)) - + # If this affects the overall resource list, notify about that too await ctx.session.send_resource_list_changed() - + return f"Updated {resource_uri} and notified clients" ``` @@ -1145,11 +1150,11 @@ def query_with_config(query: str, ctx: Context) -> str: """Execute a query using shared database and configuration.""" # Access typed lifespan context app_ctx: AppContext = ctx.request_context.lifespan_context - + # Use shared resources connection = app_ctx.db settings = app_ctx.config - + # Execute query with configuration result = connection.execute(query, timeout=settings.query_timeout) return str(result) @@ -1203,9 +1208,9 @@ cd to the `examples/snippets` directory and run: python servers/direct_execution.py """ -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP -mcp = MCPServer("My App") +mcp = FastMCP("My App") @mcp.tool() @@ -1234,7 +1239,7 @@ python servers/direct_execution.py uv run mcp run servers/direct_execution.py ``` -Note that `uv run mcp run` or `uv run mcp dev` only supports server using MCPServer and not the low-level server variant. +Note that `uv run mcp run` or `uv run mcp dev` only supports server using FastMCP and not the low-level server variant. ### Streamable HTTP Transport @@ -1242,13 +1247,22 @@ Note that `uv run mcp run` or `uv run mcp dev` only supports server using MCPSer ```python -"""Run from the repository root: -uv run examples/snippets/servers/streamable_config.py +""" +Run from the repository root: + uv run examples/snippets/servers/streamable_config.py """ -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP -mcp = MCPServer("StatelessServer") +# Stateless server with JSON responses (recommended) +mcp = FastMCP("StatelessServer", stateless_http=True, json_response=True) + +# Other configuration options: +# Stateless server with SSE streaming responses +# mcp = FastMCP("StatelessServer", stateless_http=True) + +# Stateful server with session persistence +# mcp = FastMCP("StatefulServer") # Add a simple tool to demonstrate the server @@ -1259,28 +1273,20 @@ def greet(name: str = "World") -> str: # Run server with streamable_http transport -# Transport-specific options (stateless_http, json_response) are passed to run() if __name__ == "__main__": - # Stateless server with JSON responses (recommended) - mcp.run(transport="streamable-http", stateless_http=True, json_response=True) - - # Other configuration options: - # Stateless server with SSE streaming responses - # mcp.run(transport="streamable-http", stateless_http=True) - - # Stateful server with session persistence - # mcp.run(transport="streamable-http") + mcp.run(transport="streamable-http") ``` _Full example: [examples/snippets/servers/streamable_config.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/streamable_config.py)_ -You can mount multiple MCPServer servers in a Starlette application: +You can mount multiple FastMCP servers in a Starlette application: ```python -"""Run from the repository root: -uvicorn examples.snippets.servers.streamable_starlette_mount:app --reload +""" +Run from the repository root: + uvicorn examples.snippets.servers.streamable_starlette_mount:app --reload """ import contextlib @@ -1288,10 +1294,10 @@ import contextlib from starlette.applications import Starlette from starlette.routing import Mount -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP # Create the Echo server -echo_mcp = MCPServer(name="EchoServer") +echo_mcp = FastMCP(name="EchoServer", stateless_http=True, json_response=True) @echo_mcp.tool() @@ -1301,7 +1307,7 @@ def echo(message: str) -> str: # Create the Math server -math_mcp = MCPServer(name="MathServer") +math_mcp = FastMCP(name="MathServer", stateless_http=True, json_response=True) @math_mcp.tool() @@ -1322,16 +1328,16 @@ async def lifespan(app: Starlette): # Create the Starlette app and mount the MCP servers app = Starlette( routes=[ - Mount("/echo", echo_mcp.streamable_http_app(stateless_http=True, json_response=True)), - Mount("/math", math_mcp.streamable_http_app(stateless_http=True, json_response=True)), + Mount("/echo", echo_mcp.streamable_http_app()), + Mount("/math", math_mcp.streamable_http_app()), ], lifespan=lifespan, ) # Note: Clients connect to http://localhost:8000/echo/mcp and http://localhost:8000/math/mcp # To mount at the root of each path (e.g., /echo instead of /echo/mcp): -# echo_mcp.streamable_http_app(streamable_http_path="/", stateless_http=True, json_response=True) -# math_mcp.streamable_http_app(streamable_http_path="/", stateless_http=True, json_response=True) +# echo_mcp.settings.streamable_http_path = "/" +# math_mcp.settings.streamable_http_path = "/" ``` _Full example: [examples/snippets/servers/streamable_starlette_mount.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/streamable_starlette_mount.py)_ @@ -1389,7 +1395,8 @@ You can mount the StreamableHTTP server to an existing ASGI server using the `st ```python -"""Basic example showing how to mount StreamableHTTP server in Starlette. +""" +Basic example showing how to mount StreamableHTTP server in Starlette. Run from the repository root: uvicorn examples.snippets.servers.streamable_http_basic_mounting:app --reload @@ -1400,10 +1407,10 @@ import contextlib from starlette.applications import Starlette from starlette.routing import Mount -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP # Create MCP server -mcp = MCPServer("My App") +mcp = FastMCP("My App", json_response=True) @mcp.tool() @@ -1420,10 +1427,9 @@ async def lifespan(app: Starlette): # Mount the StreamableHTTP server to the existing ASGI server -# Transport-specific options are passed to streamable_http_app() app = Starlette( routes=[ - Mount("/", app=mcp.streamable_http_app(json_response=True)), + Mount("/", app=mcp.streamable_http_app()), ], lifespan=lifespan, ) @@ -1436,7 +1442,8 @@ _Full example: [examples/snippets/servers/streamable_http_basic_mounting.py](htt ```python -"""Example showing how to mount StreamableHTTP server using Host-based routing. +""" +Example showing how to mount StreamableHTTP server using Host-based routing. Run from the repository root: uvicorn examples.snippets.servers.streamable_http_host_mounting:app --reload @@ -1447,10 +1454,10 @@ import contextlib from starlette.applications import Starlette from starlette.routing import Host -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP # Create MCP server -mcp = MCPServer("MCP Host App") +mcp = FastMCP("MCP Host App", json_response=True) @mcp.tool() @@ -1467,10 +1474,9 @@ async def lifespan(app: Starlette): # Mount using Host-based routing -# Transport-specific options are passed to streamable_http_app() app = Starlette( routes=[ - Host("mcp.acme.corp", app=mcp.streamable_http_app(json_response=True)), + Host("mcp.acme.corp", app=mcp.streamable_http_app()), ], lifespan=lifespan, ) @@ -1483,7 +1489,8 @@ _Full example: [examples/snippets/servers/streamable_http_host_mounting.py](http ```python -"""Example showing how to mount multiple StreamableHTTP servers with path configuration. +""" +Example showing how to mount multiple StreamableHTTP servers with path configuration. Run from the repository root: uvicorn examples.snippets.servers.streamable_http_multiple_servers:app --reload @@ -1494,11 +1501,11 @@ import contextlib from starlette.applications import Starlette from starlette.routing import Mount -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP # Create multiple MCP servers -api_mcp = MCPServer("API Server") -chat_mcp = MCPServer("Chat Server") +api_mcp = FastMCP("API Server", json_response=True) +chat_mcp = FastMCP("Chat Server", json_response=True) @api_mcp.tool() @@ -1513,6 +1520,12 @@ def send_message(message: str) -> str: return f"Message sent: {message}" +# Configure servers to mount at the root of each path +# This means endpoints will be at /api and /chat instead of /api/mcp and /chat/mcp +api_mcp.settings.streamable_http_path = "/" +chat_mcp.settings.streamable_http_path = "/" + + # Create a combined lifespan to manage both session managers @contextlib.asynccontextmanager async def lifespan(app: Starlette): @@ -1522,12 +1535,11 @@ async def lifespan(app: Starlette): yield -# Mount the servers with transport-specific options passed to streamable_http_app() -# streamable_http_path="/" means endpoints will be at /api and /chat instead of /api/mcp and /chat/mcp +# Mount the servers app = Starlette( routes=[ - Mount("/api", app=api_mcp.streamable_http_app(json_response=True, streamable_http_path="/")), - Mount("/chat", app=chat_mcp.streamable_http_app(json_response=True, streamable_http_path="/")), + Mount("/api", app=api_mcp.streamable_http_app()), + Mount("/chat", app=chat_mcp.streamable_http_app()), ], lifespan=lifespan, ) @@ -1540,7 +1552,8 @@ _Full example: [examples/snippets/servers/streamable_http_multiple_servers.py](h ```python -"""Example showing path configuration when mounting MCPServer. +""" +Example showing path configuration during FastMCP initialization. Run from the repository root: uvicorn examples.snippets.servers.streamable_http_path_config:app --reload @@ -1549,10 +1562,15 @@ Run from the repository root: from starlette.applications import Starlette from starlette.routing import Mount -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP -# Create a simple MCPServer server -mcp_at_root = MCPServer("My Server") +# Configure streamable_http_path during initialization +# This server will mount at the root of wherever it's mounted +mcp_at_root = FastMCP( + "My Server", + json_response=True, + streamable_http_path="/", +) @mcp_at_root.tool() @@ -1561,14 +1579,10 @@ def process_data(data: str) -> str: return f"Processed: {data}" -# Mount at /process with streamable_http_path="/" so the endpoint is /process (not /process/mcp) -# Transport-specific options like json_response are passed to streamable_http_app() +# Mount at /process - endpoints will be at /process instead of /process/mcp app = Starlette( routes=[ - Mount( - "/process", - app=mcp_at_root.streamable_http_app(json_response=True, streamable_http_path="/"), - ), + Mount("/process", app=mcp_at_root.streamable_http_app()), ] ) ``` @@ -1585,10 +1599,10 @@ You can mount the SSE server to an existing ASGI server using the `sse_app` meth ```python from starlette.applications import Starlette from starlette.routing import Mount, Host -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP -mcp = MCPServer("My App") +mcp = FastMCP("My App") # Mount the SSE server to the existing ASGI server app = Starlette( @@ -1601,28 +1615,41 @@ app = Starlette( app.router.routes.append(Host('mcp.acme.corp', app=mcp.sse_app())) ``` -You can also mount multiple MCP servers at different sub-paths. The SSE transport automatically detects the mount path via ASGI's `root_path` mechanism, so message endpoints are correctly routed: +When mounting multiple MCP servers under different paths, you can configure the mount path in several ways: ```python from starlette.applications import Starlette from starlette.routing import Mount -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP # Create multiple MCP servers -github_mcp = MCPServer("GitHub API") -browser_mcp = MCPServer("Browser") -search_mcp = MCPServer("Search") +github_mcp = FastMCP("GitHub API") +browser_mcp = FastMCP("Browser") +curl_mcp = FastMCP("Curl") +search_mcp = FastMCP("Search") + +# Method 1: Configure mount paths via settings (recommended for persistent configuration) +github_mcp.settings.mount_path = "/github" +browser_mcp.settings.mount_path = "/browser" -# Mount each server at its own sub-path -# The SSE transport automatically uses ASGI's root_path to construct -# the correct message endpoint (e.g., /github/messages/, /browser/messages/) +# Method 2: Pass mount path directly to sse_app (preferred for ad-hoc mounting) +# This approach doesn't modify the server's settings permanently + +# Create Starlette app with multiple mounted servers app = Starlette( routes=[ + # Using settings-based configuration Mount("/github", app=github_mcp.sse_app()), Mount("/browser", app=browser_mcp.sse_app()), - Mount("/search", app=search_mcp.sse_app()), + # Using direct mount path parameter + Mount("/curl", app=curl_mcp.sse_app("/curl")), + Mount("/search", app=search_mcp.sse_app("/search")), ] ) + +# Method 3: For direct execution, you can also pass the mount path to run() +if __name__ == "__main__": + search_mcp.run(transport="sse", mount_path="/search") ``` For more information on mounting applications in Starlette, see the [Starlette documentation](https://www.starlette.io/routing/#submounting-routes). @@ -1635,8 +1662,9 @@ For more control, you can use the low-level server implementation directly. This ```python -"""Run from the repository root: -uv run examples/snippets/servers/lowlevel/lifespan.py +""" +Run from the repository root: + uv run examples/snippets/servers/lowlevel/lifespan.py """ from collections.abc import AsyncIterator @@ -1644,7 +1672,7 @@ from contextlib import asynccontextmanager from typing import Any import mcp.server.stdio -from mcp import types +import mcp.types as types from mcp.server.lowlevel import NotificationOptions, Server from mcp.server.models import InitializationOptions @@ -1692,7 +1720,7 @@ async def handle_list_tools() -> list[types.Tool]: types.Tool( name="query_db", description="Query the database", - input_schema={ + inputSchema={ "type": "object", "properties": {"query": {"type": "string", "description": "SQL query to execute"}}, "required": ["query"], @@ -1751,14 +1779,15 @@ The lifespan API provides: ```python -"""Run from the repository root: +""" +Run from the repository root: uv run examples/snippets/servers/lowlevel/basic.py """ import asyncio import mcp.server.stdio -from mcp import types +import mcp.types as types from mcp.server.lowlevel import NotificationOptions, Server from mcp.server.models import InitializationOptions @@ -1829,15 +1858,16 @@ The low-level server supports structured output for tools, allowing you to retur ```python -"""Run from the repository root: -uv run examples/snippets/servers/lowlevel/structured_output.py +""" +Run from the repository root: + uv run examples/snippets/servers/lowlevel/structured_output.py """ import asyncio from typing import Any import mcp.server.stdio -from mcp import types +import mcp.types as types from mcp.server.lowlevel import NotificationOptions, Server from mcp.server.models import InitializationOptions @@ -1851,12 +1881,12 @@ async def list_tools() -> list[types.Tool]: types.Tool( name="get_weather", description="Get current weather for a city", - input_schema={ + inputSchema={ "type": "object", "properties": {"city": {"type": "string", "description": "City name"}}, "required": ["city"], }, - output_schema={ + outputSchema={ "type": "object", "properties": { "temperature": {"type": "number", "description": "Temperature in Celsius"}, @@ -1931,15 +1961,16 @@ For full control over the response including the `_meta` field (for passing data ```python -"""Run from the repository root: -uv run examples/snippets/servers/lowlevel/direct_call_tool_result.py +""" +Run from the repository root: + uv run examples/snippets/servers/lowlevel/direct_call_tool_result.py """ import asyncio from typing import Any import mcp.server.stdio -from mcp import types +import mcp.types as types from mcp.server.lowlevel import NotificationOptions, Server from mcp.server.models import InitializationOptions @@ -1953,7 +1984,7 @@ async def list_tools() -> list[types.Tool]: types.Tool( name="advanced_tool", description="Tool with full control including _meta field", - input_schema={ + inputSchema={ "type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"], @@ -1969,7 +2000,7 @@ async def handle_call_tool(name: str, arguments: dict[str, Any]) -> types.CallTo message = str(arguments.get("message", "")) return types.CallToolResult( content=[types.TextContent(type="text", text=f"Processed: {message}")], - structured_content={"result": "success", "message": message}, + structuredContent={"result": "success", "message": message}, _meta={"hidden": "data for client applications only"}, ) @@ -2010,9 +2041,13 @@ For servers that need to handle large datasets, the low-level server provides pa ```python -"""Example of implementing pagination with MCP server decorators.""" +""" +Example of implementing pagination with MCP server decorators. +""" + +from pydantic import AnyUrl -from mcp import types +import mcp.types as types from mcp.server.lowlevel import Server # Initialize the server @@ -2036,14 +2071,14 @@ async def list_resources_paginated(request: types.ListResourcesRequest) -> types # Get page of resources page_items = [ - types.Resource(uri=f"resource://items/{item}", name=item, description=f"Description for {item}") + types.Resource(uri=AnyUrl(f"resource://items/{item}"), name=item, description=f"Description for {item}") for item in ITEMS[start:end] ] # Determine next cursor next_cursor = str(end) if end < len(ITEMS) else None - return types.ListResourcesResult(resources=page_items, next_cursor=next_cursor) + return types.ListResourcesResult(resources=page_items, nextCursor=next_cursor) ``` _Full example: [examples/snippets/servers/pagination_example.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/pagination_example.py)_ @@ -2053,7 +2088,9 @@ _Full example: [examples/snippets/servers/pagination_example.py](https://github. ```python -"""Example of consuming paginated MCP endpoints from a client.""" +""" +Example of consuming paginated MCP endpoints from a client. +""" import asyncio @@ -2082,8 +2119,8 @@ async def list_all_resources() -> None: print(f"Fetched {len(result.resources)} resources") # Check if there are more pages - if result.next_cursor: - cursor = result.next_cursor + if result.nextCursor: + cursor = result.nextCursor else: break @@ -2112,28 +2149,31 @@ The SDK provides a high-level client interface for connecting to MCP servers usi ```python -"""cd to the `examples/snippets/clients` directory and run: -uv run client +""" +cd to the `examples/snippets/clients` directory and run: + uv run client """ import asyncio import os +from pydantic import AnyUrl + from mcp import ClientSession, StdioServerParameters, types -from mcp.client.context import ClientRequestContext from mcp.client.stdio import stdio_client +from mcp.shared.context import RequestContext # Create server parameters for stdio connection server_params = StdioServerParameters( command="uv", # Using uv to run the server - args=["run", "server", "mcpserver_quickstart", "stdio"], # We're already in snippets dir + args=["run", "server", "fastmcp_quickstart", "stdio"], # We're already in snippets dir env={"UV_INDEX": os.environ.get("UV_INDEX", "")}, ) # Optional: create a sampling callback async def handle_sampling_message( - context: ClientRequestContext, params: types.CreateMessageRequestParams + context: RequestContext[ClientSession, None], params: types.CreateMessageRequestParams ) -> types.CreateMessageResult: print(f"Sampling request: {params.messages}") return types.CreateMessageResult( @@ -2143,7 +2183,7 @@ async def handle_sampling_message( text="Hello, world! from model", ), model="gpt-3.5-turbo", - stop_reason="endTurn", + stopReason="endTurn", ) @@ -2157,7 +2197,7 @@ async def run(): prompts = await session.list_prompts() print(f"Available prompts: {[p.name for p in prompts.prompts]}") - # Get a prompt (greet_user prompt from mcpserver_quickstart) + # Get a prompt (greet_user prompt from fastmcp_quickstart) if prompts.prompts: prompt = await session.get_prompt("greet_user", arguments={"name": "Alice", "style": "friendly"}) print(f"Prompt result: {prompt.messages[0].content}") @@ -2170,18 +2210,18 @@ async def run(): tools = await session.list_tools() print(f"Available tools: {[t.name for t in tools.tools]}") - # Read a resource (greeting resource from mcpserver_quickstart) - resource_content = await session.read_resource("greeting://World") + # Read a resource (greeting resource from fastmcp_quickstart) + resource_content = await session.read_resource(AnyUrl("greeting://World")) content_block = resource_content.contents[0] if isinstance(content_block, types.TextContent): print(f"Resource content: {content_block.text}") - # Call a tool (add tool from mcpserver_quickstart) + # Call a tool (add tool from fastmcp_quickstart) result = await session.call_tool("add", arguments={"a": 5, "b": 3}) result_unstructured = result.content[0] if isinstance(result_unstructured, types.TextContent): print(f"Tool result: {result_unstructured.text}") - result_structured = result.structured_content + result_structured = result.structuredContent print(f"Structured tool result: {result_structured}") @@ -2201,8 +2241,9 @@ Clients can also connect using [Streamable HTTP transport](https://modelcontextp ```python -"""Run from the repository root: -uv run examples/snippets/clients/streamable_basic.py +""" +Run from the repository root: + uv run examples/snippets/clients/streamable_basic.py """ import asyncio @@ -2240,8 +2281,9 @@ When building MCP clients, the SDK provides utilities to help display human-read ```python -"""cd to the `examples/snippets` directory and run: -uv run display-utilities-client +""" +cd to the `examples/snippets` directory and run: + uv run display-utilities-client """ import asyncio @@ -2254,7 +2296,7 @@ from mcp.shared.metadata_utils import get_display_name # Create server parameters for stdio connection server_params = StdioServerParameters( command="uv", # Using uv to run the server - args=["run", "server", "mcpserver_quickstart", "stdio"], + args=["run", "server", "fastmcp_quickstart", "stdio"], env={"UV_INDEX": os.environ.get("UV_INDEX", "")}, ) @@ -2280,7 +2322,7 @@ async def display_resources(session: ClientSession): print(f"Resource: {display_name} ({resource.uri})") templates_response = await session.list_resource_templates() - for template in templates_response.resource_templates: + for template in templates_response.resourceTemplates: display_name = get_display_name(template) print(f"Resource Template: {display_name}") @@ -2324,7 +2366,8 @@ The SDK includes [authorization support](https://modelcontextprotocol.io/specifi ```python -"""Before running, specify running MCP RS server URL. +""" +Before running, specify running MCP RS server URL. To spin up RS server locally, see examples/servers/simple-auth/README.md From 29a14ab9e53b3188b7c13742bd3712fd22ea738f Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 13 Feb 2026 10:30:30 +0000 Subject: [PATCH 03/84] fix: skip readme-v1-frozen in CI and add diff-based README.md check (#2048) --- .github/workflows/shared.yml | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/.github/workflows/shared.yml b/.github/workflows/shared.yml index fe97895f63..72e328b541 100644 --- a/.github/workflows/shared.yml +++ b/.github/workflows/shared.yml @@ -26,7 +26,19 @@ jobs: with: extra_args: --all-files --verbose env: - SKIP: no-commit-to-branch + SKIP: no-commit-to-branch,readme-v1-frozen + + # TODO(Max): Drop this in v2. + - name: Check README.md is not modified + if: github.event_name == 'pull_request' + run: | + git fetch --no-tags --depth=1 origin "$BASE_SHA" + if git diff --name-only "$BASE_SHA" -- README.md | grep -q .; then + echo "::error::README.md is frozen at v1. Edit README.v2.md instead." + exit 1 + fi + env: + BASE_SHA: ${{ github.event.pull_request.base.sha }} test: name: test (${{ matrix.python-version }}, ${{ matrix.dep-resolution.name }}, ${{ matrix.os }}) From a287a40184c2024e0cda1ba9cd6366f00ae9b179 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 13 Feb 2026 14:10:00 +0100 Subject: [PATCH 04/84] docs: add coverage verification instruction to CLAUDE.md (#2050) --- CLAUDE.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CLAUDE.md b/CLAUDE.md index d7b175636b..0913b7d8ee 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -28,6 +28,9 @@ This document contains critical information about working with this codebase. Fo - Bug fixes require regression tests - IMPORTANT: The `tests/client/test_client.py` is the most well designed test file. Follow its patterns. - IMPORTANT: Be minimal, and focus on E2E tests: Use the `mcp.client.Client` whenever possible. + - IMPORTANT: Before pushing, verify 100% branch coverage on changed files by running + `uv run --frozen pytest -x` (coverage is configured in `pyproject.toml` with `fail_under = 100` + and `branch = true`). If any branch is uncovered, add a test for it before pushing. Test files mirror the source tree: `src/mcp/client/streamable_http.py` → `tests/client/test_streamable_http.py` Add tests to the existing file for that module. From 8f669a77e3a99bed00bec0d9aa9956298ccd5285 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 13 Feb 2026 18:25:10 +0000 Subject: [PATCH 05/84] fix: explicitly load required pytest plugins in addopts (#2055) --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index bfc3067137..008ee4c957 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -173,6 +173,8 @@ xfail_strict = true addopts = """ --color=yes --capture=fd + -p anyio + -p examples """ filterwarnings = [ "error", From 2fe56e56de2aff8fcb964ff7e26e7c6df4d14653 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sat, 14 Feb 2026 09:49:42 +0100 Subject: [PATCH 06/84] fix: handle HTTP error status codes in streamable HTTP client (#2047) --- src/mcp/client/streamable_http.py | 9 +++- tests/client/test_notification_response.py | 52 ++++++++++++++++++++++ 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 9d45bec6ee..8ebd22d359 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -19,6 +19,7 @@ from mcp.shared._httpx_utils import create_mcp_http_client from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.types import ( + INTERNAL_ERROR, INVALID_REQUEST, PARSE_ERROR, ErrorData, @@ -273,7 +274,13 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: await ctx.read_stream_writer.send(session_message) return - response.raise_for_status() + if response.status_code >= 400: + if isinstance(message, JSONRPCRequest): + error_data = ErrorData(code=INTERNAL_ERROR, message="Server returned an error response") + session_message = SessionMessage(JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data)) + await ctx.read_stream_writer.send(session_message) + return + if is_initialization: self._maybe_extract_session_id_from_response(response) diff --git a/tests/client/test_notification_response.py b/tests/client/test_notification_response.py index 9e233acc32..69c8afeb84 100644 --- a/tests/client/test_notification_response.py +++ b/tests/client/test_notification_response.py @@ -116,6 +116,58 @@ async def test_unexpected_content_type_sends_jsonrpc_error() -> None: await session.list_tools() +def _create_http_error_app(error_status: int, *, error_on_notifications: bool = False) -> Starlette: + """Create a server that returns an HTTP error for non-init requests.""" + + async def handle_mcp_request(request: Request) -> Response: + body = await request.body() + data = json.loads(body) + + if data.get("method") == "initialize": + return _init_json_response(data) + + if "id" not in data: + if error_on_notifications: + return Response(status_code=error_status) + return Response(status_code=202) + + return Response(status_code=error_status) + + return Starlette(debug=True, routes=[Route("/mcp", handle_mcp_request, methods=["POST"])]) + + +async def test_http_error_status_sends_jsonrpc_error() -> None: + """Verify HTTP 5xx errors unblock the pending request with an MCPError. + + When a server returns a non-2xx status code (e.g. 500), the client should + send a JSONRPCError so the pending request resolves immediately instead of + raising an unhandled httpx.HTTPStatusError that causes the caller to hang. + """ + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=_create_http_error_app(500))) as client: + async with streamable_http_client("http://localhost/mcp", http_client=client) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + await session.initialize() + + with pytest.raises(MCPError, match="Server returned an error response"): # pragma: no branch + await session.list_tools() + + +async def test_http_error_on_notification_does_not_hang() -> None: + """Verify HTTP errors on notifications are silently ignored. + + When a notification gets an HTTP error, there is no pending request to + unblock, so the client should just return without sending a JSONRPCError. + """ + app = _create_http_error_app(500, error_on_notifications=True) + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) as client: + async with streamable_http_client("http://localhost/mcp", http_client=client) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + await session.initialize() + + # Should not raise or hang — the error is silently ignored for notifications + await session.send_notification(RootsListChangedNotification(method="notifications/roots/list_changed")) + + def _create_invalid_json_response_app() -> Starlette: """Create a server that returns invalid JSON for requests.""" From 3b53fb9a003b370ada409483b839fc3edb1739d7 Mon Sep 17 00:00:00 2001 From: BabyChrist666 Date: Tue, 17 Feb 2026 02:34:47 -0500 Subject: [PATCH 07/84] fix: add HTTP readiness check to wait_for_server and remove dead code in SSE tests (#2073) --- tests/shared/test_sse.py | 11 ----------- tests/shared/test_ws.py | 6 ------ 2 files changed, 17 deletions(-) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 207364cdcb..7b2bc0a139 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -1,7 +1,6 @@ import json import multiprocessing import socket -import time from collections.abc import AsyncGenerator, Generator from typing import Any from unittest.mock import AsyncMock, MagicMock, Mock, patch @@ -134,11 +133,6 @@ def run_server(server_port: int) -> None: # pragma: no cover print(f"starting server on {server_port}") server.run() - # Give server time to start - while not server.started: - print("waiting for server to start") - time.sleep(0.5) - @pytest.fixture() def server(server_port: int) -> Generator[None, None, None]: @@ -313,11 +307,6 @@ def run_mounted_server(server_port: int) -> None: # pragma: no cover print(f"starting server on {server_port}") server.run() - # Give server time to start - while not server.started: - print("waiting for server to start") - time.sleep(0.5) - @pytest.fixture() def mounted_server(server_port: int) -> Generator[None, None, None]: diff --git a/tests/shared/test_ws.py b/tests/shared/test_ws.py index d828505295..9addb661de 100644 --- a/tests/shared/test_ws.py +++ b/tests/shared/test_ws.py @@ -1,6 +1,5 @@ import multiprocessing import socket -import time from collections.abc import AsyncGenerator, Generator from urllib.parse import urlparse @@ -114,11 +113,6 @@ def run_server(server_port: int) -> None: # pragma: no cover print(f"starting server on {server_port}") server.run() - # Give server time to start - while not server.started: - print("waiting for server to start") - time.sleep(0.5) - @pytest.fixture() def server(server_port: int) -> Generator[None, None, None]: From 705497a59369eec487b04c82672d4ea60e795298 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 17 Feb 2026 10:30:34 +0000 Subject: [PATCH 08/84] fix: allow null id in JSONRPCError per JSON-RPC 2.0 spec (#2056) --- src/mcp/client/sse.py | 2 +- src/mcp/client/stdio.py | 2 +- src/mcp/client/streamable_http.py | 2 +- src/mcp/client/websocket.py | 2 +- src/mcp/server/lowlevel/server.py | 2 +- src/mcp/server/sse.py | 2 +- src/mcp/server/stdio.py | 2 +- src/mcp/server/streamable_http.py | 19 ++- src/mcp/server/streamable_http_manager.py | 4 +- src/mcp/server/websocket.py | 2 +- src/mcp/shared/exceptions.py | 5 +- src/mcp/shared/session.py | 8 +- src/mcp/types/jsonrpc.py | 2 +- tests/server/test_streamable_http_manager.py | 2 +- tests/shared/test_session.py | 119 +++++++++++++++++++ 15 files changed, 151 insertions(+), 24 deletions(-) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 8f8e4dadc7..7c309ecb52 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -138,7 +138,7 @@ async def post_writer(endpoint_url: str): json=session_message.message.model_dump( by_alias=True, mode="json", - exclude_none=True, + exclude_unset=True, ), ) response.raise_for_status() diff --git a/src/mcp/client/stdio.py b/src/mcp/client/stdio.py index 605c5ea24e..5b8209eeb5 100644 --- a/src/mcp/client/stdio.py +++ b/src/mcp/client/stdio.py @@ -167,7 +167,7 @@ async def stdin_writer(): try: async with write_stream_reader: async for session_message in write_stream_reader: - json = session_message.message.model_dump_json(by_alias=True, exclude_none=True) + json = session_message.message.model_dump_json(by_alias=True, exclude_unset=True) await process.stdin.send( (json + "\n").encode( encoding=server.encoding, diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 8ebd22d359..d161e3c2a3 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -260,7 +260,7 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: async with ctx.client.stream( "POST", self.url, - json=message.model_dump(by_alias=True, mode="json", exclude_none=True), + json=message.model_dump(by_alias=True, mode="json", exclude_unset=True), headers=headers, ) as response: if response.status_code == 202: diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index cf4b86e992..bda199f36d 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -65,7 +65,7 @@ async def ws_writer(): async with write_stream_reader: async for session_message in write_stream_reader: # Convert to a dict, then to JSON - msg_dict = session_message.message.model_dump(by_alias=True, mode="json", exclude_none=True) + msg_dict = session_message.message.model_dump(by_alias=True, mode="json", exclude_unset=True) await ws.send(json.dumps(msg_dict)) async with anyio.create_task_group() as tg: diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 04404a3fca..9ca5ac4fc9 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -490,7 +490,7 @@ async def _handle_request( except Exception as err: if raise_exceptions: # pragma: no cover raise err - response = types.ErrorData(code=0, message=str(err), data=None) + response = types.ErrorData(code=0, message=str(err)) await message.respond(response) else: # pragma: no cover diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 5be6b78ca9..674294c5c3 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -170,7 +170,7 @@ async def sse_writer(): await sse_stream_writer.send( { "event": "message", - "data": session_message.message.model_dump_json(by_alias=True, exclude_none=True), + "data": session_message.message.model_dump_json(by_alias=True, exclude_unset=True), } ) diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index 7f3aa2ac2f..864d387bdf 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -71,7 +71,7 @@ async def stdout_writer(): try: async with write_stream_reader: async for session_message in write_stream_reader: - json = session_message.message.model_dump_json(by_alias=True, exclude_none=True) + json = session_message.message.model_dump_json(by_alias=True, exclude_unset=True) await stdout.write(json + "\n") await stdout.flush() except anyio.ClosedResourceError: # pragma: no cover diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 54ac7374a1..a8202e3857 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -298,12 +298,12 @@ def _create_error_response( # Return a properly formatted JSON error response error_response = JSONRPCError( jsonrpc="2.0", - id="server-error", # We don't have a request ID for general errors + id=None, error=ErrorData(code=error_code, message=error_message), ) return Response( - error_response.model_dump_json(by_alias=True, exclude_none=True), + error_response.model_dump_json(by_alias=True, exclude_unset=True), status_code=status_code, headers=response_headers, ) @@ -323,7 +323,7 @@ def _create_json_response( response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id return Response( - response_message.model_dump_json(by_alias=True, exclude_none=True) if response_message else None, + response_message.model_dump_json(by_alias=True, exclude_unset=True) if response_message else None, status_code=status_code, headers=response_headers, ) @@ -336,7 +336,7 @@ def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: """Create event data dictionary from an EventMessage.""" event_data = { "event": "message", - "data": event_message.message.model_dump_json(by_alias=True, exclude_none=True), + "data": event_message.message.model_dump_json(by_alias=True, exclude_unset=True), } # If an event ID was provided, include it @@ -975,12 +975,11 @@ async def message_router(): # Determine which request stream(s) should receive this message message = session_message.message target_request_id = None - # Check if this is a response - if isinstance(message, JSONRPCResponse | JSONRPCError): - response_id = str(message.id) - # If this response is for an existing request stream, - # send it there - target_request_id = response_id + # Check if this is a response with a known request id. + # Null-id errors (e.g., parse errors) fall through to + # the GET stream since they can't be correlated. + if isinstance(message, JSONRPCResponse | JSONRPCError) and message.id is not None: + target_request_id = str(message.id) # Extract related_request_id from meta if it exists elif ( # pragma: no cover session_message.metadata is not None diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 8eb29c4d48..9ffabf109d 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -244,11 +244,11 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE # See: https://github.com/modelcontextprotocol/python-sdk/issues/1821 error_response = JSONRPCError( jsonrpc="2.0", - id="server-error", + id=None, error=ErrorData(code=INVALID_REQUEST, message="Session not found"), ) response = Response( - content=error_response.model_dump_json(by_alias=True, exclude_none=True), + content=error_response.model_dump_json(by_alias=True, exclude_unset=True), status_code=HTTPStatus.NOT_FOUND, media_type="application/json", ) diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index a4c8448112..7b00f79055 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -47,7 +47,7 @@ async def ws_writer(): try: async with write_stream_reader: async for session_message in write_stream_reader: - obj = session_message.message.model_dump_json(by_alias=True, exclude_none=True) + obj = session_message.message.model_dump_json(by_alias=True, exclude_unset=True) await websocket.send_text(obj) except anyio.ClosedResourceError: await websocket.close() diff --git a/src/mcp/shared/exceptions.py b/src/mcp/shared/exceptions.py index 7a2b2ded4d..6c3a7745c1 100644 --- a/src/mcp/shared/exceptions.py +++ b/src/mcp/shared/exceptions.py @@ -12,7 +12,10 @@ class MCPError(Exception): def __init__(self, code: int, message: str, data: Any = None): super().__init__(code, message, data) - self.error = ErrorData(code=code, message=message, data=data) + if data is not None: + self.error = ErrorData(code=code, message=message, data=data) + else: + self.error = ErrorData(code=code, message=message) @property def code(self) -> int: diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 453e36274e..5ee8f3baad 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -142,7 +142,7 @@ async def cancel(self) -> None: # Send an error response to indicate cancellation await self._session._send_response( # type: ignore[reportPrivateUsage] request_id=self.request_id, - response=ErrorData(code=0, message="Request cancelled", data=None), + response=ErrorData(code=0, message="Request cancelled"), ) @property @@ -458,6 +458,12 @@ async def _handle_response(self, message: SessionMessage) -> None: if not isinstance(message.message, JSONRPCResponse | JSONRPCError): return # pragma: no cover + if message.message.id is None: + # Narrows to JSONRPCError since JSONRPCResponse.id is always RequestId + error = message.message.error + logging.warning(f"Received error with null ID: {error.message}") + await self._handle_incoming(MCPError(error.code, error.message, error.data)) + return # Normalize response ID to handle type mismatches (e.g., "0" vs 0) response_id = self._normalize_request_id(message.message.id) diff --git a/src/mcp/types/jsonrpc.py b/src/mcp/types/jsonrpc.py index 0cfdc993a5..84304a37c1 100644 --- a/src/mcp/types/jsonrpc.py +++ b/src/mcp/types/jsonrpc.py @@ -75,7 +75,7 @@ class JSONRPCError(BaseModel): """A response to a request that indicates an error occurred.""" jsonrpc: Literal["2.0"] - id: RequestId + id: RequestId | None error: ErrorData diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 475eaa167f..e9a8720f11 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -312,7 +312,7 @@ async def mock_receive(): # Verify JSON-RPC error format error_data = json.loads(response_body) assert error_data["jsonrpc"] == "2.0" - assert error_data["id"] == "server-error" + assert error_data["id"] is None assert error_data["error"]["code"] == INVALID_REQUEST assert error_data["error"]["message"] == "Session not found" diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 2c220f7379..d7c6cc3b5f 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -7,14 +7,19 @@ from mcp.shared.exceptions import MCPError from mcp.shared.memory import create_client_server_memory_streams from mcp.shared.message import SessionMessage +from mcp.shared.session import RequestResponder from mcp.types import ( + PARSE_ERROR, CancelledNotification, CancelledNotificationParams, + ClientResult, EmptyResult, ErrorData, JSONRPCError, JSONRPCRequest, JSONRPCResponse, + ServerNotification, + ServerRequest, ) @@ -297,3 +302,117 @@ async def mock_server(): await ev_closed.wait() with anyio.fail_after(1): # pragma: no branch await ev_response.wait() + + +@pytest.mark.anyio +async def test_null_id_error_surfaced_via_message_handler(): + """Test that a JSONRPCError with id=None is surfaced to the message handler. + + Per JSON-RPC 2.0, error responses use id=null when the request id could not + be determined (e.g., parse errors). These cannot be correlated to any pending + request, so they are forwarded to the message handler as MCPError. + """ + ev_error_received = anyio.Event() + error_holder: list[MCPError] = [] + + async def capture_errors( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: + assert isinstance(message, MCPError) + error_holder.append(message) + ev_error_received.set() + + sent_error = ErrorData(code=PARSE_ERROR, message="Parse error") + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + _server_read, server_write = server_streams + + async def mock_server(): + """Send a null-id error (simulating a parse error).""" + error_response = JSONRPCError(jsonrpc="2.0", id=None, error=sent_error) + await server_write.send(SessionMessage(message=error_response)) + + async with ( + anyio.create_task_group() as tg, + ClientSession( + read_stream=client_read, + write_stream=client_write, + message_handler=capture_errors, + ) as _client_session, + ): + tg.start_soon(mock_server) + + with anyio.fail_after(2): # pragma: no branch + await ev_error_received.wait() + + assert len(error_holder) == 1 + assert error_holder[0].error == sent_error + + +@pytest.mark.anyio +async def test_null_id_error_does_not_affect_pending_request(): + """Test that a null-id error doesn't interfere with an in-flight request. + + When a null-id error arrives while a request is pending, the error should + go to the message handler and the pending request should still complete + normally with its own response. + """ + ev_error_received = anyio.Event() + ev_response_received = anyio.Event() + error_holder: list[MCPError] = [] + result_holder: list[EmptyResult] = [] + + async def capture_errors( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: + assert isinstance(message, MCPError) + error_holder.append(message) + ev_error_received.set() + + sent_error = ErrorData(code=PARSE_ERROR, message="Parse error") + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async def mock_server(): + """Read a request, inject a null-id error, then respond normally.""" + message = await server_read.receive() + assert isinstance(message, SessionMessage) + assert isinstance(message.message, JSONRPCRequest) + request_id = message.message.id + + # First, send a null-id error (should go to message handler) + await server_write.send(SessionMessage(message=JSONRPCError(jsonrpc="2.0", id=None, error=sent_error))) + + # Then, respond normally to the pending request + await server_write.send(SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=request_id, result={}))) + + async def make_request(client_session: ClientSession): + result = await client_session.send_ping() + result_holder.append(result) + ev_response_received.set() + + async with ( + anyio.create_task_group() as tg, + ClientSession( + read_stream=client_read, + write_stream=client_write, + message_handler=capture_errors, + ) as client_session, + ): + tg.start_soon(mock_server) + tg.start_soon(make_request, client_session) + + with anyio.fail_after(2): # pragma: no branch + await ev_error_received.wait() + await ev_response_received.wait() + + # Null-id error reached the message handler + assert len(error_holder) == 1 + assert error_holder[0].error == sent_error + + # Pending request completed successfully + assert len(result_holder) == 1 + assert isinstance(result_holder[0], EmptyResult) From be5bb7c4f2f2b5f07a8a5b087f9db1b998194abe Mon Sep 17 00:00:00 2001 From: Felix Weinberger <3823880+felixweinberger@users.noreply.github.com> Date: Tue, 17 Feb 2026 14:34:59 +0000 Subject: [PATCH 09/84] fix: normalize trailing slashes before length check in check_resource_allowed (#2074) --- src/mcp/client/auth/oauth2.py | 6 ------ src/mcp/shared/auth_utils.py | 13 ++++--------- tests/shared/test_auth_utils.py | 2 +- 3 files changed, 5 insertions(+), 16 deletions(-) diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 41aecc6f28..f464077549 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -493,12 +493,6 @@ async def _validate_resource_match(self, prm: ProtectedResourceMetadata) -> None if not prm_resource: return # pragma: no cover default_resource = resource_url_from_server_url(self.context.server_url) - # Normalize: Pydantic AnyHttpUrl adds trailing slash to root URLs - # (e.g. "https://example.com/") while resource_url_from_server_url may not. - if not default_resource.endswith("/"): - default_resource += "/" - if not prm_resource.endswith("/"): - prm_resource += "/" if not check_resource_allowed(requested_resource=default_resource, configured_resource=prm_resource): raise OAuthFlowError(f"Protected resource {prm_resource} does not match expected {default_resource}") diff --git a/src/mcp/shared/auth_utils.py b/src/mcp/shared/auth_utils.py index 8f3c542f22..3ba880f40d 100644 --- a/src/mcp/shared/auth_utils.py +++ b/src/mcp/shared/auth_utils.py @@ -51,22 +51,17 @@ def check_resource_allowed(requested_resource: str, configured_resource: str) -> if requested.scheme.lower() != configured.scheme.lower() or requested.netloc.lower() != configured.netloc.lower(): return False - # Handle cases like requested=/foo and configured=/foo/ + # Normalize trailing slashes before comparison so that + # "/foo" and "/foo/" are treated as equivalent. requested_path = requested.path configured_path = configured.path - - # If requested path is shorter, it cannot be a child - if len(requested_path) < len(configured_path): - return False - - # Check if the requested path starts with the configured path - # Ensure both paths end with / for proper comparison - # This ensures that paths like "/api123" don't incorrectly match "/api" if not requested_path.endswith("/"): requested_path += "/" if not configured_path.endswith("/"): configured_path += "/" + # Check hierarchical match: requested must start with configured path. + # The trailing-slash normalization ensures "/api123/" won't match "/api/". return requested_path.startswith(configured_path) diff --git a/tests/shared/test_auth_utils.py b/tests/shared/test_auth_utils.py index 2c1c16dc32..5ae0e22b0c 100644 --- a/tests/shared/test_auth_utils.py +++ b/tests/shared/test_auth_utils.py @@ -104,7 +104,7 @@ def test_check_resource_allowed_trailing_slash_handling(): """Trailing slashes should be handled correctly.""" # With and without trailing slashes assert check_resource_allowed("https://example.com/api/", "https://example.com/api") is True - assert check_resource_allowed("https://example.com/api", "https://example.com/api/") is False + assert check_resource_allowed("https://example.com/api", "https://example.com/api/") is True assert check_resource_allowed("https://example.com/api/v1", "https://example.com/api") is True assert check_resource_allowed("https://example.com/api/v1", "https://example.com/api/") is True From 92140e508663680dea629e192d1db64e1debdfaa Mon Sep 17 00:00:00 2001 From: Felix Weinberger <3823880+felixweinberger@users.noreply.github.com> Date: Wed, 18 Feb 2026 10:47:02 +0000 Subject: [PATCH 10/84] Add idle session timeout to StreamableHTTPSessionManager (#2022) --- CLAUDE.md | 5 ++ src/mcp/server/streamable_http.py | 2 + src/mcp/server/streamable_http_manager.py | 60 +++++++++++---- tests/server/test_streamable_http_manager.py | 77 ++++++++++++++++++++ 4 files changed, 129 insertions(+), 15 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 0913b7d8ee..e48ce6e70c 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -31,6 +31,11 @@ This document contains critical information about working with this codebase. Fo - IMPORTANT: Before pushing, verify 100% branch coverage on changed files by running `uv run --frozen pytest -x` (coverage is configured in `pyproject.toml` with `fail_under = 100` and `branch = true`). If any branch is uncovered, add a test for it before pushing. + - Avoid `anyio.sleep()` with a fixed duration to wait for async operations. Instead: + - Use `anyio.Event` — set it in the callback/handler, `await event.wait()` in the test + - For stream messages, use `await stream.receive()` instead of `sleep()` + `receive_nowait()` + - Exception: `sleep()` is appropriate when testing time-based features (e.g., timeouts) + - Wrap indefinite waits (`event.wait()`, `stream.receive()`) in `anyio.fail_after(5)` to prevent hangs Test files mirror the source tree: `src/mcp/client/streamable_http.py` → `tests/client/test_streamable_http.py` Add tests to the existing file for that module. diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index a8202e3857..bcee3a4748 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -169,6 +169,8 @@ def __init__( ] = {} self._sse_stream_writers: dict[RequestId, MemoryObjectSendStream[dict[str, str]]] = {} self._terminated = False + # Idle timeout cancel scope; managed by the session manager. + self.idle_scope: anyio.CancelScope | None = None @property def is_terminated(self) -> bool: diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 9ffabf109d..50bcd5e791 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -39,6 +39,7 @@ class StreamableHTTPSessionManager: 2. Resumability via an optional event store 3. Connection management and lifecycle 4. Request handling and transport setup + 5. Idle session cleanup via optional timeout Important: Only one StreamableHTTPSessionManager instance should be created per application. The instance cannot be reused after its run() context has @@ -46,16 +47,20 @@ class StreamableHTTPSessionManager: Args: app: The MCP server instance - event_store: Optional event store for resumability support. - If provided, enables resumable connections where clients - can reconnect and receive missed events. - If None, sessions are still tracked but not resumable. + event_store: Optional event store for resumability support. If provided, enables resumable connections + where clients can reconnect and receive missed events. If None, sessions are still tracked but not + resumable. json_response: Whether to use JSON responses instead of SSE streams - stateless: If True, creates a completely fresh transport for each request - with no session tracking or state persistence between requests. + stateless: If True, creates a completely fresh transport for each request with no session tracking or + state persistence between requests. security_settings: Optional transport security settings. - retry_interval: Retry interval in milliseconds to suggest to clients in SSE - retry field. Used for SSE polling behavior. + retry_interval: Retry interval in milliseconds to suggest to clients in SSE retry field. Used for SSE + polling behavior. + session_idle_timeout: Optional idle timeout in seconds for stateful sessions. If set, sessions that + receive no HTTP requests for this duration will be automatically terminated and removed. When + retry_interval is also configured, ensure the idle timeout comfortably exceeds the retry interval to + avoid reaping sessions during normal SSE polling gaps. Default is None (no timeout). A value of 1800 + (30 minutes) is recommended for most deployments. """ def __init__( @@ -66,13 +71,20 @@ def __init__( stateless: bool = False, security_settings: TransportSecuritySettings | None = None, retry_interval: int | None = None, + session_idle_timeout: float | None = None, ): + if session_idle_timeout is not None and session_idle_timeout <= 0: + raise ValueError("session_idle_timeout must be a positive number of seconds") + if stateless and session_idle_timeout is not None: + raise RuntimeError("session_idle_timeout is not supported in stateless mode") + self.app = app self.event_store = event_store self.json_response = json_response self.stateless = stateless self.security_settings = security_settings self.retry_interval = retry_interval + self.session_idle_timeout = session_idle_timeout # Session tracking (only used if not stateless) self._session_creation_lock = anyio.Lock() @@ -184,6 +196,9 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: transport = self._server_instances[request_mcp_session_id] logger.debug("Session already exists, handling request directly") + # Push back idle deadline on activity + if transport.idle_scope is not None and self.session_idle_timeout is not None: + transport.idle_scope.deadline = anyio.current_time() + self.session_idle_timeout # pragma: no cover await transport.handle_request(scope, receive, send) return @@ -210,16 +225,31 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE read_stream, write_stream = streams task_status.started() try: - await self.app.run( - read_stream, - write_stream, - self.app.create_initialization_options(), - stateless=False, # Stateful mode - ) + # Use a cancel scope for idle timeout — when the + # deadline passes the scope cancels app.run() and + # execution continues after the ``with`` block. + # Incoming requests push the deadline forward. + idle_scope = anyio.CancelScope() + if self.session_idle_timeout is not None: + idle_scope.deadline = anyio.current_time() + self.session_idle_timeout + http_transport.idle_scope = idle_scope + + with idle_scope: + await self.app.run( + read_stream, + write_stream, + self.app.create_initialization_options(), + stateless=False, + ) + + if idle_scope.cancelled_caught: + assert http_transport.mcp_session_id is not None + logger.info(f"Session {http_transport.mcp_session_id} idle timeout") + self._server_instances.pop(http_transport.mcp_session_id, None) + await http_transport.terminate() except Exception: logger.exception(f"Session {http_transport.mcp_session_id} crashed") finally: - # Only remove from instances if not terminated if ( # pragma: no branch http_transport.mcp_session_id and http_transport.mcp_session_id in self._server_instances diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index e9a8720f11..54a898cc5c 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -333,3 +333,80 @@ async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestP Client(streamable_http_client(f"http://{host}/mcp", http_client=http_client)) as client, ): await client.list_tools() + + +@pytest.mark.anyio +async def test_idle_session_is_reaped(): + """After idle timeout fires, the session returns 404.""" + app = Server("test-idle-reap") + manager = StreamableHTTPSessionManager(app=app, session_idle_timeout=0.05) + + async with manager.run(): + sent_messages: list[Message] = [] + + async def mock_send(message: Message): + sent_messages.append(message) + + scope = { + "type": "http", + "method": "POST", + "path": "/mcp", + "headers": [(b"content-type", b"application/json")], + } + + async def mock_receive(): # pragma: no cover + return {"type": "http.request", "body": b"", "more_body": False} + + await manager.handle_request(scope, mock_receive, mock_send) + + session_id = None + for msg in sent_messages: # pragma: no branch + if msg["type"] == "http.response.start": # pragma: no branch + for header_name, header_value in msg.get("headers", []): # pragma: no branch + if header_name.decode().lower() == MCP_SESSION_ID_HEADER.lower(): + session_id = header_value.decode() + break + if session_id: # pragma: no branch + break + + assert session_id is not None, "Session ID not found in response headers" + + # Wait for the 50ms idle timeout to fire and cleanup to complete + await anyio.sleep(0.1) + + # Verify via public API: old session ID now returns 404 + response_messages: list[Message] = [] + + async def capture_send(message: Message): + response_messages.append(message) + + scope_with_session = { + "type": "http", + "method": "POST", + "path": "/mcp", + "headers": [ + (b"content-type", b"application/json"), + (b"mcp-session-id", session_id.encode()), + ], + } + + await manager.handle_request(scope_with_session, mock_receive, capture_send) + + response_start = next( + (msg for msg in response_messages if msg["type"] == "http.response.start"), + None, + ) + assert response_start is not None + assert response_start["status"] == 404 + + +def test_session_idle_timeout_rejects_non_positive(): + with pytest.raises(ValueError, match="positive number"): + StreamableHTTPSessionManager(app=Server("test"), session_idle_timeout=-1) + with pytest.raises(ValueError, match="positive number"): + StreamableHTTPSessionManager(app=Server("test"), session_idle_timeout=0) + + +def test_session_idle_timeout_rejects_stateless(): + with pytest.raises(RuntimeError, match="not supported in stateless"): + StreamableHTTPSessionManager(app=Server("test"), session_idle_timeout=30, stateless=True) From fc57c2c4c5c0176c1ca4d608ae213554dda66448 Mon Sep 17 00:00:00 2001 From: Akshan Krithick <97239696+akshan-main@users.noreply.github.com> Date: Wed, 18 Feb 2026 17:23:03 +0530 Subject: [PATCH 11/84] test: fix progress notification assertions for related_request_id (#2038) Co-authored-by: Lee Hubbard Co-authored-by: Max Isbey <224885523+maxisbey@users.noreply.github.com> --- src/mcp/server/mcpserver/server.py | 1 + tests/issues/test_176_progress_token.py | 12 ++++++--- tests/server/mcpserver/test_server.py | 35 ++++++++++++++++++++++++- 3 files changed, 44 insertions(+), 4 deletions(-) diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index f26944a2d8..cd459589ad 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -1163,6 +1163,7 @@ async def report_progress(self, progress: float, total: float | None = None, mes progress=progress, total=total, message=message, + related_request_id=self.request_id, ) async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]: diff --git a/tests/issues/test_176_progress_token.py b/tests/issues/test_176_progress_token.py index fb4bb0101f..5d5f8b8fc9 100644 --- a/tests/issues/test_176_progress_token.py +++ b/tests/issues/test_176_progress_token.py @@ -35,6 +35,12 @@ async def test_progress_token_zero_first_call(): # Verify progress notifications assert mock_session.send_progress_notification.call_count == 3, "All progress notifications should be sent" - mock_session.send_progress_notification.assert_any_call(progress_token=0, progress=0.0, total=10.0, message=None) - mock_session.send_progress_notification.assert_any_call(progress_token=0, progress=5.0, total=10.0, message=None) - mock_session.send_progress_notification.assert_any_call(progress_token=0, progress=10.0, total=10.0, message=None) + mock_session.send_progress_notification.assert_any_call( + progress_token=0, progress=0.0, total=10.0, message=None, related_request_id="test-request" + ) + mock_session.send_progress_notification.assert_any_call( + progress_token=0, progress=5.0, total=10.0, message=None, related_request_id="test-request" + ) + mock_session.send_progress_notification.assert_any_call( + progress_token=0, progress=10.0, total=10.0, message=None, related_request_id="test-request" + ) diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index 3f253baa82..cfbe6587bb 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -1,7 +1,7 @@ import base64 from pathlib import Path from typing import Any -from unittest.mock import patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from inline_snapshot import snapshot @@ -10,6 +10,8 @@ from starlette.routing import Mount, Route from mcp.client import Client +from mcp.server.context import ServerRequestContext +from mcp.server.experimental.request_context import Experimental from mcp.server.mcpserver import Context, MCPServer from mcp.server.mcpserver.exceptions import ToolError from mcp.server.mcpserver.prompts.base import Message, UserMessage @@ -1436,3 +1438,34 @@ def test_streamable_http_no_redirect() -> None: # Verify path values assert streamable_routes[0].path == "/mcp", "Streamable route path should be /mcp" + + +async def test_report_progress_passes_related_request_id(): + """Test that report_progress passes the request_id as related_request_id. + + Without related_request_id, the streamable HTTP transport cannot route + progress notifications to the correct SSE stream, causing them to be + silently dropped. See #953 and #2001. + """ + mock_session = AsyncMock() + mock_session.send_progress_notification = AsyncMock() + + request_context = ServerRequestContext( + request_id="req-abc-123", + session=mock_session, + meta={"progress_token": "tok-1"}, + lifespan_context=None, + experimental=Experimental(), + ) + + ctx = Context(request_context=request_context, mcp_server=MagicMock()) + + await ctx.report_progress(50, 100, message="halfway") + + mock_session.send_progress_notification.assert_awaited_once_with( + progress_token="tok-1", + progress=50, + total=100, + message="halfway", + related_request_id="req-abc-123", + ) From e82203bfc442482f7ee7ac3e0f2f300f6e0698a0 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 18 Feb 2026 13:10:02 +0000 Subject: [PATCH 12/84] refactor: remove unused `mcp.shared.progress` module (#2080) --- docs/migration.md | 45 +++++--- src/mcp/shared/progress.py | 45 -------- tests/shared/test_progress_notifications.py | 113 -------------------- 3 files changed, 32 insertions(+), 171 deletions(-) delete mode 100644 src/mcp/shared/progress.py diff --git a/docs/migration.md b/docs/migration.md index 17fd92bd06..6316836938 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -371,7 +371,7 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar server = Server("my-server", on_call_tool=handle_call_tool) ``` -### `RequestContext` and `ProgressContext` type parameters simplified +### `RequestContext` type parameters simplified The `RequestContext` class has been split to separate shared fields from server-specific fields. The shared `RequestContext` now only takes 1 type parameter (the session type) instead of 3. @@ -380,40 +380,59 @@ The `RequestContext` class has been split to separate shared fields from server- - Type parameters reduced from `RequestContext[SessionT, LifespanContextT, RequestT]` to `RequestContext[SessionT]` - Server-specific fields (`lifespan_context`, `experimental`, `request`, `close_sse_stream`, `close_standalone_sse_stream`) moved to new `ServerRequestContext` class in `mcp.server.context` -**`ProgressContext` changes:** - -- Type parameters reduced from `ProgressContext[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]` to `ProgressContext[SessionT]` - **Before (v1):** ```python from mcp.client.session import ClientSession from mcp.shared.context import RequestContext, LifespanContextT, RequestT -from mcp.shared.progress import ProgressContext # RequestContext with 3 type parameters ctx: RequestContext[ClientSession, LifespanContextT, RequestT] - -# ProgressContext with 5 type parameters -progress_ctx: ProgressContext[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT] ``` **After (v2):** ```python from mcp.client.context import ClientRequestContext -from mcp.client.session import ClientSession from mcp.server.context import ServerRequestContext, LifespanContextT, RequestT -from mcp.shared.progress import ProgressContext # For client-side context (sampling, elicitation, list_roots callbacks) ctx: ClientRequestContext # For server-specific context with lifespan and request types server_ctx: ServerRequestContext[LifespanContextT, RequestT] +``` + +### `ProgressContext` and `progress()` context manager removed + +The `mcp.shared.progress` module (`ProgressContext`, `Progress`, and the `progress()` context manager) has been removed. This module had no real-world adoption — all users send progress notifications via `Context.report_progress()` or `session.send_progress_notification()` directly. + +**Before:** + +```python +from mcp.shared.progress import progress -# ProgressContext with 1 type parameter -progress_ctx: ProgressContext[ClientSession] +with progress(ctx, total=100) as p: + await p.progress(25) +``` + +**After — use `Context.report_progress()` (recommended):** + +```python +@server.tool() +async def my_tool(x: int, ctx: Context) -> str: + await ctx.report_progress(25, 100) + return "done" +``` + +**After — use `session.send_progress_notification()` (low-level):** + +```python +await session.send_progress_notification( + progress_token=progress_token, + progress=25, + total=100, +) ``` ### Resource URI type changed from `AnyUrl` to `str` diff --git a/src/mcp/shared/progress.py b/src/mcp/shared/progress.py deleted file mode 100644 index 510bd81632..0000000000 --- a/src/mcp/shared/progress.py +++ /dev/null @@ -1,45 +0,0 @@ -from collections.abc import Generator -from contextlib import contextmanager -from dataclasses import dataclass, field -from typing import Generic - -from pydantic import BaseModel - -from mcp.shared._context import RequestContext, SessionT -from mcp.types import ProgressToken - - -class Progress(BaseModel): - progress: float - total: float | None - - -@dataclass -class ProgressContext(Generic[SessionT]): - session: SessionT - progress_token: ProgressToken - total: float | None - current: float = field(default=0.0, init=False) - - async def progress(self, amount: float, message: str | None = None) -> None: - self.current += amount - - await self.session.send_progress_notification( - self.progress_token, self.current, total=self.total, message=message - ) - - -@contextmanager -def progress( - ctx: RequestContext[SessionT], - total: float | None = None, -) -> Generator[ProgressContext[SessionT], None]: - progress_token = ctx.meta.get("progress_token") if ctx.meta else None - if progress_token is None: # pragma: no cover - raise ValueError("No progress token provided") - - progress_ctx = ProgressContext(ctx.session, progress_token, total) - try: - yield progress_ctx - finally: - pass diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index 6b87774c0c..aad9e5d439 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -10,9 +10,7 @@ from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession -from mcp.shared._context import RequestContext from mcp.shared.message import SessionMessage -from mcp.shared.progress import progress from mcp.shared.session import RequestResponder @@ -198,117 +196,6 @@ async def handle_client_message( assert server_progress_updates[2]["progress"] == 1.0 -@pytest.mark.anyio -async def test_progress_context_manager(): - """Test client using progress context manager for sending progress notifications.""" - # Create memory streams for client/server - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](5) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5) - - # Track progress updates - server_progress_updates: list[dict[str, Any]] = [] - - progress_token = None - - # Register progress handler - async def handle_progress(ctx: ServerRequestContext, params: types.ProgressNotificationParams) -> None: - server_progress_updates.append( - { - "token": params.progress_token, - "progress": params.progress, - "total": params.total, - "message": params.message, - } - ) - - server = Server(name="ProgressContextTestServer", on_progress=handle_progress) - - # Run server session to receive progress updates - async def run_server(): - # Create a server session - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="ProgressContextTestServer", - server_version="0.1.0", - capabilities=server.get_capabilities(NotificationOptions(), {}), - ), - ) as server_session: - async for message in server_session.incoming_messages: - try: - await server._handle_message(message, server_session, {}) - except Exception as e: # pragma: no cover - raise e - - # Client message handler - async def handle_client_message( - message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, - ) -> None: - if isinstance(message, Exception): # pragma: no cover - raise message - - # run client session - async with ( - ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=handle_client_message, - ) as client_session, - anyio.create_task_group() as tg, - ): - tg.start_soon(run_server) - - await client_session.initialize() - - progress_token = "client_token_456" - - # Create request context - request_context = RequestContext( - request_id="test-request", - session=client_session, - meta={"progress_token": progress_token}, - ) - - # Utilize progress context manager - with progress(request_context, total=100) as p: - await p.progress(10, message="Loading configuration...") - await p.progress(30, message="Connecting to database...") - await p.progress(40, message="Fetching data...") - await p.progress(20, message="Processing results...") - - # Wait for all messages to be processed - await anyio.sleep(0.5) - tg.cancel_scope.cancel() - - # Verify progress updates were received by server - assert len(server_progress_updates) == 4 - - # first update - assert server_progress_updates[0]["token"] == progress_token - assert server_progress_updates[0]["progress"] == 10 - assert server_progress_updates[0]["total"] == 100 - assert server_progress_updates[0]["message"] == "Loading configuration..." - - # second update - assert server_progress_updates[1]["token"] == progress_token - assert server_progress_updates[1]["progress"] == 40 - assert server_progress_updates[1]["total"] == 100 - assert server_progress_updates[1]["message"] == "Connecting to database..." - - # third update - assert server_progress_updates[2]["token"] == progress_token - assert server_progress_updates[2]["progress"] == 80 - assert server_progress_updates[2]["total"] == 100 - assert server_progress_updates[2]["message"] == "Fetching data..." - - # final update - assert server_progress_updates[3]["token"] == progress_token - assert server_progress_updates[3]["progress"] == 100 - assert server_progress_updates[3]["total"] == 100 - assert server_progress_updates[3]["message"] == "Processing results..." - - @pytest.mark.anyio async def test_progress_callback_exception_logging(): """Test that exceptions in progress callbacks are logged and \ From b9431d483fbab2e62f5851aef30cb7e2824dd488 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 18 Feb 2026 15:16:44 +0000 Subject: [PATCH 13/84] fix: prevent command injection in example URL opening (#2082) --- .../clients/url_elicitation_client.py | 38 +++++++++---------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/examples/snippets/clients/url_elicitation_client.py b/examples/snippets/clients/url_elicitation_client.py index 9888c588e6..2aecbeeee6 100644 --- a/examples/snippets/clients/url_elicitation_client.py +++ b/examples/snippets/clients/url_elicitation_client.py @@ -24,8 +24,6 @@ import asyncio import json -import subprocess -import sys import webbrowser from typing import Any from urllib.parse import urlparse @@ -56,15 +54,19 @@ async def handle_elicitation( ) +ALLOWED_SCHEMES = {"http", "https"} + + async def handle_url_elicitation( params: types.ElicitRequestParams, ) -> types.ElicitResult: """Handle URL mode elicitation - show security warning and optionally open browser. This function demonstrates the security-conscious approach to URL elicitation: - 1. Display the full URL and domain for user inspection - 2. Show the server's reason for requesting this interaction - 3. Require explicit user consent before opening any URL + 1. Validate the URL scheme before prompting the user + 2. Display the full URL and domain for user inspection + 3. Show the server's reason for requesting this interaction + 4. Require explicit user consent before opening any URL """ # Extract URL parameters - these are available on URL mode requests url = getattr(params, "url", None) @@ -75,6 +77,12 @@ async def handle_url_elicitation( print("Error: No URL provided in elicitation request") return types.ElicitResult(action="cancel") + # Reject dangerous URL schemes before prompting the user + parsed = urlparse(str(url)) + if parsed.scheme.lower() not in ALLOWED_SCHEMES: + print(f"\nRejecting URL with disallowed scheme '{parsed.scheme}': {url}") + return types.ElicitResult(action="decline") + # Extract domain for security display domain = extract_domain(url) @@ -105,7 +113,11 @@ async def handle_url_elicitation( # Open the browser print(f"\nOpening browser to: {url}") - open_browser(url) + try: + webbrowser.open(url) + except Exception as e: + print(f"Failed to open browser: {e}") + print(f"Please manually open: {url}") print("Waiting for you to complete the interaction in your browser...") print("(The server will continue once you've finished)") @@ -121,20 +133,6 @@ def extract_domain(url: str) -> str: return "unknown" -def open_browser(url: str) -> None: - """Open URL in the default browser.""" - try: - if sys.platform == "darwin": - subprocess.run(["open", url], check=False) - elif sys.platform == "win32": - subprocess.run(["start", url], shell=True, check=False) - else: - webbrowser.open(url) - except Exception as e: - print(f"Failed to open browser: {e}") - print(f"Please manually open: {url}") - - async def call_tool_with_error_handling( session: ClientSession, tool_name: str, From 0e96aecd1dcf95252692015d21a756044e35d0c8 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 18 Feb 2026 19:40:52 +0000 Subject: [PATCH 14/84] fix: use exact match for loopback hosts in issuer URL validation (#2089) --- src/mcp/server/auth/routes.py | 14 ++++------ tests/server/auth/test_routes.py | 47 ++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 9 deletions(-) create mode 100644 tests/server/auth/test_routes.py diff --git a/src/mcp/server/auth/routes.py b/src/mcp/server/auth/routes.py index 08f735f362..9a10ac57fa 100644 --- a/src/mcp/server/auth/routes.py +++ b/src/mcp/server/auth/routes.py @@ -31,19 +31,15 @@ def validate_issuer_url(url: AnyHttpUrl): ValueError: If the issuer URL is invalid """ - # RFC 8414 requires HTTPS, but we allow localhost HTTP for testing - if ( - url.scheme != "https" - and url.host != "localhost" - and (url.host is not None and not url.host.startswith("127.0.0.1")) - ): - raise ValueError("Issuer URL must be HTTPS") # pragma: no cover + # RFC 8414 requires HTTPS, but we allow loopback/localhost HTTP for testing + if url.scheme != "https" and url.host not in ("localhost", "127.0.0.1", "[::1]"): + raise ValueError("Issuer URL must be HTTPS") # No fragments or query parameters allowed if url.fragment: - raise ValueError("Issuer URL must not have a fragment") # pragma: no cover + raise ValueError("Issuer URL must not have a fragment") if url.query: - raise ValueError("Issuer URL must not have a query string") # pragma: no cover + raise ValueError("Issuer URL must not have a query string") AUTHORIZATION_PATH = "/authorize" diff --git a/tests/server/auth/test_routes.py b/tests/server/auth/test_routes.py new file mode 100644 index 0000000000..3d13b5ba53 --- /dev/null +++ b/tests/server/auth/test_routes.py @@ -0,0 +1,47 @@ +import pytest +from pydantic import AnyHttpUrl + +from mcp.server.auth.routes import validate_issuer_url + + +def test_validate_issuer_url_https_allowed(): + validate_issuer_url(AnyHttpUrl("https://example.com/path")) + + +def test_validate_issuer_url_http_localhost_allowed(): + validate_issuer_url(AnyHttpUrl("http://localhost:8080/path")) + + +def test_validate_issuer_url_http_127_0_0_1_allowed(): + validate_issuer_url(AnyHttpUrl("http://127.0.0.1:8080/path")) + + +def test_validate_issuer_url_http_ipv6_loopback_allowed(): + validate_issuer_url(AnyHttpUrl("http://[::1]:8080/path")) + + +def test_validate_issuer_url_http_non_loopback_rejected(): + with pytest.raises(ValueError, match="Issuer URL must be HTTPS"): + validate_issuer_url(AnyHttpUrl("http://evil.com/path")) + + +def test_validate_issuer_url_http_127_prefix_domain_rejected(): + """A domain like 127.0.0.1.evil.com is not loopback.""" + with pytest.raises(ValueError, match="Issuer URL must be HTTPS"): + validate_issuer_url(AnyHttpUrl("http://127.0.0.1.evil.com/path")) + + +def test_validate_issuer_url_http_127_prefix_subdomain_rejected(): + """A domain like 127.0.0.1something.example.com is not loopback.""" + with pytest.raises(ValueError, match="Issuer URL must be HTTPS"): + validate_issuer_url(AnyHttpUrl("http://127.0.0.1something.example.com/path")) + + +def test_validate_issuer_url_fragment_rejected(): + with pytest.raises(ValueError, match="fragment"): + validate_issuer_url(AnyHttpUrl("https://example.com/path#frag")) + + +def test_validate_issuer_url_query_rejected(): + with pytest.raises(ValueError, match="query"): + validate_issuer_url(AnyHttpUrl("https://example.com/path?q=1")) From 43d709c976b7984df717c48abc59b2e0efc23bf4 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 18 Feb 2026 19:42:00 +0000 Subject: [PATCH 15/84] ci: pin all GitHub Actions to commit SHAs (#2088) --- .github/workflows/claude-code-review.yml | 4 ++-- .github/workflows/claude.yml | 4 ++-- .github/workflows/weekly-lockfile-update.yml | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/claude-code-review.yml b/.github/workflows/claude-code-review.yml index 36c88040ea..514f979d7c 100644 --- a/.github/workflows/claude-code-review.yml +++ b/.github/workflows/claude-code-review.yml @@ -19,13 +19,13 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: fetch-depth: 1 - name: Run Claude Code Review id: claude-review - uses: anthropics/claude-code-action@v1 + uses: anthropics/claude-code-action@2f8ba26a219c06cfb0f468eef8d97055fa814f97 # v1.0.53 with: anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} plugin_marketplaces: "https://github.com/anthropics/claude-code.git" diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml index 490e9ae2cc..8421cf954c 100644 --- a/.github/workflows/claude.yml +++ b/.github/workflows/claude.yml @@ -27,13 +27,13 @@ jobs: actions: read # Required for Claude to read CI results on PRs steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: fetch-depth: 1 - name: Run Claude Code id: claude - uses: anthropics/claude-code-action@v1 + uses: anthropics/claude-code-action@2f8ba26a219c06cfb0f468eef8d97055fa814f97 # v1.0.53 with: anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} use_commit_signing: true diff --git a/.github/workflows/weekly-lockfile-update.yml b/.github/workflows/weekly-lockfile-update.yml index 8808822476..96507d7936 100644 --- a/.github/workflows/weekly-lockfile-update.yml +++ b/.github/workflows/weekly-lockfile-update.yml @@ -14,9 +14,9 @@ jobs: update-lockfile: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - - uses: astral-sh/setup-uv@v7.2.1 + - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 with: version: 0.9.5 From 688c6e3adeaab0864bb63dd3c0b7bc605f65e571 Mon Sep 17 00:00:00 2001 From: Den Delimarsky <53200638+localden@users.noreply.github.com> Date: Wed, 18 Feb 2026 21:19:25 -0800 Subject: [PATCH 16/84] Update SECURITY.md to use GitHub Security Advisories (#2092) --- SECURITY.md | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/SECURITY.md b/SECURITY.md index 6545156105..5029242009 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,15 +1,21 @@ # Security Policy -Thank you for helping us keep the SDKs and systems they interact with secure. +Thank you for helping keep the Model Context Protocol and its ecosystem secure. ## Reporting Security Issues -This SDK is maintained by [Anthropic](https://www.anthropic.com/) as part of the Model Context Protocol project. +If you discover a security vulnerability in this repository, please report it through +the [GitHub Security Advisory process](https://docs.github.com/en/code-security/security-advisories/guidance-on-reporting-and-writing-information-about-vulnerabilities/privately-reporting-a-security-vulnerability) +for this repository. -The security of our systems and user data is Anthropic’s top priority. We appreciate the work of security researchers acting in good faith in identifying and reporting potential vulnerabilities. +Please **do not** report security vulnerabilities through public GitHub issues, discussions, +or pull requests. -Our security program is managed on HackerOne and we ask that any validated vulnerability in this functionality be reported through their [submission form](https://hackerone.com/anthropic-vdp/reports/new?type=team&report_type=vulnerability). +## What to Include -## Vulnerability Disclosure Program +To help us triage and respond quickly, please include: -Our Vulnerability Program Guidelines are defined on our [HackerOne program page](https://hackerone.com/anthropic-vdp). +- A description of the vulnerability +- Steps to reproduce the issue +- The potential impact +- Any suggested fixes (optional) From c0328540c97de2d18925cd115fdc526a1fd2fa22 Mon Sep 17 00:00:00 2001 From: Jonathan Hefner Date: Wed, 18 Feb 2026 23:45:59 -0600 Subject: [PATCH 17/84] docs: fix docstrings across public API surface (#2095) --- src/mcp/cli/cli.py | 8 +-- .../auth/extensions/client_credentials.py | 2 +- src/mcp/client/auth/oauth2.py | 3 +- src/mcp/client/auth/utils.py | 20 +++---- src/mcp/client/client.py | 8 +-- src/mcp/client/experimental/tasks.py | 4 +- src/mcp/client/session_group.py | 10 ++-- src/mcp/client/sse.py | 1 + src/mcp/client/stdio.py | 6 +- src/mcp/client/streamable_http.py | 2 +- src/mcp/client/websocket.py | 4 +- src/mcp/os/win32/utilities.py | 15 ++--- src/mcp/server/auth/handlers/revoke.py | 2 +- src/mcp/server/auth/middleware/client_auth.py | 8 ++- src/mcp/server/auth/provider.py | 17 +++--- src/mcp/server/auth/routes.py | 6 +- src/mcp/server/elicitation.py | 4 +- .../server/experimental/request_context.py | 8 +-- src/mcp/server/lowlevel/server.py | 3 - src/mcp/server/mcpserver/resources/types.py | 2 +- src/mcp/server/mcpserver/server.py | 56 +++++++++++-------- .../mcpserver/utilities/func_metadata.py | 30 +++++----- src/mcp/server/mcpserver/utilities/logging.py | 6 +- src/mcp/server/models.py | 2 +- src/mcp/server/session.py | 28 +++++----- src/mcp/server/sse.py | 8 +-- src/mcp/server/streamable_http.py | 15 ++--- src/mcp/server/websocket.py | 4 +- src/mcp/shared/auth.py | 7 +-- src/mcp/shared/memory.py | 2 +- src/mcp/shared/metadata_utils.py | 2 +- src/mcp/shared/session.py | 5 +- src/mcp/types/_types.py | 42 +++++++------- 33 files changed, 177 insertions(+), 163 deletions(-) diff --git a/src/mcp/cli/cli.py b/src/mcp/cli/cli.py index 858ab7db29..62334a4a2c 100644 --- a/src/mcp/cli/cli.py +++ b/src/mcp/cli/cli.py @@ -317,12 +317,12 @@ def run( ) -> None: # pragma: no cover """Run an MCP server. - The server can be specified in two ways:\n - 1. Module approach: server.py - runs the module directly, expecting a server.run() call.\n - 2. Import approach: server.py:app - imports and runs the specified server object.\n\n + The server can be specified in two ways: + 1. Module approach: server.py - runs the module directly, expecting a server.run() call. + 2. Import approach: server.py:app - imports and runs the specified server object. Note: This command runs the server directly. You are responsible for ensuring - all dependencies are available.\n + all dependencies are available. For dependency management, use `mcp install` or `mcp dev` instead. """ # noqa: E501 file, server_object = _parse_file_path(file_spec) diff --git a/src/mcp/client/auth/extensions/client_credentials.py b/src/mcp/client/auth/extensions/client_credentials.py index 07f6180bf1..cb6dafb407 100644 --- a/src/mcp/client/auth/extensions/client_credentials.py +++ b/src/mcp/client/auth/extensions/client_credentials.py @@ -450,7 +450,7 @@ def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]): # prag # When using private_key_jwt, in a client_credentials flow, we use RFC 7523 Section 2.2 token_data["client_assertion"] = assertion token_data["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" - # We need to set the audience to the resource server, the audience is difference from the one in claims + # We need to set the audience to the resource server, the audience is different from the one in claims # it represents the resource server that will validate the token token_data["audience"] = self.context.get_resource_url() diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index f464077549..7f5af51867 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -215,6 +215,7 @@ def prepare_token_auth( class OAuthClientProvider(httpx.Auth): """OAuth2 authentication for httpx. + Handles OAuth flow with automatic client registration and token storage. """ @@ -241,7 +242,7 @@ def __init__( callback_handler: Handler for authorization callbacks. timeout: Timeout for the OAuth flow. client_metadata_url: URL-based client ID. When provided and the server - advertises client_id_metadata_document_supported=true, this URL will be + advertises client_id_metadata_document_supported=True, this URL will be used as the client_id instead of performing dynamic client registration. Must be a valid HTTPS URL with a non-root pathname. validate_resource_url: Optional callback to override resource URL validation. diff --git a/src/mcp/client/auth/utils.py b/src/mcp/client/auth/utils.py index 1aa960b9ce..0ca36b98d8 100644 --- a/src/mcp/client/auth/utils.py +++ b/src/mcp/client/auth/utils.py @@ -38,7 +38,7 @@ def extract_field_from_www_auth(response: Response, field_name: str) -> str | No def extract_scope_from_www_auth(response: Response) -> str | None: - """Extract scope parameter from WWW-Authenticate header as per RFC6750. + """Extract scope parameter from WWW-Authenticate header as per RFC 6750. Returns: Scope string if found in WWW-Authenticate header, None otherwise @@ -47,7 +47,7 @@ def extract_scope_from_www_auth(response: Response) -> str | None: def extract_resource_metadata_from_www_auth(response: Response) -> str | None: - """Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728. + """Extract protected resource metadata URL from WWW-Authenticate header as per RFC 9728. Returns: Resource metadata URL if found in WWW-Authenticate header, None otherwise @@ -67,8 +67,8 @@ def build_protected_resource_metadata_discovery_urls(www_auth_url: str | None, s 3. Fall back to root-based well-known URI: /.well-known/oauth-protected-resource Args: - www_auth_url: optional resource_metadata url extracted from the WWW-Authenticate header - server_url: server url + www_auth_url: Optional resource_metadata URL extracted from the WWW-Authenticate header + server_url: Server URL Returns: Ordered list of URLs to try for discovery @@ -120,10 +120,10 @@ def get_client_metadata_scopes( def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: str | None, server_url: str) -> list[str]: - """Generate ordered list of (url, type) tuples for discovery attempts. + """Generate an ordered list of URLs for authorization server metadata discovery. Args: - auth_server_url: URL for the OAuth Authorization Metadata URL if found, otherwise None + auth_server_url: OAuth Authorization Server Metadata URL if found, otherwise None server_url: URL for the MCP server, used as a fallback if auth_server_url is None """ @@ -170,7 +170,7 @@ async def handle_protected_resource_response( Per SEP-985, supports fallback when discovery fails at one URL. Returns: - True if metadata was successfully discovered, False if we should try next URL + ProtectedResourceMetadata if successfully discovered, None if we should try next URL """ if response.status_code == 200: try: @@ -206,7 +206,7 @@ def create_oauth_metadata_request(url: str) -> Request: def create_client_registration_request( auth_server_metadata: OAuthMetadata | None, client_metadata: OAuthClientMetadata, auth_base_url: str ) -> Request: - """Build registration request or skip if already registered.""" + """Build a client registration request.""" if auth_server_metadata and auth_server_metadata.registration_endpoint: registration_url = str(auth_server_metadata.registration_endpoint) @@ -261,7 +261,7 @@ def should_use_client_metadata_url( """Determine if URL-based client ID (CIMD) should be used instead of DCR. URL-based client IDs should be used when: - 1. The server advertises client_id_metadata_document_supported=true + 1. The server advertises client_id_metadata_document_supported=True 2. The client has a valid client_metadata_url configured Args: @@ -306,7 +306,7 @@ def create_client_info_from_metadata_url( async def handle_token_response_scopes( response: Response, ) -> OAuthToken: - """Parse and validate token response with optional scope validation. + """Parse and validate a token response. Parses token response JSON. Callers should check response.status_code before calling. diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 29d4a7035d..7dc67c5844 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -37,8 +37,8 @@ class Client: """A high-level MCP client for connecting to MCP servers. - Currently supports in-memory transport for testing. Pass a Server or - MCPServer instance directly to the constructor. + Supports in-memory transport for testing (pass a Server or MCPServer instance), + Streamable HTTP transport (pass a URL string), or a custom Transport instance. Example: ```python @@ -205,7 +205,7 @@ async def read_resource(self, uri: str, *, meta: RequestParamsMeta | None = None Args: uri: The URI of the resource to read. - meta: Additional metadata for the request + meta: Additional metadata for the request. Returns: The resource content. @@ -239,7 +239,7 @@ async def call_tool( meta: Additional metadata for the request Returns: - The tool result + The tool result. """ return await self.session.call_tool( name=name, diff --git a/src/mcp/client/experimental/tasks.py b/src/mcp/client/experimental/tasks.py index 8ddc4faceb..2e2fdf7355 100644 --- a/src/mcp/client/experimental/tasks.py +++ b/src/mcp/client/experimental/tasks.py @@ -83,7 +83,7 @@ async def call_tool_as_task( status = await session.experimental.get_task(task_id) if status.status == "completed": break - await asyncio.sleep(0.5) + await anyio.sleep(0.5) # Get result final = await session.experimental.get_task_result(task_id, CallToolResult) @@ -177,7 +177,7 @@ async def poll_task(self, task_id: str) -> AsyncIterator[types.GetTaskResult]: """Poll a task until it reaches a terminal status. Yields GetTaskResult for each poll, allowing the caller to react to - status changes (e.g., handle input_required). Exits when task reaches + status changes (e.g., handle input_required). Exits when the task reaches a terminal status (completed, failed, cancelled). Respects the pollInterval hint from the server. diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index f4e6293b71..17f41025bd 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -3,7 +3,7 @@ Tools, resources, and prompts are aggregated across servers. Servers may be connected to or disconnected from at any point after initialization. -This abstractions can handle naming collisions using a custom user-provided hook. +This abstraction can handle naming collisions using a custom user-provided hook. """ import contextlib @@ -30,7 +30,7 @@ class SseServerParameters(BaseModel): - """Parameters for initializing a sse_client.""" + """Parameters for initializing an sse_client.""" # The endpoint URL. url: str @@ -67,8 +67,8 @@ class StreamableHttpParameters(BaseModel): ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters -# Use dataclass instead of pydantic BaseModel -# because pydantic BaseModel cannot handle Protocol fields. +# Use dataclass instead of Pydantic BaseModel +# because Pydantic BaseModel cannot handle Protocol fields. @dataclass class ClientSessionParameters: """Parameters for establishing a client session to an MCP server.""" @@ -119,7 +119,7 @@ class _ComponentNames(BaseModel): _session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack] # Optional fn consuming (component_name, server_info) for custom names. - # This is provide a means to mitigate naming conflicts across servers. + # This is to provide a means to mitigate naming conflicts across servers. # Example: (tool_name, server_info) => "{result.server_info.name}.{tool_name}" _ComponentNameHook: TypeAlias = Callable[[str, types.Implementation], str] _component_name_hook: _ComponentNameHook | None diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 7c309ecb52..61026aa0c9 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -47,6 +47,7 @@ async def sse_client( headers: Optional headers to include in requests. timeout: HTTP timeout for regular operations (in seconds). sse_read_timeout: Timeout for SSE read operations (in seconds). + httpx_client_factory: Factory function for creating the HTTPX client. auth: Optional HTTPX authentication handler. on_session_created: Optional callback invoked with the session ID when received. """ diff --git a/src/mcp/client/stdio.py b/src/mcp/client/stdio.py index 5b8209eeb5..902dc8576c 100644 --- a/src/mcp/client/stdio.py +++ b/src/mcp/client/stdio.py @@ -87,9 +87,9 @@ class StdioServerParameters(BaseModel): encoding: str = "utf-8" """ - The text encoding used when sending/receiving messages to the server + The text encoding used when sending/receiving messages to the server. - defaults to utf-8 + Defaults to utf-8. """ encoding_error_handler: Literal["strict", "ignore", "replace"] = "strict" @@ -97,7 +97,7 @@ class StdioServerParameters(BaseModel): The text encoding error handler. See https://docs.python.org/3/library/codecs.html#codec-base-classes for - explanations of possible values + explanations of possible values. """ diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index d161e3c2a3..9f3dd5e0ba 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -358,7 +358,7 @@ async def _handle_sse_response( resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None), is_initialization=is_initialization, ) - # If the SSE event indicates completion, like returning respose/error + # If the SSE event indicates completion, like returning response/error # break the loop if is_complete: await response.aclose() diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index bda199f36d..79e75fad18 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -25,8 +25,8 @@ async def websocket_client( (read_stream, write_stream) - read_stream: As you read from this stream, you'll receive either valid - JSONRPCMessage objects or Exception objects (when validation fails). - - write_stream: Write JSONRPCMessage objects to this stream to send them + SessionMessage objects or Exception objects (when validation fails). + - write_stream: Write SessionMessage objects to this stream to send them over the WebSocket to the server. """ diff --git a/src/mcp/os/win32/utilities.py b/src/mcp/os/win32/utilities.py index fa4e4b399b..0e188691f1 100644 --- a/src/mcp/os/win32/utilities.py +++ b/src/mcp/os/win32/utilities.py @@ -138,9 +138,9 @@ async def create_windows_process( ) -> Process | FallbackProcess: """Creates a subprocess in a Windows-compatible way with Job Object support. - Attempt to use anyio's open_process for async subprocess creation. - In some cases this will throw NotImplementedError on Windows, e.g. - when using the SelectorEventLoop which does not support async subprocesses. + Attempts to use anyio's open_process for async subprocess creation. + In some cases this will throw NotImplementedError on Windows, e.g., + when using the SelectorEventLoop, which does not support async subprocesses. In that case, we fall back to using subprocess.Popen. The process is automatically added to a Job Object to ensure all child @@ -242,8 +242,9 @@ def _create_job_object() -> int | None: def _maybe_assign_process_to_job(process: Process | FallbackProcess, job: JobHandle | None) -> None: - """Try to assign a process to a job object. If assignment fails - for any reason, the job handle is closed. + """Try to assign a process to a job object. + + If assignment fails for any reason, the job handle is closed. """ if not job: return @@ -312,8 +313,8 @@ async def terminate_windows_process(process: Process | FallbackProcess): Note: On Windows, terminating a process with process.terminate() doesn't always guarantee immediate process termination. - So we give it 2s to exit, or we call process.kill() - which sends a SIGKILL equivalent signal. + If the process does not exit within 2 seconds, process.kill() is called + to send a SIGKILL-equivalent signal. Args: process: The process to terminate diff --git a/src/mcp/server/auth/handlers/revoke.py b/src/mcp/server/auth/handlers/revoke.py index 68a3392b4f..4efd154001 100644 --- a/src/mcp/server/auth/handlers/revoke.py +++ b/src/mcp/server/auth/handlers/revoke.py @@ -15,7 +15,7 @@ class RevocationRequest(BaseModel): - """# See https://datatracker.ietf.org/doc/html/rfc7009#section-2.1""" + """See https://datatracker.ietf.org/doc/html/rfc7009#section-2.1""" token: str token_type_hint: Literal["access_token", "refresh_token"] | None = None diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index 8a6a1b5188..2832f83523 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -19,15 +19,17 @@ def __init__(self, message: str): class ClientAuthenticator: """ClientAuthenticator is a callable which validates requests from a client application, used to verify /token calls. + If, during registration, the client requested to be issued a secret, the authenticator asserts that /token calls must be authenticated with - that same token. + that same secret. + NOTE: clients can opt for no authentication during registration, in which case this logic is skipped. """ def __init__(self, provider: OAuthAuthorizationServerProvider[Any, Any, Any]): - """Initialize the dependency. + """Initialize the authenticator. Args: provider: Provider to look up client information @@ -83,7 +85,7 @@ async def authenticate_request(self, request: Request) -> OAuthClientInformation elif client.token_endpoint_auth_method == "client_secret_post": raw_form_data = form_data.get("client_secret") - # form_data.get() can return a UploadFile or None, so we need to check if it's a string + # form_data.get() can return an UploadFile or None, so we need to check if it's a string if isinstance(raw_form_data, str): request_client_secret = str(raw_form_data) diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index 5eb577fd43..957082a854 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -131,8 +131,9 @@ async def register_client(self, client_info: OAuthClientInformationFull) -> None """ async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: - """Called as part of the /authorize endpoint, and returns a URL that the client + """Handle the /authorize endpoint and return a URL that the client will be redirected to. + Many MCP implementations will redirect to a third-party provider to perform a second OAuth exchange with that provider. In this sort of setup, the client has an OAuth connection with the MCP server, and the MCP server has an OAuth @@ -151,7 +152,7 @@ async def authorize(self, client: OAuthClientInformationFull, params: Authorizat | | +------------+ - Implementations will need to define another handler on the MCP server return + Implementations will need to define another handler on the MCP server's return flow to perform the second redirect, and generate and store an authorization code as part of completing the OAuth authorization step. @@ -182,7 +183,7 @@ async def load_authorization_code( authorization_code: The authorization code to get the challenge for. Returns: - The AuthorizationCode, or None if not found + The AuthorizationCode, or None if not found. """ ... @@ -199,7 +200,7 @@ async def exchange_authorization_code( The OAuth token, containing access and refresh tokens. Raises: - TokenError: If the request is invalid + TokenError: If the request is invalid. """ ... @@ -234,18 +235,18 @@ async def exchange_refresh_token( The OAuth token, containing access and refresh tokens. Raises: - TokenError: If the request is invalid + TokenError: If the request is invalid. """ ... async def load_access_token(self, token: str) -> AccessTokenT | None: - """Loads an access token by its token. + """Loads an access token by its token string. Args: token: The access token to verify. Returns: - The AuthInfo, or None if the token is invalid. + The access token, or None if the token is invalid. """ async def revoke_token( @@ -261,7 +262,7 @@ async def revoke_token( provided. Args: - token: the token to revoke + token: The token to revoke. """ diff --git a/src/mcp/server/auth/routes.py b/src/mcp/server/auth/routes.py index 9a10ac57fa..a72e819477 100644 --- a/src/mcp/server/auth/routes.py +++ b/src/mcp/server/auth/routes.py @@ -25,10 +25,10 @@ def validate_issuer_url(url: AnyHttpUrl): """Validate that the issuer URL meets OAuth 2.0 requirements. Args: - url: The issuer URL to validate + url: The issuer URL to validate. Raises: - ValueError: If the issuer URL is invalid + ValueError: If the issuer URL is invalid. """ # RFC 8414 requires HTTPS, but we allow loopback/localhost HTTP for testing @@ -213,6 +213,8 @@ def create_protected_resource_routes( resource_url: The URL of this resource server authorization_servers: List of authorization servers that can issue tokens scopes_supported: Optional list of scopes supported by this resource + resource_name: Optional human-readable name for this resource + resource_documentation: Optional URL to documentation for this resource Returns: List of Starlette routes for protected resource metadata diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index 58e9fe4485..731c914edc 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -112,8 +112,8 @@ async def elicit_with_validation( This method can be used to interactively ask for additional information from the client within a tool's execution. The client might display the message to the - user and collect a response according to the provided schema. Or in case a - client is an agent, it might decide how to handle the elicitation -- either by asking + user and collect a response according to the provided schema. If the client + is an agent, it might decide how to handle the elicitation -- either by asking the user or automatically generating a response. For sensitive data like credentials or OAuth flows, use elicit_url() instead. diff --git a/src/mcp/server/experimental/request_context.py b/src/mcp/server/experimental/request_context.py index 91aa9a6450..138c021108 100644 --- a/src/mcp/server/experimental/request_context.py +++ b/src/mcp/server/experimental/request_context.py @@ -62,8 +62,8 @@ def validate_task_mode( """Validate that the request is compatible with the tool's task execution mode. Per MCP spec: - - "required": Clients MUST invoke as task. Server returns -32601 if not. - - "forbidden" (or None): Clients MUST NOT invoke as task. Server returns -32601 if they do. + - "required": Clients MUST invoke as a task. Server returns -32601 if not. + - "forbidden" (or None): Clients MUST NOT invoke as a task. Server returns -32601 if they do. - "optional": Either is acceptable. Args: @@ -111,7 +111,7 @@ def can_use_tool(self, tool_task_mode: TaskExecutionMode | None) -> bool: """Check if this client can use a tool with the given task mode. Useful for filtering tool lists or providing warnings. - Returns False if tool requires "required" but client doesn't support tasks. + Returns False if the tool's task mode is "required" but the client doesn't support tasks. Args: tool_task_mode: The tool's execution.taskSupport value @@ -164,7 +164,7 @@ async def handle_tool(ctx: RequestContext, params: CallToolRequestParams) -> Cal async def work(task: ServerTaskContext) -> CallToolResult: result = await task.elicit( message="Are you sure?", - requestedSchema={"type": "object", ...} + requested_schema={"type": "object", ...} ) confirmed = result.content.get("confirm", False) return CallToolResult(content=[TextContent(text="Done" if confirmed else "Cancelled")]) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 9ca5ac4fc9..aee6440402 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -88,9 +88,6 @@ def __init__(self, prompts_changed: bool = False, resources_changed: bool = Fals async def lifespan(_: Server[LifespanResultT]) -> AsyncIterator[dict[str, Any]]: """Default lifespan context manager that does nothing. - Args: - server: The server instance this lifespan is managing - Returns: An empty context object """ diff --git a/src/mcp/server/mcpserver/resources/types.py b/src/mcp/server/mcpserver/resources/types.py index 64e6338060..42aecd6e39 100644 --- a/src/mcp/server/mcpserver/resources/types.py +++ b/src/mcp/server/mcpserver/resources/types.py @@ -109,7 +109,7 @@ def from_function( class FileResource(Resource): """A resource that reads from a file. - Set is_binary=True to read file as binary data instead of text. + Set is_binary=True to read the file as binary data instead of text. """ path: Path = Field(description="Path to the file") diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index cd459589ad..21b6af7b9b 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -108,7 +108,7 @@ class Settings(BaseSettings, Generic[LifespanResultT]): warn_on_duplicate_prompts: bool lifespan: Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None - """A async context manager that will be called when the server is started.""" + """An async context manager that will be called when the server is started.""" auth: AuthSettings | None @@ -388,8 +388,10 @@ async def list_tools(self) -> list[MCPTool]: ] def get_context(self) -> Context[LifespanResultT, Request]: - """Returns a Context object. Note that the context will only be valid - during a request; outside a request, most methods will error. + """Return a Context object. + + Note that the context will only be valid during a request; outside a + request, most methods will error. """ try: request_context = request_ctx.get() @@ -475,6 +477,8 @@ def add_tool( title: Optional human-readable title for the tool description: Optional description of what the tool does annotations: Optional ToolAnnotations providing additional tool information + icons: Optional list of icons for the tool + meta: Optional metadata dictionary for the tool structured_output: Controls whether the tool's output is structured or unstructured - If None, auto-detects based on the function's return type annotation - If True, creates a structured tool (return type annotation permitting) @@ -523,6 +527,8 @@ def tool( title: Optional human-readable title for the tool description: Optional description of what the tool does annotations: Optional ToolAnnotations providing additional tool information + icons: Optional list of icons for the tool + meta: Optional metadata dictionary for the tool structured_output: Controls whether the tool's output is structured or unstructured - If None, auto-detects based on the function's return type annotation - If True, creates a structured tool (return type annotation permitting) @@ -534,8 +540,8 @@ def my_tool(x: int) -> str: return str(x) @server.tool() - def tool_with_context(x: int, ctx: Context) -> str: - ctx.info(f"Processing {x}") + async def tool_with_context(x: int, ctx: Context) -> str: + await ctx.info(f"Processing {x}") return str(x) @server.tool() @@ -636,6 +642,8 @@ def resource( title: Optional human-readable title for the resource description: Optional description of the resource mime_type: Optional MIME type for the resource + icons: Optional list of icons for the resource + annotations: Optional annotations for the resource meta: Optional metadata dictionary for the resource Example: @@ -644,7 +652,7 @@ def get_data() -> str: return "Hello, world!" @server.resource("resource://my-resource") - async get_data() -> str: + async def get_data() -> str: data = await fetch_data() return f"Hello, world! {data}" @@ -736,6 +744,7 @@ def prompt( name: Optional name for the prompt (defaults to function name) title: Optional human-readable title for the prompt description: Optional description of what the prompt does + icons: Optional list of icons for the prompt Example: @server.prompt() @@ -1092,18 +1101,18 @@ class Context(BaseModel, Generic[LifespanContextT, RequestT]): ```python @server.tool() - def my_tool(x: int, ctx: Context) -> str: + async def my_tool(x: int, ctx: Context) -> str: # Log messages to the client - ctx.info(f"Processing {x}") - ctx.debug("Debug info") - ctx.warning("Warning message") - ctx.error("Error message") + await ctx.info(f"Processing {x}") + await ctx.debug("Debug info") + await ctx.warning("Warning message") + await ctx.error("Error message") # Report progress - ctx.report_progress(50, 100) + await ctx.report_progress(50, 100) # Access resources - data = ctx.read_resource("resource://data") + data = await ctx.read_resource("resource://data") # Get request info request_id = ctx.request_id @@ -1149,9 +1158,9 @@ async def report_progress(self, progress: float, total: float | None = None, mes """Report progress for the current operation. Args: - progress: Current progress value e.g. 24 - total: Optional total value e.g. 100 - message: Optional message e.g. Starting render... + progress: Current progress value (e.g., 24) + total: Optional total value (e.g., 100) + message: Optional message (e.g., "Starting render...") """ progress_token = self.request_context.meta.get("progress_token") if self.request_context.meta else None @@ -1187,15 +1196,14 @@ async def elicit( This method can be used to interactively ask for additional information from the client within a tool's execution. The client might display the message to the - user and collect a response according to the provided schema. Or in case a - client is an agent, it might decide how to handle the elicitation -- either by asking + user and collect a response according to the provided schema. If the client + is an agent, it might decide how to handle the elicitation -- either by asking the user or automatically generating a response. Args: - schema: A Pydantic model class defining the expected response structure, according to the specification, - only primitive types are allowed. - message: Optional message to present to the user. If not provided, will use - a default message based on the schema + message: Message to present to the user + schema: A Pydantic model class defining the expected response structure. + According to the specification, only primitive types are allowed. Returns: An ElicitationResult containing the action taken and the data if accepted @@ -1229,7 +1237,7 @@ async def elicit_url( The response indicates whether the user consented to navigate to the URL. The actual interaction happens out-of-band. When the elicitation completes, - call `self.session.send_elicit_complete(elicitation_id)` to notify the client. + call `ctx.session.send_elicit_complete(elicitation_id)` to notify the client. Args: message: Human-readable explanation of why the interaction is needed @@ -1299,7 +1307,7 @@ async def close_sse_stream(self) -> None: be replayed when the client reconnects with Last-Event-ID. Use this to implement polling behavior during long-running operations - - client will reconnect after the retry interval specified in the priming event. + the client will reconnect after the retry interval specified in the priming event. Note: This is a no-op if not using StreamableHTTP transport with event_store. diff --git a/src/mcp/server/mcpserver/utilities/func_metadata.py b/src/mcp/server/mcpserver/utilities/func_metadata.py index 4b539ce1f2..062b47d0ff 100644 --- a/src/mcp/server/mcpserver/utilities/func_metadata.py +++ b/src/mcp/server/mcpserver/utilities/func_metadata.py @@ -46,7 +46,7 @@ class ArgModelBase(BaseModel): def model_dump_one_level(self) -> dict[str, Any]: """Return a dict of the model's fields, one level deep. - That is, sub-models etc are not dumped - they are kept as pydantic models. + That is, sub-models etc are not dumped - they are kept as Pydantic models. """ kwargs: dict[str, Any] = {} for field_name, field_info in self.__class__.model_fields.items(): @@ -89,8 +89,7 @@ async def call_fn_with_arg_validation( return await anyio.to_thread.run_sync(functools.partial(fn, **arguments_parsed_dict)) def convert_result(self, result: Any) -> Any: - """Convert the result of a function call to the appropriate format for - the lowlevel server tool call handler: + """Convert a function call result to the format for the lowlevel tool call handler. - If output_model is None, return the unstructured content directly. - If output_model is not None, convert the result to structured output format @@ -126,11 +125,11 @@ def convert_result(self, result: Any) -> Any: def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]: """Pre-parse data from JSON. - Return a dict with same keys as input but with values parsed from JSON + Return a dict with the same keys as input but with values parsed from JSON if appropriate. This is to handle cases like `["a", "b", "c"]` being passed in as JSON inside - a string rather than an actual list. Claude desktop is prone to this - in fact + a string rather than an actual list. Claude Desktop is prone to this - in fact it seems incapable of NOT doing this. For sub-models, it tends to pass dicts (JSON objects) as JSON strings, which can be pre-parsed here. """ @@ -173,8 +172,7 @@ def func_metadata( skip_names: Sequence[str] = (), structured_output: bool | None = None, ) -> FuncMetadata: - """Given a function, return metadata including a pydantic model representing its - signature. + """Given a function, return metadata including a Pydantic model representing its signature. The use case for this is ``` @@ -183,11 +181,11 @@ def func_metadata( return func(**validated_args.model_dump_one_level()) ``` - **critically** it also provides pre-parse helper to attempt to parse things from + **critically** it also provides a pre-parse helper to attempt to parse things from JSON. Args: - func: The function to convert to a pydantic model + func: The function to convert to a Pydantic model skip_names: A list of parameter names to skip. These will not be included in the model. structured_output: Controls whether the tool's output is structured or unstructured @@ -195,8 +193,8 @@ def func_metadata( - If True, creates a structured tool (return type annotation permitting) - If False, unconditionally creates an unstructured tool - If structured, creates a Pydantic model for the function's result based on its annotation. - Supports various return types: + If structured, creates a Pydantic model for the function's result based on its annotation. + Supports various return types: - BaseModel subclasses (used directly) - Primitive types (str, int, float, bool, bytes, None) - wrapped in a model with a 'result' field @@ -206,9 +204,9 @@ def func_metadata( Returns: A FuncMetadata object containing: - - arg_model: A pydantic model representing the function's arguments - - output_model: A pydantic model for the return type if output is structured - - output_conversion: Records how function output should be converted before returning. + - arg_model: A Pydantic model representing the function's arguments + - output_model: A Pydantic model for the return type if the output is structured + - wrap_output: Whether the function result needs to be wrapped in `{"result": ...}` for structured output. """ try: sig = inspect.signature(func, eval_str=True) @@ -296,7 +294,7 @@ def func_metadata( ] # pragma: no cover else: # We only had `Annotated[CallToolResult, ReturnType]`, treat the original annotation - # as beging `ReturnType`: + # as being `ReturnType`: original_annotation = return_type_expr else: return FuncMetadata(arg_model=arguments_model) @@ -355,7 +353,7 @@ def _try_create_model_and_schema( if origin is dict: args = get_args(type_expr) if len(args) == 2 and args[0] is str: - # TODO: should we use the original annotation? We are loosing any potential `Annotated` + # TODO: should we use the original annotation? We are losing any potential `Annotated` # metadata for Pydantic here: model = _create_dict_model(func_name, type_expr) else: diff --git a/src/mcp/server/mcpserver/utilities/logging.py b/src/mcp/server/mcpserver/utilities/logging.py index c394f2bfaf..04ca38853b 100644 --- a/src/mcp/server/mcpserver/utilities/logging.py +++ b/src/mcp/server/mcpserver/utilities/logging.py @@ -8,10 +8,10 @@ def get_logger(name: str) -> logging.Logger: """Get a logger nested under MCP namespace. Args: - name: the name of the logger + name: The name of the logger. Returns: - a configured logger instance + A configured logger instance. """ return logging.getLogger(name) @@ -22,7 +22,7 @@ def configure_logging( """Configure logging for MCP. Args: - level: the log level to use + level: The log level to use. """ handlers: list[logging.Handler] = [] try: diff --git a/src/mcp/server/models.py b/src/mcp/server/models.py index 41b9224c1d..3861f42a7e 100644 --- a/src/mcp/server/models.py +++ b/src/mcp/server/models.py @@ -1,4 +1,4 @@ -"""This module provides simpler types to use with the server for managing prompts +"""This module provides simplified types to use with the server for managing prompts and tools. """ diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 6925aa556b..759d2131a1 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -363,12 +363,12 @@ async def elicit( """Send a form mode elicitation/create request. Args: - message: The message to present to the user - requested_schema: Schema defining the expected response structure - related_request_id: Optional ID of the request that triggered this elicitation + message: The message to present to the user. + requested_schema: Schema defining the expected response structure. + related_request_id: Optional ID of the request that triggered this elicitation. Returns: - The client's response + The client's response. Note: This method is deprecated in favor of elicit_form(). It remains for @@ -385,12 +385,12 @@ async def elicit_form( """Send a form mode elicitation/create request. Args: - message: The message to present to the user - requested_schema: Schema defining the expected response structure - related_request_id: Optional ID of the request that triggered this elicitation + message: The message to present to the user. + requested_schema: Schema defining the expected response structure. + related_request_id: Optional ID of the request that triggered this elicitation. Returns: - The client's response with form data + The client's response with form data. Raises: StatelessModeNotSupported: If called in stateless HTTP mode. @@ -421,13 +421,13 @@ async def elicit_url( like OAuth flows, credential collection, or payment processing. Args: - message: Human-readable explanation of why the interaction is needed - url: The URL the user should navigate to - elicitation_id: Unique identifier for tracking this elicitation - related_request_id: Optional ID of the request that triggered this elicitation + message: Human-readable explanation of why the interaction is needed. + url: The URL the user should navigate to. + elicitation_id: Unique identifier for tracking this elicitation. + related_request_id: Optional ID of the request that triggered this elicitation. Returns: - The client's response indicating acceptance, decline, or cancellation + The client's response indicating acceptance, decline, or cancellation. Raises: StatelessModeNotSupported: If called in stateless HTTP mode. @@ -499,7 +499,7 @@ async def send_elicit_complete( Args: elicitation_id: The unique identifier of the completed elicitation - related_request_id: Optional ID of the request that triggered this + related_request_id: Optional ID of the request that triggered this notification """ await self.send_notification( types.ElicitCompleteNotification( diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 674294c5c3..827ec3591f 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -29,8 +29,8 @@ async def handle_sse(request): uvicorn.run(starlette_app, host="127.0.0.1", port=port) ``` -Note: The handle_sse function must return a Response to avoid a "TypeError: 'NoneType' -object is not callable" error when client disconnects. The example above returns +Note: The handle_sse function must return a Response to avoid a +"TypeError: 'NoneType' object is not callable" error when client disconnects. The example above returns an empty Response() after the SSE connection ends to fix this. See SseServerTransport class documentation for more details. @@ -61,8 +61,8 @@ async def handle_sse(request): class SseServerTransport: - """SSE server transport for MCP. This class provides _two_ ASGI applications, - suitable to be used with a framework like Starlette and a server like Hypercorn: + """SSE server transport for MCP. This class provides two ASGI applications, + suitable for use with a framework like Starlette and a server like Hypercorn: 1. connect_sse() is an ASGI application which receives incoming GET requests, and sets up a new SSE stream to send server messages to the client. diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index bcee3a4748..04aed345e0 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -89,7 +89,7 @@ async def store_event(self, stream_id: StreamId, message: JSONRPCMessage | None) message: The JSON-RPC message to store, or None for priming events Returns: - The generated event ID for the stored event + The generated event ID for the stored event. """ pass # pragma: no cover @@ -106,7 +106,7 @@ async def replay_events_after( send_callback: A callback function to send events to the client Returns: - The stream ID of the replayed events + The stream ID of the replayed events, or None if no events were found. """ pass # pragma: no cover @@ -185,7 +185,7 @@ def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover be replayed when the client reconnects with Last-Event-ID. Use this to implement polling behavior during long-running operations - - client will reconnect after the retry interval specified in the priming event. + the client will reconnect after the retry interval specified in the priming event. Args: request_id: The request ID whose SSE stream should be closed. @@ -213,7 +213,7 @@ def close_standalone_sse_stream(self) -> None: # pragma: no cover with Last-Event-ID to resume receiving notifications. Use this to implement polling behavior for the notification stream - - client will reconnect after the retry interval specified in the priming event. + the client will reconnect after the retry interval specified in the priming event. Note: This is a no-op if there is no active standalone SSE stream. @@ -316,7 +316,7 @@ def _create_json_response( status_code: HTTPStatus = HTTPStatus.OK, headers: dict[str, str] | None = None, ) -> Response: - """Create a JSON response from a JSONRPCMessage""" + """Create a JSON response from a JSONRPCMessage.""" response_headers = {"Content-Type": CONTENT_TYPE_JSON} if headers: # pragma: lax no cover response_headers.update(headers) @@ -362,7 +362,7 @@ async def _clean_up_memory_streams(self, request_id: RequestId) -> None: self._request_streams.pop(request_id, None) async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: - """Application entry point that handles all HTTP requests""" + """Application entry point that handles all HTTP requests.""" request = Request(scope, receive) # Validate request headers for DNS rebinding protection @@ -536,7 +536,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re if isinstance(event_message.message, JSONRPCResponse | JSONRPCError): response_message = event_message.message break - # For notifications and request, keep waiting + # For notifications and requests, keep waiting else: # pragma: no cover logger.debug(f"received: {event_message.message.method}") @@ -860,6 +860,7 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: # pragma: no cover """Replays events that would have been sent after the specified event ID. + Only used when resumability is enabled. """ event_store = self._event_store diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index 7b00f79055..3e675da5fd 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -12,8 +12,8 @@ @asynccontextmanager # pragma: no cover async def websocket_server(scope: Scope, receive: Receive, send: Send): - """WebSocket server transport for MCP. This is an ASGI application, suitable to be - used with a framework like Starlette and a server like Hypercorn. + """WebSocket server transport for MCP. This is an ASGI application, suitable for use + with a framework like Starlette and a server like Hypercorn. """ websocket = WebSocket(scope, receive, send) diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index bf03a8b8dd..ca5b7b45ab 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -33,9 +33,8 @@ def __init__(self, message: str): class OAuthClientMetadata(BaseModel): - """RFC 7591 OAuth 2.0 Dynamic Client Registration metadata. + """RFC 7591 OAuth 2.0 Dynamic Client Registration Metadata. See https://datatracker.ietf.org/doc/html/rfc7591#section-2 - for the full specification. """ redirect_uris: list[AnyUrl] | None = Field(..., min_length=1) @@ -145,9 +144,9 @@ class ProtectedResourceMetadata(BaseModel): resource_documentation: AnyHttpUrl | None = None resource_policy_uri: AnyHttpUrl | None = None resource_tos_uri: AnyHttpUrl | None = None - # tls_client_certificate_bound_access_tokens default is False, but ommited here for clarity + # tls_client_certificate_bound_access_tokens default is False, but omitted here for clarity tls_client_certificate_bound_access_tokens: bool | None = None authorization_details_types_supported: list[str] | None = None dpop_signing_alg_values_supported: list[str] | None = None - # dpop_bound_access_tokens_required default is False, but ommited here for clarity + # dpop_bound_access_tokens_required default is False, but omitted here for clarity dpop_bound_access_tokens_required: bool | None = None diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index d01d28b808..f2d5e2b9ad 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -17,7 +17,7 @@ async def create_client_server_memory_streams() -> AsyncGenerator[tuple[MessageStream, MessageStream], None]: """Creates a pair of bidirectional memory streams for client-server communication. - Returns: + Yields: A tuple of (client_streams, server_streams) where each is a tuple of (read_stream, write_stream) """ diff --git a/src/mcp/shared/metadata_utils.py b/src/mcp/shared/metadata_utils.py index 2b66996bde..3b6b27dfeb 100644 --- a/src/mcp/shared/metadata_utils.py +++ b/src/mcp/shared/metadata_utils.py @@ -1,7 +1,7 @@ """Utility functions for working with metadata in MCP types. These utilities are primarily intended for client-side usage to properly display -human-readable names in user interfaces in a spec compliant way. +human-readable names in user interfaces in a spec-compliant way. """ from mcp.types import Implementation, Prompt, Resource, ResourceTemplate, Tool diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 5ee8f3baad..d3f36ded4b 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -115,6 +115,7 @@ async def respond(self, response: SendResultT | ErrorData) -> None: """Send a response for this request. Must be called within a context manager block. + Raises: RuntimeError: If not used within a context manager AssertionError: If request was already responded to @@ -235,7 +236,7 @@ async def send_request( metadata: MessageMetadata = None, progress_callback: ProgressFnT | None = None, ) -> ReceiveResultT: - """Sends a request and wait for a response. + """Sends a request and waits for a response. Raises an MCPError if the response contains an error. If a request read timeout is provided, it will take precedence over the session read timeout. @@ -512,4 +513,4 @@ async def send_progress_notification( async def _handle_incoming( self, req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception ) -> None: - """A generic handler for incoming messages. Overwritten by subclasses.""" + """A generic handler for incoming messages. Overridden by subclasses.""" diff --git a/src/mcp/types/_types.py b/src/mcp/types/_types.py index 320422636a..9005d253af 100644 --- a/src/mcp/types/_types.py +++ b/src/mcp/types/_types.py @@ -22,7 +22,7 @@ provided by the client. See the "Protocol Version Header" at -https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#protocol-version-header). +https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#protocol-version-header. """ ProgressToken = str | int @@ -108,8 +108,7 @@ class Request(MCPModel, Generic[RequestParamsT, MethodT]): class PaginatedRequest(Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]): - """Base class for paginated requests, - matching the schema's PaginatedRequest interface.""" + """Base class for paginated requests, matching the schema's PaginatedRequest interface.""" params: PaginatedRequestParams | None = None @@ -174,10 +173,10 @@ class Icon(MCPModel): theme: IconTheme | None = None """Optional theme specifier. - - `"light"` indicates the icon is designed for a light background, `"dark"` indicates the icon + + `"light"` indicates the icon is designed for a light background, `"dark"` indicates the icon is designed for a dark background. - + See https://modelcontextprotocol.io/specification/2025-11-25/schema#icon for more details. """ @@ -536,7 +535,7 @@ class TaskStatusNotificationParams(NotificationParams, Task): class TaskStatusNotification(Notification[TaskStatusNotificationParams, Literal["notifications/tasks/status"]]): """An optional notification from the receiver to the requestor, informing them that a task's status has changed. - Receivers are not required to send these notifications + Receivers are not required to send these notifications. """ method: Literal["notifications/tasks/status"] = "notifications/tasks/status" @@ -608,7 +607,7 @@ class ProgressNotificationParams(NotificationParams): message: str | None = None """Message related to progress. - This should provide relevant human readable progress information. + This should provide relevant human-readable progress information. """ @@ -999,7 +998,9 @@ class ToolResultContent(MCPModel): SamplingContent: TypeAlias = TextContent | ImageContent | AudioContent """Basic content types for sampling responses (without tool use). -Used for backwards-compatible CreateMessageResult when tools are not used.""" + +Used for backwards-compatible CreateMessageResult when tools are not used. +""" class SamplingMessage(MCPModel): @@ -1117,7 +1118,7 @@ class ToolAnnotations(MCPModel): idempotent_hint: bool | None = None """ If true, calling the tool repeatedly with the same arguments - will have no additional effect on the its environment. + will have no additional effect on its environment. (This property is meaningful only when `read_only_hint == false`) Default: false """ @@ -1265,7 +1266,7 @@ class ModelPreferences(MCPModel): sampling. Because LLMs can vary along multiple dimensions, choosing the "best" model is - rarely straightforward. Different models excel in different areas—some are + rarely straightforward. Different models excel in different areas—some are faster but less capable, others are more capable but more expensive, and so on. This interface allows servers to express their priorities across multiple dimensions to help clients make an appropriate selection for their use case. @@ -1369,7 +1370,7 @@ class CreateMessageRequest(Request[CreateMessageRequestParams, Literal["sampling class CreateMessageResult(Result): - """The client's response to a sampling/create_message request from the server. + """The client's response to a sampling/createMessage request from the server. This is the backwards-compatible version that returns single content (no arrays). Used when the request does not include tools. @@ -1386,7 +1387,7 @@ class CreateMessageResult(Result): class CreateMessageResultWithTools(Result): - """The client's response to a sampling/create_message request when tools were provided. + """The client's response to a sampling/createMessage request when tools were provided. This version supports array content for tool use flows. """ @@ -1426,14 +1427,14 @@ class PromptReference(MCPModel): type: Literal["ref/prompt"] = "ref/prompt" name: str - """The name of the prompt or prompt template""" + """The name of the prompt or prompt template.""" class CompletionArgument(MCPModel): """The argument's information for completion requests.""" name: str - """The name of the argument""" + """The name of the argument.""" value: str """The value of the argument to use for completion matching.""" @@ -1451,7 +1452,7 @@ class CompleteRequestParams(RequestParams): ref: ResourceTemplateReference | PromptReference argument: CompletionArgument context: CompletionContext | None = None - """Additional, optional context for completions""" + """Additional, optional context for completions.""" class CompleteRequest(Request[CompleteRequestParams, Literal["completion/complete"]]): @@ -1479,7 +1480,7 @@ class Completion(MCPModel): class CompleteResult(Result): - """The server's response to a completion/complete request""" + """The server's response to a completion/complete request.""" completion: Completion @@ -1522,6 +1523,7 @@ class Root(MCPModel): class ListRootsResult(Result): """The client's response to a roots/list request from the server. + This result contains an array of Root objects, each representing a root directory or file that the server can operate on. """ @@ -1643,7 +1645,7 @@ class ElicitRequestFormParams(RequestParams): requested_schema: ElicitRequestedSchema """ - A restricted subset of JSON Schema defining the structure of expected response. + A restricted subset of JSON Schema defining the structure of the expected response. Only top-level properties are allowed, without nesting. """ @@ -1697,8 +1699,8 @@ class ElicitResult(Result): content: dict[str, str | int | float | bool | list[str] | None] | None = None """ The submitted form data, only present when action is "accept" in form mode. - Contains values matching the requested schema. Values can be strings, integers, - booleans, or arrays of strings. + Contains values matching the requested schema. Values can be strings, integers, floats, + booleans, arrays of strings, or null. For URL mode, this field is omitted. """ From cb07adeca345effa111598889f6a8924a8722e6d Mon Sep 17 00:00:00 2001 From: Jonathan Hefner Date: Thu, 19 Feb 2026 14:06:11 -0600 Subject: [PATCH 18/84] docs: add code fences to `Example:` docstring blocks (#2104) --- src/mcp/client/experimental/task_handlers.py | 2 ++ src/mcp/client/experimental/tasks.py | 6 +++++ src/mcp/client/session.py | 2 ++ src/mcp/client/session_group.py | 5 +++-- .../server/experimental/request_context.py | 2 ++ src/mcp/server/experimental/task_context.py | 2 ++ src/mcp/server/experimental/task_support.py | 10 +++++++-- src/mcp/server/lowlevel/experimental.py | 10 +++++++-- src/mcp/server/mcpserver/server.py | 14 ++++++++++++ src/mcp/server/sse.py | 6 ++--- src/mcp/server/stdio.py | 6 ++--- src/mcp/shared/_httpx_utils.py | 22 ++++++++++++++----- src/mcp/shared/exceptions.py | 2 ++ src/mcp/shared/experimental/tasks/helpers.py | 2 ++ src/mcp/shared/metadata_utils.py | 2 ++ src/mcp/shared/response_router.py | 2 ++ src/mcp/shared/session.py | 2 ++ tests/client/conftest.py | 4 +++- 18 files changed, 83 insertions(+), 18 deletions(-) diff --git a/src/mcp/client/experimental/task_handlers.py b/src/mcp/client/experimental/task_handlers.py index 28ff2b1f29..0ab513236a 100644 --- a/src/mcp/client/experimental/task_handlers.py +++ b/src/mcp/client/experimental/task_handlers.py @@ -187,11 +187,13 @@ class ExperimentalTaskHandlers: WARNING: These APIs are experimental and may change without notice. Example: + ```python handlers = ExperimentalTaskHandlers( get_task=my_get_task_handler, list_tasks=my_list_tasks_handler, ) session = ClientSession(..., experimental_task_handlers=handlers) + ``` """ # Pure task request handlers diff --git a/src/mcp/client/experimental/tasks.py b/src/mcp/client/experimental/tasks.py index 2e2fdf7355..a566df766b 100644 --- a/src/mcp/client/experimental/tasks.py +++ b/src/mcp/client/experimental/tasks.py @@ -5,6 +5,7 @@ WARNING: These APIs are experimental and may change without notice. Example: + ```python # Call a tool as a task result = await session.experimental.call_tool_as_task("tool_name", {"arg": "value"}) task_id = result.task.task_id @@ -21,6 +22,7 @@ # Cancel a task await session.experimental.cancel_task(task_id) + ``` """ from collections.abc import AsyncIterator @@ -72,6 +74,7 @@ async def call_tool_as_task( CreateTaskResult containing the task reference Example: + ```python # Create task result = await session.experimental.call_tool_as_task( "long_running_tool", {"input": "data"} @@ -87,6 +90,7 @@ async def call_tool_as_task( # Get result final = await session.experimental.get_task_result(task_id, CallToolResult) + ``` """ return await self._session.send_request( types.CallToolRequest( @@ -189,6 +193,7 @@ async def poll_task(self, task_id: str) -> AsyncIterator[types.GetTaskResult]: GetTaskResult for each poll Example: + ```python async for status in session.experimental.poll_task(task_id): print(f"Status: {status.status}") if status.status == "input_required": @@ -197,6 +202,7 @@ async def poll_task(self, task_id: str) -> AsyncIterator[types.GetTaskResult]: # Task is now terminal, get the result result = await session.experimental.get_task_result(task_id, CallToolResult) + ``` """ async for status in poll_until_terminal(self.get_task, task_id): yield status diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 0687f98c3a..a0ca751bd7 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -206,8 +206,10 @@ def experimental(self) -> ExperimentalClientFeatures: These APIs are experimental and may change without notice. Example: + ```python status = await session.experimental.get_task(task_id) result = await session.experimental.get_task_result(task_id, CallToolResult) + ``` """ if self._experimental_features is None: self._experimental_features = ExperimentalClientFeatures(self) diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 17f41025bd..9610212642 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -91,13 +91,14 @@ class ClientSessionGroup: For auxiliary handlers, such as resource subscription, this is delegated to the client and can be accessed via the session. - Example Usage: + Example: + ```python name_fn = lambda name, server_info: f"{(server_info.name)}_{name}" async with ClientSessionGroup(component_name_hook=name_fn) as group: for server_param in server_params: await group.connect_to_server(server_param) ... - + ``` """ class _ComponentNames(BaseModel): diff --git a/src/mcp/server/experimental/request_context.py b/src/mcp/server/experimental/request_context.py index 138c021108..3eba65822a 100644 --- a/src/mcp/server/experimental/request_context.py +++ b/src/mcp/server/experimental/request_context.py @@ -160,6 +160,7 @@ async def run_task( RuntimeError: If task support is not enabled or task_metadata is missing Example: + ```python async def handle_tool(ctx: RequestContext, params: CallToolRequestParams) -> CallToolResult: async def work(task: ServerTaskContext) -> CallToolResult: result = await task.elicit( @@ -170,6 +171,7 @@ async def work(task: ServerTaskContext) -> CallToolResult: return CallToolResult(content=[TextContent(text="Done" if confirmed else "Cancelled")]) return await ctx.experimental.run_task(work) + ``` WARNING: This API is experimental and may change without notice. """ diff --git a/src/mcp/server/experimental/task_context.py b/src/mcp/server/experimental/task_context.py index 9b626c9862..1fc45badfd 100644 --- a/src/mcp/server/experimental/task_context.py +++ b/src/mcp/server/experimental/task_context.py @@ -56,6 +56,7 @@ class ServerTaskContext: - Status notifications via the session Example: + ```python async def my_task_work(task: ServerTaskContext) -> CallToolResult: await task.update_status("Starting...") @@ -68,6 +69,7 @@ async def my_task_work(task: ServerTaskContext) -> CallToolResult: return CallToolResult(content=[TextContent(text="Done!")]) else: return CallToolResult(content=[TextContent(text="Cancelled")]) + ``` """ def __init__( diff --git a/src/mcp/server/experimental/task_support.py b/src/mcp/server/experimental/task_support.py index 23b5d9cc89..b542195048 100644 --- a/src/mcp/server/experimental/task_support.py +++ b/src/mcp/server/experimental/task_support.py @@ -31,14 +31,20 @@ class TaskSupport: - Manages a task group for background task execution Example: - # Simple in-memory setup + Simple in-memory setup: + + ```python server.experimental.enable_tasks() + ``` + + Custom store/queue for distributed systems: - # Custom store/queue for distributed systems + ```python server.experimental.enable_tasks( store=RedisTaskStore(redis_url), queue=RedisTaskMessageQueue(redis_url), ) + ``` """ store: TaskStore diff --git a/src/mcp/server/lowlevel/experimental.py b/src/mcp/server/lowlevel/experimental.py index 8ac2687280..5a907b6407 100644 --- a/src/mcp/server/lowlevel/experimental.py +++ b/src/mcp/server/lowlevel/experimental.py @@ -118,14 +118,20 @@ def enable_tasks( The TaskSupport configuration object Example: - # Simple in-memory setup + Simple in-memory setup: + + ```python server.experimental.enable_tasks() + ``` + + Custom store/queue for distributed systems: - # Custom store/queue for distributed systems + ```python server.experimental.enable_tasks( store=RedisTaskStore(redis_url), queue=RedisTaskMessageQueue(redis_url), ) + ``` WARNING: This API is experimental and may change without notice. """ diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index 21b6af7b9b..9c7105a7b4 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -535,19 +535,25 @@ def tool( - If False, unconditionally creates an unstructured tool Example: + ```python @server.tool() def my_tool(x: int) -> str: return str(x) + ``` + ```python @server.tool() async def tool_with_context(x: int, ctx: Context) -> str: await ctx.info(f"Processing {x}") return str(x) + ``` + ```python @server.tool() async def async_tool(x: int, context: Context) -> str: await context.report_progress(50, 100) return str(x) + ``` """ # Check if user passed function directly instead of calling decorator if callable(name): @@ -579,12 +585,14 @@ def completion(self): - context: Optional CompletionContext with previously resolved arguments Example: + ```python @mcp.completion() async def handle_completion(ref, argument, context): if isinstance(ref, ResourceTemplateReference): # Return completions based on ref, argument, and context return Completion(values=["option1", "option2"]) return None + ``` """ def decorator(func: _CallableT) -> _CallableT: @@ -647,6 +655,7 @@ def resource( meta: Optional metadata dictionary for the resource Example: + ```python @server.resource("resource://my-resource") def get_data() -> str: return "Hello, world!" @@ -664,6 +673,7 @@ def get_weather(city: str) -> str: async def get_weather(city: str) -> str: data = await fetch_weather(city) return f"Weather for {city}: {data}" + ``` """ # Check if user passed function directly instead of calling decorator if callable(uri): @@ -747,6 +757,7 @@ def prompt( icons: Optional list of icons for the prompt Example: + ```python @server.prompt() def analyze_table(table_name: str) -> list[Message]: schema = read_table_schema(table_name) @@ -772,6 +783,7 @@ async def analyze_file(path: str) -> list[Message]: } } ] + ``` """ # Check if user passed function directly instead of calling decorator if callable(name): @@ -813,9 +825,11 @@ def custom_route( include_in_schema: Whether to include in OpenAPI schema, defaults to True Example: + ```python @server.custom_route("/health", methods=["GET"]) async def health_check(request: Request) -> Response: return JSONResponse({"status": "ok"}) + ``` """ def decorator( # pragma: no cover diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 827ec3591f..9007230cea 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -2,8 +2,8 @@ This module implements a Server-Sent Events (SSE) transport layer for MCP servers. -Example usage: -``` +Example: + ```python # Create an SSE transport at an endpoint sse = SseServerTransport("/messages/") @@ -27,7 +27,7 @@ async def handle_sse(request): # Create and run Starlette app starlette_app = Starlette(routes=routes) uvicorn.run(starlette_app, host="127.0.0.1", port=port) -``` + ``` Note: The handle_sse function must return a Response to avoid a "TypeError: 'NoneType' object is not callable" error when client disconnects. The example above returns diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index 864d387bdf..e526bab569 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -4,8 +4,8 @@ that can be used to communicate with an MCP client through standard input/output streams. -Example usage: -``` +Example: + ```python async def run_server(): async with stdio_server() as (read_stream, write_stream): # read_stream contains incoming JSONRPCMessages from stdin @@ -14,7 +14,7 @@ async def run_server(): await server.run(read_stream, write_stream, init_options) anyio.run(run_server) -``` + ``` """ import sys diff --git a/src/mcp/shared/_httpx_utils.py b/src/mcp/shared/_httpx_utils.py index 8cf7bda2ad..251469eaa1 100644 --- a/src/mcp/shared/_httpx_utils.py +++ b/src/mcp/shared/_httpx_utils.py @@ -44,26 +44,38 @@ def create_mcp_http_client( The returned AsyncClient must be used as a context manager to ensure proper cleanup of connections. - Examples: - # Basic usage with MCP defaults + Example: + Basic usage with MCP defaults: + + ```python async with create_mcp_http_client() as client: response = await client.get("https://api.example.com") + ``` + + With custom headers: - # With custom headers + ```python headers = {"Authorization": "Bearer token"} async with create_mcp_http_client(headers) as client: response = await client.get("/endpoint") + ``` - # With both custom headers and timeout + With both custom headers and timeout: + + ```python timeout = httpx.Timeout(60.0, read=300.0) async with create_mcp_http_client(headers, timeout) as client: response = await client.get("/long-request") + ``` + + With authentication: - # With authentication + ```python from httpx import BasicAuth auth = BasicAuth(username="user", password="pass") async with create_mcp_http_client(headers, timeout, auth) as client: response = await client.get("/protected-endpoint") + ``` """ # Set MCP defaults kwargs: dict[str, Any] = {"follow_redirects": True} diff --git a/src/mcp/shared/exceptions.py b/src/mcp/shared/exceptions.py index 6c3a7745c1..f153ea319d 100644 --- a/src/mcp/shared/exceptions.py +++ b/src/mcp/shared/exceptions.py @@ -65,6 +65,7 @@ class UrlElicitationRequiredError(MCPError): must complete one or more URL elicitations before the request can be processed. Example: + ```python raise UrlElicitationRequiredError([ ElicitRequestURLParams( message="Authorization required for your files", @@ -72,6 +73,7 @@ class UrlElicitationRequiredError(MCPError): elicitation_id="auth-001" ) ]) + ``` """ def __init__(self, elicitations: list[ElicitRequestURLParams], message: str | None = None): diff --git a/src/mcp/shared/experimental/tasks/helpers.py b/src/mcp/shared/experimental/tasks/helpers.py index bd1781cb57..3f91cd0d06 100644 --- a/src/mcp/shared/experimental/tasks/helpers.py +++ b/src/mcp/shared/experimental/tasks/helpers.py @@ -72,8 +72,10 @@ async def cancel_task( - Task is already in a terminal state (completed, failed, cancelled) Example: + ```python async def handle_cancel(ctx, params: CancelTaskRequestParams) -> CancelTaskResult: return await cancel_task(store, params.task_id) + ``` """ task = await store.get_task(task_id) if task is None: diff --git a/src/mcp/shared/metadata_utils.py b/src/mcp/shared/metadata_utils.py index 3b6b27dfeb..6e4d33da0f 100644 --- a/src/mcp/shared/metadata_utils.py +++ b/src/mcp/shared/metadata_utils.py @@ -18,11 +18,13 @@ def get_display_name(obj: Tool | Resource | Prompt | ResourceTemplate | Implemen For other objects: title > name Example: + ```python # In a client displaying available tools tools = await session.list_tools() for tool in tools.tools: display_name = get_display_name(tool) print(f"Available tool: {display_name}") + ``` Args: obj: An MCP object with name and optional title fields diff --git a/src/mcp/shared/response_router.py b/src/mcp/shared/response_router.py index 7ec4a443c1..fe24b016f1 100644 --- a/src/mcp/shared/response_router.py +++ b/src/mcp/shared/response_router.py @@ -25,6 +25,7 @@ class ResponseRouter(Protocol): and deliver the response/error to the appropriate handler. Example: + ```python class TaskResultHandler(ResponseRouter): def route_response(self, request_id, response): resolver = self._pending_requests.pop(request_id, None) @@ -32,6 +33,7 @@ def route_response(self, request_id, response): resolver.set_result(response) return True return False + ``` """ def route_response(self, request_id: RequestId, response: dict[str, Any]) -> bool: diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index d3f36ded4b..b617d702fe 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -60,8 +60,10 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): cancellation handling: Example: + ```python with request_responder as resp: await resp.respond(result) + ``` The context manager ensures: 1. Proper cancellation scope setup and cleanup diff --git a/tests/client/conftest.py b/tests/client/conftest.py index 268e968aa4..2e39f13630 100644 --- a/tests/client/conftest.py +++ b/tests/client/conftest.py @@ -77,7 +77,8 @@ def get_server_notifications(self, method: str | None = None) -> list[JSONRPCNot def stream_spy() -> Generator[Callable[[], StreamSpyCollection], None, None]: """Fixture that provides spies for both client and server write streams. - Example usage: + Example: + ```python async def test_something(stream_spy): # ... set up server and client ... @@ -92,6 +93,7 @@ async def test_something(stream_spy): # Clear for the next operation spies.clear() + ``` """ client_spy = None server_spy = None From 0fe16dd5fddf71f6b07734748172c40b107d6932 Mon Sep 17 00:00:00 2001 From: Jonathan Hefner Date: Thu, 19 Feb 2026 16:12:51 -0600 Subject: [PATCH 19/84] fix: silence mkdocs social plugin warnings in strict mode (#2109) --- .github/workflows/publish-docs-manually.yml | 2 + mkdocs.yml | 3 +- pyproject.toml | 2 +- uv.lock | 87 +++++++++++++++++++-- 4 files changed, 87 insertions(+), 7 deletions(-) diff --git a/.github/workflows/publish-docs-manually.yml b/.github/workflows/publish-docs-manually.yml index f058174ab1..ee45ab5c8a 100644 --- a/.github/workflows/publish-docs-manually.yml +++ b/.github/workflows/publish-docs-manually.yml @@ -31,3 +31,5 @@ jobs: - run: uv sync --frozen --group docs - run: uv run --frozen --no-sync mkdocs gh-deploy --force + env: + ENABLE_SOCIAL_CARDS: "true" diff --git a/mkdocs.yml b/mkdocs.yml index 3019f5214b..070c533e31 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -112,7 +112,8 @@ watch: plugins: - search - - social + - social: + enabled: !ENV [ENABLE_SOCIAL_CARDS, false] - glightbox - mkdocstrings: handlers: diff --git a/pyproject.toml b/pyproject.toml index 008ee4c957..737839a23c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,7 @@ dev = [ docs = [ "mkdocs>=1.6.1", "mkdocs-glightbox>=0.4.0", - "mkdocs-material>=9.5.45", + "mkdocs-material[imaging]>=9.5.45", "mkdocstrings-python>=2.0.1", ] diff --git a/uv.lock b/uv.lock index 5d3a83f376..d01d510f17 100644 --- a/uv.lock +++ b/uv.lock @@ -128,6 +128,34 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/09/71/54e999902aed72baf26bca0d50781b01838251a462612966e9fc4891eadd/black-25.1.0-py3-none-any.whl", hash = "sha256:95e8176dae143ba9097f351d174fdaf0ccd29efb414b362ae3fd72bf0f710717", size = 207646, upload-time = "2025-01-29T04:15:38.082Z" }, ] +[[package]] +name = "cairocffi" +version = "1.7.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/70/c5/1a4dc131459e68a173cbdab5fad6b524f53f9c1ef7861b7698e998b837cc/cairocffi-1.7.1.tar.gz", hash = "sha256:2e48ee864884ec4a3a34bfa8c9ab9999f688286eb714a15a43ec9d068c36557b", size = 88096, upload-time = "2024-06-18T10:56:06.741Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/93/d8/ba13451aa6b745c49536e87b6bf8f629b950e84bd0e8308f7dc6883b67e2/cairocffi-1.7.1-py3-none-any.whl", hash = "sha256:9803a0e11f6c962f3b0ae2ec8ba6ae45e957a146a004697a1ac1bbf16b073b3f", size = 75611, upload-time = "2024-06-18T10:55:59.489Z" }, +] + +[[package]] +name = "cairosvg" +version = "2.8.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cairocffi" }, + { name = "cssselect2" }, + { name = "defusedxml" }, + { name = "pillow" }, + { name = "tinycss2" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ab/b9/5106168bd43d7cd8b7cc2a2ee465b385f14b63f4c092bb89eee2d48c8e67/cairosvg-2.8.2.tar.gz", hash = "sha256:07cbf4e86317b27a92318a4cac2a4bb37a5e9c1b8a27355d06874b22f85bef9f", size = 8398590, upload-time = "2025-05-15T06:56:32.653Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/67/48/816bd4aaae93dbf9e408c58598bc32f4a8c65f4b86ab560864cb3ee60adb/cairosvg-2.8.2-py3-none-any.whl", hash = "sha256:eab46dad4674f33267a671dce39b64be245911c901c70d65d2b7b0821e852bf5", size = 45773, upload-time = "2025-05-15T06:56:28.552Z" }, +] + [[package]] name = "certifi" version = "2025.8.3" @@ -468,6 +496,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bc/58/6b3d24e6b9bc474a2dcdee65dfd1f008867015408a271562e4b690561a4d/cryptography-46.0.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:8456928655f856c6e1533ff59d5be76578a7157224dbd9ce6872f25055ab9ab7", size = 3407605, upload-time = "2026-02-10T19:18:29.233Z" }, ] +[[package]] +name = "cssselect2" +version = "0.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "tinycss2" }, + { name = "webencodings" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e0/20/92eaa6b0aec7189fa4b75c890640e076e9e793095721db69c5c81142c2e1/cssselect2-0.9.0.tar.gz", hash = "sha256:759aa22c216326356f65e62e791d66160a0f9c91d1424e8d8adc5e74dddfc6fb", size = 35595, upload-time = "2026-02-12T17:16:39.614Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/21/0e/8459ca4413e1a21a06c97d134bfaf18adfd27cea068813dc0faae06cbf00/cssselect2-0.9.0-py3-none-any.whl", hash = "sha256:6a99e5f91f9a016a304dd929b0966ca464bcfda15177b6fb4a118fc0fb5d9563", size = 15453, upload-time = "2026-02-12T17:16:38.317Z" }, +] + +[[package]] +name = "defusedxml" +version = "0.7.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0f/d5/c66da9b79e5bdb124974bfe172b4daf3c984ebd9c2a06e2b8a4dc7331c72/defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69", size = 75520, upload-time = "2021-03-08T10:59:26.269Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61", size = 25604, upload-time = "2021-03-08T10:59:24.45Z" }, +] + [[package]] name = "dirty-equals" version = "0.9.0" @@ -781,7 +831,7 @@ dev = [ docs = [ { name = "mkdocs" }, { name = "mkdocs-glightbox" }, - { name = "mkdocs-material" }, + { name = "mkdocs-material", extra = ["imaging"] }, { name = "mkdocstrings-python" }, ] @@ -829,7 +879,7 @@ dev = [ docs = [ { name = "mkdocs", specifier = ">=1.6.1" }, { name = "mkdocs-glightbox", specifier = ">=0.4.0" }, - { name = "mkdocs-material", specifier = ">=9.5.45" }, + { name = "mkdocs-material", extras = ["imaging"], specifier = ">=9.5.45" }, { name = "mkdocstrings-python", specifier = ">=2.0.1" }, ] @@ -1469,7 +1519,7 @@ wheels = [ [[package]] name = "mkdocs-material" -version = "9.7.1" +version = "9.7.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "babel" }, @@ -1484,9 +1534,15 @@ dependencies = [ { name = "pymdown-extensions" }, { name = "requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/27/e2/2ffc356cd72f1473d07c7719d82a8f2cbd261666828614ecb95b12169f41/mkdocs_material-9.7.1.tar.gz", hash = "sha256:89601b8f2c3e6c6ee0a918cc3566cb201d40bf37c3cd3c2067e26fadb8cce2b8", size = 4094392, upload-time = "2025-12-18T09:49:00.308Z" } +sdist = { url = "https://files.pythonhosted.org/packages/34/57/5d3c8c9e2ff9d66dc8f63aa052eb0bac5041fecff7761d8689fe65c39c13/mkdocs_material-9.7.2.tar.gz", hash = "sha256:6776256552290b9b7a7aa002780e25b1e04bc9c3a8516b6b153e82e16b8384bd", size = 4097818, upload-time = "2026-02-18T15:53:07.763Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3e/32/ed071cb721aca8c227718cffcf7bd539620e9799bbf2619e90c757bfd030/mkdocs_material-9.7.1-py3-none-any.whl", hash = "sha256:3f6100937d7d731f87f1e3e3b021c97f7239666b9ba1151ab476cabb96c60d5c", size = 9297166, upload-time = "2025-12-18T09:48:56.664Z" }, + { url = "https://files.pythonhosted.org/packages/cd/19/d194e75e82282b1d688f0720e21b5ac250ed64ddea333a228aaf83105f2e/mkdocs_material-9.7.2-py3-none-any.whl", hash = "sha256:9bf6f53452d4a4d527eac3cef3f92b7b6fc4931c55d57766a7d87890d47e1b92", size = 9305052, upload-time = "2026-02-18T15:53:05.221Z" }, +] + +[package.optional-dependencies] +imaging = [ + { name = "cairosvg" }, + { name = "pillow" }, ] [[package]] @@ -2406,6 +2462,18 @@ dependencies = [ { name = "pydantic" }, ] +[[package]] +name = "tinycss2" +version = "1.5.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "webencodings" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a3/ae/2ca4913e5c0f09781d75482874c3a95db9105462a92ddd303c7d285d3df2/tinycss2-1.5.1.tar.gz", hash = "sha256:d339d2b616ba90ccce58da8495a78f46e55d4d25f9fd71dfd526f07e7d53f957", size = 88195, upload-time = "2025-11-23T10:29:10.082Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/60/45/c7b5c3168458db837e8ceab06dc77824e18202679d0463f0e8f002143a97/tinycss2-1.5.1-py3-none-any.whl", hash = "sha256:3415ba0f5839c062696996998176c4a3751d18b7edaaeeb658c9ce21ec150661", size = 28404, upload-time = "2025-11-23T10:29:08.676Z" }, +] + [[package]] name = "tomli" version = "2.2.1" @@ -2554,6 +2622,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/33/e8/e40370e6d74ddba47f002a32919d91310d6074130fe4e17dabcafc15cbf1/watchdog-6.0.0-py3-none-win_ia64.whl", hash = "sha256:a1914259fa9e1454315171103c6a30961236f508b9b623eae470268bbcc6a22f", size = 79067, upload-time = "2024-11-01T14:07:11.845Z" }, ] +[[package]] +name = "webencodings" +version = "0.5.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0b/02/ae6ceac1baeda530866a85075641cec12989bd8d31af6d5ab4a3e8c92f47/webencodings-0.5.1.tar.gz", hash = "sha256:b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923", size = 9721, upload-time = "2017-04-05T20:21:34.189Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f4/24/2a3e3df732393fed8b3ebf2ec078f05546de641fe1b667ee316ec1dcf3b7/webencodings-0.5.1-py2.py3-none-any.whl", hash = "sha256:a0af1213f3c2226497a97e2b3aa01a7e4bee4f403f95be16fc9acd2947514a78", size = 11774, upload-time = "2017-04-05T20:21:32.581Z" }, +] + [[package]] name = "websockets" version = "15.0.1" From 62575edabd84bfa6e7f143e08468db44abd33fd6 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 26 Feb 2026 15:36:46 +0000 Subject: [PATCH 20/84] ci: sign weekly lockfile commits as github-actions[bot] (#2148) --- .github/workflows/weekly-lockfile-update.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/weekly-lockfile-update.yml b/.github/workflows/weekly-lockfile-update.yml index 96507d7936..5d79d06d52 100644 --- a/.github/workflows/weekly-lockfile-update.yml +++ b/.github/workflows/weekly-lockfile-update.yml @@ -32,6 +32,7 @@ jobs: uses: peter-evans/create-pull-request@c0f553fe549906ede9cf27b5156039d195d2ece0 # v7 with: commit-message: "chore: update uv.lock with latest dependencies" + sign-commits: true title: "chore: weekly dependency update" body-path: pr_body.md branch: weekly-lockfile-update From cc22bf54641330046fe95666e8c4e041095bb883 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 4 Mar 2026 13:23:02 +0000 Subject: [PATCH 21/84] refactor: remove request_ctx ContextVar, thread Context explicitly (#2203) Co-authored-by: Marcelo Trylesinski --- docs/migration.md | 33 +- src/mcp/server/lowlevel/server.py | 9 +- src/mcp/server/mcpserver/__init__.py | 3 +- src/mcp/server/mcpserver/context.py | 280 ++++++++++++++++ src/mcp/server/mcpserver/prompts/base.py | 12 +- src/mcp/server/mcpserver/prompts/manager.py | 8 +- .../mcpserver/resources/resource_manager.py | 6 +- .../server/mcpserver/resources/templates.py | 10 +- src/mcp/server/mcpserver/server.py | 316 ++---------------- src/mcp/server/mcpserver/tools/base.py | 10 +- .../server/mcpserver/tools/tool_manager.py | 6 +- .../mcpserver/utilities/context_injection.py | 4 +- tests/client/test_list_roots_callback.py | 3 +- tests/client/test_logging_callback.py | 13 +- tests/client/test_sampling_callback.py | 10 +- tests/server/mcpserver/prompts/test_base.py | 31 +- .../server/mcpserver/prompts/test_manager.py | 9 +- .../resources/test_resource_manager.py | 7 +- .../resources/test_resource_template.py | 13 +- tests/server/mcpserver/test_server.py | 21 +- tests/server/mcpserver/test_tool_manager.py | 83 ++--- tests/server/mcpserver/tools/__init__.py | 0 tests/server/mcpserver/tools/test_base.py | 10 + 23 files changed, 484 insertions(+), 413 deletions(-) create mode 100644 src/mcp/server/mcpserver/context.py create mode 100644 tests/server/mcpserver/tools/__init__.py create mode 100644 tests/server/mcpserver/tools/test_base.py diff --git a/docs/migration.md b/docs/migration.md index 6316836938..7cf0325533 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -288,6 +288,37 @@ app = Starlette(routes=[Mount("/", app=mcp.streamable_http_app(json_response=Tru **Note:** DNS rebinding protection is automatically enabled when `host` is `127.0.0.1`, `localhost`, or `::1`. This now happens in `sse_app()` and `streamable_http_app()` instead of the constructor. +### `MCPServer.get_context()` removed + +`MCPServer.get_context()` has been removed. Context is now injected by the framework and passed explicitly — there is no ambient ContextVar to read from. + +**If you were calling `get_context()` from inside a tool/resource/prompt:** use the `ctx: Context` parameter injection instead. + +**Before (v1):** + +```python +@mcp.tool() +async def my_tool(x: int) -> str: + ctx = mcp.get_context() + await ctx.info("Processing...") + return str(x) +``` + +**After (v2):** + +```python +@mcp.tool() +async def my_tool(x: int, ctx: Context) -> str: + await ctx.info("Processing...") + return str(x) +``` + +### `MCPServer.call_tool()`, `read_resource()`, `get_prompt()` now accept a `context` parameter + +`MCPServer.call_tool()`, `MCPServer.read_resource()`, and `MCPServer.get_prompt()` now accept an optional `context: Context | None = None` parameter. The framework passes this automatically during normal request handling. If you call these methods directly and omit `context`, a Context with no active request is constructed for you — tools that don't use `ctx` work normally, but any attempt to use `ctx.session`, `ctx.request_id`, etc. will raise. + +The internal layers (`ToolManager.call_tool`, `Tool.run`, `Prompt.render`, `ResourceTemplate.create_resource`, etc.) now require `context` as a positional argument. + ### Replace `RootModel` by union types with `TypeAdapter` validation The following union types are no longer `RootModel` subclasses: @@ -694,7 +725,7 @@ If you prefer the convenience of automatic wrapping, use `MCPServer` which still ### Lowlevel `Server`: `request_context` property removed -The `server.request_context` property has been removed. Request context is now passed directly to handlers as the first argument (`ctx`). The `request_ctx` module-level contextvar is now an internal implementation detail and should not be relied upon. +The `server.request_context` property has been removed. Request context is now passed directly to handlers as the first argument (`ctx`). The `request_ctx` module-level contextvar has been removed entirely. **Before (v1):** diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index aee6440402..1c84c86107 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -36,7 +36,6 @@ async def main(): from __future__ import annotations -import contextvars import logging import warnings from collections.abc import AsyncIterator, Awaitable, Callable @@ -74,8 +73,6 @@ async def main(): LifespanResultT = TypeVar("LifespanResultT", default=Any) -request_ctx: contextvars.ContextVar[ServerRequestContext[Any]] = contextvars.ContextVar("request_ctx") - class NotificationOptions: def __init__(self, prompts_changed: bool = False, resources_changed: bool = False, tools_changed: bool = False): @@ -474,11 +471,7 @@ async def _handle_request( close_sse_stream=close_sse_stream_cb, close_standalone_sse_stream=close_standalone_sse_stream_cb, ) - token = request_ctx.set(ctx) - try: - response = await handler(ctx, req.params) - finally: - request_ctx.reset(token) + response = await handler(ctx, req.params) except MCPError as err: response = err.error except anyio.get_cancelled_exc_class(): diff --git a/src/mcp/server/mcpserver/__init__.py b/src/mcp/server/mcpserver/__init__.py index f51c0b0ed8..0857e38bd4 100644 --- a/src/mcp/server/mcpserver/__init__.py +++ b/src/mcp/server/mcpserver/__init__.py @@ -2,7 +2,8 @@ from mcp.types import Icon -from .server import Context, MCPServer +from .context import Context +from .server import MCPServer from .utilities.types import Audio, Image __all__ = ["MCPServer", "Context", "Image", "Audio", "Icon"] diff --git a/src/mcp/server/mcpserver/context.py b/src/mcp/server/mcpserver/context.py new file mode 100644 index 0000000000..1538adc7c7 --- /dev/null +++ b/src/mcp/server/mcpserver/context.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, Generic, Literal + +from pydantic import AnyUrl, BaseModel + +from mcp.server.context import LifespanContextT, RequestT, ServerRequestContext +from mcp.server.elicitation import ( + ElicitationResult, + ElicitSchemaModelT, + UrlElicitationResult, + elicit_url, + elicit_with_validation, +) +from mcp.server.lowlevel.helper_types import ReadResourceContents + +if TYPE_CHECKING: + from mcp.server.mcpserver.server import MCPServer + + +class Context(BaseModel, Generic[LifespanContextT, RequestT]): + """Context object providing access to MCP capabilities. + + This provides a cleaner interface to MCP's RequestContext functionality. + It gets injected into tool and resource functions that request it via type hints. + + To use context in a tool function, add a parameter with the Context type annotation: + + ```python + @server.tool() + async def my_tool(x: int, ctx: Context) -> str: + # Log messages to the client + await ctx.info(f"Processing {x}") + await ctx.debug("Debug info") + await ctx.warning("Warning message") + await ctx.error("Error message") + + # Report progress + await ctx.report_progress(50, 100) + + # Access resources + data = await ctx.read_resource("resource://data") + + # Get request info + request_id = ctx.request_id + client_id = ctx.client_id + + return str(x) + ``` + + The context parameter name can be anything as long as it's annotated with Context. + The context is optional - tools that don't need it can omit the parameter. + """ + + _request_context: ServerRequestContext[LifespanContextT, RequestT] | None + _mcp_server: MCPServer | None + + # TODO(maxisbey): Consider making request_context/mcp_server required, or refactor Context entirely. + def __init__( + self, + *, + request_context: ServerRequestContext[LifespanContextT, RequestT] | None = None, + mcp_server: MCPServer | None = None, + # TODO(Marcelo): We should drop this kwargs parameter. + **kwargs: Any, + ): + super().__init__(**kwargs) + self._request_context = request_context + self._mcp_server = mcp_server + + @property + def mcp_server(self) -> MCPServer: + """Access to the MCPServer instance.""" + if self._mcp_server is None: # pragma: no cover + raise ValueError("Context is not available outside of a request") + return self._mcp_server # pragma: no cover + + @property + def request_context(self) -> ServerRequestContext[LifespanContextT, RequestT]: + """Access to the underlying request context.""" + if self._request_context is None: # pragma: no cover + raise ValueError("Context is not available outside of a request") + return self._request_context + + async def report_progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: + """Report progress for the current operation. + + Args: + progress: Current progress value (e.g., 24) + total: Optional total value (e.g., 100) + message: Optional message (e.g., "Starting render...") + """ + progress_token = self.request_context.meta.get("progress_token") if self.request_context.meta else None + + if progress_token is None: # pragma: no cover + return + + await self.request_context.session.send_progress_notification( + progress_token=progress_token, + progress=progress, + total=total, + message=message, + related_request_id=self.request_id, + ) + + async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]: + """Read a resource by URI. + + Args: + uri: Resource URI to read + + Returns: + The resource content as either text or bytes + """ + assert self._mcp_server is not None, "Context is not available outside of a request" + return await self._mcp_server.read_resource(uri, self) + + async def elicit( + self, + message: str, + schema: type[ElicitSchemaModelT], + ) -> ElicitationResult[ElicitSchemaModelT]: + """Elicit information from the client/user. + + This method can be used to interactively ask for additional information from the + client within a tool's execution. The client might display the message to the + user and collect a response according to the provided schema. If the client + is an agent, it might decide how to handle the elicitation -- either by asking + the user or automatically generating a response. + + Args: + message: Message to present to the user + schema: A Pydantic model class defining the expected response structure. + According to the specification, only primitive types are allowed. + + Returns: + An ElicitationResult containing the action taken and the data if accepted + + Note: + Check the result.action to determine if the user accepted, declined, or cancelled. + The result.data will only be populated if action is "accept" and validation succeeded. + """ + + return await elicit_with_validation( + session=self.request_context.session, + message=message, + schema=schema, + related_request_id=self.request_id, + ) + + async def elicit_url( + self, + message: str, + url: str, + elicitation_id: str, + ) -> UrlElicitationResult: + """Request URL mode elicitation from the client. + + This directs the user to an external URL for out-of-band interactions + that must not pass through the MCP client. Use this for: + - Collecting sensitive credentials (API keys, passwords) + - OAuth authorization flows with third-party services + - Payment and subscription flows + - Any interaction where data should not pass through the LLM context + + The response indicates whether the user consented to navigate to the URL. + The actual interaction happens out-of-band. When the elicitation completes, + call `ctx.session.send_elicit_complete(elicitation_id)` to notify the client. + + Args: + message: Human-readable explanation of why the interaction is needed + url: The URL the user should navigate to + elicitation_id: Unique identifier for tracking this elicitation + + Returns: + UrlElicitationResult indicating accept, decline, or cancel + """ + return await elicit_url( + session=self.request_context.session, + message=message, + url=url, + elicitation_id=elicitation_id, + related_request_id=self.request_id, + ) + + async def log( + self, + level: Literal["debug", "info", "warning", "error"], + message: str, + *, + logger_name: str | None = None, + extra: dict[str, Any] | None = None, + ) -> None: + """Send a log message to the client. + + Args: + level: Log level (debug, info, warning, error) + message: Log message + logger_name: Optional logger name + extra: Optional dictionary with additional structured data to include + """ + + if extra: + log_data = {"message": message, **extra} + else: + log_data = message + + await self.request_context.session.send_log_message( + level=level, + data=log_data, + logger=logger_name, + related_request_id=self.request_id, + ) + + @property + def client_id(self) -> str | None: + """Get the client ID if available.""" + return self.request_context.meta.get("client_id") if self.request_context.meta else None # pragma: no cover + + @property + def request_id(self) -> str: + """Get the unique ID for this request.""" + return str(self.request_context.request_id) + + @property + def session(self): + """Access to the underlying session for advanced usage.""" + return self.request_context.session + + async def close_sse_stream(self) -> None: + """Close the SSE stream to trigger client reconnection. + + This method closes the HTTP connection for the current request, triggering + client reconnection. Events continue to be stored in the event store and will + be replayed when the client reconnects with Last-Event-ID. + + Use this to implement polling behavior during long-running operations - + the client will reconnect after the retry interval specified in the priming event. + + Note: + This is a no-op if not using StreamableHTTP transport with event_store. + The callback is only available when event_store is configured. + """ + if self._request_context and self._request_context.close_sse_stream: # pragma: no cover + await self._request_context.close_sse_stream() + + async def close_standalone_sse_stream(self) -> None: + """Close the standalone GET SSE stream to trigger client reconnection. + + This method closes the HTTP connection for the standalone GET stream used + for unsolicited server-to-client notifications. The client SHOULD reconnect + with Last-Event-ID to resume receiving notifications. + + Note: + This is a no-op if not using StreamableHTTP transport with event_store. + Currently, client reconnection for standalone GET streams is NOT + implemented - this is a known gap. + """ + if self._request_context and self._request_context.close_standalone_sse_stream: # pragma: no cover + await self._request_context.close_standalone_sse_stream() + + # Convenience methods for common log levels + async def debug(self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None) -> None: + """Send a debug log message.""" + await self.log("debug", message, logger_name=logger_name, extra=extra) + + async def info(self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None) -> None: + """Send an info log message.""" + await self.log("info", message, logger_name=logger_name, extra=extra) + + async def warning( + self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None + ) -> None: + """Send a warning log message.""" + await self.log("warning", message, logger_name=logger_name, extra=extra) + + async def error(self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None) -> None: + """Send an error log message.""" + await self.log("error", message, logger_name=logger_name, extra=extra) diff --git a/src/mcp/server/mcpserver/prompts/base.py b/src/mcp/server/mcpserver/prompts/base.py index 17744a6707..0c319d53cc 100644 --- a/src/mcp/server/mcpserver/prompts/base.py +++ b/src/mcp/server/mcpserver/prompts/base.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: from mcp.server.context import LifespanContextT, RequestT - from mcp.server.mcpserver.server import Context + from mcp.server.mcpserver.context import Context class Message(BaseModel): @@ -135,10 +135,14 @@ def from_function( async def render( self, - arguments: dict[str, Any] | None = None, - context: Context[LifespanContextT, RequestT] | None = None, + arguments: dict[str, Any] | None, + context: Context[LifespanContextT, RequestT], ) -> list[Message]: - """Render the prompt with arguments.""" + """Render the prompt with arguments. + + Raises: + ValueError: If required arguments are missing, or if rendering fails. + """ # Validate required arguments if self.arguments: required = {arg.name for arg in self.arguments if arg.required} diff --git a/src/mcp/server/mcpserver/prompts/manager.py b/src/mcp/server/mcpserver/prompts/manager.py index 21b9741318..28a7a6e98c 100644 --- a/src/mcp/server/mcpserver/prompts/manager.py +++ b/src/mcp/server/mcpserver/prompts/manager.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: from mcp.server.context import LifespanContextT, RequestT - from mcp.server.mcpserver.server import Context + from mcp.server.mcpserver.context import Context logger = get_logger(__name__) @@ -48,12 +48,12 @@ def add_prompt( async def render_prompt( self, name: str, - arguments: dict[str, Any] | None = None, - context: Context[LifespanContextT, RequestT] | None = None, + arguments: dict[str, Any] | None, + context: Context[LifespanContextT, RequestT], ) -> list[Message]: """Render a prompt by name with arguments.""" prompt = self.get_prompt(name) if not prompt: raise ValueError(f"Unknown prompt: {name}") - return await prompt.render(arguments, context=context) + return await prompt.render(arguments, context) diff --git a/src/mcp/server/mcpserver/resources/resource_manager.py b/src/mcp/server/mcpserver/resources/resource_manager.py index ed5b741239..6bf17376d1 100644 --- a/src/mcp/server/mcpserver/resources/resource_manager.py +++ b/src/mcp/server/mcpserver/resources/resource_manager.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: from mcp.server.context import LifespanContextT, RequestT - from mcp.server.mcpserver.server import Context + from mcp.server.mcpserver.context import Context logger = get_logger(__name__) @@ -80,9 +80,7 @@ def add_template( self._templates[template.uri_template] = template return template - async def get_resource( - self, uri: AnyUrl | str, context: Context[LifespanContextT, RequestT] | None = None - ) -> Resource: + async def get_resource(self, uri: AnyUrl | str, context: Context[LifespanContextT, RequestT]) -> Resource: """Get resource by URI, checking concrete resources first, then templates.""" uri_str = str(uri) logger.debug("Getting resource", extra={"uri": uri_str}) diff --git a/src/mcp/server/mcpserver/resources/templates.py b/src/mcp/server/mcpserver/resources/templates.py index e796823d9d..2d612657c4 100644 --- a/src/mcp/server/mcpserver/resources/templates.py +++ b/src/mcp/server/mcpserver/resources/templates.py @@ -17,7 +17,7 @@ if TYPE_CHECKING: from mcp.server.context import LifespanContextT, RequestT - from mcp.server.mcpserver.server import Context + from mcp.server.mcpserver.context import Context class ResourceTemplate(BaseModel): @@ -99,9 +99,13 @@ async def create_resource( self, uri: str, params: dict[str, Any], - context: Context[LifespanContextT, RequestT] | None = None, + context: Context[LifespanContextT, RequestT], ) -> Resource: - """Create a resource from the template with the given parameters.""" + """Create a resource from the template with the given parameters. + + Raises: + ValueError: If creating the resource fails. + """ try: # Add context to params if needed params = inject_context(self.fn, params, context, self.context_kwarg) diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index 9c7105a7b4..2a7a58117a 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -12,7 +12,6 @@ import anyio import pydantic_core -from pydantic import BaseModel from pydantic.networks import AnyUrl from pydantic_settings import BaseSettings, SettingsConfigDict from starlette.applications import Starlette @@ -27,12 +26,11 @@ from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware from mcp.server.auth.provider import OAuthAuthorizationServerProvider, ProviderTokenVerifier, TokenVerifier from mcp.server.auth.settings import AuthSettings -from mcp.server.context import LifespanContextT, RequestT, ServerRequestContext -from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, UrlElicitationResult, elicit_with_validation -from mcp.server.elicitation import elicit_url as _elicit_url +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel.helper_types import ReadResourceContents -from mcp.server.lowlevel.server import LifespanResultT, Server, request_ctx +from mcp.server.lowlevel.server import LifespanResultT, Server from mcp.server.lowlevel.server import lifespan as default_lifespan +from mcp.server.mcpserver.context import Context from mcp.server.mcpserver.exceptions import ResourceError from mcp.server.mcpserver.prompts import Prompt, PromptManager from mcp.server.mcpserver.resources import FunctionResource, Resource, ResourceManager @@ -300,8 +298,9 @@ async def _handle_list_tools( async def _handle_call_tool( self, ctx: ServerRequestContext[LifespanResultT], params: CallToolRequestParams ) -> CallToolResult: + context = Context(request_context=ctx, mcp_server=self) try: - result = await self.call_tool(params.name, params.arguments or {}) + result = await self.call_tool(params.name, params.arguments or {}, context) except MCPError: raise except Exception as e: @@ -332,7 +331,8 @@ async def _handle_list_resources( async def _handle_read_resource( self, ctx: ServerRequestContext[LifespanResultT], params: ReadResourceRequestParams ) -> ReadResourceResult: - results = await self.read_resource(params.uri) + context = Context(request_context=ctx, mcp_server=self) + results = await self.read_resource(params.uri, context) contents: list[TextResourceContents | BlobResourceContents] = [] for item in results: if isinstance(item.content, bytes): @@ -368,7 +368,8 @@ async def _handle_list_prompts( async def _handle_get_prompt( self, ctx: ServerRequestContext[LifespanResultT], params: GetPromptRequestParams ) -> GetPromptResult: - return await self.get_prompt(params.name, params.arguments) + context = Context(request_context=ctx, mcp_server=self) + return await self.get_prompt(params.name, params.arguments, context) async def list_tools(self) -> list[MCPTool]: """List all available tools.""" @@ -387,22 +388,13 @@ async def list_tools(self) -> list[MCPTool]: for info in tools ] - def get_context(self) -> Context[LifespanResultT, Request]: - """Return a Context object. - - Note that the context will only be valid during a request; outside a - request, most methods will error. - """ - try: - request_context = request_ctx.get() - except LookupError: - request_context = None - return Context(request_context=request_context, mcp_server=self) - - async def call_tool(self, name: str, arguments: dict[str, Any]) -> Sequence[ContentBlock] | dict[str, Any]: + async def call_tool( + self, name: str, arguments: dict[str, Any], context: Context[LifespanResultT, Any] | None = None + ) -> Sequence[ContentBlock] | dict[str, Any]: """Call a tool by name with arguments.""" - context = self.get_context() - return await self._tool_manager.call_tool(name, arguments, context=context, convert_result=True) + if context is None: + context = Context(mcp_server=self) + return await self._tool_manager.call_tool(name, arguments, context, convert_result=True) async def list_resources(self) -> list[MCPResource]: """List all available resources.""" @@ -438,12 +430,14 @@ async def list_resource_templates(self) -> list[MCPResourceTemplate]: for template in templates ] - async def read_resource(self, uri: AnyUrl | str) -> Iterable[ReadResourceContents]: + async def read_resource( + self, uri: AnyUrl | str, context: Context[LifespanResultT, Any] | None = None + ) -> Iterable[ReadResourceContents]: """Read a resource by URI.""" - - context = self.get_context() + if context is None: + context = Context(mcp_server=self) try: - resource = await self._resource_manager.get_resource(uri, context=context) + resource = await self._resource_manager.get_resource(uri, context) except ValueError: raise ResourceError(f"Unknown resource: {uri}") @@ -1087,14 +1081,18 @@ async def list_prompts(self) -> list[MCPPrompt]: for prompt in prompts ] - async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> GetPromptResult: + async def get_prompt( + self, name: str, arguments: dict[str, Any] | None = None, context: Context[LifespanResultT, Any] | None = None + ) -> GetPromptResult: """Get a prompt by name with arguments.""" + if context is None: + context = Context(mcp_server=self) try: prompt = self._prompt_manager.get_prompt(name) if not prompt: raise ValueError(f"Unknown prompt: {name}") - messages = await prompt.render(arguments, context=self.get_context()) + messages = await prompt.render(arguments, context) return GetPromptResult( description=prompt.description, @@ -1103,263 +1101,3 @@ async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None) - except Exception as e: logger.exception(f"Error getting prompt {name}") raise ValueError(str(e)) - - -class Context(BaseModel, Generic[LifespanContextT, RequestT]): - """Context object providing access to MCP capabilities. - - This provides a cleaner interface to MCP's RequestContext functionality. - It gets injected into tool and resource functions that request it via type hints. - - To use context in a tool function, add a parameter with the Context type annotation: - - ```python - @server.tool() - async def my_tool(x: int, ctx: Context) -> str: - # Log messages to the client - await ctx.info(f"Processing {x}") - await ctx.debug("Debug info") - await ctx.warning("Warning message") - await ctx.error("Error message") - - # Report progress - await ctx.report_progress(50, 100) - - # Access resources - data = await ctx.read_resource("resource://data") - - # Get request info - request_id = ctx.request_id - client_id = ctx.client_id - - return str(x) - ``` - - The context parameter name can be anything as long as it's annotated with Context. - The context is optional - tools that don't need it can omit the parameter. - """ - - _request_context: ServerRequestContext[LifespanContextT, RequestT] | None - _mcp_server: MCPServer | None - - def __init__( - self, - *, - request_context: ServerRequestContext[LifespanContextT, RequestT] | None = None, - mcp_server: MCPServer | None = None, - # TODO(Marcelo): We should drop this kwargs parameter. - **kwargs: Any, - ): - super().__init__(**kwargs) - self._request_context = request_context - self._mcp_server = mcp_server - - @property - def mcp_server(self) -> MCPServer: - """Access to the MCPServer instance.""" - if self._mcp_server is None: # pragma: no cover - raise ValueError("Context is not available outside of a request") - return self._mcp_server # pragma: no cover - - @property - def request_context(self) -> ServerRequestContext[LifespanContextT, RequestT]: - """Access to the underlying request context.""" - if self._request_context is None: # pragma: no cover - raise ValueError("Context is not available outside of a request") - return self._request_context - - async def report_progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: - """Report progress for the current operation. - - Args: - progress: Current progress value (e.g., 24) - total: Optional total value (e.g., 100) - message: Optional message (e.g., "Starting render...") - """ - progress_token = self.request_context.meta.get("progress_token") if self.request_context.meta else None - - if progress_token is None: # pragma: no cover - return - - await self.request_context.session.send_progress_notification( - progress_token=progress_token, - progress=progress, - total=total, - message=message, - related_request_id=self.request_id, - ) - - async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]: - """Read a resource by URI. - - Args: - uri: Resource URI to read - - Returns: - The resource content as either text or bytes - """ - assert self._mcp_server is not None, "Context is not available outside of a request" - return await self._mcp_server.read_resource(uri) - - async def elicit( - self, - message: str, - schema: type[ElicitSchemaModelT], - ) -> ElicitationResult[ElicitSchemaModelT]: - """Elicit information from the client/user. - - This method can be used to interactively ask for additional information from the - client within a tool's execution. The client might display the message to the - user and collect a response according to the provided schema. If the client - is an agent, it might decide how to handle the elicitation -- either by asking - the user or automatically generating a response. - - Args: - message: Message to present to the user - schema: A Pydantic model class defining the expected response structure. - According to the specification, only primitive types are allowed. - - Returns: - An ElicitationResult containing the action taken and the data if accepted - - Note: - Check the result.action to determine if the user accepted, declined, or cancelled. - The result.data will only be populated if action is "accept" and validation succeeded. - """ - - return await elicit_with_validation( - session=self.request_context.session, - message=message, - schema=schema, - related_request_id=self.request_id, - ) - - async def elicit_url( - self, - message: str, - url: str, - elicitation_id: str, - ) -> UrlElicitationResult: - """Request URL mode elicitation from the client. - - This directs the user to an external URL for out-of-band interactions - that must not pass through the MCP client. Use this for: - - Collecting sensitive credentials (API keys, passwords) - - OAuth authorization flows with third-party services - - Payment and subscription flows - - Any interaction where data should not pass through the LLM context - - The response indicates whether the user consented to navigate to the URL. - The actual interaction happens out-of-band. When the elicitation completes, - call `ctx.session.send_elicit_complete(elicitation_id)` to notify the client. - - Args: - message: Human-readable explanation of why the interaction is needed - url: The URL the user should navigate to - elicitation_id: Unique identifier for tracking this elicitation - - Returns: - UrlElicitationResult indicating accept, decline, or cancel - """ - return await _elicit_url( - session=self.request_context.session, - message=message, - url=url, - elicitation_id=elicitation_id, - related_request_id=self.request_id, - ) - - async def log( - self, - level: Literal["debug", "info", "warning", "error"], - message: str, - *, - logger_name: str | None = None, - extra: dict[str, Any] | None = None, - ) -> None: - """Send a log message to the client. - - Args: - level: Log level (debug, info, warning, error) - message: Log message - logger_name: Optional logger name - extra: Optional dictionary with additional structured data to include - """ - - if extra: - log_data = {"message": message, **extra} - else: - log_data = message - - await self.request_context.session.send_log_message( - level=level, - data=log_data, - logger=logger_name, - related_request_id=self.request_id, - ) - - @property - def client_id(self) -> str | None: - """Get the client ID if available.""" - return self.request_context.meta.get("client_id") if self.request_context.meta else None # pragma: no cover - - @property - def request_id(self) -> str: - """Get the unique ID for this request.""" - return str(self.request_context.request_id) - - @property - def session(self): - """Access to the underlying session for advanced usage.""" - return self.request_context.session - - async def close_sse_stream(self) -> None: - """Close the SSE stream to trigger client reconnection. - - This method closes the HTTP connection for the current request, triggering - client reconnection. Events continue to be stored in the event store and will - be replayed when the client reconnects with Last-Event-ID. - - Use this to implement polling behavior during long-running operations - - the client will reconnect after the retry interval specified in the priming event. - - Note: - This is a no-op if not using StreamableHTTP transport with event_store. - The callback is only available when event_store is configured. - """ - if self._request_context and self._request_context.close_sse_stream: # pragma: no cover - await self._request_context.close_sse_stream() - - async def close_standalone_sse_stream(self) -> None: - """Close the standalone GET SSE stream to trigger client reconnection. - - This method closes the HTTP connection for the standalone GET stream used - for unsolicited server-to-client notifications. The client SHOULD reconnect - with Last-Event-ID to resume receiving notifications. - - Note: - This is a no-op if not using StreamableHTTP transport with event_store. - Currently, client reconnection for standalone GET streams is NOT - implemented - this is a known gap. - """ - if self._request_context and self._request_context.close_standalone_sse_stream: # pragma: no cover - await self._request_context.close_standalone_sse_stream() - - # Convenience methods for common log levels - async def debug(self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None) -> None: - """Send a debug log message.""" - await self.log("debug", message, logger_name=logger_name, extra=extra) - - async def info(self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None) -> None: - """Send an info log message.""" - await self.log("info", message, logger_name=logger_name, extra=extra) - - async def warning( - self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None - ) -> None: - """Send a warning log message.""" - await self.log("warning", message, logger_name=logger_name, extra=extra) - - async def error(self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None) -> None: - """Send an error log message.""" - await self.log("error", message, logger_name=logger_name, extra=extra) diff --git a/src/mcp/server/mcpserver/tools/base.py b/src/mcp/server/mcpserver/tools/base.py index f6bfadbc4d..dc65be9885 100644 --- a/src/mcp/server/mcpserver/tools/base.py +++ b/src/mcp/server/mcpserver/tools/base.py @@ -17,7 +17,7 @@ if TYPE_CHECKING: from mcp.server.context import LifespanContextT, RequestT - from mcp.server.mcpserver.server import Context + from mcp.server.mcpserver.context import Context class Tool(BaseModel): @@ -92,10 +92,14 @@ def from_function( async def run( self, arguments: dict[str, Any], - context: Context[LifespanContextT, RequestT] | None = None, + context: Context[LifespanContextT, RequestT], convert_result: bool = False, ) -> Any: - """Run the tool with arguments.""" + """Run the tool with arguments. + + Raises: + ToolError: If the tool function raises during execution. + """ try: result = await self.fn_metadata.call_fn_with_arg_validation( self.fn, diff --git a/src/mcp/server/mcpserver/tools/tool_manager.py b/src/mcp/server/mcpserver/tools/tool_manager.py index c6f8384bdb..32ed547973 100644 --- a/src/mcp/server/mcpserver/tools/tool_manager.py +++ b/src/mcp/server/mcpserver/tools/tool_manager.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: from mcp.server.context import LifespanContextT, RequestT - from mcp.server.mcpserver.server import Context + from mcp.server.mcpserver.context import Context logger = get_logger(__name__) @@ -81,7 +81,7 @@ async def call_tool( self, name: str, arguments: dict[str, Any], - context: Context[LifespanContextT, RequestT] | None = None, + context: Context[LifespanContextT, RequestT], convert_result: bool = False, ) -> Any: """Call a tool by name with arguments.""" @@ -89,4 +89,4 @@ async def call_tool( if not tool: raise ToolError(f"Unknown tool: {name}") - return await tool.run(arguments, context=context, convert_result=convert_result) + return await tool.run(arguments, context, convert_result=convert_result) diff --git a/src/mcp/server/mcpserver/utilities/context_injection.py b/src/mcp/server/mcpserver/utilities/context_injection.py index 9cba83e860..ac7ab82d05 100644 --- a/src/mcp/server/mcpserver/utilities/context_injection.py +++ b/src/mcp/server/mcpserver/utilities/context_injection.py @@ -7,6 +7,8 @@ from collections.abc import Callable from typing import Any +from mcp.server.mcpserver.context import Context + def find_context_parameter(fn: Callable[..., Any]) -> str | None: """Find the parameter that should receive the Context object. @@ -20,8 +22,6 @@ def find_context_parameter(fn: Callable[..., Any]) -> str | None: Returns: The name of the context parameter, or None if not found """ - from mcp.server.mcpserver.server import Context - # Get type hints to properly resolve string annotations try: hints = typing.get_type_hints(fn) diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index 6a2f49f390..1ab90be772 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -3,8 +3,7 @@ from mcp import Client from mcp.client.session import ClientSession -from mcp.server.mcpserver import MCPServer -from mcp.server.mcpserver.server import Context +from mcp.server.mcpserver import Context, MCPServer from mcp.shared._context import RequestContext from mcp.types import ListRootsResult, Root, TextContent diff --git a/tests/client/test_logging_callback.py b/tests/client/test_logging_callback.py index 31cdeece73..1598fd55f6 100644 --- a/tests/client/test_logging_callback.py +++ b/tests/client/test_logging_callback.py @@ -3,7 +3,7 @@ import pytest from mcp import Client, types -from mcp.server.mcpserver import MCPServer +from mcp.server.mcpserver import Context, MCPServer from mcp.shared.session import RequestResponder from mcp.types import ( LoggingMessageNotificationParams, @@ -33,14 +33,10 @@ async def test_tool() -> bool: # Create a function that can send a log notification @server.tool("test_tool_with_log") async def test_tool_with_log( - message: str, level: Literal["debug", "info", "warning", "error"], logger: str + message: str, level: Literal["debug", "info", "warning", "error"], logger: str, ctx: Context ) -> bool: """Send a log notification to the client.""" - await server.get_context().log( - level=level, - message=message, - logger_name=logger, - ) + await ctx.log(level=level, message=message, logger_name=logger) return True @server.tool("test_tool_with_log_extra") @@ -50,9 +46,10 @@ async def test_tool_with_log_extra( logger: str, extra_string: str, extra_dict: dict[str, Any], + ctx: Context, ) -> bool: """Send a log notification to the client with extra fields.""" - await server.get_context().log( + await ctx.log( level=level, message=message, logger_name=logger, diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index 3357bc921d..6efcac0a52 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -2,7 +2,7 @@ from mcp import Client from mcp.client.session import ClientSession -from mcp.server.mcpserver import MCPServer +from mcp.server.mcpserver import Context, MCPServer from mcp.shared._context import RequestContext from mcp.types import ( CreateMessageRequestParams, @@ -32,8 +32,8 @@ async def sampling_callback( return callback_return @server.tool("test_sampling") - async def test_sampling_tool(message: str): - value = await server.get_context().session.create_message( + async def test_sampling_tool(message: str, ctx: Context) -> bool: + value = await ctx.session.create_message( messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))], max_tokens=100, ) @@ -77,9 +77,9 @@ async def sampling_callback( return callback_return @server.tool("test_backwards_compat") - async def test_tool(message: str): + async def test_tool(message: str, ctx: Context) -> bool: # Call create_message WITHOUT tools - result = await server.get_context().session.create_message( + result = await ctx.session.create_message( messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))], max_tokens=100, ) diff --git a/tests/server/mcpserver/prompts/test_base.py b/tests/server/mcpserver/prompts/test_base.py index 553e47363f..fe18e91bd7 100644 --- a/tests/server/mcpserver/prompts/test_base.py +++ b/tests/server/mcpserver/prompts/test_base.py @@ -2,6 +2,7 @@ import pytest +from mcp.server.mcpserver import Context from mcp.server.mcpserver.prompts.base import AssistantMessage, Message, Prompt, UserMessage from mcp.types import EmbeddedResource, TextContent, TextResourceContents @@ -13,7 +14,9 @@ def fn() -> str: return "Hello, world!" prompt = Prompt.from_function(fn) - assert await prompt.render() == [UserMessage(content=TextContent(type="text", text="Hello, world!"))] + assert await prompt.render(None, Context()) == [ + UserMessage(content=TextContent(type="text", text="Hello, world!")) + ] @pytest.mark.anyio async def test_async_fn(self): @@ -21,7 +24,9 @@ async def fn() -> str: return "Hello, world!" prompt = Prompt.from_function(fn) - assert await prompt.render() == [UserMessage(content=TextContent(type="text", text="Hello, world!"))] + assert await prompt.render(None, Context()) == [ + UserMessage(content=TextContent(type="text", text="Hello, world!")) + ] @pytest.mark.anyio async def test_fn_with_args(self): @@ -29,7 +34,7 @@ async def fn(name: str, age: int = 30) -> str: return f"Hello, {name}! You're {age} years old." prompt = Prompt.from_function(fn) - assert await prompt.render(arguments={"name": "World"}) == [ + assert await prompt.render({"name": "World"}, Context()) == [ UserMessage(content=TextContent(type="text", text="Hello, World! You're 30 years old.")) ] @@ -40,7 +45,7 @@ async def fn(name: str, age: int = 30) -> str: # pragma: no cover prompt = Prompt.from_function(fn) with pytest.raises(ValueError): - await prompt.render(arguments={"age": 40}) + await prompt.render({"age": 40}, Context()) @pytest.mark.anyio async def test_fn_returns_message(self): @@ -48,7 +53,9 @@ async def fn() -> UserMessage: return UserMessage(content="Hello, world!") prompt = Prompt.from_function(fn) - assert await prompt.render() == [UserMessage(content=TextContent(type="text", text="Hello, world!"))] + assert await prompt.render(None, Context()) == [ + UserMessage(content=TextContent(type="text", text="Hello, world!")) + ] @pytest.mark.anyio async def test_fn_returns_assistant_message(self): @@ -56,7 +63,9 @@ async def fn() -> AssistantMessage: return AssistantMessage(content=TextContent(type="text", text="Hello, world!")) prompt = Prompt.from_function(fn) - assert await prompt.render() == [AssistantMessage(content=TextContent(type="text", text="Hello, world!"))] + assert await prompt.render(None, Context()) == [ + AssistantMessage(content=TextContent(type="text", text="Hello, world!")) + ] @pytest.mark.anyio async def test_fn_returns_multiple_messages(self): @@ -70,7 +79,7 @@ async def fn() -> list[Message]: return expected prompt = Prompt.from_function(fn) - assert await prompt.render() == expected + assert await prompt.render(None, Context()) == expected @pytest.mark.anyio async def test_fn_returns_list_of_strings(self): @@ -83,7 +92,7 @@ async def fn() -> list[str]: return expected prompt = Prompt.from_function(fn) - assert await prompt.render() == [UserMessage(t) for t in expected] + assert await prompt.render(None, Context()) == [UserMessage(t) for t in expected] @pytest.mark.anyio async def test_fn_returns_resource_content(self): @@ -102,7 +111,7 @@ async def fn() -> UserMessage: ) prompt = Prompt.from_function(fn) - assert await prompt.render() == [ + assert await prompt.render(None, Context()) == [ UserMessage( content=EmbeddedResource( type="resource", @@ -136,7 +145,7 @@ async def fn() -> list[Message]: ] prompt = Prompt.from_function(fn) - assert await prompt.render() == [ + assert await prompt.render(None, Context()) == [ UserMessage(content=TextContent(type="text", text="Please analyze this file:")), UserMessage( content=EmbeddedResource( @@ -169,7 +178,7 @@ async def fn() -> dict[str, Any]: } prompt = Prompt.from_function(fn) - assert await prompt.render() == [ + assert await prompt.render(None, Context()) == [ UserMessage( content=EmbeddedResource( type="resource", diff --git a/tests/server/mcpserver/prompts/test_manager.py b/tests/server/mcpserver/prompts/test_manager.py index 02f91c6802..99a03db565 100644 --- a/tests/server/mcpserver/prompts/test_manager.py +++ b/tests/server/mcpserver/prompts/test_manager.py @@ -1,5 +1,6 @@ import pytest +from mcp.server.mcpserver import Context from mcp.server.mcpserver.prompts.base import Prompt, UserMessage from mcp.server.mcpserver.prompts.manager import PromptManager from mcp.types import TextContent @@ -72,7 +73,7 @@ def fn() -> str: manager = PromptManager() prompt = Prompt.from_function(fn) manager.add_prompt(prompt) - messages = await manager.render_prompt("fn") + messages = await manager.render_prompt("fn", None, Context()) assert messages == [UserMessage(content=TextContent(type="text", text="Hello, world!"))] @pytest.mark.anyio @@ -85,7 +86,7 @@ def fn(name: str) -> str: manager = PromptManager() prompt = Prompt.from_function(fn) manager.add_prompt(prompt) - messages = await manager.render_prompt("fn", arguments={"name": "World"}) + messages = await manager.render_prompt("fn", {"name": "World"}, Context()) assert messages == [UserMessage(content=TextContent(type="text", text="Hello, World!"))] @pytest.mark.anyio @@ -93,7 +94,7 @@ async def test_render_unknown_prompt(self): """Test rendering a non-existent prompt.""" manager = PromptManager() with pytest.raises(ValueError, match="Unknown prompt: unknown"): - await manager.render_prompt("unknown") + await manager.render_prompt("unknown", None, Context()) @pytest.mark.anyio async def test_render_prompt_with_missing_args(self): @@ -106,4 +107,4 @@ def fn(name: str) -> str: # pragma: no cover prompt = Prompt.from_function(fn) manager.add_prompt(prompt) with pytest.raises(ValueError, match="Missing required arguments"): - await manager.render_prompt("fn") + await manager.render_prompt("fn", None, Context()) diff --git a/tests/server/mcpserver/resources/test_resource_manager.py b/tests/server/mcpserver/resources/test_resource_manager.py index eb9b355aaf..724b579974 100644 --- a/tests/server/mcpserver/resources/test_resource_manager.py +++ b/tests/server/mcpserver/resources/test_resource_manager.py @@ -4,6 +4,7 @@ import pytest from pydantic import AnyUrl +from mcp.server.mcpserver import Context from mcp.server.mcpserver.resources import FileResource, FunctionResource, ResourceManager, ResourceTemplate @@ -86,7 +87,7 @@ async def test_get_resource(self, temp_file: Path): path=temp_file, ) manager.add_resource(resource) - retrieved = await manager.get_resource(resource.uri) + retrieved = await manager.get_resource(resource.uri, Context()) assert retrieved == resource @pytest.mark.anyio @@ -104,7 +105,7 @@ def greet(name: str) -> str: ) manager._templates[template.uri_template] = template - resource = await manager.get_resource(AnyUrl("greet://world")) + resource = await manager.get_resource(AnyUrl("greet://world"), Context()) assert isinstance(resource, FunctionResource) content = await resource.read() assert content == "Hello, world!" @@ -114,7 +115,7 @@ async def test_get_unknown_resource(self): """Test getting a non-existent resource.""" manager = ResourceManager() with pytest.raises(ValueError, match="Unknown resource"): - await manager.get_resource(AnyUrl("unknown://test")) + await manager.get_resource(AnyUrl("unknown://test"), Context()) def test_list_resources(self, temp_file: Path): """Test listing all resources.""" diff --git a/tests/server/mcpserver/resources/test_resource_template.py b/tests/server/mcpserver/resources/test_resource_template.py index 08f2033bff..640cfe8031 100644 --- a/tests/server/mcpserver/resources/test_resource_template.py +++ b/tests/server/mcpserver/resources/test_resource_template.py @@ -4,7 +4,7 @@ import pytest from pydantic import BaseModel -from mcp.server.mcpserver import MCPServer +from mcp.server.mcpserver import Context, MCPServer from mcp.server.mcpserver.resources import FunctionResource, ResourceTemplate from mcp.types import Annotations @@ -64,6 +64,7 @@ def my_func(key: str, value: int) -> dict[str, Any]: resource = await template.create_resource( "test://foo/123", {"key": "foo", "value": 123}, + Context(), ) assert isinstance(resource, FunctionResource) @@ -86,7 +87,7 @@ def failing_func(x: str) -> str: ) with pytest.raises(ValueError, match="Error creating resource from template"): - await template.create_resource("fail://test", {"x": "test"}) + await template.create_resource("fail://test", {"x": "test"}, Context()) @pytest.mark.anyio async def test_async_text_resource(self): @@ -104,6 +105,7 @@ async def greet(name: str) -> str: resource = await template.create_resource( "greet://world", {"name": "world"}, + Context(), ) assert isinstance(resource, FunctionResource) @@ -126,6 +128,7 @@ async def get_bytes(value: str) -> bytes: resource = await template.create_resource( "bytes://test", {"value": "test"}, + Context(), ) assert isinstance(resource, FunctionResource) @@ -152,6 +155,7 @@ def get_data(key: str, value: int) -> MyModel: resource = await template.create_resource( "test://foo/123", {"key": "foo", "value": 123}, + Context(), ) assert isinstance(resource, FunctionResource) @@ -183,6 +187,7 @@ def get_data(value: str) -> CustomData: resource = await template.create_resource( "test://hello", {"value": "hello"}, + Context(), ) assert isinstance(resource, FunctionResource) @@ -249,7 +254,7 @@ def get_item(item_id: str) -> str: ) # Create a resource from the template - resource = await template.create_resource("resource://items/123", {"item_id": "123"}) + resource = await template.create_resource("resource://items/123", {"item_id": "123"}, Context()) # The resource should inherit the template's annotations assert resource.annotations is not None @@ -298,7 +303,7 @@ def get_item(item_id: str) -> str: ) # Create a resource from the template - resource = await template.create_resource("resource://items/123", {"item_id": "123"}) + resource = await template.create_resource("resource://items/123", {"item_id": "123"}, Context()) # The resource should inherit the template's metadata assert resource.meta is not None diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index cfbe6587bb..3d130bfc33 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -891,7 +891,7 @@ def get_data(name: str) -> str: assert len(await mcp.list_resources()) == 0 # When accessed, should create a concrete resource - resource = await mcp._resource_manager.get_resource("resource://test/data") + resource = await mcp._resource_manager.get_resource("resource://test/data", Context()) assert isinstance(resource, FunctionResource) result = await resource.read() assert result == "Data for test" @@ -1231,6 +1231,19 @@ def prompt_no_context(text: str) -> str: class TestServerPrompts: """Test prompt functionality in MCPServer server.""" + async def test_get_prompt_direct_call_without_context(self): + """Test calling mcp.get_prompt() directly without passing context.""" + mcp = MCPServer() + + @mcp.prompt() + def fn() -> str: + return "Hello, world!" + + result = await mcp.get_prompt("fn") + content = result.messages[0].content + assert isinstance(content, TextContent) + assert content.text == "Hello, world!" + async def test_prompt_decorator(self): """Test that the prompt decorator registers prompts correctly.""" mcp = MCPServer() @@ -1243,7 +1256,7 @@ def fn() -> str: assert len(prompts) == 1 assert prompts[0].name == "fn" # Don't compare functions directly since validate_call wraps them - content = await prompts[0].render() + content = await prompts[0].render(None, Context()) assert isinstance(content[0].content, TextContent) assert content[0].content.text == "Hello, world!" @@ -1258,7 +1271,7 @@ def fn() -> str: prompts = mcp._prompt_manager.list_prompts() assert len(prompts) == 1 assert prompts[0].name == "custom_name" - content = await prompts[0].render() + content = await prompts[0].render(None, Context()) assert isinstance(content[0].content, TextContent) assert content[0].content.text == "Hello, world!" @@ -1273,7 +1286,7 @@ def fn() -> str: prompts = mcp._prompt_manager.list_prompts() assert len(prompts) == 1 assert prompts[0].description == "A custom description" - content = await prompts[0].render() + content = await prompts[0].render(None, Context()) assert isinstance(content[0].content, TextContent) assert content[0].content.text == "Hello, world!" diff --git a/tests/server/mcpserver/test_tool_manager.py b/tests/server/mcpserver/test_tool_manager.py index 550bba50a3..f990ec47b7 100644 --- a/tests/server/mcpserver/test_tool_manager.py +++ b/tests/server/mcpserver/test_tool_manager.py @@ -188,7 +188,7 @@ def sum(a: int, b: int) -> int: manager = ToolManager() manager.add_tool(sum) - result = await manager.call_tool("sum", {"a": 1, "b": 2}) + result = await manager.call_tool("sum", {"a": 1, "b": 2}, Context()) assert result == 3 @pytest.mark.anyio @@ -199,7 +199,7 @@ async def double(n: int) -> int: manager = ToolManager() manager.add_tool(double) - result = await manager.call_tool("double", {"n": 5}) + result = await manager.call_tool("double", {"n": 5}, Context()) assert result == 10 @pytest.mark.anyio @@ -213,7 +213,7 @@ def __call__(self, x: int) -> int: manager = ToolManager() tool = manager.add_tool(MyTool()) - result = await tool.run({"x": 5}) + result = await tool.run({"x": 5}, Context()) assert result == 10 @pytest.mark.anyio @@ -227,7 +227,7 @@ async def __call__(self, x: int) -> int: manager = ToolManager() tool = manager.add_tool(MyAsyncTool()) - result = await tool.run({"x": 5}) + result = await tool.run({"x": 5}, Context()) assert result == 10 @pytest.mark.anyio @@ -238,7 +238,7 @@ def sum(a: int, b: int = 1) -> int: manager = ToolManager() manager.add_tool(sum) - result = await manager.call_tool("sum", {"a": 1}) + result = await manager.call_tool("sum", {"a": 1}, Context()) assert result == 2 @pytest.mark.anyio @@ -250,13 +250,13 @@ def sum(a: int, b: int) -> int: # pragma: no cover manager = ToolManager() manager.add_tool(sum) with pytest.raises(ToolError): - await manager.call_tool("sum", {"a": 1}) + await manager.call_tool("sum", {"a": 1}, Context()) @pytest.mark.anyio async def test_call_unknown_tool(self): manager = ToolManager() with pytest.raises(ToolError): - await manager.call_tool("unknown", {"a": 1}) + await manager.call_tool("unknown", {"a": 1}, Context()) @pytest.mark.anyio async def test_call_tool_with_list_int_input(self): @@ -266,9 +266,9 @@ def sum_vals(vals: list[int]) -> int: manager = ToolManager() manager.add_tool(sum_vals) # Try both with plain list and with JSON list - result = await manager.call_tool("sum_vals", {"vals": "[1, 2, 3]"}) + result = await manager.call_tool("sum_vals", {"vals": "[1, 2, 3]"}, Context()) assert result == 6 - result = await manager.call_tool("sum_vals", {"vals": [1, 2, 3]}) + result = await manager.call_tool("sum_vals", {"vals": [1, 2, 3]}, Context()) assert result == 6 @pytest.mark.anyio @@ -279,13 +279,13 @@ def concat_strs(vals: list[str] | str) -> str: manager = ToolManager() manager.add_tool(concat_strs) # Try both with plain python object and with JSON list - result = await manager.call_tool("concat_strs", {"vals": ["a", "b", "c"]}) + result = await manager.call_tool("concat_strs", {"vals": ["a", "b", "c"]}, Context()) assert result == "abc" - result = await manager.call_tool("concat_strs", {"vals": '["a", "b", "c"]'}) + result = await manager.call_tool("concat_strs", {"vals": '["a", "b", "c"]'}, Context()) assert result == "abc" - result = await manager.call_tool("concat_strs", {"vals": "a"}) + result = await manager.call_tool("concat_strs", {"vals": "a"}, Context()) assert result == "a" - result = await manager.call_tool("concat_strs", {"vals": '"a"'}) + result = await manager.call_tool("concat_strs", {"vals": '"a"'}, Context()) assert result == '"a"' @pytest.mark.anyio @@ -297,7 +297,7 @@ class Shrimp(BaseModel): shrimp: list[Shrimp] x: None - def name_shrimp(tank: MyShrimpTank, ctx: Context[ServerSessionT, None]) -> list[str]: + def name_shrimp(tank: MyShrimpTank) -> list[str]: return [x.name for x in tank.shrimp] manager = ToolManager() @@ -305,11 +305,13 @@ def name_shrimp(tank: MyShrimpTank, ctx: Context[ServerSessionT, None]) -> list[ result = await manager.call_tool( "name_shrimp", {"tank": {"x": None, "shrimp": [{"name": "rex"}, {"name": "gertrude"}]}}, + Context(), ) assert result == ["rex", "gertrude"] result = await manager.call_tool( "name_shrimp", {"tank": '{"x": null, "shrimp": [{"name": "rex"}, {"name": "gertrude"}]}'}, + Context(), ) assert result == ["rex", "gertrude"] @@ -364,9 +366,7 @@ def tool_with_context(x: int, ctx: Context[ServerSessionT, None]) -> str: manager = ToolManager() manager.add_tool(tool_with_context) - mcp = MCPServer() - ctx = mcp.get_context() - result = await manager.call_tool("tool_with_context", {"x": 42}, context=ctx) + result = await manager.call_tool("tool_with_context", {"x": 42}, context=Context()) assert result == "42" @pytest.mark.anyio @@ -380,22 +380,7 @@ async def async_tool(x: int, ctx: Context[ServerSessionT, None]) -> str: manager = ToolManager() manager.add_tool(async_tool) - mcp = MCPServer() - ctx = mcp.get_context() - result = await manager.call_tool("async_tool", {"x": 42}, context=ctx) - assert result == "42" - - @pytest.mark.anyio - async def test_context_optional(self): - """Test that context is optional when calling tools.""" - - def tool_with_context(x: int, ctx: Context[ServerSessionT, None] | None = None) -> str: - return str(x) - - manager = ToolManager() - manager.add_tool(tool_with_context) - # Should not raise an error when context is not provided - result = await manager.call_tool("tool_with_context", {"x": 42}) + result = await manager.call_tool("async_tool", {"x": 42}, context=Context()) assert result == "42" @pytest.mark.anyio @@ -408,10 +393,8 @@ def tool_with_context(x: int, ctx: Context[ServerSessionT, None]) -> str: manager = ToolManager() manager.add_tool(tool_with_context) - mcp = MCPServer() - ctx = mcp.get_context() with pytest.raises(ToolError, match="Error executing tool tool_with_context"): - await manager.call_tool("tool_with_context", {"x": 42}, context=ctx) + await manager.call_tool("tool_with_context", {"x": 42}, context=Context()) class TestToolAnnotations: @@ -471,7 +454,7 @@ def get_user(user_id: int) -> UserOutput: manager = ToolManager() manager.add_tool(get_user) - result = await manager.call_tool("get_user", {"user_id": 1}, convert_result=True) + result = await manager.call_tool("get_user", {"user_id": 1}, Context(), convert_result=True) # don't test unstructured output here, just the structured conversion assert len(result) == 2 and result[1] == {"name": "John", "age": 30} @@ -485,9 +468,9 @@ def double_number(n: int) -> int: manager = ToolManager() manager.add_tool(double_number) - result = await manager.call_tool("double_number", {"n": 5}) + result = await manager.call_tool("double_number", {"n": 5}, Context()) assert result == 10 - result = await manager.call_tool("double_number", {"n": 5}, convert_result=True) + result = await manager.call_tool("double_number", {"n": 5}, Context(), convert_result=True) assert isinstance(result[0][0], TextContent) and result[1] == {"result": 10} @pytest.mark.anyio @@ -506,7 +489,7 @@ def get_user_dict(user_id: int) -> UserDict: manager = ToolManager() manager.add_tool(get_user_dict) - result = await manager.call_tool("get_user_dict", {"user_id": 1}) + result = await manager.call_tool("get_user_dict", {"user_id": 1}, Context()) assert result == expected_output @pytest.mark.anyio @@ -526,7 +509,7 @@ def get_person() -> Person: manager = ToolManager() manager.add_tool(get_person) - result = await manager.call_tool("get_person", {}, convert_result=True) + result = await manager.call_tool("get_person", {}, Context(), convert_result=True) # don't test unstructured output here, just the structured conversion assert len(result) == 2 and result[1] == expected_output @@ -543,9 +526,9 @@ def get_numbers() -> list[int]: manager = ToolManager() manager.add_tool(get_numbers) - result = await manager.call_tool("get_numbers", {}) + result = await manager.call_tool("get_numbers", {}, Context()) assert result == expected_list - result = await manager.call_tool("get_numbers", {}, convert_result=True) + result = await manager.call_tool("get_numbers", {}, Context(), convert_result=True) assert isinstance(result[0][0], TextContent) and result[1] == expected_output @pytest.mark.anyio @@ -558,7 +541,7 @@ def get_dict() -> dict[str, Any]: manager = ToolManager() manager.add_tool(get_dict, structured_output=False) - result = await manager.call_tool("get_dict", {}) + result = await manager.call_tool("get_dict", {}, Context()) assert isinstance(result, dict) assert result == {"key": "value"} @@ -601,12 +584,12 @@ def get_config() -> dict[str, Any]: assert "properties" not in tool.output_schema # dict[str, Any] has no constraints # Test raw result - result = await manager.call_tool("get_config", {}) + result = await manager.call_tool("get_config", {}, Context()) expected = {"debug": True, "port": 8080, "features": ["auth", "logging"]} assert result == expected # Test converted result - result = await manager.call_tool("get_config", {}) + result = await manager.call_tool("get_config", {}, Context()) assert result == expected @pytest.mark.anyio @@ -626,12 +609,12 @@ def get_scores() -> dict[str, int]: assert tool.output_schema["additionalProperties"]["type"] == "integer" # Test raw result - result = await manager.call_tool("get_scores", {}) + result = await manager.call_tool("get_scores", {}, Context()) expected = {"alice": 100, "bob": 85, "charlie": 92} assert result == expected # Test converted result - result = await manager.call_tool("get_scores", {}) + result = await manager.call_tool("get_scores", {}, Context()) assert result == expected @@ -885,7 +868,7 @@ def greet(name: str) -> str: manager.add_tool(greet) # Verify tool works before removal - result = await manager.call_tool("greet", {"name": "World"}) + result = await manager.call_tool("greet", {"name": "World"}, Context()) assert result == "Hello, World!" # Remove the tool @@ -893,7 +876,7 @@ def greet(name: str) -> str: # Verify calling removed tool raises error with pytest.raises(ToolError, match="Unknown tool: greet"): - await manager.call_tool("greet", {"name": "World"}) + await manager.call_tool("greet", {"name": "World"}, Context()) def test_remove_tool_case_sensitive(self): """Test that tool removal is case-sensitive.""" diff --git a/tests/server/mcpserver/tools/__init__.py b/tests/server/mcpserver/tools/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/server/mcpserver/tools/test_base.py b/tests/server/mcpserver/tools/test_base.py new file mode 100644 index 0000000000..22d5f973e9 --- /dev/null +++ b/tests/server/mcpserver/tools/test_base.py @@ -0,0 +1,10 @@ +from mcp.server.mcpserver import Context +from mcp.server.mcpserver.tools.base import Tool + + +def test_context_detected_in_union_annotation(): + def my_tool(x: int, ctx: Context | None) -> str: + raise NotImplementedError + + tool = Tool.from_function(my_tool) + assert tool.context_kwarg == "ctx" From b3149d2f33dd929143123195755bf4b9b57efddb Mon Sep 17 00:00:00 2001 From: Varun6578 <34965159+Varun6578@users.noreply.github.com> Date: Wed, 4 Mar 2026 20:15:11 +0530 Subject: [PATCH 22/84] fix: clean up SSE session on client disconnect (#2200) Co-authored-by: Varun Sharma Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Max Isbey <224885523+maxisbey@users.noreply.github.com> --- src/mcp/server/sse.py | 1 + tests/shared/test_sse.py | 27 +++++++++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 9007230cea..9dcee67f78 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -186,6 +186,7 @@ async def response_wrapper(scope: Scope, receive: Receive, send: Send): ) await read_stream_writer.aclose() await write_stream_reader.aclose() + self._read_stream_writers.pop(session_id, None) logging.debug(f"Client session disconnected {session_id}") logger.debug("Starting SSE response task") diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 7b2bc0a139..bfbecc0c8a 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -611,3 +611,30 @@ async def mock_aiter_sse() -> AsyncGenerator[ServerSentEvent, None]: assert not isinstance(msg, Exception) assert isinstance(msg.message, types.JSONRPCResponse) assert msg.message.id == 1 + + +@pytest.mark.anyio +async def test_sse_session_cleanup_on_disconnect(server: None, server_url: str) -> None: + """Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/1227 + + When a client disconnects, the server should remove the session from + _read_stream_writers. Without this cleanup, stale sessions accumulate and + POST requests to disconnected sessions return 202 Accepted followed by a + ClosedResourceError when the server tries to write to the dead stream. + """ + captured: list[str] = [] + + # Connect a client session, then disconnect + async with sse_client(server_url + "/sse", on_session_created=captured.append) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + + # After disconnect, POST to the stale session should return 404 + # (not 202 as it did before the fix) + async with httpx.AsyncClient() as client: + response = await client.post( + f"{server_url}/messages/?session_id={captured[0]}", + json={"jsonrpc": "2.0", "method": "ping", "id": 99}, + headers={"Content-Type": "application/json"}, + ) + assert response.status_code == 404 From 528abfab86f1f3c003bd7d54a1f0bbd65d81c59c Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 4 Mar 2026 16:11:34 +0000 Subject: [PATCH 23/84] tests: remove lax-no-cover pragmas by moving assertions before cancellation (#2206) --- tests/shared/test_sse.py | 19 ++++----- tests/shared/test_streamable_http.py | 63 +++++++++++++--------------- 2 files changed, 36 insertions(+), 46 deletions(-) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index bfbecc0c8a..890e997332 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -203,19 +203,15 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non @pytest.mark.anyio async def test_sse_client_on_session_created(server: None, server_url: str) -> None: - captured_session_id: str | None = None - - def on_session_created(session_id: str) -> None: - nonlocal captured_session_id - captured_session_id = session_id + captured: list[str] = [] - async with sse_client(server_url + "/sse", on_session_created=on_session_created) as streams: + async with sse_client(server_url + "/sse", on_session_created=captured.append) as streams: async with ClientSession(*streams) as session: result = await session.initialize() assert isinstance(result, InitializeResult) - - assert captured_session_id is not None # pragma: lax no cover - assert len(captured_session_id) > 0 # pragma: lax no cover + # Callback fires when the endpoint event arrives, before sse_client yields. + assert len(captured) == 1 + assert len(captured[0]) > 0 @pytest.mark.parametrize( @@ -248,8 +244,9 @@ def mock_extract(url: str) -> None: async with ClientSession(*streams) as session: result = await session.initialize() assert isinstance(result, InitializeResult) - - callback_mock.assert_not_called() # pragma: lax no cover + # Callback would have fired by now (endpoint event arrives before + # sse_client yields); if it hasn't, it won't. + callback_mock.assert_not_called() @pytest.fixture diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 42b1a3698a..61ba4a2e54 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -1132,22 +1132,19 @@ async def test_streamable_http_client_session_termination(basic_server: None, ba read_stream, write_stream, ): - async with ClientSession(read_stream, write_stream) as session: + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch # Initialize the session result = await session.initialize() assert isinstance(result, InitializeResult) assert len(captured_ids) > 0 captured_session_id = captured_ids[0] assert captured_session_id is not None + headers = {MCP_SESSION_ID_HEADER: captured_session_id} # Make a request to confirm session is working tools = await session.list_tools() assert len(tools.tools) == 10 - headers: dict[str, str] = {} # pragma: lax no cover - if captured_session_id: # pragma: lax no cover - headers[MCP_SESSION_ID_HEADER] = captured_session_id - async with create_mcp_http_client(headers=headers) as httpx_client2: async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client2) as ( read_stream, @@ -1196,22 +1193,19 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt read_stream, write_stream, ): - async with ClientSession(read_stream, write_stream) as session: + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch # Initialize the session result = await session.initialize() assert isinstance(result, InitializeResult) assert len(captured_ids) > 0 captured_session_id = captured_ids[0] assert captured_session_id is not None + headers = {MCP_SESSION_ID_HEADER: captured_session_id} # Make a request to confirm session is working tools = await session.list_tools() assert len(tools.tools) == 10 - headers: dict[str, str] = {} # pragma: lax no cover - if captured_session_id: # pragma: lax no cover - headers[MCP_SESSION_ID_HEADER] = captured_session_id - async with create_mcp_http_client(headers=headers) as httpx_client2: async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client2) as ( read_stream, @@ -1231,7 +1225,6 @@ async def test_streamable_http_client_resumption(event_server: tuple[SimpleEvent # Variables to track the state captured_resumption_token: str | None = None captured_notifications: list[types.ServerNotification] = [] - captured_protocol_version: str | int | None = None first_notification_received = False async def message_handler( # pragma: no branch @@ -1258,15 +1251,20 @@ async def on_resumption_token_update(token: str) -> None: read_stream, write_stream, ): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: + async with ClientSession( # pragma: no branch + read_stream, write_stream, message_handler=message_handler + ) as session: # Initialize the session result = await session.initialize() assert isinstance(result, InitializeResult) assert len(captured_ids) > 0 captured_session_id = captured_ids[0] assert captured_session_id is not None - # Capture the negotiated protocol version - captured_protocol_version = result.protocol_version + # Build phase-2 headers now while both values are in scope + headers: dict[str, Any] = { + MCP_SESSION_ID_HEADER: captured_session_id, + MCP_PROTOCOL_VERSION_HEADER: result.protocol_version, + } # Start the tool that will wait on lock in a task async with anyio.create_task_group() as tg: # pragma: no branch @@ -1291,25 +1289,19 @@ async def run_tool(): while not first_notification_received or not captured_resumption_token: await anyio.sleep(0.1) + # The while loop only exits after first_notification_received=True, + # which is set by message_handler immediately after appending to + # captured_notifications. The server tool is blocked on its lock, + # so nothing else can arrive before we cancel. + assert len(captured_notifications) == 1 + assert isinstance(captured_notifications[0], types.LoggingMessageNotification) + assert captured_notifications[0].params.data == "First notification before lock" + # Reset for phase 2 before cancelling + captured_notifications.clear() + # Kill the client session while tool is waiting on lock tg.cancel_scope.cancel() - # Verify we received exactly one notification (inside ClientSession - # so coverage tracks these on Python 3.11, see PR #1897 for details) - assert len(captured_notifications) == 1 # pragma: lax no cover - assert isinstance(captured_notifications[0], types.LoggingMessageNotification) # pragma: lax no cover - assert captured_notifications[0].params.data == "First notification before lock" # pragma: lax no cover - - # Clear notifications and set up headers for phase 2 (between connections, - # not tracked by coverage on Python 3.11 due to cancel scope + sys.settrace bug) - captured_notifications = [] # pragma: lax no cover - assert captured_session_id is not None # pragma: lax no cover - assert captured_protocol_version is not None # pragma: lax no cover - headers: dict[str, Any] = { # pragma: lax no cover - MCP_SESSION_ID_HEADER: captured_session_id, - MCP_PROTOCOL_VERSION_HEADER: captured_protocol_version, - } - async with create_mcp_http_client(headers=headers) as httpx_client2: async with streamable_http_client(f"{server_url}/mcp", http_client=httpx_client2) as ( read_stream, @@ -2092,11 +2084,12 @@ async def on_resumption_token(token: str) -> None: assert isinstance(result.content[0], TextContent) assert "Completed 3 checkpoints" in result.content[0].text - # 4 priming + 3 notifications + 1 response = 8 tokens - assert len(resumption_tokens) == 8, ( # pragma: lax no cover - f"Expected 8 resumption tokens (4 priming + 3 notifs + 1 response), " - f"got {len(resumption_tokens)}: {resumption_tokens}" - ) + # 4 priming + 3 notifications + 1 response = 8 tokens. All tokens are + # captured before send_request returns, so this is safe to check here. + assert len(resumption_tokens) == 8, ( + f"Expected 8 resumption tokens (4 priming + 3 notifs + 1 response), " + f"got {len(resumption_tokens)}: {resumption_tokens}" + ) @pytest.mark.anyio From 7c0224828bc4e62c53208a051aaa209779868f53 Mon Sep 17 00:00:00 2001 From: Giulio Leone Date: Thu, 5 Mar 2026 15:57:33 +0100 Subject: [PATCH 24/84] fix(oauth): include client_id in token request body for client_secret_post (#2185) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/mcp/client/auth/oauth2.py | 5 +- .../extensions/test_client_credentials.py | 66 +++++++++++++++++++ 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 7f5af51867..25075dec3b 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -205,8 +205,9 @@ def prepare_token_auth( headers["Authorization"] = f"Basic {encoded_credentials}" # Don't include client_secret in body for basic auth data = {k: v for k, v in data.items() if k != "client_secret"} - elif auth_method == "client_secret_post" and self.client_info.client_secret: - # Include client_secret in request body + elif auth_method == "client_secret_post" and self.client_info.client_id and self.client_info.client_secret: + # Include client_id and client_secret in request body (RFC 6749 §2.3.1) + data["client_id"] = self.client_info.client_id data["client_secret"] = self.client_info.client_secret # For auth_method == "none", don't add any client_secret diff --git a/tests/client/auth/extensions/test_client_credentials.py b/tests/client/auth/extensions/test_client_credentials.py index 0003b16797..09760f4530 100644 --- a/tests/client/auth/extensions/test_client_credentials.py +++ b/tests/client/auth/extensions/test_client_credentials.py @@ -252,6 +252,72 @@ async def test_exchange_token_client_credentials(self, mock_storage: MockTokenSt assert "scope=read write" in content assert "resource=https://api.example.com/v1/mcp" in content + @pytest.mark.anyio + async def test_exchange_token_client_secret_post_includes_client_id(self, mock_storage: MockTokenStorage): + """Test that client_secret_post includes both client_id and client_secret in body (RFC 6749 §2.3.1).""" + provider = ClientCredentialsOAuthProvider( + server_url="https://api.example.com/v1/mcp", + storage=mock_storage, + client_id="test-client-id", + client_secret="test-client-secret", + token_endpoint_auth_method="client_secret_post", + scopes="read write", + ) + await provider._initialize() + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://api.example.com"), + authorization_endpoint=AnyHttpUrl("https://api.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://api.example.com/token"), + ) + provider.context.protocol_version = "2025-06-18" + + request = await provider._perform_authorization() + + content = urllib.parse.unquote_plus(request.content.decode()) + assert "grant_type=client_credentials" in content + assert "client_id=test-client-id" in content + assert "client_secret=test-client-secret" in content + # Should NOT have Basic auth header + assert "Authorization" not in request.headers + + @pytest.mark.anyio + async def test_exchange_token_client_secret_post_without_client_id(self, mock_storage: MockTokenStorage): + """Test client_secret_post skips body credentials when client_id is None.""" + provider = ClientCredentialsOAuthProvider( + server_url="https://api.example.com/v1/mcp", + storage=mock_storage, + client_id="placeholder", + client_secret="test-client-secret", + token_endpoint_auth_method="client_secret_post", + scopes="read write", + ) + await provider._initialize() + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://api.example.com"), + authorization_endpoint=AnyHttpUrl("https://api.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://api.example.com/token"), + ) + provider.context.protocol_version = "2025-06-18" + # Override client_info to have client_id=None (edge case) + provider.context.client_info = OAuthClientInformationFull( + redirect_uris=None, + client_id=None, + client_secret="test-client-secret", + grant_types=["client_credentials"], + token_endpoint_auth_method="client_secret_post", + scope="read write", + ) + + request = await provider._perform_authorization() + + content = urllib.parse.unquote_plus(request.content.decode()) + assert "grant_type=client_credentials" in content + # Neither client_id nor client_secret should be in body since client_id is None + # (RFC 6749 §2.3.1 requires both for client_secret_post) + assert "client_id=" not in content + assert "client_secret=" not in content + assert "Authorization" not in request.headers + @pytest.mark.anyio async def test_exchange_token_without_scopes(self, mock_storage: MockTokenStorage): """Test token exchange without scopes.""" From b33c81167572096baeb7f7cff35987fc1168b28d Mon Sep 17 00:00:00 2001 From: Giulio Leone Date: Thu, 5 Mar 2026 16:44:33 +0100 Subject: [PATCH 25/84] perf: use deque for InMemoryTaskMessageQueue FIFO operations (#2165) --- src/mcp/shared/experimental/tasks/message_queue.py | 9 +++++---- tests/experimental/tasks/test_message_queue.py | 3 ++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/mcp/shared/experimental/tasks/message_queue.py b/src/mcp/shared/experimental/tasks/message_queue.py index 018c2b7b26..e17c4a8650 100644 --- a/src/mcp/shared/experimental/tasks/message_queue.py +++ b/src/mcp/shared/experimental/tasks/message_queue.py @@ -12,6 +12,7 @@ """ from abc import ABC, abstractmethod +from collections import deque from dataclasses import dataclass, field from datetime import datetime, timezone from typing import Any, Literal @@ -151,13 +152,13 @@ class InMemoryTaskMessageQueue(TaskMessageQueue): """ def __init__(self) -> None: - self._queues: dict[str, list[QueuedMessage]] = {} + self._queues: dict[str, deque[QueuedMessage]] = {} self._events: dict[str, anyio.Event] = {} - def _get_queue(self, task_id: str) -> list[QueuedMessage]: + def _get_queue(self, task_id: str) -> deque[QueuedMessage]: """Get or create the queue for a task.""" if task_id not in self._queues: - self._queues[task_id] = [] + self._queues[task_id] = deque() return self._queues[task_id] async def enqueue(self, task_id: str, message: QueuedMessage) -> None: @@ -172,7 +173,7 @@ async def dequeue(self, task_id: str) -> QueuedMessage | None: queue = self._get_queue(task_id) if not queue: return None - return queue.pop(0) + return queue.popleft() async def peek(self, task_id: str) -> QueuedMessage | None: """Return the next message without removing it.""" diff --git a/tests/experimental/tasks/test_message_queue.py b/tests/experimental/tasks/test_message_queue.py index a8517e535c..eca113d5b4 100644 --- a/tests/experimental/tasks/test_message_queue.py +++ b/tests/experimental/tasks/test_message_queue.py @@ -1,5 +1,6 @@ """Tests for TaskMessageQueue and InMemoryTaskMessageQueue.""" +from collections import deque from datetime import datetime, timezone import anyio @@ -270,7 +271,7 @@ async def is_empty_with_injection(tid: str) -> bool: if call_count == 2 and tid == task_id: # Before second check, inject a message - this simulates a message # arriving between event creation and the double-check - queue._queues[task_id] = [QueuedMessage(type="request", message=make_request())] + queue._queues[task_id] = deque([QueuedMessage(type="request", message=make_request())]) return await original_is_empty(tid) queue.is_empty = is_empty_with_injection # type: ignore[method-assign] From 92f1b1500d808a32d261d9e101d6f3bae3ad9e25 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 6 Mar 2026 14:50:58 +0000 Subject: [PATCH 26/84] fix: remove MIME type validation from MCPServer Resource (#2235) --- src/mcp/server/mcpserver/resources/base.py | 6 +----- tests/server/mcpserver/resources/test_resources.py | 8 ++++++++ 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/mcp/server/mcpserver/resources/base.py b/src/mcp/server/mcpserver/resources/base.py index d3ccc425ef..d48e0695cb 100644 --- a/src/mcp/server/mcpserver/resources/base.py +++ b/src/mcp/server/mcpserver/resources/base.py @@ -23,11 +23,7 @@ class Resource(BaseModel, abc.ABC): name: str | None = Field(description="Name of the resource", default=None) title: str | None = Field(description="Human-readable title of the resource", default=None) description: str | None = Field(description="Description of the resource", default=None) - mime_type: str = Field( - default="text/plain", - description="MIME type of the resource content", - pattern=r"^[a-zA-Z0-9]+/[a-zA-Z0-9\-+.]+(;\s*[a-zA-Z0-9\-_.]+=[a-zA-Z0-9\-_.]+)*$", - ) + mime_type: str = Field(default="text/plain", description="MIME type of the resource content") icons: list[Icon] | None = Field(default=None, description="Optional list of icons for this resource") annotations: Annotations | None = Field(default=None, description="Optional annotations for the resource") meta: dict[str, Any] | None = Field(default=None, description="Optional metadata for this resource") diff --git a/tests/server/mcpserver/resources/test_resources.py b/tests/server/mcpserver/resources/test_resources.py index 93dc438d5d..5d36beda85 100644 --- a/tests/server/mcpserver/resources/test_resources.py +++ b/tests/server/mcpserver/resources/test_resources.py @@ -91,6 +91,14 @@ def dummy_func() -> str: # pragma: no cover ) assert resource.mime_type == "application/json" + # RFC 2045 quoted parameter value (gh-1756) + resource = FunctionResource( + uri="resource://test", + fn=dummy_func, + mime_type='text/plain; charset="utf-8"', + ) + assert resource.mime_type == 'text/plain; charset="utf-8"' + @pytest.mark.anyio async def test_resource_read_abstract(self): """Test that Resource.read() is abstract.""" From eaf971cf252692c94d3f76d3b9063eaddc7f5eb4 Mon Sep 17 00:00:00 2001 From: Ramesh Reddy Adutla <134313151+rameshreddy-adutla@users.noreply.github.com> Date: Fri, 6 Mar 2026 16:55:37 +0000 Subject: [PATCH 27/84] Add warning log when rejecting request with unknown/expired session ID (#2212) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Max Isbey <224885523+maxisbey@users.noreply.github.com> --- src/mcp/server/streamable_http_manager.py | 1 + tests/server/test_streamable_http_manager.py | 7 +++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 50bcd5e791..c25314eab6 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -272,6 +272,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE # Unknown or expired session ID - return 404 per MCP spec # TODO: Align error code once spec clarifies # See: https://github.com/modelcontextprotocol/python-sdk/issues/1821 + logger.info(f"Rejected request with unknown or expired session ID: {request_mcp_session_id[:64]}") error_response = JSONRPCError( jsonrpc="2.0", id=None, diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 54a898cc5c..47cfbf14a4 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -1,6 +1,7 @@ """Tests for StreamableHTTPSessionManager.""" import json +import logging from typing import Any from unittest.mock import AsyncMock, patch @@ -269,7 +270,7 @@ async def mock_receive(): @pytest.mark.anyio -async def test_unknown_session_id_returns_404(): +async def test_unknown_session_id_returns_404(caplog: pytest.LogCaptureFixture): """Test that requests with unknown session IDs return HTTP 404 per MCP spec.""" app = Server("test-unknown-session") manager = StreamableHTTPSessionManager(app=app) @@ -299,7 +300,8 @@ async def mock_send(message: Message): async def mock_receive(): return {"type": "http.request", "body": b"{}", "more_body": False} # pragma: no cover - await manager.handle_request(scope, mock_receive, mock_send) + with caplog.at_level(logging.INFO): + await manager.handle_request(scope, mock_receive, mock_send) # Find the response start message response_start = next( @@ -315,6 +317,7 @@ async def mock_receive(): assert error_data["id"] is None assert error_data["error"]["code"] == INVALID_REQUEST assert error_data["error"]["message"] == "Session not found" + assert "Rejected request with unknown or expired session ID: non-existent-session-id" in caplog.text @pytest.mark.anyio From 7ba41dcfae2044987fbcca0e10744e992489091f Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 6 Mar 2026 17:24:18 +0000 Subject: [PATCH 28/84] fix: make local coverage runs reliable (#2236) --- .github/workflows/shared.yml | 1 + CLAUDE.md | 16 +++++++++++++--- scripts/test | 1 + tests/test_examples.py | 27 ++++++++++++--------------- 4 files changed, 27 insertions(+), 18 deletions(-) diff --git a/.github/workflows/shared.yml b/.github/workflows/shared.yml index 72e328b541..efb45c8898 100644 --- a/.github/workflows/shared.yml +++ b/.github/workflows/shared.yml @@ -70,6 +70,7 @@ jobs: - name: Run pytest with coverage shell: bash run: | + uv run --frozen --no-sync coverage erase uv run --frozen --no-sync coverage run -m pytest -n auto uv run --frozen --no-sync coverage combine uv run --frozen --no-sync coverage report diff --git a/CLAUDE.md b/CLAUDE.md index e48ce6e70c..98bd451152 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -28,9 +28,19 @@ This document contains critical information about working with this codebase. Fo - Bug fixes require regression tests - IMPORTANT: The `tests/client/test_client.py` is the most well designed test file. Follow its patterns. - IMPORTANT: Be minimal, and focus on E2E tests: Use the `mcp.client.Client` whenever possible. - - IMPORTANT: Before pushing, verify 100% branch coverage on changed files by running - `uv run --frozen pytest -x` (coverage is configured in `pyproject.toml` with `fail_under = 100` - and `branch = true`). If any branch is uncovered, add a test for it before pushing. + - Coverage: CI requires 100% (`fail_under = 100`, `branch = true`). + - Full check: `./scripts/test` (~20s, matches CI exactly) + - Targeted check while iterating: + + ```bash + uv run --frozen coverage erase + uv run --frozen coverage run -m pytest tests/path/test_foo.py + uv run --frozen coverage combine + uv run --frozen coverage report --include='src/mcp/path/foo.py' --fail-under=0 + ``` + + Partial runs can't hit 100% (coverage tracks `tests/` too), so `--fail-under=0` + and `--include` scope the report to what you actually changed. - Avoid `anyio.sleep()` with a fixed duration to wait for async operations. Instead: - Use `anyio.Event` — set it in the callback/handler, `await event.wait()` in the test - For stream messages, use `await stream.receive()` instead of `sleep()` + `receive_nowait()` diff --git a/scripts/test b/scripts/test index 0d08e47b1b..ee1259b597 100755 --- a/scripts/test +++ b/scripts/test @@ -2,6 +2,7 @@ set -ex +uv run --frozen coverage erase uv run --frozen coverage run -m pytest -n auto $@ uv run --frozen coverage combine uv run --frozen coverage report diff --git a/tests/test_examples.py b/tests/test_examples.py index aa9de09579..3af82f04c5 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -5,7 +5,6 @@ # pyright: reportUnknownArgumentType=false # pyright: reportUnknownMemberType=false -import sys from pathlib import Path import pytest @@ -65,12 +64,17 @@ async def test_direct_call_tool_result_return(): @pytest.mark.anyio -async def test_desktop(monkeypatch: pytest.MonkeyPatch): +async def test_desktop(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): """Test the desktop server""" - # Mock desktop directory listing - mock_files = [Path("/fake/path/file1.txt"), Path("/fake/path/file2.txt")] - monkeypatch.setattr(Path, "iterdir", lambda self: mock_files) # type: ignore[reportUnknownArgumentType] - monkeypatch.setattr(Path, "home", lambda: Path("/fake/home")) + # Build a real Desktop directory under tmp_path rather than patching + # Path.iterdir — a class-level patch breaks jsonschema_specifications' + # import-time schema discovery when this test happens to be the first + # tool call in an xdist worker. + desktop = tmp_path / "Desktop" + desktop.mkdir() + (desktop / "file1.txt").touch() + (desktop / "file2.txt").touch() + monkeypatch.setattr(Path, "home", lambda: tmp_path) from examples.mcpserver.desktop import mcp @@ -85,15 +89,8 @@ async def test_desktop(monkeypatch: pytest.MonkeyPatch): content = result.contents[0] assert isinstance(content, TextResourceContents) assert isinstance(content.text, str) - if sys.platform == "win32": # pragma: no cover - file_1 = "/fake/path/file1.txt".replace("/", "\\\\") # might be a bug - file_2 = "/fake/path/file2.txt".replace("/", "\\\\") # might be a bug - assert file_1 in content.text - assert file_2 in content.text - # might be a bug, but the test is passing - else: # pragma: lax no cover - assert "/fake/path/file1.txt" in content.text - assert "/fake/path/file2.txt" in content.text + assert "file1.txt" in content.text + assert "file2.txt" in content.text # TODO(v2): Change back to README.md when v2 is released From 51c53f2c189ddc8f5ed0925e582565a0f91b1d9b Mon Sep 17 00:00:00 2001 From: Shivam Aggarwal Date: Mon, 9 Mar 2026 22:00:02 +0530 Subject: [PATCH 29/84] fix: accept wildcard media types in Accept header per RFC 7231 (#2152) Co-authored-by: Shivam --- src/mcp/server/streamable_http.py | 15 +++-- tests/shared/test_streamable_http.py | 86 ++++++++++++++++++++++++++-- 2 files changed, 92 insertions(+), 9 deletions(-) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 04aed345e0..aa99e7c887 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -391,12 +391,19 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No await self._handle_unsupported_request(request, send) def _check_accept_headers(self, request: Request) -> tuple[bool, bool]: - """Check if the request accepts the required media types.""" + """Check if the request accepts the required media types. + + Supports wildcard media types per RFC 7231, section 5.3.2: + - */* matches any media type + - application/* matches any application/ subtype + - text/* matches any text/ subtype + """ accept_header = request.headers.get("accept", "") - accept_types = [media_type.strip() for media_type in accept_header.split(",")] + accept_types = [media_type.strip().split(";")[0].strip().lower() for media_type in accept_header.split(",")] - has_json = any(media_type.startswith(CONTENT_TYPE_JSON) for media_type in accept_types) - has_sse = any(media_type.startswith(CONTENT_TYPE_SSE) for media_type in accept_types) + has_wildcard = "*/*" in accept_types + has_json = has_wildcard or any(t in (CONTENT_TYPE_JSON, "application/*") for t in accept_types) + has_sse = has_wildcard or any(t in (CONTENT_TYPE_SSE, "text/*") for t in accept_types) return has_json, has_sse diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 61ba4a2e54..f8ca30441b 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -572,8 +572,10 @@ def json_server_url(json_server_port: int) -> str: # Basic request validation tests def test_accept_header_validation(basic_server: None, basic_server_url: str): """Test that Accept header is properly validated.""" - # Test without Accept header - response = requests.post( + # Test without Accept header (suppress requests library default Accept: */*) + session = requests.Session() + session.headers.pop("Accept") + response = session.post( f"{basic_server_url}/mcp", headers={"Content-Type": "application/json"}, json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, @@ -582,6 +584,52 @@ def test_accept_header_validation(basic_server: None, basic_server_url: str): assert "Not Acceptable" in response.text +@pytest.mark.parametrize( + "accept_header", + [ + "*/*", + "application/*, text/*", + "text/*, application/json", + "application/json, text/*", + "*/*;q=0.8", + "application/*;q=0.9, text/*;q=0.8", + ], +) +def test_accept_header_wildcard(basic_server: None, basic_server_url: str, accept_header: str): + """Test that wildcard Accept headers are accepted per RFC 7231.""" + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": accept_header, + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + + +@pytest.mark.parametrize( + "accept_header", + [ + "text/html", + "application/*", + "text/*", + ], +) +def test_accept_header_incompatible(basic_server: None, basic_server_url: str, accept_header: str): + """Test that incompatible Accept headers are rejected for SSE mode.""" + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": accept_header, + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text + + def test_content_type_validation(basic_server: None, basic_server_url: str): """Test that Content-Type header is properly validated.""" # Test with incorrect Content-Type @@ -826,7 +874,10 @@ def test_json_response_accept_json_only(json_response_server: None, json_server_ def test_json_response_missing_accept_header(json_response_server: None, json_server_url: str): """Test that json_response servers reject requests without Accept header.""" mcp_url = f"{json_server_url}/mcp" - response = requests.post( + # Suppress requests library default Accept: */* header + session = requests.Session() + session.headers.pop("Accept") + response = session.post( mcp_url, headers={ "Content-Type": "application/json", @@ -853,6 +904,29 @@ def test_json_response_incorrect_accept_header(json_response_server: None, json_ assert "Not Acceptable" in response.text +@pytest.mark.parametrize( + "accept_header", + [ + "*/*", + "application/*", + "application/*;q=0.9", + ], +) +def test_json_response_wildcard_accept_header(json_response_server: None, json_server_url: str, accept_header: str): + """Test that json_response servers accept wildcard Accept headers per RFC 7231.""" + mcp_url = f"{json_server_url}/mcp" + response = requests.post( + mcp_url, + headers={ + "Accept": accept_header, + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + assert response.headers.get("Content-Type") == "application/json" + + def test_get_sse_stream(basic_server: None, basic_server_url: str): """Test establishing an SSE stream via GET request.""" # First, we need to initialize a session @@ -941,8 +1015,10 @@ def test_get_validation(basic_server: None, basic_server_url: str): assert init_data is not None negotiated_version = init_data["result"]["protocolVersion"] - # Test without Accept header - response = requests.get( + # Test without Accept header (suppress requests library default Accept: */*) + session = requests.Session() + session.headers.pop("Accept") + response = session.get( mcp_url, headers={ MCP_SESSION_ID_HEADER: session_id, From 31a38b50786f8d22c99d781beaa89b71cad26d23 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 9 Mar 2026 16:52:56 +0000 Subject: [PATCH 30/84] fix: correct Context type parameters across examples and tests (#2256) --- README.v2.md | 19 ++++++--------- .../mcp_everything_server/server.py | 15 ++++++------ examples/snippets/servers/elicitation.py | 7 +++--- examples/snippets/servers/notifications.py | 3 +-- examples/snippets/servers/sampling.py | 3 +-- examples/snippets/servers/tool_progress.py | 3 +-- tests/client/test_list_roots_callback.py | 2 +- tests/server/mcpserver/test_elicitation.py | 21 ++++++++--------- tests/server/mcpserver/test_server.py | 17 +++++++------- tests/server/mcpserver/test_tool_manager.py | 11 ++++----- .../server/mcpserver/test_url_elicitation.py | 23 +++++++++---------- .../test_url_elicitation_error_throw.py | 7 +++--- tests/server/test_lifespan.py | 3 +-- 13 files changed, 59 insertions(+), 75 deletions(-) diff --git a/README.v2.md b/README.v2.md index bd6927bf92..55d867586d 100644 --- a/README.v2.md +++ b/README.v2.md @@ -346,13 +346,12 @@ Tools can optionally receive a Context object by including a parameter with the ```python from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession mcp = MCPServer(name="Progress Example") @mcp.tool() -async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str: +async def long_running_task(task_name: str, ctx: Context, steps: int = 5) -> str: """Execute a task with progress updates.""" await ctx.info(f"Starting: {task_name}") @@ -694,13 +693,12 @@ The Context object provides the following capabilities: ```python from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession mcp = MCPServer(name="Progress Example") @mcp.tool() -async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str: +async def long_running_task(task_name: str, ctx: Context, steps: int = 5) -> str: """Execute a task with progress updates.""" await ctx.info(f"Starting: {task_name}") @@ -826,7 +824,6 @@ import uuid from pydantic import BaseModel, Field from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession from mcp.shared.exceptions import UrlElicitationRequiredError from mcp.types import ElicitRequestURLParams @@ -844,7 +841,7 @@ class BookingPreferences(BaseModel): @mcp.tool() -async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerSession, None]) -> str: +async def book_table(date: str, time: str, party_size: int, ctx: Context) -> str: """Book a table with date availability check. This demonstrates form mode elicitation for collecting non-sensitive user input. @@ -868,7 +865,7 @@ async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerS @mcp.tool() -async def secure_payment(amount: float, ctx: Context[ServerSession, None]) -> str: +async def secure_payment(amount: float, ctx: Context) -> str: """Process a secure payment requiring URL confirmation. This demonstrates URL mode elicitation using ctx.elicit_url() for @@ -892,7 +889,7 @@ async def secure_payment(amount: float, ctx: Context[ServerSession, None]) -> st @mcp.tool() -async def connect_service(service_name: str, ctx: Context[ServerSession, None]) -> str: +async def connect_service(service_name: str, ctx: Context) -> str: """Connect to a third-party service requiring OAuth authorization. This demonstrates the "throw error" pattern using UrlElicitationRequiredError. @@ -933,14 +930,13 @@ Tools can interact with LLMs through sampling (generating text): ```python from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession from mcp.types import SamplingMessage, TextContent mcp = MCPServer(name="Sampling Example") @mcp.tool() -async def generate_poem(topic: str, ctx: Context[ServerSession, None]) -> str: +async def generate_poem(topic: str, ctx: Context) -> str: """Generate a poem using LLM sampling.""" prompt = f"Write a short poem about {topic}" @@ -970,13 +966,12 @@ Tools can send logs and notifications through the context: ```python from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession mcp = MCPServer(name="Notifications Example") @mcp.tool() -async def process_data(data: str, ctx: Context[ServerSession, None]) -> str: +async def process_data(data: str, ctx: Context) -> str: """Process data with logging.""" # Different log levels await ctx.debug(f"Debug: Processing '{data}'") diff --git a/examples/servers/everything-server/mcp_everything_server/server.py b/examples/servers/everything-server/mcp_everything_server/server.py index 2101cff28f..a0620b9c1d 100644 --- a/examples/servers/everything-server/mcp_everything_server/server.py +++ b/examples/servers/everything-server/mcp_everything_server/server.py @@ -13,7 +13,6 @@ from mcp.server import ServerRequestContext from mcp.server.mcpserver import Context, MCPServer from mcp.server.mcpserver.prompts.base import UserMessage -from mcp.server.session import ServerSession from mcp.server.streamable_http import EventCallback, EventMessage, EventStore from mcp.types import ( AudioContent, @@ -142,7 +141,7 @@ def test_multiple_content_types() -> list[TextContent | ImageContent | EmbeddedR @mcp.tool() -async def test_tool_with_logging(ctx: Context[ServerSession, None]) -> str: +async def test_tool_with_logging(ctx: Context) -> str: """Tests tool that emits log messages during execution""" await ctx.info("Tool execution started") await asyncio.sleep(0.05) @@ -155,7 +154,7 @@ async def test_tool_with_logging(ctx: Context[ServerSession, None]) -> str: @mcp.tool() -async def test_tool_with_progress(ctx: Context[ServerSession, None]) -> str: +async def test_tool_with_progress(ctx: Context) -> str: """Tests tool that reports progress notifications""" await ctx.report_progress(progress=0, total=100, message="Completed step 0 of 100") await asyncio.sleep(0.05) @@ -173,7 +172,7 @@ async def test_tool_with_progress(ctx: Context[ServerSession, None]) -> str: @mcp.tool() -async def test_sampling(prompt: str, ctx: Context[ServerSession, None]) -> str: +async def test_sampling(prompt: str, ctx: Context) -> str: """Tests server-initiated sampling (LLM completion request)""" try: # Request sampling from client @@ -198,7 +197,7 @@ class UserResponse(BaseModel): @mcp.tool() -async def test_elicitation(message: str, ctx: Context[ServerSession, None]) -> str: +async def test_elicitation(message: str, ctx: Context) -> str: """Tests server-initiated elicitation (user input request)""" try: # Request user input from client @@ -230,7 +229,7 @@ class SEP1034DefaultsSchema(BaseModel): @mcp.tool() -async def test_elicitation_sep1034_defaults(ctx: Context[ServerSession, None]) -> str: +async def test_elicitation_sep1034_defaults(ctx: Context) -> str: """Tests elicitation with default values for all primitive types (SEP-1034)""" try: # Request user input with defaults for all primitive types @@ -289,7 +288,7 @@ class EnumSchemasTestSchema(BaseModel): @mcp.tool() -async def test_elicitation_sep1330_enums(ctx: Context[ServerSession, None]) -> str: +async def test_elicitation_sep1330_enums(ctx: Context) -> str: """Tests elicitation with enum schema variations per SEP-1330""" try: result = await ctx.elicit( @@ -313,7 +312,7 @@ def test_error_handling() -> str: @mcp.tool() -async def test_reconnection(ctx: Context[ServerSession, None]) -> str: +async def test_reconnection(ctx: Context) -> str: """Tests SSE polling by closing stream mid-call (SEP-1699)""" await ctx.info("Before disconnect") diff --git a/examples/snippets/servers/elicitation.py b/examples/snippets/servers/elicitation.py index 70e515c75f..79453f543e 100644 --- a/examples/snippets/servers/elicitation.py +++ b/examples/snippets/servers/elicitation.py @@ -10,7 +10,6 @@ from pydantic import BaseModel, Field from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession from mcp.shared.exceptions import UrlElicitationRequiredError from mcp.types import ElicitRequestURLParams @@ -28,7 +27,7 @@ class BookingPreferences(BaseModel): @mcp.tool() -async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerSession, None]) -> str: +async def book_table(date: str, time: str, party_size: int, ctx: Context) -> str: """Book a table with date availability check. This demonstrates form mode elicitation for collecting non-sensitive user input. @@ -52,7 +51,7 @@ async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerS @mcp.tool() -async def secure_payment(amount: float, ctx: Context[ServerSession, None]) -> str: +async def secure_payment(amount: float, ctx: Context) -> str: """Process a secure payment requiring URL confirmation. This demonstrates URL mode elicitation using ctx.elicit_url() for @@ -76,7 +75,7 @@ async def secure_payment(amount: float, ctx: Context[ServerSession, None]) -> st @mcp.tool() -async def connect_service(service_name: str, ctx: Context[ServerSession, None]) -> str: +async def connect_service(service_name: str, ctx: Context) -> str: """Connect to a third-party service requiring OAuth authorization. This demonstrates the "throw error" pattern using UrlElicitationRequiredError. diff --git a/examples/snippets/servers/notifications.py b/examples/snippets/servers/notifications.py index 5579af510f..d6d903cc7f 100644 --- a/examples/snippets/servers/notifications.py +++ b/examples/snippets/servers/notifications.py @@ -1,11 +1,10 @@ from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession mcp = MCPServer(name="Notifications Example") @mcp.tool() -async def process_data(data: str, ctx: Context[ServerSession, None]) -> str: +async def process_data(data: str, ctx: Context) -> str: """Process data with logging.""" # Different log levels await ctx.debug(f"Debug: Processing '{data}'") diff --git a/examples/snippets/servers/sampling.py b/examples/snippets/servers/sampling.py index 4ffeeda726..43259589a4 100644 --- a/examples/snippets/servers/sampling.py +++ b/examples/snippets/servers/sampling.py @@ -1,12 +1,11 @@ from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession from mcp.types import SamplingMessage, TextContent mcp = MCPServer(name="Sampling Example") @mcp.tool() -async def generate_poem(topic: str, ctx: Context[ServerSession, None]) -> str: +async def generate_poem(topic: str, ctx: Context) -> str: """Generate a poem using LLM sampling.""" prompt = f"Write a short poem about {topic}" diff --git a/examples/snippets/servers/tool_progress.py b/examples/snippets/servers/tool_progress.py index 0b283cb1f5..376dbc5db8 100644 --- a/examples/snippets/servers/tool_progress.py +++ b/examples/snippets/servers/tool_progress.py @@ -1,11 +1,10 @@ from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession mcp = MCPServer(name="Progress Example") @mcp.tool() -async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str: +async def long_running_task(task_name: str, ctx: Context, steps: int = 5) -> str: """Execute a task with progress updates.""" await ctx.info(f"Starting: {task_name}") diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index 1ab90be772..be4b9a97b9 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -25,7 +25,7 @@ async def list_roots_callback( return callback_return @server.tool("test_list_roots") - async def test_list_roots(context: Context[None], message: str): + async def test_list_roots(context: Context, message: str): roots = await context.session.list_roots() assert roots == callback_return return True diff --git a/tests/server/mcpserver/test_elicitation.py b/tests/server/mcpserver/test_elicitation.py index 6cf49fbd7a..679fb848f5 100644 --- a/tests/server/mcpserver/test_elicitation.py +++ b/tests/server/mcpserver/test_elicitation.py @@ -8,7 +8,6 @@ from mcp import Client, types from mcp.client.session import ClientSession, ElicitationFnT from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession from mcp.shared._context import RequestContext from mcp.types import ElicitRequestParams, ElicitResult, TextContent @@ -22,7 +21,7 @@ def create_ask_user_tool(mcp: MCPServer): """Create a standard ask_user tool that handles all elicitation responses.""" @mcp.tool(description="A tool that uses elicitation") - async def ask_user(prompt: str, ctx: Context[ServerSession, None]) -> str: + async def ask_user(prompt: str, ctx: Context) -> str: result = await ctx.elicit(message=f"Tool wants to ask: {prompt}", schema=AnswerSchema) if result.action == "accept" and result.data: @@ -97,7 +96,7 @@ async def test_elicitation_schema_validation(): def create_validation_tool(name: str, schema_class: type[BaseModel]): @mcp.tool(name=name, description=f"Tool testing {name}") - async def tool(ctx: Context[ServerSession, None]) -> str: + async def tool(ctx: Context) -> str: try: await ctx.elicit(message="This should fail validation", schema=schema_class) return "Should not reach here" # pragma: no cover @@ -147,7 +146,7 @@ class OptionalSchema(BaseModel): subscribe: bool | None = Field(default=False, description="Subscribe to newsletter?") @mcp.tool(description="Tool with optional fields") - async def optional_tool(ctx: Context[ServerSession, None]) -> str: + async def optional_tool(ctx: Context) -> str: result = await ctx.elicit(message="Please provide your information", schema=OptionalSchema) if result.action == "accept" and result.data: @@ -188,7 +187,7 @@ class InvalidOptionalSchema(BaseModel): optional_list: list[int] | None = Field(default=None, description="Invalid optional list") @mcp.tool(description="Tool with invalid optional field") - async def invalid_optional_tool(ctx: Context[ServerSession, None]) -> str: + async def invalid_optional_tool(ctx: Context) -> str: try: await ctx.elicit(message="This should fail", schema=InvalidOptionalSchema) return "Should not reach here" # pragma: no cover @@ -214,7 +213,7 @@ class ValidMultiSelectSchema(BaseModel): tags: list[str] = Field(description="Tags") @mcp.tool(description="Tool with valid list[str] field") - async def valid_multiselect_tool(ctx: Context[ServerSession, None]) -> str: + async def valid_multiselect_tool(ctx: Context) -> str: result = await ctx.elicit(message="Please provide tags", schema=ValidMultiSelectSchema) if result.action == "accept" and result.data: return f"Name: {result.data.name}, Tags: {', '.join(result.data.tags)}" @@ -233,7 +232,7 @@ class OptionalMultiSelectSchema(BaseModel): tags: list[str] | None = Field(default=None, description="Optional tags") @mcp.tool(description="Tool with optional list[str] field") - async def optional_multiselect_tool(ctx: Context[ServerSession, None]) -> str: + async def optional_multiselect_tool(ctx: Context) -> str: result = await ctx.elicit(message="Please provide optional tags", schema=OptionalMultiSelectSchema) if result.action == "accept" and result.data: tags_str = ", ".join(result.data.tags) if result.data.tags else "none" @@ -262,7 +261,7 @@ class DefaultsSchema(BaseModel): email: str = Field(description="Email address (required)") @mcp.tool(description="Tool with default values") - async def defaults_tool(ctx: Context[ServerSession, None]) -> str: + async def defaults_tool(ctx: Context) -> str: result = await ctx.elicit(message="Please provide your information", schema=DefaultsSchema) if result.action == "accept" and result.data: @@ -327,7 +326,7 @@ class FavoriteColorSchema(BaseModel): ) @mcp.tool(description="Single color selection") - async def select_favorite_color(ctx: Context[ServerSession, None]) -> str: + async def select_favorite_color(ctx: Context) -> str: result = await ctx.elicit(message="Select your favorite color", schema=FavoriteColorSchema) if result.action == "accept" and result.data: return f"User: {result.data.user_name}, Favorite: {result.data.favorite_color}" @@ -351,7 +350,7 @@ class FavoriteColorsSchema(BaseModel): ) @mcp.tool(description="Multiple color selection") - async def select_favorite_colors(ctx: Context[ServerSession, None]) -> str: + async def select_favorite_colors(ctx: Context) -> str: result = await ctx.elicit(message="Select your favorite colors", schema=FavoriteColorsSchema) if result.action == "accept" and result.data: return f"User: {result.data.user_name}, Colors: {', '.join(result.data.favorite_colors)}" @@ -366,7 +365,7 @@ class LegacyColorSchema(BaseModel): ) @mcp.tool(description="Legacy enum format") - async def select_color_legacy(ctx: Context[ServerSession, None]) -> str: + async def select_color_legacy(ctx: Context) -> str: result = await ctx.elicit(message="Select a color (legacy format)", schema=LegacyColorSchema) if result.action == "accept" and result.data: return f"User: {result.data.user_name}, Color: {result.data.color}" diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index 3d130bfc33..3ef06d0381 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -17,7 +17,6 @@ from mcp.server.mcpserver.prompts.base import Message, UserMessage from mcp.server.mcpserver.resources import FileResource, FunctionResource from mcp.server.mcpserver.utilities.types import Audio, Image -from mcp.server.session import ServerSession from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.exceptions import MCPError from mcp.types import ( @@ -1003,7 +1002,7 @@ async def test_context_detection(self): """Test that context parameters are properly detected.""" mcp = MCPServer() - def tool_with_context(x: int, ctx: Context[ServerSession, None]) -> str: # pragma: no cover + def tool_with_context(x: int, ctx: Context) -> str: # pragma: no cover return f"Request {ctx.request_id}: {x}" tool = mcp._tool_manager.add_tool(tool_with_context) @@ -1013,7 +1012,7 @@ async def test_context_injection(self): """Test that context is properly injected into tool calls.""" mcp = MCPServer() - def tool_with_context(x: int, ctx: Context[ServerSession, None]) -> str: + def tool_with_context(x: int, ctx: Context) -> str: assert ctx.request_id is not None return f"Request {ctx.request_id}: {x}" @@ -1030,7 +1029,7 @@ async def test_async_context(self): """Test that context works in async functions.""" mcp = MCPServer() - async def async_tool(x: int, ctx: Context[ServerSession, None]) -> str: + async def async_tool(x: int, ctx: Context) -> str: assert ctx.request_id is not None return f"Async request {ctx.request_id}: {x}" @@ -1047,7 +1046,7 @@ async def test_context_logging(self): """Test that context logging methods work.""" mcp = MCPServer() - async def logging_tool(msg: str, ctx: Context[ServerSession, None]) -> str: + async def logging_tool(msg: str, ctx: Context) -> str: await ctx.debug("Debug message") await ctx.info("Info message") await ctx.warning("Warning message") @@ -1094,7 +1093,7 @@ def test_resource() -> str: return "resource data" @mcp.tool() - async def tool_with_resource(ctx: Context[ServerSession, None]) -> str: + async def tool_with_resource(ctx: Context) -> str: r_iter = await ctx.read_resource("test://data") r_list = list(r_iter) assert len(r_list) == 1 @@ -1113,7 +1112,7 @@ async def test_resource_with_context(self): mcp = MCPServer() @mcp.resource("resource://context/{name}") - def resource_with_context(name: str, ctx: Context[ServerSession, None]) -> str: + def resource_with_context(name: str, ctx: Context) -> str: """Resource that receives context.""" assert ctx is not None return f"Resource {name} - context injected" @@ -1166,7 +1165,7 @@ async def test_resource_context_custom_name(self): mcp = MCPServer() @mcp.resource("resource://custom/{id}") - def resource_custom_ctx(id: str, my_ctx: Context[ServerSession, None]) -> str: + def resource_custom_ctx(id: str, my_ctx: Context) -> str: """Resource with custom context parameter name.""" assert my_ctx is not None return f"Resource {id} with context" @@ -1194,7 +1193,7 @@ async def test_prompt_with_context(self): mcp = MCPServer() @mcp.prompt("prompt_with_ctx") - def prompt_with_context(text: str, ctx: Context[ServerSession, None]) -> str: + def prompt_with_context(text: str, ctx: Context) -> str: """Prompt that expects context.""" assert ctx is not None return f"Prompt '{text}' - context injected" diff --git a/tests/server/mcpserver/test_tool_manager.py b/tests/server/mcpserver/test_tool_manager.py index f990ec47b7..e4dfd4ff9b 100644 --- a/tests/server/mcpserver/test_tool_manager.py +++ b/tests/server/mcpserver/test_tool_manager.py @@ -11,7 +11,6 @@ from mcp.server.mcpserver.exceptions import ToolError from mcp.server.mcpserver.tools import Tool, ToolManager from mcp.server.mcpserver.utilities.func_metadata import ArgModelBase, FuncMetadata -from mcp.server.session import ServerSessionT from mcp.types import TextContent, ToolAnnotations @@ -319,7 +318,7 @@ def name_shrimp(tank: MyShrimpTank) -> list[str]: class TestToolSchema: @pytest.mark.anyio async def test_context_arg_excluded_from_schema(self): - def something(a: int, ctx: Context[ServerSessionT, None]) -> int: # pragma: no cover + def something(a: int, ctx: Context) -> int: # pragma: no cover return a manager = ToolManager() @@ -336,7 +335,7 @@ def test_context_parameter_detection(self): """Test that context parameters are properly detected in Tool.from_function().""" - def tool_with_context(x: int, ctx: Context[ServerSessionT, None]) -> str: # pragma: no cover + def tool_with_context(x: int, ctx: Context) -> str: # pragma: no cover return str(x) manager = ToolManager() @@ -359,7 +358,7 @@ def tool_with_parametrized_context(x: int, ctx: Context[LifespanContextT, Reques async def test_context_injection(self): """Test that context is properly injected during tool execution.""" - def tool_with_context(x: int, ctx: Context[ServerSessionT, None]) -> str: + def tool_with_context(x: int, ctx: Context) -> str: assert isinstance(ctx, Context) return str(x) @@ -373,7 +372,7 @@ def tool_with_context(x: int, ctx: Context[ServerSessionT, None]) -> str: async def test_context_injection_async(self): """Test that context is properly injected in async tools.""" - async def async_tool(x: int, ctx: Context[ServerSessionT, None]) -> str: + async def async_tool(x: int, ctx: Context) -> str: assert isinstance(ctx, Context) return str(x) @@ -387,7 +386,7 @@ async def async_tool(x: int, ctx: Context[ServerSessionT, None]) -> str: async def test_context_error_handling(self): """Test error handling when context injection fails.""" - def tool_with_context(x: int, ctx: Context[ServerSessionT, None]) -> str: + def tool_with_context(x: int, ctx: Context) -> str: raise ValueError("Test error") manager = ToolManager() diff --git a/tests/server/mcpserver/test_url_elicitation.py b/tests/server/mcpserver/test_url_elicitation.py index 1311bd6728..af90dc208b 100644 --- a/tests/server/mcpserver/test_url_elicitation.py +++ b/tests/server/mcpserver/test_url_elicitation.py @@ -8,7 +8,6 @@ from mcp.client.session import ClientSession from mcp.server.elicitation import CancelledElicitation, DeclinedElicitation, elicit_url from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession from mcp.shared._context import RequestContext from mcp.types import ElicitRequestParams, ElicitResult, TextContent @@ -19,7 +18,7 @@ async def test_url_elicitation_accept(): mcp = MCPServer(name="URLElicitationServer") @mcp.tool(description="A tool that uses URL elicitation") - async def request_api_key(ctx: Context[ServerSession, None]) -> str: + async def request_api_key(ctx: Context) -> str: result = await ctx.session.elicit_url( message="Please provide your API key to continue.", url="https://example.com/api_key_setup", @@ -49,7 +48,7 @@ async def test_url_elicitation_decline(): mcp = MCPServer(name="URLElicitationDeclineServer") @mcp.tool(description="A tool that uses URL elicitation") - async def oauth_flow(ctx: Context[ServerSession, None]) -> str: + async def oauth_flow(ctx: Context) -> str: result = await ctx.session.elicit_url( message="Authorize access to your files.", url="https://example.com/oauth/authorize", @@ -75,7 +74,7 @@ async def test_url_elicitation_cancel(): mcp = MCPServer(name="URLElicitationCancelServer") @mcp.tool(description="A tool that uses URL elicitation") - async def payment_flow(ctx: Context[ServerSession, None]) -> str: + async def payment_flow(ctx: Context) -> str: result = await ctx.session.elicit_url( message="Complete payment to proceed.", url="https://example.com/payment", @@ -101,7 +100,7 @@ async def test_url_elicitation_helper_function(): mcp = MCPServer(name="URLElicitationHelperServer") @mcp.tool(description="Tool using elicit_url helper") - async def setup_credentials(ctx: Context[ServerSession, None]) -> str: + async def setup_credentials(ctx: Context) -> str: result = await elicit_url( session=ctx.session, message="Set up your credentials", @@ -127,7 +126,7 @@ async def test_url_no_content_in_response(): mcp = MCPServer(name="URLContentCheckServer") @mcp.tool(description="Check URL response format") - async def check_url_response(ctx: Context[ServerSession, None]) -> str: + async def check_url_response(ctx: Context) -> str: result = await ctx.session.elicit_url( message="Test message", url="https://example.com/test", @@ -164,7 +163,7 @@ class NameSchema(BaseModel): name: str = Field(description="Your name") @mcp.tool(description="Test form mode") - async def ask_name(ctx: Context[ServerSession, None]) -> str: + async def ask_name(ctx: Context) -> str: result = await ctx.elicit(message="What is your name?", schema=NameSchema) # Test only checks accept path with data assert result.action == "accept" @@ -195,7 +194,7 @@ async def test_elicit_complete_notification(): notification_sent = False @mcp.tool(description="Tool that sends completion notification") - async def trigger_elicitation(ctx: Context[ServerSession, None]) -> str: + async def trigger_elicitation(ctx: Context) -> str: nonlocal notification_sent # Simulate an async operation (e.g., user completing auth in browser) @@ -238,7 +237,7 @@ async def test_elicit_url_typed_results(): mcp = MCPServer(name="TypedResultsServer") @mcp.tool(description="Test declined result") - async def test_decline(ctx: Context[ServerSession, None]) -> str: + async def test_decline(ctx: Context) -> str: result = await elicit_url( session=ctx.session, message="Test decline", @@ -251,7 +250,7 @@ async def test_decline(ctx: Context[ServerSession, None]) -> str: return "Not declined" # pragma: no cover @mcp.tool(description="Test cancelled result") - async def test_cancel(ctx: Context[ServerSession, None]) -> str: + async def test_cancel(ctx: Context) -> str: result = await elicit_url( session=ctx.session, message="Test cancel", @@ -293,7 +292,7 @@ class EmailSchema(BaseModel): email: str = Field(description="Email address") @mcp.tool(description="Test deprecated elicit method") - async def use_deprecated_elicit(ctx: Context[ServerSession, None]) -> str: + async def use_deprecated_elicit(ctx: Context) -> str: # Use the deprecated elicit() method which should call elicit_form() result = await ctx.session.elicit( message="Enter your email", @@ -323,7 +322,7 @@ async def test_ctx_elicit_url_convenience_method(): mcp = MCPServer(name="CtxElicitUrlServer") @mcp.tool(description="A tool that uses ctx.elicit_url() directly") - async def direct_elicit_url(ctx: Context[ServerSession, None]) -> str: + async def direct_elicit_url(ctx: Context) -> str: # Use ctx.elicit_url() directly instead of ctx.session.elicit_url() result = await ctx.elicit_url( message="Test the convenience method", diff --git a/tests/server/mcpserver/test_url_elicitation_error_throw.py b/tests/server/mcpserver/test_url_elicitation_error_throw.py index 2d29937995..1f45fd60f0 100644 --- a/tests/server/mcpserver/test_url_elicitation_error_throw.py +++ b/tests/server/mcpserver/test_url_elicitation_error_throw.py @@ -5,7 +5,6 @@ from mcp import Client, ErrorData, types from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession from mcp.shared.exceptions import MCPError, UrlElicitationRequiredError @@ -15,7 +14,7 @@ async def test_url_elicitation_error_thrown_from_tool(): mcp = MCPServer(name="UrlElicitationErrorServer") @mcp.tool(description="A tool that raises UrlElicitationRequiredError") - async def connect_service(service_name: str, ctx: Context[ServerSession, None]) -> str: + async def connect_service(service_name: str, ctx: Context) -> str: # This tool cannot proceed without authorization raise UrlElicitationRequiredError( [ @@ -56,7 +55,7 @@ async def test_url_elicitation_error_from_error(): mcp = MCPServer(name="UrlElicitationErrorServer") @mcp.tool(description="A tool that raises UrlElicitationRequiredError with multiple elicitations") - async def multi_auth(ctx: Context[ServerSession, None]) -> str: + async def multi_auth(ctx: Context) -> str: raise UrlElicitationRequiredError( [ types.ElicitRequestURLParams( @@ -97,7 +96,7 @@ async def test_normal_exceptions_still_return_error_result(): mcp = MCPServer(name="NormalErrorServer") @mcp.tool(description="A tool that raises a normal exception") - async def failing_tool(ctx: Context[ServerSession, None]) -> str: + async def failing_tool(ctx: Context) -> str: raise ValueError("Something went wrong") async with Client(mcp) as client: diff --git a/tests/server/test_lifespan.py b/tests/server/test_lifespan.py index 0f8840d291..0d87905042 100644 --- a/tests/server/test_lifespan.py +++ b/tests/server/test_lifespan.py @@ -11,7 +11,6 @@ from mcp.server.lowlevel.server import NotificationOptions, Server from mcp.server.mcpserver import Context, MCPServer from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession from mcp.shared.message import SessionMessage from mcp.types import ( CallToolRequestParams, @@ -143,7 +142,7 @@ async def test_lifespan(server: MCPServer) -> AsyncIterator[dict[str, bool]]: # Add a tool that checks lifespan context @server.tool() - def check_lifespan(ctx: Context[ServerSession, None]) -> bool: + def check_lifespan(ctx: Context) -> bool: """Tool that checks lifespan context.""" assert isinstance(ctx.request_context.lifespan_context, dict) assert ctx.request_context.lifespan_context["started"] From 62eb08e5b23944510b8ec500a51c8f895fb58553 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 9 Mar 2026 17:47:27 +0000 Subject: [PATCH 31/84] fix: don't send log notification on transport error (#2257) --- src/mcp/server/lowlevel/server.py | 5 -- src/mcp/shared/session.py | 2 +- .../test_lowlevel_exception_handling.py | 82 +++++++++++++------ 3 files changed, 58 insertions(+), 31 deletions(-) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 1c84c86107..167f34b8bc 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -414,11 +414,6 @@ async def _handle_message( ) case Exception(): logger.error(f"Received exception from stream: {message}") - await session.send_log_message( - level="error", - data="Internal Server Error", - logger="mcp.server.exception_handler", - ) if raise_exceptions: raise message case _: diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index b617d702fe..9364abb73b 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -334,7 +334,7 @@ async def _receive_loop(self) -> None: async with self._read_stream, self._write_stream: try: async for message in self._read_stream: - if isinstance(message, Exception): # pragma: no cover + if isinstance(message, Exception): await self._handle_incoming(message) elif isinstance(message.message, JSONRPCRequest): try: diff --git a/tests/server/test_lowlevel_exception_handling.py b/tests/server/test_lowlevel_exception_handling.py index 848b35b299..46925916d9 100644 --- a/tests/server/test_lowlevel_exception_handling.py +++ b/tests/server/test_lowlevel_exception_handling.py @@ -1,55 +1,42 @@ from unittest.mock import AsyncMock, Mock +import anyio import pytest from mcp import types from mcp.server.lowlevel.server import Server from mcp.server.session import ServerSession +from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder @pytest.mark.anyio async def test_exception_handling_with_raise_exceptions_true(): - """Test that exceptions are re-raised when raise_exceptions=True""" + """Transport exceptions are re-raised when raise_exceptions=True.""" server = Server("test-server") session = Mock(spec=ServerSession) - session.send_log_message = AsyncMock() test_exception = RuntimeError("Test error") with pytest.raises(RuntimeError, match="Test error"): await server._handle_message(test_exception, session, {}, raise_exceptions=True) - session.send_log_message.assert_called_once() - @pytest.mark.anyio -@pytest.mark.parametrize( - "exception_class,message", - [ - (ValueError, "Test validation error"), - (RuntimeError, "Test runtime error"), - (KeyError, "Test key error"), - (Exception, "Basic error"), - ], -) -async def test_exception_handling_with_raise_exceptions_false(exception_class: type[Exception], message: str): - """Test that exceptions are logged when raise_exceptions=False""" +async def test_exception_handling_with_raise_exceptions_false(): + """Transport exceptions are logged locally but not sent to the client. + + The transport that reported the error is likely broken; writing back + through it races with stream closure (#1967, #2064). The TypeScript, + Go, and C# SDKs all log locally only. + """ server = Server("test-server") session = Mock(spec=ServerSession) session.send_log_message = AsyncMock() - test_exception = exception_class(message) - - await server._handle_message(test_exception, session, {}, raise_exceptions=False) - - # Should send log message - session.send_log_message.assert_called_once() - call_args = session.send_log_message.call_args + await server._handle_message(RuntimeError("Test error"), session, {}, raise_exceptions=False) - assert call_args.kwargs["level"] == "error" - assert call_args.kwargs["data"] == "Internal Server Error" - assert call_args.kwargs["logger"] == "mcp.server.exception_handler" + session.send_log_message.assert_not_called() @pytest.mark.anyio @@ -72,3 +59,48 @@ async def test_normal_message_handling_not_affected(): # Verify _handle_request was called server._handle_request.assert_called_once() + + +@pytest.mark.anyio +async def test_server_run_exits_cleanly_when_transport_yields_exception_then_closes(): + """Regression test for #1967 / #2064. + + Exercises the real Server.run() path with real memory streams, reproducing + what happens in stateless streamable HTTP when a POST handler throws: + + 1. Transport yields an Exception into the read stream + (streamable_http.py does this in its broad POST-handler except). + 2. Transport closes the read stream (terminate() in stateless mode). + 3. _receive_loop exits its `async with read_stream, write_stream:` block, + closing the write stream. + 4. Meanwhile _handle_message(exc) was spawned via tg.start_soon and runs + after the write stream is closed. + + Before the fix, _handle_message tried to send_log_message through the + closed write stream, raising ClosedResourceError inside the TaskGroup + and crashing server.run(). After the fix, it only logs locally. + """ + server = Server("test-server") + + read_send, read_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + # Zero-buffer on the write stream forces send() to block until received. + # With no receiver, a send() sits blocked until _receive_loop exits its + # `async with self._read_stream, self._write_stream:` block and closes the + # stream, at which point the blocked send raises ClosedResourceError. + # This deterministically reproduces the race without sleeps. + write_send, write_recv = anyio.create_memory_object_stream[SessionMessage](0) + + # What the streamable HTTP transport does: push the exception, then close. + read_send.send_nowait(RuntimeError("simulated transport error")) + read_send.close() + + with anyio.fail_after(5): + # stateless=True so server.run doesn't wait for initialize handshake. + # Before this fix, this raised ExceptionGroup(ClosedResourceError). + await server.run(read_recv, write_send, server.create_initialization_options(), stateless=True) + + # write_send was closed inside _receive_loop's `async with`; receive_nowait + # raises EndOfStream iff the buffer is empty (i.e., server wrote nothing). + with pytest.raises(anyio.EndOfStream): + write_recv.receive_nowait() + write_recv.close() From dd52713517d88541f9233cb2af62012ad65bb993 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 12 Mar 2026 12:52:32 +0000 Subject: [PATCH 32/84] Rewrite TestChildProcessCleanup with socket-based deterministic liveness probe (#2265) --- src/mcp/os/win32/utilities.py | 5 + tests/client/test_stdio.py | 454 +++++++++++++++------------------- 2 files changed, 201 insertions(+), 258 deletions(-) diff --git a/src/mcp/os/win32/utilities.py b/src/mcp/os/win32/utilities.py index 0e188691f1..6f68405f78 100644 --- a/src/mcp/os/win32/utilities.py +++ b/src/mcp/os/win32/utilities.py @@ -123,6 +123,11 @@ def pid(self) -> int: """Return the process ID.""" return self.popen.pid + @property + def returncode(self) -> int | None: + """Return the exit code, or ``None`` if the process has not yet terminated.""" + return self.popen.returncode + # ------------------------ # Updated function diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index f70c24eee7..06e2cba4b1 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -1,12 +1,12 @@ import errno -import os import shutil import sys -import tempfile import textwrap import time +from contextlib import AsyncExitStack, suppress import anyio +import anyio.abc import pytest from mcp.client.session import ClientSession @@ -16,12 +16,11 @@ _terminate_process_tree, stdio_client, ) +from mcp.os.win32.utilities import FallbackProcess from mcp.shared.exceptions import MCPError from mcp.shared.message import SessionMessage from mcp.types import CONNECTION_CLOSED, JSONRPCMessage, JSONRPCRequest, JSONRPCResponse -from ..shared.test_win32_utils import escape_path_for_python - # Timeout for cleanup of processes that ignore SIGTERM # This timeout ensures the test fails quickly if the cleanup logic doesn't have # proper fallback mechanisms (SIGINT/SIGKILL) for processes that ignore SIGTERM @@ -221,291 +220,230 @@ def sigint_handler(signum, frame): raise -class TestChildProcessCleanup: - """Tests for child process cleanup functionality using _terminate_process_tree. - - These tests verify that child processes are properly terminated when the parent - is killed, addressing the issue where processes like npx spawn child processes - that need to be cleaned up. The tests cover various process tree scenarios: - - - Basic parent-child relationship (single child process) - - Multi-level process trees (parent → child → grandchild) - - Race conditions where parent exits during cleanup - - Note on Windows ResourceWarning: - On Windows, we may see ResourceWarning about subprocess still running. This is - expected behavior due to how Windows process termination works: - - anyio's process.terminate() calls Windows TerminateProcess() API - - TerminateProcess() immediately kills the process without allowing cleanup - - subprocess.Popen objects in the killed process can't run their cleanup code - - Python detects this during garbage collection and issues a ResourceWarning - - This warning does NOT indicate a process leak - the processes are properly - terminated. It only means the Popen objects couldn't clean up gracefully. - This is a fundamental difference between Windows and Unix process termination. - """ +# --------------------------------------------------------------------------- +# TestChildProcessCleanup — socket-based deterministic child liveness probe +# --------------------------------------------------------------------------- +# +# These tests verify that `_terminate_process_tree()` kills the *entire* process +# tree (not just the immediate child), which is critical for cleaning up tools +# like `npx` that spawn their own subprocesses. +# +# Mechanism: each subprocess in the tree connects a TCP socket back to a +# listener owned by the test. We then use two kernel-guaranteed blocking-I/O +# signals — neither requires any `sleep()` or polling loop: +# +# 1. `await listener.accept()` blocks until the subprocess connects, +# proving it is running. +# 2. After `_terminate_process_tree()`, `await stream.receive(1)` raises +# `EndOfStream` (clean close / FIN) or `BrokenResourceError` (abrupt +# close / RST — typical on Windows after TerminateJobObject) because the +# kernel closes all file descriptors when a process terminates. Either +# is the direct, OS-level proof that the child is dead. +# +# This replaces an older file-growth-watching approach whose fixed `sleep()` +# durations raced against slow Python interpreter startup on loaded CI runners. + + +def _connect_back_script(port: int) -> str: + """Return a ``python -c`` script body that connects to the given port, + sends ``b'alive'``, then blocks forever. Used by TestChildProcessCleanup + subprocesses as a liveness probe.""" + return ( + f"import socket, time\n" + f"s = socket.create_connection(('127.0.0.1', {port}))\n" + f"s.sendall(b'alive')\n" + f"time.sleep(3600)\n" + ) - @pytest.mark.anyio - @pytest.mark.filterwarnings("ignore::ResourceWarning" if sys.platform == "win32" else "default") - async def test_basic_child_process_cleanup(self): - """Test basic parent-child process cleanup. - Parent spawns a single child process that writes continuously to a file. - """ - # Create a marker file for the child process to write to - with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: - marker_file = f.name - # Also create a file to verify parent started - with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: - parent_marker = f.name +def _spawn_then_block(child_script: str) -> str: + """Return a ``python -c`` script body that spawns ``child_script`` as a + subprocess, then blocks forever. The ``!r`` injection avoids nested-quote + escaping for arbitrary child script content.""" + return ( + f"import subprocess, sys, time\nsubprocess.Popen([sys.executable, '-c', {child_script!r}])\ntime.sleep(3600)\n" + ) - try: - # Parent script that spawns a child process - parent_script = textwrap.dedent( - f""" - import subprocess - import sys - import time - import os - - # Mark that parent started - with open({escape_path_for_python(parent_marker)}, 'w') as f: - f.write('parent started\\n') - - # Child script that writes continuously - child_script = f''' - import time - with open({escape_path_for_python(marker_file)}, 'a') as f: - while True: - f.write(f"{time.time()}") - f.flush() - time.sleep(0.1) - ''' - - # Start the child process - child = subprocess.Popen([sys.executable, '-c', child_script]) - - # Parent just sleeps - while True: - time.sleep(0.1) - """ - ) - print("\nStarting child process termination test...") +async def _open_liveness_listener() -> tuple[anyio.abc.SocketListener, int]: + """Open a TCP listener on localhost and return it along with its port.""" + multi = await anyio.create_tcp_listener(local_host="127.0.0.1") + sock = multi.listeners[0] + assert isinstance(sock, anyio.abc.SocketListener) + addr = sock.extra(anyio.abc.SocketAttribute.local_address) + # IPv4 local_address is (host: str, port: int) + assert isinstance(addr, tuple) and len(addr) >= 2 and isinstance(addr[1], int) + return sock, addr[1] + + +async def _accept_alive(sock: anyio.abc.SocketListener) -> anyio.abc.SocketStream: + """Accept one connection and assert the peer sent ``b'alive'``. + + Blocks deterministically until a subprocess connects (no polling). The + outer test bounds this with ``anyio.fail_after`` to catch the case where + the subprocess chain failed to start. + """ + stream = await sock.accept() + msg = await stream.receive(5) + assert msg == b"alive", f"expected b'alive', got {msg!r}" + return stream - # Start the parent process - proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script]) - # Wait for processes to start - await anyio.sleep(0.5) +async def _assert_stream_closed(stream: anyio.abc.SocketStream) -> None: + """Assert the peer holding the other end of ``stream`` has terminated. - # Verify parent started - assert os.path.exists(parent_marker), "Parent process didn't start" + When a process dies, the kernel closes its file descriptors including + sockets. The next ``receive()`` on the peer socket unblocks with one of: - # Verify child is writing - if os.path.exists(marker_file): # pragma: no branch - initial_size = os.path.getsize(marker_file) - await anyio.sleep(0.3) - size_after_wait = os.path.getsize(marker_file) - assert size_after_wait > initial_size, "Child process should be writing" - print(f"Child is writing (file grew from {initial_size} to {size_after_wait} bytes)") + - ``anyio.EndOfStream`` — clean close (FIN), typical after graceful exit + or POSIX ``SIGTERM``. + - ``anyio.BrokenResourceError`` — abrupt close (RST), typical after + Windows ``TerminateJobObject`` or POSIX ``SIGKILL``. - # Terminate using our function - print("Terminating process and children...") + Either is a deterministic, kernel-level signal that the process is dead — + no sleeps or polling required. + """ + with anyio.fail_after(5.0), pytest.raises((anyio.EndOfStream, anyio.BrokenResourceError)): + await stream.receive(1) + + +async def _terminate_and_reap(proc: anyio.abc.Process | FallbackProcess) -> None: + """Terminate the process tree, reap, and tear down pipe transports. + + ``_terminate_process_tree`` kills the OS process group / Job Object but does + not call ``process.wait()`` or clean up the asyncio pipe transports. On + Windows those transports leak and emit ``ResourceWarning`` when GC'd in a + later test, causing ``PytestUnraisableExceptionWarning`` knock-on failures. + + Production ``stdio.py`` avoids this via its ``stdout_reader`` task which + reads stdout to EOF (triggering ``_ProactorReadPipeTransport._eof_received`` + → ``close()``) plus ``async with process:`` which waits and closes stdin. + These tests call ``_terminate_process_tree`` directly, so they replicate + both parts here: ``wait()`` + close stdin + drain stdout to EOF. + + The stdout drain is the non-obvious part: anyio's ``StreamReaderWrapper.aclose()`` + only marks the Python-level reader closed — it never touches the underlying + ``_ProactorReadPipeTransport``. That transport starts paused and only detects + pipe EOF when someone reads, so without a drain it lives until ``__del__``. + + Idempotent: the ``returncode`` guard skips termination if already reaped + (avoids spurious WARNING/ERROR logs from ``terminate_posix_process_tree``'s + fallback path, visible because ``log_cli = true``); ``wait()`` and stream + ``aclose()`` no-op on subsequent calls; the drain raises ``ClosedResourceError`` + on the second call, caught by the suppress. The tests call this explicitly + as the action under test and ``AsyncExitStack`` calls it again on exit as a + safety net. Bounded by ``move_on_after`` to prevent hangs. + """ + with anyio.move_on_after(5.0): + if proc.returncode is None: await _terminate_process_tree(proc) + await proc.wait() + assert proc.stdin is not None + assert proc.stdout is not None + await proc.stdin.aclose() + with suppress(anyio.EndOfStream, anyio.BrokenResourceError, anyio.ClosedResourceError): + await proc.stdout.receive(65536) + await proc.stdout.aclose() - # Verify processes stopped - await anyio.sleep(0.5) - if os.path.exists(marker_file): # pragma: no branch - size_after_cleanup = os.path.getsize(marker_file) - await anyio.sleep(0.5) - final_size = os.path.getsize(marker_file) - print(f"After cleanup: file size {size_after_cleanup} -> {final_size}") - assert final_size == size_after_cleanup, ( - f"Child process still running! File grew by {final_size - size_after_cleanup} bytes" - ) +class TestChildProcessCleanup: + """Integration tests for ``_terminate_process_tree`` covering basic, + nested, and early-parent-exit process tree scenarios. See module-level + comment above for the socket-based liveness probe mechanism. + """ - print("SUCCESS: Child process was properly terminated") + @pytest.mark.anyio + async def test_basic_child_process_cleanup(self): + """Parent spawns one child; terminating the tree kills both.""" + async with AsyncExitStack() as stack: + sock, port = await _open_liveness_listener() + stack.push_async_callback(sock.aclose) + + # Parent spawns a child; the child connects back to us. + parent_script = _spawn_then_block(_connect_back_script(port)) + proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script]) + stack.push_async_callback(_terminate_and_reap, proc) + + # Deterministic: accept() blocks until the child connects. No sleep. + with anyio.fail_after(10.0): + stream = await _accept_alive(sock) + stack.push_async_callback(stream.aclose) - finally: - # Clean up files - for f in [marker_file, parent_marker]: - try: - os.unlink(f) - except OSError: # pragma: no cover - pass + # Terminate, reap and close transports (wraps _terminate_process_tree, + # the behavior under test). + await _terminate_and_reap(proc) + + # Deterministic: kernel closed child's socket when it died. + await _assert_stream_closed(stream) @pytest.mark.anyio - @pytest.mark.filterwarnings("ignore::ResourceWarning" if sys.platform == "win32" else "default") async def test_nested_process_tree(self): - """Test nested process tree cleanup (parent → child → grandchild). - Each level writes to a different file to verify all processes are terminated. - """ - # Create temporary files for each process level - with tempfile.NamedTemporaryFile(mode="w", delete=False) as f1: - parent_file = f1.name - with tempfile.NamedTemporaryFile(mode="w", delete=False) as f2: - child_file = f2.name - with tempfile.NamedTemporaryFile(mode="w", delete=False) as f3: - grandchild_file = f3.name - - try: - # Simple nested process tree test - # We create parent -> child -> grandchild, each writing to a file - parent_script = textwrap.dedent( - f""" - import subprocess - import sys - import time - import os - - # Child will spawn grandchild and write to child file - child_script = f'''import subprocess - import sys - import time - - # Grandchild just writes to file - grandchild_script = \"\"\"import time - with open({escape_path_for_python(grandchild_file)}, 'a') as f: - while True: - f.write(f"gc {{time.time()}}") - f.flush() - time.sleep(0.1)\"\"\" - - # Spawn grandchild - subprocess.Popen([sys.executable, '-c', grandchild_script]) - - # Child writes to its file - with open({escape_path_for_python(child_file)}, 'a') as f: - while True: - f.write(f"c {time.time()}") - f.flush() - time.sleep(0.1)''' - - # Spawn child process - subprocess.Popen([sys.executable, '-c', child_script]) - - # Parent writes to its file - with open({escape_path_for_python(parent_file)}, 'a') as f: - while True: - f.write(f"p {time.time()}") - f.flush() - time.sleep(0.1) - """ + """Parent → child → grandchild; terminating the tree kills all three.""" + async with AsyncExitStack() as stack: + sock, port = await _open_liveness_listener() + stack.push_async_callback(sock.aclose) + + # Build a three-level chain: parent spawns child, child spawns + # grandchild. Every level connects back to our socket. + grandchild = _connect_back_script(port) + child = ( + f"import subprocess, sys\n" + f"subprocess.Popen([sys.executable, '-c', {grandchild!r}])\n" + _connect_back_script(port) + ) + parent_script = ( + f"import subprocess, sys\n" + f"subprocess.Popen([sys.executable, '-c', {child!r}])\n" + _connect_back_script(port) ) - - # Start the parent process proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script]) + stack.push_async_callback(_terminate_and_reap, proc) - # Let all processes start - await anyio.sleep(1.0) - - # Verify all are writing - for file_path, name in [(parent_file, "parent"), (child_file, "child"), (grandchild_file, "grandchild")]: - if os.path.exists(file_path): # pragma: no branch - initial_size = os.path.getsize(file_path) - await anyio.sleep(0.3) - new_size = os.path.getsize(file_path) - assert new_size > initial_size, f"{name} process should be writing" + # Deterministic: three blocking accepts, one per tree level. + streams: list[anyio.abc.SocketStream] = [] + with anyio.fail_after(10.0): + for _ in range(3): + stream = await _accept_alive(sock) + stack.push_async_callback(stream.aclose) + streams.append(stream) - # Terminate the whole tree - await _terminate_process_tree(proc) + # Terminate the entire tree (wraps _terminate_process_tree). + await _terminate_and_reap(proc) - # Verify all stopped - await anyio.sleep(0.5) - for file_path, name in [(parent_file, "parent"), (child_file, "child"), (grandchild_file, "grandchild")]: - if os.path.exists(file_path): # pragma: no branch - size1 = os.path.getsize(file_path) - await anyio.sleep(0.3) - size2 = os.path.getsize(file_path) - assert size1 == size2, f"{name} still writing after cleanup!" - - print("SUCCESS: All processes in tree terminated") - - finally: - # Clean up all marker files - for f in [parent_file, child_file, grandchild_file]: - try: - os.unlink(f) - except OSError: # pragma: no cover - pass + # Every level of the tree must be dead: three kernel-level EOFs. + for stream in streams: + await _assert_stream_closed(stream) @pytest.mark.anyio - @pytest.mark.filterwarnings("ignore::ResourceWarning" if sys.platform == "win32" else "default") async def test_early_parent_exit(self): - """Test cleanup when parent exits during termination sequence. - Tests the race condition where parent might die during our termination - sequence but we can still clean up the children via the process group. + """Parent exits immediately on SIGTERM; process-group termination still + catches the child (exercises the race where the parent dies mid-cleanup). """ - # Create a temporary file for the child - with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: - marker_file = f.name - - try: - # Parent that spawns child and waits briefly - parent_script = textwrap.dedent( - f""" - import subprocess - import sys - import time - import signal - - # Child that continues running - child_script = f'''import time - with open({escape_path_for_python(marker_file)}, 'a') as f: - while True: - f.write(f"child {time.time()}") - f.flush() - time.sleep(0.1)''' - - # Start child in same process group - subprocess.Popen([sys.executable, '-c', child_script]) - - # Parent waits a bit then exits on SIGTERM - def handle_term(sig, frame): - sys.exit(0) - - signal.signal(signal.SIGTERM, handle_term) - - # Wait - while True: - time.sleep(0.1) - """ + async with AsyncExitStack() as stack: + sock, port = await _open_liveness_listener() + stack.push_async_callback(sock.aclose) + + # Parent installs a SIGTERM handler that exits immediately, spawns a + # child that connects back to us, then blocks. + child = _connect_back_script(port) + parent_script = ( + f"import signal, subprocess, sys, time\n" + f"signal.signal(signal.SIGTERM, lambda *_: sys.exit(0))\n" + f"subprocess.Popen([sys.executable, '-c', {child!r}])\n" + f"time.sleep(3600)\n" ) - - # Start the parent process proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script]) + stack.push_async_callback(_terminate_and_reap, proc) - # Let child start writing - await anyio.sleep(0.5) - - # Verify child is writing - if os.path.exists(marker_file): # pragma: no branch - size1 = os.path.getsize(marker_file) - await anyio.sleep(0.3) - size2 = os.path.getsize(marker_file) - assert size2 > size1, "Child should be writing" + # Deterministic: child connected means both parent and child are up. + with anyio.fail_after(10.0): + stream = await _accept_alive(sock) + stack.push_async_callback(stream.aclose) - # Terminate - this will kill the process group even if parent exits first - await _terminate_process_tree(proc) + # Parent will sys.exit(0) on SIGTERM, but the process-group kill + # (POSIX killpg / Windows Job Object) must still terminate the child. + await _terminate_and_reap(proc) - # Verify child stopped - await anyio.sleep(0.5) - if os.path.exists(marker_file): # pragma: no branch - size3 = os.path.getsize(marker_file) - await anyio.sleep(0.3) - size4 = os.path.getsize(marker_file) - assert size3 == size4, "Child should be terminated" - - print("SUCCESS: Child terminated even with parent exit during cleanup") - - finally: - # Clean up marker file - try: - os.unlink(marker_file) - except OSError: # pragma: no cover - pass + # Child must be dead despite parent's early exit. + await _assert_stream_closed(stream) @pytest.mark.anyio From 2c73a2a8811de5d4fa34bf23a1bf390d9328bedb Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 13 Mar 2026 10:39:56 +0000 Subject: [PATCH 33/84] chore(deps): bump black from 25.1.0 to 26.3.1 in the uv group across 1 directory (#2290) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- uv.lock | 95 +++++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 72 insertions(+), 23 deletions(-) diff --git a/uv.lock b/uv.lock index d01d510f17..c01e96a4a6 100644 --- a/uv.lock +++ b/uv.lock @@ -96,7 +96,7 @@ wheels = [ [[package]] name = "black" -version = "25.1.0" +version = "26.3.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, @@ -104,28 +104,38 @@ dependencies = [ { name = "packaging" }, { name = "pathspec" }, { name = "platformdirs" }, + { name = "pytokens" }, { name = "tomli", marker = "python_full_version < '3.11'" }, { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/94/49/26a7b0f3f35da4b5a65f081943b7bcd22d7002f5f0fb8098ec1ff21cb6ef/black-25.1.0.tar.gz", hash = "sha256:33496d5cd1222ad73391352b4ae8da15253c5de89b93a80b3e2c8d9a19ec2666", size = 649449, upload-time = "2025-01-29T04:15:40.373Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/4d/3b/4ba3f93ac8d90410423fdd31d7541ada9bcee1df32fb90d26de41ed40e1d/black-25.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:759e7ec1e050a15f89b770cefbf91ebee8917aac5c20483bc2d80a6c3a04df32", size = 1629419, upload-time = "2025-01-29T05:37:06.642Z" }, - { url = "https://files.pythonhosted.org/packages/b4/02/0bde0485146a8a5e694daed47561785e8b77a0466ccc1f3e485d5ef2925e/black-25.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0e519ecf93120f34243e6b0054db49c00a35f84f195d5bce7e9f5cfc578fc2da", size = 1461080, upload-time = "2025-01-29T05:37:09.321Z" }, - { url = "https://files.pythonhosted.org/packages/52/0e/abdf75183c830eaca7589144ff96d49bce73d7ec6ad12ef62185cc0f79a2/black-25.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:055e59b198df7ac0b7efca5ad7ff2516bca343276c466be72eb04a3bcc1f82d7", size = 1766886, upload-time = "2025-01-29T04:18:24.432Z" }, - { url = "https://files.pythonhosted.org/packages/dc/a6/97d8bb65b1d8a41f8a6736222ba0a334db7b7b77b8023ab4568288f23973/black-25.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:db8ea9917d6f8fc62abd90d944920d95e73c83a5ee3383493e35d271aca872e9", size = 1419404, upload-time = "2025-01-29T04:19:04.296Z" }, - { url = "https://files.pythonhosted.org/packages/7e/4f/87f596aca05c3ce5b94b8663dbfe242a12843caaa82dd3f85f1ffdc3f177/black-25.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a39337598244de4bae26475f77dda852ea00a93bd4c728e09eacd827ec929df0", size = 1614372, upload-time = "2025-01-29T05:37:11.71Z" }, - { url = "https://files.pythonhosted.org/packages/e7/d0/2c34c36190b741c59c901e56ab7f6e54dad8df05a6272a9747ecef7c6036/black-25.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:96c1c7cd856bba8e20094e36e0f948718dc688dba4a9d78c3adde52b9e6c2299", size = 1442865, upload-time = "2025-01-29T05:37:14.309Z" }, - { url = "https://files.pythonhosted.org/packages/21/d4/7518c72262468430ead45cf22bd86c883a6448b9eb43672765d69a8f1248/black-25.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bce2e264d59c91e52d8000d507eb20a9aca4a778731a08cfff7e5ac4a4bb7096", size = 1749699, upload-time = "2025-01-29T04:18:17.688Z" }, - { url = "https://files.pythonhosted.org/packages/58/db/4f5beb989b547f79096e035c4981ceb36ac2b552d0ac5f2620e941501c99/black-25.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:172b1dbff09f86ce6f4eb8edf9dede08b1fce58ba194c87d7a4f1a5aa2f5b3c2", size = 1428028, upload-time = "2025-01-29T04:18:51.711Z" }, - { url = "https://files.pythonhosted.org/packages/83/71/3fe4741df7adf015ad8dfa082dd36c94ca86bb21f25608eb247b4afb15b2/black-25.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4b60580e829091e6f9238c848ea6750efed72140b91b048770b64e74fe04908b", size = 1650988, upload-time = "2025-01-29T05:37:16.707Z" }, - { url = "https://files.pythonhosted.org/packages/13/f3/89aac8a83d73937ccd39bbe8fc6ac8860c11cfa0af5b1c96d081facac844/black-25.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1e2978f6df243b155ef5fa7e558a43037c3079093ed5d10fd84c43900f2d8ecc", size = 1453985, upload-time = "2025-01-29T05:37:18.273Z" }, - { url = "https://files.pythonhosted.org/packages/6f/22/b99efca33f1f3a1d2552c714b1e1b5ae92efac6c43e790ad539a163d1754/black-25.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b48735872ec535027d979e8dcb20bf4f70b5ac75a8ea99f127c106a7d7aba9f", size = 1783816, upload-time = "2025-01-29T04:18:33.823Z" }, - { url = "https://files.pythonhosted.org/packages/18/7e/a27c3ad3822b6f2e0e00d63d58ff6299a99a5b3aee69fa77cd4b0076b261/black-25.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:ea0213189960bda9cf99be5b8c8ce66bb054af5e9e861249cd23471bd7b0b3ba", size = 1440860, upload-time = "2025-01-29T04:19:12.944Z" }, - { url = "https://files.pythonhosted.org/packages/98/87/0edf98916640efa5d0696e1abb0a8357b52e69e82322628f25bf14d263d1/black-25.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8f0b18a02996a836cc9c9c78e5babec10930862827b1b724ddfe98ccf2f2fe4f", size = 1650673, upload-time = "2025-01-29T05:37:20.574Z" }, - { url = "https://files.pythonhosted.org/packages/52/e5/f7bf17207cf87fa6e9b676576749c6b6ed0d70f179a3d812c997870291c3/black-25.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:afebb7098bfbc70037a053b91ae8437c3857482d3a690fefc03e9ff7aa9a5fd3", size = 1453190, upload-time = "2025-01-29T05:37:22.106Z" }, - { url = "https://files.pythonhosted.org/packages/e3/ee/adda3d46d4a9120772fae6de454c8495603c37c4c3b9c60f25b1ab6401fe/black-25.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:030b9759066a4ee5e5aca28c3c77f9c64789cdd4de8ac1df642c40b708be6171", size = 1782926, upload-time = "2025-01-29T04:18:58.564Z" }, - { url = "https://files.pythonhosted.org/packages/cc/64/94eb5f45dcb997d2082f097a3944cfc7fe87e071907f677e80788a2d7b7a/black-25.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:a22f402b410566e2d1c950708c77ebf5ebd5d0d88a6a2e87c86d9fb48afa0d18", size = 1442613, upload-time = "2025-01-29T04:19:27.63Z" }, - { url = "https://files.pythonhosted.org/packages/09/71/54e999902aed72baf26bca0d50781b01838251a462612966e9fc4891eadd/black-25.1.0-py3-none-any.whl", hash = "sha256:95e8176dae143ba9097f351d174fdaf0ccd29efb414b362ae3fd72bf0f710717", size = 207646, upload-time = "2025-01-29T04:15:38.082Z" }, +sdist = { url = "https://files.pythonhosted.org/packages/e1/c5/61175d618685d42b005847464b8fb4743a67b1b8fdb75e50e5a96c31a27a/black-26.3.1.tar.gz", hash = "sha256:2c50f5063a9641c7eed7795014ba37b0f5fa227f3d408b968936e24bc0566b07", size = 666155, upload-time = "2026-03-12T03:36:03.593Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/a8/11170031095655d36ebc6664fe0897866f6023892396900eec0e8fdc4299/black-26.3.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:86a8b5035fce64f5dcd1b794cf8ec4d31fe458cf6ce3986a30deb434df82a1d2", size = 1866562, upload-time = "2026-03-12T03:39:58.639Z" }, + { url = "https://files.pythonhosted.org/packages/69/ce/9e7548d719c3248c6c2abfd555d11169457cbd584d98d179111338423790/black-26.3.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5602bdb96d52d2d0672f24f6ffe5218795736dd34807fd0fd55ccd6bf206168b", size = 1703623, upload-time = "2026-03-12T03:40:00.347Z" }, + { url = "https://files.pythonhosted.org/packages/7f/0a/8d17d1a9c06f88d3d030d0b1d4373c1551146e252afe4547ed601c0e697f/black-26.3.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6c54a4a82e291a1fee5137371ab488866b7c86a3305af4026bdd4dc78642e1ac", size = 1768388, upload-time = "2026-03-12T03:40:01.765Z" }, + { url = "https://files.pythonhosted.org/packages/52/79/c1ee726e221c863cde5164f925bacf183dfdf0397d4e3f94889439b947b4/black-26.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:6e131579c243c98f35bce64a7e08e87fb2d610544754675d4a0e73a070a5aa3a", size = 1412969, upload-time = "2026-03-12T03:40:03.252Z" }, + { url = "https://files.pythonhosted.org/packages/73/a5/15c01d613f5756f68ed8f6d4ec0a1e24b82b18889fa71affd3d1f7fad058/black-26.3.1-cp310-cp310-win_arm64.whl", hash = "sha256:5ed0ca58586c8d9a487352a96b15272b7fa55d139fc8496b519e78023a8dab0a", size = 1220345, upload-time = "2026-03-12T03:40:04.892Z" }, + { url = "https://files.pythonhosted.org/packages/17/57/5f11c92861f9c92eb9dddf515530bc2d06db843e44bdcf1c83c1427824bc/black-26.3.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:28ef38aee69e4b12fda8dba75e21f9b4f979b490c8ac0baa7cb505369ac9e1ff", size = 1851987, upload-time = "2026-03-12T03:40:06.248Z" }, + { url = "https://files.pythonhosted.org/packages/54/aa/340a1463660bf6831f9e39646bf774086dbd8ca7fc3cded9d59bbdf4ad0a/black-26.3.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bf9bf162ed91a26f1adba8efda0b573bc6924ec1408a52cc6f82cb73ec2b142c", size = 1689499, upload-time = "2026-03-12T03:40:07.642Z" }, + { url = "https://files.pythonhosted.org/packages/f3/01/b726c93d717d72733da031d2de10b92c9fa4c8d0c67e8a8a372076579279/black-26.3.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:474c27574d6d7037c1bc875a81d9be0a9a4f9ee95e62800dab3cfaadbf75acd5", size = 1754369, upload-time = "2026-03-12T03:40:09.279Z" }, + { url = "https://files.pythonhosted.org/packages/e3/09/61e91881ca291f150cfc9eb7ba19473c2e59df28859a11a88248b5cbbc4d/black-26.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:5e9d0d86df21f2e1677cc4bd090cd0e446278bcbbe49bf3659c308c3e402843e", size = 1413613, upload-time = "2026-03-12T03:40:10.943Z" }, + { url = "https://files.pythonhosted.org/packages/16/73/544f23891b22e7efe4d8f812371ab85b57f6a01b2fc45e3ba2e52ba985b8/black-26.3.1-cp311-cp311-win_arm64.whl", hash = "sha256:9a5e9f45e5d5e1c5b5c29b3bd4265dcc90e8b92cf4534520896ed77f791f4da5", size = 1219719, upload-time = "2026-03-12T03:40:12.597Z" }, + { url = "https://files.pythonhosted.org/packages/dc/f8/da5eae4fc75e78e6dceb60624e1b9662ab00d6b452996046dfa9b8a6025b/black-26.3.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b5e6f89631eb88a7302d416594a32faeee9fb8fb848290da9d0a5f2903519fc1", size = 1895920, upload-time = "2026-03-12T03:40:13.921Z" }, + { url = "https://files.pythonhosted.org/packages/2c/9f/04e6f26534da2e1629b2b48255c264cabf5eedc5141d04516d9d68a24111/black-26.3.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:41cd2012d35b47d589cb8a16faf8a32ef7a336f56356babd9fcf70939ad1897f", size = 1718499, upload-time = "2026-03-12T03:40:15.239Z" }, + { url = "https://files.pythonhosted.org/packages/04/91/a5935b2a63e31b331060c4a9fdb5a6c725840858c599032a6f3aac94055f/black-26.3.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0f76ff19ec5297dd8e66eb64deda23631e642c9393ab592826fd4bdc97a4bce7", size = 1794994, upload-time = "2026-03-12T03:40:17.124Z" }, + { url = "https://files.pythonhosted.org/packages/e7/0a/86e462cdd311a3c2a8ece708d22aba17d0b2a0d5348ca34b40cdcbea512e/black-26.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:ddb113db38838eb9f043623ba274cfaf7d51d5b0c22ecb30afe58b1bb8322983", size = 1420867, upload-time = "2026-03-12T03:40:18.83Z" }, + { url = "https://files.pythonhosted.org/packages/5b/e5/22515a19cb7eaee3440325a6b0d95d2c0e88dd180cb011b12ae488e031d1/black-26.3.1-cp312-cp312-win_arm64.whl", hash = "sha256:dfdd51fc3e64ea4f35873d1b3fb25326773d55d2329ff8449139ebaad7357efb", size = 1230124, upload-time = "2026-03-12T03:40:20.425Z" }, + { url = "https://files.pythonhosted.org/packages/f5/77/5728052a3c0450c53d9bb3945c4c46b91baa62b2cafab6801411b6271e45/black-26.3.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:855822d90f884905362f602880ed8b5df1b7e3ee7d0db2502d4388a954cc8c54", size = 1895034, upload-time = "2026-03-12T03:40:21.813Z" }, + { url = "https://files.pythonhosted.org/packages/52/73/7cae55fdfdfbe9d19e9a8d25d145018965fe2079fa908101c3733b0c55a0/black-26.3.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:8a33d657f3276328ce00e4d37fe70361e1ec7614da5d7b6e78de5426cb56332f", size = 1718503, upload-time = "2026-03-12T03:40:23.666Z" }, + { url = "https://files.pythonhosted.org/packages/e1/87/af89ad449e8254fdbc74654e6467e3c9381b61472cc532ee350d28cfdafb/black-26.3.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f1cd08e99d2f9317292a311dfe578fd2a24b15dbce97792f9c4d752275c1fa56", size = 1793557, upload-time = "2026-03-12T03:40:25.497Z" }, + { url = "https://files.pythonhosted.org/packages/43/10/d6c06a791d8124b843bf325ab4ac7d2f5b98731dff84d6064eafd687ded1/black-26.3.1-cp313-cp313-win_amd64.whl", hash = "sha256:c7e72339f841b5a237ff14f7d3880ddd0fc7f98a1199e8c4327f9a4f478c1839", size = 1422766, upload-time = "2026-03-12T03:40:27.14Z" }, + { url = "https://files.pythonhosted.org/packages/59/4f/40a582c015f2d841ac24fed6390bd68f0fc896069ff3a886317959c9daf8/black-26.3.1-cp313-cp313-win_arm64.whl", hash = "sha256:afc622538b430aa4c8c853f7f63bc582b3b8030fd8c80b70fb5fa5b834e575c2", size = 1232140, upload-time = "2026-03-12T03:40:28.882Z" }, + { url = "https://files.pythonhosted.org/packages/d5/da/e36e27c9cebc1311b7579210df6f1c86e50f2d7143ae4fcf8a5017dc8809/black-26.3.1-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:2d6bfaf7fd0993b420bed691f20f9492d53ce9a2bcccea4b797d34e947318a78", size = 1889234, upload-time = "2026-03-12T03:40:30.964Z" }, + { url = "https://files.pythonhosted.org/packages/0e/7b/9871acf393f64a5fa33668c19350ca87177b181f44bb3d0c33b2d534f22c/black-26.3.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:f89f2ab047c76a9c03f78d0d66ca519e389519902fa27e7a91117ef7611c0568", size = 1720522, upload-time = "2026-03-12T03:40:32.346Z" }, + { url = "https://files.pythonhosted.org/packages/03/87/e766c7f2e90c07fb7586cc787c9ae6462b1eedab390191f2b7fc7f6170a9/black-26.3.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b07fc0dab849d24a80a29cfab8d8a19187d1c4685d8a5e6385a5ce323c1f015f", size = 1787824, upload-time = "2026-03-12T03:40:33.636Z" }, + { url = "https://files.pythonhosted.org/packages/ac/94/2424338fb2d1875e9e83eed4c8e9c67f6905ec25afd826a911aea2b02535/black-26.3.1-cp314-cp314-win_amd64.whl", hash = "sha256:0126ae5b7c09957da2bdbd91a9ba1207453feada9e9fe51992848658c6c8e01c", size = 1445855, upload-time = "2026-03-12T03:40:35.442Z" }, + { url = "https://files.pythonhosted.org/packages/86/43/0c3338bd928afb8ee7471f1a4eec3bdbe2245ccb4a646092a222e8669840/black-26.3.1-cp314-cp314-win_arm64.whl", hash = "sha256:92c0ec1f2cc149551a2b7b47efc32c866406b6891b0ee4625e95967c8f4acfb1", size = 1258109, upload-time = "2026-03-12T03:40:36.832Z" }, + { url = "https://files.pythonhosted.org/packages/8e/0d/52d98722666d6fc6c3dd4c76df339501d6efd40e0ff95e6186a7b7f0befd/black-26.3.1-py3-none-any.whl", hash = "sha256:2bd5aa94fc267d38bb21a70d7410a89f1a1d318841855f698746f8e7f51acd1b", size = 207542, upload-time = "2026-03-12T03:36:01.668Z" }, ] [[package]] @@ -1636,11 +1646,11 @@ wheels = [ [[package]] name = "pathspec" -version = "0.12.1" +version = "1.0.4" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ca/bc/f35b8446f4531a7cb215605d100cd88b7ac6f44ab3fc94870c120ab3adbf/pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712", size = 51043, upload-time = "2023-12-10T22:30:45Z" } +sdist = { url = "https://files.pythonhosted.org/packages/fa/36/e27608899f9b8d4dff0617b2d9ab17ca5608956ca44461ac14ac48b44015/pathspec-1.0.4.tar.gz", hash = "sha256:0210e2ae8a21a9137c0d470578cb0e595af87edaa6ebf12ff176f14a02e0e645", size = 131200, upload-time = "2026-01-27T03:59:46.938Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08", size = 31191, upload-time = "2023-12-10T22:30:43.14Z" }, + { url = "https://files.pythonhosted.org/packages/ef/3c/2c197d226f9ea224a9ab8d197933f9da0ae0aac5b6e0f884e2b8d9c8e9f7/pathspec-1.0.4-py3-none-any.whl", hash = "sha256:fb6ae2fd4e7c921a165808a552060e722767cfa526f99ca5156ed2ce45a5c723", size = 55206, upload-time = "2026-01-27T03:59:45.137Z" }, ] [[package]] @@ -2064,6 +2074,45 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1b/d0/397f9626e711ff749a95d96b7af99b9c566a9bb5129b8e4c10fc4d100304/python_multipart-0.0.22-py3-none-any.whl", hash = "sha256:2b2cd894c83d21bf49d702499531c7bafd057d730c201782048f7945d82de155", size = 24579, upload-time = "2026-01-25T10:15:54.811Z" }, ] +[[package]] +name = "pytokens" +version = "0.4.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b6/34/b4e015b99031667a7b960f888889c5bd34ef585c85e1cb56a594b92836ac/pytokens-0.4.1.tar.gz", hash = "sha256:292052fe80923aae2260c073f822ceba21f3872ced9a68bb7953b348e561179a", size = 23015, upload-time = "2026-01-30T01:03:45.924Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/42/24/f206113e05cb8ef51b3850e7ef88f20da6f4bf932190ceb48bd3da103e10/pytokens-0.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2a44ed93ea23415c54f3face3b65ef2b844d96aeb3455b8a69b3df6beab6acc5", size = 161522, upload-time = "2026-01-30T01:02:50.393Z" }, + { url = "https://files.pythonhosted.org/packages/d4/e9/06a6bf1b90c2ed81a9c7d2544232fe5d2891d1cd480e8a1809ca354a8eb2/pytokens-0.4.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:add8bf86b71a5d9fb5b89f023a80b791e04fba57960aa790cc6125f7f1d39dfe", size = 246945, upload-time = "2026-01-30T01:02:52.399Z" }, + { url = "https://files.pythonhosted.org/packages/69/66/f6fb1007a4c3d8b682d5d65b7c1fb33257587a5f782647091e3408abe0b8/pytokens-0.4.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:670d286910b531c7b7e3c0b453fd8156f250adb140146d234a82219459b9640c", size = 259525, upload-time = "2026-01-30T01:02:53.737Z" }, + { url = "https://files.pythonhosted.org/packages/04/92/086f89b4d622a18418bac74ab5db7f68cf0c21cf7cc92de6c7b919d76c88/pytokens-0.4.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:4e691d7f5186bd2842c14813f79f8884bb03f5995f0575272009982c5ac6c0f7", size = 262693, upload-time = "2026-01-30T01:02:54.871Z" }, + { url = "https://files.pythonhosted.org/packages/b4/7b/8b31c347cf94a3f900bdde750b2e9131575a61fdb620d3d3c75832262137/pytokens-0.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:27b83ad28825978742beef057bfe406ad6ed524b2d28c252c5de7b4a6dd48fa2", size = 103567, upload-time = "2026-01-30T01:02:56.414Z" }, + { url = "https://files.pythonhosted.org/packages/3d/92/790ebe03f07b57e53b10884c329b9a1a308648fc083a6d4a39a10a28c8fc/pytokens-0.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d70e77c55ae8380c91c0c18dea05951482e263982911fc7410b1ffd1dadd3440", size = 160864, upload-time = "2026-01-30T01:02:57.882Z" }, + { url = "https://files.pythonhosted.org/packages/13/25/a4f555281d975bfdd1eba731450e2fe3a95870274da73fb12c40aeae7625/pytokens-0.4.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4a58d057208cb9075c144950d789511220b07636dd2e4708d5645d24de666bdc", size = 248565, upload-time = "2026-01-30T01:02:59.912Z" }, + { url = "https://files.pythonhosted.org/packages/17/50/bc0394b4ad5b1601be22fa43652173d47e4c9efbf0044c62e9a59b747c56/pytokens-0.4.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b49750419d300e2b5a3813cf229d4e5a4c728dae470bcc89867a9ad6f25a722d", size = 260824, upload-time = "2026-01-30T01:03:01.471Z" }, + { url = "https://files.pythonhosted.org/packages/4e/54/3e04f9d92a4be4fc6c80016bc396b923d2a6933ae94b5f557c939c460ee0/pytokens-0.4.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d9907d61f15bf7261d7e775bd5d7ee4d2930e04424bab1972591918497623a16", size = 264075, upload-time = "2026-01-30T01:03:04.143Z" }, + { url = "https://files.pythonhosted.org/packages/d1/1b/44b0326cb5470a4375f37988aea5d61b5cc52407143303015ebee94abfd6/pytokens-0.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:ee44d0f85b803321710f9239f335aafe16553b39106384cef8e6de40cb4ef2f6", size = 103323, upload-time = "2026-01-30T01:03:05.412Z" }, + { url = "https://files.pythonhosted.org/packages/41/5d/e44573011401fb82e9d51e97f1290ceb377800fb4eed650b96f4753b499c/pytokens-0.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:140709331e846b728475786df8aeb27d24f48cbcf7bcd449f8de75cae7a45083", size = 160663, upload-time = "2026-01-30T01:03:06.473Z" }, + { url = "https://files.pythonhosted.org/packages/f0/e6/5bbc3019f8e6f21d09c41f8b8654536117e5e211a85d89212d59cbdab381/pytokens-0.4.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6d6c4268598f762bc8e91f5dbf2ab2f61f7b95bdc07953b602db879b3c8c18e1", size = 255626, upload-time = "2026-01-30T01:03:08.177Z" }, + { url = "https://files.pythonhosted.org/packages/bf/3c/2d5297d82286f6f3d92770289fd439956b201c0a4fc7e72efb9b2293758e/pytokens-0.4.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:24afde1f53d95348b5a0eb19488661147285ca4dd7ed752bbc3e1c6242a304d1", size = 269779, upload-time = "2026-01-30T01:03:09.756Z" }, + { url = "https://files.pythonhosted.org/packages/20/01/7436e9ad693cebda0551203e0bf28f7669976c60ad07d6402098208476de/pytokens-0.4.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:5ad948d085ed6c16413eb5fec6b3e02fa00dc29a2534f088d3302c47eb59adf9", size = 268076, upload-time = "2026-01-30T01:03:10.957Z" }, + { url = "https://files.pythonhosted.org/packages/2e/df/533c82a3c752ba13ae7ef238b7f8cdd272cf1475f03c63ac6cf3fcfb00b6/pytokens-0.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:3f901fe783e06e48e8cbdc82d631fca8f118333798193e026a50ce1b3757ea68", size = 103552, upload-time = "2026-01-30T01:03:12.066Z" }, + { url = "https://files.pythonhosted.org/packages/cb/dc/08b1a080372afda3cceb4f3c0a7ba2bde9d6a5241f1edb02a22a019ee147/pytokens-0.4.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:8bdb9d0ce90cbf99c525e75a2fa415144fd570a1ba987380190e8b786bc6ef9b", size = 160720, upload-time = "2026-01-30T01:03:13.843Z" }, + { url = "https://files.pythonhosted.org/packages/64/0c/41ea22205da480837a700e395507e6a24425151dfb7ead73343d6e2d7ffe/pytokens-0.4.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5502408cab1cb18e128570f8d598981c68a50d0cbd7c61312a90507cd3a1276f", size = 254204, upload-time = "2026-01-30T01:03:14.886Z" }, + { url = "https://files.pythonhosted.org/packages/e0/d2/afe5c7f8607018beb99971489dbb846508f1b8f351fcefc225fcf4b2adc0/pytokens-0.4.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:29d1d8fb1030af4d231789959f21821ab6325e463f0503a61d204343c9b355d1", size = 268423, upload-time = "2026-01-30T01:03:15.936Z" }, + { url = "https://files.pythonhosted.org/packages/68/d4/00ffdbd370410c04e9591da9220a68dc1693ef7499173eb3e30d06e05ed1/pytokens-0.4.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:970b08dd6b86058b6dc07efe9e98414f5102974716232d10f32ff39701e841c4", size = 266859, upload-time = "2026-01-30T01:03:17.458Z" }, + { url = "https://files.pythonhosted.org/packages/a7/c9/c3161313b4ca0c601eeefabd3d3b576edaa9afdefd32da97210700e47652/pytokens-0.4.1-cp313-cp313-win_amd64.whl", hash = "sha256:9bd7d7f544d362576be74f9d5901a22f317efc20046efe2034dced238cbbfe78", size = 103520, upload-time = "2026-01-30T01:03:18.652Z" }, + { url = "https://files.pythonhosted.org/packages/8f/a7/b470f672e6fc5fee0a01d9e75005a0e617e162381974213a945fcd274843/pytokens-0.4.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:4a14d5f5fc78ce85e426aa159489e2d5961acf0e47575e08f35584009178e321", size = 160821, upload-time = "2026-01-30T01:03:19.684Z" }, + { url = "https://files.pythonhosted.org/packages/80/98/e83a36fe8d170c911f864bfded690d2542bfcfacb9c649d11a9e6eb9dc41/pytokens-0.4.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:97f50fd18543be72da51dd505e2ed20d2228c74e0464e4262e4899797803d7fa", size = 254263, upload-time = "2026-01-30T01:03:20.834Z" }, + { url = "https://files.pythonhosted.org/packages/0f/95/70d7041273890f9f97a24234c00b746e8da86df462620194cef1d411ddeb/pytokens-0.4.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:dc74c035f9bfca0255c1af77ddd2d6ae8419012805453e4b0e7513e17904545d", size = 268071, upload-time = "2026-01-30T01:03:21.888Z" }, + { url = "https://files.pythonhosted.org/packages/da/79/76e6d09ae19c99404656d7db9c35dfd20f2086f3eb6ecb496b5b31163bad/pytokens-0.4.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:f66a6bbe741bd431f6d741e617e0f39ec7257ca1f89089593479347cc4d13324", size = 271716, upload-time = "2026-01-30T01:03:23.633Z" }, + { url = "https://files.pythonhosted.org/packages/79/37/482e55fa1602e0a7ff012661d8c946bafdc05e480ea5a32f4f7e336d4aa9/pytokens-0.4.1-cp314-cp314-win_amd64.whl", hash = "sha256:b35d7e5ad269804f6697727702da3c517bb8a5228afa450ab0fa787732055fc9", size = 104539, upload-time = "2026-01-30T01:03:24.788Z" }, + { url = "https://files.pythonhosted.org/packages/30/e8/20e7db907c23f3d63b0be3b8a4fd1927f6da2395f5bcc7f72242bb963dfe/pytokens-0.4.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:8fcb9ba3709ff77e77f1c7022ff11d13553f3c30299a9fe246a166903e9091eb", size = 168474, upload-time = "2026-01-30T01:03:26.428Z" }, + { url = "https://files.pythonhosted.org/packages/d6/81/88a95ee9fafdd8f5f3452107748fd04c24930d500b9aba9738f3ade642cc/pytokens-0.4.1-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:79fc6b8699564e1f9b521582c35435f1bd32dd06822322ec44afdeba666d8cb3", size = 290473, upload-time = "2026-01-30T01:03:27.415Z" }, + { url = "https://files.pythonhosted.org/packages/cf/35/3aa899645e29b6375b4aed9f8d21df219e7c958c4c186b465e42ee0a06bf/pytokens-0.4.1-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d31b97b3de0f61571a124a00ffe9a81fb9939146c122c11060725bd5aea79975", size = 303485, upload-time = "2026-01-30T01:03:28.558Z" }, + { url = "https://files.pythonhosted.org/packages/52/a0/07907b6ff512674d9b201859f7d212298c44933633c946703a20c25e9d81/pytokens-0.4.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:967cf6e3fd4adf7de8fc73cd3043754ae79c36475c1c11d514fc72cf5490094a", size = 306698, upload-time = "2026-01-30T01:03:29.653Z" }, + { url = "https://files.pythonhosted.org/packages/39/2a/cbbf9250020a4a8dd53ba83a46c097b69e5eb49dd14e708f496f548c6612/pytokens-0.4.1-cp314-cp314t-win_amd64.whl", hash = "sha256:584c80c24b078eec1e227079d56dc22ff755e0ba8654d8383b2c549107528918", size = 116287, upload-time = "2026-01-30T01:03:30.912Z" }, + { url = "https://files.pythonhosted.org/packages/c6/78/397db326746f0a342855b81216ae1f0a32965deccfd7c830a2dbc66d2483/pytokens-0.4.1-py3-none-any.whl", hash = "sha256:26cef14744a8385f35d0e095dc8b3a7583f6c953c2e3d269c7f82484bf5ad2de", size = 13729, upload-time = "2026-01-30T01:03:45.029Z" }, +] + [[package]] name = "pywin32" version = "311" From e1fd62e0f324bf64479d663cf02ba492326f64df Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 13 Mar 2026 14:43:54 +0000 Subject: [PATCH 34/84] fix: close all memory stream ends in client transport cleanup (#2266) --- pyproject.toml | 2 +- src/mcp/client/sse.py | 2 + src/mcp/client/streamable_http.py | 2 + src/mcp/client/websocket.py | 13 ++- tests/client/test_transport_stream_cleanup.py | 102 ++++++++++++++++++ tests/shared/test_sse.py | 6 -- uv.lock | 2 +- 7 files changed, 117 insertions(+), 12 deletions(-) create mode 100644 tests/client/test_transport_stream_cleanup.py diff --git a/pyproject.toml b/pyproject.toml index 737839a23c..f275b90cfd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ classifiers = [ "Programming Language :: Python :: 3.14", ] dependencies = [ - "anyio>=4.5", + "anyio>=4.9", "httpx>=0.27.1", "httpx-sse>=0.4", "pydantic>=2.12.0", diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 61026aa0c9..972efce588 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -160,3 +160,5 @@ async def post_writer(endpoint_url: str): finally: await read_stream_writer.aclose() await write_stream.aclose() + await read_stream.aclose() + await write_stream_reader.aclose() diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 9f3dd5e0ba..3416bbc816 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -577,3 +577,5 @@ def start_get_stream() -> None: finally: await read_stream_writer.aclose() await write_stream.aclose() + await read_stream.aclose() + await write_stream_reader.aclose() diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index 79e75fad18..de473f36d3 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -38,11 +38,10 @@ async def websocket_client( write_stream: MemoryObjectSendStream[SessionMessage] write_stream_reader: MemoryObjectReceiveStream[SessionMessage] - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - # Connect using websockets, requesting the "mcp" subprotocol async with ws_connect(url, subprotocols=[Subprotocol("mcp")]) as ws: + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) async def ws_reader(): """Reads text messages from the WebSocket, parses them as JSON-RPC messages, @@ -68,7 +67,13 @@ async def ws_writer(): msg_dict = session_message.message.model_dump(by_alias=True, mode="json", exclude_unset=True) await ws.send(json.dumps(msg_dict)) - async with anyio.create_task_group() as tg: + async with ( + read_stream_writer, + read_stream, + write_stream, + write_stream_reader, + anyio.create_task_group() as tg, + ): # Start reader and writer tasks tg.start_soon(ws_reader) tg.start_soon(ws_writer) diff --git a/tests/client/test_transport_stream_cleanup.py b/tests/client/test_transport_stream_cleanup.py new file mode 100644 index 0000000000..631b0fff22 --- /dev/null +++ b/tests/client/test_transport_stream_cleanup.py @@ -0,0 +1,102 @@ +"""Regression tests for memory stream leaks in client transports. + +When a connection error occurs (404, 403, ConnectError), transport context +managers must close ALL 4 memory stream ends they created. anyio memory streams +are paired but independent — closing the writer does NOT close the reader. +Unclosed stream ends emit ResourceWarning on GC, which pytest promotes to a +test failure in whatever test happens to be running when GC triggers. + +These tests force GC after the transport context exits, so any leaked stream +triggers a ResourceWarning immediately and deterministically here, rather than +nondeterministically in an unrelated later test. +""" + +import gc +import sys +from collections.abc import Iterator +from contextlib import contextmanager + +import httpx +import pytest + +from mcp.client.sse import sse_client +from mcp.client.streamable_http import streamable_http_client +from mcp.client.websocket import websocket_client + + +@contextmanager +def _assert_no_memory_stream_leak() -> Iterator[None]: + """Fail if any anyio MemoryObject stream emits ResourceWarning during the block. + + Uses a custom sys.unraisablehook to capture ONLY MemoryObject stream leaks, + ignoring unrelated resources (e.g. PipeHandle from flaky stdio tests on the + same xdist worker). gc.collect() is forced after the block to make leaks + deterministic. + """ + leaked: list[str] = [] + old_hook = sys.unraisablehook + + def hook(args: "sys.UnraisableHookArgs") -> None: # pragma: no cover + # Only executes if a leak occurs (i.e. the bug is present). + # args.object is the __del__ function (not the stream instance) when + # unraisablehook fires from a finalizer, so check exc_value — the + # actual ResourceWarning("Unclosed "). + # Non-MemoryObject unraisables (e.g. PipeHandle leaked by an earlier + # flaky test on the same xdist worker) are deliberately ignored — + # this test should not fail for another test's resource leak. + if "MemoryObject" in str(args.exc_value): + leaked.append(str(args.exc_value)) + + sys.unraisablehook = hook + try: + yield + gc.collect() + assert not leaked, f"Memory streams leaked: {leaked}" + finally: + sys.unraisablehook = old_hook + + +@pytest.mark.anyio +async def test_sse_client_closes_all_streams_on_connection_error(free_tcp_port: int) -> None: + """sse_client must close all 4 stream ends when the connection fails. + + Before the fix, only read_stream_writer and write_stream were closed in + the finally block. read_stream and write_stream_reader were leaked. + """ + with _assert_no_memory_stream_leak(): + # sse_client enters a task group BEFORE connecting, so anyio wraps the + # ConnectError from aconnect_sse in an ExceptionGroup. + with pytest.raises(Exception) as exc_info: # noqa: B017 + async with sse_client(f"http://127.0.0.1:{free_tcp_port}/sse"): + pytest.fail("should not reach here") # pragma: no cover + + assert exc_info.group_contains(httpx.ConnectError) + # exc_info holds the traceback → holds frame locals → keeps leaked + # streams alive. Must drop it before gc.collect() can detect a leak. + del exc_info + + +@pytest.mark.anyio +async def test_streamable_http_client_closes_all_streams_on_exit() -> None: + """streamable_http_client must close all 4 stream ends on exit. + + Before the fix, read_stream was never closed — not even on the happy path. + This test enters and exits the context without sending any messages, so no + network connection is ever attempted (streamable_http connects lazily). + """ + with _assert_no_memory_stream_leak(): + async with streamable_http_client("http://127.0.0.1:1/mcp"): + pass + + +@pytest.mark.anyio +async def test_websocket_client_closes_all_streams_on_connection_error(free_tcp_port: int) -> None: + """websocket_client must close all 4 stream ends when ws_connect fails. + + Before the fix, there was no try/finally at all — if ws_connect raised, + all 4 streams were leaked. + """ + with _assert_no_memory_stream_leak(): + with pytest.raises(OSError): + async with websocket_client(f"ws://127.0.0.1:{free_tcp_port}/ws"): + pytest.fail("should not reach here") # pragma: no cover diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 890e997332..5629a5707b 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -544,12 +544,6 @@ def test_sse_server_transport_endpoint_validation(endpoint: str, expected_result assert sse._endpoint.startswith("/") -# ResourceWarning filter: When mocking aconnect_sse, the sse_client's internal task -# group doesn't receive proper cancellation signals, so the sse_reader task's finally -# block (which closes read_stream_writer) doesn't execute. This is a test artifact - -# the actual code path (`if not sse.data: continue`) IS exercised and works correctly. -# Production code with real SSE connections cleans up properly. -@pytest.mark.filterwarnings("ignore::ResourceWarning") @pytest.mark.anyio async def test_sse_client_handles_empty_keepalive_pings() -> None: """Test that SSE client properly handles empty data lines (keep-alive pings). diff --git a/uv.lock b/uv.lock index c01e96a4a6..c25047e48d 100644 --- a/uv.lock +++ b/uv.lock @@ -847,7 +847,7 @@ docs = [ [package.metadata] requires-dist = [ - { name = "anyio", specifier = ">=4.5" }, + { name = "anyio", specifier = ">=4.9" }, { name = "httpx", specifier = ">=0.27.1" }, { name = "httpx-sse", specifier = ">=0.4" }, { name = "jsonschema", specifier = ">=4.20.0" }, From abfb482246a65f829d092b594d109cf23ac35a7e Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 16 Mar 2026 11:37:01 +0000 Subject: [PATCH 35/84] refactor(examples): migrate all HTTP examples to streamable_http_app() (#2291) --- docs/experimental/tasks-server.md | 22 +---------- examples/servers/simple-pagination/README.md | 6 +-- .../mcp_simple_pagination/server.py | 29 ++------------ examples/servers/simple-prompt/README.md | 6 +-- .../simple-prompt/mcp_simple_prompt/server.py | 29 ++------------ examples/servers/simple-resource/README.md | 6 +-- .../mcp_simple_resource/server.py | 29 ++------------ .../simple-streamablehttp-stateless/README.md | 1 - .../server.py | 34 ++-------------- .../servers/simple-streamablehttp/README.md | 2 - .../mcp_simple_streamablehttp/server.py | 39 ++----------------- .../mcp_simple_task_interactive/server.py | 20 +--------- .../simple-task/mcp_simple_task/server.py | 18 +-------- examples/servers/simple-tool/README.md | 6 +-- .../simple-tool/mcp_simple_tool/server.py | 29 ++------------ .../mcp_sse_polling_demo/server.py | 38 +++--------------- 16 files changed, 43 insertions(+), 271 deletions(-) diff --git a/docs/experimental/tasks-server.md b/docs/experimental/tasks-server.md index 761dc5de5c..b350ee3bb6 100644 --- a/docs/experimental/tasks-server.md +++ b/docs/experimental/tasks-server.md @@ -408,16 +408,10 @@ For custom error messages, call `task.fail()` before raising. For web applications, use the Streamable HTTP transport: ```python -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager - import uvicorn -from starlette.applications import Starlette -from starlette.routing import Mount from mcp.server import Server from mcp.server.experimental.task_context import ServerTaskContext -from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.types import ( CallToolResult, CreateTaskResult, TextContent, Tool, ToolExecution, TASK_REQUIRED, ) @@ -462,22 +456,8 @@ async def handle_tool(name: str, arguments: dict) -> CallToolResult | CreateTask return CallToolResult(content=[TextContent(type="text", text=f"Unknown: {name}")], isError=True) -def create_app(): - session_manager = StreamableHTTPSessionManager(app=server) - - @asynccontextmanager - async def lifespan(app: Starlette) -> AsyncIterator[None]: - async with session_manager.run(): - yield - - return Starlette( - routes=[Mount("/mcp", app=session_manager.handle_request)], - lifespan=lifespan, - ) - - if __name__ == "__main__": - uvicorn.run(create_app(), host="127.0.0.1", port=8000) + uvicorn.run(server.streamable_http_app(), host="127.0.0.1", port=8000) ``` ## Testing Task Servers diff --git a/examples/servers/simple-pagination/README.md b/examples/servers/simple-pagination/README.md index e732b8efbe..4cab40fd34 100644 --- a/examples/servers/simple-pagination/README.md +++ b/examples/servers/simple-pagination/README.md @@ -4,14 +4,14 @@ A simple MCP server demonstrating pagination for tools, resources, and prompts u ## Usage -Start the server using either stdio (default) or SSE transport: +Start the server using either stdio (default) or Streamable HTTP transport: ```bash # Using stdio transport (default) uv run mcp-simple-pagination -# Using SSE transport on custom port -uv run mcp-simple-pagination --transport sse --port 8000 +# Using Streamable HTTP transport on custom port +uv run mcp-simple-pagination --transport streamable-http --port 8000 ``` The server exposes: diff --git a/examples/servers/simple-pagination/mcp_simple_pagination/server.py b/examples/servers/simple-pagination/mcp_simple_pagination/server.py index bac27a0f1f..c94f2ac3d1 100644 --- a/examples/servers/simple-pagination/mcp_simple_pagination/server.py +++ b/examples/servers/simple-pagination/mcp_simple_pagination/server.py @@ -10,7 +10,6 @@ import click from mcp import types from mcp.server import Server, ServerRequestContext -from starlette.requests import Request T = TypeVar("T") @@ -143,10 +142,10 @@ async def handle_get_prompt(ctx: ServerRequestContext, params: types.GetPromptRe @click.command() -@click.option("--port", default=8000, help="Port to listen on for SSE") +@click.option("--port", default=8000, help="Port to listen on for HTTP") @click.option( "--transport", - type=click.Choice(["stdio", "sse"]), + type=click.Choice(["stdio", "streamable-http"]), default="stdio", help="Transport type", ) @@ -161,30 +160,10 @@ def main(port: int, transport: str) -> int: on_get_prompt=handle_get_prompt, ) - if transport == "sse": - from mcp.server.sse import SseServerTransport - from starlette.applications import Starlette - from starlette.responses import Response - from starlette.routing import Mount, Route - - sse = SseServerTransport("/messages/") - - async def handle_sse(request: Request): - async with sse.connect_sse(request.scope, request.receive, request._send) as streams: # type: ignore[reportPrivateUsage] - await app.run(streams[0], streams[1], app.create_initialization_options()) - return Response() - - starlette_app = Starlette( - debug=True, - routes=[ - Route("/sse", endpoint=handle_sse, methods=["GET"]), - Mount("/messages/", app=sse.handle_post_message), - ], - ) - + if transport == "streamable-http": import uvicorn - uvicorn.run(starlette_app, host="127.0.0.1", port=port) + uvicorn.run(app.streamable_http_app(), host="127.0.0.1", port=port) else: from mcp.server.stdio import stdio_server diff --git a/examples/servers/simple-prompt/README.md b/examples/servers/simple-prompt/README.md index 48e796e198..c837da876e 100644 --- a/examples/servers/simple-prompt/README.md +++ b/examples/servers/simple-prompt/README.md @@ -4,14 +4,14 @@ A simple MCP server that exposes a customizable prompt template with optional co ## Usage -Start the server using either stdio (default) or SSE transport: +Start the server using either stdio (default) or Streamable HTTP transport: ```bash # Using stdio transport (default) uv run mcp-simple-prompt -# Using SSE transport on custom port -uv run mcp-simple-prompt --transport sse --port 8000 +# Using Streamable HTTP transport on custom port +uv run mcp-simple-prompt --transport streamable-http --port 8000 ``` The server exposes a prompt named "simple" that accepts two optional arguments: diff --git a/examples/servers/simple-prompt/mcp_simple_prompt/server.py b/examples/servers/simple-prompt/mcp_simple_prompt/server.py index 6cf99d4b69..74b71b3f38 100644 --- a/examples/servers/simple-prompt/mcp_simple_prompt/server.py +++ b/examples/servers/simple-prompt/mcp_simple_prompt/server.py @@ -2,7 +2,6 @@ import click from mcp import types from mcp.server import Server, ServerRequestContext -from starlette.requests import Request def create_messages(context: str | None = None, topic: str | None = None) -> list[types.PromptMessage]: @@ -69,10 +68,10 @@ async def handle_get_prompt(ctx: ServerRequestContext, params: types.GetPromptRe @click.command() -@click.option("--port", default=8000, help="Port to listen on for SSE") +@click.option("--port", default=8000, help="Port to listen on for HTTP") @click.option( "--transport", - type=click.Choice(["stdio", "sse"]), + type=click.Choice(["stdio", "streamable-http"]), default="stdio", help="Transport type", ) @@ -83,30 +82,10 @@ def main(port: int, transport: str) -> int: on_get_prompt=handle_get_prompt, ) - if transport == "sse": - from mcp.server.sse import SseServerTransport - from starlette.applications import Starlette - from starlette.responses import Response - from starlette.routing import Mount, Route - - sse = SseServerTransport("/messages/") - - async def handle_sse(request: Request): - async with sse.connect_sse(request.scope, request.receive, request._send) as streams: # type: ignore[reportPrivateUsage] - await app.run(streams[0], streams[1], app.create_initialization_options()) - return Response() - - starlette_app = Starlette( - debug=True, - routes=[ - Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse.handle_post_message), - ], - ) - + if transport == "streamable-http": import uvicorn - uvicorn.run(starlette_app, host="127.0.0.1", port=port) + uvicorn.run(app.streamable_http_app(), host="127.0.0.1", port=port) else: from mcp.server.stdio import stdio_server diff --git a/examples/servers/simple-resource/README.md b/examples/servers/simple-resource/README.md index df674e91e4..7fb2ab7cdc 100644 --- a/examples/servers/simple-resource/README.md +++ b/examples/servers/simple-resource/README.md @@ -4,14 +4,14 @@ A simple MCP server that exposes sample text files as resources. ## Usage -Start the server using either stdio (default) or SSE transport: +Start the server using either stdio (default) or Streamable HTTP transport: ```bash # Using stdio transport (default) uv run mcp-simple-resource -# Using SSE transport on custom port -uv run mcp-simple-resource --transport sse --port 8000 +# Using Streamable HTTP transport on custom port +uv run mcp-simple-resource --transport streamable-http --port 8000 ``` The server exposes some basic text file resources that can be read by clients. diff --git a/examples/servers/simple-resource/mcp_simple_resource/server.py b/examples/servers/simple-resource/mcp_simple_resource/server.py index b9b6a1d960..8d11054145 100644 --- a/examples/servers/simple-resource/mcp_simple_resource/server.py +++ b/examples/servers/simple-resource/mcp_simple_resource/server.py @@ -4,7 +4,6 @@ import click from mcp import types from mcp.server import Server, ServerRequestContext -from starlette.requests import Request SAMPLE_RESOURCES = { "greeting": { @@ -62,10 +61,10 @@ async def handle_read_resource( @click.command() -@click.option("--port", default=8000, help="Port to listen on for SSE") +@click.option("--port", default=8000, help="Port to listen on for HTTP") @click.option( "--transport", - type=click.Choice(["stdio", "sse"]), + type=click.Choice(["stdio", "streamable-http"]), default="stdio", help="Transport type", ) @@ -76,30 +75,10 @@ def main(port: int, transport: str) -> int: on_read_resource=handle_read_resource, ) - if transport == "sse": - from mcp.server.sse import SseServerTransport - from starlette.applications import Starlette - from starlette.responses import Response - from starlette.routing import Mount, Route - - sse = SseServerTransport("/messages/") - - async def handle_sse(request: Request): - async with sse.connect_sse(request.scope, request.receive, request._send) as streams: # type: ignore[reportPrivateUsage] - await app.run(streams[0], streams[1], app.create_initialization_options()) - return Response() - - starlette_app = Starlette( - debug=True, - routes=[ - Route("/sse", endpoint=handle_sse, methods=["GET"]), - Mount("/messages/", app=sse.handle_post_message), - ], - ) - + if transport == "streamable-http": import uvicorn - uvicorn.run(starlette_app, host="127.0.0.1", port=port) + uvicorn.run(app.streamable_http_app(), host="127.0.0.1", port=port) else: from mcp.server.stdio import stdio_server diff --git a/examples/servers/simple-streamablehttp-stateless/README.md b/examples/servers/simple-streamablehttp-stateless/README.md index b87250b353..a254f88d14 100644 --- a/examples/servers/simple-streamablehttp-stateless/README.md +++ b/examples/servers/simple-streamablehttp-stateless/README.md @@ -7,7 +7,6 @@ A stateless MCP server example demonstrating the StreamableHttp transport withou - Uses the StreamableHTTP transport in stateless mode (mcp_session_id=None) - Each request creates a new ephemeral connection - No session state maintained between requests -- Task lifecycle scoped to individual requests - Suitable for deployment in multi-node environments ## Usage diff --git a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py index cb4a6503ce..e2b8d2ef2f 100644 --- a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py +++ b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py @@ -1,17 +1,11 @@ -import contextlib import logging -from collections.abc import AsyncIterator import anyio import click import uvicorn from mcp import types from mcp.server import Server, ServerRequestContext -from mcp.server.streamable_http_manager import StreamableHTTPSessionManager -from starlette.applications import Starlette from starlette.middleware.cors import CORSMiddleware -from starlette.routing import Mount -from starlette.types import Receive, Scope, Send logger = logging.getLogger(__name__) @@ -104,39 +98,17 @@ def main( on_call_tool=handle_call_tool, ) - # Create the session manager with true stateless mode - session_manager = StreamableHTTPSessionManager( - app=app, - event_store=None, + starlette_app = app.streamable_http_app( + stateless_http=True, json_response=json_response, - stateless=True, - ) - - async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None: - await session_manager.handle_request(scope, receive, send) - - @contextlib.asynccontextmanager - async def lifespan(app: Starlette) -> AsyncIterator[None]: - """Context manager for session manager.""" - async with session_manager.run(): - logger.info("Application started with StreamableHTTP session manager!") - try: - yield - finally: - logger.info("Application shutting down...") - - # Create an ASGI application using the transport - starlette_app = Starlette( debug=True, - routes=[Mount("/mcp", app=handle_streamable_http)], - lifespan=lifespan, ) # Wrap ASGI application with CORS middleware to expose Mcp-Session-Id header # for browser-based clients (ensures 500 errors get proper CORS headers) starlette_app = CORSMiddleware( starlette_app, - allow_origins=["*"], # Allow all origins - adjust as needed for production + allow_origins=["*"], # Note: streamable_http_app() enforces localhost-only Origin by default allow_methods=["GET", "POST", "DELETE"], # MCP streamable HTTP methods expose_headers=["Mcp-Session-Id"], ) diff --git a/examples/servers/simple-streamablehttp/README.md b/examples/servers/simple-streamablehttp/README.md index 9836367170..3eed3320e7 100644 --- a/examples/servers/simple-streamablehttp/README.md +++ b/examples/servers/simple-streamablehttp/README.md @@ -6,9 +6,7 @@ A simple MCP server example demonstrating the StreamableHttp transport, which en - Uses the StreamableHTTP transport for server-client communication - Supports REST API operations (POST, GET, DELETE) for `/mcp` endpoint -- Task management with anyio task groups - Ability to send multiple notifications over time to the client -- Proper resource cleanup and lifespan management - Resumability support via InMemoryEventStore ## Usage diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index 2f2a53b1b1..ec9761d1b0 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -1,16 +1,11 @@ -import contextlib import logging -from collections.abc import AsyncIterator import anyio import click +import uvicorn from mcp import types from mcp.server import Server, ServerRequestContext -from mcp.server.streamable_http_manager import StreamableHTTPSessionManager -from starlette.applications import Starlette from starlette.middleware.cors import CORSMiddleware -from starlette.routing import Mount -from starlette.types import Receive, Scope, Send from .event_store import InMemoryEventStore @@ -127,47 +122,21 @@ def main( # For production, use a persistent storage solution. event_store = InMemoryEventStore() - # Create the session manager with our app and event store - session_manager = StreamableHTTPSessionManager( - app=app, - event_store=event_store, # Enable resumability + starlette_app = app.streamable_http_app( + event_store=event_store, json_response=json_response, - ) - - # ASGI handler for streamable HTTP connections - async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None: - await session_manager.handle_request(scope, receive, send) - - @contextlib.asynccontextmanager - async def lifespan(app: Starlette) -> AsyncIterator[None]: - """Context manager for managing session manager lifecycle.""" - async with session_manager.run(): - logger.info("Application started with StreamableHTTP session manager!") - try: - yield - finally: - logger.info("Application shutting down...") - - # Create an ASGI application using the transport - starlette_app = Starlette( debug=True, - routes=[ - Mount("/mcp", app=handle_streamable_http), - ], - lifespan=lifespan, ) # Wrap ASGI application with CORS middleware to expose Mcp-Session-Id header # for browser-based clients (ensures 500 errors get proper CORS headers) starlette_app = CORSMiddleware( starlette_app, - allow_origins=["*"], # Allow all origins - adjust as needed for production + allow_origins=["*"], # Note: streamable_http_app() enforces localhost-only Origin by default allow_methods=["GET", "POST", "DELETE"], # MCP streamable HTTP methods expose_headers=["Mcp-Session-Id"], ) - import uvicorn - uvicorn.run(starlette_app, host="127.0.0.1", port=port) return 0 diff --git a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py b/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py index 6938b6552a..bc06e12088 100644 --- a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py +++ b/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py @@ -6,8 +6,6 @@ - ServerTaskContext.elicit() and ServerTaskContext.create_message() queue requests properly """ -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager from typing import Any import click @@ -15,9 +13,6 @@ from mcp import types from mcp.server import Server, ServerRequestContext from mcp.server.experimental.task_context import ServerTaskContext -from mcp.server.streamable_http_manager import StreamableHTTPSessionManager -from starlette.applications import Starlette -from starlette.routing import Mount async def handle_list_tools( @@ -134,23 +129,10 @@ async def handle_call_tool( server.experimental.enable_tasks() -def create_app(session_manager: StreamableHTTPSessionManager) -> Starlette: - @asynccontextmanager - async def app_lifespan(app: Starlette) -> AsyncIterator[None]: - async with session_manager.run(): - yield - - return Starlette( - routes=[Mount("/mcp", app=session_manager.handle_request)], - lifespan=app_lifespan, - ) - - @click.command() @click.option("--port", default=8000, help="Port to listen on") def main(port: int) -> int: - session_manager = StreamableHTTPSessionManager(app=server) - starlette_app = create_app(session_manager) + starlette_app = server.streamable_http_app() print(f"Starting server on http://localhost:{port}/mcp") uvicorn.run(starlette_app, host="127.0.0.1", port=port) return 0 diff --git a/examples/servers/simple-task/mcp_simple_task/server.py b/examples/servers/simple-task/mcp_simple_task/server.py index 50ae3ca9af..7583cd8f0e 100644 --- a/examples/servers/simple-task/mcp_simple_task/server.py +++ b/examples/servers/simple-task/mcp_simple_task/server.py @@ -1,17 +1,11 @@ """Simple task server demonstrating MCP tasks over streamable HTTP.""" -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager - import anyio import click import uvicorn from mcp import types from mcp.server import Server, ServerRequestContext from mcp.server.experimental.task_context import ServerTaskContext -from mcp.server.streamable_http_manager import StreamableHTTPSessionManager -from starlette.applications import Starlette -from starlette.routing import Mount async def handle_list_tools( @@ -69,17 +63,7 @@ async def work(task: ServerTaskContext) -> types.CallToolResult: @click.command() @click.option("--port", default=8000, help="Port to listen on") def main(port: int) -> int: - session_manager = StreamableHTTPSessionManager(app=server) - - @asynccontextmanager - async def app_lifespan(app: Starlette) -> AsyncIterator[None]: - async with session_manager.run(): - yield - - starlette_app = Starlette( - routes=[Mount("/mcp", app=session_manager.handle_request)], - lifespan=app_lifespan, - ) + starlette_app = server.streamable_http_app() print(f"Starting server on http://localhost:{port}/mcp") uvicorn.run(starlette_app, host="127.0.0.1", port=port) diff --git a/examples/servers/simple-tool/README.md b/examples/servers/simple-tool/README.md index 06020b4b0e..7d3759f9de 100644 --- a/examples/servers/simple-tool/README.md +++ b/examples/servers/simple-tool/README.md @@ -3,14 +3,14 @@ A simple MCP server that exposes a website fetching tool. ## Usage -Start the server using either stdio (default) or SSE transport: +Start the server using either stdio (default) or Streamable HTTP transport: ```bash # Using stdio transport (default) uv run mcp-simple-tool -# Using SSE transport on custom port -uv run mcp-simple-tool --transport sse --port 8000 +# Using Streamable HTTP transport on custom port +uv run mcp-simple-tool --transport streamable-http --port 8000 ``` The server exposes a tool named "fetch" that accepts one required argument: diff --git a/examples/servers/simple-tool/mcp_simple_tool/server.py b/examples/servers/simple-tool/mcp_simple_tool/server.py index 9fe71e5b7a..226058b955 100644 --- a/examples/servers/simple-tool/mcp_simple_tool/server.py +++ b/examples/servers/simple-tool/mcp_simple_tool/server.py @@ -3,7 +3,6 @@ from mcp import types from mcp.server import Server, ServerRequestContext from mcp.shared._httpx_utils import create_mcp_http_client -from starlette.requests import Request async def fetch_website( @@ -51,10 +50,10 @@ async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequ @click.command() -@click.option("--port", default=8000, help="Port to listen on for SSE") +@click.option("--port", default=8000, help="Port to listen on for HTTP") @click.option( "--transport", - type=click.Choice(["stdio", "sse"]), + type=click.Choice(["stdio", "streamable-http"]), default="stdio", help="Transport type", ) @@ -65,30 +64,10 @@ def main(port: int, transport: str) -> int: on_call_tool=handle_call_tool, ) - if transport == "sse": - from mcp.server.sse import SseServerTransport - from starlette.applications import Starlette - from starlette.responses import Response - from starlette.routing import Mount, Route - - sse = SseServerTransport("/messages/") - - async def handle_sse(request: Request): - async with sse.connect_sse(request.scope, request.receive, request._send) as streams: # type: ignore[reportPrivateUsage] - await app.run(streams[0], streams[1], app.create_initialization_options()) - return Response() - - starlette_app = Starlette( - debug=True, - routes=[ - Route("/sse", endpoint=handle_sse, methods=["GET"]), - Mount("/messages/", app=sse.handle_post_message), - ], - ) - + if transport == "streamable-http": import uvicorn - uvicorn.run(starlette_app, host="127.0.0.1", port=port) + uvicorn.run(app.streamable_http_app(), host="127.0.0.1", port=port) else: from mcp.server.stdio import stdio_server diff --git a/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py index c8178c35a4..14bc174c47 100644 --- a/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py +++ b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py @@ -12,18 +12,13 @@ uv run mcp-sse-polling-demo --port 3000 """ -import contextlib import logging -from collections.abc import AsyncIterator import anyio import click +import uvicorn from mcp import types from mcp.server import Server, ServerRequestContext -from mcp.server.streamable_http_manager import StreamableHTTPSessionManager -from starlette.applications import Starlette -from starlette.routing import Mount -from starlette.types import Receive, Scope, Send from .event_store import InMemoryEventStore @@ -149,37 +144,14 @@ def main(port: int, log_level: str, retry_interval: int) -> int: on_call_tool=handle_call_tool, ) - # Create event store for resumability - event_store = InMemoryEventStore() - - # Create session manager with event store and retry interval - session_manager = StreamableHTTPSessionManager( - app=app, - event_store=event_store, + starlette_app = app.streamable_http_app( + event_store=InMemoryEventStore(), retry_interval=retry_interval, - ) - - async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None: - await session_manager.handle_request(scope, receive, send) - - @contextlib.asynccontextmanager - async def lifespan(starlette_app: Starlette) -> AsyncIterator[None]: - async with session_manager.run(): - logger.info(f"SSE Polling Demo server started on port {port}") - logger.info("Try: POST /mcp with tools/call for 'process_batch'") - yield - logger.info("Server shutting down...") - - starlette_app = Starlette( debug=True, - routes=[ - Mount("/mcp", app=handle_streamable_http), - ], - lifespan=lifespan, ) - import uvicorn - + logger.info(f"SSE Polling Demo server starting on port {port}") + logger.info("Try: POST /mcp with tools/call for 'process_batch'") uvicorn.run(starlette_app, host="127.0.0.1", port=port) return 0 From 75a80b6f07f0dfbe0ceb8df50c920a4bda3266d0 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 16 Mar 2026 23:30:20 +0000 Subject: [PATCH 36/84] refactor: connect-first stream lifecycle for sse and streamable_http (#2292) Co-authored-by: Marcelo Trylesinski --- src/mcp/client/sse.py | 201 +++++++++--------- src/mcp/client/streamable_http.py | 76 ++++--- tests/client/test_transport_stream_cleanup.py | 37 +++- 3 files changed, 161 insertions(+), 153 deletions(-) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 972efce588..7b66b5c1b6 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -57,108 +57,101 @@ async def sse_client( write_stream: MemoryObjectSendStream[SessionMessage] write_stream_reader: MemoryObjectReceiveStream[SessionMessage] - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - - async with anyio.create_task_group() as tg: - try: - logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}") - async with httpx_client_factory( - headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout) - ) as client: - async with aconnect_sse( - client, - "GET", - url, - ) as event_source: - event_source.response.raise_for_status() - logger.debug("SSE connection established") - - async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED): - try: - async for sse in event_source.aiter_sse(): # pragma: no branch - logger.debug(f"Received SSE event: {sse.event}") - match sse.event: - case "endpoint": - endpoint_url = urljoin(url, sse.data) - logger.debug(f"Received endpoint URL: {endpoint_url}") - - url_parsed = urlparse(url) - endpoint_parsed = urlparse(endpoint_url) - if ( # pragma: no cover - url_parsed.netloc != endpoint_parsed.netloc - or url_parsed.scheme != endpoint_parsed.scheme - ): - error_msg = ( # pragma: no cover - f"Endpoint origin does not match connection origin: {endpoint_url}" - ) - logger.error(error_msg) # pragma: no cover - raise ValueError(error_msg) # pragma: no cover - - if on_session_created: - session_id = _extract_session_id_from_endpoint(endpoint_url) - if session_id: - on_session_created(session_id) - - task_status.started(endpoint_url) - - case "message": - # Skip empty data (keep-alive pings) - if not sse.data: - continue - try: - message = types.jsonrpc_message_adapter.validate_json( - sse.data, by_name=False - ) - logger.debug(f"Received server message: {message}") - except Exception as exc: # pragma: no cover - logger.exception("Error parsing server message") # pragma: no cover - await read_stream_writer.send(exc) # pragma: no cover - continue # pragma: no cover - - session_message = SessionMessage(message) - await read_stream_writer.send(session_message) - case _: # pragma: no cover - logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover - except SSEError as sse_exc: # pragma: lax no cover - logger.exception("Encountered SSE exception") - raise sse_exc - except Exception as exc: # pragma: lax no cover - logger.exception("Error in sse_reader") - await read_stream_writer.send(exc) - finally: - await read_stream_writer.aclose() - - async def post_writer(endpoint_url: str): - try: - async with write_stream_reader: - async for session_message in write_stream_reader: - logger.debug(f"Sending client message: {session_message}") - response = await client.post( - endpoint_url, - json=session_message.message.model_dump( - by_alias=True, - mode="json", - exclude_unset=True, - ), + logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}") + async with httpx_client_factory( + headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout) + ) as client: + async with aconnect_sse(client, "GET", url) as event_source: + event_source.response.raise_for_status() + logger.debug("SSE connection established") + + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + + async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED): + try: + async for sse in event_source.aiter_sse(): # pragma: no branch + logger.debug(f"Received SSE event: {sse.event}") + match sse.event: + case "endpoint": + endpoint_url = urljoin(url, sse.data) + logger.debug(f"Received endpoint URL: {endpoint_url}") + + url_parsed = urlparse(url) + endpoint_parsed = urlparse(endpoint_url) + if ( # pragma: no cover + url_parsed.netloc != endpoint_parsed.netloc + or url_parsed.scheme != endpoint_parsed.scheme + ): + error_msg = ( # pragma: no cover + f"Endpoint origin does not match connection origin: {endpoint_url}" ) - response.raise_for_status() - logger.debug(f"Client message sent successfully: {response.status_code}") - except Exception: # pragma: lax no cover - logger.exception("Error in post_writer") - finally: - await write_stream.aclose() - - endpoint_url = await tg.start(sse_reader) - logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}") - tg.start_soon(post_writer, endpoint_url) - - try: - yield read_stream, write_stream - finally: - tg.cancel_scope.cancel() - finally: - await read_stream_writer.aclose() - await write_stream.aclose() - await read_stream.aclose() - await write_stream_reader.aclose() + logger.error(error_msg) # pragma: no cover + raise ValueError(error_msg) # pragma: no cover + + if on_session_created: + session_id = _extract_session_id_from_endpoint(endpoint_url) + if session_id: + on_session_created(session_id) + + task_status.started(endpoint_url) + + case "message": + # Skip empty data (keep-alive pings) + if not sse.data: + continue + try: + message = types.jsonrpc_message_adapter.validate_json(sse.data, by_name=False) + logger.debug(f"Received server message: {message}") + except Exception as exc: # pragma: no cover + logger.exception("Error parsing server message") # pragma: no cover + await read_stream_writer.send(exc) # pragma: no cover + continue # pragma: no cover + + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) + case _: # pragma: no cover + logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover + except SSEError as sse_exc: # pragma: lax no cover + logger.exception("Encountered SSE exception") + raise sse_exc + except Exception as exc: # pragma: lax no cover + logger.exception("Error in sse_reader") + await read_stream_writer.send(exc) + finally: + await read_stream_writer.aclose() + + async def post_writer(endpoint_url: str): + try: + async with write_stream_reader, write_stream: + async for session_message in write_stream_reader: + logger.debug(f"Sending client message: {session_message}") + response = await client.post( + endpoint_url, + json=session_message.message.model_dump( + by_alias=True, + mode="json", + exclude_unset=True, + ), + ) + response.raise_for_status() + logger.debug(f"Client message sent successfully: {response.status_code}") + except Exception: # pragma: lax no cover + logger.exception("Error in post_writer") + + # On Python 3.14, coverage.py reports a phantom branch arc on this + # line (->yield) when nested two async-with levels deep. The branch + # is the unreachable "did __aexit__ suppress?" arm for memory streams. + async with ( # pragma: no branch + read_stream_writer, + read_stream, + write_stream, + write_stream_reader, + anyio.create_task_group() as tg, + ): + endpoint_url = await tg.start(sse_reader) + logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}") + tg.start_soon(post_writer, endpoint_url) + + yield read_stream, write_stream + tg.cancel_scope.cancel() diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 3416bbc816..3afb94b034 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -440,7 +440,7 @@ async def post_writer( ) -> None: """Handle writing requests to the server.""" try: - async with write_stream_reader: + async with write_stream_reader, read_stream_writer, write_stream: async for session_message in write_stream_reader: message = session_message.message metadata = ( @@ -480,9 +480,6 @@ async def handle_request_async(): except Exception: # pragma: lax no cover logger.exception("Error in post_writer") - finally: - await read_stream_writer.aclose() - await write_stream.aclose() async def terminate_session(self, client: httpx.AsyncClient) -> None: """Terminate the session by sending a DELETE request.""" @@ -533,9 +530,6 @@ async def streamable_http_client( Example: See examples/snippets/clients/ for usage patterns. """ - read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) - write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) - # Determine if we need to create and manage the client client_provided = http_client is not None client = http_client @@ -546,36 +540,40 @@ async def streamable_http_client( transport = StreamableHTTPTransport(url) - async with anyio.create_task_group() as tg: - try: - logger.debug(f"Connecting to StreamableHTTP endpoint: {url}") - - async with contextlib.AsyncExitStack() as stack: - # Only manage client lifecycle if we created it - if not client_provided: - await stack.enter_async_context(client) - - def start_get_stream() -> None: - tg.start_soon(transport.handle_get_stream, client, read_stream_writer) - - tg.start_soon( - transport.post_writer, - client, - write_stream_reader, - read_stream_writer, - write_stream, - start_get_stream, - tg, - ) + logger.debug(f"Connecting to StreamableHTTP endpoint: {url}") + + async with contextlib.AsyncExitStack() as stack: + # Only manage client lifecycle if we created it + if not client_provided: + await stack.enter_async_context(client) + + read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) + write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) + + async with ( + read_stream_writer, + read_stream, + write_stream, + write_stream_reader, + anyio.create_task_group() as tg, + ): + + def start_get_stream() -> None: + tg.start_soon(transport.handle_get_stream, client, read_stream_writer) + + tg.start_soon( + transport.post_writer, + client, + write_stream_reader, + read_stream_writer, + write_stream, + start_get_stream, + tg, + ) - try: - yield read_stream, write_stream - finally: - if transport.session_id and terminate_on_close: - await transport.terminate_session(client) - tg.cancel_scope.cancel() - finally: - await read_stream_writer.aclose() - await write_stream.aclose() - await read_stream.aclose() - await write_stream_reader.aclose() + try: + yield read_stream, write_stream + finally: + if transport.session_id and terminate_on_close: + await transport.terminate_session(client) + tg.cancel_scope.cancel() diff --git a/tests/client/test_transport_stream_cleanup.py b/tests/client/test_transport_stream_cleanup.py index 631b0fff22..1e6be3c725 100644 --- a/tests/client/test_transport_stream_cleanup.py +++ b/tests/client/test_transport_stream_cleanup.py @@ -58,22 +58,39 @@ def hook(args: "sys.UnraisableHookArgs") -> None: # pragma: no cover @pytest.mark.anyio async def test_sse_client_closes_all_streams_on_connection_error(free_tcp_port: int) -> None: - """sse_client must close all 4 stream ends when the connection fails. + """sse_client creates streams only after the SSE connection succeeds, so a + ConnectError propagates directly with nothing to leak. - Before the fix, only read_stream_writer and write_stream were closed in - the finally block. read_stream and write_stream_reader were leaked. + Before the fix, streams were created before connecting and only 2 of 4 were + closed in the finally block. """ with _assert_no_memory_stream_leak(): - # sse_client enters a task group BEFORE connecting, so anyio wraps the - # ConnectError from aconnect_sse in an ExceptionGroup. - with pytest.raises(Exception) as exc_info: # noqa: B017 + with pytest.raises(httpx.ConnectError): async with sse_client(f"http://127.0.0.1:{free_tcp_port}/sse"): pytest.fail("should not reach here") # pragma: no cover - assert exc_info.group_contains(httpx.ConnectError) - # exc_info holds the traceback → holds frame locals → keeps leaked - # streams alive. Must drop it before gc.collect() can detect a leak. - del exc_info + +@pytest.mark.anyio +async def test_sse_client_closes_all_streams_on_http_error() -> None: + """sse_client creates streams only after raise_for_status() passes, so an + HTTPStatusError from a 4xx/5xx response propagates bare (not wrapped in an + ExceptionGroup) with nothing to leak — the task group is never entered. + """ + + def return_403(request: httpx.Request) -> httpx.Response: + return httpx.Response(403) + + def mock_factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + return httpx.AsyncClient(transport=httpx.MockTransport(return_403)) + + with _assert_no_memory_stream_leak(): + with pytest.raises(httpx.HTTPStatusError): + async with sse_client("http://test/sse", httpx_client_factory=mock_factory): + pytest.fail("should not reach here") # pragma: no cover @pytest.mark.anyio From 1a2244f402bdb08f1611bfd4ad34ef36a91a23a7 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 17 Mar 2026 18:40:39 +0000 Subject: [PATCH 37/84] fix: handle non-UTF-8 bytes in stdio server stdin (#2302) --- src/mcp/server/stdio.py | 4 ++-- tests/server/test_stdio.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index e526bab569..5ea6c4e778 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -39,7 +39,7 @@ async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio. # python is platform-dependent (Windows is particularly problematic), so we # re-wrap the underlying binary stream to ensure UTF-8. if not stdin: - stdin = anyio.wrap_file(TextIOWrapper(sys.stdin.buffer, encoding="utf-8")) + stdin = anyio.wrap_file(TextIOWrapper(sys.stdin.buffer, encoding="utf-8", errors="replace")) if not stdout: stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8")) @@ -58,7 +58,7 @@ async def stdin_reader(): async for line in stdin: try: message = types.jsonrpc_message_adapter.validate_json(line, by_name=False) - except Exception as exc: # pragma: no cover + except Exception as exc: await read_stream_writer.send(exc) continue diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index 9a7ddaab40..677a993567 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -1,4 +1,6 @@ import io +import sys +from io import TextIOWrapper import anyio import pytest @@ -59,3 +61,34 @@ async def test_stdio_server(): assert len(received_responses) == 2 assert received_responses[0] == JSONRPCRequest(jsonrpc="2.0", id=3, method="ping") assert received_responses[1] == JSONRPCResponse(jsonrpc="2.0", id=4, result={}) + + +@pytest.mark.anyio +async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch): + """Non-UTF-8 bytes on stdin must not crash the server. + + Invalid bytes are replaced with U+FFFD, which then fails JSON parsing and + is delivered as an in-stream exception. Subsequent valid messages must + still be processed. + """ + # \xff\xfe are invalid UTF-8 start bytes. + valid = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + raw_stdin = io.BytesIO(b"\xff\xfe\n" + valid.model_dump_json(by_alias=True, exclude_none=True).encode() + b"\n") + + # Replace sys.stdin with a wrapper whose .buffer is our raw bytes, so that + # stdio_server()'s default path wraps it with errors='replace'. + monkeypatch.setattr(sys, "stdin", TextIOWrapper(raw_stdin, encoding="utf-8")) + monkeypatch.setattr(sys, "stdout", TextIOWrapper(io.BytesIO(), encoding="utf-8")) + + with anyio.fail_after(5): + async with stdio_server() as (read_stream, write_stream): + await write_stream.aclose() + async with read_stream: # pragma: no branch + # First line: \xff\xfe -> U+FFFD U+FFFD -> JSON parse fails -> exception in stream + first = await read_stream.receive() + assert isinstance(first, Exception) + + # Second line: valid message still comes through + second = await read_stream.receive() + assert isinstance(second, SessionMessage) + assert second.message == valid From ff50351f9e08b0a7dbbcade3813c48986b737b05 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 17 Mar 2026 19:53:39 +0000 Subject: [PATCH 38/84] ci: run strict-no-cover in scripts/test to catch stale pragmas locally (#2305) --- CLAUDE.md | 19 ++++++++++++++++--- scripts/test | 3 +++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 98bd451152..2eee085e13 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -29,18 +29,31 @@ This document contains critical information about working with this codebase. Fo - IMPORTANT: The `tests/client/test_client.py` is the most well designed test file. Follow its patterns. - IMPORTANT: Be minimal, and focus on E2E tests: Use the `mcp.client.Client` whenever possible. - Coverage: CI requires 100% (`fail_under = 100`, `branch = true`). - - Full check: `./scripts/test` (~20s, matches CI exactly) - - Targeted check while iterating: + - Full check: `./scripts/test` (~23s). Runs coverage + `strict-no-cover` on the + default Python. Not identical to CI: CI also runs 3.10–3.14 × {ubuntu, windows}, + and some branch-coverage quirks only surface on specific matrix entries. + - Targeted check while iterating (~4s, deterministic): ```bash uv run --frozen coverage erase uv run --frozen coverage run -m pytest tests/path/test_foo.py uv run --frozen coverage combine uv run --frozen coverage report --include='src/mcp/path/foo.py' --fail-under=0 + UV_FROZEN=1 uv run --frozen strict-no-cover ``` Partial runs can't hit 100% (coverage tracks `tests/` too), so `--fail-under=0` - and `--include` scope the report to what you actually changed. + and `--include` scope the report. `strict-no-cover` has no false positives on + partial runs — if your new test executes a line marked `# pragma: no cover`, + even a single-file run catches it. + - Coverage pragmas: + - `# pragma: no cover` — line is never executed. CI's `strict-no-cover` fails if + it IS executed. When your test starts covering such a line, remove the pragma. + - `# pragma: lax no cover` — excluded from coverage but not checked by + `strict-no-cover`. Use for lines covered on some platforms/versions but not + others. + - `# pragma: no branch` — excludes branch arcs only. coverage.py misreports the + `->exit` arc for nested `async with` on Python 3.11+ (worse on 3.14/Windows). - Avoid `anyio.sleep()` with a fixed duration to wait for async operations. Instead: - Use `anyio.Event` — set it in the callback/handler, `await event.wait()` in the test - For stream messages, use `await stream.receive()` instead of `sleep()` + `receive_nowait()` diff --git a/scripts/test b/scripts/test index ee1259b597..dc43f351dd 100755 --- a/scripts/test +++ b/scripts/test @@ -6,3 +6,6 @@ uv run --frozen coverage erase uv run --frozen coverage run -m pytest -n auto $@ uv run --frozen coverage combine uv run --frozen coverage report +# strict-no-cover spawns `uv run coverage json` internally without --frozen; +# UV_FROZEN=1 propagates to that subprocess so it doesn't touch uv.lock. +UV_FROZEN=1 uv run --frozen strict-no-cover From 7826ade12b50ce377d0856a3e3448ea36269a4c7 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 18 Mar 2026 15:25:11 +0000 Subject: [PATCH 39/84] test: convert test_integration.py to in-memory transport (fix flaky) (#2277) --- src/mcp/server/session.py | 2 +- tests/server/mcpserver/test_integration.py | 744 +++++++-------------- 2 files changed, 231 insertions(+), 515 deletions(-) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 759d2131a1..ce467e6c93 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -474,7 +474,7 @@ async def send_progress_notification( related_request_id, ) - async def send_resource_list_changed(self) -> None: # pragma: no cover + async def send_resource_list_changed(self) -> None: """Send a resource list changed notification.""" await self.send_notification(types.ResourceListChangedNotification()) diff --git a/tests/server/mcpserver/test_integration.py b/tests/server/mcpserver/test_integration.py index c4ea2dad65..90e333b775 100644 --- a/tests/server/mcpserver/test_integration.py +++ b/tests/server/mcpserver/test_integration.py @@ -1,7 +1,7 @@ """Integration tests for MCPServer server functionality. These tests validate the proper functioning of MCPServer features using focused, -single-feature servers across different transports (SSE and StreamableHTTP). +single-feature example servers over an in-memory transport. """ # TODO(Marcelo): The `examples` package is not being imported as package. We need to solve this. # pyright: reportUnknownMemberType=false @@ -10,12 +10,8 @@ # pyright: reportUnknownArgumentType=false import json -import multiprocessing -import socket -from collections.abc import Generator import pytest -import uvicorn from inline_snapshot import snapshot from examples.snippets.servers import ( @@ -30,9 +26,8 @@ structured_output, tool_progress, ) +from mcp.client import Client from mcp.client.session import ClientSession -from mcp.client.sse import sse_client -from mcp.client.streamable_http import streamable_http_client from mcp.shared._context import RequestContext from mcp.shared.session import RequestResponder from mcp.types import ( @@ -42,7 +37,6 @@ ElicitRequestParams, ElicitResult, GetPromptResult, - InitializeResult, LoggingMessageNotification, LoggingMessageNotificationParams, NotificationParams, @@ -58,7 +52,8 @@ TextResourceContents, ToolListChangedNotification, ) -from tests.test_helpers import wait_for_server + +pytestmark = pytest.mark.anyio class NotificationCollector: @@ -85,105 +80,6 @@ async def handle_generic_notification( self.tool_notifications.append(message.params) -# Common fixtures -@pytest.fixture -def server_port() -> int: - """Get a free port for testing.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: - """Get the server URL for testing.""" - return f"http://127.0.0.1:{server_port}" - - -def run_server_with_transport(module_name: str, port: int, transport: str) -> None: # pragma: no cover - """Run server with specified transport.""" - # Get the MCP instance based on module name - if module_name == "basic_tool": - mcp = basic_tool.mcp - elif module_name == "basic_resource": - mcp = basic_resource.mcp - elif module_name == "basic_prompt": - mcp = basic_prompt.mcp - elif module_name == "tool_progress": - mcp = tool_progress.mcp - elif module_name == "sampling": - mcp = sampling.mcp - elif module_name == "elicitation": - mcp = elicitation.mcp - elif module_name == "completion": - mcp = completion.mcp - elif module_name == "notifications": - mcp = notifications.mcp - elif module_name == "mcpserver_quickstart": - mcp = mcpserver_quickstart.mcp - elif module_name == "structured_output": - mcp = structured_output.mcp - else: - raise ImportError(f"Unknown module: {module_name}") - - # Create app based on transport type - if transport == "sse": - app = mcp.sse_app() - elif transport == "streamable-http": - app = mcp.streamable_http_app() - else: - raise ValueError(f"Invalid transport for test server: {transport}") - - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=port, log_level="error")) - print(f"Starting {transport} server on port {port}") - server.run() - - -@pytest.fixture -def server_transport(request: pytest.FixtureRequest, server_port: int) -> Generator[str, None, None]: - """Start server in a separate process with specified MCP instance and transport. - - Args: - request: pytest request with param tuple of (module_name, transport) - server_port: Port to run the server on - - Yields: - str: The transport type ('sse' or 'streamable_http') - """ - module_name, transport = request.param - - proc = multiprocessing.Process( - target=run_server_with_transport, - args=(module_name, server_port, transport), - daemon=True, - ) - proc.start() - - # Wait for server to be ready - wait_for_server(server_port) - - yield transport - - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("Server process failed to terminate") - - -# Helper function to create client based on transport -def create_client_for_transport(transport: str, server_url: str): - """Create the appropriate client context manager based on transport type.""" - if transport == "sse": - endpoint = f"{server_url}/sse" - return sse_client(endpoint) - elif transport == "streamable-http": - endpoint = f"{server_url}/mcp" - return streamable_http_client(endpoint) - else: # pragma: no cover - raise ValueError(f"Invalid transport: {transport}") - - -# Callback functions for testing async def sampling_callback( context: RequestContext[ClientSession], params: CreateMessageRequestParams ) -> CreateMessageResult: @@ -210,147 +106,85 @@ async def elicitation_callback(context: RequestContext[ClientSession], params: E return ElicitResult(action="decline") -# Test basic tools -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("basic_tool", "sse"), - ("basic_tool", "streamable-http"), - ], - indirect=True, -) -async def test_basic_tools(server_transport: str, server_url: str) -> None: +async def test_basic_tools() -> None: """Test basic tool functionality.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Tool Example" - assert result.capabilities.tools is not None - - # Test sum tool - tool_result = await session.call_tool("sum", {"a": 5, "b": 3}) - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - assert tool_result.content[0].text == "8" - - # Test weather tool - weather_result = await session.call_tool("get_weather", {"city": "London"}) - assert len(weather_result.content) == 1 - assert isinstance(weather_result.content[0], TextContent) - assert "Weather in London: 22degreesC" in weather_result.content[0].text - - -# Test resources -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("basic_resource", "sse"), - ("basic_resource", "streamable-http"), - ], - indirect=True, -) -async def test_basic_resources(server_transport: str, server_url: str) -> None: + async with Client(basic_tool.mcp) as client: + assert client.server_capabilities is not None + assert client.server_capabilities.tools is not None + + # Test sum tool + tool_result = await client.call_tool("sum", {"a": 5, "b": 3}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == "8" + + # Test weather tool + weather_result = await client.call_tool("get_weather", {"city": "London"}) + assert len(weather_result.content) == 1 + assert isinstance(weather_result.content[0], TextContent) + assert "Weather in London: 22degreesC" in weather_result.content[0].text + + +async def test_basic_resources() -> None: """Test basic resource functionality.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Resource Example" - assert result.capabilities.resources is not None - - # Test document resource - doc_content = await session.read_resource("file://documents/readme") - assert isinstance(doc_content, ReadResourceResult) - assert len(doc_content.contents) == 1 - assert isinstance(doc_content.contents[0], TextResourceContents) - assert "Content of readme" in doc_content.contents[0].text - - # Test settings resource - settings_content = await session.read_resource("config://settings") - assert isinstance(settings_content, ReadResourceResult) - assert len(settings_content.contents) == 1 - assert isinstance(settings_content.contents[0], TextResourceContents) - settings_json = json.loads(settings_content.contents[0].text) - assert settings_json["theme"] == "dark" - assert settings_json["language"] == "en" - - -# Test prompts -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("basic_prompt", "sse"), - ("basic_prompt", "streamable-http"), - ], - indirect=True, -) -async def test_basic_prompts(server_transport: str, server_url: str) -> None: + async with Client(basic_resource.mcp) as client: + assert client.server_capabilities is not None + assert client.server_capabilities.resources is not None + + # Test document resource + doc_content = await client.read_resource("file://documents/readme") + assert isinstance(doc_content, ReadResourceResult) + assert len(doc_content.contents) == 1 + assert isinstance(doc_content.contents[0], TextResourceContents) + assert "Content of readme" in doc_content.contents[0].text + + # Test settings resource + settings_content = await client.read_resource("config://settings") + assert isinstance(settings_content, ReadResourceResult) + assert len(settings_content.contents) == 1 + assert isinstance(settings_content.contents[0], TextResourceContents) + settings_json = json.loads(settings_content.contents[0].text) + assert settings_json["theme"] == "dark" + assert settings_json["language"] == "en" + + +async def test_basic_prompts() -> None: """Test basic prompt functionality.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Prompt Example" - assert result.capabilities.prompts is not None - - # Test review_code prompt - prompts = await session.list_prompts() - review_prompt = next((p for p in prompts.prompts if p.name == "review_code"), None) - assert review_prompt is not None - - prompt_result = await session.get_prompt("review_code", {"code": "def hello():\n print('Hello')"}) - assert isinstance(prompt_result, GetPromptResult) - assert len(prompt_result.messages) == 1 - assert isinstance(prompt_result.messages[0].content, TextContent) - assert "Please review this code:" in prompt_result.messages[0].content.text - assert "def hello():" in prompt_result.messages[0].content.text - - # Test debug_error prompt - debug_result = await session.get_prompt( - "debug_error", {"error": "TypeError: 'NoneType' object is not subscriptable"} - ) - assert isinstance(debug_result, GetPromptResult) - assert len(debug_result.messages) == 3 - assert debug_result.messages[0].role == "user" - assert isinstance(debug_result.messages[0].content, TextContent) - assert "I'm seeing this error:" in debug_result.messages[0].content.text - assert debug_result.messages[1].role == "user" - assert isinstance(debug_result.messages[1].content, TextContent) - assert "TypeError" in debug_result.messages[1].content.text - assert debug_result.messages[2].role == "assistant" - assert isinstance(debug_result.messages[2].content, TextContent) - assert "I'll help debug that" in debug_result.messages[2].content.text - - -# Test progress reporting -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("tool_progress", "sse"), - ("tool_progress", "streamable-http"), - ], - indirect=True, -) -async def test_tool_progress(server_transport: str, server_url: str) -> None: + async with Client(basic_prompt.mcp) as client: + assert client.server_capabilities is not None + assert client.server_capabilities.prompts is not None + + # Test review_code prompt + prompts = await client.list_prompts() + review_prompt = next((p for p in prompts.prompts if p.name == "review_code"), None) + assert review_prompt is not None + + prompt_result = await client.get_prompt("review_code", {"code": "def hello():\n print('Hello')"}) + assert isinstance(prompt_result, GetPromptResult) + assert len(prompt_result.messages) == 1 + assert isinstance(prompt_result.messages[0].content, TextContent) + assert "Please review this code:" in prompt_result.messages[0].content.text + assert "def hello():" in prompt_result.messages[0].content.text + + # Test debug_error prompt + debug_result = await client.get_prompt( + "debug_error", {"error": "TypeError: 'NoneType' object is not subscriptable"} + ) + assert isinstance(debug_result, GetPromptResult) + assert len(debug_result.messages) == 3 + assert debug_result.messages[0].role == "user" + assert isinstance(debug_result.messages[0].content, TextContent) + assert "I'm seeing this error:" in debug_result.messages[0].content.text + assert debug_result.messages[1].role == "user" + assert isinstance(debug_result.messages[1].content, TextContent) + assert "TypeError" in debug_result.messages[1].content.text + assert debug_result.messages[2].role == "assistant" + assert isinstance(debug_result.messages[2].content, TextContent) + assert "I'll help debug that" in debug_result.messages[2].content.text + + +async def test_tool_progress() -> None: """Test tool progress reporting.""" - transport = server_transport collector = NotificationCollector() async def message_handler(message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception): @@ -358,134 +192,79 @@ async def message_handler(message: RequestResponder[ServerRequest, ClientResult] if isinstance(message, Exception): # pragma: no cover raise message - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Progress Example" - - # Test progress callback - progress_updates = [] - - async def progress_callback(progress: float, total: float | None, message: str | None) -> None: - progress_updates.append((progress, total, message)) - - # Call tool with progress - steps = 3 - tool_result = await session.call_tool( - "long_running_task", - {"task_name": "Test Task", "steps": steps}, - progress_callback=progress_callback, - ) - assert tool_result.content == snapshot([TextContent(text="Task 'Test Task' completed")]) - - # Verify progress updates - assert len(progress_updates) == steps - for i, (progress, total, message) in enumerate(progress_updates): - expected_progress = (i + 1) / steps - assert abs(progress - expected_progress) < 0.01 - assert total == 1.0 - assert f"Step {i + 1}/{steps}" in message - - # Verify log messages - assert len(collector.log_messages) > 0 - - -# Test sampling -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("sampling", "sse"), - ("sampling", "streamable-http"), - ], - indirect=True, -) -async def test_sampling(server_transport: str, server_url: str) -> None: + async with Client(tool_progress.mcp, message_handler=message_handler) as client: + # Test progress callback + progress_updates = [] + + async def progress_callback(progress: float, total: float | None, message: str | None) -> None: + progress_updates.append((progress, total, message)) + + # Call tool with progress + steps = 3 + tool_result = await client.call_tool( + "long_running_task", + {"task_name": "Test Task", "steps": steps}, + progress_callback=progress_callback, + ) + assert tool_result.content == snapshot([TextContent(text="Task 'Test Task' completed")]) + + # Verify progress updates + assert len(progress_updates) == steps + for i, (progress, total, message) in enumerate(progress_updates): + expected_progress = (i + 1) / steps + assert abs(progress - expected_progress) < 0.01 + assert total == 1.0 + assert f"Step {i + 1}/{steps}" in message + + # Verify log messages + assert len(collector.log_messages) > 0 + + +async def test_sampling() -> None: """Test sampling (LLM interaction) functionality.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, sampling_callback=sampling_callback) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Sampling Example" - assert result.capabilities.tools is not None - - # Test sampling tool - sampling_result = await session.call_tool("generate_poem", {"topic": "nature"}) - assert len(sampling_result.content) == 1 - assert isinstance(sampling_result.content[0], TextContent) - assert "This is a simulated LLM response" in sampling_result.content[0].text - - -# Test elicitation -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("elicitation", "sse"), - ("elicitation", "streamable-http"), - ], - indirect=True, -) -async def test_elicitation(server_transport: str, server_url: str) -> None: + async with Client(sampling.mcp, sampling_callback=sampling_callback) as client: + assert client.server_capabilities is not None + assert client.server_capabilities.tools is not None + + # Test sampling tool + sampling_result = await client.call_tool("generate_poem", {"topic": "nature"}) + assert len(sampling_result.content) == 1 + assert isinstance(sampling_result.content[0], TextContent) + assert "This is a simulated LLM response" in sampling_result.content[0].text + + +async def test_elicitation() -> None: """Test elicitation (user interaction) functionality.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, elicitation_callback=elicitation_callback) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Elicitation Example" - - # Test booking with unavailable date (triggers elicitation) - booking_result = await session.call_tool( - "book_table", - { - "date": "2024-12-25", # Unavailable date - "time": "19:00", - "party_size": 4, - }, - ) - assert len(booking_result.content) == 1 - assert isinstance(booking_result.content[0], TextContent) - assert "[SUCCESS] Booked for 2024-12-26" in booking_result.content[0].text - - # Test booking with available date (no elicitation) - booking_result = await session.call_tool( - "book_table", - { - "date": "2024-12-20", # Available date - "time": "20:00", - "party_size": 2, - }, - ) - assert len(booking_result.content) == 1 - assert isinstance(booking_result.content[0], TextContent) - assert "[SUCCESS] Booked for 2024-12-20 at 20:00" in booking_result.content[0].text - - -# Test notifications -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("notifications", "sse"), - ("notifications", "streamable-http"), - ], - indirect=True, -) -async def test_notifications(server_transport: str, server_url: str) -> None: + async with Client(elicitation.mcp, elicitation_callback=elicitation_callback) as client: + # Test booking with unavailable date (triggers elicitation) + booking_result = await client.call_tool( + "book_table", + { + "date": "2024-12-25", # Unavailable date + "time": "19:00", + "party_size": 4, + }, + ) + assert len(booking_result.content) == 1 + assert isinstance(booking_result.content[0], TextContent) + assert "[SUCCESS] Booked for 2024-12-26" in booking_result.content[0].text + + # Test booking with available date (no elicitation) + booking_result = await client.call_tool( + "book_table", + { + "date": "2024-12-20", # Available date + "time": "20:00", + "party_size": 2, + }, + ) + assert len(booking_result.content) == 1 + assert isinstance(booking_result.content[0], TextContent) + assert "[SUCCESS] Booked for 2024-12-20 at 20:00" in booking_result.content[0].text + + +async def test_notifications() -> None: """Test notifications and logging functionality.""" - transport = server_transport collector = NotificationCollector() async def message_handler(message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception): @@ -493,150 +272,87 @@ async def message_handler(message: RequestResponder[ServerRequest, ClientResult] if isinstance(message, Exception): # pragma: no cover raise message - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Notifications Example" - - # Call tool that generates notifications - tool_result = await session.call_tool("process_data", {"data": "test_data"}) - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - assert "Processed: test_data" in tool_result.content[0].text - - # Verify log messages at different levels - assert len(collector.log_messages) >= 4 - log_levels = {msg.level for msg in collector.log_messages} - assert "debug" in log_levels - assert "info" in log_levels - assert "warning" in log_levels - assert "error" in log_levels - - # Verify resource list changed notification - assert len(collector.resource_notifications) > 0 - - -# Test completion -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("completion", "sse"), - ("completion", "streamable-http"), - ], - indirect=True, -) -async def test_completion(server_transport: str, server_url: str) -> None: + async with Client(notifications.mcp, message_handler=message_handler) as client: + # Call tool that generates notifications + tool_result = await client.call_tool("process_data", {"data": "test_data"}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert "Processed: test_data" in tool_result.content[0].text + + # Verify log messages at different levels + assert len(collector.log_messages) >= 4 + log_levels = {msg.level for msg in collector.log_messages} + assert "debug" in log_levels + assert "info" in log_levels + assert "warning" in log_levels + assert "error" in log_levels + + # Verify resource list changed notification + assert len(collector.resource_notifications) > 0 + + +async def test_completion() -> None: """Test completion (autocomplete) functionality.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Example" - assert result.capabilities.resources is not None - assert result.capabilities.prompts is not None - - # Test resource completion - completion_result = await session.complete( - ref=ResourceTemplateReference(type="ref/resource", uri="github://repos/{owner}/{repo}"), - argument={"name": "repo", "value": ""}, - context_arguments={"owner": "modelcontextprotocol"}, - ) - - assert completion_result is not None - assert hasattr(completion_result, "completion") - assert completion_result.completion is not None - assert len(completion_result.completion.values) == 3 - assert "python-sdk" in completion_result.completion.values - assert "typescript-sdk" in completion_result.completion.values - assert "specification" in completion_result.completion.values - - # Test prompt completion - completion_result = await session.complete( - ref=PromptReference(type="ref/prompt", name="review_code"), - argument={"name": "language", "value": "py"}, - ) - - assert completion_result is not None - assert hasattr(completion_result, "completion") - assert completion_result.completion is not None - assert "python" in completion_result.completion.values - assert all(lang.startswith("py") for lang in completion_result.completion.values) - - -# Test MCPServer quickstart example -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("mcpserver_quickstart", "sse"), - ("mcpserver_quickstart", "streamable-http"), - ], - indirect=True, -) -async def test_mcpserver_quickstart(server_transport: str, server_url: str) -> None: + async with Client(completion.mcp) as client: + assert client.server_capabilities is not None + assert client.server_capabilities.resources is not None + assert client.server_capabilities.prompts is not None + + # Test resource completion + completion_result = await client.complete( + ref=ResourceTemplateReference(type="ref/resource", uri="github://repos/{owner}/{repo}"), + argument={"name": "repo", "value": ""}, + context_arguments={"owner": "modelcontextprotocol"}, + ) + + assert completion_result is not None + assert hasattr(completion_result, "completion") + assert completion_result.completion is not None + assert len(completion_result.completion.values) == 3 + assert "python-sdk" in completion_result.completion.values + assert "typescript-sdk" in completion_result.completion.values + assert "specification" in completion_result.completion.values + + # Test prompt completion + completion_result = await client.complete( + ref=PromptReference(type="ref/prompt", name="review_code"), + argument={"name": "language", "value": "py"}, + ) + + assert completion_result is not None + assert hasattr(completion_result, "completion") + assert completion_result.completion is not None + assert "python" in completion_result.completion.values + assert all(lang.startswith("py") for lang in completion_result.completion.values) + + +async def test_mcpserver_quickstart() -> None: """Test MCPServer quickstart example.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Demo" - - # Test add tool - tool_result = await session.call_tool("add", {"a": 10, "b": 20}) - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - assert tool_result.content[0].text == "30" - - # Test greeting resource directly - resource_result = await session.read_resource("greeting://Alice") - assert len(resource_result.contents) == 1 - assert isinstance(resource_result.contents[0], TextResourceContents) - assert resource_result.contents[0].text == "Hello, Alice!" - - -# Test structured output example -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("structured_output", "sse"), - ("structured_output", "streamable-http"), - ], - indirect=True, -) -async def test_structured_output(server_transport: str, server_url: str) -> None: + async with Client(mcpserver_quickstart.mcp) as client: + # Test add tool + tool_result = await client.call_tool("add", {"a": 10, "b": 20}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == "30" + + # Test greeting resource directly + resource_result = await client.read_resource("greeting://Alice") + assert len(resource_result.contents) == 1 + assert isinstance(resource_result.contents[0], TextResourceContents) + assert resource_result.contents[0].text == "Hello, Alice!" + + +async def test_structured_output() -> None: """Test structured output functionality.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Structured Output Example" - - # Test get_weather tool - weather_result = await session.call_tool("get_weather", {"city": "New York"}) - assert len(weather_result.content) == 1 - assert isinstance(weather_result.content[0], TextContent) - - # Check that the result contains expected weather data - result_text = weather_result.content[0].text - assert "22.5" in result_text # temperature - assert "sunny" in result_text # condition - assert "45" in result_text # humidity - assert "5.2" in result_text # wind_speed + async with Client(structured_output.mcp) as client: + # Test get_weather tool + weather_result = await client.call_tool("get_weather", {"city": "New York"}) + assert len(weather_result.content) == 1 + assert isinstance(weather_result.content[0], TextContent) + + # Check that the result contains expected weather data + result_text = weather_result.content[0].text + assert "22.5" in result_text # temperature + assert "sunny" in result_text # condition + assert "45" in result_text # humidity + assert "5.2" in result_text # wind_speed From 67201a9bbd9c31419716886d9c1036e6370f83dc Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 18 Mar 2026 15:48:30 +0000 Subject: [PATCH 40/84] test: fix WS test port race; narrow to single smoke test covering both transport ends (#2267) --- src/mcp/server/websocket.py | 8 +- tests/client/test_client.py | 17 ++- tests/shared/test_ws.py | 205 ++++-------------------------------- tests/test_helpers.py | 54 ++++++++++ 4 files changed, 97 insertions(+), 187 deletions(-) diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index 3e675da5fd..32b50560cc 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -10,7 +10,7 @@ from mcp.shared.message import SessionMessage -@asynccontextmanager # pragma: no cover +@asynccontextmanager async def websocket_server(scope: Scope, receive: Receive, send: Send): """WebSocket server transport for MCP. This is an ASGI application, suitable for use with a framework like Starlette and a server like Hypercorn. @@ -34,13 +34,13 @@ async def ws_reader(): async for msg in websocket.iter_text(): try: client_message = types.jsonrpc_message_adapter.validate_json(msg, by_name=False) - except ValidationError as exc: + except ValidationError as exc: # pragma: no cover await read_stream_writer.send(exc) continue session_message = SessionMessage(client_message) await read_stream_writer.send(session_message) - except anyio.ClosedResourceError: + except anyio.ClosedResourceError: # pragma: no cover await websocket.close() async def ws_writer(): @@ -49,7 +49,7 @@ async def ws_writer(): async for session_message in write_stream_reader: obj = session_message.message.model_dump_json(by_alias=True, exclude_unset=True) await websocket.send_text(obj) - except anyio.ClosedResourceError: + except anyio.ClosedResourceError: # pragma: no cover await websocket.close() async with anyio.create_task_group() as tg: diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 45300063a2..3bdd305702 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -8,7 +8,7 @@ import pytest from inline_snapshot import snapshot -from mcp import types +from mcp import MCPError, types from mcp.client._memory import InMemoryTransport from mcp.client.client import Client from mcp.server import Server, ServerRequestContext @@ -175,6 +175,21 @@ async def test_read_resource(app: MCPServer): ) +async def test_read_resource_error_propagates(): + """MCPError raised by a server handler propagates to the client with its code intact.""" + + async def handle_read_resource( + ctx: ServerRequestContext, params: types.ReadResourceRequestParams + ) -> ReadResourceResult: + raise MCPError(code=404, message="no resource with that URI was found") + + server = Server("test", on_read_resource=handle_read_resource) + async with Client(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.read_resource("unknown://example") + assert exc_info.value.error.code == 404 + + async def test_get_prompt(app: MCPServer): """Test getting a prompt.""" async with Client(app) as client: diff --git a/tests/shared/test_ws.py b/tests/shared/test_ws.py index 9addb661de..482dcdcf32 100644 --- a/tests/shared/test_ws.py +++ b/tests/shared/test_ws.py @@ -1,210 +1,51 @@ -import multiprocessing -import socket -from collections.abc import AsyncGenerator, Generator -from urllib.parse import urlparse +"""Smoke test for the WebSocket transport. + +Runs the full WS stack end-to-end over a real TCP connection, covering both +``src/mcp/client/websocket.py`` and ``src/mcp/server/websocket.py``. MCP +semantics (error propagation, timeouts, etc.) are transport-agnostic and are +covered in ``tests/client/test_client.py`` and ``tests/issues/test_88_random_error.py``. +""" + +from collections.abc import Generator -import anyio import pytest -import uvicorn from starlette.applications import Starlette from starlette.routing import WebSocketRoute from starlette.websockets import WebSocket -from mcp import MCPError from mcp.client.session import ClientSession from mcp.client.websocket import websocket_client -from mcp.server import Server, ServerRequestContext +from mcp.server import Server from mcp.server.websocket import websocket_server -from mcp.types import ( - CallToolRequestParams, - CallToolResult, - EmptyResult, - InitializeResult, - ListToolsResult, - PaginatedRequestParams, - ReadResourceRequestParams, - ReadResourceResult, - TextContent, - TextResourceContents, - Tool, -) -from tests.test_helpers import wait_for_server +from mcp.types import EmptyResult, InitializeResult +from tests.test_helpers import run_uvicorn_in_thread SERVER_NAME = "test_server_for_WS" -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: - return f"ws://127.0.0.1:{server_port}" - - -async def handle_read_resource( # pragma: no cover - ctx: ServerRequestContext, params: ReadResourceRequestParams -) -> ReadResourceResult: - parsed = urlparse(str(params.uri)) - if parsed.scheme == "foobar": - return ReadResourceResult( - contents=[TextResourceContents(uri=str(params.uri), text=f"Read {parsed.netloc}", mime_type="text/plain")] - ) - elif parsed.scheme == "slow": - await anyio.sleep(2.0) - return ReadResourceResult( - contents=[ - TextResourceContents( - uri=str(params.uri), text=f"Slow response from {parsed.netloc}", mime_type="text/plain" - ) - ] - ) - raise MCPError(code=404, message="OOPS! no resource with that URI was found") - - -async def handle_list_tools( # pragma: no cover - ctx: ServerRequestContext, params: PaginatedRequestParams | None -) -> ListToolsResult: - return ListToolsResult( - tools=[ - Tool( - name="test_tool", - description="A test tool", - input_schema={"type": "object", "properties": {}}, - ) - ] - ) - +def make_server_app() -> Starlette: + srv = Server(SERVER_NAME) -async def handle_call_tool( # pragma: no cover - ctx: ServerRequestContext, params: CallToolRequestParams -) -> CallToolResult: - return CallToolResult(content=[TextContent(type="text", text=f"Called {params.name}")]) - - -def _create_server() -> Server: # pragma: no cover - return Server( - SERVER_NAME, - on_read_resource=handle_read_resource, - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, - ) - - -# Test fixtures -def make_server_app() -> Starlette: # pragma: no cover - """Create test Starlette app with WebSocket transport""" - server = _create_server() - - async def handle_ws(websocket: WebSocket): + async def handle_ws(websocket: WebSocket) -> None: async with websocket_server(websocket.scope, websocket.receive, websocket.send) as streams: - await server.run(streams[0], streams[1], server.create_initialization_options()) - - app = Starlette(routes=[WebSocketRoute("/ws", endpoint=handle_ws)]) - return app - - -def run_server(server_port: int) -> None: # pragma: no cover - app = make_server_app() - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting server on {server_port}") - server.run() - - -@pytest.fixture() -def server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True) - print("starting process") - proc.start() + await srv.run(streams[0], streams[1], srv.create_initialization_options()) - # Wait for server to be running - print("waiting for server to start") - wait_for_server(server_port) + return Starlette(routes=[WebSocketRoute("/ws", endpoint=handle_ws)]) - yield - print("killing server") - # Signal the server to stop - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("server process failed to terminate") - - -@pytest.fixture() -async def initialized_ws_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: - """Create and initialize a WebSocket client session""" - async with websocket_client(server_url + "/ws") as streams: - async with ClientSession(*streams) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == SERVER_NAME - - # Test ping - ping_result = await session.send_ping() - assert isinstance(ping_result, EmptyResult) - - yield session +@pytest.fixture +def ws_server_url() -> Generator[str, None, None]: + with run_uvicorn_in_thread(make_server_app()) as base_url: + yield base_url.replace("http://", "ws://") + "/ws" -# Tests @pytest.mark.anyio -async def test_ws_client_basic_connection(server: None, server_url: str) -> None: - """Test the WebSocket connection establishment""" - async with websocket_client(server_url + "/ws") as streams: +async def test_ws_client_basic_connection(ws_server_url: str) -> None: + async with websocket_client(ws_server_url) as streams: async with ClientSession(*streams) as session: - # Test initialization result = await session.initialize() assert isinstance(result, InitializeResult) assert result.server_info.name == SERVER_NAME - # Test ping ping_result = await session.send_ping() assert isinstance(ping_result, EmptyResult) - - -@pytest.mark.anyio -async def test_ws_client_happy_request_and_response( - initialized_ws_client_session: ClientSession, -) -> None: - """Test a successful request and response via WebSocket""" - result = await initialized_ws_client_session.read_resource("foobar://example") - assert isinstance(result, ReadResourceResult) - assert isinstance(result.contents, list) - assert len(result.contents) > 0 - assert isinstance(result.contents[0], TextResourceContents) - assert result.contents[0].text == "Read example" - - -@pytest.mark.anyio -async def test_ws_client_exception_handling( - initialized_ws_client_session: ClientSession, -) -> None: - """Test exception handling in WebSocket communication""" - with pytest.raises(MCPError) as exc_info: - await initialized_ws_client_session.read_resource("unknown://example") - assert exc_info.value.error.code == 404 - - -@pytest.mark.anyio -async def test_ws_client_timeout( - initialized_ws_client_session: ClientSession, -) -> None: - """Test timeout handling in WebSocket communication""" - # Set a very short timeout to trigger a timeout exception - with pytest.raises(TimeoutError): - with anyio.fail_after(0.1): # 100ms timeout - await initialized_ws_client_session.read_resource("slow://example") - - # Now test that we can still use the session after a timeout - with anyio.fail_after(5): # Longer timeout to allow completion - result = await initialized_ws_client_session.read_resource("foobar://example") - assert isinstance(result, ReadResourceResult) - assert isinstance(result.contents, list) - assert len(result.contents) > 0 - assert isinstance(result.contents[0], TextResourceContents) - assert result.contents[0].text == "Read example" diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 5c04c269ff..810c72820b 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,7 +1,61 @@ """Common test utilities for MCP server tests.""" import socket +import threading import time +from collections.abc import Generator +from contextlib import contextmanager +from typing import Any + +import uvicorn + +_SERVER_SHUTDOWN_TIMEOUT_S = 5.0 + + +@contextmanager +def run_uvicorn_in_thread(app: Any, **config_kwargs: Any) -> Generator[str, None, None]: + """Run a uvicorn server in a background thread on an ephemeral port. + + The socket is bound and put into listening state *before* the thread + starts, so the port is known immediately with no wait. The kernel's + listen queue buffers any connections that arrive before uvicorn's event + loop reaches ``accept()``, so callers can connect as soon as this + function yields — no polling, no sleeps, no startup race. + + This also avoids the TOCTOU race of the old pick-a-port-then-rebind + pattern: the socket passed here is the one uvicorn serves on, with no + gap where another pytest-xdist worker could claim it. + + Args: + app: ASGI application to serve. + **config_kwargs: Additional keyword arguments for :class:`uvicorn.Config` + (e.g. ``log_level``). ``host``/``port`` are ignored since the + socket is pre-bound. + + Yields: + The base URL of the running server, e.g. ``http://127.0.0.1:54321``. + """ + host = "127.0.0.1" + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((host, 0)) + sock.listen() + port = sock.getsockname()[1] + + config_kwargs.setdefault("log_level", "error") + # Uvicorn's interface autodetection calls asyncio.iscoroutinefunction, + # which Python 3.14 deprecates. Under filterwarnings=error this crashes + # the server thread silently. Starlette is asgi3; skip the autodetect. + config_kwargs.setdefault("interface", "asgi3") + server = uvicorn.Server(config=uvicorn.Config(app=app, **config_kwargs)) + + thread = threading.Thread(target=server.run, kwargs={"sockets": [sock]}, daemon=True) + thread.start() + try: + yield f"http://{host}:{port}" + finally: + server.should_exit = True + thread.join(timeout=_SERVER_SHUTDOWN_TIMEOUT_S) def wait_for_server(port: int, timeout: float = 20.0) -> None: From 20dd94632ef1b8a8e2db5ca9c0c38dab845fa5bb Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 18 Mar 2026 17:31:26 +0000 Subject: [PATCH 41/84] feat(client): store InitializeResult as initialize_result (#2300) --- docs/migration.md | 24 +++++++++++++++++++ src/mcp/client/client.py | 15 ++++++++---- src/mcp/client/session.py | 13 ++++++----- tests/client/test_client.py | 3 ++- tests/client/test_session.py | 27 +++++++++++----------- tests/client/transports/test_memory.py | 2 +- tests/server/mcpserver/test_integration.py | 17 +++++--------- 7 files changed, 64 insertions(+), 37 deletions(-) diff --git a/docs/migration.md b/docs/migration.md index 7cf0325533..dd6a7a18f4 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -169,6 +169,30 @@ result = await session.list_resources(params=PaginatedRequestParams(cursor="next result = await session.list_tools(params=PaginatedRequestParams(cursor="next_page_token")) ``` +### `ClientSession.get_server_capabilities()` replaced by `initialize_result` property + +`ClientSession` now stores the full `InitializeResult` via an `initialize_result` property. This provides access to `server_info`, `capabilities`, `instructions`, and the negotiated `protocol_version` through a single property. The `get_server_capabilities()` method has been removed. + +**Before (v1):** + +```python +capabilities = session.get_server_capabilities() +# server_info, instructions, protocol_version were not stored — had to capture initialize() return value +``` + +**After (v2):** + +```python +result = session.initialize_result +if result is not None: + capabilities = result.capabilities + server_info = result.server_info + instructions = result.instructions + version = result.protocol_version +``` + +The high-level `Client.initialize_result` returns the same `InitializeResult` but is non-nullable — initialization is guaranteed inside the context manager, so no `None` check is needed. This replaces v1's `Client.server_capabilities`; use `client.initialize_result.capabilities` instead. + ### `McpError` renamed to `MCPError` The `McpError` exception class has been renamed to `MCPError` for consistent naming with the MCP acronym style used throughout the SDK. diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 7dc67c5844..34d6a360fa 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -19,6 +19,7 @@ EmptyResult, GetPromptResult, Implementation, + InitializeResult, ListPromptsResult, ListResourcesResult, ListResourceTemplatesResult, @@ -29,7 +30,6 @@ ReadResourceResult, RequestParamsMeta, ResourceTemplateReference, - ServerCapabilities, ) @@ -155,9 +155,16 @@ def session(self) -> ClientSession: return self._session @property - def server_capabilities(self) -> ServerCapabilities | None: - """The server capabilities received during initialization, or None if not yet initialized.""" - return self.session.get_server_capabilities() + def initialize_result(self) -> InitializeResult: + """The server's InitializeResult. + + Contains server_info, capabilities, instructions, and the negotiated protocol_version. + Raises RuntimeError if accessed outside the context manager. + """ + result = self.session.initialize_result + if result is None: # pragma: no cover + raise RuntimeError("Client must be used within an async context manager") + return result async def send_ping(self, *, meta: RequestParamsMeta | None = None) -> EmptyResult: """Send a ping request to the server.""" diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index a0ca751bd7..7c964a334c 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -131,7 +131,7 @@ def __init__( self._logging_callback = logging_callback or _default_logging_callback self._message_handler = message_handler or _default_message_handler self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} - self._server_capabilities: types.ServerCapabilities | None = None + self._initialize_result: types.InitializeResult | None = None self._experimental_features: ExperimentalClientFeatures | None = None # Experimental: Task handlers (use defaults if not provided) @@ -185,18 +185,19 @@ async def initialize(self) -> types.InitializeResult: if result.protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: raise RuntimeError(f"Unsupported protocol version from the server: {result.protocol_version}") - self._server_capabilities = result.capabilities + self._initialize_result = result await self.send_notification(types.InitializedNotification()) return result - def get_server_capabilities(self) -> types.ServerCapabilities | None: - """Return the server capabilities received during initialization. + @property + def initialize_result(self) -> types.InitializeResult | None: + """The server's InitializeResult. None until initialize() has been called. - Returns None if the session has not been initialized yet. + Contains server_info, capabilities, instructions, and the negotiated protocol_version. """ - return self._server_capabilities + return self._initialize_result @property def experimental(self) -> ExperimentalClientFeatures: diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 3bdd305702..18368e6bb3 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -99,7 +99,7 @@ def greeting_prompt(name: str) -> str: async def test_client_is_initialized(app: MCPServer): """Test that the client is initialized after entering context.""" async with Client(app) as client: - assert client.server_capabilities == snapshot( + assert client.initialize_result.capabilities == snapshot( ServerCapabilities( experimental={}, prompts=PromptsCapability(list_changed=False), @@ -107,6 +107,7 @@ async def test_client_is_initialized(app: MCPServer): tools=ToolsCapability(list_changed=False), ) ) + assert client.initialize_result.server_info.name == "test" async def test_client_with_simple_server(simple_server: Server): diff --git a/tests/client/test_session.py b/tests/client/test_session.py index d6d13e273c..f25c964f03 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -540,8 +540,8 @@ async def mock_server(): @pytest.mark.anyio -async def test_get_server_capabilities(): - """Test that get_server_capabilities returns None before init and capabilities after""" +async def test_initialize_result(): + """Test that initialize_result is None before init and contains the full result after.""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) @@ -551,6 +551,8 @@ async def test_get_server_capabilities(): resources=types.ResourcesCapability(subscribe=True, list_changed=True), tools=types.ToolsCapability(list_changed=False), ) + expected_server_info = Implementation(name="mock-server", version="0.1.0") + expected_instructions = "Use the tools wisely." async def mock_server(): session_message = await client_to_server_receive.receive() @@ -564,7 +566,8 @@ async def mock_server(): result = InitializeResult( protocol_version=LATEST_PROTOCOL_VERSION, capabilities=expected_capabilities, - server_info=Implementation(name="mock-server", version="0.1.0"), + server_info=expected_server_info, + instructions=expected_instructions, ) async with server_to_client_send: @@ -590,21 +593,17 @@ async def mock_server(): server_to_client_send, server_to_client_receive, ): - assert session.get_server_capabilities() is None + assert session.initialize_result is None tg.start_soon(mock_server) await session.initialize() - capabilities = session.get_server_capabilities() - assert capabilities is not None - assert capabilities == expected_capabilities - assert capabilities.logging is not None - assert capabilities.prompts is not None - assert capabilities.prompts.list_changed is True - assert capabilities.resources is not None - assert capabilities.resources.subscribe is True - assert capabilities.tools is not None - assert capabilities.tools.list_changed is False + result = session.initialize_result + assert result is not None + assert result.server_info == expected_server_info + assert result.capabilities == expected_capabilities + assert result.instructions == expected_instructions + assert result.protocol_version == LATEST_PROTOCOL_VERSION @pytest.mark.anyio diff --git a/tests/client/transports/test_memory.py b/tests/client/transports/test_memory.py index 47be3e2089..c8fc41fd5d 100644 --- a/tests/client/transports/test_memory.py +++ b/tests/client/transports/test_memory.py @@ -69,7 +69,7 @@ async def test_with_mcpserver(mcpserver_server: MCPServer): async def test_server_is_running(mcpserver_server: MCPServer): """Test that the server is running and responding to requests.""" async with Client(mcpserver_server) as client: - assert client.server_capabilities is not None + assert client.initialize_result.capabilities.tools is not None async def test_list_tools(mcpserver_server: MCPServer): diff --git a/tests/server/mcpserver/test_integration.py b/tests/server/mcpserver/test_integration.py index 90e333b775..f71c0574cd 100644 --- a/tests/server/mcpserver/test_integration.py +++ b/tests/server/mcpserver/test_integration.py @@ -109,8 +109,7 @@ async def elicitation_callback(context: RequestContext[ClientSession], params: E async def test_basic_tools() -> None: """Test basic tool functionality.""" async with Client(basic_tool.mcp) as client: - assert client.server_capabilities is not None - assert client.server_capabilities.tools is not None + assert client.initialize_result.capabilities.tools is not None # Test sum tool tool_result = await client.call_tool("sum", {"a": 5, "b": 3}) @@ -128,8 +127,7 @@ async def test_basic_tools() -> None: async def test_basic_resources() -> None: """Test basic resource functionality.""" async with Client(basic_resource.mcp) as client: - assert client.server_capabilities is not None - assert client.server_capabilities.resources is not None + assert client.initialize_result.capabilities.resources is not None # Test document resource doc_content = await client.read_resource("file://documents/readme") @@ -151,8 +149,7 @@ async def test_basic_resources() -> None: async def test_basic_prompts() -> None: """Test basic prompt functionality.""" async with Client(basic_prompt.mcp) as client: - assert client.server_capabilities is not None - assert client.server_capabilities.prompts is not None + assert client.initialize_result.capabilities.prompts is not None # Test review_code prompt prompts = await client.list_prompts() @@ -223,8 +220,7 @@ async def progress_callback(progress: float, total: float | None, message: str | async def test_sampling() -> None: """Test sampling (LLM interaction) functionality.""" async with Client(sampling.mcp, sampling_callback=sampling_callback) as client: - assert client.server_capabilities is not None - assert client.server_capabilities.tools is not None + assert client.initialize_result.capabilities.tools is not None # Test sampling tool sampling_result = await client.call_tool("generate_poem", {"topic": "nature"}) @@ -294,9 +290,8 @@ async def message_handler(message: RequestResponder[ServerRequest, ClientResult] async def test_completion() -> None: """Test completion (autocomplete) functionality.""" async with Client(completion.mcp) as client: - assert client.server_capabilities is not None - assert client.server_capabilities.resources is not None - assert client.server_capabilities.prompts is not None + assert client.initialize_result.capabilities.resources is not None + assert client.initialize_result.capabilities.prompts is not None # Test resource completion completion_result = await client.complete( From 5388bea53ad7b13db7d031cdfd27392474c89007 Mon Sep 17 00:00:00 2001 From: Jonathan Hefner Date: Wed, 18 Mar 2026 13:15:17 -0500 Subject: [PATCH 42/84] docs: generate hierarchical per-module API reference pages (#2103) --- docs/api.md | 1 - docs/hooks/gen_ref_pages.py | 35 +++++++++++++++++++++++++++++++++++ docs/index.md | 2 +- docs/migration.md | 2 +- mkdocs.yml | 11 +++++++---- pyproject.toml | 2 ++ uv.lock | 28 ++++++++++++++++++++++++++++ 7 files changed, 74 insertions(+), 7 deletions(-) delete mode 100644 docs/api.md create mode 100644 docs/hooks/gen_ref_pages.py diff --git a/docs/api.md b/docs/api.md deleted file mode 100644 index 3f696af543..0000000000 --- a/docs/api.md +++ /dev/null @@ -1 +0,0 @@ -::: mcp diff --git a/docs/hooks/gen_ref_pages.py b/docs/hooks/gen_ref_pages.py new file mode 100644 index 0000000000..ad8c19b45f --- /dev/null +++ b/docs/hooks/gen_ref_pages.py @@ -0,0 +1,35 @@ +"""Generate the code reference pages and navigation.""" + +from pathlib import Path + +import mkdocs_gen_files + +nav = mkdocs_gen_files.Nav() + +root = Path(__file__).parent.parent.parent +src = root / "src" + +for path in sorted(src.rglob("*.py")): + module_path = path.relative_to(src).with_suffix("") + doc_path = path.relative_to(src).with_suffix(".md") + full_doc_path = Path("api", doc_path) + + parts = tuple(module_path.parts) + + if parts[-1] == "__init__": + parts = parts[:-1] + doc_path = doc_path.with_name("index.md") + full_doc_path = full_doc_path.with_name("index.md") + elif parts[-1].startswith("_"): + continue + + nav[parts] = doc_path.as_posix() + + with mkdocs_gen_files.open(full_doc_path, "w") as fd: + ident = ".".join(parts) + fd.write(f"::: {ident}") + + mkdocs_gen_files.set_edit_path(full_doc_path, path.relative_to(root)) + +with mkdocs_gen_files.open("api/SUMMARY.md", "w") as nav_file: + nav_file.writelines(nav.build_literate_nav()) diff --git a/docs/index.md b/docs/index.md index e096d910b4..436d1c8fcd 100644 --- a/docs/index.md +++ b/docs/index.md @@ -64,4 +64,4 @@ npx -y @modelcontextprotocol/inspector ## API Reference -Full API documentation is available in the [API Reference](api.md). +Full API documentation is available in the [API Reference](api/mcp/index.md). diff --git a/docs/migration.md b/docs/migration.md index dd6a7a18f4..3b47f9aade 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -883,6 +883,6 @@ The lowlevel `Server` also now exposes a `session_manager` property to access th If you encounter issues during migration: -1. Check the [API Reference](api.md) for updated method signatures +1. Check the [API Reference](api/mcp/index.md) for updated method signatures 2. Review the [examples](https://github.com/modelcontextprotocol/python-sdk/tree/main/examples) for updated usage patterns 3. Open an issue on [GitHub](https://github.com/modelcontextprotocol/python-sdk/issues) if you find a bug or need further assistance diff --git a/mkdocs.yml b/mkdocs.yml index 070c533e31..3a555785a7 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -25,7 +25,7 @@ nav: - Introduction: experimental/tasks.md - Server Implementation: experimental/tasks-server.md - Client Usage: experimental/tasks-client.md - - API Reference: api.md + - API Reference: api/ theme: name: "material" @@ -115,10 +115,15 @@ plugins: - social: enabled: !ENV [ENABLE_SOCIAL_CARDS, false] - glightbox + - gen-files: + scripts: + - docs/hooks/gen_ref_pages.py + - literate-nav: + nav_file: SUMMARY.md - mkdocstrings: handlers: python: - paths: [src/mcp] + paths: [src] options: relative_crossrefs: true members_order: source @@ -126,8 +131,6 @@ plugins: show_signature_annotations: true signature_crossrefs: true group_by_category: false - # 3 because docs are in pages with an H2 just above them - heading_level: 3 inventories: - url: https://docs.python.org/3/objects.inv - url: https://docs.pydantic.dev/latest/objects.inv diff --git a/pyproject.toml b/pyproject.toml index f275b90cfd..624ade1709 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,9 @@ dev = [ ] docs = [ "mkdocs>=1.6.1", + "mkdocs-gen-files>=0.5.0", "mkdocs-glightbox>=0.4.0", + "mkdocs-literate-nav>=0.6.1", "mkdocs-material[imaging]>=9.5.45", "mkdocstrings-python>=2.0.1", ] diff --git a/uv.lock b/uv.lock index c25047e48d..4af3532ea2 100644 --- a/uv.lock +++ b/uv.lock @@ -840,7 +840,9 @@ dev = [ ] docs = [ { name = "mkdocs" }, + { name = "mkdocs-gen-files" }, { name = "mkdocs-glightbox" }, + { name = "mkdocs-literate-nav" }, { name = "mkdocs-material", extra = ["imaging"] }, { name = "mkdocstrings-python" }, ] @@ -888,7 +890,9 @@ dev = [ ] docs = [ { name = "mkdocs", specifier = ">=1.6.1" }, + { name = "mkdocs-gen-files", specifier = ">=0.5.0" }, { name = "mkdocs-glightbox", specifier = ">=0.4.0" }, + { name = "mkdocs-literate-nav", specifier = ">=0.6.1" }, { name = "mkdocs-material", extras = ["imaging"], specifier = ">=9.5.45" }, { name = "mkdocstrings-python", specifier = ">=2.0.1" }, ] @@ -1501,6 +1505,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9f/4d/7123b6fa2278000688ebd338e2a06d16870aaf9eceae6ba047ea05f92df1/mkdocs_autorefs-1.4.3-py3-none-any.whl", hash = "sha256:469d85eb3114801d08e9cc55d102b3ba65917a869b893403b8987b601cf55dc9", size = 25034, upload-time = "2025-08-26T14:23:15.906Z" }, ] +[[package]] +name = "mkdocs-gen-files" +version = "0.6.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mkdocs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/61/35/f26349f7fa18414eb2e25d75a6fa9c7e3186c36e1d227c0b2d785a7bd5c4/mkdocs_gen_files-0.6.0.tar.gz", hash = "sha256:52022dc14dcc0451e05e54a8f5d5e7760351b6701eff816d1e9739577ec5635e", size = 8642, upload-time = "2025-11-23T12:13:22.124Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8d/ec/72417415563c60ae01b36f0d497f1f4c803972f447ef4fb7f7746d6e07db/mkdocs_gen_files-0.6.0-py3-none-any.whl", hash = "sha256:815af15f3e2dbfda379629c1b95c02c8e6f232edf2a901186ea3b204ab1135b2", size = 8182, upload-time = "2025-11-23T12:13:20.756Z" }, +] + [[package]] name = "mkdocs-get-deps" version = "0.2.0" @@ -1527,6 +1543,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/30/cf/e9a0ce9da269746906fdc595c030f6df66793dad1487abd1699af2ba44f1/mkdocs_glightbox-0.5.1-py3-none-any.whl", hash = "sha256:f47af0daff164edf8d36e553338425be3aab6e34b987d9cbbc2ae7819a98cb01", size = 26431, upload-time = "2025-09-04T13:10:27.933Z" }, ] +[[package]] +name = "mkdocs-literate-nav" +version = "0.6.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mkdocs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f6/5f/99aa379b305cd1c2084d42db3d26f6de0ea9bf2cc1d10ed17f61aff35b9a/mkdocs_literate_nav-0.6.2.tar.gz", hash = "sha256:760e1708aa4be86af81a2b56e82c739d5a8388a0eab1517ecfd8e5aa40810a75", size = 17419, upload-time = "2025-03-18T21:53:09.711Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8a/84/b5b14d2745e4dd1a90115186284e9ee1b4d0863104011ab46abb7355a1c3/mkdocs_literate_nav-0.6.2-py3-none-any.whl", hash = "sha256:0a6489a26ec7598477b56fa112056a5e3a6c15729f0214bea8a4dbc55bd5f630", size = 13261, upload-time = "2025-03-18T21:53:08.1Z" }, +] + [[package]] name = "mkdocs-material" version = "9.7.2" From 883d89309755def3048d2c8b6ad85a2b9861130b Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 19 Mar 2026 15:16:34 +0000 Subject: [PATCH 43/84] test: rewrite cli.claude config tests to assert JSON output directly (#2311) Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> Co-authored-by: Felix Weinberger --- src/mcp/cli/claude.py | 20 +++-- tests/cli/test_claude.py | 146 ++++++++++++++++++++++++++++++++++++ tests/client/test_config.py | 75 ------------------ 3 files changed, 158 insertions(+), 83 deletions(-) create mode 100644 tests/cli/test_claude.py delete mode 100644 tests/client/test_config.py diff --git a/src/mcp/cli/claude.py b/src/mcp/cli/claude.py index 071b4b6fbb..93bf218fbb 100644 --- a/src/mcp/cli/claude.py +++ b/src/mcp/cli/claude.py @@ -33,7 +33,7 @@ def get_claude_config_path() -> Path | None: # pragma: no cover def get_uv_path() -> str: """Get the full path to the uv executable.""" uv_path = shutil.which("uv") - if not uv_path: # pragma: no cover + if not uv_path: logger.error( "uv executable not found in PATH, falling back to 'uv'. Please ensure uv is installed and in your PATH" ) @@ -65,7 +65,7 @@ def update_claude_config( """ config_dir = get_claude_config_path() uv_path = get_uv_path() - if not config_dir: # pragma: no cover + if not config_dir: raise RuntimeError( "Claude Desktop config directory not found. Please ensure Claude Desktop" " is installed and has been run at least once to initialize its config." @@ -90,7 +90,7 @@ def update_claude_config( config["mcpServers"] = {} # Always preserve existing env vars and merge with new ones - if server_name in config["mcpServers"] and "env" in config["mcpServers"][server_name]: # pragma: no cover + if server_name in config["mcpServers"] and "env" in config["mcpServers"][server_name]: existing_env = config["mcpServers"][server_name]["env"] if env_vars: # New vars take precedence over existing ones @@ -103,22 +103,26 @@ def update_claude_config( # Collect all packages in a set to deduplicate packages = {MCP_PACKAGE} - if with_packages: # pragma: no cover + if with_packages: packages.update(pkg for pkg in with_packages if pkg) # Add all packages with --with for pkg in sorted(packages): args.extend(["--with", pkg]) - if with_editable: # pragma: no cover + if with_editable: args.extend(["--with-editable", str(with_editable)]) # Convert file path to absolute before adding to command # Split off any :object suffix first - if ":" in file_spec: + # First check if we have a Windows path (e.g., C:\...) + has_windows_drive = len(file_spec) > 1 and file_spec[1] == ":" + + # Split on the last colon, but only if it's not part of the Windows drive letter + if ":" in (file_spec[2:] if has_windows_drive else file_spec): file_path, server_object = file_spec.rsplit(":", 1) file_spec = f"{Path(file_path).resolve()}:{server_object}" - else: # pragma: no cover + else: file_spec = str(Path(file_spec).resolve()) # Add mcp run command @@ -127,7 +131,7 @@ def update_claude_config( server_config: dict[str, Any] = {"command": uv_path, "args": args} # Add environment variables if specified - if env_vars: # pragma: no cover + if env_vars: server_config["env"] = env_vars config["mcpServers"][server_name] = server_config diff --git a/tests/cli/test_claude.py b/tests/cli/test_claude.py new file mode 100644 index 0000000000..73d4f0eb52 --- /dev/null +++ b/tests/cli/test_claude.py @@ -0,0 +1,146 @@ +"""Tests for mcp.cli.claude — Claude Desktop config file generation.""" + +import json +from pathlib import Path +from typing import Any + +import pytest + +from mcp.cli.claude import get_uv_path, update_claude_config + + +@pytest.fixture +def config_dir(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path: + """Temp Claude config dir with get_claude_config_path and get_uv_path mocked.""" + claude_dir = tmp_path / "Claude" + claude_dir.mkdir() + monkeypatch.setattr("mcp.cli.claude.get_claude_config_path", lambda: claude_dir) + monkeypatch.setattr("mcp.cli.claude.get_uv_path", lambda: "/fake/bin/uv") + return claude_dir + + +def _read_server(config_dir: Path, name: str) -> dict[str, Any]: + config = json.loads((config_dir / "claude_desktop_config.json").read_text()) + return config["mcpServers"][name] + + +def test_generates_uv_run_command(config_dir: Path): + """Should write a uv run command that invokes mcp run on the resolved file spec.""" + assert update_claude_config(file_spec="server.py:app", server_name="my_server") + + resolved = Path("server.py").resolve() + assert _read_server(config_dir, "my_server") == { + "command": "/fake/bin/uv", + "args": ["run", "--frozen", "--with", "mcp[cli]", "mcp", "run", f"{resolved}:app"], + } + + +def test_file_spec_without_object_suffix(config_dir: Path): + """File specs without :object should still resolve to an absolute path.""" + assert update_claude_config(file_spec="server.py", server_name="s") + + assert _read_server(config_dir, "s")["args"][-1] == str(Path("server.py").resolve()) + + +def test_with_packages_sorted_and_deduplicated(config_dir: Path): + """Extra packages should appear as --with flags, sorted and deduplicated with mcp[cli].""" + assert update_claude_config(file_spec="s.py:app", server_name="s", with_packages=["zebra", "aardvark", "zebra"]) + + args = _read_server(config_dir, "s")["args"] + assert args[:8] == ["run", "--frozen", "--with", "aardvark", "--with", "mcp[cli]", "--with", "zebra"] + + +def test_with_editable_adds_flag(config_dir: Path, tmp_path: Path): + """with_editable should add --with-editable after the --with flags.""" + editable = tmp_path / "project" + assert update_claude_config(file_spec="s.py:app", server_name="s", with_editable=editable) + + args = _read_server(config_dir, "s")["args"] + assert args[4:6] == ["--with-editable", str(editable)] + + +def test_env_vars_written(config_dir: Path): + """env_vars should be written under the server's env key.""" + assert update_claude_config(file_spec="s.py:app", server_name="s", env_vars={"KEY": "val"}) + + assert _read_server(config_dir, "s")["env"] == {"KEY": "val"} + + +def test_existing_env_vars_merged_new_wins(config_dir: Path): + """Re-installing should merge env vars, with new values overriding existing ones.""" + (config_dir / "claude_desktop_config.json").write_text( + json.dumps({"mcpServers": {"s": {"env": {"OLD": "keep", "KEY": "old"}}}}) + ) + + assert update_claude_config(file_spec="s.py:app", server_name="s", env_vars={"KEY": "new"}) + + assert _read_server(config_dir, "s")["env"] == {"OLD": "keep", "KEY": "new"} + + +def test_existing_env_vars_preserved_without_new(config_dir: Path): + """Re-installing without env_vars should keep the existing env block intact.""" + (config_dir / "claude_desktop_config.json").write_text(json.dumps({"mcpServers": {"s": {"env": {"KEEP": "me"}}}})) + + assert update_claude_config(file_spec="s.py:app", server_name="s") + + assert _read_server(config_dir, "s")["env"] == {"KEEP": "me"} + + +def test_other_servers_preserved(config_dir: Path): + """Installing a new server should not clobber existing mcpServers entries.""" + (config_dir / "claude_desktop_config.json").write_text(json.dumps({"mcpServers": {"other": {"command": "x"}}})) + + assert update_claude_config(file_spec="s.py:app", server_name="s") + + config = json.loads((config_dir / "claude_desktop_config.json").read_text()) + assert set(config["mcpServers"]) == {"other", "s"} + assert config["mcpServers"]["other"] == {"command": "x"} + + +def test_raises_when_config_dir_missing(monkeypatch: pytest.MonkeyPatch): + """Should raise RuntimeError when Claude Desktop config dir can't be found.""" + monkeypatch.setattr("mcp.cli.claude.get_claude_config_path", lambda: None) + monkeypatch.setattr("mcp.cli.claude.get_uv_path", lambda: "/fake/bin/uv") + + with pytest.raises(RuntimeError, match="Claude Desktop config directory not found"): + update_claude_config(file_spec="s.py:app", server_name="s") + + +@pytest.mark.parametrize("which_result, expected", [("/usr/local/bin/uv", "/usr/local/bin/uv"), (None, "uv")]) +def test_get_uv_path(monkeypatch: pytest.MonkeyPatch, which_result: str | None, expected: str): + """Should return shutil.which's result, or fall back to bare 'uv' when not on PATH.""" + + def fake_which(cmd: str) -> str | None: + return which_result + + monkeypatch.setattr("shutil.which", fake_which) + assert get_uv_path() == expected + + +@pytest.mark.parametrize( + "file_spec, expected_last_arg", + [ + ("C:\\Users\\server.py", "C:\\Users\\server.py"), + ("C:\\Users\\server.py:app", "C:\\Users\\server.py:app"), + ], +) +def test_windows_drive_letter_not_split( + config_dir: Path, monkeypatch: pytest.MonkeyPatch, file_spec: str, expected_last_arg: str +): + """Drive-letter paths like 'C:\\server.py' must not be split on the drive colon. + + Before the fix, a bare 'C:\\path\\server.py' would hit rsplit(":", 1) and yield + ("C", "\\path\\server.py"), calling resolve() on Path("C") instead of the full path. + """ + seen: list[str] = [] + + def fake_resolve(self: Path) -> Path: + seen.append(str(self)) + return self + + monkeypatch.setattr(Path, "resolve", fake_resolve) + + assert update_claude_config(file_spec=file_spec, server_name="s") + + assert seen == ["C:\\Users\\server.py"] + assert _read_server(config_dir, "s")["args"][-1] == expected_last_arg diff --git a/tests/client/test_config.py b/tests/client/test_config.py deleted file mode 100644 index d1a0576ff3..0000000000 --- a/tests/client/test_config.py +++ /dev/null @@ -1,75 +0,0 @@ -import json -import subprocess -from pathlib import Path -from unittest.mock import patch - -import pytest - -from mcp.cli.claude import update_claude_config - - -@pytest.fixture -def temp_config_dir(tmp_path: Path): - """Create a temporary Claude config directory.""" - config_dir = tmp_path / "Claude" - config_dir.mkdir() - return config_dir - - -@pytest.fixture -def mock_config_path(temp_config_dir: Path): - """Mock get_claude_config_path to return our temporary directory.""" - with patch("mcp.cli.claude.get_claude_config_path", return_value=temp_config_dir): - yield temp_config_dir - - -def test_command_execution(mock_config_path: Path): - """Test that the generated command can actually be executed.""" - # Setup - server_name = "test_server" - file_spec = "test_server.py:app" - - # Update config - success = update_claude_config(file_spec=file_spec, server_name=server_name) - assert success - - # Read the generated config - config_file = mock_config_path / "claude_desktop_config.json" - config = json.loads(config_file.read_text()) - - # Get the command and args - server_config = config["mcpServers"][server_name] - command = server_config["command"] - args = server_config["args"] - - test_args = [command] + args + ["--help"] - - result = subprocess.run(test_args, capture_output=True, text=True, timeout=20, check=False) - - assert result.returncode == 0 - assert "usage" in result.stdout.lower() - - -def test_absolute_uv_path(mock_config_path: Path): - """Test that the absolute path to uv is used when available.""" - # Mock the shutil.which function to return a fake path - mock_uv_path = "/usr/local/bin/uv" - - with patch("mcp.cli.claude.get_uv_path", return_value=mock_uv_path): - # Setup - server_name = "test_server" - file_spec = "test_server.py:app" - - # Update config - success = update_claude_config(file_spec=file_spec, server_name=server_name) - assert success - - # Read the generated config - config_file = mock_config_path / "claude_desktop_config.json" - config = json.loads(config_file.read_text()) - - # Verify the command is the absolute path - server_config = config["mcpServers"][server_name] - command = server_config["command"] - - assert command == mock_uv_path From 92c693bb730d059b2cc3836ccf0cdfe0720e9e18 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 20 Mar 2026 13:37:32 +0000 Subject: [PATCH 44/84] fix: cancel in-flight handlers when transport closes in server.run() (#2306) --- src/mcp/server/lowlevel/server.py | 51 ++++++--- src/mcp/shared/session.py | 6 +- tests/server/test_cancel_handling.py | 158 +++++++++++++++++++++++++++ 3 files changed, 199 insertions(+), 16 deletions(-) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 167f34b8bc..c288422720 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -387,16 +387,23 @@ async def run( await stack.enter_async_context(task_support.run()) async with anyio.create_task_group() as tg: - async for message in session.incoming_messages: - logger.debug("Received message: %s", message) - - tg.start_soon( - self._handle_message, - message, - session, - lifespan_context, - raise_exceptions, - ) + try: + async for message in session.incoming_messages: + logger.debug("Received message: %s", message) + + tg.start_soon( + self._handle_message, + message, + session, + lifespan_context, + raise_exceptions, + ) + finally: + # Transport closed: cancel in-flight handlers. Without this the + # TG join waits for them, and when they eventually try to + # respond they hit a closed write stream (the session's + # _receive_loop closed it when the read stream ended). + tg.cancel_scope.cancel() async def _handle_message( self, @@ -470,16 +477,32 @@ async def _handle_request( except MCPError as err: response = err.error except anyio.get_cancelled_exc_class(): - logger.info("Request %s cancelled - duplicate response suppressed", message.request_id) - return + if message.cancelled: + # Client sent CancelledNotification; responder.cancel() already + # sent an error response, so skip the duplicate. + logger.info("Request %s cancelled - duplicate response suppressed", message.request_id) + return + # Transport-close cancellation from the TG in run(); re-raise so the + # TG swallows its own cancellation. + raise except Exception as err: if raise_exceptions: # pragma: no cover raise err response = types.ErrorData(code=0, message=str(err)) + else: # pragma: no cover + response = types.ErrorData(code=types.METHOD_NOT_FOUND, message="Method not found") + try: await message.respond(response) - else: # pragma: no cover - await message.respond(types.ErrorData(code=types.METHOD_NOT_FOUND, message="Method not found")) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + # Transport closed between handler unblocking and respond. Happens + # when _receive_loop's finally wakes a handler blocked on + # send_request: the handler runs to respond() before run()'s TG + # cancel fires, but after the write stream closed. Closed if our + # end closed (_receive_loop's async-with exit); Broken if the peer + # end closed first (streamable_http terminate()). + logger.debug("Response for %s dropped - transport closed", message.request_id) + return logger.debug("Response sent") diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 9364abb73b..6fc59923f7 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -105,7 +105,7 @@ def __exit__( ) -> None: """Exit the context manager, performing cleanup and notifying completion.""" try: - if self._completed: # pragma: no branch + if self._completed: self._on_complete(self) finally: self._entered = False @@ -418,7 +418,9 @@ async def _receive_loop(self) -> None: finally: # after the read stream is closed, we need to send errors # to any pending requests - for id, stream in self._response_streams.items(): + # Snapshot: stream.send() wakes the waiter, whose finally pops + # from _response_streams before the next __next__() call. + for id, stream in list(self._response_streams.items()): error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed") try: await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index 297f3d6a5c..cff5a37c15 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -6,12 +6,19 @@ from mcp import Client from mcp.server import Server, ServerRequestContext from mcp.shared.exceptions import MCPError +from mcp.shared.message import SessionMessage from mcp.types import ( + LATEST_PROTOCOL_VERSION, CallToolRequest, CallToolRequestParams, CallToolResult, CancelledNotification, CancelledNotificationParams, + ClientCapabilities, + Implementation, + InitializeRequestParams, + JSONRPCNotification, + JSONRPCRequest, ListToolsResult, PaginatedRequestParams, TextContent, @@ -90,3 +97,154 @@ async def first_request(): assert isinstance(content, TextContent) assert content.text == "Call number: 2" assert call_count == 2 + + +@pytest.mark.anyio +async def test_server_cancels_in_flight_handlers_on_transport_close(): + """When the transport closes mid-request, server.run() must cancel in-flight + handlers rather than join on them. + + Without the cancel, the task group waits for the handler, which then tries + to respond through a write stream that _receive_loop already closed, + raising ClosedResourceError and crashing server.run() with exit code 1. + + This drives server.run() with raw memory streams because InMemoryTransport + wraps it in its own finally-cancel (_memory.py) which masks the bug. + """ + handler_started = anyio.Event() + handler_cancelled = anyio.Event() + server_run_returned = anyio.Event() + + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + handler_started.set() + try: + await anyio.sleep_forever() + finally: + handler_cancelled.set() + # unreachable: sleep_forever only exits via cancellation + raise AssertionError # pragma: no cover + + server = Server("test", on_call_tool=handle_call_tool) + + to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10) + server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server(): + await server.run(server_read, server_write, server.create_initialization_options()) + server_run_returned.set() + + init_req = JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=InitializeRequestParams( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ClientCapabilities(), + client_info=Implementation(name="test", version="1.0"), + ).model_dump(by_alias=True, mode="json", exclude_none=True), + ) + initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized") + call_req = JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"), + ) + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server: + tg.start_soon(run_server) + + await to_server.send(SessionMessage(init_req)) + await from_server.receive() # init response + await to_server.send(SessionMessage(initialized)) + await to_server.send(SessionMessage(call_req)) + + await handler_started.wait() + + # Close the server's input stream — this is what stdin EOF does. + # server.run()'s incoming_messages loop ends, finally-cancel fires, + # handler gets CancelledError, server.run() returns. + await to_server.aclose() + + await server_run_returned.wait() + + assert handler_cancelled.is_set() + + +@pytest.mark.anyio +async def test_server_handles_transport_close_with_pending_server_to_client_requests(): + """When the transport closes while handlers are blocked on server→client + requests (sampling, roots, elicitation), server.run() must still exit cleanly. + + Two bugs covered: + 1. _receive_loop's finally iterates _response_streams with await checkpoints + inside; the woken handler's send_request finally pops from that dict + before the next __next__() — RuntimeError: dictionary changed size. + 2. The woken handler's MCPError is caught in _handle_request, which falls + through to respond() against a write stream _receive_loop already closed. + """ + handlers_started = 0 + both_started = anyio.Event() + server_run_returned = anyio.Event() + + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + nonlocal handlers_started + handlers_started += 1 + if handlers_started == 2: + both_started.set() + # Blocks on send_request waiting for a client response that never comes. + # _receive_loop's finally will wake this with CONNECTION_CLOSED. + await ctx.session.list_roots() + raise AssertionError # pragma: no cover + + server = Server("test", on_call_tool=handle_call_tool) + + to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10) + server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server(): + await server.run(server_read, server_write, server.create_initialization_options()) + server_run_returned.set() + + init_req = JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=InitializeRequestParams( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ClientCapabilities(), + client_info=Implementation(name="test", version="1.0"), + ).model_dump(by_alias=True, mode="json", exclude_none=True), + ) + initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized") + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server: + tg.start_soon(run_server) + + await to_server.send(SessionMessage(init_req)) + await from_server.receive() # init response + await to_server.send(SessionMessage(initialized)) + + # Two tool calls → two handlers → two _response_streams entries. + for rid in (2, 3): + call_req = JSONRPCRequest( + jsonrpc="2.0", + id=rid, + method="tools/call", + params=CallToolRequestParams(name="t", arguments={}).model_dump(by_alias=True, mode="json"), + ) + await to_server.send(SessionMessage(call_req)) + + await both_started.wait() + # Drain the two roots/list requests so send_request's _write_stream.send() + # completes and both handlers are parked at response_stream_reader.receive(). + await from_server.receive() + await from_server.receive() + + await to_server.aclose() + + # Without the fixes: RuntimeError (dict mutation) or ClosedResourceError + # (respond after write-stream close) escapes run_server and this hangs. + await server_run_returned.wait() From 7ba4fb881d85406f44a5af8169fb7200fa7c8e49 Mon Sep 17 00:00:00 2001 From: Felix Weinberger <3823880+felixweinberger@users.noreply.github.com> Date: Tue, 24 Mar 2026 20:36:36 +0000 Subject: [PATCH 45/84] ci: skip claude.yml when comment is '@claude review' (#2337) --- .github/workflows/claude.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml index 8421cf954c..59dac99dcb 100644 --- a/.github/workflows/claude.yml +++ b/.github/workflows/claude.yml @@ -14,7 +14,7 @@ on: jobs: claude: if: | - (github.event_name == 'issue_comment' && contains(github.event.comment.body, '@claude')) || + (github.event_name == 'issue_comment' && contains(github.event.comment.body, '@claude') && !startsWith(github.event.comment.body, '@claude review')) || (github.event_name == 'pull_request_review_comment' && contains(github.event.comment.body, '@claude')) || (github.event_name == 'pull_request_review' && contains(github.event.review.body, '@claude')) || (github.event_name == 'issues' && (contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude'))) From 98f8ef295a6a4178ebeb75fd2e9f346be5c9eb0e Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 25 Mar 2026 23:29:13 +0100 Subject: [PATCH 46/84] Restrict httpx version to <1.0.0 (#2345) Co-authored-by: Max Isbey <224885523+maxisbey@users.noreply.github.com> --- pyproject.toml | 2 +- uv.lock | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 624ade1709..e1b19e3c9f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ ] dependencies = [ "anyio>=4.9", - "httpx>=0.27.1", + "httpx>=0.27.1,<1.0.0", "httpx-sse>=0.4", "pydantic>=2.12.0", "starlette>=0.48.0; python_version >= '3.14'", diff --git a/uv.lock b/uv.lock index 4af3532ea2..8f9a5396a2 100644 --- a/uv.lock +++ b/uv.lock @@ -850,7 +850,7 @@ docs = [ [package.metadata] requires-dist = [ { name = "anyio", specifier = ">=4.9" }, - { name = "httpx", specifier = ">=0.27.1" }, + { name = "httpx", specifier = ">=0.27.1,<1.0.0" }, { name = "httpx-sse", specifier = ">=0.4" }, { name = "jsonschema", specifier = ">=4.20.0" }, { name = "pydantic", specifier = ">=2.12.0" }, From 3517a29c828d596f0e3fb5f82fcfc86fd7a14dd0 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 27 Mar 2026 13:42:15 +0000 Subject: [PATCH 47/84] feat(server): restore `dependencies` parameter on MCPServer (#2358) --- src/mcp/server/mcpserver/server.py | 6 ++++++ tests/server/mcpserver/test_server.py | 9 +++++++++ 2 files changed, 15 insertions(+) diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index 2a7a58117a..6f9bb0e287 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -105,6 +105,9 @@ class Settings(BaseSettings, Generic[LifespanResultT]): # prompt settings warn_on_duplicate_prompts: bool + dependencies: list[str] + """List of dependencies to install in the server environment. Used by the `mcp install` and `mcp dev` CLI.""" + lifespan: Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None """An async context manager that will be called when the server is started.""" @@ -142,6 +145,7 @@ def __init__( warn_on_duplicate_resources: bool = True, warn_on_duplicate_tools: bool = True, warn_on_duplicate_prompts: bool = True, + dependencies: list[str] | None = None, lifespan: Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None = None, auth: AuthSettings | None = None, ): @@ -151,9 +155,11 @@ def __init__( warn_on_duplicate_resources=warn_on_duplicate_resources, warn_on_duplicate_tools=warn_on_duplicate_tools, warn_on_duplicate_prompts=warn_on_duplicate_prompts, + dependencies=dependencies or [], lifespan=lifespan, auth=auth, ) + self.dependencies = self.settings.dependencies self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools) self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources) diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index 3ef06d0381..49b6deb4bb 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -65,6 +65,15 @@ async def test_create_server(self): assert len(mcp.icons) == 1 assert mcp.icons[0].src == "https://example.com/icon.png" + def test_dependencies(self): + """Dependencies list is read by `mcp install` / `mcp dev` CLI commands.""" + mcp = MCPServer("test", dependencies=["pandas", "numpy"]) + assert mcp.dependencies == ["pandas", "numpy"] + assert mcp.settings.dependencies == ["pandas", "numpy"] + + mcp_no_deps = MCPServer("test") + assert mcp_no_deps.dependencies == [] + async def test_sse_app_returns_starlette_app(self): """Test that sse_app returns a Starlette application with correct routes.""" mcp = MCPServer("test") From fb2276b95fb5ac6631f9c160c9b1bd7a6d7312a9 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 27 Mar 2026 14:02:41 +0000 Subject: [PATCH 48/84] ci: remove claude-code-review workflow (#2359) --- .github/workflows/claude-code-review.yml | 33 ------------------------ 1 file changed, 33 deletions(-) delete mode 100644 .github/workflows/claude-code-review.yml diff --git a/.github/workflows/claude-code-review.yml b/.github/workflows/claude-code-review.yml deleted file mode 100644 index 514f979d7c..0000000000 --- a/.github/workflows/claude-code-review.yml +++ /dev/null @@ -1,33 +0,0 @@ -# Source: https://github.com/anthropics/claude-code-action/blob/main/docs/code-review.md -name: Claude Code Review - -on: - pull_request: - types: [opened, synchronize, ready_for_review, reopened] - -jobs: - claude-review: - # Fork PRs don't have access to secrets or OIDC tokens, so the action - # cannot authenticate. See https://github.com/anthropics/claude-code-action/issues/339 - if: github.event.pull_request.head.repo.fork == false && github.actor != 'dependabot[bot]' - runs-on: ubuntu-latest - permissions: - contents: read - pull-requests: read - issues: read - id-token: write - - steps: - - name: Checkout repository - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - fetch-depth: 1 - - - name: Run Claude Code Review - id: claude-review - uses: anthropics/claude-code-action@2f8ba26a219c06cfb0f468eef8d97055fa814f97 # v1.0.53 - with: - anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} - plugin_marketplaces: "https://github.com/anthropics/claude-code.git" - plugins: "code-review@claude-code-plugins" - prompt: "/code-review:code-review ${{ github.repository }}/pull/${{ github.event.pull_request.number }}" From e6235d1667ee59c63cfa365ce0136beae05f067a Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 31 Mar 2026 12:49:38 -0400 Subject: [PATCH 49/84] Propagate contextvars.Context through anyio streams without modifying SessionMessage (#2298) --- .../mcp_simple_auth_client/main.py | 6 +- pyproject.toml | 5 +- src/mcp/client/__main__.py | 6 +- src/mcp/client/_transport.py | 7 +- src/mcp/client/session.py | 6 +- src/mcp/client/sse.py | 23 ++-- src/mcp/client/streamable_http.py | 23 ++-- src/mcp/server/lowlevel/server.py | 15 ++- src/mcp/server/session.py | 7 +- src/mcp/server/sse.py | 13 +- src/mcp/server/stdio.py | 12 +- src/mcp/server/streamable_http.py | 18 +-- src/mcp/server/websocket.py | 12 +- src/mcp/shared/_context_streams.py | 119 ++++++++++++++++++ src/mcp/shared/_stream_protocols.py | 49 ++++++++ src/mcp/shared/memory.py | 10 +- src/mcp/shared/session.py | 26 ++-- tests/client/conftest.py | 4 +- tests/client/test_client.py | 33 +++++ tests/shared/test_streamable_http.py | 10 +- 20 files changed, 310 insertions(+), 94 deletions(-) create mode 100644 src/mcp/shared/_context_streams.py create mode 100644 src/mcp/shared/_stream_protocols.py diff --git a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py index 5fac56be5d..6ef2f0b113 100644 --- a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py +++ b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py @@ -18,7 +18,7 @@ from urllib.parse import parse_qs, urlparse import httpx -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from mcp.client._transport import ReadStream, WriteStream from mcp.client.auth import OAuthClientProvider, TokenStorage from mcp.client.session import ClientSession from mcp.client.sse import sse_client @@ -241,8 +241,8 @@ async def _default_redirect_handler(authorization_url: str) -> None: async def _run_session( self, - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], ): """Run the MCP session with the given streams.""" print("🤝 Initializing MCP session...") diff --git a/pyproject.toml b/pyproject.toml index e1b19e3c9f..7d8b4a8743 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -219,13 +219,10 @@ skip_covered = true show_missing = true ignore_errors = true precision = 2 -exclude_lines = [ - "pragma: no cover", +exclude_also = [ "pragma: lax no cover", - "if TYPE_CHECKING:", "@overload", "raise NotImplementedError", - "^\\s*\\.\\.\\.\\s*$", ] # https://coverage.readthedocs.io/en/latest/config.html#paths diff --git a/src/mcp/client/__main__.py b/src/mcp/client/__main__.py index f3db17906d..b9ec344226 100644 --- a/src/mcp/client/__main__.py +++ b/src/mcp/client/__main__.py @@ -6,9 +6,9 @@ from urllib.parse import urlparse import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp import types +from mcp.client._transport import ReadStream, WriteStream from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.stdio import StdioServerParameters, stdio_client @@ -33,8 +33,8 @@ async def message_handler( async def run_session( - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], client_info: types.Implementation | None = None, ): async with ClientSession( diff --git a/src/mcp/client/_transport.py b/src/mcp/client/_transport.py index a863629005..0163fef950 100644 --- a/src/mcp/client/_transport.py +++ b/src/mcp/client/_transport.py @@ -5,11 +5,12 @@ from contextlib import AbstractAsyncContextManager from typing import Protocol -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream - +from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.message import SessionMessage -TransportStreams = tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]] +__all__ = ["ReadStream", "WriteStream", "Transport", "TransportStreams"] + +TransportStreams = tuple[ReadStream[SessionMessage | Exception], WriteStream[SessionMessage]] class Transport(AbstractAsyncContextManager[TransportStreams], Protocol): diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 7c964a334c..0cea454a77 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -4,10 +4,10 @@ from typing import Any, Protocol import anyio.lowlevel -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import TypeAdapter from mcp import types +from mcp.client._transport import ReadStream, WriteStream from mcp.client.experimental import ExperimentalClientFeatures from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers from mcp.shared._context import RequestContext @@ -109,8 +109,8 @@ class ClientSession( ): def __init__( self, - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], read_timeout_seconds: float | None = None, sampling_callback: SamplingFnT | None = None, elicitation_callback: ElicitationFnT | None = None, diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 7b66b5c1b6..193204a153 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -7,11 +7,11 @@ import anyio import httpx from anyio.abc import TaskStatus -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from httpx_sse import aconnect_sse from httpx_sse._exceptions import SSEError from mcp import types +from mcp.shared._context_streams import create_context_streams from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client from mcp.shared.message import SessionMessage @@ -51,12 +51,6 @@ async def sse_client( auth: Optional HTTPX authentication handler. on_session_created: Optional callback invoked with the session ID when received. """ - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] - read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - - write_stream: MemoryObjectSendStream[SessionMessage] - write_stream_reader: MemoryObjectReceiveStream[SessionMessage] - logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}") async with httpx_client_factory( headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout) @@ -65,8 +59,8 @@ async def sse_client( event_source.response.raise_for_status() logger.debug("SSE connection established") - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) + write_stream, write_stream_reader = create_context_streams[SessionMessage](0) async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED): try: @@ -124,7 +118,8 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED): async def post_writer(endpoint_url: str): try: async with write_stream_reader, write_stream: - async for session_message in write_stream_reader: + + async def _send_message(session_message: SessionMessage) -> None: logger.debug(f"Sending client message: {session_message}") response = await client.post( endpoint_url, @@ -136,6 +131,14 @@ async def post_writer(endpoint_url: str): ) response.raise_for_status() logger.debug(f"Client message sent successfully: {response.status_code}") + + async for session_message in write_stream_reader: + sender_ctx = write_stream_reader.last_context + if sender_ctx is not None: + async with anyio.create_task_group() as tg: + sender_ctx.run(tg.start_soon, _send_message, session_message) + else: + await _send_message(session_message) # pragma: no cover except Exception: # pragma: lax no cover logger.exception("Error in post_writer") diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 3afb94b034..9a119c6338 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -11,11 +11,11 @@ import anyio import httpx from anyio.abc import TaskGroup -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from httpx_sse import EventSource, ServerSentEvent, aconnect_sse from pydantic import ValidationError from mcp.client._transport import TransportStreams +from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams from mcp.shared._httpx_utils import create_mcp_http_client from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.types import ( @@ -38,8 +38,8 @@ # TODO(Marcelo): Put the TransportStreams in a module under shared, so we can import here. SessionMessageOrError = SessionMessage | Exception -StreamWriter = MemoryObjectSendStream[SessionMessageOrError] -StreamReader = MemoryObjectReceiveStream[SessionMessage] +StreamWriter = ContextSendStream[SessionMessageOrError] +StreamReader = ContextReceiveStream[SessionMessage] MCP_SESSION_ID = "mcp-session-id" MCP_PROTOCOL_VERSION = "mcp-protocol-version" @@ -434,14 +434,15 @@ async def post_writer( client: httpx.AsyncClient, write_stream_reader: StreamReader, read_stream_writer: StreamWriter, - write_stream: MemoryObjectSendStream[SessionMessage], + write_stream: ContextSendStream[SessionMessage], start_get_stream: Callable[[], None], tg: TaskGroup, ) -> None: """Handle writing requests to the server.""" try: async with write_stream_reader, read_stream_writer, write_stream: - async for session_message in write_stream_reader: + + async def _handle_message(session_message: SessionMessage) -> None: message = session_message.message metadata = ( session_message.metadata @@ -478,6 +479,14 @@ async def handle_request_async(): else: await handle_request_async() + async for session_message in write_stream_reader: + sender_ctx = write_stream_reader.last_context + if sender_ctx is not None: + async with anyio.create_task_group() as tg_local: + sender_ctx.run(tg_local.start_soon, _handle_message, session_message) + else: + await _handle_message(session_message) # pragma: no cover + except Exception: # pragma: lax no cover logger.exception("Error in post_writer") @@ -547,8 +556,8 @@ async def streamable_http_client( if not client_provided: await stack.enter_async_context(client) - read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) - write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) + read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) + write_stream, write_stream_reader = create_context_streams[SessionMessage](0) async with ( read_stream_writer, diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index c288422720..0fdbbff866 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -36,6 +36,7 @@ async def main(): from __future__ import annotations +import contextvars import logging import warnings from collections.abc import AsyncIterator, Awaitable, Callable @@ -44,7 +45,6 @@ async def main(): from typing import Any, Generic import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware @@ -65,6 +65,7 @@ async def main(): from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings +from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.exceptions import MCPError from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder @@ -355,8 +356,8 @@ def session_manager(self) -> StreamableHTTPSessionManager: async def run( self, - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], initialization_options: InitializationOptions, # When False, exceptions are returned as messages to the client. # When True, exceptions are raised, which will cause the server to shut down @@ -391,7 +392,13 @@ async def run( async for message in session.incoming_messages: logger.debug("Received message: %s", message) - tg.start_soon( + if isinstance(message, RequestResponder) and message.context is not None: + context = message.context + else: + context = contextvars.copy_context() + + context.run( + tg.start_soon, self._handle_message, message, session, diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index ce467e6c93..20b640527a 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -33,13 +33,14 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult: import anyio import anyio.lowlevel -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from anyio.streams.memory import MemoryObjectReceiveStream from pydantic import AnyUrl, TypeAdapter from mcp import types from mcp.server.experimental.session_features import ExperimentalServerSessionFeatures from mcp.server.models import InitializationOptions from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages +from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.exceptions import StatelessModeNotSupported from mcp.shared.experimental.tasks.capabilities import check_tasks_capability from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY @@ -79,8 +80,8 @@ class ServerSession( def __init__( self, - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], init_options: InitializationOptions, stateless: bool = False, ) -> None: diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 9dcee67f78..48192ff612 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -43,7 +43,6 @@ async def handle_sse(request): from uuid import UUID, uuid4 import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import ValidationError from sse_starlette import EventSourceResponse from starlette.requests import Request @@ -55,6 +54,7 @@ async def handle_sse(request): TransportSecurityMiddleware, TransportSecuritySettings, ) +from mcp.shared._context_streams import ContextSendStream, create_context_streams from mcp.shared.message import ServerMessageMetadata, SessionMessage logger = logging.getLogger(__name__) @@ -72,7 +72,7 @@ class SseServerTransport: """ _endpoint: str - _read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]] + _read_stream_writers: dict[UUID, ContextSendStream[SessionMessage | Exception]] _security: TransportSecurityMiddleware def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | None = None) -> None: @@ -129,14 +129,9 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # prag raise ValueError("Request validation failed") logger.debug("Setting up SSE connection") - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] - read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - write_stream: MemoryObjectSendStream[SessionMessage] - write_stream_reader: MemoryObjectReceiveStream[SessionMessage] - - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) + write_stream, write_stream_reader = create_context_streams[SessionMessage](0) session_id = uuid4() self._read_stream_writers[session_id] = read_stream_writer diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index 5ea6c4e778..5c1459dff6 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -23,9 +23,9 @@ async def run_server(): import anyio import anyio.lowlevel -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp import types +from mcp.shared._context_streams import create_context_streams from mcp.shared.message import SessionMessage @@ -43,14 +43,8 @@ async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio. if not stdout: stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8")) - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] - read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - - write_stream: MemoryObjectSendStream[SessionMessage] - write_stream_reader: MemoryObjectReceiveStream[SessionMessage] - - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) + write_stream, write_stream_reader = create_context_streams[SessionMessage](0) async def stdin_reader(): try: diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index aa99e7c887..f14201857c 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -25,6 +25,8 @@ from starlette.types import Receive, Scope, Send from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings +from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams +from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp.types import ( @@ -119,10 +121,10 @@ class StreamableHTTPServerTransport: """ # Server notification streams for POST requests as well as standalone SSE stream - _read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = None - _read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None - _write_stream: MemoryObjectSendStream[SessionMessage] | None = None - _write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None + _read_stream_writer: ContextSendStream[SessionMessage | Exception] | None = None + _read_stream: ContextReceiveStream[SessionMessage | Exception] | None = None + _write_stream: ContextSendStream[SessionMessage] | None = None + _write_stream_reader: ContextReceiveStream[SessionMessage] | None = None _security: TransportSecurityMiddleware def __init__( @@ -954,8 +956,8 @@ async def connect( self, ) -> AsyncGenerator[ tuple[ - MemoryObjectReceiveStream[SessionMessage | Exception], - MemoryObjectSendStream[SessionMessage], + ReadStream[SessionMessage | Exception], + WriteStream[SessionMessage], ], None, ]: @@ -967,8 +969,8 @@ async def connect( # Create the memory streams for this connection - read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) - write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) + read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) + write_stream, write_stream_reader = create_context_streams[SessionMessage](0) # Store the streams self._read_stream_writer = read_stream_writer diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index 32b50560cc..277f9b5af2 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -1,12 +1,12 @@ from contextlib import asynccontextmanager import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic_core import ValidationError from starlette.types import Receive, Scope, Send from starlette.websockets import WebSocket from mcp import types +from mcp.shared._context_streams import create_context_streams from mcp.shared.message import SessionMessage @@ -19,14 +19,8 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send): websocket = WebSocket(scope, receive, send) await websocket.accept(subprotocol="mcp") - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] - read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - - write_stream: MemoryObjectSendStream[SessionMessage] - write_stream_reader: MemoryObjectReceiveStream[SessionMessage] - - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) + write_stream, write_stream_reader = create_context_streams[SessionMessage](0) async def ws_reader(): try: diff --git a/src/mcp/shared/_context_streams.py b/src/mcp/shared/_context_streams.py new file mode 100644 index 0000000000..04c33306d9 --- /dev/null +++ b/src/mcp/shared/_context_streams.py @@ -0,0 +1,119 @@ +"""Context-aware memory stream wrappers. + +anyio memory streams do not propagate ``contextvars.Context`` across task +boundaries. These thin wrappers capture the sender's context at ``send()`` +time and expose it on the receive side via ``last_context``, so consumers +can restore it with ``ctx.run(handler, item)``. + +The iteration interface is unchanged (yields ``T``, not tuples), keeping +these wrappers duck-type compatible with plain ``MemoryObjectSendStream`` +and ``MemoryObjectReceiveStream``. +""" + +from __future__ import annotations + +import contextvars +from types import TracebackType +from typing import Any, Generic, TypeVar + +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream + +T = TypeVar("T") + +# Internal payload carried through the underlying raw stream. +_Envelope = tuple[contextvars.Context, T] + + +class ContextSendStream(Generic[T]): + """Send-side wrapper that snapshots ``contextvars.copy_context()`` on every ``send()``.""" + + __slots__ = ("_inner",) + + def __init__(self, inner: MemoryObjectSendStream[_Envelope[T]]) -> None: + self._inner = inner + + async def send(self, item: T) -> None: + await self._inner.send((contextvars.copy_context(), item)) + + def close(self) -> None: + self._inner.close() + + async def aclose(self) -> None: + await self._inner.aclose() + + def clone(self) -> ContextSendStream[T]: # pragma: no cover + return ContextSendStream(self._inner.clone()) + + async def __aenter__(self) -> ContextSendStream[T]: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + await self.aclose() + return None + + +class ContextReceiveStream(Generic[T]): + """Receive-side wrapper that yields ``T`` and stores the sender's context in ``last_context``.""" + + __slots__ = ("_inner", "last_context") + + def __init__(self, inner: MemoryObjectReceiveStream[_Envelope[T]]) -> None: + self._inner = inner + self.last_context: contextvars.Context | None = None + + async def receive(self) -> T: + ctx, item = await self._inner.receive() + self.last_context = ctx + return item + + def close(self) -> None: + self._inner.close() + + async def aclose(self) -> None: + await self._inner.aclose() + + def clone(self) -> ContextReceiveStream[T]: # pragma: no cover + return ContextReceiveStream(self._inner.clone()) + + def __aiter__(self) -> ContextReceiveStream[T]: + return self + + async def __anext__(self) -> T: + try: + return await self.receive() + except anyio.EndOfStream: + raise StopAsyncIteration + + async def __aenter__(self) -> ContextReceiveStream[T]: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + await self.aclose() + return None + + +class create_context_streams( + tuple[ContextSendStream[T], ContextReceiveStream[T]], +): + """Create context-aware memory object streams. + + Supports ``create_context_streams[T](n)`` bracket syntax, + matching anyio's ``create_memory_object_stream`` API style. + """ + + def __new__(cls, max_buffer_size: float = 0) -> tuple[ContextSendStream[T], ContextReceiveStream[T]]: # type: ignore[type-var] + raw_send: MemoryObjectSendStream[Any] + raw_receive: MemoryObjectReceiveStream[Any] + raw_send, raw_receive = anyio.create_memory_object_stream(max_buffer_size) + return (ContextSendStream(raw_send), ContextReceiveStream(raw_receive)) diff --git a/src/mcp/shared/_stream_protocols.py b/src/mcp/shared/_stream_protocols.py new file mode 100644 index 0000000000..b799751329 --- /dev/null +++ b/src/mcp/shared/_stream_protocols.py @@ -0,0 +1,49 @@ +"""Stream protocols for MCP transports. + +These are general-purpose protocols satisfied by both ``MemoryObjectSendStream``/ +``MemoryObjectReceiveStream`` and the context-aware wrappers in ``_context_streams``. +""" + +from __future__ import annotations + +from types import TracebackType +from typing import Protocol, TypeVar + +from typing_extensions import Self + +T_co = TypeVar("T_co", covariant=True) +T_contra = TypeVar("T_contra", contravariant=True) + + +class ReadStream(Protocol[T_co]): + """Protocol for reading items from a stream. + + Consumers that need the sender's context should use + ``getattr(stream, 'last_context', None)``. + """ + + async def receive(self) -> T_co: ... + async def aclose(self) -> None: ... + def __aiter__(self) -> ReadStream[T_co]: ... + async def __anext__(self) -> T_co: ... + async def __aenter__(self) -> Self: ... + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: ... + + +class WriteStream(Protocol[T_contra]): + """Protocol for writing items to a stream.""" + + async def send(self, item: T_contra, /) -> None: ... + async def aclose(self) -> None: ... + async def __aenter__(self) -> Self: ... + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: ... diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index f2d5e2b9ad..468590d095 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -5,12 +5,10 @@ from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream - +from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams from mcp.shared.message import SessionMessage -MessageStream = tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]] +MessageStream = tuple[ContextReceiveStream[SessionMessage | Exception], ContextSendStream[SessionMessage | Exception]] @asynccontextmanager @@ -22,8 +20,8 @@ async def create_client_server_memory_streams() -> AsyncGenerator[tuple[MessageS (read_stream, write_stream) """ # Create streams for both directions - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) + server_to_client_send, server_to_client_receive = create_context_streams[SessionMessage | Exception](1) + client_to_server_send, client_to_server_receive = create_context_streams[SessionMessage | Exception](1) client_streams = (server_to_client_receive, client_to_server_send) server_streams = (client_to_server_receive, server_to_client_send) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 6fc59923f7..3534fbb768 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextvars import logging from collections.abc import Callable from contextlib import AsyncExitStack @@ -7,10 +8,11 @@ from typing import Any, Generic, Protocol, TypeVar import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from anyio.streams.memory import MemoryObjectSendStream from pydantic import BaseModel, TypeAdapter from typing_extensions import Self +from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.exceptions import MCPError from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage from mcp.shared.response_router import ResponseRouter @@ -79,11 +81,13 @@ def __init__( session: BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT], on_complete: Callable[[RequestResponder[ReceiveRequestT, SendResultT]], Any], message_metadata: MessageMetadata = None, + context: contextvars.Context | None = None, ) -> None: self.request_id = request_id self.request_meta = request_meta self.request = request self.message_metadata = message_metadata + self.context = context self._session = session self._completed = False self._cancel_scope = anyio.CancelScope() @@ -181,8 +185,8 @@ class BaseSession( def __init__( self, - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], # If none, reading will never time out read_timeout_seconds: float | None = None, ) -> None: @@ -333,10 +337,10 @@ def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]: async def _receive_loop(self) -> None: async with self._read_stream, self._write_stream: try: - async for message in self._read_stream: - if isinstance(message, Exception): - await self._handle_incoming(message) - elif isinstance(message.message, JSONRPCRequest): + + async def _handle_session_message(message: SessionMessage) -> None: + sender_context: contextvars.Context | None = getattr(self._read_stream, "last_context", None) + if isinstance(message.message, JSONRPCRequest): try: validated_request = self._receive_request_adapter.validate_python( message.message.model_dump(by_alias=True, mode="json", exclude_none=True), @@ -349,6 +353,7 @@ async def _receive_loop(self) -> None: session=self, on_complete=lambda r: self._in_flight.pop(r.request_id, None), message_metadata=message.metadata, + context=sender_context, ) self._in_flight[responder.request_id] = responder await self._received_request(responder) @@ -406,6 +411,13 @@ async def _receive_loop(self) -> None: else: # Response or error await self._handle_response(message) + async for message in self._read_stream: + if isinstance(message, Exception): + await self._handle_incoming(message) + continue + + await _handle_session_message(message) + except anyio.ClosedResourceError: # This is expected when the client disconnects abruptly. # Without this handler, the exception would propagate up and diff --git a/tests/client/conftest.py b/tests/client/conftest.py index 2e39f13630..081e1d68ea 100644 --- a/tests/client/conftest.py +++ b/tests/client/conftest.py @@ -4,15 +4,15 @@ from unittest.mock import patch import pytest -from anyio.streams.memory import MemoryObjectSendStream import mcp.shared.memory +from mcp.client._transport import WriteStream from mcp.shared.message import SessionMessage from mcp.types import JSONRPCNotification, JSONRPCRequest class SpyMemoryObjectSendStream: - def __init__(self, original_stream: MemoryObjectSendStream[SessionMessage]): + def __init__(self, original_stream: WriteStream[SessionMessage]): self.original_stream = original_stream self.sent_messages: list[SessionMessage] = [] diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 18368e6bb3..ac52a9024a 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -2,6 +2,9 @@ from __future__ import annotations +import contextvars +from collections.abc import Iterator +from contextlib import contextmanager from unittest.mock import patch import anyio @@ -320,3 +323,33 @@ async def test_client_uses_transport_directly(app: MCPServer): structured_content={"result": "Hello, Transport!"}, ) ) + + +_TEST_CONTEXTVAR = contextvars.ContextVar("test_var", default="initial") + + +@contextmanager +def _set_test_contextvar(value: str) -> Iterator[None]: + token = _TEST_CONTEXTVAR.set(value) + try: + yield + finally: + _TEST_CONTEXTVAR.reset(token) + + +async def test_context_propagation(): + """Sender's contextvars.Context is propagated to the server handler.""" + server = MCPServer("test") + + @server.tool() + async def check_context() -> str: + """Return the contextvar value visible to the handler.""" + return _TEST_CONTEXTVAR.get() + + async with Client(server) as client: + with _set_test_contextvar("client_value"): + result = await client.call_tool("check_context", {}) + + assert result.content[0].text == "client_value", ( # type: ignore[union-attr] + "Server handler did not see the sender's contextvars.Context" + ) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index f8ca30441b..3d5770fb61 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -45,6 +45,7 @@ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings from mcp.shared._context import RequestContext +from mcp.shared._context_streams import create_context_streams from mcp.shared._httpx_utils import ( MCP_DEFAULT_SSE_READ_TIMEOUT, MCP_DEFAULT_TIMEOUT, @@ -1783,8 +1784,8 @@ async def test_handle_sse_event_skips_empty_data(): # Create a mock SSE event with empty data (keep-alive ping) mock_sse = ServerSentEvent(event="message", data="", id=None, retry=None) - # Create a mock stream writer - write_stream, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](1) + # Create a context-aware stream writer (matches StreamWriter type alias) + write_stream, read_stream = create_context_streams[SessionMessage | Exception](1) try: # Call _handle_sse_event with empty data - should return False and not raise @@ -1794,8 +1795,9 @@ async def test_handle_sse_event_skips_empty_data(): assert result is False # Nothing should have been written to the stream - # Check buffer is empty (statistics().current_buffer_used returns buffer size) - assert write_stream.statistics().current_buffer_used == 0 + with pytest.raises(TimeoutError): + with anyio.fail_after(0): + await read_stream.receive() finally: await write_stream.aclose() await read_stream.aclose() From 3ce0f76e6e3b33f035b6b28421e1b7c6dbe8c77f Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 31 Mar 2026 13:43:56 -0400 Subject: [PATCH 50/84] Don't block the event loop on sync resource and prompt functions (#2380) --- src/mcp/server/mcpserver/prompts/base.py | 10 ++-- .../server/mcpserver/resources/templates.py | 10 ++-- src/mcp/server/mcpserver/resources/types.py | 9 ++-- tests/server/mcpserver/prompts/test_base.py | 19 +++++++ .../resources/test_function_resources.py | 52 +++++++++++++++++++ .../resources/test_resource_template.py | 20 +++++++ 6 files changed, 107 insertions(+), 13 deletions(-) diff --git a/src/mcp/server/mcpserver/prompts/base.py b/src/mcp/server/mcpserver/prompts/base.py index 0c319d53cc..b4810c100e 100644 --- a/src/mcp/server/mcpserver/prompts/base.py +++ b/src/mcp/server/mcpserver/prompts/base.py @@ -2,10 +2,12 @@ from __future__ import annotations +import functools import inspect from collections.abc import Awaitable, Callable, Sequence from typing import TYPE_CHECKING, Any, Literal +import anyio.to_thread import pydantic_core from pydantic import BaseModel, Field, TypeAdapter, validate_call @@ -155,10 +157,10 @@ async def render( # Add context to arguments if needed call_args = inject_context(self.fn, arguments or {}, context, self.context_kwarg) - # Call function and check if result is a coroutine - result = self.fn(**call_args) - if inspect.iscoroutine(result): - result = await result + if inspect.iscoroutinefunction(self.fn): + result = await self.fn(**call_args) + else: + result = await anyio.to_thread.run_sync(functools.partial(self.fn, **call_args)) # Validate messages if not isinstance(result, list | tuple): diff --git a/src/mcp/server/mcpserver/resources/templates.py b/src/mcp/server/mcpserver/resources/templates.py index 2d612657c4..542b5e6f81 100644 --- a/src/mcp/server/mcpserver/resources/templates.py +++ b/src/mcp/server/mcpserver/resources/templates.py @@ -2,12 +2,14 @@ from __future__ import annotations +import functools import inspect import re from collections.abc import Callable from typing import TYPE_CHECKING, Any from urllib.parse import unquote +import anyio.to_thread from pydantic import BaseModel, Field, validate_call from mcp.server.mcpserver.resources.types import FunctionResource, Resource @@ -110,10 +112,10 @@ async def create_resource( # Add context to params if needed params = inject_context(self.fn, params, context, self.context_kwarg) - # Call function and check if result is a coroutine - result = self.fn(**params) - if inspect.iscoroutine(result): - result = await result + if inspect.iscoroutinefunction(self.fn): + result = await self.fn(**params) + else: + result = await anyio.to_thread.run_sync(functools.partial(self.fn, **params)) return FunctionResource( uri=uri, # type: ignore diff --git a/src/mcp/server/mcpserver/resources/types.py b/src/mcp/server/mcpserver/resources/types.py index 42aecd6e39..04763be8ba 100644 --- a/src/mcp/server/mcpserver/resources/types.py +++ b/src/mcp/server/mcpserver/resources/types.py @@ -55,11 +55,10 @@ class FunctionResource(Resource): async def read(self) -> str | bytes: """Read the resource by calling the wrapped function.""" try: - # Call the function first to see if it returns a coroutine - result = self.fn() - # If it's a coroutine, await it - if inspect.iscoroutine(result): - result = await result + if inspect.iscoroutinefunction(self.fn): + result = await self.fn() + else: + result = await anyio.to_thread.run_sync(self.fn) if isinstance(result, Resource): # pragma: no cover return await result.read() diff --git a/tests/server/mcpserver/prompts/test_base.py b/tests/server/mcpserver/prompts/test_base.py index fe18e91bd7..d4e4e6b5a6 100644 --- a/tests/server/mcpserver/prompts/test_base.py +++ b/tests/server/mcpserver/prompts/test_base.py @@ -1,3 +1,4 @@ +import threading from typing import Any import pytest @@ -190,3 +191,21 @@ async def fn() -> dict[str, Any]: ) ) ] + + +@pytest.mark.anyio +async def test_sync_fn_runs_in_worker_thread(): + """Sync prompt functions must run in a worker thread, not the event loop.""" + + main_thread = threading.get_ident() + fn_thread: list[int] = [] + + def blocking_fn() -> str: + fn_thread.append(threading.get_ident()) + return "hello" + + prompt = Prompt.from_function(blocking_fn) + messages = await prompt.render(None, Context()) + + assert messages == [UserMessage(content=TextContent(type="text", text="hello"))] + assert fn_thread[0] != main_thread diff --git a/tests/server/mcpserver/resources/test_function_resources.py b/tests/server/mcpserver/resources/test_function_resources.py index 5f5c216ed1..c1ff960617 100644 --- a/tests/server/mcpserver/resources/test_function_resources.py +++ b/tests/server/mcpserver/resources/test_function_resources.py @@ -1,3 +1,7 @@ +import threading + +import anyio +import anyio.from_thread import pytest from pydantic import BaseModel @@ -190,3 +194,51 @@ def get_data() -> str: # pragma: no cover ) assert resource.meta is None + + +@pytest.mark.anyio +async def test_sync_fn_runs_in_worker_thread(): + """Sync resource functions must run in a worker thread, not the event loop.""" + + main_thread = threading.get_ident() + fn_thread: list[int] = [] + + def blocking_fn() -> str: + fn_thread.append(threading.get_ident()) + return "data" + + resource = FunctionResource(uri="resource://test", name="test", fn=blocking_fn) + result = await resource.read() + + assert result == "data" + assert fn_thread[0] != main_thread + + +@pytest.mark.anyio +async def test_sync_fn_does_not_block_event_loop(): + """A blocking sync resource function must not stall the event loop. + + On regression (sync runs inline), anyio.from_thread.run_sync raises + RuntimeError because there is no worker-thread context, failing fast. + """ + handler_entered = anyio.Event() + release = threading.Event() + + def blocking_fn() -> str: + anyio.from_thread.run_sync(handler_entered.set) + release.wait() + return "done" + + resource = FunctionResource(uri="resource://test", name="test", fn=blocking_fn) + result: list[str | bytes] = [] + + async def run() -> None: + result.append(await resource.read()) + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + tg.start_soon(run) + await handler_entered.wait() + release.set() + + assert result == ["done"] diff --git a/tests/server/mcpserver/resources/test_resource_template.py b/tests/server/mcpserver/resources/test_resource_template.py index 640cfe8031..2a7ba8d503 100644 --- a/tests/server/mcpserver/resources/test_resource_template.py +++ b/tests/server/mcpserver/resources/test_resource_template.py @@ -1,4 +1,5 @@ import json +import threading from typing import Any import pytest @@ -310,3 +311,22 @@ def get_item(item_id: str) -> str: assert resource.meta == metadata assert resource.meta["category"] == "inventory" assert resource.meta["cacheable"] is True + + +@pytest.mark.anyio +async def test_sync_fn_runs_in_worker_thread(): + """Sync template functions must run in a worker thread, not the event loop.""" + + main_thread = threading.get_ident() + fn_thread: list[int] = [] + + def blocking_fn(name: str) -> str: + fn_thread.append(threading.get_ident()) + return f"hello {name}" + + template = ResourceTemplate.from_function(fn=blocking_fn, uri_template="test://{name}") + resource = await template.create_resource("test://world", {"name": "world"}, Context()) + + assert isinstance(resource, FunctionResource) + assert await resource.read() == "hello world" + assert fn_thread[0] != main_thread From 37891f42a409ac10a46fb1577e7fd3ec3f70d5eb Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 31 Mar 2026 16:33:33 -0400 Subject: [PATCH 51/84] Add basic OpenTelemetry tracing for client and server requests (#2381) --- pyproject.toml | 2 + src/mcp/server/lowlevel/server.py | 150 +++++++++++-------- src/mcp/shared/_otel.py | 36 +++++ src/mcp/shared/session.py | 50 ++++--- tests/shared/test_otel.py | 44 ++++++ uv.lock | 237 ++++++++++++++++++++++++++++++ 6 files changed, 436 insertions(+), 83 deletions(-) create mode 100644 src/mcp/shared/_otel.py create mode 100644 tests/shared/test_otel.py diff --git a/pyproject.toml b/pyproject.toml index 7d8b4a8743..be1200cff0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "pyjwt[crypto]>=2.10.1", "typing-extensions>=4.13.0", "typing-inspection>=0.4.1", + "opentelemetry-api>=1.28.0", ] [project.optional-dependencies] @@ -71,6 +72,7 @@ dev = [ "coverage[toml]>=7.10.7,<=7.13", "pillow>=12.0", "strict-no-cover", + "logfire>=3.0.0", ] docs = [ "mkdocs>=1.6.1", diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 0fdbbff866..59de0ace45 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -42,9 +42,10 @@ async def main(): from collections.abc import AsyncIterator, Awaitable, Callable from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager from importlib.metadata import version as importlib_version -from typing import Any, Generic +from typing import Any, Generic, cast import anyio +from opentelemetry.trace import SpanKind, StatusCode from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware @@ -65,6 +66,7 @@ async def main(): from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings +from mcp.shared._otel import extract_trace_context, otel_span from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.exceptions import MCPError from mcp.shared.message import ServerMessageMetadata, SessionMessage @@ -446,72 +448,90 @@ async def _handle_request( ): logger.info("Processing request of type %s", type(req).__name__) - if handler := self._request_handlers.get(req.method): - logger.debug("Dispatching request of type %s", type(req).__name__) + target = getattr(req.params, "name", None) if req.params else None + span_name = f"MCP handle {req.method} {target}" if target else f"MCP handle {req.method}" - try: - # Extract request context and close_sse_stream from message metadata - request_data = None - close_sse_stream_cb = None - close_standalone_sse_stream_cb = None - if message.message_metadata is not None and isinstance(message.message_metadata, ServerMessageMetadata): - request_data = message.message_metadata.request_context - close_sse_stream_cb = message.message_metadata.close_sse_stream - close_standalone_sse_stream_cb = message.message_metadata.close_standalone_sse_stream + # Extract W3C trace context from _meta (SEP-414). + meta = cast(dict[str, Any] | None, getattr(req.params, "meta", None)) if req.params else None + parent_context = extract_trace_context(meta) if meta is not None else None - client_capabilities = session.client_params.capabilities if session.client_params else None - task_support = self._experimental_handlers.task_support if self._experimental_handlers else None - # Get task metadata from request params if present - task_metadata = None - if hasattr(req, "params") and req.params is not None: - task_metadata = getattr(req.params, "task", None) - ctx = ServerRequestContext( - request_id=message.request_id, - meta=message.request_meta, - session=session, - lifespan_context=lifespan_context, - experimental=Experimental( - task_metadata=task_metadata, - _client_capabilities=client_capabilities, - _session=session, - _task_support=task_support, - ), - request=request_data, - close_sse_stream=close_sse_stream_cb, - close_standalone_sse_stream=close_standalone_sse_stream_cb, - ) - response = await handler(ctx, req.params) - except MCPError as err: - response = err.error - except anyio.get_cancelled_exc_class(): - if message.cancelled: - # Client sent CancelledNotification; responder.cancel() already - # sent an error response, so skip the duplicate. - logger.info("Request %s cancelled - duplicate response suppressed", message.request_id) - return - # Transport-close cancellation from the TG in run(); re-raise so the - # TG swallows its own cancellation. - raise - except Exception as err: - if raise_exceptions: # pragma: no cover - raise err - response = types.ErrorData(code=0, message=str(err)) - else: # pragma: no cover - response = types.ErrorData(code=types.METHOD_NOT_FOUND, message="Method not found") - - try: - await message.respond(response) - except (anyio.BrokenResourceError, anyio.ClosedResourceError): - # Transport closed between handler unblocking and respond. Happens - # when _receive_loop's finally wakes a handler blocked on - # send_request: the handler runs to respond() before run()'s TG - # cancel fires, but after the write stream closed. Closed if our - # end closed (_receive_loop's async-with exit); Broken if the peer - # end closed first (streamable_http terminate()). - logger.debug("Response for %s dropped - transport closed", message.request_id) - return - - logger.debug("Response sent") + with otel_span( + span_name, + kind=SpanKind.SERVER, + attributes={"mcp.method.name": req.method, "jsonrpc.request.id": message.request_id}, + context=parent_context, + ) as span: + if handler := self._request_handlers.get(req.method): + logger.debug("Dispatching request of type %s", type(req).__name__) + + try: + # Extract request context and close_sse_stream from message metadata + request_data = None + close_sse_stream_cb = None + close_standalone_sse_stream_cb = None + if message.message_metadata is not None and isinstance( + message.message_metadata, ServerMessageMetadata + ): + request_data = message.message_metadata.request_context + close_sse_stream_cb = message.message_metadata.close_sse_stream + close_standalone_sse_stream_cb = message.message_metadata.close_standalone_sse_stream + + client_capabilities = session.client_params.capabilities if session.client_params else None + task_support = self._experimental_handlers.task_support if self._experimental_handlers else None + # Get task metadata from request params if present + task_metadata = None + if hasattr(req, "params") and req.params is not None: # pragma: no branch + task_metadata = getattr(req.params, "task", None) + ctx = ServerRequestContext( + request_id=message.request_id, + meta=message.request_meta, + session=session, + lifespan_context=lifespan_context, + experimental=Experimental( + task_metadata=task_metadata, + _client_capabilities=client_capabilities, + _session=session, + _task_support=task_support, + ), + request=request_data, + close_sse_stream=close_sse_stream_cb, + close_standalone_sse_stream=close_standalone_sse_stream_cb, + ) + response = await handler(ctx, req.params) + except MCPError as err: + response = err.error + except anyio.get_cancelled_exc_class(): + if message.cancelled: + # Client sent CancelledNotification; responder.cancel() already + # sent an error response, so skip the duplicate. + logger.info("Request %s cancelled - duplicate response suppressed", message.request_id) + return + # Transport-close cancellation from the TG in run(); re-raise so the + # TG swallows its own cancellation. + raise + except Exception as err: + if raise_exceptions: # pragma: no cover + raise err + response = types.ErrorData(code=0, message=str(err)) + else: # pragma: no cover + response = types.ErrorData(code=types.METHOD_NOT_FOUND, message="Method not found") + + if isinstance(response, types.ErrorData) and span is not None: + span.set_status(StatusCode.ERROR, response.message) + + try: + await message.respond(response) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + # Transport closed between handler unblocking and respond. Happens + # when _receive_loop's finally wakes a handler blocked on + # send_request: the handler runs to respond() before run()'s TG + # cancel fires, but after the write stream closed. Closed if our + # end closed (_receive_loop's async-with exit); Broken if the peer + # end closed first (streamable_http terminate()). + logger.debug("Response for %s dropped - transport closed", message.request_id) + return + + logger.debug("Response sent") async def _handle_notification( self, diff --git a/src/mcp/shared/_otel.py b/src/mcp/shared/_otel.py new file mode 100644 index 0000000000..170e873a0f --- /dev/null +++ b/src/mcp/shared/_otel.py @@ -0,0 +1,36 @@ +"""OpenTelemetry helpers for MCP.""" + +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +from typing import Any + +from opentelemetry.context import Context +from opentelemetry.propagate import extract, inject +from opentelemetry.trace import SpanKind, get_tracer + +_tracer = get_tracer("mcp-python-sdk") + + +@contextmanager +def otel_span( + name: str, + *, + kind: SpanKind, + attributes: dict[str, Any] | None = None, + context: Context | None = None, +) -> Iterator[Any]: + """Create an OTel span.""" + with _tracer.start_as_current_span(name, kind=kind, attributes=attributes, context=context) as span: + yield span + + +def inject_trace_context(meta: dict[str, Any]) -> None: + """Inject W3C trace context (traceparent/tracestate) into a `_meta` dict.""" + inject(meta) + + +def extract_trace_context(meta: dict[str, Any]) -> Context: + """Extract W3C trace context from a `_meta` dict.""" + return extract(meta) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 3534fbb768..243eef5ae6 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -9,9 +9,11 @@ import anyio from anyio.streams.memory import MemoryObjectSendStream +from opentelemetry.trace import SpanKind from pydantic import BaseModel, TypeAdapter from typing_extensions import Self +from mcp.shared._otel import inject_trace_context, otel_span from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.exceptions import MCPError from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage @@ -268,24 +270,36 @@ async def send_request( self._progress_callbacks[request_id] = progress_callback try: - jsonrpc_request = JSONRPCRequest(jsonrpc="2.0", id=request_id, **request_data) - await self._write_stream.send(SessionMessage(message=jsonrpc_request, metadata=metadata)) - - # request read timeout takes precedence over session read timeout - timeout = request_read_timeout_seconds or self._session_read_timeout_seconds - - try: - with anyio.fail_after(timeout): - response_or_error = await response_stream_reader.receive() - except TimeoutError: - class_name = request.__class__.__name__ - message = f"Timed out while waiting for response to {class_name}. Waited {timeout} seconds." - raise MCPError(code=REQUEST_TIMEOUT, message=message) - - if isinstance(response_or_error, JSONRPCError): - raise MCPError.from_jsonrpc_error(response_or_error) - else: - return result_type.model_validate(response_or_error.result, by_name=False) + target = request_data.get("params", {}).get("name") + span_name = f"MCP send {request.method} {target}" if target else f"MCP send {request.method}" + + with otel_span( + span_name, + kind=SpanKind.CLIENT, + attributes={"mcp.method.name": request.method, "jsonrpc.request.id": request_id}, + ): + # Inject W3C trace context into _meta (SEP-414). + meta: dict[str, Any] = request_data.setdefault("params", {}).setdefault("_meta", {}) + inject_trace_context(meta) + + jsonrpc_request = JSONRPCRequest(jsonrpc="2.0", id=request_id, **request_data) + await self._write_stream.send(SessionMessage(message=jsonrpc_request, metadata=metadata)) + + # request read timeout takes precedence over session read timeout + timeout = request_read_timeout_seconds or self._session_read_timeout_seconds + + try: + with anyio.fail_after(timeout): + response_or_error = await response_stream_reader.receive() + except TimeoutError: + class_name = request.__class__.__name__ + message = f"Timed out while waiting for response to {class_name}. Waited {timeout} seconds." + raise MCPError(code=REQUEST_TIMEOUT, message=message) + + if isinstance(response_or_error, JSONRPCError): + raise MCPError.from_jsonrpc_error(response_or_error) + else: + return result_type.model_validate(response_or_error.result, by_name=False) finally: self._response_streams.pop(request_id, None) diff --git a/tests/shared/test_otel.py b/tests/shared/test_otel.py new file mode 100644 index 0000000000..ec7ff78cc1 --- /dev/null +++ b/tests/shared/test_otel.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import pytest +from logfire.testing import CaptureLogfire + +from mcp import types +from mcp.client.client import Client +from mcp.server.mcpserver import MCPServer + +pytestmark = pytest.mark.anyio + + +# Logfire warns about propagated trace context by default (distributed_tracing=None). +# This is expected here since we're testing cross-boundary context propagation. +@pytest.mark.filterwarnings("ignore::RuntimeWarning") +async def test_client_and_server_spans(capfire: CaptureLogfire): + """Verify that calling a tool produces client and server spans with correct attributes.""" + server = MCPServer("test") + + @server.tool() + def greet(name: str) -> str: + """Greet someone.""" + return f"Hello, {name}!" + + async with Client(server) as client: + result = await client.call_tool("greet", {"name": "World"}) + + assert isinstance(result.content[0], types.TextContent) + assert result.content[0].text == "Hello, World!" + + spans = capfire.exporter.exported_spans_as_dict() + span_names = {s["name"] for s in spans} + + assert "MCP send tools/call greet" in span_names + assert "MCP handle tools/call greet" in span_names + + client_span = next(s for s in spans if s["name"] == "MCP send tools/call greet") + server_span = next(s for s in spans if s["name"] == "MCP handle tools/call greet") + + assert client_span["attributes"]["mcp.method.name"] == "tools/call" + assert server_span["attributes"]["mcp.method.name"] == "tools/call" + + # Server span should be in the same trace as the client span (context propagation). + assert server_span["context"]["trace_id"] == client_span["context"]["trace_id"] diff --git a/uv.lock b/uv.lock index 8f9a5396a2..5efbb05dce 100644 --- a/uv.lock +++ b/uv.lock @@ -579,6 +579,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f7/ec/67fbef5d497f86283db54c22eec6f6140243aae73265799baaaa19cd17fb/ghp_import-2.1.0-py3-none-any.whl", hash = "sha256:8337dd7b50877f163d4c0289bc1f1c7f127550241988d568c1db512c4324a619", size = 11034, upload-time = "2022-05-02T15:47:14.552Z" }, ] +[[package]] +name = "googleapis-common-protos" +version = "1.73.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a1/c0/4a54c386282c13449eca8bbe2ddb518181dc113e78d240458a68856b4d69/googleapis_common_protos-1.73.1.tar.gz", hash = "sha256:13114f0e9d2391756a0194c3a8131974ed7bffb06086569ba193364af59163b6", size = 147506, upload-time = "2026-03-26T22:17:38.451Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dc/82/fcb6520612bec0c39b973a6c0954b6a0d948aadfe8f7e9487f60ceb8bfa6/googleapis_common_protos-1.73.1-py3-none-any.whl", hash = "sha256:e51f09eb0a43a8602f5a915870972e6b4a394088415c79d79605a46d8e826ee8", size = 297556, upload-time = "2026-03-26T22:15:58.455Z" }, +] + [[package]] name = "griffe" version = "1.14.0" @@ -646,6 +658,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" }, ] +[[package]] +name = "importlib-metadata" +version = "8.7.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "zipp" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f3/49/3b30cad09e7771a4982d9975a8cbf64f00d4a1ececb53297f1d9a7be1b10/importlib_metadata-8.7.1.tar.gz", hash = "sha256:49fef1ae6440c182052f407c8d34a68f72efc36db9ca90dc0113398f2fdde8bb", size = 57107, upload-time = "2025-12-21T10:00:19.278Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fa/5e/f8e9a1d23b9c20a551a8a02ea3637b4642e22c2626e3a13a9a29cdea99eb/importlib_metadata-8.7.1-py3-none-any.whl", hash = "sha256:5a1f80bf1daa489495071efbb095d75a634cf28a8bc299581244063b53176151", size = 27865, upload-time = "2025-12-21T10:00:18.329Z" }, +] + [[package]] name = "iniconfig" version = "2.1.0" @@ -710,6 +734,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/41/45/1a4ed80516f02155c51f51e8cedb3c1902296743db0bbc66608a0db2814f/jsonschema_specifications-2025.9.1-py3-none-any.whl", hash = "sha256:98802fee3a11ee76ecaca44429fda8a41bff98b00a0f2838151b113f210cc6fe", size = 18437, upload-time = "2025-09-08T01:34:57.871Z" }, ] +[[package]] +name = "logfire" +version = "4.31.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "executing" }, + { name = "opentelemetry-exporter-otlp-proto-http" }, + { name = "opentelemetry-instrumentation" }, + { name = "opentelemetry-sdk" }, + { name = "protobuf" }, + { name = "rich" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/61/fc/21f923243d8c3ca2ebfa97de46970ced734e66ac634c1c35b6abb41300f1/logfire-4.31.0.tar.gz", hash = "sha256:361bfda17c9d70ada5d220211033bae06b871ddac9d5b06978bc0ceca6b8e658", size = 1080609, upload-time = "2026-03-27T19:00:46.339Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/49/1a/8c860e35bf847ac0d647d94bad89dccbb66cbcafdd61d8334f8cc7cfdd58/logfire-4.31.0-py3-none-any.whl", hash = "sha256:49fad38b5e6f199a98e9c8814e860c8a42595bb81479b52a20413e53ee475b72", size = 308896, upload-time = "2026-03-27T19:00:43.107Z" }, +] + [[package]] name = "markdown" version = "3.9" @@ -797,6 +840,7 @@ dependencies = [ { name = "httpx" }, { name = "httpx-sse" }, { name = "jsonschema" }, + { name = "opentelemetry-api" }, { name = "pydantic" }, { name = "pydantic-settings" }, { name = "pyjwt", extra = ["crypto"] }, @@ -826,6 +870,7 @@ dev = [ { name = "coverage", extra = ["toml"] }, { name = "dirty-equals" }, { name = "inline-snapshot" }, + { name = "logfire" }, { name = "mcp", extra = ["cli", "ws"] }, { name = "pillow" }, { name = "pyright" }, @@ -853,6 +898,7 @@ requires-dist = [ { name = "httpx", specifier = ">=0.27.1,<1.0.0" }, { name = "httpx-sse", specifier = ">=0.4" }, { name = "jsonschema", specifier = ">=4.20.0" }, + { name = "opentelemetry-api", specifier = ">=1.28.0" }, { name = "pydantic", specifier = ">=2.12.0" }, { name = "pydantic-settings", specifier = ">=2.5.2" }, { name = "pyjwt", extras = ["crypto"], specifier = ">=2.10.1" }, @@ -876,6 +922,7 @@ dev = [ { name = "coverage", extras = ["toml"], specifier = ">=7.10.7,<=7.13" }, { name = "dirty-equals", specifier = ">=0.9.0" }, { name = "inline-snapshot", specifier = ">=0.23.0" }, + { name = "logfire", specifier = ">=3.0.0" }, { name = "mcp", extras = ["cli", "ws"], editable = "." }, { name = "pillow", specifier = ">=12.0" }, { name = "pyright", specifier = ">=1.1.400" }, @@ -1642,6 +1689,103 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314, upload-time = "2024-06-04T18:44:08.352Z" }, ] +[[package]] +name = "opentelemetry-api" +version = "1.39.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "importlib-metadata" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/97/b9/3161be15bb8e3ad01be8be5a968a9237c3027c5be504362ff800fca3e442/opentelemetry_api-1.39.1.tar.gz", hash = "sha256:fbde8c80e1b937a2c61f20347e91c0c18a1940cecf012d62e65a7caf08967c9c", size = 65767, upload-time = "2025-12-11T13:32:39.182Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cf/df/d3f1ddf4bb4cb50ed9b1139cc7b1c54c34a1e7ce8fd1b9a37c0d1551a6bd/opentelemetry_api-1.39.1-py3-none-any.whl", hash = "sha256:2edd8463432a7f8443edce90972169b195e7d6a05500cd29e6d13898187c9950", size = 66356, upload-time = "2025-12-11T13:32:17.304Z" }, +] + +[[package]] +name = "opentelemetry-exporter-otlp-proto-common" +version = "1.39.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-proto" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e9/9d/22d241b66f7bbde88a3bfa6847a351d2c46b84de23e71222c6aae25c7050/opentelemetry_exporter_otlp_proto_common-1.39.1.tar.gz", hash = "sha256:763370d4737a59741c89a67b50f9e39271639ee4afc999dadfe768541c027464", size = 20409, upload-time = "2025-12-11T13:32:40.885Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8c/02/ffc3e143d89a27ac21fd557365b98bd0653b98de8a101151d5805b5d4c33/opentelemetry_exporter_otlp_proto_common-1.39.1-py3-none-any.whl", hash = "sha256:08f8a5862d64cc3435105686d0216c1365dc5701f86844a8cd56597d0c764fde", size = 18366, upload-time = "2025-12-11T13:32:20.2Z" }, +] + +[[package]] +name = "opentelemetry-exporter-otlp-proto-http" +version = "1.39.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "googleapis-common-protos" }, + { name = "opentelemetry-api" }, + { name = "opentelemetry-exporter-otlp-proto-common" }, + { name = "opentelemetry-proto" }, + { name = "opentelemetry-sdk" }, + { name = "requests" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/80/04/2a08fa9c0214ae38880df01e8bfae12b067ec0793446578575e5080d6545/opentelemetry_exporter_otlp_proto_http-1.39.1.tar.gz", hash = "sha256:31bdab9745c709ce90a49a0624c2bd445d31a28ba34275951a6a362d16a0b9cb", size = 17288, upload-time = "2025-12-11T13:32:42.029Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/95/f1/b27d3e2e003cd9a3592c43d099d2ed8d0a947c15281bf8463a256db0b46c/opentelemetry_exporter_otlp_proto_http-1.39.1-py3-none-any.whl", hash = "sha256:d9f5207183dd752a412c4cd564ca8875ececba13be6e9c6c370ffb752fd59985", size = 19641, upload-time = "2025-12-11T13:32:22.248Z" }, +] + +[[package]] +name = "opentelemetry-instrumentation" +version = "0.60b1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-semantic-conventions" }, + { name = "packaging" }, + { name = "wrapt" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/41/0f/7e6b713ac117c1f5e4e3300748af699b9902a2e5e34c9cf443dde25a01fa/opentelemetry_instrumentation-0.60b1.tar.gz", hash = "sha256:57ddc7974c6eb35865af0426d1a17132b88b2ed8586897fee187fd5b8944bd6a", size = 31706, upload-time = "2025-12-11T13:36:42.515Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/77/d2/6788e83c5c86a2690101681aeef27eeb2a6bf22df52d3f263a22cee20915/opentelemetry_instrumentation-0.60b1-py3-none-any.whl", hash = "sha256:04480db952b48fb1ed0073f822f0ee26012b7be7c3eac1a3793122737c78632d", size = 33096, upload-time = "2025-12-11T13:35:33.067Z" }, +] + +[[package]] +name = "opentelemetry-proto" +version = "1.39.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/49/1d/f25d76d8260c156c40c97c9ed4511ec0f9ce353f8108ca6e7561f82a06b2/opentelemetry_proto-1.39.1.tar.gz", hash = "sha256:6c8e05144fc0d3ed4d22c2289c6b126e03bcd0e6a7da0f16cedd2e1c2772e2c8", size = 46152, upload-time = "2025-12-11T13:32:48.681Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/51/95/b40c96a7b5203005a0b03d8ce8cd212ff23f1793d5ba289c87a097571b18/opentelemetry_proto-1.39.1-py3-none-any.whl", hash = "sha256:22cdc78efd3b3765d09e68bfbd010d4fc254c9818afd0b6b423387d9dee46007", size = 72535, upload-time = "2025-12-11T13:32:33.866Z" }, +] + +[[package]] +name = "opentelemetry-sdk" +version = "1.39.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-semantic-conventions" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/eb/fb/c76080c9ba07e1e8235d24cdcc4d125ef7aa3edf23eb4e497c2e50889adc/opentelemetry_sdk-1.39.1.tar.gz", hash = "sha256:cf4d4563caf7bff906c9f7967e2be22d0d6b349b908be0d90fb21c8e9c995cc6", size = 171460, upload-time = "2025-12-11T13:32:49.369Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7c/98/e91cf858f203d86f4eccdf763dcf01cf03f1dae80c3750f7e635bfa206b6/opentelemetry_sdk-1.39.1-py3-none-any.whl", hash = "sha256:4d5482c478513ecb0a5d938dcc61394e647066e0cc2676bee9f3af3f3f45f01c", size = 132565, upload-time = "2025-12-11T13:32:35.069Z" }, +] + +[[package]] +name = "opentelemetry-semantic-conventions" +version = "0.60b1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/91/df/553f93ed38bf22f4b999d9be9c185adb558982214f33eae539d3b5cd0858/opentelemetry_semantic_conventions-0.60b1.tar.gz", hash = "sha256:87c228b5a0669b748c76d76df6c364c369c28f1c465e50f661e39737e84bc953", size = 137935, upload-time = "2025-12-11T13:32:50.487Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/5e/5958555e09635d09b75de3c4f8b9cae7335ca545d77392ffe7331534c402/opentelemetry_semantic_conventions-0.60b1-py3-none-any.whl", hash = "sha256:9fa8c8b0c110da289809292b0591220d3a7b53c1526a23021e977d68597893fb", size = 219982, upload-time = "2025-12-11T13:32:36.955Z" }, +] + [[package]] name = "outcome" version = "1.3.0.post0" @@ -1797,6 +1941,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "protobuf" +version = "6.33.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/66/70/e908e9c5e52ef7c3a6c7902c9dfbb34c7e29c25d2f81ade3856445fd5c94/protobuf-6.33.6.tar.gz", hash = "sha256:a6768d25248312c297558af96a9f9c929e8c4cee0659cb07e780731095f38135", size = 444531, upload-time = "2026-03-18T19:05:00.988Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fc/9f/2f509339e89cfa6f6a4c4ff50438db9ca488dec341f7e454adad60150b00/protobuf-6.33.6-cp310-abi3-win32.whl", hash = "sha256:7d29d9b65f8afef196f8334e80d6bc1d5d4adedb449971fefd3723824e6e77d3", size = 425739, upload-time = "2026-03-18T19:04:48.373Z" }, + { url = "https://files.pythonhosted.org/packages/76/5d/683efcd4798e0030c1bab27374fd13a89f7c2515fb1f3123efdfaa5eab57/protobuf-6.33.6-cp310-abi3-win_amd64.whl", hash = "sha256:0cd27b587afca21b7cfa59a74dcbd48a50f0a6400cfb59391340ad729d91d326", size = 437089, upload-time = "2026-03-18T19:04:50.381Z" }, + { url = "https://files.pythonhosted.org/packages/5c/01/a3c3ed5cd186f39e7880f8303cc51385a198a81469d53d0fdecf1f64d929/protobuf-6.33.6-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:9720e6961b251bde64edfdab7d500725a2af5280f3f4c87e57c0208376aa8c3a", size = 427737, upload-time = "2026-03-18T19:04:51.866Z" }, + { url = "https://files.pythonhosted.org/packages/ee/90/b3c01fdec7d2f627b3a6884243ba328c1217ed2d978def5c12dc50d328a3/protobuf-6.33.6-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:e2afbae9b8e1825e3529f88d514754e094278bb95eadc0e199751cdd9a2e82a2", size = 324610, upload-time = "2026-03-18T19:04:53.096Z" }, + { url = "https://files.pythonhosted.org/packages/9b/ca/25afc144934014700c52e05103c2421997482d561f3101ff352e1292fb81/protobuf-6.33.6-cp39-abi3-manylinux2014_s390x.whl", hash = "sha256:c96c37eec15086b79762ed265d59ab204dabc53056e3443e702d2681f4b39ce3", size = 339381, upload-time = "2026-03-18T19:04:54.616Z" }, + { url = "https://files.pythonhosted.org/packages/16/92/d1e32e3e0d894fe00b15ce28ad4944ab692713f2e7f0a99787405e43533a/protobuf-6.33.6-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:e9db7e292e0ab79dd108d7f1a94fe31601ce1ee3f7b79e0692043423020b0593", size = 323436, upload-time = "2026-03-18T19:04:55.768Z" }, + { url = "https://files.pythonhosted.org/packages/c4/72/02445137af02769918a93807b2b7890047c32bfb9f90371cbc12688819eb/protobuf-6.33.6-py3-none-any.whl", hash = "sha256:77179e006c476e69bf8e8ce866640091ec42e1beb80b213c3900006ecfba6901", size = 170656, upload-time = "2026-03-18T19:04:59.826Z" }, +] + [[package]] name = "pycparser" version = "2.23" @@ -2766,3 +2925,81 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/68/a1/dcb68430b1d00b698ae7a7e0194433bce4f07ded185f0ee5fb21e2a2e91e/websockets-15.0.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:cad21560da69f4ce7658ca2cb83138fb4cf695a2ba3e475e0559e05991aa8122", size = 176884, upload-time = "2025-03-05T20:03:27.934Z" }, { url = "https://files.pythonhosted.org/packages/fa/a8/5b41e0da817d64113292ab1f8247140aac61cbf6cfd085d6a0fa77f4984f/websockets-15.0.1-py3-none-any.whl", hash = "sha256:f7a866fbc1e97b5c617ee4116daaa09b722101d4a3c170c787450ba409f9736f", size = 169743, upload-time = "2025-03-05T20:03:39.41Z" }, ] + +[[package]] +name = "wrapt" +version = "1.17.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/95/8f/aeb76c5b46e273670962298c23e7ddde79916cb74db802131d49a85e4b7d/wrapt-1.17.3.tar.gz", hash = "sha256:f66eb08feaa410fe4eebd17f2a2c8e2e46d3476e9f8c783daa8e09e0faa666d0", size = 55547, upload-time = "2025-08-12T05:53:21.714Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3f/23/bb82321b86411eb51e5a5db3fb8f8032fd30bd7c2d74bfe936136b2fa1d6/wrapt-1.17.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:88bbae4d40d5a46142e70d58bf664a89b6b4befaea7b2ecc14e03cedb8e06c04", size = 53482, upload-time = "2025-08-12T05:51:44.467Z" }, + { url = "https://files.pythonhosted.org/packages/45/69/f3c47642b79485a30a59c63f6d739ed779fb4cc8323205d047d741d55220/wrapt-1.17.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e6b13af258d6a9ad602d57d889f83b9d5543acd471eee12eb51f5b01f8eb1bc2", size = 38676, upload-time = "2025-08-12T05:51:32.636Z" }, + { url = "https://files.pythonhosted.org/packages/d1/71/e7e7f5670c1eafd9e990438e69d8fb46fa91a50785332e06b560c869454f/wrapt-1.17.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd341868a4b6714a5962c1af0bd44f7c404ef78720c7de4892901e540417111c", size = 38957, upload-time = "2025-08-12T05:51:54.655Z" }, + { url = "https://files.pythonhosted.org/packages/de/17/9f8f86755c191d6779d7ddead1a53c7a8aa18bccb7cea8e7e72dfa6a8a09/wrapt-1.17.3-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f9b2601381be482f70e5d1051a5965c25fb3625455a2bf520b5a077b22afb775", size = 81975, upload-time = "2025-08-12T05:52:30.109Z" }, + { url = "https://files.pythonhosted.org/packages/f2/15/dd576273491f9f43dd09fce517f6c2ce6eb4fe21681726068db0d0467096/wrapt-1.17.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:343e44b2a8e60e06a7e0d29c1671a0d9951f59174f3709962b5143f60a2a98bd", size = 83149, upload-time = "2025-08-12T05:52:09.316Z" }, + { url = "https://files.pythonhosted.org/packages/0c/c4/5eb4ce0d4814521fee7aa806264bf7a114e748ad05110441cd5b8a5c744b/wrapt-1.17.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:33486899acd2d7d3066156b03465b949da3fd41a5da6e394ec49d271baefcf05", size = 82209, upload-time = "2025-08-12T05:52:10.331Z" }, + { url = "https://files.pythonhosted.org/packages/31/4b/819e9e0eb5c8dc86f60dfc42aa4e2c0d6c3db8732bce93cc752e604bb5f5/wrapt-1.17.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e6f40a8aa5a92f150bdb3e1c44b7e98fb7113955b2e5394122fa5532fec4b418", size = 81551, upload-time = "2025-08-12T05:52:31.137Z" }, + { url = "https://files.pythonhosted.org/packages/f8/83/ed6baf89ba3a56694700139698cf703aac9f0f9eb03dab92f57551bd5385/wrapt-1.17.3-cp310-cp310-win32.whl", hash = "sha256:a36692b8491d30a8c75f1dfee65bef119d6f39ea84ee04d9f9311f83c5ad9390", size = 36464, upload-time = "2025-08-12T05:53:01.204Z" }, + { url = "https://files.pythonhosted.org/packages/2f/90/ee61d36862340ad7e9d15a02529df6b948676b9a5829fd5e16640156627d/wrapt-1.17.3-cp310-cp310-win_amd64.whl", hash = "sha256:afd964fd43b10c12213574db492cb8f73b2f0826c8df07a68288f8f19af2ebe6", size = 38748, upload-time = "2025-08-12T05:53:00.209Z" }, + { url = "https://files.pythonhosted.org/packages/bd/c3/cefe0bd330d389c9983ced15d326f45373f4073c9f4a8c2f99b50bfea329/wrapt-1.17.3-cp310-cp310-win_arm64.whl", hash = "sha256:af338aa93554be859173c39c85243970dc6a289fa907402289eeae7543e1ae18", size = 36810, upload-time = "2025-08-12T05:52:51.906Z" }, + { url = "https://files.pythonhosted.org/packages/52/db/00e2a219213856074a213503fdac0511203dceefff26e1daa15250cc01a0/wrapt-1.17.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:273a736c4645e63ac582c60a56b0acb529ef07f78e08dc6bfadf6a46b19c0da7", size = 53482, upload-time = "2025-08-12T05:51:45.79Z" }, + { url = "https://files.pythonhosted.org/packages/5e/30/ca3c4a5eba478408572096fe9ce36e6e915994dd26a4e9e98b4f729c06d9/wrapt-1.17.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5531d911795e3f935a9c23eb1c8c03c211661a5060aab167065896bbf62a5f85", size = 38674, upload-time = "2025-08-12T05:51:34.629Z" }, + { url = "https://files.pythonhosted.org/packages/31/25/3e8cc2c46b5329c5957cec959cb76a10718e1a513309c31399a4dad07eb3/wrapt-1.17.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0610b46293c59a3adbae3dee552b648b984176f8562ee0dba099a56cfbe4df1f", size = 38959, upload-time = "2025-08-12T05:51:56.074Z" }, + { url = "https://files.pythonhosted.org/packages/5d/8f/a32a99fc03e4b37e31b57cb9cefc65050ea08147a8ce12f288616b05ef54/wrapt-1.17.3-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b32888aad8b6e68f83a8fdccbf3165f5469702a7544472bdf41f582970ed3311", size = 82376, upload-time = "2025-08-12T05:52:32.134Z" }, + { url = "https://files.pythonhosted.org/packages/31/57/4930cb8d9d70d59c27ee1332a318c20291749b4fba31f113c2f8ac49a72e/wrapt-1.17.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8cccf4f81371f257440c88faed6b74f1053eef90807b77e31ca057b2db74edb1", size = 83604, upload-time = "2025-08-12T05:52:11.663Z" }, + { url = "https://files.pythonhosted.org/packages/a8/f3/1afd48de81d63dd66e01b263a6fbb86e1b5053b419b9b33d13e1f6d0f7d0/wrapt-1.17.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d8a210b158a34164de8bb68b0e7780041a903d7b00c87e906fb69928bf7890d5", size = 82782, upload-time = "2025-08-12T05:52:12.626Z" }, + { url = "https://files.pythonhosted.org/packages/1e/d7/4ad5327612173b144998232f98a85bb24b60c352afb73bc48e3e0d2bdc4e/wrapt-1.17.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:79573c24a46ce11aab457b472efd8d125e5a51da2d1d24387666cd85f54c05b2", size = 82076, upload-time = "2025-08-12T05:52:33.168Z" }, + { url = "https://files.pythonhosted.org/packages/bb/59/e0adfc831674a65694f18ea6dc821f9fcb9ec82c2ce7e3d73a88ba2e8718/wrapt-1.17.3-cp311-cp311-win32.whl", hash = "sha256:c31eebe420a9a5d2887b13000b043ff6ca27c452a9a22fa71f35f118e8d4bf89", size = 36457, upload-time = "2025-08-12T05:53:03.936Z" }, + { url = "https://files.pythonhosted.org/packages/83/88/16b7231ba49861b6f75fc309b11012ede4d6b0a9c90969d9e0db8d991aeb/wrapt-1.17.3-cp311-cp311-win_amd64.whl", hash = "sha256:0b1831115c97f0663cb77aa27d381237e73ad4f721391a9bfb2fe8bc25fa6e77", size = 38745, upload-time = "2025-08-12T05:53:02.885Z" }, + { url = "https://files.pythonhosted.org/packages/9a/1e/c4d4f3398ec073012c51d1c8d87f715f56765444e1a4b11e5180577b7e6e/wrapt-1.17.3-cp311-cp311-win_arm64.whl", hash = "sha256:5a7b3c1ee8265eb4c8f1b7d29943f195c00673f5ab60c192eba2d4a7eae5f46a", size = 36806, upload-time = "2025-08-12T05:52:53.368Z" }, + { url = "https://files.pythonhosted.org/packages/9f/41/cad1aba93e752f1f9268c77270da3c469883d56e2798e7df6240dcb2287b/wrapt-1.17.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ab232e7fdb44cdfbf55fc3afa31bcdb0d8980b9b95c38b6405df2acb672af0e0", size = 53998, upload-time = "2025-08-12T05:51:47.138Z" }, + { url = "https://files.pythonhosted.org/packages/60/f8/096a7cc13097a1869fe44efe68dace40d2a16ecb853141394047f0780b96/wrapt-1.17.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:9baa544e6acc91130e926e8c802a17f3b16fbea0fd441b5a60f5cf2cc5c3deba", size = 39020, upload-time = "2025-08-12T05:51:35.906Z" }, + { url = "https://files.pythonhosted.org/packages/33/df/bdf864b8997aab4febb96a9ae5c124f700a5abd9b5e13d2a3214ec4be705/wrapt-1.17.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6b538e31eca1a7ea4605e44f81a48aa24c4632a277431a6ed3f328835901f4fd", size = 39098, upload-time = "2025-08-12T05:51:57.474Z" }, + { url = "https://files.pythonhosted.org/packages/9f/81/5d931d78d0eb732b95dc3ddaeeb71c8bb572fb01356e9133916cd729ecdd/wrapt-1.17.3-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:042ec3bb8f319c147b1301f2393bc19dba6e176b7da446853406d041c36c7828", size = 88036, upload-time = "2025-08-12T05:52:34.784Z" }, + { url = "https://files.pythonhosted.org/packages/ca/38/2e1785df03b3d72d34fc6252d91d9d12dc27a5c89caef3335a1bbb8908ca/wrapt-1.17.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3af60380ba0b7b5aeb329bc4e402acd25bd877e98b3727b0135cb5c2efdaefe9", size = 88156, upload-time = "2025-08-12T05:52:13.599Z" }, + { url = "https://files.pythonhosted.org/packages/b3/8b/48cdb60fe0603e34e05cffda0b2a4adab81fd43718e11111a4b0100fd7c1/wrapt-1.17.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0b02e424deef65c9f7326d8c19220a2c9040c51dc165cddb732f16198c168396", size = 87102, upload-time = "2025-08-12T05:52:14.56Z" }, + { url = "https://files.pythonhosted.org/packages/3c/51/d81abca783b58f40a154f1b2c56db1d2d9e0d04fa2d4224e357529f57a57/wrapt-1.17.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:74afa28374a3c3a11b3b5e5fca0ae03bef8450d6aa3ab3a1e2c30e3a75d023dc", size = 87732, upload-time = "2025-08-12T05:52:36.165Z" }, + { url = "https://files.pythonhosted.org/packages/9e/b1/43b286ca1392a006d5336412d41663eeef1ad57485f3e52c767376ba7e5a/wrapt-1.17.3-cp312-cp312-win32.whl", hash = "sha256:4da9f45279fff3543c371d5ababc57a0384f70be244de7759c85a7f989cb4ebe", size = 36705, upload-time = "2025-08-12T05:53:07.123Z" }, + { url = "https://files.pythonhosted.org/packages/28/de/49493f962bd3c586ab4b88066e967aa2e0703d6ef2c43aa28cb83bf7b507/wrapt-1.17.3-cp312-cp312-win_amd64.whl", hash = "sha256:e71d5c6ebac14875668a1e90baf2ea0ef5b7ac7918355850c0908ae82bcb297c", size = 38877, upload-time = "2025-08-12T05:53:05.436Z" }, + { url = "https://files.pythonhosted.org/packages/f1/48/0f7102fe9cb1e8a5a77f80d4f0956d62d97034bbe88d33e94699f99d181d/wrapt-1.17.3-cp312-cp312-win_arm64.whl", hash = "sha256:604d076c55e2fdd4c1c03d06dc1a31b95130010517b5019db15365ec4a405fc6", size = 36885, upload-time = "2025-08-12T05:52:54.367Z" }, + { url = "https://files.pythonhosted.org/packages/fc/f6/759ece88472157acb55fc195e5b116e06730f1b651b5b314c66291729193/wrapt-1.17.3-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a47681378a0439215912ef542c45a783484d4dd82bac412b71e59cf9c0e1cea0", size = 54003, upload-time = "2025-08-12T05:51:48.627Z" }, + { url = "https://files.pythonhosted.org/packages/4f/a9/49940b9dc6d47027dc850c116d79b4155f15c08547d04db0f07121499347/wrapt-1.17.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:54a30837587c6ee3cd1a4d1c2ec5d24e77984d44e2f34547e2323ddb4e22eb77", size = 39025, upload-time = "2025-08-12T05:51:37.156Z" }, + { url = "https://files.pythonhosted.org/packages/45/35/6a08de0f2c96dcdd7fe464d7420ddb9a7655a6561150e5fc4da9356aeaab/wrapt-1.17.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:16ecf15d6af39246fe33e507105d67e4b81d8f8d2c6598ff7e3ca1b8a37213f7", size = 39108, upload-time = "2025-08-12T05:51:58.425Z" }, + { url = "https://files.pythonhosted.org/packages/0c/37/6faf15cfa41bf1f3dba80cd3f5ccc6622dfccb660ab26ed79f0178c7497f/wrapt-1.17.3-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:6fd1ad24dc235e4ab88cda009e19bf347aabb975e44fd5c2fb22a3f6e4141277", size = 88072, upload-time = "2025-08-12T05:52:37.53Z" }, + { url = "https://files.pythonhosted.org/packages/78/f2/efe19ada4a38e4e15b6dff39c3e3f3f73f5decf901f66e6f72fe79623a06/wrapt-1.17.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0ed61b7c2d49cee3c027372df5809a59d60cf1b6c2f81ee980a091f3afed6a2d", size = 88214, upload-time = "2025-08-12T05:52:15.886Z" }, + { url = "https://files.pythonhosted.org/packages/40/90/ca86701e9de1622b16e09689fc24b76f69b06bb0150990f6f4e8b0eeb576/wrapt-1.17.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:423ed5420ad5f5529db9ce89eac09c8a2f97da18eb1c870237e84c5a5c2d60aa", size = 87105, upload-time = "2025-08-12T05:52:17.914Z" }, + { url = "https://files.pythonhosted.org/packages/fd/e0/d10bd257c9a3e15cbf5523025252cc14d77468e8ed644aafb2d6f54cb95d/wrapt-1.17.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e01375f275f010fcbf7f643b4279896d04e571889b8a5b3f848423d91bf07050", size = 87766, upload-time = "2025-08-12T05:52:39.243Z" }, + { url = "https://files.pythonhosted.org/packages/e8/cf/7d848740203c7b4b27eb55dbfede11aca974a51c3d894f6cc4b865f42f58/wrapt-1.17.3-cp313-cp313-win32.whl", hash = "sha256:53e5e39ff71b3fc484df8a522c933ea2b7cdd0d5d15ae82e5b23fde87d44cbd8", size = 36711, upload-time = "2025-08-12T05:53:10.074Z" }, + { url = "https://files.pythonhosted.org/packages/57/54/35a84d0a4d23ea675994104e667ceff49227ce473ba6a59ba2c84f250b74/wrapt-1.17.3-cp313-cp313-win_amd64.whl", hash = "sha256:1f0b2f40cf341ee8cc1a97d51ff50dddb9fcc73241b9143ec74b30fc4f44f6cb", size = 38885, upload-time = "2025-08-12T05:53:08.695Z" }, + { url = "https://files.pythonhosted.org/packages/01/77/66e54407c59d7b02a3c4e0af3783168fff8e5d61def52cda8728439d86bc/wrapt-1.17.3-cp313-cp313-win_arm64.whl", hash = "sha256:7425ac3c54430f5fc5e7b6f41d41e704db073309acfc09305816bc6a0b26bb16", size = 36896, upload-time = "2025-08-12T05:52:55.34Z" }, + { url = "https://files.pythonhosted.org/packages/02/a2/cd864b2a14f20d14f4c496fab97802001560f9f41554eef6df201cd7f76c/wrapt-1.17.3-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:cf30f6e3c077c8e6a9a7809c94551203c8843e74ba0c960f4a98cd80d4665d39", size = 54132, upload-time = "2025-08-12T05:51:49.864Z" }, + { url = "https://files.pythonhosted.org/packages/d5/46/d011725b0c89e853dc44cceb738a307cde5d240d023d6d40a82d1b4e1182/wrapt-1.17.3-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:e228514a06843cae89621384cfe3a80418f3c04aadf8a3b14e46a7be704e4235", size = 39091, upload-time = "2025-08-12T05:51:38.935Z" }, + { url = "https://files.pythonhosted.org/packages/2e/9e/3ad852d77c35aae7ddebdbc3b6d35ec8013af7d7dddad0ad911f3d891dae/wrapt-1.17.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:5ea5eb3c0c071862997d6f3e02af1d055f381b1d25b286b9d6644b79db77657c", size = 39172, upload-time = "2025-08-12T05:51:59.365Z" }, + { url = "https://files.pythonhosted.org/packages/c3/f7/c983d2762bcce2326c317c26a6a1e7016f7eb039c27cdf5c4e30f4160f31/wrapt-1.17.3-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:281262213373b6d5e4bb4353bc36d1ba4084e6d6b5d242863721ef2bf2c2930b", size = 87163, upload-time = "2025-08-12T05:52:40.965Z" }, + { url = "https://files.pythonhosted.org/packages/e4/0f/f673f75d489c7f22d17fe0193e84b41540d962f75fce579cf6873167c29b/wrapt-1.17.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:dc4a8d2b25efb6681ecacad42fca8859f88092d8732b170de6a5dddd80a1c8fa", size = 87963, upload-time = "2025-08-12T05:52:20.326Z" }, + { url = "https://files.pythonhosted.org/packages/df/61/515ad6caca68995da2fac7a6af97faab8f78ebe3bf4f761e1b77efbc47b5/wrapt-1.17.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:373342dd05b1d07d752cecbec0c41817231f29f3a89aa8b8843f7b95992ed0c7", size = 86945, upload-time = "2025-08-12T05:52:21.581Z" }, + { url = "https://files.pythonhosted.org/packages/d3/bd/4e70162ce398462a467bc09e768bee112f1412e563620adc353de9055d33/wrapt-1.17.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:d40770d7c0fd5cbed9d84b2c3f2e156431a12c9a37dc6284060fb4bec0b7ffd4", size = 86857, upload-time = "2025-08-12T05:52:43.043Z" }, + { url = "https://files.pythonhosted.org/packages/2b/b8/da8560695e9284810b8d3df8a19396a6e40e7518059584a1a394a2b35e0a/wrapt-1.17.3-cp314-cp314-win32.whl", hash = "sha256:fbd3c8319de8e1dc79d346929cd71d523622da527cca14e0c1d257e31c2b8b10", size = 37178, upload-time = "2025-08-12T05:53:12.605Z" }, + { url = "https://files.pythonhosted.org/packages/db/c8/b71eeb192c440d67a5a0449aaee2310a1a1e8eca41676046f99ed2487e9f/wrapt-1.17.3-cp314-cp314-win_amd64.whl", hash = "sha256:e1a4120ae5705f673727d3253de3ed0e016f7cd78dc463db1b31e2463e1f3cf6", size = 39310, upload-time = "2025-08-12T05:53:11.106Z" }, + { url = "https://files.pythonhosted.org/packages/45/20/2cda20fd4865fa40f86f6c46ed37a2a8356a7a2fde0773269311f2af56c7/wrapt-1.17.3-cp314-cp314-win_arm64.whl", hash = "sha256:507553480670cab08a800b9463bdb881b2edeed77dc677b0a5915e6106e91a58", size = 37266, upload-time = "2025-08-12T05:52:56.531Z" }, + { url = "https://files.pythonhosted.org/packages/77/ed/dd5cf21aec36c80443c6f900449260b80e2a65cf963668eaef3b9accce36/wrapt-1.17.3-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:ed7c635ae45cfbc1a7371f708727bf74690daedc49b4dba310590ca0bd28aa8a", size = 56544, upload-time = "2025-08-12T05:51:51.109Z" }, + { url = "https://files.pythonhosted.org/packages/8d/96/450c651cc753877ad100c7949ab4d2e2ecc4d97157e00fa8f45df682456a/wrapt-1.17.3-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:249f88ed15503f6492a71f01442abddd73856a0032ae860de6d75ca62eed8067", size = 40283, upload-time = "2025-08-12T05:51:39.912Z" }, + { url = "https://files.pythonhosted.org/packages/d1/86/2fcad95994d9b572db57632acb6f900695a648c3e063f2cd344b3f5c5a37/wrapt-1.17.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:5a03a38adec8066d5a37bea22f2ba6bbf39fcdefbe2d91419ab864c3fb515454", size = 40366, upload-time = "2025-08-12T05:52:00.693Z" }, + { url = "https://files.pythonhosted.org/packages/64/0e/f4472f2fdde2d4617975144311f8800ef73677a159be7fe61fa50997d6c0/wrapt-1.17.3-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:5d4478d72eb61c36e5b446e375bbc49ed002430d17cdec3cecb36993398e1a9e", size = 108571, upload-time = "2025-08-12T05:52:44.521Z" }, + { url = "https://files.pythonhosted.org/packages/cc/01/9b85a99996b0a97c8a17484684f206cbb6ba73c1ce6890ac668bcf3838fb/wrapt-1.17.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:223db574bb38637e8230eb14b185565023ab624474df94d2af18f1cdb625216f", size = 113094, upload-time = "2025-08-12T05:52:22.618Z" }, + { url = "https://files.pythonhosted.org/packages/25/02/78926c1efddcc7b3aa0bc3d6b33a822f7d898059f7cd9ace8c8318e559ef/wrapt-1.17.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:e405adefb53a435f01efa7ccdec012c016b5a1d3f35459990afc39b6be4d5056", size = 110659, upload-time = "2025-08-12T05:52:24.057Z" }, + { url = "https://files.pythonhosted.org/packages/dc/ee/c414501ad518ac3e6fe184753632fe5e5ecacdcf0effc23f31c1e4f7bfcf/wrapt-1.17.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:88547535b787a6c9ce4086917b6e1d291aa8ed914fdd3a838b3539dc95c12804", size = 106946, upload-time = "2025-08-12T05:52:45.976Z" }, + { url = "https://files.pythonhosted.org/packages/be/44/a1bd64b723d13bb151d6cc91b986146a1952385e0392a78567e12149c7b4/wrapt-1.17.3-cp314-cp314t-win32.whl", hash = "sha256:41b1d2bc74c2cac6f9074df52b2efbef2b30bdfe5f40cb78f8ca22963bc62977", size = 38717, upload-time = "2025-08-12T05:53:15.214Z" }, + { url = "https://files.pythonhosted.org/packages/79/d9/7cfd5a312760ac4dd8bf0184a6ee9e43c33e47f3dadc303032ce012b8fa3/wrapt-1.17.3-cp314-cp314t-win_amd64.whl", hash = "sha256:73d496de46cd2cdbdbcce4ae4bcdb4afb6a11234a1df9c085249d55166b95116", size = 41334, upload-time = "2025-08-12T05:53:14.178Z" }, + { url = "https://files.pythonhosted.org/packages/46/78/10ad9781128ed2f99dbc474f43283b13fea8ba58723e98844367531c18e9/wrapt-1.17.3-cp314-cp314t-win_arm64.whl", hash = "sha256:f38e60678850c42461d4202739f9bf1e3a737c7ad283638251e79cc49effb6b6", size = 38471, upload-time = "2025-08-12T05:52:57.784Z" }, + { url = "https://files.pythonhosted.org/packages/1f/f6/a933bd70f98e9cf3e08167fc5cd7aaaca49147e48411c0bd5ae701bb2194/wrapt-1.17.3-py3-none-any.whl", hash = "sha256:7171ae35d2c33d326ac19dd8facb1e82e5fd04ef8c6c0e394d7af55a55051c22", size = 23591, upload-time = "2025-08-12T05:53:20.674Z" }, +] + +[[package]] +name = "zipp" +version = "3.23.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e3/02/0f2892c661036d50ede074e376733dca2ae7c6eb617489437771209d4180/zipp-3.23.0.tar.gz", hash = "sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166", size = 25547, upload-time = "2025-06-08T17:06:39.4Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2e/54/647ade08bf0db230bfea292f893923872fd20be6ac6f53b2b936ba839d75/zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e", size = 10276, upload-time = "2025-06-08T17:06:38.034Z" }, +] From d5b9155f14ab3b4b203a47c55b53955b2d0bca09 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 31 Mar 2026 17:11:12 -0400 Subject: [PATCH 52/84] chore(deps): bump requests from 2.32.5 to 2.33.0 in the uv group across 1 directory (#2350) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- uv.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/uv.lock b/uv.lock index 5efbb05dce..e6fb526167 100644 --- a/uv.lock +++ b/uv.lock @@ -2394,7 +2394,7 @@ wheels = [ [[package]] name = "requests" -version = "2.32.5" +version = "2.33.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi" }, @@ -2402,9 +2402,9 @@ dependencies = [ { name = "idna" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517, upload-time = "2025-08-18T20:46:02.573Z" } +sdist = { url = "https://files.pythonhosted.org/packages/34/64/8860370b167a9721e8956ae116825caff829224fbca0ca6e7bf8ddef8430/requests-2.33.0.tar.gz", hash = "sha256:c7ebc5e8b0f21837386ad0e1c8fe8b829fa5f544d8df3b2253bff14ef29d7652", size = 134232, upload-time = "2026-03-25T15:10:41.586Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, + { url = "https://files.pythonhosted.org/packages/56/5d/c814546c2333ceea4ba42262d8c4d55763003e767fa169adc693bd524478/requests-2.33.0-py3-none-any.whl", hash = "sha256:3324635456fa185245e24865e810cecec7b4caf933d7eb133dcde67d48cee69b", size = 65017, upload-time = "2026-03-25T15:10:40.382Z" }, ] [[package]] From cf4e435db015c9fe2d6534ad1acb30b9bb94ca7d Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 8 Apr 2026 14:12:20 +0200 Subject: [PATCH 53/84] Use shared `is_async_callable` instead of `inspect.iscoroutinefunction` (#2389) --- src/mcp/server/mcpserver/prompts/base.py | 7 ++-- .../server/mcpserver/resources/templates.py | 7 ++-- src/mcp/server/mcpserver/resources/types.py | 11 ++++--- src/mcp/server/mcpserver/tools/base.py | 14 ++------ src/mcp/shared/_callable_inspection.py | 33 +++++++++++++++++++ 5 files changed, 50 insertions(+), 22 deletions(-) create mode 100644 src/mcp/shared/_callable_inspection.py diff --git a/src/mcp/server/mcpserver/prompts/base.py b/src/mcp/server/mcpserver/prompts/base.py index b4810c100e..e5b2af7d82 100644 --- a/src/mcp/server/mcpserver/prompts/base.py +++ b/src/mcp/server/mcpserver/prompts/base.py @@ -3,7 +3,6 @@ from __future__ import annotations import functools -import inspect from collections.abc import Awaitable, Callable, Sequence from typing import TYPE_CHECKING, Any, Literal @@ -13,6 +12,7 @@ from mcp.server.mcpserver.utilities.context_injection import find_context_parameter, inject_context from mcp.server.mcpserver.utilities.func_metadata import func_metadata +from mcp.shared._callable_inspection import is_async_callable from mcp.types import ContentBlock, Icon, TextContent if TYPE_CHECKING: @@ -157,8 +157,9 @@ async def render( # Add context to arguments if needed call_args = inject_context(self.fn, arguments or {}, context, self.context_kwarg) - if inspect.iscoroutinefunction(self.fn): - result = await self.fn(**call_args) + fn = self.fn + if is_async_callable(fn): + result = await fn(**call_args) else: result = await anyio.to_thread.run_sync(functools.partial(self.fn, **call_args)) diff --git a/src/mcp/server/mcpserver/resources/templates.py b/src/mcp/server/mcpserver/resources/templates.py index 542b5e6f81..f1ee29a37f 100644 --- a/src/mcp/server/mcpserver/resources/templates.py +++ b/src/mcp/server/mcpserver/resources/templates.py @@ -3,7 +3,6 @@ from __future__ import annotations import functools -import inspect import re from collections.abc import Callable from typing import TYPE_CHECKING, Any @@ -15,6 +14,7 @@ from mcp.server.mcpserver.resources.types import FunctionResource, Resource from mcp.server.mcpserver.utilities.context_injection import find_context_parameter, inject_context from mcp.server.mcpserver.utilities.func_metadata import func_metadata +from mcp.shared._callable_inspection import is_async_callable from mcp.types import Annotations, Icon if TYPE_CHECKING: @@ -112,8 +112,9 @@ async def create_resource( # Add context to params if needed params = inject_context(self.fn, params, context, self.context_kwarg) - if inspect.iscoroutinefunction(self.fn): - result = await self.fn(**params) + fn = self.fn + if is_async_callable(fn): + result = await fn(**params) else: result = await anyio.to_thread.run_sync(functools.partial(self.fn, **params)) diff --git a/src/mcp/server/mcpserver/resources/types.py b/src/mcp/server/mcpserver/resources/types.py index 04763be8ba..d9e472e362 100644 --- a/src/mcp/server/mcpserver/resources/types.py +++ b/src/mcp/server/mcpserver/resources/types.py @@ -1,6 +1,7 @@ """Concrete resource implementations.""" -import inspect +from __future__ import annotations + import json from collections.abc import Callable from pathlib import Path @@ -14,6 +15,7 @@ from pydantic import Field, ValidationInfo, validate_call from mcp.server.mcpserver.resources.base import Resource +from mcp.shared._callable_inspection import is_async_callable from mcp.types import Annotations, Icon @@ -55,8 +57,9 @@ class FunctionResource(Resource): async def read(self) -> str | bytes: """Read the resource by calling the wrapped function.""" try: - if inspect.iscoroutinefunction(self.fn): - result = await self.fn() + fn = self.fn + if is_async_callable(fn): + result = await fn() else: result = await anyio.to_thread.run_sync(self.fn) @@ -83,7 +86,7 @@ def from_function( icons: list[Icon] | None = None, annotations: Annotations | None = None, meta: dict[str, Any] | None = None, - ) -> "FunctionResource": + ) -> FunctionResource: """Create a FunctionResource from a function.""" func_name = name or fn.__name__ if func_name == "": # pragma: no cover diff --git a/src/mcp/server/mcpserver/tools/base.py b/src/mcp/server/mcpserver/tools/base.py index dc65be9885..754313eb8a 100644 --- a/src/mcp/server/mcpserver/tools/base.py +++ b/src/mcp/server/mcpserver/tools/base.py @@ -1,7 +1,5 @@ from __future__ import annotations -import functools -import inspect from collections.abc import Callable from functools import cached_property from typing import TYPE_CHECKING, Any @@ -11,6 +9,7 @@ from mcp.server.mcpserver.exceptions import ToolError from mcp.server.mcpserver.utilities.context_injection import find_context_parameter from mcp.server.mcpserver.utilities.func_metadata import FuncMetadata, func_metadata +from mcp.shared._callable_inspection import is_async_callable from mcp.shared.exceptions import UrlElicitationRequiredError from mcp.shared.tool_name_validation import validate_and_warn_tool_name from mcp.types import Icon, ToolAnnotations @@ -63,7 +62,7 @@ def from_function( raise ValueError("You must provide a name for lambda functions") func_doc = description or fn.__doc__ or "" - is_async = _is_async_callable(fn) + is_async = is_async_callable(fn) if context_kwarg is None: # pragma: no branch context_kwarg = find_context_parameter(fn) @@ -118,12 +117,3 @@ async def run( raise except Exception as e: raise ToolError(f"Error executing tool {self.name}: {e}") from e - - -def _is_async_callable(obj: Any) -> bool: - while isinstance(obj, functools.partial): # pragma: lax no cover - obj = obj.func - - return inspect.iscoroutinefunction(obj) or ( - callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None)) - ) diff --git a/src/mcp/shared/_callable_inspection.py b/src/mcp/shared/_callable_inspection.py new file mode 100644 index 0000000000..0e89e446f8 --- /dev/null +++ b/src/mcp/shared/_callable_inspection.py @@ -0,0 +1,33 @@ +"""Callable inspection utilities. + +Adapted from Starlette's `is_async_callable` implementation. +https://github.com/encode/starlette/blob/main/starlette/_utils.py +""" + +from __future__ import annotations + +import functools +import inspect +from collections.abc import Awaitable, Callable +from typing import Any, TypeGuard, TypeVar, overload + +T = TypeVar("T") + +AwaitableCallable = Callable[..., Awaitable[T]] + + +@overload +def is_async_callable(obj: AwaitableCallable[T]) -> TypeGuard[AwaitableCallable[T]]: ... + + +@overload +def is_async_callable(obj: Any) -> TypeGuard[AwaitableCallable[Any]]: ... + + +def is_async_callable(obj: Any) -> Any: + while isinstance(obj, functools.partial): # pragma: lax no cover + obj = obj.func + + return inspect.iscoroutinefunction(obj) or ( + callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None)) + ) From f27d2aac055ed031512eba84b31ac849bb81da2f Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 9 Apr 2026 13:25:16 +0100 Subject: [PATCH 54/84] docs: fill migration guide gaps surfaced by automated upgrade eval (#2412) --- docs/migration.md | 240 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 235 insertions(+), 5 deletions(-) diff --git a/docs/migration.md b/docs/migration.md index 3b47f9aade..2528f046c6 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -38,6 +38,7 @@ http_client = httpx.AsyncClient( headers={"Authorization": "Bearer token"}, timeout=httpx.Timeout(30, read=300), auth=my_auth, + follow_redirects=True, ) async with http_client: @@ -48,6 +49,8 @@ async with http_client: ... ``` +v1's internal client set `follow_redirects=True`; set it explicitly when supplying your own `httpx.AsyncClient` to preserve that behavior. + ### `get_session_id` callback removed from `streamable_http_client` The `get_session_id` callback (third element of the returned tuple) has been removed from `streamable_http_client`. The function now returns a 2-tuple `(read_stream, write_stream)` instead of a 3-tuple. @@ -100,6 +103,8 @@ async with http_client: The `headers`, `timeout`, `sse_read_timeout`, and `auth` parameters have been removed from `StreamableHTTPTransport`. Configure these on the `httpx.AsyncClient` instead (see example above). +Note: `sse_client` retains its `headers`, `timeout`, `sse_read_timeout`, and `auth` parameters — only the streamable HTTP transport changed. + ### Removed type aliases and classes The following deprecated type aliases and classes have been removed from `mcp.types`: @@ -126,6 +131,52 @@ from mcp.types import ContentBlock, ResourceTemplateReference # Use `str` instead of `Cursor` for pagination cursors ``` +### Field names changed from camelCase to snake_case + +All Pydantic model fields in `mcp.types` now use snake_case names for Python attribute access. The JSON wire format is unchanged — serialization still uses camelCase via Pydantic aliases. + +**Before (v1):** + +```python +result = await session.call_tool("my_tool", {"x": 1}) +if result.isError: + ... + +tools = await session.list_tools() +cursor = tools.nextCursor +schema = tools.tools[0].inputSchema +``` + +**After (v2):** + +```python +result = await session.call_tool("my_tool", {"x": 1}) +if result.is_error: + ... + +tools = await session.list_tools() +cursor = tools.next_cursor +schema = tools.tools[0].input_schema +``` + +Common renames: + +| v1 (camelCase) | v2 (snake_case) | +|----------------|-----------------| +| `inputSchema` | `input_schema` | +| `outputSchema` | `output_schema` | +| `isError` | `is_error` | +| `nextCursor` | `next_cursor` | +| `mimeType` | `mime_type` | +| `structuredContent` | `structured_content` | +| `serverInfo` | `server_info` | +| `protocolVersion` | `protocol_version` | +| `uriTemplate` | `uri_template` | +| `listChanged` | `list_changed` | +| `progressToken` | `progress_token` | + +Because `populate_by_name=True` is set, the old camelCase names still work as constructor kwargs (e.g., `Tool(inputSchema={...})` is accepted), but attribute access must use snake_case (`tool.input_schema`). + ### `args` parameter removed from `ClientSessionGroup.call_tool()` The deprecated `args` parameter has been removed from `ClientSessionGroup.call_tool()`. Use `arguments` instead. @@ -225,6 +276,28 @@ except MCPError as e: from mcp import MCPError ``` +The constructor signature also changed — it now takes `code`, `message`, and optional `data` directly instead of wrapping an `ErrorData`: + +**Before (v1):** + +```python +from mcp.shared.exceptions import McpError +from mcp.types import ErrorData, INVALID_REQUEST + +raise McpError(ErrorData(code=INVALID_REQUEST, message="bad input")) +``` + +**After (v2):** + +```python +from mcp.shared.exceptions import MCPError +from mcp.types import INVALID_REQUEST + +raise MCPError(INVALID_REQUEST, "bad input") +# or, if you already have an ErrorData: +raise MCPError.from_error_data(error_data) +``` + ### `FastMCP` renamed to `MCPServer` The `FastMCP` class has been renamed to `MCPServer` to better reflect its role as the main server class in the SDK. This is a simple rename with no functional changes to the class itself. @@ -240,11 +313,19 @@ mcp = FastMCP("Demo") **After (v2):** ```python -from mcp.server.mcpserver import MCPServer +from mcp.server.mcpserver import MCPServer, Context mcp = MCPServer("Demo") ``` +`Context` is the type annotation for the `ctx` parameter injected into tools, resources, and prompts (see [`get_context()` removed](#mcpserverget_context-removed) below). + +All submodules under `mcp.server.fastmcp.*` are now under `mcp.server.mcpserver.*` with the same structure. Common imports: + +- `Image`, `Audio` — from `mcp.server.mcpserver` (or `.utilities.types`) +- `UserMessage`, `AssistantMessage` — from `mcp.server.mcpserver.prompts.base` +- `ToolError`, `ResourceError` — from `mcp.server.mcpserver.exceptions` + ### `mount_path` parameter removed from MCPServer The `mount_path` parameter has been removed from `MCPServer.__init__()`, `MCPServer.run()`, `MCPServer.run_sse_async()`, and `MCPServer.sse_app()`. It was also removed from the `Settings` class. @@ -312,6 +393,8 @@ app = Starlette(routes=[Mount("/", app=mcp.streamable_http_app(json_response=Tru **Note:** DNS rebinding protection is automatically enabled when `host` is `127.0.0.1`, `localhost`, or `::1`. This now happens in `sse_app()` and `streamable_http_app()` instead of the constructor. +If you were mutating these via `mcp.settings` after construction (e.g., `mcp.settings.port = 9000`), pass them to `run()` / `sse_app()` / `streamable_http_app()` instead — these fields no longer exist on `Settings`. The `debug` and `log_level` parameters remain on the constructor. + ### `MCPServer.get_context()` removed `MCPServer.get_context()` has been removed. Context is now injected by the framework and passed explicitly — there is no ambient ContextVar to read from. @@ -331,6 +414,8 @@ async def my_tool(x: int) -> str: **After (v2):** ```python +from mcp.server.mcpserver import Context + @mcp.tool() async def my_tool(x: int, ctx: Context) -> str: await ctx.info("Processing...") @@ -343,6 +428,45 @@ async def my_tool(x: int, ctx: Context) -> str: The internal layers (`ToolManager.call_tool`, `Tool.run`, `Prompt.render`, `ResourceTemplate.create_resource`, etc.) now require `context` as a positional argument. +### Registering lowlevel handlers on `MCPServer` (workaround) + +`MCPServer` does not expose public APIs for `subscribe_resource`, `unsubscribe_resource`, or `set_logging_level` handlers. In v1, the workaround was to reach into the private lowlevel server and use its decorator methods: + +**Before (v1):** + +```python +@mcp._mcp_server.set_logging_level() # pyright: ignore[reportPrivateUsage] +async def handle_set_logging_level(level: str) -> None: + ... + +mcp._mcp_server.subscribe_resource()(handle_subscribe) # pyright: ignore[reportPrivateUsage] +``` + +In v2, the lowlevel `Server` no longer has decorator methods (handlers are constructor-only), so the equivalent workaround is `_add_request_handler`: + +**After (v2):** + +```python +from mcp.server import ServerRequestContext +from mcp.types import EmptyResult, SetLevelRequestParams, SubscribeRequestParams + + +async def handle_set_logging_level(ctx: ServerRequestContext, params: SetLevelRequestParams) -> EmptyResult: + ... + return EmptyResult() + + +async def handle_subscribe(ctx: ServerRequestContext, params: SubscribeRequestParams) -> EmptyResult: + ... + return EmptyResult() + + +mcp._lowlevel_server._add_request_handler("logging/setLevel", handle_set_logging_level) # pyright: ignore[reportPrivateUsage] +mcp._lowlevel_server._add_request_handler("resources/subscribe", handle_subscribe) # pyright: ignore[reportPrivateUsage] +``` + +This is a private API and may change. A public way to register these handlers on `MCPServer` is planned; until then, use this workaround or use the lowlevel `Server` directly. + ### Replace `RootModel` by union types with `TypeAdapter` validation The following union types are no longer `RootModel` subclasses: @@ -383,6 +507,22 @@ notification = server_notification_adapter.validate_python(data) # No .root access needed - notification is the actual type ``` +The same applies when constructing values — the wrapper call is no longer needed: + +**Before (v1):** + +```python +await session.send_notification(ClientNotification(InitializedNotification())) +await session.send_request(ClientRequest(PingRequest()), EmptyResult) +``` + +**After (v2):** + +```python +await session.send_notification(InitializedNotification()) +await session.send_request(PingRequest(), EmptyResult) +``` + **Available adapters:** | Union Type | Adapter | @@ -428,6 +568,8 @@ server = Server("my-server", on_call_tool=handle_call_tool) ### `RequestContext` type parameters simplified +The `mcp.shared.context` module has been removed. `RequestContext` is now split into `ClientRequestContext` (in `mcp.client.context`) and `ServerRequestContext` (in `mcp.server.context`). + The `RequestContext` class has been split to separate shared fields from server-specific fields. The shared `RequestContext` now only takes 1 type parameter (the session type) instead of 3. **`RequestContext` changes:** @@ -458,11 +600,27 @@ ctx: ClientRequestContext server_ctx: ServerRequestContext[LifespanContextT, RequestT] ``` +The high-level `Context` class (injected into `@mcp.tool()` etc.) similarly dropped its `ServerSessionT` parameter: `Context[ServerSessionT, LifespanContextT, RequestT]` → `Context[LifespanContextT, RequestT]`. Both remaining parameters have defaults, so bare `Context` is usually sufficient: + +**Before (v1):** + +```python +async def my_tool(ctx: Context[ServerSession, None]) -> str: ... +``` + +**After (v2):** + +```python +async def my_tool(ctx: Context) -> str: ... +# or, with an explicit lifespan type: +async def my_tool(ctx: Context[MyLifespanState]) -> str: ... +``` + ### `ProgressContext` and `progress()` context manager removed The `mcp.shared.progress` module (`ProgressContext`, `Progress`, and the `progress()` context manager) has been removed. This module had no real-world adoption — all users send progress notifications via `Context.report_progress()` or `session.send_progress_notification()` directly. -**Before:** +**Before (v1):** ```python from mcp.shared.progress import progress @@ -490,6 +648,46 @@ await session.send_progress_notification( ) ``` +### `create_connected_server_and_client_session` removed + +The `create_connected_server_and_client_session` helper in `mcp.shared.memory` has been removed. Use `mcp.client.Client` instead — it accepts a `Server` or `MCPServer` instance directly and handles the in-memory transport and session setup for you. + +**Before (v1):** + +```python +from mcp.shared.memory import create_connected_server_and_client_session + +async with create_connected_server_and_client_session(server) as session: + result = await session.call_tool("my_tool", {"x": 1}) +``` + +**After (v2):** + +```python +from mcp.client import Client + +async with Client(server) as client: + result = await client.call_tool("my_tool", {"x": 1}) +``` + +`Client` accepts the same callback parameters the old helper did (`sampling_callback`, `list_roots_callback`, `logging_callback`, `message_handler`, `elicitation_callback`, `client_info`) plus `raise_exceptions` to surface server-side errors. + +If you need direct access to the underlying `ClientSession` and memory streams (e.g., for low-level transport testing), `create_client_server_memory_streams` is still available in `mcp.shared.memory`: + +```python +import anyio +from mcp.client.session import ClientSession +from mcp.shared.memory import create_client_server_memory_streams + +async with create_client_server_memory_streams() as (client_streams, server_streams): + async with anyio.create_task_group() as tg: + tg.start_soon(lambda: server.run(*server_streams, server.create_initialization_options())) + async with ClientSession(*client_streams) as session: + await session.initialize() + ... + tg.cancel_scope.cancel() +``` + ### Resource URI type changed from `AnyUrl` to `str` The `uri` field on resource-related types now uses `str` instead of Pydantic's `AnyUrl`. This aligns with the [MCP specification schema](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/draft/schema.ts) which defines URIs as plain strings (`uri: string`) without strict URL validation. This change allows relative paths like `users/me` that were previously rejected. @@ -593,6 +791,8 @@ if ListToolsRequest in server.request_handlers: server = Server("my-server", on_list_tools=handle_list_tools) ``` +If you need to check whether a handler is registered, track this yourself — there is currently no public introspection API. + ### Lowlevel `Server`: decorator-based handlers replaced with constructor `on_*` params The lowlevel `Server` class no longer uses decorator methods for handler registration. Instead, handlers are passed as `on_*` keyword arguments to the constructor. @@ -645,6 +845,29 @@ server = Server("my-server", on_list_tools=handle_list_tools, on_call_tool=handl - Handlers return the full result type (e.g. `ListToolsResult`) rather than unwrapped values (e.g. `list[Tool]`). - The automatic `jsonschema` input/output validation that the old `call_tool()` decorator performed has been removed. There is no built-in replacement — if you relied on schema validation in the lowlevel server, you will need to validate inputs yourself in your handler. +**Complete handler reference:** + +All handlers receive `ctx: ServerRequestContext` as the first argument. The second argument and return type are: + +| v1 decorator | v2 constructor kwarg | `params` type | return type | +|---|---|---|---| +| `@server.list_tools()` | `on_list_tools` | `PaginatedRequestParams \| None` | `ListToolsResult` | +| `@server.call_tool()` | `on_call_tool` | `CallToolRequestParams` | `CallToolResult \| CreateTaskResult` | +| `@server.list_resources()` | `on_list_resources` | `PaginatedRequestParams \| None` | `ListResourcesResult` | +| `@server.list_resource_templates()` | `on_list_resource_templates` | `PaginatedRequestParams \| None` | `ListResourceTemplatesResult` | +| `@server.read_resource()` | `on_read_resource` | `ReadResourceRequestParams` | `ReadResourceResult` | +| `@server.subscribe_resource()` | `on_subscribe_resource` | `SubscribeRequestParams` | `EmptyResult` | +| `@server.unsubscribe_resource()` | `on_unsubscribe_resource` | `UnsubscribeRequestParams` | `EmptyResult` | +| `@server.list_prompts()` | `on_list_prompts` | `PaginatedRequestParams \| None` | `ListPromptsResult` | +| `@server.get_prompt()` | `on_get_prompt` | `GetPromptRequestParams` | `GetPromptResult` | +| `@server.completion()` | `on_completion` | `CompleteRequestParams` | `CompleteResult` | +| `@server.set_logging_level()` | `on_set_logging_level` | `SetLevelRequestParams` | `EmptyResult` | +| — | `on_ping` | `RequestParams \| None` | `EmptyResult` | +| `@server.progress_notification()` | `on_progress` | `ProgressNotificationParams` | `None` | +| — | `on_roots_list_changed` | `NotificationParams \| None` | `None` | + +All `params` and return types are importable from `mcp.types`. + **Notification handlers:** ```python @@ -694,10 +917,17 @@ Note: `params.arguments` can be `None` (the old decorator defaulted it to `{}`). **`read_resource()` — content type wrapping removed:** -The old decorator auto-wrapped `str` into `TextResourceContents` and `bytes` into `BlobResourceContents` (with base64 encoding), and applied a default mime type of `text/plain`: +The old decorator auto-wrapped `Iterable[ReadResourceContents]` (and the deprecated `str`/`bytes` shorthand) into `TextResourceContents`/`BlobResourceContents`, handling base64 encoding and mime-type defaulting: ```python -# Before (v1) — str/bytes auto-wrapped with mime type defaulting +# Before (v1) — Iterable[ReadResourceContents] auto-wrapped +from mcp.server.lowlevel.helper_types import ReadResourceContents + +@server.read_resource() +async def handle(uri: AnyUrl) -> Iterable[ReadResourceContents]: + return [ReadResourceContents(content="file contents", mime_type="text/plain")] + +# Before (v1) — str/bytes shorthand (already deprecated in v1) @server.read_resource() async def handle(uri: str) -> str: return "file contents" @@ -849,7 +1079,7 @@ params = CallToolRequestParams( params = CallToolRequestParams( name="my_tool", arguments={}, - _meta={"progressToken": "tok", "customField": "value"}, # OK + _meta={"my_custom_key": "value", "another": 123}, # OK ) ``` From c5f12ec1e969ef69d4964816b650faf322eda3f3 Mon Sep 17 00:00:00 2001 From: Matt LeMay Date: Sun, 12 Apr 2026 07:42:12 -0500 Subject: [PATCH 55/84] Add `resources` parameter to `MCPServer` (#2414) Co-authored-by: Marcelo Trylesinski --- .../mcpserver/resources/resource_manager.py | 16 +- src/mcp/server/mcpserver/server.py | 5 +- .../server/mcpserver/tools/tool_manager.py | 16 +- .../resources/test_resource_manager.py | 292 ++++++++---------- tests/server/mcpserver/test_server.py | 26 ++ 5 files changed, 170 insertions(+), 185 deletions(-) diff --git a/src/mcp/server/mcpserver/resources/resource_manager.py b/src/mcp/server/mcpserver/resources/resource_manager.py index 6bf17376d1..766cf51aea 100644 --- a/src/mcp/server/mcpserver/resources/resource_manager.py +++ b/src/mcp/server/mcpserver/resources/resource_manager.py @@ -22,28 +22,26 @@ class ResourceManager: """Manages MCPServer resources.""" - def __init__(self, warn_on_duplicate_resources: bool = True): + def __init__(self, warn_on_duplicate_resources: bool = True, *, resources: list[Resource] | None = None): self._resources: dict[str, Resource] = {} self._templates: dict[str, ResourceTemplate] = {} self.warn_on_duplicate_resources = warn_on_duplicate_resources + for resource in resources or (): + self.add_resource(resource) + def add_resource(self, resource: Resource) -> Resource: """Add a resource to the manager. Args: - resource: A Resource instance to add + resource: A Resource instance to add. Returns: - The added resource. If a resource with the same URI already exists, - returns the existing resource. + The added resource. If a resource with the same URI already exists, returns the existing resource. """ logger.debug( "Adding resource", - extra={ - "uri": resource.uri, - "type": type(resource).__name__, - "resource_name": resource.name, - }, + extra={"uri": resource.uri, "type": type(resource).__name__, "resource_name": resource.name}, ) existing = self._resources.get(str(resource.uri)) if existing: diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index 6f9bb0e287..be77705da6 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -140,6 +140,7 @@ def __init__( token_verifier: TokenVerifier | None = None, *, tools: list[Tool] | None = None, + resources: list[Resource] | None = None, debug: bool = False, log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO", warn_on_duplicate_resources: bool = True, @@ -162,7 +163,9 @@ def __init__( self.dependencies = self.settings.dependencies self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools) - self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources) + self._resource_manager = ResourceManager( + resources=resources, warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources + ) self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts) self._lowlevel_server = Server( name=name or "mcp-server", diff --git a/src/mcp/server/mcpserver/tools/tool_manager.py b/src/mcp/server/mcpserver/tools/tool_manager.py index 32ed547973..eef4911f9e 100644 --- a/src/mcp/server/mcpserver/tools/tool_manager.py +++ b/src/mcp/server/mcpserver/tools/tool_manager.py @@ -18,18 +18,12 @@ class ToolManager: """Manages MCPServer tools.""" - def __init__( - self, - warn_on_duplicate_tools: bool = True, - *, - tools: list[Tool] | None = None, - ): + def __init__(self, warn_on_duplicate_tools: bool = True, *, tools: list[Tool] | None = None): self._tools: dict[str, Tool] = {} - if tools is not None: - for tool in tools: - if warn_on_duplicate_tools and tool.name in self._tools: - logger.warning(f"Tool already exists: {tool.name}") - self._tools[tool.name] = tool + for tool in tools or (): + if warn_on_duplicate_tools and tool.name in self._tools: + logger.warning(f"Tool already exists: {tool.name}") + self._tools[tool.name] = tool self.warn_on_duplicate_tools = warn_on_duplicate_tools diff --git a/tests/server/mcpserver/resources/test_resource_manager.py b/tests/server/mcpserver/resources/test_resource_manager.py index 724b579974..b91c71581c 100644 --- a/tests/server/mcpserver/resources/test_resource_manager.py +++ b/tests/server/mcpserver/resources/test_resource_manager.py @@ -1,5 +1,5 @@ +import logging from pathlib import Path -from tempfile import NamedTemporaryFile import pytest from pydantic import AnyUrl @@ -8,170 +8,134 @@ from mcp.server.mcpserver.resources import FileResource, FunctionResource, ResourceManager, ResourceTemplate -@pytest.fixture -def temp_file(): +@pytest.fixture() +def temp_file(tmp_path: Path): """Create a temporary file for testing. File is automatically cleaned up after the test if it still exists. """ - content = "test content" - with NamedTemporaryFile(mode="w", delete=False) as f: - f.write(content) - path = Path(f.name).resolve() - yield path - try: # pragma: lax no cover - path.unlink() - except FileNotFoundError: # pragma: lax no cover - pass # File was already deleted by the test - - -class TestResourceManager: - """Test ResourceManager functionality.""" - - def test_add_resource(self, temp_file: Path): - """Test adding a resource.""" - manager = ResourceManager() - resource = FileResource( - uri=f"file://{temp_file}", - name="test", - path=temp_file, - ) - added = manager.add_resource(resource) - assert added == resource - assert manager.list_resources() == [resource] - - def test_add_duplicate_resource(self, temp_file: Path): - """Test adding the same resource twice.""" - manager = ResourceManager() - resource = FileResource( - uri=f"file://{temp_file}", - name="test", - path=temp_file, - ) - first = manager.add_resource(resource) - second = manager.add_resource(resource) - assert first == second - assert manager.list_resources() == [resource] - - def test_warn_on_duplicate_resources(self, temp_file: Path, caplog: pytest.LogCaptureFixture): - """Test warning on duplicate resources.""" - manager = ResourceManager() - resource = FileResource( - uri=f"file://{temp_file}", - name="test", - path=temp_file, - ) - manager.add_resource(resource) - manager.add_resource(resource) - assert "Resource already exists" in caplog.text - - def test_disable_warn_on_duplicate_resources(self, temp_file: Path, caplog: pytest.LogCaptureFixture): - """Test disabling warning on duplicate resources.""" - manager = ResourceManager(warn_on_duplicate_resources=False) - resource = FileResource( - uri=f"file://{temp_file}", - name="test", - path=temp_file, - ) - manager.add_resource(resource) - manager.add_resource(resource) - assert "Resource already exists" not in caplog.text - - @pytest.mark.anyio - async def test_get_resource(self, temp_file: Path): - """Test getting a resource by URI.""" - manager = ResourceManager() - resource = FileResource( - uri=f"file://{temp_file}", - name="test", - path=temp_file, - ) - manager.add_resource(resource) - retrieved = await manager.get_resource(resource.uri, Context()) - assert retrieved == resource - - @pytest.mark.anyio - async def test_get_resource_from_template(self): - """Test getting a resource through a template.""" - manager = ResourceManager() - - def greet(name: str) -> str: - return f"Hello, {name}!" - - template = ResourceTemplate.from_function( - fn=greet, - uri_template="greet://{name}", - name="greeter", - ) - manager._templates[template.uri_template] = template - - resource = await manager.get_resource(AnyUrl("greet://world"), Context()) - assert isinstance(resource, FunctionResource) - content = await resource.read() - assert content == "Hello, world!" - - @pytest.mark.anyio - async def test_get_unknown_resource(self): - """Test getting a non-existent resource.""" - manager = ResourceManager() - with pytest.raises(ValueError, match="Unknown resource"): - await manager.get_resource(AnyUrl("unknown://test"), Context()) - - def test_list_resources(self, temp_file: Path): - """Test listing all resources.""" - manager = ResourceManager() - resource1 = FileResource( - uri=f"file://{temp_file}", - name="test1", - path=temp_file, - ) - resource2 = FileResource( - uri=f"file://{temp_file}2", - name="test2", - path=temp_file, - ) - manager.add_resource(resource1) - manager.add_resource(resource2) - resources = manager.list_resources() - assert len(resources) == 2 - assert resources == [resource1, resource2] - - -class TestResourceManagerMetadata: - """Test ResourceManager Metadata""" - - def test_add_template_with_metadata(self): - """Test that ResourceManager.add_template() accepts and passes meta parameter.""" - - manager = ResourceManager() - - def get_item(id: str) -> str: # pragma: no cover - return f"Item {id}" - - metadata = {"source": "database", "cached": True} - - template = manager.add_template( - fn=get_item, - uri_template="resource://items/{id}", - meta=metadata, - ) - - assert template.meta is not None - assert template.meta == metadata - assert template.meta["source"] == "database" - assert template.meta["cached"] is True - - def test_add_template_without_metadata(self): - """Test that ResourceManager.add_template() works without meta parameter.""" - - manager = ResourceManager() - - def get_item(id: str) -> str: # pragma: no cover - return f"Item {id}" - - template = manager.add_template( - fn=get_item, - uri_template="resource://items/{id}", - ) - - assert template.meta is None + tmp_file = tmp_path / "file" + tmp_file.touch() + yield tmp_file + + +def test_init_with_resources(temp_file: Path, caplog: pytest.LogCaptureFixture): + resource = FileResource(uri=f"file://{temp_file}", name="test", path=temp_file) + manager = ResourceManager(resources=[resource]) + assert manager.list_resources() == [resource] + + duplicate_resource = FileResource(uri=f"file://{temp_file}", name="duplicate", path=temp_file) + + with caplog.at_level(logging.WARNING): + manager = ResourceManager(True, resources=[resource, duplicate_resource]) + + assert "Resource already exists" in caplog.text + assert manager.list_resources() == [resource] + + +def test_add_resource(temp_file: Path): + """Test adding a resource.""" + manager = ResourceManager() + resource = FileResource(uri=f"file://{temp_file}", name="test", path=temp_file) + added = manager.add_resource(resource) + assert added == resource + assert manager.list_resources() == [resource] + + +def test_add_duplicate_resource(temp_file: Path): + """Test adding the same resource twice.""" + manager = ResourceManager() + resource = FileResource(uri=f"file://{temp_file}", name="test", path=temp_file) + first = manager.add_resource(resource) + second = manager.add_resource(resource) + assert first == second + assert manager.list_resources() == [resource] + + +def test_warn_on_duplicate_resources(temp_file: Path, caplog: pytest.LogCaptureFixture): + """Test warning on duplicate resources.""" + manager = ResourceManager() + resource = FileResource(uri=f"file://{temp_file}", name="test", path=temp_file) + manager.add_resource(resource) + manager.add_resource(resource) + assert "Resource already exists" in caplog.text + + +def test_disable_warn_on_duplicate_resources(temp_file: Path, caplog: pytest.LogCaptureFixture): + """Test disabling warning on duplicate resources.""" + manager = ResourceManager(warn_on_duplicate_resources=False) + resource = FileResource(uri=f"file://{temp_file}", name="test", path=temp_file) + manager.add_resource(resource) + manager.add_resource(resource) + assert "Resource already exists" not in caplog.text + + +@pytest.mark.anyio +async def test_get_resource(temp_file: Path): + """Test getting a resource by URI.""" + manager = ResourceManager() + resource = FileResource(uri=f"file://{temp_file}", name="test", path=temp_file) + manager.add_resource(resource) + retrieved = await manager.get_resource(resource.uri, Context()) + assert retrieved == resource + + +@pytest.mark.anyio +async def test_get_resource_from_template(): + """Test getting a resource through a template.""" + manager = ResourceManager() + + def greet(name: str) -> str: + return f"Hello, {name}!" + + template = ResourceTemplate.from_function(fn=greet, uri_template="greet://{name}", name="greeter") + manager._templates[template.uri_template] = template + + resource = await manager.get_resource(AnyUrl("greet://world"), Context()) + assert isinstance(resource, FunctionResource) + content = await resource.read() + assert content == "Hello, world!" + + +@pytest.mark.anyio +async def test_get_unknown_resource(): + """Test getting a non-existent resource.""" + manager = ResourceManager() + with pytest.raises(ValueError, match="Unknown resource"): + await manager.get_resource(AnyUrl("unknown://test"), Context()) + + +def test_list_resources(temp_file: Path): + """Test listing all resources.""" + manager = ResourceManager() + resource1 = FileResource(uri=f"file://{temp_file}", name="test1", path=temp_file) + resource2 = FileResource(uri=f"file://{temp_file}2", name="test2", path=temp_file) + + manager.add_resource(resource1) + manager.add_resource(resource2) + + resources = manager.list_resources() + assert len(resources) == 2 + assert resources == [resource1, resource2] + + +def get_item(id: str) -> str: ... + + +def test_add_template_with_metadata(): + """Test that ResourceManager.add_template() accepts and passes meta parameter.""" + manager = ResourceManager() + metadata = {"source": "database", "cached": True} + template = manager.add_template(fn=get_item, uri_template="resource://items/{id}", meta=metadata) + + assert template.meta is not None + assert template.meta == metadata + assert template.meta["source"] == "database" + assert template.meta["cached"] is True + + +def test_add_template_without_metadata(): + """Test that ResourceManager.add_template() works without meta parameter.""" + manager = ResourceManager() + template = manager.add_template(fn=get_item, uri_template="resource://items/{id}") + assert template.meta is None diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index 49b6deb4bb..3457ec944a 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -685,6 +685,32 @@ async def test_remove_tool_and_call(self): class TestServerResources: + async def test_init_with_resources(self): + def get_text() -> str: + """Seeded resource.""" + return "Hello from init!" + + resource = FunctionResource.from_function(fn=get_text, uri="resource://init", name="init_resource") + + mcp = MCPServer(resources=[resource]) + + async with Client(mcp) as client: + assert client.initialize_result.capabilities.resources is not None + + resources = await client.list_resources() + assert len(resources.resources) == 1 + listed = resources.resources[0] + assert listed.uri == "resource://init" + assert listed.name == "init_resource" + assert listed.description == "Seeded resource." + + result = await client.read_resource("resource://init") + + assert len(result.contents) == 1 + content = result.contents[0] + assert isinstance(content, TextResourceContents) + assert content.text == "Hello from init!" + async def test_text_resource(self): mcp = MCPServer() From 8f806da611131e012937cd6ba189c0bca1bceea2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sun, 12 Apr 2026 12:48:13 +0000 Subject: [PATCH 56/84] chore(deps): bump cryptography from 46.0.5 to 46.0.7 in the uv group across 1 directory (#2406) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- uv.lock | 102 ++++++++++++++++++++++++++++---------------------------- 1 file changed, 51 insertions(+), 51 deletions(-) diff --git a/uv.lock b/uv.lock index e6fb526167..705d014aa5 100644 --- a/uv.lock +++ b/uv.lock @@ -448,62 +448,62 @@ toml = [ [[package]] name = "cryptography" -version = "46.0.5" +version = "46.0.7" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cffi", marker = "platform_python_implementation != 'PyPy'" }, { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/60/04/ee2a9e8542e4fa2773b81771ff8349ff19cdd56b7258a0cc442639052edb/cryptography-46.0.5.tar.gz", hash = "sha256:abace499247268e3757271b2f1e244b36b06f8515cf27c4d49468fc9eb16e93d", size = 750064, upload-time = "2026-02-10T19:18:38.255Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f7/81/b0bb27f2ba931a65409c6b8a8b358a7f03c0e46eceacddff55f7c84b1f3b/cryptography-46.0.5-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:351695ada9ea9618b3500b490ad54c739860883df6c1f555e088eaf25b1bbaad", size = 7176289, upload-time = "2026-02-10T19:17:08.274Z" }, - { url = "https://files.pythonhosted.org/packages/ff/9e/6b4397a3e3d15123de3b1806ef342522393d50736c13b20ec4c9ea6693a6/cryptography-46.0.5-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:c18ff11e86df2e28854939acde2d003f7984f721eba450b56a200ad90eeb0e6b", size = 4275637, upload-time = "2026-02-10T19:17:10.53Z" }, - { url = "https://files.pythonhosted.org/packages/63/e7/471ab61099a3920b0c77852ea3f0ea611c9702f651600397ac567848b897/cryptography-46.0.5-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d7e3d356b8cd4ea5aff04f129d5f66ebdc7b6f8eae802b93739ed520c47c79b", size = 4424742, upload-time = "2026-02-10T19:17:12.388Z" }, - { url = "https://files.pythonhosted.org/packages/37/53/a18500f270342d66bf7e4d9f091114e31e5ee9e7375a5aba2e85a91e0044/cryptography-46.0.5-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:50bfb6925eff619c9c023b967d5b77a54e04256c4281b0e21336a130cd7fc263", size = 4277528, upload-time = "2026-02-10T19:17:13.853Z" }, - { url = "https://files.pythonhosted.org/packages/22/29/c2e812ebc38c57b40e7c583895e73c8c5adb4d1e4a0cc4c5a4fdab2b1acc/cryptography-46.0.5-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:803812e111e75d1aa73690d2facc295eaefd4439be1023fefc4995eaea2af90d", size = 4947993, upload-time = "2026-02-10T19:17:15.618Z" }, - { url = "https://files.pythonhosted.org/packages/6b/e7/237155ae19a9023de7e30ec64e5d99a9431a567407ac21170a046d22a5a3/cryptography-46.0.5-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3ee190460e2fbe447175cda91b88b84ae8322a104fc27766ad09428754a618ed", size = 4456855, upload-time = "2026-02-10T19:17:17.221Z" }, - { url = "https://files.pythonhosted.org/packages/2d/87/fc628a7ad85b81206738abbd213b07702bcbdada1dd43f72236ef3cffbb5/cryptography-46.0.5-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:f145bba11b878005c496e93e257c1e88f154d278d2638e6450d17e0f31e558d2", size = 3984635, upload-time = "2026-02-10T19:17:18.792Z" }, - { url = "https://files.pythonhosted.org/packages/84/29/65b55622bde135aedf4565dc509d99b560ee4095e56989e815f8fd2aa910/cryptography-46.0.5-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:e9251e3be159d1020c4030bd2e5f84d6a43fe54b6c19c12f51cde9542a2817b2", size = 4277038, upload-time = "2026-02-10T19:17:20.256Z" }, - { url = "https://files.pythonhosted.org/packages/bc/36/45e76c68d7311432741faf1fbf7fac8a196a0a735ca21f504c75d37e2558/cryptography-46.0.5-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:47fb8a66058b80e509c47118ef8a75d14c455e81ac369050f20ba0d23e77fee0", size = 4912181, upload-time = "2026-02-10T19:17:21.825Z" }, - { url = "https://files.pythonhosted.org/packages/6d/1a/c1ba8fead184d6e3d5afcf03d569acac5ad063f3ac9fb7258af158f7e378/cryptography-46.0.5-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:4c3341037c136030cb46e4b1e17b7418ea4cbd9dd207e4a6f3b2b24e0d4ac731", size = 4456482, upload-time = "2026-02-10T19:17:25.133Z" }, - { url = "https://files.pythonhosted.org/packages/f9/e5/3fb22e37f66827ced3b902cf895e6a6bc1d095b5b26be26bd13c441fdf19/cryptography-46.0.5-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:890bcb4abd5a2d3f852196437129eb3667d62630333aacc13dfd470fad3aaa82", size = 4405497, upload-time = "2026-02-10T19:17:26.66Z" }, - { url = "https://files.pythonhosted.org/packages/1a/df/9d58bb32b1121a8a2f27383fabae4d63080c7ca60b9b5c88be742be04ee7/cryptography-46.0.5-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:80a8d7bfdf38f87ca30a5391c0c9ce4ed2926918e017c29ddf643d0ed2778ea1", size = 4667819, upload-time = "2026-02-10T19:17:28.569Z" }, - { url = "https://files.pythonhosted.org/packages/ea/ed/325d2a490c5e94038cdb0117da9397ece1f11201f425c4e9c57fe5b9f08b/cryptography-46.0.5-cp311-abi3-win32.whl", hash = "sha256:60ee7e19e95104d4c03871d7d7dfb3d22ef8a9b9c6778c94e1c8fcc8365afd48", size = 3028230, upload-time = "2026-02-10T19:17:30.518Z" }, - { url = "https://files.pythonhosted.org/packages/e9/5a/ac0f49e48063ab4255d9e3b79f5def51697fce1a95ea1370f03dc9db76f6/cryptography-46.0.5-cp311-abi3-win_amd64.whl", hash = "sha256:38946c54b16c885c72c4f59846be9743d699eee2b69b6988e0a00a01f46a61a4", size = 3480909, upload-time = "2026-02-10T19:17:32.083Z" }, - { url = "https://files.pythonhosted.org/packages/00/13/3d278bfa7a15a96b9dc22db5a12ad1e48a9eb3d40e1827ef66a5df75d0d0/cryptography-46.0.5-cp314-cp314t-macosx_10_9_universal2.whl", hash = "sha256:94a76daa32eb78d61339aff7952ea819b1734b46f73646a07decb40e5b3448e2", size = 7119287, upload-time = "2026-02-10T19:17:33.801Z" }, - { url = "https://files.pythonhosted.org/packages/67/c8/581a6702e14f0898a0848105cbefd20c058099e2c2d22ef4e476dfec75d7/cryptography-46.0.5-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5be7bf2fb40769e05739dd0046e7b26f9d4670badc7b032d6ce4db64dddc0678", size = 4265728, upload-time = "2026-02-10T19:17:35.569Z" }, - { url = "https://files.pythonhosted.org/packages/dd/4a/ba1a65ce8fc65435e5a849558379896c957870dd64fecea97b1ad5f46a37/cryptography-46.0.5-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fe346b143ff9685e40192a4960938545c699054ba11d4f9029f94751e3f71d87", size = 4408287, upload-time = "2026-02-10T19:17:36.938Z" }, - { url = "https://files.pythonhosted.org/packages/f8/67/8ffdbf7b65ed1ac224d1c2df3943553766914a8ca718747ee3871da6107e/cryptography-46.0.5-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:c69fd885df7d089548a42d5ec05be26050ebcd2283d89b3d30676eb32ff87dee", size = 4270291, upload-time = "2026-02-10T19:17:38.748Z" }, - { url = "https://files.pythonhosted.org/packages/f8/e5/f52377ee93bc2f2bba55a41a886fd208c15276ffbd2569f2ddc89d50e2c5/cryptography-46.0.5-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:8293f3dea7fc929ef7240796ba231413afa7b68ce38fd21da2995549f5961981", size = 4927539, upload-time = "2026-02-10T19:17:40.241Z" }, - { url = "https://files.pythonhosted.org/packages/3b/02/cfe39181b02419bbbbcf3abdd16c1c5c8541f03ca8bda240debc467d5a12/cryptography-46.0.5-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:1abfdb89b41c3be0365328a410baa9df3ff8a9110fb75e7b52e66803ddabc9a9", size = 4442199, upload-time = "2026-02-10T19:17:41.789Z" }, - { url = "https://files.pythonhosted.org/packages/c0/96/2fcaeb4873e536cf71421a388a6c11b5bc846e986b2b069c79363dc1648e/cryptography-46.0.5-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:d66e421495fdb797610a08f43b05269e0a5ea7f5e652a89bfd5a7d3c1dee3648", size = 3960131, upload-time = "2026-02-10T19:17:43.379Z" }, - { url = "https://files.pythonhosted.org/packages/d8/d2/b27631f401ddd644e94c5cf33c9a4069f72011821cf3dc7309546b0642a0/cryptography-46.0.5-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:4e817a8920bfbcff8940ecfd60f23d01836408242b30f1a708d93198393a80b4", size = 4270072, upload-time = "2026-02-10T19:17:45.481Z" }, - { url = "https://files.pythonhosted.org/packages/f4/a7/60d32b0370dae0b4ebe55ffa10e8599a2a59935b5ece1b9f06edb73abdeb/cryptography-46.0.5-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:68f68d13f2e1cb95163fa3b4db4bf9a159a418f5f6e7242564fc75fcae667fd0", size = 4892170, upload-time = "2026-02-10T19:17:46.997Z" }, - { url = "https://files.pythonhosted.org/packages/d2/b9/cf73ddf8ef1164330eb0b199a589103c363afa0cf794218c24d524a58eab/cryptography-46.0.5-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:a3d1fae9863299076f05cb8a778c467578262fae09f9dc0ee9b12eb4268ce663", size = 4441741, upload-time = "2026-02-10T19:17:48.661Z" }, - { url = "https://files.pythonhosted.org/packages/5f/eb/eee00b28c84c726fe8fa0158c65afe312d9c3b78d9d01daf700f1f6e37ff/cryptography-46.0.5-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:c4143987a42a2397f2fc3b4d7e3a7d313fbe684f67ff443999e803dd75a76826", size = 4396728, upload-time = "2026-02-10T19:17:50.058Z" }, - { url = "https://files.pythonhosted.org/packages/65/f4/6bc1a9ed5aef7145045114b75b77c2a8261b4d38717bd8dea111a63c3442/cryptography-46.0.5-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:7d731d4b107030987fd61a7f8ab512b25b53cef8f233a97379ede116f30eb67d", size = 4652001, upload-time = "2026-02-10T19:17:51.54Z" }, - { url = "https://files.pythonhosted.org/packages/86/ef/5d00ef966ddd71ac2e6951d278884a84a40ffbd88948ef0e294b214ae9e4/cryptography-46.0.5-cp314-cp314t-win32.whl", hash = "sha256:c3bcce8521d785d510b2aad26ae2c966092b7daa8f45dd8f44734a104dc0bc1a", size = 3003637, upload-time = "2026-02-10T19:17:52.997Z" }, - { url = "https://files.pythonhosted.org/packages/b7/57/f3f4160123da6d098db78350fdfd9705057aad21de7388eacb2401dceab9/cryptography-46.0.5-cp314-cp314t-win_amd64.whl", hash = "sha256:4d8ae8659ab18c65ced284993c2265910f6c9e650189d4e3f68445ef82a810e4", size = 3469487, upload-time = "2026-02-10T19:17:54.549Z" }, - { url = "https://files.pythonhosted.org/packages/e2/fa/a66aa722105ad6a458bebd64086ca2b72cdd361fed31763d20390f6f1389/cryptography-46.0.5-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:4108d4c09fbbf2789d0c926eb4152ae1760d5a2d97612b92d508d96c861e4d31", size = 7170514, upload-time = "2026-02-10T19:17:56.267Z" }, - { url = "https://files.pythonhosted.org/packages/0f/04/c85bdeab78c8bc77b701bf0d9bdcf514c044e18a46dcff330df5448631b0/cryptography-46.0.5-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7d1f30a86d2757199cb2d56e48cce14deddf1f9c95f1ef1b64ee91ea43fe2e18", size = 4275349, upload-time = "2026-02-10T19:17:58.419Z" }, - { url = "https://files.pythonhosted.org/packages/5c/32/9b87132a2f91ee7f5223b091dc963055503e9b442c98fc0b8a5ca765fab0/cryptography-46.0.5-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:039917b0dc418bb9f6edce8a906572d69e74bd330b0b3fea4f79dab7f8ddd235", size = 4420667, upload-time = "2026-02-10T19:18:00.619Z" }, - { url = "https://files.pythonhosted.org/packages/a1/a6/a7cb7010bec4b7c5692ca6f024150371b295ee1c108bdc1c400e4c44562b/cryptography-46.0.5-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:ba2a27ff02f48193fc4daeadf8ad2590516fa3d0adeeb34336b96f7fa64c1e3a", size = 4276980, upload-time = "2026-02-10T19:18:02.379Z" }, - { url = "https://files.pythonhosted.org/packages/8e/7c/c4f45e0eeff9b91e3f12dbd0e165fcf2a38847288fcfd889deea99fb7b6d/cryptography-46.0.5-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:61aa400dce22cb001a98014f647dc21cda08f7915ceb95df0c9eaf84b4b6af76", size = 4939143, upload-time = "2026-02-10T19:18:03.964Z" }, - { url = "https://files.pythonhosted.org/packages/37/19/e1b8f964a834eddb44fa1b9a9976f4e414cbb7aa62809b6760c8803d22d1/cryptography-46.0.5-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3ce58ba46e1bc2aac4f7d9290223cead56743fa6ab94a5d53292ffaac6a91614", size = 4453674, upload-time = "2026-02-10T19:18:05.588Z" }, - { url = "https://files.pythonhosted.org/packages/db/ed/db15d3956f65264ca204625597c410d420e26530c4e2943e05a0d2f24d51/cryptography-46.0.5-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:420d0e909050490d04359e7fdb5ed7e667ca5c3c402b809ae2563d7e66a92229", size = 3978801, upload-time = "2026-02-10T19:18:07.167Z" }, - { url = "https://files.pythonhosted.org/packages/41/e2/df40a31d82df0a70a0daf69791f91dbb70e47644c58581d654879b382d11/cryptography-46.0.5-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:582f5fcd2afa31622f317f80426a027f30dc792e9c80ffee87b993200ea115f1", size = 4276755, upload-time = "2026-02-10T19:18:09.813Z" }, - { url = "https://files.pythonhosted.org/packages/33/45/726809d1176959f4a896b86907b98ff4391a8aa29c0aaaf9450a8a10630e/cryptography-46.0.5-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:bfd56bb4b37ed4f330b82402f6f435845a5f5648edf1ad497da51a8452d5d62d", size = 4901539, upload-time = "2026-02-10T19:18:11.263Z" }, - { url = "https://files.pythonhosted.org/packages/99/0f/a3076874e9c88ecb2ecc31382f6e7c21b428ede6f55aafa1aa272613e3cd/cryptography-46.0.5-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:a3d507bb6a513ca96ba84443226af944b0f7f47dcc9a399d110cd6146481d24c", size = 4452794, upload-time = "2026-02-10T19:18:12.914Z" }, - { url = "https://files.pythonhosted.org/packages/02/ef/ffeb542d3683d24194a38f66ca17c0a4b8bf10631feef44a7ef64e631b1a/cryptography-46.0.5-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9f16fbdf4da055efb21c22d81b89f155f02ba420558db21288b3d0035bafd5f4", size = 4404160, upload-time = "2026-02-10T19:18:14.375Z" }, - { url = "https://files.pythonhosted.org/packages/96/93/682d2b43c1d5f1406ed048f377c0fc9fc8f7b0447a478d5c65ab3d3a66eb/cryptography-46.0.5-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:ced80795227d70549a411a4ab66e8ce307899fad2220ce5ab2f296e687eacde9", size = 4667123, upload-time = "2026-02-10T19:18:15.886Z" }, - { url = "https://files.pythonhosted.org/packages/45/2d/9c5f2926cb5300a8eefc3f4f0b3f3df39db7f7ce40c8365444c49363cbda/cryptography-46.0.5-cp38-abi3-win32.whl", hash = "sha256:02f547fce831f5096c9a567fd41bc12ca8f11df260959ecc7c3202555cc47a72", size = 3010220, upload-time = "2026-02-10T19:18:17.361Z" }, - { url = "https://files.pythonhosted.org/packages/48/ef/0c2f4a8e31018a986949d34a01115dd057bf536905dca38897bacd21fac3/cryptography-46.0.5-cp38-abi3-win_amd64.whl", hash = "sha256:556e106ee01aa13484ce9b0239bca667be5004efb0aabbed28d353df86445595", size = 3467050, upload-time = "2026-02-10T19:18:18.899Z" }, - { url = "https://files.pythonhosted.org/packages/eb/dd/2d9fdb07cebdf3d51179730afb7d5e576153c6744c3ff8fded23030c204e/cryptography-46.0.5-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:3b4995dc971c9fb83c25aa44cf45f02ba86f71ee600d81091c2f0cbae116b06c", size = 3476964, upload-time = "2026-02-10T19:18:20.687Z" }, - { url = "https://files.pythonhosted.org/packages/e9/6f/6cc6cc9955caa6eaf83660b0da2b077c7fe8ff9950a3c5e45d605038d439/cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:bc84e875994c3b445871ea7181d424588171efec3e185dced958dad9e001950a", size = 4218321, upload-time = "2026-02-10T19:18:22.349Z" }, - { url = "https://files.pythonhosted.org/packages/3e/5d/c4da701939eeee699566a6c1367427ab91a8b7088cc2328c09dbee940415/cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:2ae6971afd6246710480e3f15824ed3029a60fc16991db250034efd0b9fb4356", size = 4381786, upload-time = "2026-02-10T19:18:24.529Z" }, - { url = "https://files.pythonhosted.org/packages/ac/97/a538654732974a94ff96c1db621fa464f455c02d4bb7d2652f4edc21d600/cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:d861ee9e76ace6cf36a6a89b959ec08e7bc2493ee39d07ffe5acb23ef46d27da", size = 4217990, upload-time = "2026-02-10T19:18:25.957Z" }, - { url = "https://files.pythonhosted.org/packages/ae/11/7e500d2dd3ba891197b9efd2da5454b74336d64a7cc419aa7327ab74e5f6/cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:2b7a67c9cd56372f3249b39699f2ad479f6991e62ea15800973b956f4b73e257", size = 4381252, upload-time = "2026-02-10T19:18:27.496Z" }, - { url = "https://files.pythonhosted.org/packages/bc/58/6b3d24e6b9bc474a2dcdee65dfd1f008867015408a271562e4b690561a4d/cryptography-46.0.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:8456928655f856c6e1533ff59d5be76578a7157224dbd9ce6872f25055ab9ab7", size = 3407605, upload-time = "2026-02-10T19:18:29.233Z" }, +sdist = { url = "https://files.pythonhosted.org/packages/47/93/ac8f3d5ff04d54bc814e961a43ae5b0b146154c89c61b47bb07557679b18/cryptography-46.0.7.tar.gz", hash = "sha256:e4cfd68c5f3e0bfdad0d38e023239b96a2fe84146481852dffbcca442c245aa5", size = 750652, upload-time = "2026-04-08T01:57:54.692Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0b/5d/4a8f770695d73be252331e60e526291e3df0c9b27556a90a6b47bccca4c2/cryptography-46.0.7-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:ea42cbe97209df307fdc3b155f1b6fa2577c0defa8f1f7d3be7d31d189108ad4", size = 7179869, upload-time = "2026-04-08T01:56:17.157Z" }, + { url = "https://files.pythonhosted.org/packages/5f/45/6d80dc379b0bbc1f9d1e429f42e4cb9e1d319c7a8201beffd967c516ea01/cryptography-46.0.7-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:b36a4695e29fe69215d75960b22577197aca3f7a25b9cf9d165dcfe9d80bc325", size = 4275492, upload-time = "2026-04-08T01:56:19.36Z" }, + { url = "https://files.pythonhosted.org/packages/4a/9a/1765afe9f572e239c3469f2cb429f3ba7b31878c893b246b4b2994ffe2fe/cryptography-46.0.7-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ad9ef796328c5e3c4ceed237a183f5d41d21150f972455a9d926593a1dcb308", size = 4426670, upload-time = "2026-04-08T01:56:21.415Z" }, + { url = "https://files.pythonhosted.org/packages/8f/3e/af9246aaf23cd4ee060699adab1e47ced3f5f7e7a8ffdd339f817b446462/cryptography-46.0.7-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:73510b83623e080a2c35c62c15298096e2a5dc8d51c3b4e1740211839d0dea77", size = 4280275, upload-time = "2026-04-08T01:56:23.539Z" }, + { url = "https://files.pythonhosted.org/packages/0f/54/6bbbfc5efe86f9d71041827b793c24811a017c6ac0fd12883e4caa86b8ed/cryptography-46.0.7-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:cbd5fb06b62bd0721e1170273d3f4d5a277044c47ca27ee257025146c34cbdd1", size = 4928402, upload-time = "2026-04-08T01:56:25.624Z" }, + { url = "https://files.pythonhosted.org/packages/2d/cf/054b9d8220f81509939599c8bdbc0c408dbd2bdd41688616a20731371fe0/cryptography-46.0.7-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:420b1e4109cc95f0e5700eed79908cef9268265c773d3a66f7af1eef53d409ef", size = 4459985, upload-time = "2026-04-08T01:56:27.309Z" }, + { url = "https://files.pythonhosted.org/packages/f9/46/4e4e9c6040fb01c7467d47217d2f882daddeb8828f7df800cb806d8a2288/cryptography-46.0.7-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:24402210aa54baae71d99441d15bb5a1919c195398a87b563df84468160a65de", size = 3990652, upload-time = "2026-04-08T01:56:29.095Z" }, + { url = "https://files.pythonhosted.org/packages/36/5f/313586c3be5a2fbe87e4c9a254207b860155a8e1f3cca99f9910008e7d08/cryptography-46.0.7-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:8a469028a86f12eb7d2fe97162d0634026d92a21f3ae0ac87ed1c4a447886c83", size = 4279805, upload-time = "2026-04-08T01:56:30.928Z" }, + { url = "https://files.pythonhosted.org/packages/69/33/60dfc4595f334a2082749673386a4d05e4f0cf4df8248e63b2c3437585f2/cryptography-46.0.7-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:9694078c5d44c157ef3162e3bf3946510b857df5a3955458381d1c7cfc143ddb", size = 4892883, upload-time = "2026-04-08T01:56:32.614Z" }, + { url = "https://files.pythonhosted.org/packages/c7/0b/333ddab4270c4f5b972f980adef4faa66951a4aaf646ca067af597f15563/cryptography-46.0.7-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:42a1e5f98abb6391717978baf9f90dc28a743b7d9be7f0751a6f56a75d14065b", size = 4459756, upload-time = "2026-04-08T01:56:34.306Z" }, + { url = "https://files.pythonhosted.org/packages/d2/14/633913398b43b75f1234834170947957c6b623d1701ffc7a9600da907e89/cryptography-46.0.7-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:91bbcb08347344f810cbe49065914fe048949648f6bd5c2519f34619142bbe85", size = 4410244, upload-time = "2026-04-08T01:56:35.977Z" }, + { url = "https://files.pythonhosted.org/packages/10/f2/19ceb3b3dc14009373432af0c13f46aa08e3ce334ec6eff13492e1812ccd/cryptography-46.0.7-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5d1c02a14ceb9148cc7816249f64f623fbfee39e8c03b3650d842ad3f34d637e", size = 4674868, upload-time = "2026-04-08T01:56:38.034Z" }, + { url = "https://files.pythonhosted.org/packages/1a/bb/a5c213c19ee94b15dfccc48f363738633a493812687f5567addbcbba9f6f/cryptography-46.0.7-cp311-abi3-win32.whl", hash = "sha256:d23c8ca48e44ee015cd0a54aeccdf9f09004eba9fc96f38c911011d9ff1bd457", size = 3026504, upload-time = "2026-04-08T01:56:39.666Z" }, + { url = "https://files.pythonhosted.org/packages/2b/02/7788f9fefa1d060ca68717c3901ae7fffa21ee087a90b7f23c7a603c32ae/cryptography-46.0.7-cp311-abi3-win_amd64.whl", hash = "sha256:397655da831414d165029da9bc483bed2fe0e75dde6a1523ec2fe63f3c46046b", size = 3488363, upload-time = "2026-04-08T01:56:41.893Z" }, + { url = "https://files.pythonhosted.org/packages/7b/56/15619b210e689c5403bb0540e4cb7dbf11a6bf42e483b7644e471a2812b3/cryptography-46.0.7-cp314-cp314t-macosx_10_9_universal2.whl", hash = "sha256:d151173275e1728cf7839aaa80c34fe550c04ddb27b34f48c232193df8db5842", size = 7119671, upload-time = "2026-04-08T01:56:44Z" }, + { url = "https://files.pythonhosted.org/packages/74/66/e3ce040721b0b5599e175ba91ab08884c75928fbeb74597dd10ef13505d2/cryptography-46.0.7-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:db0f493b9181c7820c8134437eb8b0b4792085d37dbb24da050476ccb664e59c", size = 4268551, upload-time = "2026-04-08T01:56:46.071Z" }, + { url = "https://files.pythonhosted.org/packages/03/11/5e395f961d6868269835dee1bafec6a1ac176505a167f68b7d8818431068/cryptography-46.0.7-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ebd6daf519b9f189f85c479427bbd6e9c9037862cf8fe89ee35503bd209ed902", size = 4408887, upload-time = "2026-04-08T01:56:47.718Z" }, + { url = "https://files.pythonhosted.org/packages/40/53/8ed1cf4c3b9c8e611e7122fb56f1c32d09e1fff0f1d77e78d9ff7c82653e/cryptography-46.0.7-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:b7b412817be92117ec5ed95f880defe9cf18a832e8cafacf0a22337dc1981b4d", size = 4271354, upload-time = "2026-04-08T01:56:49.312Z" }, + { url = "https://files.pythonhosted.org/packages/50/46/cf71e26025c2e767c5609162c866a78e8a2915bbcfa408b7ca495c6140c4/cryptography-46.0.7-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:fbfd0e5f273877695cb93baf14b185f4878128b250cc9f8e617ea0c025dfb022", size = 4905845, upload-time = "2026-04-08T01:56:50.916Z" }, + { url = "https://files.pythonhosted.org/packages/c0/ea/01276740375bac6249d0a971ebdf6b4dc9ead0ee0a34ef3b5a88c1a9b0d4/cryptography-46.0.7-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:ffca7aa1d00cf7d6469b988c581598f2259e46215e0140af408966a24cf086ce", size = 4444641, upload-time = "2026-04-08T01:56:52.882Z" }, + { url = "https://files.pythonhosted.org/packages/3d/4c/7d258f169ae71230f25d9f3d06caabcff8c3baf0978e2b7d65e0acac3827/cryptography-46.0.7-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:60627cf07e0d9274338521205899337c5d18249db56865f943cbe753aa96f40f", size = 3967749, upload-time = "2026-04-08T01:56:54.597Z" }, + { url = "https://files.pythonhosted.org/packages/b5/2a/2ea0767cad19e71b3530e4cad9605d0b5e338b6a1e72c37c9c1ceb86c333/cryptography-46.0.7-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:80406c3065e2c55d7f49a9550fe0c49b3f12e5bfff5dedb727e319e1afb9bf99", size = 4270942, upload-time = "2026-04-08T01:56:56.416Z" }, + { url = "https://files.pythonhosted.org/packages/41/3d/fe14df95a83319af25717677e956567a105bb6ab25641acaa093db79975d/cryptography-46.0.7-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:c5b1ccd1239f48b7151a65bc6dd54bcfcc15e028c8ac126d3fada09db0e07ef1", size = 4871079, upload-time = "2026-04-08T01:56:58.31Z" }, + { url = "https://files.pythonhosted.org/packages/9c/59/4a479e0f36f8f378d397f4eab4c850b4ffb79a2f0d58704b8fa0703ddc11/cryptography-46.0.7-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:d5f7520159cd9c2154eb61eb67548ca05c5774d39e9c2c4339fd793fe7d097b2", size = 4443999, upload-time = "2026-04-08T01:57:00.508Z" }, + { url = "https://files.pythonhosted.org/packages/28/17/b59a741645822ec6d04732b43c5d35e4ef58be7bfa84a81e5ae6f05a1d33/cryptography-46.0.7-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:fcd8eac50d9138c1d7fc53a653ba60a2bee81a505f9f8850b6b2888555a45d0e", size = 4399191, upload-time = "2026-04-08T01:57:02.654Z" }, + { url = "https://files.pythonhosted.org/packages/59/6a/bb2e166d6d0e0955f1e9ff70f10ec4b2824c9cfcdb4da772c7dd69cc7d80/cryptography-46.0.7-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:65814c60f8cc400c63131584e3e1fad01235edba2614b61fbfbfa954082db0ee", size = 4655782, upload-time = "2026-04-08T01:57:04.592Z" }, + { url = "https://files.pythonhosted.org/packages/95/b6/3da51d48415bcb63b00dc17c2eff3a651b7c4fed484308d0f19b30e8cb2c/cryptography-46.0.7-cp314-cp314t-win32.whl", hash = "sha256:fdd1736fed309b4300346f88f74cd120c27c56852c3838cab416e7a166f67298", size = 3002227, upload-time = "2026-04-08T01:57:06.91Z" }, + { url = "https://files.pythonhosted.org/packages/32/a8/9f0e4ed57ec9cebe506e58db11ae472972ecb0c659e4d52bbaee80ca340a/cryptography-46.0.7-cp314-cp314t-win_amd64.whl", hash = "sha256:e06acf3c99be55aa3b516397fe42f5855597f430add9c17fa46bf2e0fb34c9bb", size = 3475332, upload-time = "2026-04-08T01:57:08.807Z" }, + { url = "https://files.pythonhosted.org/packages/a7/7f/cd42fc3614386bc0c12f0cb3c4ae1fc2bbca5c9662dfed031514911d513d/cryptography-46.0.7-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:462ad5cb1c148a22b2e3bcc5ad52504dff325d17daf5df8d88c17dda1f75f2a4", size = 7165618, upload-time = "2026-04-08T01:57:10.645Z" }, + { url = "https://files.pythonhosted.org/packages/a5/d0/36a49f0262d2319139d2829f773f1b97ef8aef7f97e6e5bd21455e5a8fb5/cryptography-46.0.7-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:84d4cced91f0f159a7ddacad249cc077e63195c36aac40b4150e7a57e84fffe7", size = 4270628, upload-time = "2026-04-08T01:57:12.885Z" }, + { url = "https://files.pythonhosted.org/packages/8a/6c/1a42450f464dda6ffbe578a911f773e54dd48c10f9895a23a7e88b3e7db5/cryptography-46.0.7-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:128c5edfe5e5938b86b03941e94fac9ee793a94452ad1365c9fc3f4f62216832", size = 4415405, upload-time = "2026-04-08T01:57:14.923Z" }, + { url = "https://files.pythonhosted.org/packages/9a/92/4ed714dbe93a066dc1f4b4581a464d2d7dbec9046f7c8b7016f5286329e2/cryptography-46.0.7-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:5e51be372b26ef4ba3de3c167cd3d1022934bc838ae9eaad7e644986d2a3d163", size = 4272715, upload-time = "2026-04-08T01:57:16.638Z" }, + { url = "https://files.pythonhosted.org/packages/b7/e6/a26b84096eddd51494bba19111f8fffe976f6a09f132706f8f1bf03f51f7/cryptography-46.0.7-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:cdf1a610ef82abb396451862739e3fc93b071c844399e15b90726ef7470eeaf2", size = 4918400, upload-time = "2026-04-08T01:57:19.021Z" }, + { url = "https://files.pythonhosted.org/packages/c7/08/ffd537b605568a148543ac3c2b239708ae0bd635064bab41359252ef88ed/cryptography-46.0.7-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:1d25aee46d0c6f1a501adcddb2d2fee4b979381346a78558ed13e50aa8a59067", size = 4450634, upload-time = "2026-04-08T01:57:21.185Z" }, + { url = "https://files.pythonhosted.org/packages/16/01/0cd51dd86ab5b9befe0d031e276510491976c3a80e9f6e31810cce46c4ad/cryptography-46.0.7-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:cdfbe22376065ffcf8be74dc9a909f032df19bc58a699456a21712d6e5eabfd0", size = 3985233, upload-time = "2026-04-08T01:57:22.862Z" }, + { url = "https://files.pythonhosted.org/packages/92/49/819d6ed3a7d9349c2939f81b500a738cb733ab62fbecdbc1e38e83d45e12/cryptography-46.0.7-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:abad9dac36cbf55de6eb49badd4016806b3165d396f64925bf2999bcb67837ba", size = 4271955, upload-time = "2026-04-08T01:57:24.814Z" }, + { url = "https://files.pythonhosted.org/packages/80/07/ad9b3c56ebb95ed2473d46df0847357e01583f4c52a85754d1a55e29e4d0/cryptography-46.0.7-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:935ce7e3cfdb53e3536119a542b839bb94ec1ad081013e9ab9b7cfd478b05006", size = 4879888, upload-time = "2026-04-08T01:57:26.88Z" }, + { url = "https://files.pythonhosted.org/packages/b8/c7/201d3d58f30c4c2bdbe9b03844c291feb77c20511cc3586daf7edc12a47b/cryptography-46.0.7-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:35719dc79d4730d30f1c2b6474bd6acda36ae2dfae1e3c16f2051f215df33ce0", size = 4449961, upload-time = "2026-04-08T01:57:29.068Z" }, + { url = "https://files.pythonhosted.org/packages/a5/ef/649750cbf96f3033c3c976e112265c33906f8e462291a33d77f90356548c/cryptography-46.0.7-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:7bbc6ccf49d05ac8f7d7b5e2e2c33830d4fe2061def88210a126d130d7f71a85", size = 4401696, upload-time = "2026-04-08T01:57:31.029Z" }, + { url = "https://files.pythonhosted.org/packages/41/52/a8908dcb1a389a459a29008c29966c1d552588d4ae6d43f3a1a4512e0ebe/cryptography-46.0.7-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a1529d614f44b863a7b480c6d000fe93b59acee9c82ffa027cfadc77521a9f5e", size = 4664256, upload-time = "2026-04-08T01:57:33.144Z" }, + { url = "https://files.pythonhosted.org/packages/4b/fa/f0ab06238e899cc3fb332623f337a7364f36f4bb3f2534c2bb95a35b132c/cryptography-46.0.7-cp38-abi3-win32.whl", hash = "sha256:f247c8c1a1fb45e12586afbb436ef21ff1e80670b2861a90353d9b025583d246", size = 3013001, upload-time = "2026-04-08T01:57:34.933Z" }, + { url = "https://files.pythonhosted.org/packages/d2/f1/00ce3bde3ca542d1acd8f8cfa38e446840945aa6363f9b74746394b14127/cryptography-46.0.7-cp38-abi3-win_amd64.whl", hash = "sha256:506c4ff91eff4f82bdac7633318a526b1d1309fc07ca76a3ad182cb5b686d6d3", size = 3472985, upload-time = "2026-04-08T01:57:36.714Z" }, + { url = "https://files.pythonhosted.org/packages/63/0c/dca8abb64e7ca4f6b2978769f6fea5ad06686a190cec381f0a796fdcaaba/cryptography-46.0.7-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:fc9ab8856ae6cf7c9358430e49b368f3108f050031442eaeb6b9d87e4dcf4e4f", size = 3476879, upload-time = "2026-04-08T01:57:38.664Z" }, + { url = "https://files.pythonhosted.org/packages/3a/ea/075aac6a84b7c271578d81a2f9968acb6e273002408729f2ddff517fed4a/cryptography-46.0.7-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:d3b99c535a9de0adced13d159c5a9cf65c325601aa30f4be08afd680643e9c15", size = 4219700, upload-time = "2026-04-08T01:57:40.625Z" }, + { url = "https://files.pythonhosted.org/packages/6c/7b/1c55db7242b5e5612b29fc7a630e91ee7a6e3c8e7bf5406d22e206875fbd/cryptography-46.0.7-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d02c738dacda7dc2a74d1b2b3177042009d5cab7c7079db74afc19e56ca1b455", size = 4385982, upload-time = "2026-04-08T01:57:42.725Z" }, + { url = "https://files.pythonhosted.org/packages/cb/da/9870eec4b69c63ef5925bf7d8342b7e13bc2ee3d47791461c4e49ca212f4/cryptography-46.0.7-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:04959522f938493042d595a736e7dbdff6eb6cc2339c11465b3ff89343b65f65", size = 4219115, upload-time = "2026-04-08T01:57:44.939Z" }, + { url = "https://files.pythonhosted.org/packages/f4/72/05aa5832b82dd341969e9a734d1812a6aadb088d9eb6f0430fc337cc5a8f/cryptography-46.0.7-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:3986ac1dee6def53797289999eabe84798ad7817f3e97779b5061a95b0ee4968", size = 4385479, upload-time = "2026-04-08T01:57:46.86Z" }, + { url = "https://files.pythonhosted.org/packages/20/2a/1b016902351a523aa2bd446b50a5bc1175d7a7d1cf90fe2ef904f9b84ebc/cryptography-46.0.7-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:258514877e15963bd43b558917bc9f54cf7cf866c38aa576ebf47a77ddbc43a4", size = 3412829, upload-time = "2026-04-08T01:57:48.874Z" }, ] [[package]] From 941089e06aa4f7ea09c58f8cae42e28059bb9c9b Mon Sep 17 00:00:00 2001 From: Max <224885523+maxisbey@users.noreply.github.com> Date: Mon, 13 Apr 2026 12:50:04 +0100 Subject: [PATCH 57/84] docs: modernize development guidelines and rename to AGENTS.md (#2413) --- AGENTS.md | 138 ++++++++++++++++++++++++++++++++++++++ CLAUDE.md | 175 +------------------------------------------------ pyproject.toml | 1 - 3 files changed, 139 insertions(+), 175 deletions(-) create mode 100644 AGENTS.md diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000000..9692271044 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,138 @@ +# Development Guidelines + +## Branching Model + + + +- `main` is currently the V2 rework. Breaking changes are expected here — when removing or + replacing an API, delete it outright and document the change in + `docs/migration.md`. Do not add `@deprecated` shims or backward-compat layers + on `main`. +- `v1.x` is the release branch for the current stable line. Backport PRs target + this branch and use a `[v1.x]` title prefix. +- `README.md` is frozen at v1 (a pre-commit hook rejects edits). Edit + `README.v2.md` instead. + +## Package Management + +- ONLY use uv, NEVER pip +- Installation: `uv add ` +- Running tools: `uv run --frozen `. Always pass `--frozen` so uv doesn't + rewrite `uv.lock` as a side effect. +- Cross-version testing: `uv run --frozen --python 3.10 pytest ...` to run + against a specific interpreter (CI covers 3.10–3.14). +- Upgrading: `uv lock --upgrade-package ` +- FORBIDDEN: `uv pip install`, `@latest` syntax +- Don't raise dependency floors for CVEs alone. The `>=` constraint already + lets users upgrade. Only raise a floor when the SDK needs functionality from + the newer version, and don't add SDK code to work around a dependency's + vulnerability. See Kludex/uvicorn#2643 and python-sdk #1552 for reasoning. + +## Code Quality + +- Type hints required for all code +- Public APIs must have docstrings. When a public API raises exceptions a + caller would reasonably catch, document them in a `Raises:` section. Don't + list exceptions from argument validation or programmer error. +- `src/mcp/__init__.py` defines the public API surface via `__all__`. Adding a + symbol there is a deliberate API decision, not a convenience re-export. +- IMPORTANT: All imports go at the top of the file — inline imports hide + dependencies and obscure circular-import bugs. Only exception: when a + top-level import genuinely can't work (lazy-loading optional deps, or + tests that re-import a module). + +## Testing + +- Framework: `uv run --frozen pytest` +- Async testing: use anyio, not asyncio +- Do not use `Test` prefixed classes, use functions +- IMPORTANT: Tests should be fast and deterministic. Prefer in-memory async execution; + reach for threads only when necessary, and subprocesses only as a last resort. +- For end-to-end behavior, an in-memory `Client(server)` is usually the + cleanest approach (see `tests/client/test_client.py` for the canonical + pattern). For narrower changes, testing the function directly is fine. Use + judgment. +- Test files mirror the source tree: `src/mcp/client/stdio.py` → + `tests/client/test_stdio.py`. Add tests to the existing file for that module. +- Avoid `anyio.sleep()` with a fixed duration to wait for async operations. Instead: + - Use `anyio.Event` — set it in the callback/handler, `await event.wait()` in the test + - For stream messages, use `await stream.receive()` instead of `sleep()` + `receive_nowait()` + - Exception: `sleep()` is appropriate when testing time-based features (e.g., timeouts) +- Wrap indefinite waits (`event.wait()`, `stream.receive()`) in `anyio.fail_after(5)` to prevent hangs +- Pytest is configured with `filterwarnings = ["error"]`, so warnings fail + tests. Don't silence warnings from your own code; fix the underlying cause. + Scoped `ignore::` entries for upstream libraries are acceptable in + `pyproject.toml` with a comment explaining why. + +### Coverage + +CI requires 100% (`fail_under = 100`, `branch = true`). + +- Full check: `./scripts/test` (~23s). Runs coverage + `strict-no-cover` on the + default Python. Not identical to CI: CI runs 3.10–3.14 × {ubuntu, windows} + × {locked, lowest-direct}, and some branch-coverage quirks only surface on + specific matrix entries. +- Targeted check while iterating (~4s, deterministic): + + ```bash + uv run --frozen coverage erase + uv run --frozen coverage run -m pytest tests/path/test_foo.py + uv run --frozen coverage combine + uv run --frozen coverage report --include='src/mcp/path/foo.py' --fail-under=0 + # UV_FROZEN=1 propagates --frozen to the uv subprocess strict-no-cover spawns + UV_FROZEN=1 uv run --frozen strict-no-cover + ``` + + Partial runs can't hit 100% (coverage tracks `tests/` too), so `--fail-under=0` + and `--include` scope the report. `strict-no-cover` has no false positives on + partial runs — if your new test executes a line marked `# pragma: no cover`, + even a single-file run catches it. + +Avoid adding new `# pragma: no cover`, `# type: ignore`, or `# noqa` comments. +In tests, use `assert isinstance(x, T)` to narrow types instead of +`# type: ignore`. In library code (`src/`), a `# pragma: no cover` needs very +good reasoning — it usually means a test is missing. Audit before pushing: + +```bash +git diff origin/main... | grep -E '^\+.*(pragma|type: ignore|noqa)' +``` + +What the existing pragmas mean: + +- `# pragma: no cover` — line is never executed. CI's `strict-no-cover` (skipped + on Windows runners) fails if it IS executed. When your test starts covering + such a line, remove the pragma. +- `# pragma: lax no cover` — excluded from coverage but not checked by + `strict-no-cover`. Use for lines covered on some platforms/versions but not + others. +- `# pragma: no branch` — excludes branch arcs only. coverage.py misreports the + `->exit` arc for nested `async with` on Python 3.11+ (worse on 3.14/Windows). + +## Breaking Changes + +When making breaking changes, document them in `docs/migration.md`. Include: + +- What changed +- Why it changed +- How to migrate existing code + +Search for related sections in the migration guide and group related changes together +rather than adding new standalone sections. + +## Formatting & Type Checking + +- Format: `uv run --frozen ruff format .` +- Lint: `uv run --frozen ruff check . --fix` +- Type check: `uv run --frozen pyright` +- Pre-commit runs all of the above plus markdownlint, a `uv.lock` consistency + check, and README checks — see `.pre-commit-config.yaml` + +## Exception Handling + +- **Always use `logger.exception()` instead of `logger.error()` when catching exceptions** + - Don't include the exception in the message: `logger.exception("Failed")` not `logger.exception(f"Failed: {e}")` +- **Catch specific exceptions** where possible: + - File ops: `except (OSError, PermissionError):` + - JSON: `except json.JSONDecodeError:` + - Network: `except (ConnectionError, TimeoutError):` +- **FORBIDDEN** `except Exception:` - unless in top-level handlers diff --git a/CLAUDE.md b/CLAUDE.md index 2eee085e13..43c994c2d3 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,174 +1 @@ -# Development Guidelines - -This document contains critical information about working with this codebase. Follow these guidelines precisely. - -## Core Development Rules - -1. Package Management - - ONLY use uv, NEVER pip - - Installation: `uv add ` - - Running tools: `uv run ` - - Upgrading: `uv lock --upgrade-package ` - - FORBIDDEN: `uv pip install`, `@latest` syntax - -2. Code Quality - - Type hints required for all code - - Public APIs must have docstrings - - Functions must be focused and small - - Follow existing patterns exactly - - Line length: 120 chars maximum - - FORBIDDEN: imports inside functions. THEY SHOULD BE AT THE TOP OF THE FILE. - -3. Testing Requirements - - Framework: `uv run --frozen pytest` - - Async testing: use anyio, not asyncio - - Do not use `Test` prefixed classes, use functions - - Coverage: test edge cases and errors - - New features require tests - - Bug fixes require regression tests - - IMPORTANT: The `tests/client/test_client.py` is the most well designed test file. Follow its patterns. - - IMPORTANT: Be minimal, and focus on E2E tests: Use the `mcp.client.Client` whenever possible. - - Coverage: CI requires 100% (`fail_under = 100`, `branch = true`). - - Full check: `./scripts/test` (~23s). Runs coverage + `strict-no-cover` on the - default Python. Not identical to CI: CI also runs 3.10–3.14 × {ubuntu, windows}, - and some branch-coverage quirks only surface on specific matrix entries. - - Targeted check while iterating (~4s, deterministic): - - ```bash - uv run --frozen coverage erase - uv run --frozen coverage run -m pytest tests/path/test_foo.py - uv run --frozen coverage combine - uv run --frozen coverage report --include='src/mcp/path/foo.py' --fail-under=0 - UV_FROZEN=1 uv run --frozen strict-no-cover - ``` - - Partial runs can't hit 100% (coverage tracks `tests/` too), so `--fail-under=0` - and `--include` scope the report. `strict-no-cover` has no false positives on - partial runs — if your new test executes a line marked `# pragma: no cover`, - even a single-file run catches it. - - Coverage pragmas: - - `# pragma: no cover` — line is never executed. CI's `strict-no-cover` fails if - it IS executed. When your test starts covering such a line, remove the pragma. - - `# pragma: lax no cover` — excluded from coverage but not checked by - `strict-no-cover`. Use for lines covered on some platforms/versions but not - others. - - `# pragma: no branch` — excludes branch arcs only. coverage.py misreports the - `->exit` arc for nested `async with` on Python 3.11+ (worse on 3.14/Windows). - - Avoid `anyio.sleep()` with a fixed duration to wait for async operations. Instead: - - Use `anyio.Event` — set it in the callback/handler, `await event.wait()` in the test - - For stream messages, use `await stream.receive()` instead of `sleep()` + `receive_nowait()` - - Exception: `sleep()` is appropriate when testing time-based features (e.g., timeouts) - - Wrap indefinite waits (`event.wait()`, `stream.receive()`) in `anyio.fail_after(5)` to prevent hangs - -Test files mirror the source tree: `src/mcp/client/streamable_http.py` → `tests/client/test_streamable_http.py` -Add tests to the existing file for that module. - -- For commits fixing bugs or adding features based on user reports add: - - ```bash - git commit --trailer "Reported-by:" - ``` - - Where `` is the name of the user. - -- For commits related to a Github issue, add - - ```bash - git commit --trailer "Github-Issue:#" - ``` - -- NEVER ever mention a `co-authored-by` or similar aspects. In particular, never - mention the tool used to create the commit message or PR. - -## Pull Requests - -- Create a detailed message of what changed. Focus on the high level description of - the problem it tries to solve, and how it is solved. Don't go into the specifics of the - code unless it adds clarity. - -- NEVER ever mention a `co-authored-by` or similar aspects. In particular, never - mention the tool used to create the commit message or PR. - -## Breaking Changes - -When making breaking changes, document them in `docs/migration.md`. Include: - -- What changed -- Why it changed -- How to migrate existing code - -Search for related sections in the migration guide and group related changes together -rather than adding new standalone sections. - -## Python Tools - -## Code Formatting - -1. Ruff - - Format: `uv run --frozen ruff format .` - - Check: `uv run --frozen ruff check .` - - Fix: `uv run --frozen ruff check . --fix` - - Critical issues: - - Line length (88 chars) - - Import sorting (I001) - - Unused imports - - Line wrapping: - - Strings: use parentheses - - Function calls: multi-line with proper indent - - Imports: try to use a single line - -2. Type Checking - - Tool: `uv run --frozen pyright` - - Requirements: - - Type narrowing for strings - - Version warnings can be ignored if checks pass - -3. Pre-commit - - Config: `.pre-commit-config.yaml` - - Runs: on git commit - - Tools: Prettier (YAML/JSON), Ruff (Python) - - Ruff updates: - - Check PyPI versions - - Update config rev - - Commit config first - -## Error Resolution - -1. CI Failures - - Fix order: - 1. Formatting - 2. Type errors - 3. Linting - - Type errors: - - Get full line context - - Check Optional types - - Add type narrowing - - Verify function signatures - -2. Common Issues - - Line length: - - Break strings with parentheses - - Multi-line function calls - - Split imports - - Types: - - Add None checks - - Narrow string types - - Match existing patterns - -3. Best Practices - - Check git status before commits - - Run formatters before type checks - - Keep changes minimal - - Follow existing patterns - - Document public APIs - - Test thoroughly - -## Exception Handling - -- **Always use `logger.exception()` instead of `logger.error()` when catching exceptions** - - Don't include the exception in the message: `logger.exception("Failed")` not `logger.exception(f"Failed: {e}")` -- **Catch specific exceptions** where possible: - - File ops: `except (OSError, PermissionError):` - - JSON: `except json.JSONDecodeError:` - - Network: `except (ConnectionError, TimeoutError):` -- **FORBIDDEN** `except Exception:` - unless in top-level handlers +@AGENTS.md diff --git a/pyproject.toml b/pyproject.toml index be1200cff0..a5d2c3d80a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -185,7 +185,6 @@ filterwarnings = [ # This should be fixed on Uvicorn's side. "ignore::DeprecationWarning:websockets", "ignore:websockets.server.WebSocketServerProtocol is deprecated:DeprecationWarning", - "ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel", # pywin32 internal deprecation warning "ignore:getargs.*The 'u' format is deprecated:DeprecationWarning", ] From 2dfb51a4d1af35ed0462c49ea77cca0a3e68766d Mon Sep 17 00:00:00 2001 From: Felix Weinberger <3823880+felixweinberger@users.noreply.github.com> Date: Mon, 13 Apr 2026 15:42:35 +0100 Subject: [PATCH 58/84] fix(auth): coerce empty-string optional URL fields to None in OAuthClientMetadata (#2404) --- AGENTS.md | 4 +- src/mcp/shared/auth.py | 18 +++++++++ tests/shared/test_auth.py | 82 ++++++++++++++++++++++++++++++++++++++- 3 files changed, 102 insertions(+), 2 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 9692271044..307bd81b3e 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -45,7 +45,9 @@ - Framework: `uv run --frozen pytest` - Async testing: use anyio, not asyncio -- Do not use `Test` prefixed classes, use functions +- Do not use `Test` prefixed classes — write plain top-level `test_*` functions. + Legacy files still contain `Test*` classes; do NOT follow that pattern for new + tests even when adding to such a file. - IMPORTANT: Tests should be fast and deterministic. Prefer in-memory async execution; reach for threads only when necessary, and subprocesses only as a last resort. - For end-to-end behavior, an in-memory `Client(server)` is usually the diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index ca5b7b45ab..ebf534d792 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -67,6 +67,24 @@ class OAuthClientMetadata(BaseModel): software_id: str | None = None software_version: str | None = None + @field_validator( + "client_uri", + "logo_uri", + "tos_uri", + "policy_uri", + "jwks_uri", + mode="before", + ) + @classmethod + def _empty_string_optional_url_to_none(cls, v: object) -> object: + # RFC 7591 §2 marks these URL fields OPTIONAL. Some authorization servers + # echo omitted metadata back as "" instead of dropping the keys, which + # AnyHttpUrl would otherwise reject — throwing away an otherwise valid + # registration response. Treat "" as absent. + if v == "": + return None + return v + def validate_scope(self, requested_scope: str | None) -> list[str] | None: if requested_scope is None: return None diff --git a/tests/shared/test_auth.py b/tests/shared/test_auth.py index cd3c35332f..7463bc5a8a 100644 --- a/tests/shared/test_auth.py +++ b/tests/shared/test_auth.py @@ -1,6 +1,9 @@ """Tests for OAuth 2.0 shared code.""" -from mcp.shared.auth import OAuthMetadata +import pytest +from pydantic import ValidationError + +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthMetadata def test_oauth(): @@ -58,3 +61,80 @@ def test_oauth_with_jarm(): "token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"], } ) + + +# RFC 7591 §2 marks client_uri/logo_uri/tos_uri/policy_uri/jwks_uri as OPTIONAL. +# Some authorization servers echo the client's omitted metadata back as "" +# instead of dropping the keys; without coercion, AnyHttpUrl rejects "" and +# the whole registration response is thrown away even though the server +# returned a valid client_id. + + +@pytest.mark.parametrize( + "empty_field", + ["client_uri", "logo_uri", "tos_uri", "policy_uri", "jwks_uri"], +) +def test_optional_url_empty_string_coerced_to_none(empty_field: str): + data = { + "redirect_uris": ["https://example.com/callback"], + empty_field: "", + } + metadata = OAuthClientMetadata.model_validate(data) + assert getattr(metadata, empty_field) is None + + +def test_all_optional_urls_empty_together(): + data = { + "redirect_uris": ["https://example.com/callback"], + "client_uri": "", + "logo_uri": "", + "tos_uri": "", + "policy_uri": "", + "jwks_uri": "", + } + metadata = OAuthClientMetadata.model_validate(data) + assert metadata.client_uri is None + assert metadata.logo_uri is None + assert metadata.tos_uri is None + assert metadata.policy_uri is None + assert metadata.jwks_uri is None + + +def test_valid_url_passes_through_unchanged(): + data = { + "redirect_uris": ["https://example.com/callback"], + "client_uri": "https://udemy.com/", + } + metadata = OAuthClientMetadata.model_validate(data) + assert str(metadata.client_uri) == "https://udemy.com/" + + +def test_information_full_inherits_coercion(): + """OAuthClientInformationFull subclasses OAuthClientMetadata, so the + same coercion applies to DCR responses parsed via the full model.""" + data = { + "client_id": "abc123", + "redirect_uris": ["https://example.com/callback"], + "client_uri": "", + "logo_uri": "", + "tos_uri": "", + "policy_uri": "", + "jwks_uri": "", + } + info = OAuthClientInformationFull.model_validate(data) + assert info.client_id == "abc123" + assert info.client_uri is None + assert info.logo_uri is None + assert info.tos_uri is None + assert info.policy_uri is None + assert info.jwks_uri is None + + +def test_invalid_non_empty_url_still_rejected(): + """Coercion must only touch empty strings — garbage URLs still raise.""" + data = { + "redirect_uris": ["https://example.com/callback"], + "client_uri": "not a url", + } + with pytest.raises(ValidationError): + OAuthClientMetadata.model_validate(data) From 5cbd259c3bc97a9fe6ea661941feb637e72551be Mon Sep 17 00:00:00 2001 From: Max <224885523+maxisbey@users.noreply.github.com> Date: Mon, 13 Apr 2026 17:03:40 +0100 Subject: [PATCH 59/84] fix: catch PydanticUserError when generating output schema (pydantic 2.13 compat) (#2434) --- src/mcp/server/mcpserver/utilities/func_metadata.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/mcp/server/mcpserver/utilities/func_metadata.py b/src/mcp/server/mcpserver/utilities/func_metadata.py index 062b47d0ff..4a76106371 100644 --- a/src/mcp/server/mcpserver/utilities/func_metadata.py +++ b/src/mcp/server/mcpserver/utilities/func_metadata.py @@ -9,7 +9,7 @@ import anyio import anyio.to_thread import pydantic_core -from pydantic import BaseModel, ConfigDict, Field, WithJsonSchema, create_model +from pydantic import BaseModel, ConfigDict, Field, PydanticUserError, WithJsonSchema, create_model from pydantic.fields import FieldInfo from pydantic.json_schema import GenerateJsonSchema, JsonSchemaWarningKind from typing_extensions import is_typeddict @@ -402,9 +402,16 @@ def _try_create_model_and_schema( # Use StrictJsonSchema to raise exceptions instead of warnings try: schema = model.model_json_schema(schema_generator=StrictJsonSchema) - except (TypeError, ValueError, pydantic_core.SchemaError, pydantic_core.ValidationError) as e: + except ( + PydanticUserError, + TypeError, + ValueError, + pydantic_core.SchemaError, + pydantic_core.ValidationError, + ) as e: # These are expected errors when a type can't be converted to a Pydantic schema - # TypeError: When Pydantic can't handle the type + # PydanticUserError: When Pydantic can't handle the type (e.g. PydanticInvalidForJsonSchema); + # subclasses TypeError on pydantic <2.13 and RuntimeError on pydantic >=2.13 # ValueError: When there are issues with the type definition (including our custom warnings) # SchemaError: When Pydantic can't build a schema # ValidationError: When validation fails From 437d15aa7126b86ffecfa6fd8c0158851a340ced Mon Sep 17 00:00:00 2001 From: Wils Dawson Date: Tue, 14 Apr 2026 03:48:07 -0700 Subject: [PATCH 60/84] SEP-2207: Refresh token guidance (#2039) --- src/mcp/client/auth/oauth2.py | 15 +- src/mcp/client/auth/utils.py | 40 ++-- tests/client/test_auth.py | 354 ++++++++++++++++++++++++++++++++++ 3 files changed, 392 insertions(+), 17 deletions(-) diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 25075dec3b..72309f5775 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -320,7 +320,7 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: raise OAuthFlowError("No callback handler provided for authorization code grant") # pragma: no cover if self.context.oauth_metadata and self.context.oauth_metadata.authorization_endpoint: - auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint) # pragma: no cover + auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint) else: auth_base_url = self.context.get_authorization_base_url(self.context.server_url) auth_endpoint = urljoin(auth_base_url, "/authorize") @@ -343,11 +343,16 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: # Only include resource param if conditions are met if self.context.should_include_resource_param(self.context.protocol_version): - auth_params["resource"] = self.context.get_resource_url() # RFC 8707 # pragma: no cover + auth_params["resource"] = self.context.get_resource_url() # RFC 8707 if self.context.client_metadata.scope: # pragma: no branch auth_params["scope"] = self.context.client_metadata.scope + # OIDC requires prompt=consent when offline_access is requested + # https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess + if "offline_access" in self.context.client_metadata.scope.split(): + auth_params["prompt"] = "consent" + authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}" await self.context.redirect_handler(authorization_url) @@ -576,6 +581,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. extract_scope_from_www_auth(response), self.context.protected_resource_metadata, self.context.oauth_metadata, + self.context.client_metadata.grant_types, ) # Step 4: Register client or use URL-based client ID (CIMD) @@ -622,7 +628,10 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. try: # Step 2a: Update the required scopes self.context.client_metadata.scope = get_client_metadata_scopes( - extract_scope_from_www_auth(response), self.context.protected_resource_metadata + extract_scope_from_www_auth(response), + self.context.protected_resource_metadata, + self.context.oauth_metadata, + self.context.client_metadata.grant_types, ) # Step 2b: Perform (re-)authorization and token exchange diff --git a/src/mcp/client/auth/utils.py b/src/mcp/client/auth/utils.py index 0ca36b98d8..d75324f2f0 100644 --- a/src/mcp/client/auth/utils.py +++ b/src/mcp/client/auth/utils.py @@ -99,24 +99,36 @@ def get_client_metadata_scopes( www_authenticate_scope: str | None, protected_resource_metadata: ProtectedResourceMetadata | None, authorization_server_metadata: OAuthMetadata | None = None, + client_grant_types: list[str] | None = None, ) -> str | None: - """Select scopes as outlined in the 'Scope Selection Strategy' in the MCP spec.""" - # Per MCP spec, scope selection priority order: - # 1. Use scope from WWW-Authenticate header (if provided) - # 2. Use all scopes from PRM scopes_supported (if available) - # 3. Omit scope parameter if neither is available - + """Select effective scopes and augment for refresh token support.""" + selected_scope: str | None = None + + # MCP spec scope selection priority: + # 1. WWW-Authenticate header scope + # 2. PRM scopes_supported + # 3. AS scopes_supported (SDK fallback) + # 4. Omit scope parameter if www_authenticate_scope is not None: - # Priority 1: WWW-Authenticate header scope - return www_authenticate_scope + selected_scope = www_authenticate_scope elif protected_resource_metadata is not None and protected_resource_metadata.scopes_supported is not None: - # Priority 2: PRM scopes_supported - return " ".join(protected_resource_metadata.scopes_supported) + selected_scope = " ".join(protected_resource_metadata.scopes_supported) elif authorization_server_metadata is not None and authorization_server_metadata.scopes_supported is not None: - return " ".join(authorization_server_metadata.scopes_supported) # pragma: no cover - else: - # Priority 3: Omit scope parameter - return None + selected_scope = " ".join(authorization_server_metadata.scopes_supported) + + # SEP-2207: append offline_access when the AS supports it and the client can use refresh tokens + if ( + selected_scope is not None + and authorization_server_metadata is not None + and authorization_server_metadata.scopes_supported is not None + and "offline_access" in authorization_server_metadata.scopes_supported + and client_grant_types is not None + and "refresh_token" in client_grant_types + and "offline_access" not in selected_scope.split() + ): + selected_scope = f"{selected_scope} offline_access" + + return selected_scope def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: str | None, server_url: str) -> list[str]: diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 5aa985e360..bb0bce4c92 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -2264,3 +2264,357 @@ async def callback_handler() -> tuple[str, str | None]: await auth_flow.asend(final_response) except StopAsyncIteration: pass + + +class TestSEP2207OfflineAccessScope: + """Test SEP-2207: offline_access scope augmentation for OIDC-flavored refresh tokens.""" + + def _make_as_metadata(self, scopes_supported: list[str] | None = None) -> OAuthMetadata: + return OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + scopes_supported=scopes_supported, + ) + + def _make_prm(self, scopes_supported: list[str] | None = None) -> ProtectedResourceMetadata: + return ProtectedResourceMetadata( + resource=AnyHttpUrl("https://api.example.com/v1/mcp"), + authorization_servers=[AnyHttpUrl("https://auth.example.com")], + scopes_supported=scopes_supported, + ) + + def test_offline_access_added_when_as_supports_and_client_has_refresh_token(self): + """offline_access is appended when AS advertises it and client supports refresh_token grant.""" + prm = self._make_prm(scopes_supported=["read", "write"]) + asm = self._make_as_metadata(scopes_supported=["read", "write", "offline_access"]) + + scopes = get_client_metadata_scopes( + www_authenticate_scope=None, + protected_resource_metadata=prm, + authorization_server_metadata=asm, + client_grant_types=["authorization_code", "refresh_token"], + ) + assert scopes == "read write offline_access" + + def test_offline_access_added_with_www_authenticate_scope(self): + """offline_access is appended even when scopes come from WWW-Authenticate header.""" + asm = self._make_as_metadata(scopes_supported=["read", "write", "offline_access"]) + + scopes = get_client_metadata_scopes( + www_authenticate_scope="read write", + protected_resource_metadata=None, + authorization_server_metadata=asm, + client_grant_types=["authorization_code", "refresh_token"], + ) + assert scopes == "read write offline_access" + + def test_offline_access_not_added_when_as_does_not_support(self): + """offline_access is not added when AS does not advertise it in scopes_supported.""" + prm = self._make_prm(scopes_supported=["read", "write"]) + asm = self._make_as_metadata(scopes_supported=["read", "write"]) + + scopes = get_client_metadata_scopes( + www_authenticate_scope=None, + protected_resource_metadata=prm, + authorization_server_metadata=asm, + client_grant_types=["authorization_code", "refresh_token"], + ) + assert scopes == "read write" + + def test_offline_access_not_added_when_client_has_no_refresh_token_grant(self): + """offline_access is not added when client does not support refresh_token grant.""" + prm = self._make_prm(scopes_supported=["read", "write"]) + asm = self._make_as_metadata(scopes_supported=["read", "write", "offline_access"]) + + scopes = get_client_metadata_scopes( + www_authenticate_scope=None, + protected_resource_metadata=prm, + authorization_server_metadata=asm, + client_grant_types=["authorization_code"], + ) + assert scopes == "read write" + + def test_offline_access_not_duplicated_when_already_present(self): + """offline_access is not added again if it already appears in the selected scopes.""" + prm = self._make_prm(scopes_supported=["read", "offline_access", "write"]) + asm = self._make_as_metadata(scopes_supported=["read", "write", "offline_access"]) + + scopes = get_client_metadata_scopes( + www_authenticate_scope=None, + protected_resource_metadata=prm, + authorization_server_metadata=asm, + client_grant_types=["authorization_code", "refresh_token"], + ) + assert scopes == "read offline_access write" + + def test_offline_access_not_added_when_no_scopes_selected(self): + """offline_access is not added when no base scopes are available (None).""" + asm = self._make_as_metadata(scopes_supported=["offline_access"]) + + scopes = get_client_metadata_scopes( + www_authenticate_scope=None, + protected_resource_metadata=None, + authorization_server_metadata=asm, + client_grant_types=["authorization_code", "refresh_token"], + ) + # When AS scopes are the only source and include offline_access, + # the base scope is "offline_access" and no duplication happens + assert scopes == "offline_access" + + def test_offline_access_not_added_when_as_scopes_supported_is_none(self): + """offline_access is not added when AS scopes_supported is None.""" + prm = self._make_prm(scopes_supported=["read", "write"]) + asm = self._make_as_metadata(scopes_supported=None) + + scopes = get_client_metadata_scopes( + www_authenticate_scope=None, + protected_resource_metadata=prm, + authorization_server_metadata=asm, + client_grant_types=["authorization_code", "refresh_token"], + ) + assert scopes == "read write" + + def test_offline_access_not_added_when_no_as_metadata(self): + """offline_access is not added when AS metadata is not available.""" + prm = self._make_prm(scopes_supported=["read", "write"]) + + scopes = get_client_metadata_scopes( + www_authenticate_scope=None, + protected_resource_metadata=prm, + authorization_server_metadata=None, + client_grant_types=["authorization_code", "refresh_token"], + ) + assert scopes == "read write" + + def test_offline_access_not_added_when_no_grant_types_provided(self): + """offline_access is not added when client_grant_types is None.""" + prm = self._make_prm(scopes_supported=["read", "write"]) + asm = self._make_as_metadata(scopes_supported=["read", "write", "offline_access"]) + + scopes = get_client_metadata_scopes( + www_authenticate_scope=None, + protected_resource_metadata=prm, + authorization_server_metadata=asm, + client_grant_types=None, + ) + assert scopes == "read write" + + def test_default_client_metadata_includes_refresh_token_grant(self): + """Default OAuthClientMetadata includes refresh_token in grant_types (SEP-2207 guidance).""" + metadata = OAuthClientMetadata(redirect_uris=[AnyUrl("http://localhost:3030/callback")]) + assert "refresh_token" in metadata.grant_types + + @pytest.mark.anyio + async def test_auth_flow_adds_offline_access_when_as_advertises( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ): + """E2E: auth flow includes offline_access in authorization request when AS advertises it.""" + + captured_auth_url: str | None = None + captured_state: str | None = None + + async def redirect_handler(url: str) -> None: + nonlocal captured_auth_url, captured_state + captured_auth_url = url + parsed = urlparse(url) + params = parse_qs(parsed.query) + captured_state = params.get("state", [None])[0] + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", captured_state + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + provider.context.current_tokens = None + provider.context.token_expiry_time = None + provider._initialized = True + + # Pre-set client info to skip DCR + provider.context.client_info = OAuthClientInformationFull( + client_id="test_client", + client_secret="test_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + + test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") + auth_flow = provider.async_auth_flow(test_request) + + # First request + request = await auth_flow.__anext__() + assert "Authorization" not in request.headers + + # Send 401 + response = httpx.Response(401, headers={}, request=test_request) + + # PRM discovery + prm_request = await auth_flow.asend(response) + prm_response = httpx.Response( + 200, + content=( + b'{"resource": "https://api.example.com/v1/mcp",' + b' "authorization_servers": ["https://auth.example.com"],' + b' "scopes_supported": ["read", "write"]}' + ), + request=prm_request, + ) + + # OAuth metadata discovery - AS advertises offline_access + oauth_request = await auth_flow.asend(prm_response) + oauth_response = httpx.Response( + 200, + content=( + b'{"issuer": "https://auth.example.com",' + b' "authorization_endpoint": "https://auth.example.com/authorize",' + b' "token_endpoint": "https://auth.example.com/token",' + b' "scopes_supported": ["read", "write", "offline_access"]}' + ), + request=oauth_request, + ) + + # This triggers authorization, which calls redirect_handler + token_request = await auth_flow.asend(oauth_response) + + # Verify the authorization URL included offline_access in the scope + assert captured_auth_url is not None + parsed = urlparse(captured_auth_url) + params = parse_qs(parsed.query) + scope_value = params["scope"][0] + scope_parts = scope_value.split() + assert "offline_access" in scope_parts + assert "read" in scope_parts + assert "write" in scope_parts + + # OIDC requires prompt=consent when offline_access is requested + assert params["prompt"][0] == "consent" + + # Complete the token exchange + token_response = httpx.Response( + 200, + content=( + b'{"access_token": "new_access_token", "token_type": "Bearer",' + b' "expires_in": 3600, "refresh_token": "new_refresh_token"}' + ), + request=token_request, + ) + + final_request = await auth_flow.asend(token_response) + assert final_request.headers["Authorization"] == "Bearer new_access_token" + + # Close the generator + final_response = httpx.Response(200, request=final_request) + try: + await auth_flow.asend(final_response) + except StopAsyncIteration: + pass + + @pytest.mark.anyio + async def test_auth_flow_no_offline_access_when_as_does_not_advertise( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ): + """E2E: auth flow does NOT include offline_access when AS doesn't advertise it.""" + + captured_auth_url: str | None = None + captured_state: str | None = None + + async def redirect_handler(url: str) -> None: + nonlocal captured_auth_url, captured_state + captured_auth_url = url + parsed = urlparse(url) + params = parse_qs(parsed.query) + captured_state = params.get("state", [None])[0] + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", captured_state + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + provider.context.current_tokens = None + provider.context.token_expiry_time = None + provider._initialized = True + + # Pre-set client info to skip DCR + provider.context.client_info = OAuthClientInformationFull( + client_id="test_client", + client_secret="test_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + + test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") + auth_flow = provider.async_auth_flow(test_request) + + # First request + await auth_flow.__anext__() + + # Send 401 + response = httpx.Response(401, headers={}, request=test_request) + + # PRM discovery + prm_request = await auth_flow.asend(response) + prm_response = httpx.Response( + 200, + content=( + b'{"resource": "https://api.example.com/v1/mcp",' + b' "authorization_servers": ["https://auth.example.com"],' + b' "scopes_supported": ["read", "write"]}' + ), + request=prm_request, + ) + + # OAuth metadata discovery - AS does NOT advertise offline_access + oauth_request = await auth_flow.asend(prm_response) + oauth_response = httpx.Response( + 200, + content=( + b'{"issuer": "https://auth.example.com",' + b' "authorization_endpoint": "https://auth.example.com/authorize",' + b' "token_endpoint": "https://auth.example.com/token",' + b' "scopes_supported": ["read", "write"]}' + ), + request=oauth_request, + ) + + # This triggers authorization, which calls redirect_handler + token_request = await auth_flow.asend(oauth_response) + + # Verify the authorization URL does NOT include offline_access + assert captured_auth_url is not None + parsed = urlparse(captured_auth_url) + params = parse_qs(parsed.query) + scope_value = params["scope"][0] + scope_parts = scope_value.split() + assert "offline_access" not in scope_parts + assert "read" in scope_parts + assert "write" in scope_parts + + # prompt=consent should NOT be present without offline_access + assert "prompt" not in params + + # Complete the token exchange + token_response = httpx.Response( + 200, + content=b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600}', + request=token_request, + ) + + final_request = await auth_flow.asend(token_response) + assert final_request.headers["Authorization"] == "Bearer new_access_token" + + # Close the generator + final_response = httpx.Response(200, request=final_request) + try: + await auth_flow.asend(final_response) + except StopAsyncIteration: + pass From 3d7b311de07aade1281d18aa7b04689a81ab8793 Mon Sep 17 00:00:00 2001 From: "Gyeongjun Paik (Kent)" Date: Wed, 15 Apr 2026 06:41:51 +0900 Subject: [PATCH 61/84] fix: align Context logging methods with MCP spec data type (#2366) Co-authored-by: Claude Opus 4.6 Co-authored-by: Max Isbey <224885523+maxisbey@users.noreply.github.com> --- README.v2.md | 10 +++---- docs/migration.md | 20 +++++++++++++ src/mcp/server/mcpserver/context.py | 41 +++++++++++---------------- tests/client/test_logging_callback.py | 33 +++++++++------------ 4 files changed, 55 insertions(+), 49 deletions(-) diff --git a/README.v2.md b/README.v2.md index 55d867586d..d0851c04e5 100644 --- a/README.v2.md +++ b/README.v2.md @@ -681,11 +681,11 @@ The Context object provides the following capabilities: - `ctx.mcp_server` - Access to the MCPServer server instance (see [MCPServer Properties](#mcpserver-properties)) - `ctx.session` - Access to the underlying session for advanced communication (see [Session Properties and Methods](#session-properties-and-methods)) - `ctx.request_context` - Access to request-specific data and lifespan resources (see [Request Context Properties](#request-context-properties)) -- `await ctx.debug(message)` - Send debug log message -- `await ctx.info(message)` - Send info log message -- `await ctx.warning(message)` - Send warning log message -- `await ctx.error(message)` - Send error log message -- `await ctx.log(level, message, logger_name=None)` - Send log with custom level +- `await ctx.debug(data)` - Send debug log message +- `await ctx.info(data)` - Send info log message +- `await ctx.warning(data)` - Send warning log message +- `await ctx.error(data)` - Send error log message +- `await ctx.log(level, data, logger_name=None)` - Send log with custom level - `await ctx.report_progress(progress, total=None, message=None)` - Report operation progress - `await ctx.read_resource(uri)` - Read a resource by URI - `await ctx.elicit(message, schema)` - Request additional information from user with validation diff --git a/docs/migration.md b/docs/migration.md index 2528f046c6..8b70885e8d 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -467,6 +467,26 @@ mcp._lowlevel_server._add_request_handler("resources/subscribe", handle_subscrib This is a private API and may change. A public way to register these handlers on `MCPServer` is planned; until then, use this workaround or use the lowlevel `Server` directly. +### `MCPServer`'s `Context` logging: `message` renamed to `data`, `extra` removed + +On the high-level `Context` object (`mcp.server.mcpserver.Context`), `log()`, `.debug()`, `.info()`, `.warning()`, and `.error()` now take `data: Any` instead of `message: str`, matching the MCP spec's `LoggingMessageNotificationParams.data` field which allows any JSON-serializable value. The `extra` parameter has been removed — pass structured data directly as `data`. + +The lowlevel `ServerSession.send_log_message(data: Any)` already accepted arbitrary data and is unchanged. + +`Context.log()` also now accepts all eight RFC-5424 log levels (`debug`, `info`, `notice`, `warning`, `error`, `critical`, `alert`, `emergency`) via the `LoggingLevel` type, not just the four it previously allowed. + +```python +# Before +await ctx.info("Connection failed", extra={"host": "localhost", "port": 5432}) +await ctx.log(level="info", message="hello") + +# After +await ctx.info({"message": "Connection failed", "host": "localhost", "port": 5432}) +await ctx.log(level="info", data="hello") +``` + +Positional calls (`await ctx.info("hello")`) are unaffected. + ### Replace `RootModel` by union types with `TypeAdapter` validation The following union types are no longer `RootModel` subclasses: diff --git a/src/mcp/server/mcpserver/context.py b/src/mcp/server/mcpserver/context.py index 1538adc7c7..e87388eee9 100644 --- a/src/mcp/server/mcpserver/context.py +++ b/src/mcp/server/mcpserver/context.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Generic, Literal +from typing import TYPE_CHECKING, Any, Generic from pydantic import AnyUrl, BaseModel @@ -14,6 +14,7 @@ elicit_with_validation, ) from mcp.server.lowlevel.helper_types import ReadResourceContents +from mcp.types import LoggingLevel if TYPE_CHECKING: from mcp.server.mcpserver.server import MCPServer @@ -186,29 +187,23 @@ async def elicit_url( async def log( self, - level: Literal["debug", "info", "warning", "error"], - message: str, + level: LoggingLevel, + data: Any, *, logger_name: str | None = None, - extra: dict[str, Any] | None = None, ) -> None: """Send a log message to the client. Args: - level: Log level (debug, info, warning, error) - message: Log message + level: Log level (debug, info, notice, warning, error, critical, + alert, emergency) + data: The data to be logged. Any JSON serializable type is allowed + (string, dict, list, number, bool, etc.) per the MCP specification. logger_name: Optional logger name - extra: Optional dictionary with additional structured data to include """ - - if extra: - log_data = {"message": message, **extra} - else: - log_data = message - await self.request_context.session.send_log_message( level=level, - data=log_data, + data=data, logger=logger_name, related_request_id=self.request_id, ) @@ -261,20 +256,18 @@ async def close_standalone_sse_stream(self) -> None: await self._request_context.close_standalone_sse_stream() # Convenience methods for common log levels - async def debug(self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None) -> None: + async def debug(self, data: Any, *, logger_name: str | None = None) -> None: """Send a debug log message.""" - await self.log("debug", message, logger_name=logger_name, extra=extra) + await self.log("debug", data, logger_name=logger_name) - async def info(self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None) -> None: + async def info(self, data: Any, *, logger_name: str | None = None) -> None: """Send an info log message.""" - await self.log("info", message, logger_name=logger_name, extra=extra) + await self.log("info", data, logger_name=logger_name) - async def warning( - self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None - ) -> None: + async def warning(self, data: Any, *, logger_name: str | None = None) -> None: """Send a warning log message.""" - await self.log("warning", message, logger_name=logger_name, extra=extra) + await self.log("warning", data, logger_name=logger_name) - async def error(self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None) -> None: + async def error(self, data: Any, *, logger_name: str | None = None) -> None: """Send an error log message.""" - await self.log("error", message, logger_name=logger_name, extra=extra) + await self.log("error", data, logger_name=logger_name) diff --git a/tests/client/test_logging_callback.py b/tests/client/test_logging_callback.py index 1598fd55f6..454c1d3382 100644 --- a/tests/client/test_logging_callback.py +++ b/tests/client/test_logging_callback.py @@ -1,4 +1,4 @@ -from typing import Any, Literal +from typing import Literal import pytest @@ -36,24 +36,20 @@ async def test_tool_with_log( message: str, level: Literal["debug", "info", "warning", "error"], logger: str, ctx: Context ) -> bool: """Send a log notification to the client.""" - await ctx.log(level=level, message=message, logger_name=logger) + await ctx.log(level=level, data=message, logger_name=logger) return True - @server.tool("test_tool_with_log_extra") - async def test_tool_with_log_extra( - message: str, + @server.tool("test_tool_with_log_dict") + async def test_tool_with_log_dict( level: Literal["debug", "info", "warning", "error"], logger: str, - extra_string: str, - extra_dict: dict[str, Any], ctx: Context, ) -> bool: - """Send a log notification to the client with extra fields.""" + """Send a log notification with a dict payload.""" await ctx.log( level=level, - message=message, + data={"message": "Test log message", "extra_string": "example", "extra_dict": {"a": 1, "b": 2, "c": 3}}, logger_name=logger, - extra={"extra_string": extra_string, "extra_dict": extra_dict}, ) return True @@ -84,18 +80,15 @@ async def message_handler( "logger": "test_logger", }, ) - log_result_with_extra = await client.call_tool( - "test_tool_with_log_extra", + log_result_with_dict = await client.call_tool( + "test_tool_with_log_dict", { - "message": "Test log message", "level": "info", "logger": "test_logger", - "extra_string": "example", - "extra_dict": {"a": 1, "b": 2, "c": 3}, }, ) assert log_result.is_error is False - assert log_result_with_extra.is_error is False + assert log_result_with_dict.is_error is False assert len(logging_collector.log_messages) == 2 # Create meta object with related_request_id added dynamically log = logging_collector.log_messages[0] @@ -103,10 +96,10 @@ async def message_handler( assert log.logger == "test_logger" assert log.data == "Test log message" - log_with_extra = logging_collector.log_messages[1] - assert log_with_extra.level == "info" - assert log_with_extra.logger == "test_logger" - assert log_with_extra.data == { + log_with_dict = logging_collector.log_messages[1] + assert log_with_dict.level == "info" + assert log_with_dict.logger == "test_logger" + assert log_with_dict.data == { "message": "Test log message", "extra_string": "example", "extra_dict": {"a": 1, "b": 2, "c": 3}, From 2b0da5631a7cb460c17d60b8cc735626d5e1b3f8 Mon Sep 17 00:00:00 2001 From: Max <224885523+maxisbey@users.noreply.github.com> Date: Thu, 7 May 2026 17:32:30 +0100 Subject: [PATCH 62/84] build: pin PEP 517 build dependencies (#2547) --- pyproject.toml | 18 ++++++++++++++++++ uv.lock | 13 +++++++++++++ 2 files changed, 31 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index a5d2c3d80a..364b9add0b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,24 @@ mcp = "mcp.cli:app [cli]" [tool.uv] default-groups = ["dev", "docs"] required-version = ">=0.9.5" +# PEP 517 build isolation fetches [build-system].requires (and transitives) at +# floating-latest with no hash check on every fresh sync; uv does not lock them +# (astral-sh/uv#5190). Pinning here narrows that to known-good versions. Covers +# the workspace builds (hatchling + uv-dynamic-versioning) and the legacy +# setuptools fallback used by the strict-no-cover git dep. +build-constraint-dependencies = [ + "hatchling==1.29.0", + "uv-dynamic-versioning==0.14.0", + "dunamai==1.26.1", + "jinja2==3.1.6", + "markupsafe==3.0.3", + "packaging==26.1", + "pathspec==1.0.4", + "pluggy==1.6.0", + "tomlkit==0.14.0", + "trove-classifiers==2026.1.14.14", + "setuptools==82.0.1", +] [dependency-groups] dev = [ diff --git a/uv.lock b/uv.lock index 705d014aa5..b396898b66 100644 --- a/uv.lock +++ b/uv.lock @@ -28,6 +28,19 @@ members = [ "mcp-sse-polling-demo", "mcp-structured-output-lowlevel", ] +build-constraints = [ + { name = "dunamai", specifier = "==1.26.1" }, + { name = "hatchling", specifier = "==1.29.0" }, + { name = "jinja2", specifier = "==3.1.6" }, + { name = "markupsafe", specifier = "==3.0.3" }, + { name = "packaging", specifier = "==26.1" }, + { name = "pathspec", specifier = "==1.0.4" }, + { name = "pluggy", specifier = "==1.6.0" }, + { name = "setuptools", specifier = "==82.0.1" }, + { name = "tomlkit", specifier = "==0.14.0" }, + { name = "trove-classifiers", specifier = "==2026.1.14.14" }, + { name = "uv-dynamic-versioning", specifier = "==0.14.0" }, +] [[package]] name = "annotated-types" From bf3e0010b87a6a2535711500eb9590741b1d1940 Mon Sep 17 00:00:00 2001 From: Dayna Blackwell Date: Fri, 8 May 2026 06:28:36 -0700 Subject: [PATCH 63/84] fix: chain exceptions in get_prompt and read_resource handlers (#2542) --- src/mcp/server/mcpserver/server.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index be77705da6..b3471163b7 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -447,8 +447,8 @@ async def read_resource( context = Context(mcp_server=self) try: resource = await self._resource_manager.get_resource(uri, context) - except ValueError: - raise ResourceError(f"Unknown resource: {uri}") + except ValueError as exc: + raise ResourceError(f"Unknown resource: {uri}") from exc try: content = await resource.read() @@ -1109,4 +1109,4 @@ async def get_prompt( ) except Exception as e: logger.exception(f"Error getting prompt {name}") - raise ValueError(str(e)) + raise ValueError(str(e)) from e From 161834d4aee2633c42d3976c8f8751b6c4d947d5 Mon Sep 17 00:00:00 2001 From: Max <224885523+maxisbey@users.noreply.github.com> Date: Fri, 8 May 2026 17:42:44 +0100 Subject: [PATCH 64/84] refactor: import SSEError from httpx_sse public API (#2560) --- src/mcp/client/sse.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 193204a153..74e5ba8062 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -7,8 +7,7 @@ import anyio import httpx from anyio.abc import TaskStatus -from httpx_sse import aconnect_sse -from httpx_sse._exceptions import SSEError +from httpx_sse import SSEError, aconnect_sse from mcp import types from mcp.shared._context_streams import create_context_streams From f4753440dac8b2b6fa6407808e06c51258b78322 Mon Sep 17 00:00:00 2001 From: Max <224885523+maxisbey@users.noreply.github.com> Date: Mon, 18 May 2026 15:28:13 +0100 Subject: [PATCH 65/84] ci: deploy docs to py.sdk.modelcontextprotocol.io via Pages artifact (v1 at /, v2 at /v2/) (#2634) --- .github/workflows/deploy-docs.yml | 57 +++++++++++++++++++++ .github/workflows/publish-docs-manually.yml | 35 ------------- .github/workflows/publish-pypi.yml | 29 ----------- .gitignore | 1 + docs/index.md | 3 ++ mkdocs.yml | 2 +- pyproject.toml | 1 + scripts/build-docs.sh | 54 +++++++++++++++++++ 8 files changed, 117 insertions(+), 65 deletions(-) create mode 100644 .github/workflows/deploy-docs.yml delete mode 100644 .github/workflows/publish-docs-manually.yml create mode 100755 scripts/build-docs.sh diff --git a/.github/workflows/deploy-docs.yml b/.github/workflows/deploy-docs.yml new file mode 100644 index 0000000000..d9362afd57 --- /dev/null +++ b/.github/workflows/deploy-docs.yml @@ -0,0 +1,57 @@ +name: Deploy Docs + +on: + push: + branches: + - main + - v1.x + paths: + - docs/** + - mkdocs.yml + - src/mcp/** + - scripts/build-docs.sh + - pyproject.toml + - uv.lock + - .github/workflows/deploy-docs.yml + workflow_dispatch: + +concurrency: + group: deploy-docs + cancel-in-progress: false + +jobs: + deploy-docs: + runs-on: ubuntu-latest + + permissions: + contents: read + pages: write + id-token: write + + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + + steps: + - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 + + - name: Install uv + uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 + with: + enable-cache: true + version: 0.9.5 + + - name: Build combined docs (v1.x at /, main at /v2/) + run: bash scripts/build-docs.sh site + + - name: Configure Pages + uses: actions/configure-pages@45bfe0192ca1faeb007ade9deae92b16b8254a0d # v6.0.0 + + - name: Upload Pages artifact + uses: actions/upload-pages-artifact@fc324d3547104276b827a68afc52ff2a11cc49c9 # v5.0.0 + with: + path: site + + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@cd2ce8fcbc39b97be8ca5fce6e763baed58fa128 # v5.0.0 diff --git a/.github/workflows/publish-docs-manually.yml b/.github/workflows/publish-docs-manually.yml deleted file mode 100644 index ee45ab5c8a..0000000000 --- a/.github/workflows/publish-docs-manually.yml +++ /dev/null @@ -1,35 +0,0 @@ -name: Publish Docs manually - -on: - workflow_dispatch: - -jobs: - docs-publish: - runs-on: ubuntu-latest - permissions: - contents: write - steps: - - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 - - name: Configure Git Credentials - run: | - git config user.name github-actions[bot] - git config user.email 41898282+github-actions[bot]@users.noreply.github.com - - - name: Install uv - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 - with: - enable-cache: true - version: 0.9.5 - - - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV - - uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3 - with: - key: mkdocs-material-${{ env.cache_id }} - path: .cache - restore-keys: | - mkdocs-material- - - - run: uv sync --frozen --group docs - - run: uv run --frozen --no-sync mkdocs gh-deploy --force - env: - ENABLE_SOCIAL_CARDS: "true" diff --git a/.github/workflows/publish-pypi.yml b/.github/workflows/publish-pypi.yml index bfc3cc64e1..7ba11e86fa 100644 --- a/.github/workflows/publish-pypi.yml +++ b/.github/workflows/publish-pypi.yml @@ -51,32 +51,3 @@ jobs: - name: Publish package distributions to PyPI uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # release/v1 - - docs-publish: - runs-on: ubuntu-latest - needs: ["pypi-publish"] - permissions: - contents: write - steps: - - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 - - name: Configure Git Credentials - run: | - git config user.name github-actions[bot] - git config user.email 41898282+github-actions[bot]@users.noreply.github.com - - - name: Install uv - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 - with: - enable-cache: true - version: 0.9.5 - - - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV - - uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3 - with: - key: mkdocs-material-${{ env.cache_id }} - path: .cache - restore-keys: | - mkdocs-material- - - - run: uv sync --frozen --group docs - - run: uv run --frozen --no-sync mkdocs gh-deploy --force diff --git a/.gitignore b/.gitignore index 5ff4ce9771..3443adf7c8 100644 --- a/.gitignore +++ b/.gitignore @@ -143,6 +143,7 @@ venv.bak/ # mkdocs documentation /site +/.worktrees/ # mypy .mypy_cache/ diff --git a/docs/index.md b/docs/index.md index 436d1c8fcd..6a937da67f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,5 +1,8 @@ # MCP Python SDK +!!! info "You are viewing the in-development v2 documentation" + For the current stable release, see the [v1.x documentation](https://py.sdk.modelcontextprotocol.io/). + The **Model Context Protocol (MCP)** allows applications to provide context for LLMs in a standardized way, separating the concerns of providing context from the actual LLM interaction. This Python SDK implements the full MCP specification, making it easy to: diff --git a/mkdocs.yml b/mkdocs.yml index 3a555785a7..e48c64242d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -5,7 +5,7 @@ strict: true repo_name: modelcontextprotocol/python-sdk repo_url: https://github.com/modelcontextprotocol/python-sdk edit_uri: edit/main/docs/ -site_url: https://modelcontextprotocol.github.io/python-sdk +site_url: https://py.sdk.modelcontextprotocol.io/v2/ # TODO(Marcelo): Add Anthropic copyright? # copyright: © Model Context Protocol 2025 to present diff --git a/pyproject.toml b/pyproject.toml index 364b9add0b..d88869da1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,6 +115,7 @@ bump = true [project.urls] Homepage = "https://modelcontextprotocol.io" +Documentation = "https://py.sdk.modelcontextprotocol.io/v2/" Repository = "https://github.com/modelcontextprotocol/python-sdk" Issues = "https://github.com/modelcontextprotocol/python-sdk/issues" diff --git a/scripts/build-docs.sh b/scripts/build-docs.sh new file mode 100755 index 0000000000..5a61309acf --- /dev/null +++ b/scripts/build-docs.sh @@ -0,0 +1,54 @@ +#!/usr/bin/env bash +# +# Build combined v1 + v2 MkDocs documentation for GitHub Pages. +# +# v1 docs (from the v1.x branch) are placed at the site root. +# v2 docs (from main) are placed under /v2/. +# +# Both branches are fetched fresh from origin, so the output is identical +# regardless of which branch triggered the workflow. This script is intended +# to run in CI; for local single-branch preview use `uv run mkdocs serve`. +# +# Usage: +# scripts/build-docs.sh [output-dir] +# +# Default output directory: site +# +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" +OUTPUT_DIR="$(cd "$REPO_ROOT" && mkdir -p "${1:-site}" && cd "${1:-site}" && pwd)" +V1_WORKTREE="$REPO_ROOT/.worktrees/v1-docs" +V2_WORKTREE="$REPO_ROOT/.worktrees/v2-docs" + +cleanup() { + cd "$REPO_ROOT" + git worktree remove --force "$V1_WORKTREE" 2>/dev/null || true + git worktree remove --force "$V2_WORKTREE" 2>/dev/null || true + rmdir "$REPO_ROOT/.worktrees" 2>/dev/null || true +} +trap cleanup EXIT + +rm -rf "${OUTPUT_DIR:?}"/* + +build_branch() { + local branch="$1" worktree="$2" dest="$3" + + echo "=== Building docs for ${branch} ===" + git fetch origin "$branch" + git worktree remove --force "$worktree" 2>/dev/null || true + rm -rf "$worktree" + git worktree add --detach "$worktree" "origin/${branch}" + + ( + cd "$worktree" + uv sync --frozen --group docs + uv run --frozen --no-sync mkdocs build --site-dir "$dest" + ) +} + +build_branch v1.x "$V1_WORKTREE" "$OUTPUT_DIR" +build_branch main "$V2_WORKTREE" "$OUTPUT_DIR/v2" + +echo "=== Combined docs built at $OUTPUT_DIR ===" From e8e64842781c66b613872cf394de6e7d6f6925bf Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 20 May 2026 17:06:57 +0200 Subject: [PATCH 66/84] ci: add zizmor for GitHub Actions security analysis (#2648) --- .github/workflows/claude.yml | 3 ++- .github/workflows/comment-on-release.yml | 22 ++++++++++++----- .github/workflows/conformance.yml | 4 ++++ .github/workflows/deploy-docs.yml | 2 ++ .github/workflows/publish-pypi.yml | 7 +++++- .github/workflows/shared.yml | 6 +++++ .github/workflows/weekly-lockfile-update.yml | 2 ++ .github/workflows/zizmor.yml | 25 ++++++++++++++++++++ 8 files changed, 63 insertions(+), 8 deletions(-) create mode 100644 .github/workflows/zizmor.yml diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml index 59dac99dcb..18f5377cb6 100644 --- a/.github/workflows/claude.yml +++ b/.github/workflows/claude.yml @@ -30,12 +30,13 @@ jobs: uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: fetch-depth: 1 + persist-credentials: false - name: Run Claude Code id: claude uses: anthropics/claude-code-action@2f8ba26a219c06cfb0f468eef8d97055fa814f97 # v1.0.53 with: - anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} + anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} # zizmor: ignore[secrets-outside-env] use_commit_signing: true additional_permissions: | actions: read diff --git a/.github/workflows/comment-on-release.yml b/.github/workflows/comment-on-release.yml index 15d6a1d26a..66f1fcc32a 100644 --- a/.github/workflows/comment-on-release.yml +++ b/.github/workflows/comment-on-release.yml @@ -16,13 +16,16 @@ jobs: uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 with: fetch-depth: 0 + persist-credentials: false - name: Get previous release id: previous_release uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + env: + CURRENT_TAG: ${{ github.event.release.tag_name }} with: script: | - const currentTag = '${{ github.event.release.tag_name }}'; + const currentTag = process.env.CURRENT_TAG; // Get all releases const { data: releases } = await github.rest.repos.listReleases({ @@ -54,10 +57,13 @@ jobs: - name: Get merged PRs between releases id: get_prs uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + env: + CURRENT_TAG: ${{ github.event.release.tag_name }} + PREVIOUS_TAG_JSON: ${{ steps.previous_release.outputs.result }} with: script: | - const currentTag = '${{ github.event.release.tag_name }}'; - const previousTag = ${{ steps.previous_release.outputs.result }}; + const currentTag = process.env.CURRENT_TAG; + const previousTag = JSON.parse(process.env.PREVIOUS_TAG_JSON); if (!previousTag) { console.log('No previous release found, skipping'); @@ -104,11 +110,15 @@ jobs: - name: Comment on PRs uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + env: + PR_NUMBERS_JSON: ${{ steps.get_prs.outputs.result }} + RELEASE_TAG: ${{ github.event.release.tag_name }} + RELEASE_URL: ${{ github.event.release.html_url }} with: script: | - const prNumbers = ${{ steps.get_prs.outputs.result }}; - const releaseTag = '${{ github.event.release.tag_name }}'; - const releaseUrl = '${{ github.event.release.html_url }}'; + const prNumbers = JSON.parse(process.env.PR_NUMBERS_JSON); + const releaseTag = process.env.RELEASE_TAG; + const releaseUrl = process.env.RELEASE_URL; const comment = `This pull request is included in [${releaseTag}](${releaseUrl})`; diff --git a/.github/workflows/conformance.yml b/.github/workflows/conformance.yml index d876da00b0..9c33d2936b 100644 --- a/.github/workflows/conformance.yml +++ b/.github/workflows/conformance.yml @@ -19,6 +19,8 @@ jobs: continue-on-error: true steps: - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 + with: + persist-credentials: false - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 with: enable-cache: true @@ -34,6 +36,8 @@ jobs: continue-on-error: true steps: - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 + with: + persist-credentials: false - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 with: enable-cache: true diff --git a/.github/workflows/deploy-docs.yml b/.github/workflows/deploy-docs.yml index d9362afd57..fcf7a6c1a7 100644 --- a/.github/workflows/deploy-docs.yml +++ b/.github/workflows/deploy-docs.yml @@ -34,6 +34,8 @@ jobs: steps: - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 + with: + persist-credentials: false - name: Install uv uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 diff --git a/.github/workflows/publish-pypi.yml b/.github/workflows/publish-pypi.yml index 7ba11e86fa..c72061d12b 100644 --- a/.github/workflows/publish-pypi.yml +++ b/.github/workflows/publish-pypi.yml @@ -4,6 +4,9 @@ on: release: types: [published] +permissions: + contents: read + jobs: release-build: name: Build distribution @@ -11,11 +14,13 @@ jobs: needs: [checks] steps: - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 + with: + persist-credentials: false - name: Install uv uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 with: - enable-cache: true + enable-cache: false version: 0.9.5 - name: Set up Python 3.12 diff --git a/.github/workflows/shared.yml b/.github/workflows/shared.yml index efb45c8898..5f115aef89 100644 --- a/.github/workflows/shared.yml +++ b/.github/workflows/shared.yml @@ -14,6 +14,8 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 + with: + persist-credentials: false - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 with: @@ -57,6 +59,8 @@ jobs: steps: - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 + with: + persist-credentials: false - name: Install uv uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 @@ -83,6 +87,8 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 + with: + persist-credentials: false - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 with: diff --git a/.github/workflows/weekly-lockfile-update.yml b/.github/workflows/weekly-lockfile-update.yml index 5d79d06d52..c30c72991a 100644 --- a/.github/workflows/weekly-lockfile-update.yml +++ b/.github/workflows/weekly-lockfile-update.yml @@ -15,6 +15,8 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 with: diff --git a/.github/workflows/zizmor.yml b/.github/workflows/zizmor.yml new file mode 100644 index 0000000000..83d3eb3817 --- /dev/null +++ b/.github/workflows/zizmor.yml @@ -0,0 +1,25 @@ +name: GitHub Actions Security Analysis + +on: + push: + branches: ["main"] + pull_request: + branches: ["**"] + +permissions: {} + +jobs: + zizmor: + runs-on: ubuntu-latest + + permissions: + security-events: write # Required for upload-sarif (used by zizmor-action) to upload SARIF files. + + steps: + - name: Checkout repository + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + + - name: Run zizmor 🌈 + uses: zizmorcore/zizmor-action@5f14fd08f7cf1cb1609c1e344975f152c7ee938d # v0.5.6 From 3eb579948a4719d606d2adbd1f3f69371c9c0f48 Mon Sep 17 00:00:00 2001 From: Max <224885523+maxisbey@users.noreply.github.com> Date: Tue, 26 May 2026 14:13:55 +0100 Subject: [PATCH 67/84] Add subject and claims to AccessToken (#2686) --- .../servers/simple-auth/mcp_simple_auth/auth_server.py | 2 ++ .../mcp_simple_auth/simple_auth_provider.py | 2 ++ .../simple-auth/mcp_simple_auth/token_verifier.py | 2 ++ src/mcp/server/auth/provider.py | 6 +++++- src/mcp/server/mcpserver/context.py | 7 ++++++- tests/server/mcpserver/auth/test_auth_integration.py | 10 ++++++++++ 6 files changed, 27 insertions(+), 2 deletions(-) diff --git a/examples/servers/simple-auth/mcp_simple_auth/auth_server.py b/examples/servers/simple-auth/mcp_simple_auth/auth_server.py index 9d13fffe42..26c87c5ef2 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/auth_server.py +++ b/examples/servers/simple-auth/mcp_simple_auth/auth_server.py @@ -120,6 +120,8 @@ async def introspect_handler(request: Request) -> Response: "iat": int(time.time()), "token_type": "Bearer", "aud": access_token.resource, # RFC 8707 audience claim + "sub": access_token.subject, # RFC 7662 subject + "iss": str(server_settings.server_url), } ) diff --git a/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py b/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py index 3a3895cc57..48eb9a8414 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py +++ b/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py @@ -181,6 +181,7 @@ async def handle_simple_callback(self, username: str, password: str, state: str) scopes=[self.settings.mcp_scope], code_challenge=code_challenge, resource=resource, # RFC 8707 + subject=username, ) self.auth_codes[new_code] = auth_code @@ -219,6 +220,7 @@ async def exchange_authorization_code( scopes=authorization_code.scopes, expires_at=int(time.time()) + 3600, resource=authorization_code.resource, # RFC 8707 + subject=authorization_code.subject, ) # Store user data mapping for this token diff --git a/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py b/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py index 5228d034e4..641095a125 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py +++ b/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py @@ -75,6 +75,8 @@ async def verify_token(self, token: str) -> AccessToken | None: scopes=data.get("scope", "").split() if data.get("scope") else [], expires_at=data.get("exp"), resource=data.get("aud"), # Include resource in token + subject=data.get("sub"), # RFC 7662 subject (resource owner) + claims=data, ) except Exception as e: logger.warning(f"Token introspection failed: {e}") diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index 957082a854..4ce1137575 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Generic, Literal, Protocol, TypeVar +from typing import Any, Generic, Literal, Protocol, TypeVar from urllib.parse import parse_qs, urlencode, urlparse, urlunparse from pydantic import AnyUrl, BaseModel @@ -25,6 +25,7 @@ class AuthorizationCode(BaseModel): redirect_uri: AnyUrl redirect_uri_provided_explicitly: bool resource: str | None = None # RFC 8707 resource indicator + subject: str | None = None # resource owner; propagate to the issued AccessToken class RefreshToken(BaseModel): @@ -32,6 +33,7 @@ class RefreshToken(BaseModel): client_id: str scopes: list[str] expires_at: int | None = None + subject: str | None = None # resource owner; propagate to refreshed AccessTokens class AccessToken(BaseModel): @@ -40,6 +42,8 @@ class AccessToken(BaseModel): scopes: list[str] expires_at: int | None = None resource: str | None = None # RFC 8707 resource indicator + subject: str | None = None # RFC 7662/9068 `sub`: resource owner; unique only per issuer + claims: dict[str, Any] | None = None # additional claims (e.g. `iss`, `act`) RegistrationErrorCode = Literal[ diff --git a/src/mcp/server/mcpserver/context.py b/src/mcp/server/mcpserver/context.py index e87388eee9..39bba839bd 100644 --- a/src/mcp/server/mcpserver/context.py +++ b/src/mcp/server/mcpserver/context.py @@ -208,9 +208,14 @@ async def log( related_request_id=self.request_id, ) + # TODO(maxisbey): see if this is needed otherwise remove @property def client_id(self) -> str | None: - """Get the client ID if available.""" + """Get the client ID if available. + + Note: this reads from the MCP request's `_meta` params, not the OAuth + bearer token. For that, use `get_access_token().client_id`. + """ return self.request_context.meta.get("client_id") if self.request_context.meta else None # pragma: no cover @property diff --git a/tests/server/mcpserver/auth/test_auth_integration.py b/tests/server/mcpserver/auth/test_auth_integration.py index 602f5cc752..35fec1c57e 100644 --- a/tests/server/mcpserver/auth/test_auth_integration.py +++ b/tests/server/mcpserver/auth/test_auth_integration.py @@ -53,6 +53,7 @@ async def authorize(self, client: OAuthClientInformationFull, params: Authorizat redirect_uri_provided_explicitly=params.redirect_uri_provided_explicitly, expires_at=time.time() + 300, scopes=params.scopes or ["read", "write"], + subject="test-user", ) self.auth_codes[code.code] = code @@ -79,6 +80,7 @@ async def exchange_authorization_code( client_id=client.client_id, scopes=authorization_code.scopes, expires_at=int(time.time()) + 3600, + subject=authorization_code.subject, ) self.refresh_tokens[refresh_token] = access_token @@ -108,6 +110,7 @@ async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_t client_id=token_info.client_id, scopes=token_info.scopes, expires_at=token_info.expires_at, + subject=token_info.subject, ) return refresh_obj @@ -141,6 +144,7 @@ async def exchange_refresh_token( client_id=client.client_id, scopes=scopes or token_info.scopes, expires_at=int(time.time()) + 3600, + subject=refresh_token.subject, ) self.refresh_tokens[new_refresh_token] = new_access_token @@ -169,6 +173,7 @@ async def load_access_token(self, token: str) -> AccessToken | None: client_id=token_info.client_id, scopes=token_info.scopes, expires_at=token_info.expires_at, + subject=token_info.subject, ) async def revoke_token(self, token: AccessToken | RefreshToken) -> None: @@ -832,6 +837,7 @@ async def test_authorization_get( assert auth_info.client_id == client_info["client_id"] assert "read" in auth_info.scopes assert "write" in auth_info.scopes + assert auth_info.subject == "test-user" # 6. Refresh the token response = await test_client.post( @@ -852,6 +858,10 @@ async def test_authorization_get( assert new_token_response["access_token"] != access_token assert new_token_response["refresh_token"] != refresh_token + refreshed_auth_info = await mock_oauth_provider.load_access_token(new_token_response["access_token"]) + assert refreshed_auth_info + assert refreshed_auth_info.subject == "test-user" + # 7. Revoke the token response = await test_client.post( "/revoke", From 24725633f11276e94c5f7310ee0296a0e2117618 Mon Sep 17 00:00:00 2001 From: Max <224885523+maxisbey@users.noreply.github.com> Date: Thu, 28 May 2026 19:48:30 +0100 Subject: [PATCH 68/84] test: interaction-model end-to-end suite with a requirements manifest (#2691) --- pyproject.toml | 5 +- src/mcp/client/auth/oauth2.py | 20 +- src/mcp/client/client.py | 2 +- src/mcp/client/session.py | 10 +- src/mcp/client/streamable_http.py | 16 +- src/mcp/server/auth/handlers/authorize.py | 2 +- src/mcp/server/auth/middleware/bearer_auth.py | 2 +- src/mcp/server/lowlevel/server.py | 16 +- src/mcp/server/mcpserver/context.py | 4 +- src/mcp/server/mcpserver/prompts/base.py | 2 +- src/mcp/server/mcpserver/server.py | 2 +- src/mcp/server/session.py | 8 +- src/mcp/server/sse.py | 13 +- src/mcp/server/streamable_http.py | 98 +- src/mcp/server/streamable_http_manager.py | 2 +- src/mcp/server/transport_security.py | 34 +- src/mcp/shared/auth.py | 4 +- src/mcp/shared/session.py | 2 +- tests/interaction/README.md | 228 ++ tests/interaction/__init__.py | 0 tests/interaction/_connect.py | 360 +++ tests/interaction/_helpers.py | 107 + tests/interaction/_requirements.py | 2816 +++++++++++++++++ tests/interaction/auth/__init__.py | 0 tests/interaction/auth/_harness.py | 465 +++ tests/interaction/auth/_provider.py | 186 ++ tests/interaction/auth/test_as_handlers.py | 300 ++ .../interaction/auth/test_authorize_token.py | 399 +++ tests/interaction/auth/test_bearer.py | 189 ++ tests/interaction/auth/test_discovery.py | 333 ++ tests/interaction/auth/test_flow.py | 239 ++ tests/interaction/auth/test_lifecycle.py | 445 +++ tests/interaction/conftest.py | 23 + tests/interaction/lowlevel/__init__.py | 0 .../interaction/lowlevel/test_cancellation.py | 234 ++ tests/interaction/lowlevel/test_completion.py | 131 + .../interaction/lowlevel/test_elicitation.py | 662 ++++ tests/interaction/lowlevel/test_flows.py | 203 ++ tests/interaction/lowlevel/test_initialize.py | 384 +++ .../interaction/lowlevel/test_list_changed.py | 136 + tests/interaction/lowlevel/test_logging.py | 127 + tests/interaction/lowlevel/test_meta.py | 63 + tests/interaction/lowlevel/test_pagination.py | 242 ++ tests/interaction/lowlevel/test_ping.py | 53 + tests/interaction/lowlevel/test_progress.py | 301 ++ tests/interaction/lowlevel/test_prompts.py | 209 ++ tests/interaction/lowlevel/test_resources.py | 309 ++ tests/interaction/lowlevel/test_roots.py | 166 + tests/interaction/lowlevel/test_sampling.py | 687 ++++ tests/interaction/lowlevel/test_timeouts.py | 114 + tests/interaction/lowlevel/test_tools.py | 512 +++ tests/interaction/lowlevel/test_wire.py | 309 ++ tests/interaction/mcpserver/__init__.py | 0 .../interaction/mcpserver/test_completion.py | 38 + tests/interaction/mcpserver/test_context.py | 271 ++ tests/interaction/mcpserver/test_prompts.py | 195 ++ tests/interaction/mcpserver/test_resources.py | 183 ++ tests/interaction/mcpserver/test_tools.py | 432 +++ tests/interaction/test_coverage.py | 105 + tests/interaction/transports/__init__.py | 0 tests/interaction/transports/_bridge.py | 169 + tests/interaction/transports/_event_store.py | 55 + tests/interaction/transports/_stdio_server.py | 63 + tests/interaction/transports/test_bridge.py | 94 + .../transports/test_client_transport_http.py | 247 ++ tests/interaction/transports/test_flows.py | 129 + .../transports/test_hosting_http.py | 344 ++ .../transports/test_hosting_resume.py | 372 +++ .../transports/test_hosting_session.py | 202 ++ tests/interaction/transports/test_sse.py | 90 + tests/interaction/transports/test_stdio.py | 143 + .../transports/test_streamable_http.py | 168 + uv.lock | 2 +- 73 files changed, 14355 insertions(+), 121 deletions(-) create mode 100644 tests/interaction/README.md create mode 100644 tests/interaction/__init__.py create mode 100644 tests/interaction/_connect.py create mode 100644 tests/interaction/_helpers.py create mode 100644 tests/interaction/_requirements.py create mode 100644 tests/interaction/auth/__init__.py create mode 100644 tests/interaction/auth/_harness.py create mode 100644 tests/interaction/auth/_provider.py create mode 100644 tests/interaction/auth/test_as_handlers.py create mode 100644 tests/interaction/auth/test_authorize_token.py create mode 100644 tests/interaction/auth/test_bearer.py create mode 100644 tests/interaction/auth/test_discovery.py create mode 100644 tests/interaction/auth/test_flow.py create mode 100644 tests/interaction/auth/test_lifecycle.py create mode 100644 tests/interaction/conftest.py create mode 100644 tests/interaction/lowlevel/__init__.py create mode 100644 tests/interaction/lowlevel/test_cancellation.py create mode 100644 tests/interaction/lowlevel/test_completion.py create mode 100644 tests/interaction/lowlevel/test_elicitation.py create mode 100644 tests/interaction/lowlevel/test_flows.py create mode 100644 tests/interaction/lowlevel/test_initialize.py create mode 100644 tests/interaction/lowlevel/test_list_changed.py create mode 100644 tests/interaction/lowlevel/test_logging.py create mode 100644 tests/interaction/lowlevel/test_meta.py create mode 100644 tests/interaction/lowlevel/test_pagination.py create mode 100644 tests/interaction/lowlevel/test_ping.py create mode 100644 tests/interaction/lowlevel/test_progress.py create mode 100644 tests/interaction/lowlevel/test_prompts.py create mode 100644 tests/interaction/lowlevel/test_resources.py create mode 100644 tests/interaction/lowlevel/test_roots.py create mode 100644 tests/interaction/lowlevel/test_sampling.py create mode 100644 tests/interaction/lowlevel/test_timeouts.py create mode 100644 tests/interaction/lowlevel/test_tools.py create mode 100644 tests/interaction/lowlevel/test_wire.py create mode 100644 tests/interaction/mcpserver/__init__.py create mode 100644 tests/interaction/mcpserver/test_completion.py create mode 100644 tests/interaction/mcpserver/test_context.py create mode 100644 tests/interaction/mcpserver/test_prompts.py create mode 100644 tests/interaction/mcpserver/test_resources.py create mode 100644 tests/interaction/mcpserver/test_tools.py create mode 100644 tests/interaction/test_coverage.py create mode 100644 tests/interaction/transports/__init__.py create mode 100644 tests/interaction/transports/_bridge.py create mode 100644 tests/interaction/transports/_event_store.py create mode 100644 tests/interaction/transports/_stdio_server.py create mode 100644 tests/interaction/transports/test_bridge.py create mode 100644 tests/interaction/transports/test_client_transport_http.py create mode 100644 tests/interaction/transports/test_flows.py create mode 100644 tests/interaction/transports/test_hosting_http.py create mode 100644 tests/interaction/transports/test_hosting_resume.py create mode 100644 tests/interaction/transports/test_hosting_session.py create mode 100644 tests/interaction/transports/test_sse.py create mode 100644 tests/interaction/transports/test_stdio.py create mode 100644 tests/interaction/transports/test_streamable_http.py diff --git a/pyproject.toml b/pyproject.toml index d88869da1c..6d2319621a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,7 @@ dev = [ # We add mcp[cli,ws] so `uv sync` considers the extras. "mcp[cli,ws]", "pyright>=1.1.400", - "pytest>=8.3.4", + "pytest>=8.4.0", "ruff>=0.8.5", "trio>=0.26.2", "pytest-flakefinder>=1.1.0", @@ -193,6 +193,9 @@ strict-no-cover = { git = "https://github.com/pydantic/strict-no-cover" } [tool.pytest.ini_options] log_cli = true xfail_strict = true +markers = [ + "requirement(id): links a test to the entry in tests/interaction/_requirements.py it exercises", +] addopts = """ --color=yes --capture=fd diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 72309f5775..3c546fda2b 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -360,10 +360,10 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: auth_code, returned_state = await self.context.callback_handler() if returned_state is None or not secrets.compare_digest(returned_state, state): - raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {state}") # pragma: no cover + raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {state}") if not auth_code: - raise OAuthFlowError("No authorization code received") # pragma: no cover + raise OAuthFlowError("No authorization code received") # Return auth code and code verifier for token exchange return auth_code, pkce_params.code_verifier @@ -452,7 +452,7 @@ async def _refresh_token(self) -> httpx.Request: return httpx.Request("POST", token_url, data=refresh_data, headers=headers) - async def _handle_refresh_response(self, response: httpx.Response) -> bool: # pragma: no cover + async def _handle_refresh_response(self, response: httpx.Response) -> bool: """Handle token refresh response. Returns True if successful.""" if response.status_code != 200: logger.warning(f"Token refresh failed: {response.status_code}") @@ -468,12 +468,12 @@ async def _handle_refresh_response(self, response: httpx.Response) -> bool: # p await self.context.storage.set_tokens(token_response) return True - except ValidationError: + except ValidationError: # pragma: no cover logger.exception("Invalid refresh response") self.context.clear_tokens() return False - async def _initialize(self) -> None: # pragma: no cover + async def _initialize(self) -> None: """Load stored tokens and client info.""" self.context.current_tokens = await self.context.storage.get_tokens() self.context.client_info = await self.context.storage.get_client_info() @@ -507,17 +507,17 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. """HTTPX auth flow integration.""" async with self.context.lock: if not self._initialized: - await self._initialize() # pragma: no cover + await self._initialize() # Capture protocol version from request headers self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION) if not self.context.is_token_valid() and self.context.can_refresh_token(): # Try to refresh token - refresh_request = await self._refresh_token() # pragma: no cover - refresh_response = yield refresh_request # pragma: no cover + refresh_request = await self._refresh_token() + refresh_response = yield refresh_request - if not await self._handle_refresh_response(refresh_response): # pragma: no cover + if not await self._handle_refresh_response(refresh_response): # Refresh failed, need full re-authentication self._initialized = False @@ -612,7 +612,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # Step 5: Perform authorization and complete token exchange token_response = yield await self._perform_authorization() await self._handle_token_response(token_response) - except Exception: # pragma: no cover + except Exception: logger.exception("OAuth flow error") raise diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 34d6a360fa..b33fea4052 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -305,4 +305,4 @@ async def list_tools(self, *, cursor: str | None = None, meta: RequestParamsMeta async def send_roots_list_changed(self) -> None: """Send a notification that the roots list has changed.""" # TODO(Marcelo): Currently, there is no way for the server to handle this. We should add support. - await self.session.send_roots_list_changed() # pragma: no cover + await self.session.send_roots_list_changed() diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 0cea454a77..86113874be 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -74,7 +74,7 @@ async def _default_elicitation_callback( context: RequestContext[ClientSession], params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: - return types.ErrorData( # pragma: no cover + return types.ErrorData( code=types.INVALID_REQUEST, message="Elicitation not supported", ) @@ -337,9 +337,7 @@ async def _validate_tool_result(self, name: str, result: types.CallToolResult) - from jsonschema import SchemaError, ValidationError, validate if result.structured_content is None: - raise RuntimeError( - f"Tool {name} has an output schema but did not return structured content" - ) # pragma: no cover + raise RuntimeError(f"Tool {name} has an output schema but did not return structured content") try: validate(result.structured_content, output_schema) except ValidationError as e: @@ -408,7 +406,7 @@ async def list_tools(self, *, params: types.PaginatedRequestParams | None = None return result - async def send_roots_list_changed(self) -> None: # pragma: no cover + async def send_roots_list_changed(self) -> None: """Send a roots/list_changed notification.""" await self.send_notification(types.RootsListChangedNotification()) @@ -449,7 +447,7 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques client_response = ClientResponse.validate_python(response) await responder.respond(client_response) - case types.PingRequest(): # pragma: no cover + case types.PingRequest(): with responder: return await responder.respond(types.EmptyResult()) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 9a119c6338..aa3e50e07e 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -210,7 +210,7 @@ async def handle_get_stream(self, client: httpx.AsyncClient, read_stream_writer: # Stream ended normally (server closed) - reset attempt counter attempt = 0 - except Exception: # pragma: lax no cover + except Exception: logger.debug("GET stream error", exc_info=True) attempt += 1 @@ -267,8 +267,8 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: logger.debug("Received 202 Accepted") return - if response.status_code == 404: # pragma: no branch - if isinstance(message, JSONRPCRequest): # pragma: no branch + if response.status_code == 404: + if isinstance(message, JSONRPCRequest): error_data = ErrorData(code=INVALID_REQUEST, message="Session terminated") session_message = SessionMessage(JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data)) await ctx.read_stream_writer.send(session_message) @@ -492,17 +492,17 @@ async def handle_request_async(): async def terminate_session(self, client: httpx.AsyncClient) -> None: """Terminate the session by sending a DELETE request.""" - if not self.session_id: # pragma: lax no cover - return + if not self.session_id: + return # pragma: no cover try: headers = self._prepare_headers() response = await client.delete(self.url, headers=headers) - if response.status_code == 405: # pragma: lax no cover + if response.status_code == 405: logger.debug("Server does not allow session termination") - elif response.status_code not in (200, 204): # pragma: lax no cover - logger.warning(f"Session termination failed: {response.status_code}") + elif response.status_code not in (200, 204): + logger.warning(f"Session termination failed: {response.status_code}") # pragma: no cover except Exception as exc: # pragma: no cover logger.warning(f"Session termination failed: {exc}") diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index dec6713b13..5cf93cf8c2 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -117,7 +117,7 @@ async def error_response( pass # the error response MUST contain the state specified by the client, if any - if state is None: # pragma: no cover + if state is None: # make last-ditch effort to load state state = best_effort_extract_string("state", params) diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index 6825c00b9e..2eafdc793e 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -95,7 +95,7 @@ async def _send_auth_error(self, send: Send, status_code: int, error: str, descr """Send an authentication error response with WWW-Authenticate header.""" # Build WWW-Authenticate header value www_auth_parts = [f'error="{error}"', f'error_description="{description}"'] - if self.resource_metadata_url: # pragma: no cover + if self.resource_metadata_url: www_auth_parts.append(f'resource_metadata="{self.resource_metadata_url}"') www_authenticate = f"Bearer {', '.join(www_auth_parts)}" diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 59de0ace45..5e4e2e6f5b 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -349,12 +349,12 @@ def session_manager(self) -> StreamableHTTPSessionManager: Raises: RuntimeError: If called before streamable_http_app() has been called. """ - if self._session_manager is None: # pragma: no cover - raise RuntimeError( + if self._session_manager is None: + raise RuntimeError( # pragma: no cover "Session manager can only be accessed after calling streamable_http_app(). " "The session manager is created lazily to avoid unnecessary initialization." ) - return self._session_manager # pragma: no cover + return self._session_manager async def run( self, @@ -513,7 +513,7 @@ async def _handle_request( if raise_exceptions: # pragma: no cover raise err response = types.ErrorData(code=0, message=str(err)) - else: # pragma: no cover + else: response = types.ErrorData(code=types.METHOD_NOT_FOUND, message="Method not found") if isinstance(response, types.ErrorData) and span is not None: @@ -603,7 +603,7 @@ def streamable_http_app( required_scopes: list[str] = [] # Set up auth if configured - if auth: # pragma: no cover + if auth: required_scopes = auth.required_scopes or [] # Add auth middleware if token verifier is available @@ -629,10 +629,10 @@ def streamable_http_app( ) # Set up routes with or without auth - if token_verifier: # pragma: no cover + if token_verifier: # Determine resource metadata URL resource_metadata_url = None - if auth and auth.resource_server_url: + if auth and auth.resource_server_url: # pragma: no branch # Build compliant metadata URL for WWW-Authenticate header resource_metadata_url = build_resource_metadata_url(auth.resource_server_url) @@ -652,7 +652,7 @@ def streamable_http_app( ) # Add protected resource metadata endpoint if configured as RS - if auth and auth.resource_server_url: # pragma: no cover + if auth and auth.resource_server_url: routes.extend( create_protected_resource_routes( resource_url=auth.resource_server_url, diff --git a/src/mcp/server/mcpserver/context.py b/src/mcp/server/mcpserver/context.py index 39bba839bd..92de074d34 100644 --- a/src/mcp/server/mcpserver/context.py +++ b/src/mcp/server/mcpserver/context.py @@ -94,7 +94,7 @@ async def report_progress(self, progress: float, total: float | None = None, mes """ progress_token = self.request_context.meta.get("progress_token") if self.request_context.meta else None - if progress_token is None: # pragma: no cover + if progress_token is None: return await self.request_context.session.send_progress_notification( @@ -242,7 +242,7 @@ async def close_sse_stream(self) -> None: This is a no-op if not using StreamableHTTP transport with event_store. The callback is only available when event_store is configured. """ - if self._request_context and self._request_context.close_sse_stream: # pragma: no cover + if self._request_context and self._request_context.close_sse_stream: # pragma: no branch await self._request_context.close_sse_stream() async def close_standalone_sse_stream(self) -> None: diff --git a/src/mcp/server/mcpserver/prompts/base.py b/src/mcp/server/mcpserver/prompts/base.py index e5b2af7d82..2f778eb514 100644 --- a/src/mcp/server/mcpserver/prompts/base.py +++ b/src/mcp/server/mcpserver/prompts/base.py @@ -185,5 +185,5 @@ async def render( raise ValueError(f"Could not convert prompt result to message: {msg}") return messages - except Exception as e: # pragma: no cover + except Exception as e: raise ValueError(f"Error rendering prompt {self.name}: {e}") diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index b3471163b7..ec2365810e 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -244,7 +244,7 @@ def session_manager(self) -> StreamableHTTPSessionManager: Raises: RuntimeError: If called before streamable_http_app() has been called. """ - return self._lowlevel_server.session_manager # pragma: no cover + return self._lowlevel_server.session_manager @overload def run(self, transport: Literal["stdio"] = ...) -> None: ... diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 20b640527a..fc2f97a9cb 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -223,7 +223,7 @@ async def send_log_message( related_request_id, ) - async def send_resource_updated(self, uri: str | AnyUrl) -> None: # pragma: no cover + async def send_resource_updated(self, uri: str | AnyUrl) -> None: """Send a resource updated notification.""" await self.send_notification( types.ResourceUpdatedNotification( @@ -447,7 +447,7 @@ async def elicit_url( metadata=ServerMessageMetadata(related_request_id=related_request_id), ) - async def send_ping(self) -> types.EmptyResult: # pragma: no cover + async def send_ping(self) -> types.EmptyResult: """Send a ping request.""" return await self.send_request( types.PingRequest(), @@ -479,11 +479,11 @@ async def send_resource_list_changed(self) -> None: """Send a resource list changed notification.""" await self.send_notification(types.ResourceListChangedNotification()) - async def send_tool_list_changed(self) -> None: # pragma: no cover + async def send_tool_list_changed(self) -> None: """Send a tool list changed notification.""" await self.send_notification(types.ToolListChangedNotification()) - async def send_prompt_list_changed(self) -> None: # pragma: no cover + async def send_prompt_list_changed(self) -> None: """Send a prompt list changed notification.""" await self.send_notification(types.PromptListChangedNotification()) diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 48192ff612..be8e979c9d 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -116,15 +116,15 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}") @asynccontextmanager - async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # pragma: no cover - if scope["type"] != "http": + async def connect_sse(self, scope: Scope, receive: Receive, send: Send): + if scope["type"] != "http": # pragma: no cover logger.error("connect_sse received non-HTTP request") raise ValueError("connect_sse can only handle HTTP requests") # Validate request headers for DNS rebinding protection request = Request(scope, receive) error_response = await self._security.validate_request(request, is_post=False) - if error_response: + if error_response: # pragma: no cover await error_response(scope, receive, send) raise ValueError("Request validation failed") @@ -179,6 +179,7 @@ async def response_wrapper(scope: Scope, receive: Receive, send: Send): await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)( scope, receive, send ) + await sse_stream_reader.aclose() await read_stream_writer.aclose() await write_stream_reader.aclose() self._read_stream_writers.pop(session_id, None) @@ -190,13 +191,13 @@ async def response_wrapper(scope: Scope, receive: Receive, send: Send): logger.debug("Yielding read and write streams") yield (read_stream, write_stream) - async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: # pragma: no cover + async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: logger.debug("Handling POST message") request = Request(scope, receive) # Validate request headers for DNS rebinding protection error_response = await self._security.validate_request(request, is_post=True) - if error_response: + if error_response: # pragma: no cover return await error_response(scope, receive, send) session_id_param = request.query_params.get("session_id") @@ -225,7 +226,7 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) try: message = types.jsonrpc_message_adapter.validate_json(body, by_name=False) logger.debug(f"Validated client message: {message}") - except ValidationError as err: + except ValidationError as err: # pragma: no cover logger.exception("Failed to parse message") response = Response("Could not parse message", status_code=400) await response(scope, receive, send) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index f14201857c..f2f4407cea 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -179,7 +179,7 @@ def is_terminated(self) -> bool: """Check if this transport has been explicitly terminated.""" return self._terminated - def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover + def close_sse_stream(self, request_id: RequestId) -> None: """Close SSE connection for a specific request without terminating the stream. This method closes the HTTP connection for the specified request, triggering @@ -198,11 +198,11 @@ def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover the disconnect. """ writer = self._sse_stream_writers.pop(request_id, None) - if writer: + if writer: # pragma: no branch writer.close() # Also close and remove request streams - if request_id in self._request_streams: + if request_id in self._request_streams: # pragma: no branch send_stream, receive_stream = self._request_streams.pop(request_id) send_stream.close() receive_stream.close() @@ -242,7 +242,7 @@ def _create_session_message( # Only provide close callbacks when client supports resumability if self._event_store and protocol_version >= "2025-11-25": - async def close_stream_callback() -> None: # pragma: no cover + async def close_stream_callback() -> None: self.close_sse_stream(request_id) async def close_standalone_stream_callback() -> None: # pragma: no cover @@ -293,7 +293,7 @@ def _create_error_response( ) -> Response: """Create an error response with a simple string message.""" response_headers = {"Content-Type": CONTENT_TYPE_JSON} - if headers: # pragma: no cover + if headers: response_headers.update(headers) if self.mcp_session_id: @@ -320,10 +320,10 @@ def _create_json_response( ) -> Response: """Create a JSON response from a JSONRPCMessage.""" response_headers = {"Content-Type": CONTENT_TYPE_JSON} - if headers: # pragma: lax no cover - response_headers.update(headers) + if headers: + response_headers.update(headers) # pragma: no cover - if self.mcp_session_id: # pragma: lax no cover + if self.mcp_session_id: response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id return Response( @@ -344,7 +344,7 @@ def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: } # If an event ID was provided, include it - if event_message.event_id: # pragma: no cover + if event_message.event_id: event_data["id"] = event_message.event_id return event_data @@ -374,7 +374,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No await error_response(scope, receive, send) return - if self._terminated: # pragma: no cover + if self._terminated: # If the session has been terminated, return 404 Not Found response = self._create_error_response( "Not Found: Session has been terminated", @@ -389,7 +389,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No await self._handle_get_request(request, send) elif request.method == "DELETE": await self._handle_delete_request(request, send) - else: # pragma: no cover + else: await self._handle_unsupported_request(request, send) def _check_accept_headers(self, request: Request) -> tuple[bool, bool]: @@ -421,7 +421,7 @@ async def _validate_accept_header(self, request: Request, scope: Scope, send: Se has_json, has_sse = self._check_accept_headers(request) if self.is_json_response_enabled: # For JSON-only responses, only require application/json - if not has_json: # pragma: lax no cover + if not has_json: # pragma: no cover response = self._create_error_response( "Not Acceptable: Client must accept application/json", HTTPStatus.NOT_ACCEPTABLE, @@ -469,7 +469,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re try: message = jsonrpc_message_adapter.validate_python(raw_message, by_name=False) - except ValidationError as e: # pragma: no cover + except ValidationError as e: response = self._create_error_response( f"Validation error: {str(e)}", HTTPStatus.BAD_REQUEST, @@ -495,7 +495,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re ) await response(scope, receive, send) return - elif not await self._validate_request_headers(request, send): # pragma: no cover + elif not await self._validate_request_headers(request, send): return # For notifications and responses only, return 202 Accepted @@ -579,7 +579,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re # Store writer reference so close_sse_stream() can close it self._sse_stream_writers[request_id] = sse_stream_writer - async def sse_writer(): # pragma: lax no cover + async def sse_writer(): # Get the request ID from the incoming request message try: async with sse_stream_writer, request_stream_reader: @@ -595,10 +595,10 @@ async def sse_writer(): # pragma: lax no cover # If response, remove from pending streams and close if isinstance(event_message.message, JSONRPCResponse | JSONRPCError): break - except anyio.ClosedResourceError: + except anyio.ClosedResourceError: # pragma: lax no cover # Expected when close_sse_stream() is called logger.debug("SSE stream closed by close_sse_stream()") - except Exception: + except Exception: # pragma: lax no cover logger.exception("Error in SSE writer") finally: logger.debug("Closing SSE writer") @@ -628,14 +628,14 @@ async def sse_writer(): # pragma: lax no cover # Then send the message to be processed by the server session_message = self._create_session_message(message, request, request_id, protocol_version) await writer.send(session_message) - except Exception: # pragma: no cover + except Exception: # pragma: lax no cover logger.exception("SSE response error") await sse_stream_writer.aclose() await self._clean_up_memory_streams(request_id) finally: await sse_stream_reader.aclose() - except Exception as err: # pragma: no cover + except Exception as err: logger.exception("Error handling POST request") response = self._create_error_response( f"Error handling POST request: {err}", @@ -643,9 +643,9 @@ async def sse_writer(): # pragma: lax no cover INTERNAL_ERROR, ) await response(scope, receive, send) - if writer: + if writer: # pragma: no cover await writer.send(Exception(err)) - return + return # pragma: no cover async def _handle_get_request(self, request: Request, send: Send) -> None: """Handle GET request to establish SSE. @@ -661,7 +661,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: # Validate Accept header - must include text/event-stream _, has_sse = self._check_accept_headers(request) - if not has_sse: # pragma: no cover + if not has_sse: response = self._create_error_response( "Not Acceptable: Client must accept text/event-stream", HTTPStatus.NOT_ACCEPTABLE, @@ -673,7 +673,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: return # Handle resumability: check for Last-Event-ID header - if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): # pragma: no cover + if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): await self._replay_events(last_event_id, request, send) return @@ -683,11 +683,11 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: "Content-Type": CONTENT_TYPE_SSE, } - if self.mcp_session_id: + if self.mcp_session_id: # pragma: no branch headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id # Check if we already have an active GET stream - if GET_STREAM_KEY in self._request_streams: # pragma: no cover + if GET_STREAM_KEY in self._request_streams: response = self._create_error_response( "Conflict: Only one SSE stream is allowed per session", HTTPStatus.CONFLICT, @@ -707,7 +707,7 @@ async def standalone_sse_writer(): async with sse_stream_writer, standalone_stream_reader: # Process messages from the standalone stream - async for event_message in standalone_stream_reader: # pragma: lax no cover + async for event_message in standalone_stream_reader: # For the standalone stream, we handle: # - JSONRPCNotification (server sends notifications to client) # - JSONRPCRequest (server sends requests to client) @@ -716,8 +716,8 @@ async def standalone_sse_writer(): # Send the message via SSE event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) - except Exception: # pragma: no cover - logger.exception("Error in standalone SSE writer") + except Exception: + logger.exception("Error in standalone SSE writer") # pragma: no cover finally: logger.debug("Closing standalone SSE writer") await self._clean_up_memory_streams(GET_STREAM_KEY) @@ -775,7 +775,7 @@ async def terminate(self) -> None: request_stream_keys = list(self._request_streams.keys()) # Close all request streams asynchronously - for key in request_stream_keys: # pragma: lax no cover + for key in request_stream_keys: await self._clean_up_memory_streams(key) # Clear the request streams dictionary immediately @@ -793,13 +793,13 @@ async def terminate(self) -> None: # During cleanup, we catch all exceptions since streams might be in various states logger.debug(f"Error closing streams: {e}") - async def _handle_unsupported_request(self, request: Request, send: Send) -> None: # pragma: no cover + async def _handle_unsupported_request(self, request: Request, send: Send) -> None: """Handle unsupported HTTP methods.""" headers = { "Content-Type": CONTENT_TYPE_JSON, "Allow": "GET, POST, DELETE", } - if self.mcp_session_id: + if self.mcp_session_id: # pragma: no branch headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id response = self._create_error_response( @@ -809,7 +809,7 @@ async def _handle_unsupported_request(self, request: Request, send: Send) -> Non ) await response(request.scope, request.receive, send) - async def _validate_request_headers(self, request: Request, send: Send) -> bool: # pragma: lax no cover + async def _validate_request_headers(self, request: Request, send: Send) -> bool: if not await self._validate_session(request, send): return False if not await self._validate_protocol_version(request, send): @@ -818,7 +818,7 @@ async def _validate_request_headers(self, request: Request, send: Send) -> bool: async def _validate_session(self, request: Request, send: Send) -> bool: """Validate the session ID in the request.""" - if not self.mcp_session_id: # pragma: no cover + if not self.mcp_session_id: # If we're not using session IDs, return True return True @@ -826,7 +826,7 @@ async def _validate_session(self, request: Request, send: Send) -> bool: request_session_id = self._get_session_id(request) # If no session ID provided but required, return error - if not request_session_id: # pragma: no cover + if not request_session_id: response = self._create_error_response( "Bad Request: Missing session ID", HTTPStatus.BAD_REQUEST, @@ -851,11 +851,11 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER) # If no protocol version provided, assume default version - if protocol_version is None: # pragma: no cover + if protocol_version is None: protocol_version = DEFAULT_NEGOTIATED_VERSION # Check if the protocol version is supported - if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: # pragma: no cover + if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: supported_versions = ", ".join(SUPPORTED_PROTOCOL_VERSIONS) response = self._create_error_response( f"Bad Request: Unsupported protocol version: {protocol_version}. " @@ -867,14 +867,14 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool return True - async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: # pragma: no cover + async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: """Replays events that would have been sent after the specified event ID. Only used when resumability is enabled. """ event_store = self._event_store if not event_store: - return + return # pragma: no cover try: headers = { @@ -883,7 +883,7 @@ async def _replay_events(self, last_event_id: str, request: Request, send: Send) "Content-Type": CONTENT_TYPE_SSE, } - if self.mcp_session_id: + if self.mcp_session_id: # pragma: no branch headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id # Get protocol version from header (already validated in _validate_protocol_version) @@ -904,7 +904,7 @@ async def send_event(event_message: EventMessage) -> None: stream_id = await event_store.replay_events_after(last_event_id, send_event) # If stream ID not in mapping, create it - if stream_id and stream_id not in self._request_streams: + if stream_id and stream_id not in self._request_streams: # pragma: no branch # Register SSE writer so close_sse_stream() can close it self._sse_stream_writers[stream_id] = sse_stream_writer @@ -921,10 +921,10 @@ async def send_event(event_message: EventMessage) -> None: event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) - except anyio.ClosedResourceError: + except anyio.ClosedResourceError: # pragma: lax no cover # Expected when close_sse_stream() is called logger.debug("Replay SSE stream closed by close_sse_stream()") - except Exception: + except Exception: # pragma: lax no cover logger.exception("Error in replay sender") # Create and start EventSourceResponse @@ -936,13 +936,13 @@ async def send_event(event_message: EventMessage) -> None: try: await response(request.scope, request.receive, send) - except Exception: + except Exception: # pragma: lax no cover logger.exception("Error in replay response") finally: await sse_stream_writer.aclose() await sse_stream_reader.aclose() - except Exception: + except Exception: # pragma: lax no cover logger.exception("Error replaying events") response = self._create_error_response( "Error replaying events", @@ -993,7 +993,7 @@ async def message_router(): if isinstance(message, JSONRPCResponse | JSONRPCError) and message.id is not None: target_request_id = str(message.id) # Extract related_request_id from meta if it exists - elif ( # pragma: no cover + elif ( session_message.metadata is not None and isinstance( session_message.metadata, @@ -1009,7 +1009,7 @@ async def message_router(): # regardless of whether a client is connected # messages will be replayed on the re-connect event_id = None - if self._event_store: # pragma: lax no cover + if self._event_store: event_id = await self._event_store.store_event(request_stream_id, message) logger.debug(f"Stored {event_id} from {request_stream_id}") @@ -1020,14 +1020,14 @@ async def message_router(): except (anyio.BrokenResourceError, anyio.ClosedResourceError): # pragma: no cover # Stream might be closed, remove from registry self._request_streams.pop(request_stream_id, None) - else: # pragma: no cover + else: logger.debug( f"""Request stream {request_stream_id} not found for message. Still processing message as the client might reconnect and replay.""" ) except anyio.ClosedResourceError: - if self._terminated: + if self._terminated: # pragma: lax no cover logger.debug("Read stream closed by client") else: logger.exception("Unexpected closure of read stream in message router") @@ -1041,7 +1041,7 @@ async def message_router(): # Yield the streams for the caller to use yield read_stream, write_stream finally: - for stream_id in list(self._request_streams.keys()): # pragma: lax no cover + for stream_id in list(self._request_streams.keys()): await self._clean_up_memory_streams(stream_id) self._request_streams.clear() diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index c25314eab6..39d434505c 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -173,7 +173,7 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA self.app.create_initialization_options(), stateless=True, ) - except Exception: # pragma: no cover + except Exception: # pragma: lax no cover logger.exception("Stateless session crashed") # Assert task group is not None for type checking diff --git a/src/mcp/server/transport_security.py b/src/mcp/server/transport_security.py index 1ed9842c0e..707d4b61dd 100644 --- a/src/mcp/server/transport_security.py +++ b/src/mcp/server/transport_security.py @@ -40,19 +40,19 @@ def __init__(self, settings: TransportSecuritySettings | None = None): # If not specified, disable DNS rebinding protection by default for backwards compatibility self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False) - def _validate_host(self, host: str | None) -> bool: # pragma: no cover + def _validate_host(self, host: str | None) -> bool: """Validate the Host header against allowed values.""" - if not host: + if not host: # pragma: no cover logger.warning("Missing Host header in request") return False # Check exact match first - if host in self.settings.allowed_hosts: + if host in self.settings.allowed_hosts: # pragma: no cover return True # Check wildcard port patterns for allowed in self.settings.allowed_hosts: - if allowed.endswith(":*"): + if allowed.endswith(":*"): # pragma: no branch # Extract base host from pattern base_host = allowed[:-2] # Check if the actual host starts with base host and has a port @@ -62,19 +62,19 @@ def _validate_host(self, host: str | None) -> bool: # pragma: no cover logger.warning(f"Invalid Host header: {host}") return False - def _validate_origin(self, origin: str | None) -> bool: # pragma: no cover + def _validate_origin(self, origin: str | None) -> bool: """Validate the Origin header against allowed values.""" # Origin can be absent for same-origin requests - if not origin: + if not origin: # pragma: no cover return True # Check exact match first - if origin in self.settings.allowed_origins: + if origin in self.settings.allowed_origins: # pragma: no cover return True # Check wildcard port patterns for allowed in self.settings.allowed_origins: - if allowed.endswith(":*"): + if allowed.endswith(":*"): # pragma: no branch # Extract base origin from pattern base_origin = allowed[:-2] # Check if the actual origin starts with base origin and has a port @@ -103,14 +103,14 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res if not self.settings.enable_dns_rebinding_protection: return None - # Validate Host header # pragma: no cover - host = request.headers.get("host") # pragma: no cover - if not self._validate_host(host): # pragma: no cover - return Response("Invalid Host header", status_code=421) # pragma: no cover + # Validate Host header + host = request.headers.get("host") + if not self._validate_host(host): + return Response("Invalid Host header", status_code=421) - # Validate Origin header # pragma: no cover - origin = request.headers.get("origin") # pragma: no cover - if not self._validate_origin(origin): # pragma: no cover - return Response("Invalid Origin header", status_code=403) # pragma: no cover + # Validate Origin header + origin = request.headers.get("origin") + if not self._validate_origin(origin): + return Response("Invalid Origin header", status_code=403) - return None # pragma: no cover + return None diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index ebf534d792..3b48152d5b 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -91,9 +91,9 @@ def validate_scope(self, requested_scope: str | None) -> list[str] | None: requested_scopes = requested_scope.split(" ") allowed_scopes = [] if self.scope is None else self.scope.split(" ") for scope in requested_scopes: - if scope not in allowed_scopes: # pragma: no branch + if scope not in allowed_scopes: raise InvalidScopeError(f"Client was not registered with scope {scope}") - return requested_scopes # pragma: no cover + return requested_scopes def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl: if redirect_uri is not None: diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 243eef5ae6..9c72a23844 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -451,7 +451,7 @@ async def _handle_session_message(message: SessionMessage) -> None: try: await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) await stream.aclose() - except Exception: # pragma: no cover + except Exception: # pragma: lax no cover # Stream might already be closed pass self._response_streams.clear() diff --git a/tests/interaction/README.md b/tests/interaction/README.md new file mode 100644 index 0000000000..be68c3b0f1 --- /dev/null +++ b/tests/interaction/README.md @@ -0,0 +1,228 @@ +# Interaction-model test suite + +This suite enumerates the MCP interaction model as end-to-end tests: one test per piece of +functionality, asserting the full client↔server round trip through the public API. It exists to +pin the SDK's observable behaviour — every request type, every notification direction, every +error plane — so that internal rewrites of the send/receive path can be proven equivalent by +running the suite before and after. + +```bash +uv run --frozen pytest tests/interaction/ +``` + +The whole suite is in-process and event-driven — including the streamable HTTP, SSE, and OAuth +flows — with a single subprocess test for stdio. + +## Ground rules + +- **Public API only.** Tests drive a `Client` connected to a `Server` or `MCPServer`. Nothing + reaches into session internals, so the suite keeps working when those internals change. + `ClientSession` is used directly only for behaviours `Client` cannot express (skipping + initialization, requesting a non-default protocol version). +- **Pin current behaviour.** Every test passes against the current `main`, including behaviours + that diverge from the specification. A failing or xfailed test proves nothing about whether a + rewrite preserved behaviour; a passing test that pins the wrong output exactly does. Known + divergences are recorded as data on the requirement (see below), not worked around in the test. +- **Spec-mandated assertions, not implementation quirks.** Error *codes* are asserted against + the constants in `mcp.types`; error *message strings* are pinned only where they are the + SDK's own deliberate output. +- **No sleeps, no real I/O.** Concurrency is coordinated with `anyio.Event`; every wait that + could hang is bounded by `anyio.fail_after(5)`. The HTTP and OAuth tests drive the Starlette + app in-process through the suite's streaming ASGI bridge (`transports/_bridge.py`), which + delivers each response chunk as the server produces it — full duplex, but still no sockets, + threads, or subprocesses anywhere outside the one stdio test. + +## Layout + +```text +tests/interaction/ + _requirements.py the requirements manifest (see below) + _helpers.py shared type aliases + the wire-recording transport + _connect.py the transport-parametrized connection factories + conftest.py the connect fixture (the transport matrix) + test_coverage.py enforces the manifest ↔ test contract + lowlevel/ one file per feature area, against the low-level Server + mcpserver/ the same feature areas in MCPServer's natural idiom + transports/ behaviour specific to one transport (sessions, resumability, framing) + auth/ OAuth flows against an in-process authorization server +``` + +The two server APIs produce genuinely different wire output for the same conceptual feature +(`MCPServer` generates schemas, converts exceptions to `isError` results, attaches structured +content), so they get parallel directories with mirrored file names rather than one parametrized +test body — each directory pins its flavour's true output exactly. + +### The transport matrix + +Transport-agnostic tests take the `connect` fixture instead of constructing `Client(server)` +directly, and therefore run once per transport: over the in-memory transport, over the server's +real streamable HTTP app driven in-process through the streaming bridge, and over the legacy SSE +transport the same way. A test connects with `async with connect(server, ...) as client:` and +asserts the same output on every leg, because the transport is not supposed to change observable +behaviour. Tests that are tied to one transport do not use the fixture: the wire-recording tests +(their seam is the in-memory stream pair), the bare-`ClientSession` lifecycle tests, the +real-clock timeout tests (the timeout machinery is transport-independent and must not race +transport latency), and everything under `transports/`, which pins behaviour only observable on +that transport. + +A transport conformance test in `transports/` speaks raw `httpx` against the mounted ASGI app +**only** when its assertion is about HTTP semantics that `Client` cannot observe — status codes, +response headers, SSE event fields, which stream a message travels on. Any other behaviour is +asserted through a `Client`, connected to the mounted app via `client_via_http(http)` so several +clients can share one session manager. + +## The requirements manifest + +`_requirements.py` maps every behaviour the suite covers to the reason it must hold: + +```python +"tools:call:content:text": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#text-content", + behavior="tools/call delivers arguments to the tool handler and returns its text content.", +), +``` + +- **`source`** is a deep link into the MCP specification for externally mandated behaviour, + the literal string `"sdk"` for behaviour the SDK chose where the spec is silent, or + `"issue:#n"` for a regression lock. +- **`behavior`** describes the *required* behaviour — what the specification (or the SDK's own + contract) says should happen. Tests always pin the SDK's current behaviour; where that falls + short of `behavior`, the gap is recorded as data rather than hidden in the test. +- **`divergence`** records that gap for entries whose tests pin the divergent current behaviour. +- **`deferred`** marks a behaviour that is tracked but has no test in this suite, with a precise + reason: the SDK does not implement it, the negative cannot be observed, the assertion is + schema-level rather than interaction-level, the feature is experimental (tasks), or the test + would require real-time waits the suite refuses. +- **`transports`** names the transports a behaviour applies to; omitted means transport-independent. +- **`issue`** carries the tracking link for a recorded gap once one is filed. + +Tests link themselves to the manifest with a decorator: + +```python +@requirement("tools:call:content:text") +async def test_call_tool_returns_text_content() -> None: ... +``` + +`test_coverage.py` enforces the contract in both directions: every non-deferred requirement must +be exercised by at least one test, every deferred requirement by none, and an unknown ID fails at +import time. A behaviour without a manifest entry cannot be silently half-tested, and a manifest +entry without a test cannot be silently aspirational. + +### The divergence lifecycle + +1. A test reveals that the SDK does not do what the spec says. The test pins what the SDK + *actually does* and a `Divergence(note=..., issue=...)` goes on the requirement. +2. When the behaviour is eventually fixed, the pinned test fails. Whoever makes the change finds + the divergence note explaining that the old behaviour was a known gap, re-pins the test to the + spec-correct output, and deletes the `Divergence`. +3. An empty divergence list means the SDK is spec-conformant on every behaviour the suite covers. + +A requirement may carry both `divergence` and `deferred`: the divergence records that the SDK falls +short of the spec, and the deferral records why no test pins it (typically because the divergent +behaviour cannot be driven through the public API). Divergence alone implies a test pins the +divergent behaviour; divergence plus deferred means the gap is known but unpinned. + +This is also the triage key for any rewrite: a test that fails on the new code path either has a +divergence note (the rewrite accidentally fixed a known gap — decide whether to keep the fix) or +it does not (the rewrite broke something that was correct — fix the rewrite). + +### When a new spec revision is released + +1. Update `SPEC_REVISION` and walk the new revision's changelog. +2. For each changed interaction, find its requirements (the IDs use the wire method strings the + changelog speaks in), re-audit the tests against the new text, and update `source` links and + assertions where behaviour legitimately changed. +3. New interactions get new requirements and new tests; removed interactions get their + requirements deleted along with their tests. +4. A behaviour that is correct under both revisions needs no change beyond the `source` link. + +## Writing a test + +The shortest complete example of the conventions: + +```python +@requirement("tools:call:content:text") +async def test_call_tool_returns_text_content() -> None: + """Arguments reach the tool handler; its content comes back as the call result.""" + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "add" + assert params.arguments is not None + return CallToolResult(content=[TextContent(text=str(params.arguments["a"] + params.arguments["b"]))]) + + server = Server("adder", on_call_tool=call_tool) + + async with Client(server) as client: + result = await client.call_tool("add", {"a": 2, "b": 3}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="5")])) +``` + +- **The server is defined inside the test** (or in a small fixture at the top of the file when + several tests genuinely share it). The whole observable behaviour fits on one screen. +- **Test names are behaviour sentences** — they state the observable outcome, not the feature + being poked. Docstrings add the one or two sentences of context a reviewer needs, including + whether the assertion is spec-mandated, SDK-defined, or a known divergence. +- **Handlers assert their dispatch identity first** (`assert params.name == "add"`), proving the + request that arrived is the request the test sent. +- **The result proves the round trip.** Server-side observations travel back to the test through + the protocol itself (a tool returns what it saw) or through a closure-captured list; the test + asserts after the call returns. +- **Order within a test**: server handlers → server construction → client callbacks → connect → + act → assert. The test reads in the order the conversation happens. +- A registered handler or tool that a test never invokes gets a `raise NotImplementedError` body + so it cannot silently become load-bearing. +- A test that needs a peer no real `Server` or `Client` can play (a server that answers initialize + with an unsupported version, a client that sends malformed params) plays that side of the wire by + hand over `create_client_server_memory_streams()`. This scripted-peer pattern is the suite's only + way to drive behaviour the typed API cannot produce, and the docstring of every such test says so. + +Stack a second `@requirement` decorator only when a test's natural assertions incidentally prove +another behaviour — one capabilities snapshot proving four `*:capability:declared` entries, one +input-schema identity check proving each preserved keyword. Do not build a test around covering +many requirements at once; if the assertions would be separate, write separate tests. + +### Choosing an assertion + +| The property under test is… | Assert with | +|---|---| +| the result of a transformation (arguments → output, exception → error result) | `result == snapshot(...)` of the full object, so any field the implementation adds or drops fails the test | +| pass-through of an opaque value (`_meta`, cursors) | identity against the same variable that was sent — a snapshot of a pass-through value only matches the input because a human checked two literals correspond | +| an error | `pytest.raises(MCPError)` and a snapshot of `exc.value.error` when the message is the SDK's own; a plain `==` on `.code` against the `mcp.types` constant when it is not | +| third-party output embedded in a result (validation messages) | the stable prefix only — never pin text that changes with a dependency upgrade | + +### Notifications and concurrency + +The client's receive loop dispatches each incoming message to completion before reading the next, +and the in-memory transport delivers everything on one ordered stream. Together these guarantee +that every notification a server handler emits before its response reaches the client callback +before the originating request returns — so tests collect notifications into a plain list and +assert after the call, with no synchronisation. The exceptions: + +- a notification not triggered by a request the test is awaiting needs an `anyio.Event` set in + the receiving handler and awaited under `anyio.fail_after(5)`; +- the ordering guarantee does not survive transports that split messages across streams (the + streamable HTTP standalone GET stream) — see `transports/test_streamable_http.py`. + +### Coverage + +CI requires 100% line and branch coverage, including `tests/`, and `strict-no-cover` fails the +build if a line marked `# pragma: no cover` is ever executed. When a new test starts covering a +pragma'd line in `src/`, delete the pragma in the same change. Do not add new `# type: ignore` or +`# noqa` comments; restructure instead. Two pragmas are sanctioned in this suite's test code, both +for known-upstream tracer bugs and only after restructuring has been tried: `# pragma: no branch` +on a `with`/`async with` line whose only fault is coverage.py mis-tracing the exit arc of a nested +async context (reserve it for shapes that cannot collapse — a sync `with` adjacent to an +`async with`); and `# pragma: lax no cover` on a single statement that 3.11's tracer drops because +the preceding `async with` unwinds via `coro.throw()` (python/cpython#106749, wontfix on 3.11) — +this hits any test that must run statements after a `ClientSession`/`streamable_http_client` exits +but still inside an outer `async with`, and no restructure can avoid it. + +A handful of `# pragma: lax no cover` markers in `src/` cover teardown exception handlers whose +execution is timing-dependent under the in-process HTTP bridge — the POST-stream and +stateless-session `except Exception` handlers in `server/streamable_http*.py`, the `_terminated` +check in `message_router`, and the response-stream double-close guard in +`BaseSession._receive_loop`. `strict-no-cover` does not check `lax` lines; do not promote them to +strict `no cover` without first making the teardown ordering deterministic. The suite also relies +on a one-line `src/mcp/server/sse.py` fix (`sse_stream_reader.aclose()`) that closes a stream the +SSE leg would otherwise leak. diff --git a/tests/interaction/__init__.py b/tests/interaction/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/interaction/_connect.py b/tests/interaction/_connect.py new file mode 100644 index 0000000000..1faf4aa8d6 --- /dev/null +++ b/tests/interaction/_connect.py @@ -0,0 +1,360 @@ +"""Transport-parametrized connection factories for the interaction suite. + +The `connect` fixture (see conftest.py) hands tests one of these factories so the same test body +runs over each transport without naming any of them: the factory is a drop-in replacement for +constructing `Client(server, ...)` and yields the connected client. The HTTP factories drive the +server's real Starlette app through the in-process streaming bridge, so the full transport layer +(session ids, SSE encoding, session management) runs with no sockets, threads, or subprocesses. +""" + +from collections.abc import AsyncIterator, Awaitable, Callable, Iterable +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from typing import Any, Protocol + +import httpx +from httpx_sse import ServerSentEvent, aconnect_sse +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import Mount, Route + +from mcp.client.client import Client +from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT +from mcp.client.sse import sse_client +from mcp.client.streamable_http import streamable_http_client +from mcp.server import Server +from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier +from mcp.server.auth.settings import AuthSettings +from mcp.server.mcpserver import MCPServer +from mcp.server.sse import SseServerTransport +from mcp.server.streamable_http import EventStore +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.server.transport_security import TransportSecuritySettings +from mcp.types import ( + LATEST_PROTOCOL_VERSION, + ClientCapabilities, + Implementation, + InitializeRequestParams, + JSONRPCMessage, + JSONRPCRequest, + JSONRPCResponse, + jsonrpc_message_adapter, +) +from tests.interaction.transports._bridge import StreamingASGITransport + +# The in-process app is mounted at this origin purely so URLs are well-formed; nothing listens here. +BASE_URL = "http://127.0.0.1:8000" + +# DNS-rebinding protection validates Host/Origin headers against a real network attack that cannot +# exist for an in-process ASGI app, so the in-process factories disable it; tests that exercise the +# protection itself pass explicit settings (or transport_security=None to get the localhost +# auto-enable behaviour). +NO_DNS_REBINDING_PROTECTION = TransportSecuritySettings(enable_dns_rebinding_protection=False) + + +class Connect(Protocol): + """Connect a Client to a server over the transport selected by the `connect` fixture. + + Accepts the same keyword arguments as `Client` and yields the connected client. + """ + + def __call__( + self, + server: Server | MCPServer, + *, + read_timeout_seconds: float | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + client_info: Implementation | None = None, + elicitation_callback: ElicitationFnT | None = None, + ) -> AbstractAsyncContextManager[Client]: ... + + +@asynccontextmanager +async def connect_in_memory( + server: Server | MCPServer, + *, + read_timeout_seconds: float | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + client_info: Implementation | None = None, + elicitation_callback: ElicitationFnT | None = None, +) -> AsyncIterator[Client]: + """Yield a Client connected to the server over the in-memory transport.""" + async with Client( + server, + read_timeout_seconds=read_timeout_seconds, + sampling_callback=sampling_callback, + list_roots_callback=list_roots_callback, + logging_callback=logging_callback, + message_handler=message_handler, + client_info=client_info, + elicitation_callback=elicitation_callback, + ) as client: + yield client + + +@asynccontextmanager +async def connect_over_streamable_http( + server: Server | MCPServer, + *, + stateless_http: bool = False, + json_response: bool = False, + event_store: EventStore | None = None, + retry_interval: int | None = None, + read_timeout_seconds: float | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + client_info: Implementation | None = None, + elicitation_callback: ElicitationFnT | None = None, +) -> AsyncIterator[Client]: + """Yield a Client connected to the server's streamable HTTP app, entirely in process. + + With the defaults this is the matrix leg (stateful sessions, SSE responses); the + transport-specific tests pass `stateless_http` or `json_response` to select the other + server modes, and the resumability tests pass an `event_store` (with `retry_interval=0` so + the client's reconnection wait is a no-op). + """ + app = server.streamable_http_app( + stateless_http=stateless_http, + json_response=json_response, + event_store=event_store, + retry_interval=retry_interval, + transport_security=NO_DNS_REBINDING_PROTECTION, + ) + async with ( + server.session_manager.run(), + httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http_client, + Client( + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client), + read_timeout_seconds=read_timeout_seconds, + sampling_callback=sampling_callback, + list_roots_callback=list_roots_callback, + logging_callback=logging_callback, + message_handler=message_handler, + client_info=client_info, + elicitation_callback=elicitation_callback, + ) as client, + ): + yield client + + +@asynccontextmanager +async def mounted_app( + server: Server | MCPServer, + *, + stateless_http: bool = False, + json_response: bool = False, + event_store: EventStore | None = None, + retry_interval: int | None = None, + transport_security: TransportSecuritySettings | None = NO_DNS_REBINDING_PROTECTION, + on_request: Callable[[httpx.Request], Awaitable[None]] | None = None, + headers: dict[str, str] | None = None, + auth: AuthSettings | None = None, + token_verifier: TokenVerifier | None = None, + auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any] | None = None, +) -> AsyncIterator[tuple[httpx.AsyncClient, StreamableHTTPSessionManager]]: + """Mount the server's streamable HTTP app on the in-process bridge and yield an httpx client. + + Yields the httpx client (rooted at the in-process origin) and the live session manager. Tests + use this in two ways: for raw-httpx assertions (status codes, headers, SSE bytes) the test + speaks HTTP through the yielded client directly; for client-driven assertions the test wraps + that client in `client_via_http(http)`, which lets several `Client`s share the one mounted + session manager. `on_request` records every outgoing HTTP request before it leaves the + yielded client. + + DNS-rebinding protection is disabled by default; pass explicit settings (or `None` for the + localhost auto-enable behaviour) to test the protection itself. + """ + lowlevel = server._lowlevel_server if isinstance(server, MCPServer) else server + app = lowlevel.streamable_http_app( + stateless_http=stateless_http, + json_response=json_response, + event_store=event_store, + retry_interval=retry_interval, + transport_security=transport_security, + auth=auth, + token_verifier=token_verifier, + auth_server_provider=auth_server_provider, + ) + event_hooks = {"request": [on_request]} if on_request is not None else None + async with ( + server.session_manager.run(), + httpx.AsyncClient( + transport=StreamingASGITransport(app), base_url=BASE_URL, event_hooks=event_hooks, headers=headers + ) as http_client, + ): + yield http_client, server.session_manager + + +@asynccontextmanager +async def client_via_http( + http_client: httpx.AsyncClient, + *, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + elicitation_callback: ElicitationFnT | None = None, +) -> AsyncIterator[Client]: + """Connect a `Client` over an already-mounted streamable HTTP app. + + Use with `mounted_app(...)` so several `Client`s share the one session manager, or so a + client-driven assertion can sit alongside raw-httpx assertions in the same test. The + underlying `httpx.AsyncClient` is left open when the `Client` exits. + """ + transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) + async with Client( + transport, + logging_callback=logging_callback, + message_handler=message_handler, + elicitation_callback=elicitation_callback, + ) as client: + yield client + + +def parse_sse_messages(events: Iterable[ServerSentEvent]) -> list[JSONRPCMessage]: + """Decode SSE events into JSON-RPC messages, skipping priming events that carry no data.""" + return [jsonrpc_message_adapter.validate_json(event.data) for event in events if event.data] + + +async def post_jsonrpc( + http: httpx.AsyncClient, body: dict[str, object], *, session_id: str | None = None +) -> tuple[httpx.Response, list[JSONRPCMessage]]: + """POST a JSON-RPC body and read its SSE response stream to completion. + + Returns the HTTP response (for header/status assertions) and the parsed JSON-RPC messages + that arrived on the response's SSE stream. Only meaningful for requests the server answers + with `text/event-stream`; for error responses or 202 notification acknowledgements, use + `httpx.AsyncClient.post` directly and assert on the response. + """ + async with aconnect_sse(http, "POST", "/mcp", json=body, headers=base_headers(session_id=session_id)) as source: + events = [event async for event in source.aiter_sse()] + return source.response, parse_sse_messages(events) + + +def base_headers(*, session_id: str | None = None) -> dict[str, str]: + """Standard request headers for raw-httpx streamable-HTTP tests. + + Every well-formed request carries these (Accept covering both response representations, + Content-Type for POST bodies, MCP-Protocol-Version at the latest revision, and the session + ID once one exists), so a test that wants to assert a specific rejection only varies the one + header under test. + """ + headers = { + "accept": "application/json, text/event-stream", + "content-type": "application/json", + "mcp-protocol-version": LATEST_PROTOCOL_VERSION, + } + if session_id is not None: + headers["mcp-session-id"] = session_id + return headers + + +def initialize_body(request_id: int = 1) -> dict[str, object]: + """A wire-level initialize JSON-RPC request body, exactly as an SDK client would send it.""" + params = InitializeRequestParams( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ClientCapabilities(), + client_info=Implementation(name="raw", version="0.0.0"), + ) + return JSONRPCRequest( + jsonrpc="2.0", id=request_id, method="initialize", params=params.model_dump(by_alias=True, exclude_none=True) + ).model_dump(by_alias=True, exclude_none=True) + + +async def initialize_via_http(http: httpx.AsyncClient) -> str: + """Perform the initialize handshake over a raw `httpx.AsyncClient` and return the session ID. + + Validates the SSE response and sends the `notifications/initialized` follow-up, so the server + is fully ready for subsequent feature requests when this returns. + """ + async with aconnect_sse(http, "POST", "/mcp", json=initialize_body(), headers=base_headers()) as source: + assert source.response.status_code == 200 + # An event-store-backed server opens the stream with a priming event (empty data); skip it. + events = [event async for event in source.aiter_sse() if event.data] + assert len(events) == 1 + assert JSONRPCResponse.model_validate_json(events[0].data).id == 1 + session_id = source.response.headers["mcp-session-id"] + initialized = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "method": "notifications/initialized"}, + headers=base_headers(session_id=session_id), + ) + assert initialized.status_code == 202 + return session_id + + +def build_sse_app(server: Server | MCPServer) -> tuple[Starlette, SseServerTransport]: + """Mount a server on a Starlette app exposing the legacy SSE transport at /sse and /messages/. + + `MCPServer.sse_app()` exists but does not expose the underlying `SseServerTransport`, which + the SSE-specific tests need; building the app explicitly here gives both server flavours the + same routing while keeping that handle. + """ + sse = SseServerTransport( + "/messages/", security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False) + ) + lowlevel = server._lowlevel_server if isinstance(server, MCPServer) else server + + async def handle_sse(request: Request) -> Response: + async with sse.connect_sse(request.scope, request.receive, request._send) as (read, write): + await lowlevel.run(read, write, lowlevel.create_initialization_options()) + return Response() + + app = Starlette( + routes=[ + Route("/sse", endpoint=handle_sse, methods=["GET"]), + Mount("/messages/", app=sse.handle_post_message), + ], + ) + return app, sse + + +@asynccontextmanager +async def connect_over_sse( + server: Server | MCPServer, + *, + read_timeout_seconds: float | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + client_info: Implementation | None = None, + elicitation_callback: ElicitationFnT | None = None, +) -> AsyncIterator[Client]: + """Yield a Client connected to the server's legacy SSE transport, entirely in process.""" + app, _ = build_sse_app(server) + + def httpx_client_factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + # The SSE server transport's connect_sse runs the entire MCP session inside the GET + # request and only releases its streams after that request observes a disconnect, so the + # bridge must let the application drain rather than cancelling at close. + return httpx.AsyncClient( + transport=StreamingASGITransport(app, cancel_on_close=False), + base_url=BASE_URL, + headers=headers, + timeout=timeout, + auth=auth, + ) + + transport = sse_client(f"{BASE_URL}/sse", httpx_client_factory=httpx_client_factory) + async with Client( + transport, + read_timeout_seconds=read_timeout_seconds, + sampling_callback=sampling_callback, + list_roots_callback=list_roots_callback, + logging_callback=logging_callback, + message_handler=message_handler, + client_info=client_info, + elicitation_callback=elicitation_callback, + ) as client: + yield client diff --git a/tests/interaction/_helpers.py b/tests/interaction/_helpers.py new file mode 100644 index 0000000000..25833b0ca5 --- /dev/null +++ b/tests/interaction/_helpers.py @@ -0,0 +1,107 @@ +"""Shared helpers for the interaction suite. + +Keep this module small: it exists only for (a) types that every test would otherwise have to +assemble from the SDK's internals to annotate a client callback, and (b) the recording transport +used by the wire-level tests. Server fixtures and assertion helpers belong in the test that uses +them. +""" + +from types import TracebackType + +import anyio +from typing_extensions import Self + +from mcp.client._transport import ReadStream, Transport, TransportStreams, WriteStream +from mcp.shared.message import SessionMessage +from mcp.shared.session import RequestResponder +from mcp.types import ClientResult, ServerNotification, ServerRequest + +# TODO: this union is the parameter type of every client message handler (MessageHandlerFnT), +# but the SDK does not export a name for it -- writing a correctly-typed handler requires +# importing RequestResponder from mcp.shared.session and assembling the union by hand. It +# should be a named, exported alias next to MessageHandlerFnT (like ClientRequestContext is +# for the request callbacks), at which point this alias can be deleted. +IncomingMessage = RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception +"""Everything a client message handler can receive.""" + + +class _RecordingReadStream: + """Delegates to a read stream, appending every received message to a log.""" + + def __init__(self, inner: ReadStream[SessionMessage | Exception], log: list[SessionMessage | Exception]) -> None: + self._inner = inner + self._log = log + + async def receive(self) -> SessionMessage | Exception: + item = await self._inner.receive() + self._log.append(item) + return item + + async def aclose(self) -> None: + await self._inner.aclose() + + def __aiter__(self) -> Self: + return self + + async def __anext__(self) -> SessionMessage | Exception: + try: + return await self.receive() + except anyio.EndOfStream: + raise StopAsyncIteration from None + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None + ) -> bool | None: + await self.aclose() + return None + + +class _RecordingWriteStream: + """Delegates to a write stream, appending every sent message to a log.""" + + def __init__(self, inner: WriteStream[SessionMessage], log: list[SessionMessage]) -> None: + self._inner = inner + self._log = log + + async def send(self, item: SessionMessage, /) -> None: + self._log.append(item) + await self._inner.send(item) + + async def aclose(self) -> None: + await self._inner.aclose() + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None + ) -> bool | None: + await self.aclose() + return None + + +class RecordingTransport: + """Wraps a Transport and records every message crossing the client's transport boundary. + + `sent` holds everything the client wrote towards the server; `received` holds everything the + server delivered to the client. The recording sits at the transport seam -- the exact payloads + a real transport would serialise -- and never touches the session, so wire-level assertions + written against it survive changes to the receive path. + """ + + def __init__(self, inner: Transport) -> None: + self.inner = inner + self.sent: list[SessionMessage] = [] + self.received: list[SessionMessage | Exception] = [] + + async def __aenter__(self) -> TransportStreams: + read_stream, write_stream = await self.inner.__aenter__() + return _RecordingReadStream(read_stream, self.received), _RecordingWriteStream(write_stream, self.sent) + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None + ) -> bool | None: + return await self.inner.__aexit__(exc_type, exc_val, exc_tb) diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py new file mode 100644 index 0000000000..109b30fc77 --- /dev/null +++ b/tests/interaction/_requirements.py @@ -0,0 +1,2816 @@ +"""Requirements manifest for the interaction-model test suite. + +Every user-facing behaviour the SDK must satisfy, keyed by a stable `:[:]` +ID. Each entry owns the tests that exercise it: tests declare `@requirement("")` (a test that +proves several behaviours stacks several decorators) and `test_coverage.py` enforces the contract +in both directions: every non-deferred requirement has at least one test, and every test carries +at least one requirement. + +Sources: + spec URL -- externally mandated by the MCP specification (deep link to the section) + `sdk` -- a behavioural guarantee the SDK chose; not spec-mandated + `issue:#n` -- regression lock-in for a previously fixed bug + +The `behavior` sentence describes the REQUIRED behaviour -- what the specification (or the SDK's +own contract) says should happen. Tests always pin the SDK's current behaviour. Where current +behaviour falls short of `behavior`, the gap is recorded as data: `divergence` on entries whose +tests pin the divergent behaviour, or `deferred` on entries that are tracked but not yet covered +by a test in this suite. An entry may carry both: `divergence` records the spec-compliance gap +(issue-able) and `deferred` records why no test exists; `divergence` alone implies a test pins +the divergent behaviour. `issue` carries the tracking link for a recorded gap once one is filed. + +`deferred` reasons take one of three shapes: where the behaviour is exercised elsewhere in this +repo the reason names the covering test path; where the SDK does not implement the behaviour at +all the reason starts with "Not implemented in the SDK"; and where an interaction-level test is +planned but not yet written the reason starts with "Not yet covered here". + +`transports` records which transports a behaviour applies to (or is observable on); None means +the behaviour is transport-independent. + +The ID vocabulary and entry granularity are aligned with the TypeScript SDK's end-to-end +requirements suite, so coverage and recorded divergences can be compared across the two SDKs +entry by entry; IDs that exist in only one SDK reflect genuinely different API surface. +""" + +import re +from collections.abc import Callable +from dataclasses import dataclass +from typing import Literal, TypeVar + +import pytest + +SPEC_REVISION = "2025-11-25" +SPEC_BASE_URL = f"https://modelcontextprotocol.io/specification/{SPEC_REVISION}" + +Transport = Literal["in-memory", "stdio", "streamable-http", "sse"] + +_TestFn = TypeVar("_TestFn", bound=Callable[..., object]) + +_SOURCE_PATTERN = re.compile(r"https://modelcontextprotocol\.io/specification/.+|sdk|issue:#\d+") + +_TASKS_DEFERRAL = ( + "Tasks are experimental and the spec is being substantially revised; python task behaviour is " + "covered by tests/experimental/tasks/ until the next spec revision settles." +) + + +@dataclass(frozen=True, kw_only=True) +class Divergence: + """A documented gap between the SDK behaviour this suite pins and what `source` mandates.""" + + note: str + issue: str | None = None + + +@dataclass(frozen=True, kw_only=True) +class Requirement: + """A single testable behaviour and the provenance of why it must hold.""" + + source: str + behavior: str + transports: tuple[Transport, ...] | None = None + divergence: Divergence | None = None + deferred: str | None = None + issue: str | None = None + + def __post_init__(self) -> None: + if not _SOURCE_PATTERN.fullmatch(self.source): + raise ValueError(f"source must be a specification URL, 'sdk', or 'issue:#n', got {self.source!r}") + + +REQUIREMENTS: dict[str, Requirement] = { + # ═══════════════════════════════════════════════════════════════════════════ + # Lifecycle & version negotiation + # ═══════════════════════════════════════════════════════════════════════════ + "lifecycle:capability:client-not-declared": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#operation", + behavior=( + "The client rejects sending notifications or registering handlers for capabilities it did not declare." + ), + divergence=Divergence( + note=( + "The client does not check its own declared capabilities before sending notifications or " + "serving callbacks; nothing prevents a caller from violating the spec's MUST." + ), + ), + deferred=( + "Not implemented in the SDK: the client does not check its own declared capabilities before " + "sending notifications or serving callbacks." + ), + ), + "lifecycle:capability:server-not-advertised": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#operation", + behavior=( + "The client rejects calls to methods (e.g. resources/list) for capabilities the server did not advertise." + ), + divergence=Divergence( + note=( + "The client sends any request regardless of the server's advertised capabilities and " + "surfaces whatever the server answers; the spec's MUST is not enforced." + ), + ), + deferred=( + "Not implemented in the SDK: the client sends any request regardless of the server's " + "advertised capabilities and surfaces whatever the server answers." + ), + ), + "lifecycle:initialize:basic": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior=( + "Connecting sends initialize with the protocol version, client capabilities, and client " + "info; the server responds with its own and the connection is established." + ), + ), + "lifecycle:initialize:server-info": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior="The initialize result identifies the server: name and version, plus title when declared.", + ), + "lifecycle:initialize:instructions": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior="A server may include an instructions string in the initialize result; the client exposes it.", + ), + "lifecycle:initialize:capabilities:from-handlers": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#capability-negotiation", + behavior=( + "The server advertises a capability for each feature area it has a registered handler for, " + "and omits the capability for areas it does not." + ), + ), + "lifecycle:initialize:capabilities:minimal": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#capability-negotiation", + behavior="A server with no feature handlers advertises no feature capabilities.", + ), + "lifecycle:initialize:client-info": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior="The client's name, version, and title are visible to server handlers after initialization.", + ), + "lifecycle:initialize:client-capabilities": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#capability-negotiation", + behavior=( + "The client capabilities visible to the server reflect which client callbacks are configured " + "(sampling, elicitation, roots)." + ), + ), + "lifecycle:initialized-notification": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior=( + "After successful initialization, the client sends exactly one initialized notification, " + "before any non-ping request." + ), + ), + "lifecycle:ping": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/ping#behavior-requirements", + behavior="ping in either direction returns an empty result.", + ), + "ping:client-to-server": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/ping#behavior-requirements", + behavior="A client-initiated ping receives an empty result from the server.", + ), + "ping:server-to-client": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/ping#behavior-requirements", + behavior="A server-initiated ping receives an empty result from the client.", + ), + "lifecycle:requests-before-initialized": Requirement( + source="sdk", + behavior=( + "A request other than ping sent before the initialization handshake completes is rejected with an error." + ), + ), + "lifecycle:pre-initialization-ordering": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior=( + "Before initialization completes, the client sends no requests other than pings, and the " + "server sends no requests other than pings and logging." + ), + divergence=Divergence( + note=( + "The server's send methods (create_message / elicit_form / list_roots) do not check " + "initialization state before sending; on the client side, Client always completes the " + "handshake before any caller code runs." + ), + ), + deferred=( + "Not implemented in the SDK: neither side enforces sender-side restraint. The server's send " + "methods (create_message / elicit_form / list_roots) do not check initialization state before " + "sending, and there is no natural hook to issue a server-to-client request between the " + "initialize response and the initialized notification through the public API; on the client " + "side, Client always completes the handshake before any caller code runs." + ), + ), + "lifecycle:version:downgrade": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#version-negotiation", + behavior=( + "When the server returns an older supported protocol version, the client downgrades to it " + "and the connection succeeds at that version." + ), + ), + "lifecycle:version:match": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#version-negotiation", + behavior=( + "When the server supports the requested protocol version it echoes that version in the " + "initialize result, and the connection proceeds at that version." + ), + ), + "lifecycle:version:server-fallback-latest": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#version-negotiation", + behavior=( + "An initialize request carrying a protocol version the server does not support is answered " + "with another version the server supports — the latest one — rather than an error." + ), + ), + "lifecycle:version:reject-unsupported": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#version-negotiation", + behavior=( + "A client that receives an initialize response carrying a protocol version it does not " + "support fails initialization with an error rather than proceeding with the session." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Protocol primitives: cancellation, timeout, progress, errors, _meta + # ═══════════════════════════════════════════════════════════════════════════ + "protocol:request-id:unique": Requirement( + source=f"{SPEC_BASE_URL}/basic#requests", + behavior=( + "Every request sent on a session carries a unique, non-null string or integer id; ids are " + "never reused within the session." + ), + ), + "protocol:notifications:no-response": Requirement( + source=f"{SPEC_BASE_URL}/basic#notifications", + behavior=( + "Notifications are never answered: every message the server delivers is either the response " + "to a request the client sent or a notification carrying no id." + ), + ), + "protocol:cancel:abort-signal": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#cancellation-flow", + behavior=( + "Cancelling an in-flight request through the client API sends notifications/cancelled with " + "the request id and fails the local call." + ), + deferred=( + "Not implemented in the SDK: there is no public client-side API to cancel an in-flight " + "request; cancellation requires hand-constructing the notification (which is how " + "protocol:cancel:in-flight exercises the receiving side)." + ), + ), + "protocol:cancel:handler-abort-propagates": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior="On the receiving side, a cancellation notification stops the running request handler.", + ), + "protocol:cancel:in-flight": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior=( + "A cancellation notification for an in-flight request stops the server-side handler, and the " + "receiver does not send a response for the cancelled request." + ), + divergence=Divergence( + note=( + "The spec says receivers of a cancellation SHOULD NOT send a response for the cancelled " + "request; the server sends an error response (code 0, 'Request cancelled'), which is what " + "unblocks the SDK client's pending call." + ), + ), + ), + "protocol:cancel:initialize-not-cancellable": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior="The client never sends notifications/cancelled for the initialize request.", + deferred=( + "Not implemented in the SDK: the client has no public cancellation API at all, so no pathway " + "exists that could cancel initialize; there is no distinct behaviour to pin beyond that absence." + ), + ), + "protocol:cancel:late-response-ignored": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior=( + "A response that arrives after the sender issued notifications/cancelled is ignored; the " + "request stays failed and no error is raised." + ), + divergence=Divergence( + note=( + "A response whose id matches no in-flight request is delivered to the message handler " + "as a RuntimeError rather than being silently ignored. The post-cancellation case is the " + "same code path; tested in its unknown-id form because that is deterministic without the " + "client-side cancellation API the SDK does not yet provide." + ), + ), + ), + "protocol:cancel:server-survives": Requirement( + source="sdk", + behavior="The session continues to serve new requests after an earlier request was cancelled.", + ), + "protocol:cancel:server-to-client": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior=( + "A server that abandons an in-flight server-initiated request (sampling, elicitation, roots) " + "cancels it, and the client stops processing the cancelled request." + ), + divergence=Divergence( + note=( + "Abandoning a server-side send_request emits no cancellation notification, and the client " + "could not act on one anyway: client callbacks run inline in the receive loop, so a " + "cancellation is not even read until the callback has finished." + ), + ), + deferred=( + "Not implemented in the SDK: abandoning a server-side send_request emits no cancellation " + "notification (the same sender-side gap recorded on protocol:timeout:sends-cancellation), and " + "the client could not act on one anyway because client callbacks run inline in the receive " + "loop, so a cancellation would not even be read until the callback had already finished." + ), + ), + "protocol:cancel:unknown-id-ignored": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#error-handling", + behavior=( + "The receiver silently ignores a cancellation notification referencing an unknown or " + "already-completed request id; no error response is sent and no exception is raised." + ), + ), + "protocol:cancel:sender-targeting": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior=( + "Cancellation notifications reference only requests that were previously issued in the same " + "direction and are believed to still be in flight." + ), + deferred=( + "Not implemented in the SDK: there is no public client-side cancel API to drive (see " + "protocol:cancel:abort-signal), so the sender-side targeting rule has nothing to pin." + ), + ), + "protocol:error:connection-closed": Requirement( + source="sdk", + behavior="Closing the transport fails all in-flight requests with a connection-closed error.", + ), + "protocol:error:internal-error": Requirement( + source=f"{SPEC_BASE_URL}/basic#responses", + behavior=( + "An unhandled exception in a request handler is returned to the caller as JSON-RPC error " + "-32603 Internal error." + ), + divergence=Divergence( + note=( + "The low-level Server returns code 0 (not a defined JSON-RPC code) instead of -32603 and " + "leaks str(exc) as the error message." + ), + ), + ), + "protocol:error:invalid-params": Requirement( + source=f"{SPEC_BASE_URL}/basic#responses", + behavior="A request with malformed params is answered with JSON-RPC error -32602 Invalid params.", + ), + "protocol:error:method-not-found": Requirement( + source=f"{SPEC_BASE_URL}/basic#responses", + behavior="A request whose method has no registered handler is answered with a METHOD_NOT_FOUND error.", + ), + "protocol:meta:related-task": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#related-task-metadata", + behavior="Messages may carry related-task _meta associating them with a task.", + deferred=_TASKS_DEFERRAL, + ), + "meta:request-to-handler": Requirement( + source=f"{SPEC_BASE_URL}/basic#_meta", + behavior="The _meta object the client attaches to a request is visible to the server handler.", + ), + "meta:result-to-client": Requirement( + source=f"{SPEC_BASE_URL}/basic#_meta", + behavior="The _meta object a handler attaches to its result is delivered to the client.", + ), + "protocol:progress:callback": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=( + "Progress notifications emitted by a handler during a request are delivered to the caller's " + "progress callback, in order, with their progress, total, and message." + ), + ), + "protocol:progress:token-injected": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=( + "Supplying a progress callback attaches a progress token to the outgoing request, which the " + "server-side handler can observe in its request metadata." + ), + ), + "protocol:progress:token-unique": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=("Concurrent in-flight requests that each supply a progress callback carry distinct progress tokens."), + ), + "protocol:progress:monotonic": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=( + "The progress value increases with each notification for a given token, even when the total is unknown." + ), + divergence=Divergence( + note=( + "The spec MUST is not enforced: progress values are not validated on either side, so a " + "handler that emits non-increasing values has them forwarded to the callback unchanged." + ), + ), + ), + "protocol:progress:stops-after-completion": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#behavior-requirements", + behavior="Progress notifications for a token stop once the associated request completes.", + divergence=Divergence( + note=( + "send_progress_notification does not check whether the token's request has already " + "completed; the late notification is sent and reaches the client." + ), + ), + ), + "protocol:progress:late-dropped-by-client": Requirement( + source="sdk", + behavior=( + "A progress notification that arrives after its request has completed is not delivered to the " + "original progress callback." + ), + ), + "protocol:progress:no-token": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior="Without a progress callback the request carries no progress token.", + ), + "protocol:progress:client-to-server": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior="A progress notification sent by the client is delivered to the server's progress handler.", + ), + "protocol:timeout:basic": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior=( + "A request that exceeds its read timeout fails with a request-timeout error instead of " + "waiting forever for the response." + ), + ), + "protocol:timeout:max-total": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior="A maximum total timeout is enforced even when progress notifications keep arriving.", + divergence=Divergence( + note=( + "There is no maximum-total-timeout option; only the per-request read timeout exists, so the " + "spec's SHOULD that an overall maximum is always enforced cannot be satisfied." + ), + ), + deferred=( + "Not implemented in the SDK: there is no maximum-total-timeout option; only the per-request " + "read timeout exists." + ), + ), + "protocol:timeout:reset-on-progress": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior="When configured to do so, each progress notification resets the request's read timeout.", + deferred=( + "Not implemented in the SDK: progress notifications do not reset the request read timeout and " + "no option exists to enable that." + ), + ), + "protocol:timeout:sends-cancellation": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior=( + "When a request times out, the sender issues notifications/cancelled for that request before " + "failing the local call." + ), + divergence=Divergence( + note=( + "The client only raises locally and sends nothing on timeout, so the server keeps running the handler." + ), + ), + ), + "protocol:timeout:session-survives": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior="The session continues to serve new requests after an earlier request timed out.", + ), + "protocol:timeout:session-default": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior="A session-level read timeout applies to every request that does not override it.", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Tools + # ═══════════════════════════════════════════════════════════════════════════ + "tools:call:content:audio": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#audio-content", + behavior="A tool result can carry audio content: base64 data with a mimeType.", + ), + "tools:call:content:embedded-resource": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#embedded-resources", + behavior="A tool result can carry an embedded resource with full text or blob contents.", + ), + "tools:call:content:image": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#image-content", + behavior="A tool result can carry image content: base64 data with a mimeType.", + ), + "tools:call:content:mixed": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool-result", + behavior="A tool result can carry multiple content blocks of different types; order is preserved.", + ), + "tools:call:content:resource-link": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#resource-links", + behavior="A tool result can carry a resource_link content block referencing a resource by URI.", + ), + "tools:call:content:text": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#text-content", + behavior="tools/call delivers arguments to the tool handler and returns its text content to the caller.", + ), + "tools:call:concurrent": Requirement( + source="sdk", + behavior=( + "Multiple tool calls in flight on one session are dispatched concurrently, and each caller " + "receives the response to its own request." + ), + ), + "tools:call:elicitation-roundtrip": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#user-interaction-model", + behavior=( + "A tool handler that issues an elicitation receives the client's result and can embed it in " + "the tool call result." + ), + ), + "tools:call:is-error": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#error-handling", + behavior=( + "A tool execution failure is returned as a result with isError true and the failure described " + "in content, not as a JSON-RPC error." + ), + ), + "tools:call:logging-mid-execution": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#log-message-notifications", + behavior=( + "Log notifications emitted by a tool handler during execution reach the client's logging " + "callback before the tool result returns." + ), + ), + "tools:call:progress": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=( + "Progress notifications emitted by a tool handler reach the caller's progress callback before " + "the tool result returns." + ), + ), + "tools:call:sampling-roundtrip": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", + behavior=( + "A tool handler that issues a sampling request receives the client's completion and can embed " + "it in the tool call result." + ), + ), + "tools:call:structured-content": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#structured-content", + behavior="A tool result can carry structuredContent alongside content; the client receives both.", + ), + "tools:call:structured-content:text-mirror": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#structured-content", + behavior="A tool returning structured content also returns the serialized JSON as a text content block.", + ), + "tools:call:unknown-name": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#error-handling", + behavior="tools/call for a name the server does not recognise returns a JSON-RPC error.", + ), + "tools:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#capabilities", + behavior="A server with a list_tools handler advertises the tools capability in its initialize result.", + ), + "tools:input-schema:json-schema-2020-12": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool", + behavior=( + "A tool registered with a JSON Schema 2020-12 inputSchema (nested objects, $defs references) " + "is discoverable and callable." + ), + ), + "tools:input-schema:preserve-additional-properties": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool", + behavior="tools/list preserves inputSchema additionalProperties as registered.", + ), + "tools:input-schema:preserve-defs": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool", + behavior="tools/list preserves inputSchema $defs as registered.", + ), + "tools:input-schema:preserve-schema-dialect": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool", + behavior="tools/list preserves the inputSchema $schema dialect URI as registered.", + ), + "tools:list-changed": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#list-changed-notification", + behavior=( + "When the tool set changes, the server sends notifications/tools/list_changed and it reaches " + "the client's handler." + ), + ), + "tools:list:basic": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#listing-tools", + behavior="tools/list returns the registered tools with name, description, and inputSchema.", + ), + "tools:list:metadata": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool", + behavior=( + "Optional Tool fields supplied by the server (title, annotations, outputSchema, icons, _meta) " + "are delivered to the client unchanged." + ), + ), + "tools:list:pagination": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#response-format", + behavior=( + "tools/list supports cursor pagination: the nextCursor returned by a list handler round-trips " + "back to the handler as an opaque cursor until the listing is exhausted." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Tools: SDK guarantees + # ═══════════════════════════════════════════════════════════════════════════ + "client:output-schema:skip-on-error": Requirement( + source="sdk", + behavior="The client skips structured-content validation when the tool result has isError true.", + ), + "client:output-schema:validate": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#output-schema", + behavior=( + "A tool result whose structuredContent does not conform to the tool's declared outputSchema " + "is rejected by the client: the call raises instead of returning the invalid result." + ), + ), + "client:output-schema:missing-structured": Requirement( + source="sdk", + behavior="A tool that declares an output schema but returns no structuredContent fails client-side validation.", + ), + "client:output-schema:auto-list": Requirement( + source="sdk", + behavior=( + "Calling a tool whose output schema is not yet cached issues an implicit tools/list to " + "populate the cache; subsequent calls of the same tool do not." + ), + divergence=Divergence( + note=( + "Design concern rather than spec violation: the implicit request is invisible to the " + "caller, and against a server that registers only on_call_tool a successful call surfaces " + "as METHOD_NOT_FOUND from a tools/list the caller never asked for." + ), + ), + ), + "mcpserver:output-schema:missing-structured": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#output-schema", + behavior="A tool with an output schema whose function returns no structured content produces a server error.", + ), + "mcpserver:output-schema:server-validate": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#output-schema", + behavior=( + "MCPServer validates structured content against the tool's output schema before returning; a " + "mismatch produces a server error." + ), + ), + "mcpserver:output-schema:skip-on-error": Requirement( + source="sdk", + behavior="Server-side output schema validation is skipped when the tool returns an isError result.", + ), + "mcpserver:tool:duplicate-name": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool-names", + behavior="Registering a tool with a name already in use is rejected at registration time.", + divergence=Divergence( + note=( + "MCPServer logs a warning and keeps the first registration instead of rejecting; " + "warn_on_duplicate_tools defaults to True and warning is the only effect -- there is " + "no rejection mode." + ), + ), + ), + "mcpserver:tool:extra": Requirement( + source="sdk", + behavior=( + "Tool functions can access request metadata (request id, client params, session) through the " + "Context parameter." + ), + ), + "mcpserver:tool:handler-throws": Requirement( + source="sdk", + behavior=( + "An exception raised by a tool function (ToolError or otherwise) is caught and returned as a " + "tool result with isError true and the failure text in content; it does not become a JSON-RPC error." + ), + ), + "mcpserver:tool:input-validation": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#error-handling", + behavior=( + "Arguments that fail the tool's input validation produce a tool execution error (isError true " + "with the validation failure described in content) without invoking the function." + ), + ), + "mcpserver:tool:naming-validation": Requirement( + source="sdk", + behavior=( + "Registering a tool whose name violates the spec's tool-naming conventions emits a warning; " + "registration still succeeds." + ), + ), + "mcpserver:tool:output-schema:model": Requirement( + source="sdk", + behavior=( + "A tool returning a typed model advertises a matching generated outputSchema and returns the " + "model's fields as structuredContent alongside a serialised text block." + ), + ), + "mcpserver:tool:output-schema:wrapped": Requirement( + source="sdk", + behavior=( + "A tool returning a non-object type (primitive or list) wraps the value as {'result': ...} in " + "structuredContent, with a matching generated outputSchema." + ), + ), + "mcpserver:tool:schema-variants": Requirement( + source="sdk", + behavior=( + "Tool input schemas generated from complex parameter types (unions, nested models, " + "constrained types) validate and coerce arguments before the function runs." + ), + ), + "mcpserver:tool:unknown-name": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#error-handling", + behavior="tools/call for a name that was never registered returns a JSON-RPC error.", + divergence=Divergence( + note=( + "The spec classifies unknown tools as a protocol error (its example uses -32602 Invalid " + "params); MCPServer reports a tool execution error (isError true) instead. The low-level " + "path follows the spec example (see tools:call:unknown-name)." + ), + ), + ), + "mcpserver:tool:url-elicitation-error": Requirement( + source="sdk", + behavior=( + "A tool function that raises the URL-elicitation-required error surfaces to the caller as " + "error -32042 with the elicitation parameters intact." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # MCPServer: Context helpers (SDK) + # ═══════════════════════════════════════════════════════════════════════════ + "mcpserver:context:logging": Requirement( + source="sdk", + behavior=( + "The Context logging helpers (debug/info/warning/error) send log message notifications at the " + "corresponding severity." + ), + ), + "mcpserver:context:progress": Requirement( + source="sdk", + behavior=( + "Context.report_progress sends a progress notification against the requesting client's progress token." + ), + ), + "mcpserver:context:elicit": Requirement( + source="sdk", + behavior=( + "Context.elicit sends a form elicitation built from a typed schema and returns a typed " + "accepted/declined/cancelled result." + ), + ), + "mcpserver:context:read-resource": Requirement( + source="sdk", + behavior="Context.read_resource reads a resource registered on the same server from inside a tool.", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Resources + # ═══════════════════════════════════════════════════════════════════════════ + "resources:annotations": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#annotations", + behavior="Resource annotations supplied by the server round-trip to the client in the list result.", + divergence=Divergence( + note=( + "The SDK Annotations model is missing the schema's lastModified field; MCPModel uses the " + "pydantic default extra='ignore', so the value is silently dropped on parse." + ), + ), + ), + "resources:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#capabilities", + behavior=( + "A server with resource handlers advertises the resources capability, including the subscribe " + "sub-flag when a subscribe handler is registered." + ), + ), + "resources:list-changed": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#list-changed-notification", + behavior=( + "When the resource set changes, the server sends notifications/resources/list_changed and it " + "reaches the client's handler." + ), + ), + "resources:list:basic": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#listing-resources", + behavior=( + "resources/list returns the registered resources with uri, name, and the optional descriptive " + "fields supplied by the server." + ), + ), + "resources:list:pagination": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#operations-supporting-pagination", + behavior="resources/list supports cursor pagination.", + ), + "resources:read:blob": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#reading-resources", + behavior="resources/read returns binary contents base64-encoded in blob.", + ), + "resources:read:template-vars": Requirement( + source="sdk", + behavior="Variables extracted from a templated resource URI reach the resource function as typed arguments.", + ), + "resources:read:text": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#reading-resources", + behavior="resources/read returns text contents carrying uri, mimeType, and the text.", + ), + "resources:read:unknown-uri": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#error-handling", + behavior="resources/read for an unknown URI returns JSON-RPC error -32002 (resource not found).", + ), + "resources:subscribe": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#subscriptions", + behavior="resources/subscribe delivers the URI to the server's subscribe handler and returns an empty result.", + ), + "resources:subscribe:capability-required": Requirement( + source="sdk", + behavior=( + "resources/subscribe to a server that did not advertise the subscribe capability is rejected with an error." + ), + ), + "resources:subscribe:updated": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#subscriptions", + behavior="After resources/subscribe, changes to that resource send notifications/resources/updated.", + deferred=( + "Not implemented in the SDK: the server keeps no subscription state linking subscribe to " + "updated notifications; emitting updates is entirely handler code. The two halves are pinned " + "separately by resources:subscribe and resources:updated-notification." + ), + ), + "resources:templates:list": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#resource-templates", + behavior=( + "resources/templates/list returns the registered templates with their uriTemplate and descriptive fields." + ), + ), + "resources:templates:pagination": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#operations-supporting-pagination", + behavior="resources/templates/list supports cursor pagination.", + ), + "resources:unsubscribe": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#subscriptions", + behavior=( + "resources/unsubscribe delivers the URI to the server's unsubscribe handler and returns an empty result." + ), + ), + "resources:unsubscribe:stops-updates": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#subscriptions", + behavior="After resources/unsubscribe the server stops sending updated notifications for that URI.", + deferred=( + "Not implemented in the SDK: the server keeps no subscription state, so whether updated " + "notifications stop after unsubscribe is entirely handler code; there is no SDK behaviour to " + "pin beyond the unsubscribe request reaching the handler (covered by resources:unsubscribe)." + ), + ), + "resources:updated-notification": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#subscriptions", + behavior=( + "A resources/updated notification sent by the server reaches the client carrying the URI of " + "the changed resource." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Resources: SDK guarantees + # ═══════════════════════════════════════════════════════════════════════════ + "mcpserver:resource:duplicate-name": Requirement( + source="sdk", + behavior="Registering a resource or template with a duplicate identifier is rejected at registration time.", + divergence=Divergence( + note=( + "MCPServer logs a warning and keeps the first registration instead of rejecting; same " + "warn-and-ignore behaviour as duplicate tool names (mcpserver:tool:duplicate-name). " + "Templates differ: a duplicate uri_template silently replaces the first with no warning." + ), + ), + ), + "mcpserver:resource:read-throws-surfaced": Requirement( + source="sdk", + behavior="A resource function that raises is surfaced to the caller as a JSON-RPC error response.", + ), + "mcpserver:resource:static": Requirement( + source="sdk", + behavior=( + "A function registered with @mcp.resource() for a fixed URI is listed by resources/list and " + "served by resources/read at that URI." + ), + ), + "mcpserver:resource:template": Requirement( + source="sdk", + behavior=( + "A function registered with a URI template is listed by resources/templates/list and matched " + "by resources/read, receiving the parameters extracted from the requested URI." + ), + ), + "mcpserver:resource:unknown-uri": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#error-handling", + behavior="resources/read for a URI matching no registered resource returns JSON-RPC error -32002.", + divergence=Divergence( + note=( + "The spec reserves -32002 for resource-not-found; MCPServer raises ResourceError, which " + "the low-level server converts to error code 0." + ), + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Prompts + # ═══════════════════════════════════════════════════════════════════════════ + "prompts:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#capabilities", + behavior="A server with a list_prompts handler advertises the prompts capability in its initialize result.", + ), + "prompts:get:content:audio": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#audio-content", + behavior="Prompt messages may contain audio content with base64 data and a mimeType.", + ), + "prompts:get:content:embedded-resource": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#embedded-resources", + behavior="Prompt messages may contain embedded resource content.", + ), + "prompts:get:content:image": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#image-content", + behavior="Prompt messages may contain image content.", + ), + "prompts:get:missing-required-args": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#error-handling", + behavior="prompts/get omitting a required argument returns JSON-RPC error -32602 (Invalid params).", + divergence=Divergence( + note=( + "MCPServer's prompt renderer raises a plain ValueError before the prompt function runs, " + "which the low-level server converts to error code 0 with the exception text as the message." + ), + ), + ), + "prompts:get:multi-message": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#getting-a-prompt", + behavior="A prompt can return multiple messages mixing user and assistant roles; order is preserved.", + ), + "prompts:get:no-args": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#getting-a-prompt", + behavior="prompts/get with no arguments returns the prompt's messages.", + ), + "prompts:get:unknown-name": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#error-handling", + behavior="prompts/get for an unknown prompt name returns JSON-RPC error -32602 (Invalid params).", + ), + "prompts:get:with-args": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#getting-a-prompt", + behavior="prompts/get delivers the supplied arguments to the prompt handler and returns its messages.", + ), + "prompts:list-changed": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#list-changed-notification", + behavior=( + "When the prompt set changes, the server sends notifications/prompts/list_changed and it " + "reaches the client's handler." + ), + ), + "prompts:list:basic": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#listing-prompts", + behavior="prompts/list returns the registered prompts with name, description, and argument declarations.", + ), + "prompts:list:pagination": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#operations-supporting-pagination", + behavior="prompts/list supports cursor pagination.", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Prompts: SDK guarantees + # ═══════════════════════════════════════════════════════════════════════════ + "mcpserver:prompt:args-validation": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#implementation-considerations", + behavior="prompts/get arguments that fail the prompt's argument schema are rejected before the function runs.", + ), + "mcpserver:prompt:decorated": Requirement( + source="sdk", + behavior=( + "A function registered with @mcp.prompt() is listed with arguments derived from its signature " + "and rendered into prompt messages by prompts/get." + ), + ), + "mcpserver:prompt:duplicate-name": Requirement( + source="sdk", + behavior="Registering a duplicate prompt name is rejected at registration time.", + divergence=Divergence( + note=( + "MCPServer logs a warning and keeps the first registration instead of rejecting; same " + "warn-and-ignore behaviour as duplicate tool names (mcpserver:tool:duplicate-name)." + ), + ), + ), + "mcpserver:prompt:optional-args": Requirement( + source="sdk", + behavior="A prompt with optional arguments can be fetched without supplying them.", + ), + "mcpserver:prompt:unknown-name": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#error-handling", + behavior="prompts/get for a name that was never registered returns JSON-RPC error -32602 (Invalid params).", + divergence=Divergence( + note=( + "The spec's example uses -32602 Invalid params for unknown prompts; MCPServer raises " + "ValueError, which the low-level server converts to error code 0." + ), + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Completion + # ═══════════════════════════════════════════════════════════════════════════ + "completion:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#capabilities", + behavior="A server with a completion handler advertises the completions capability in its initialize result.", + ), + "completion:complete:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#capabilities", + behavior=( + "A server with no completion handler does not advertise the completions capability and rejects " + "completion/complete with METHOD_NOT_FOUND." + ), + ), + "completion:context-arguments": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#requesting-completions", + behavior="Previously-resolved argument values supplied in context.arguments reach the completion handler.", + ), + "completion:error:invalid-ref": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#error-handling", + behavior=( + "completion/complete with a ref naming an unknown prompt or non-matching resource URI returns " + "JSON-RPC error -32602 (Invalid params)." + ), + ), + "completion:prompt-arg": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#reference-types", + behavior="completion/complete with a ref/prompt returns suggested values for the named prompt argument.", + ), + "completion:resource-template-arg": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#reference-types", + behavior="completion/complete with a ref/resource returns suggested values for a URI template variable.", + ), + "completion:result-shape": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#completion-results", + behavior="The completion result carries values (at most 100), an optional total, and an optional hasMore flag.", + ), + "mcpserver:completion:capability-auto": Requirement( + source="sdk", + behavior=( + "MCPServer advertises the completions capability when at least one completion source is " + "registered, and omits it otherwise." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Logging + # ═══════════════════════════════════════════════════════════════════════════ + "logging:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#capabilities", + behavior=( + "A server that emits log message notifications declares the logging capability in its initialize result." + ), + divergence=Divergence( + note=( + "MCPServer registers no setLevel handler, so capability derivation leaves logging unset " + "even though the Context helpers send log message notifications." + ), + ), + ), + "logging:message:all-levels": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#log-levels", + behavior="All eight RFC 5424 severity levels are deliverable as log message notifications.", + ), + "logging:message:fields": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#log-message-notifications", + behavior=( + "A log message sent by a server handler is delivered to the client's logging callback with its " + "severity level, logger name, and data." + ), + ), + "logging:message:filtered": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#setting-log-level", + behavior="After logging/setLevel, log messages below the configured level are not sent.", + divergence=Divergence( + note=( + "Neither MCPServer (which rejects logging/setLevel with method-not-found) nor the " + "low-level Server (which leaves the handler entirely to the author) implements any " + "filtering; messages are delivered at every severity regardless of the requested level." + ), + ), + ), + "logging:set-level": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#setting-log-level", + behavior="logging/setLevel delivers the requested level to the server's handler and returns an empty result.", + ), + "logging:set-level:invalid-level": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#error-handling", + behavior="logging/setLevel with an invalid level value returns JSON-RPC error -32602 (Invalid params).", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Sampling (server → client) + # ═══════════════════════════════════════════════════════════════════════════ + "sampling:capability:declare": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#capabilities", + behavior=( + "A client that handles sampling requests advertises the sampling capability in its initialize request." + ), + ), + "sampling:create:basic": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", + behavior=( + "A sampling/createMessage request from a server handler is answered by the client's sampling " + "callback, and the callback's result (role, content, model, stopReason) is returned to the handler." + ), + ), + "sampling:create:include-context": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#capabilities", + behavior="The includeContext value supplied by the server reaches the client callback intact.", + ), + "sampling:context:server-gated-by-capability": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#capabilities", + behavior=( + "The server does not use includeContext values thisServer or allServers unless the client " + "declared the sampling.context capability." + ), + divergence=Divergence( + note=( + "include_context is forwarded regardless of the client's declared sampling.context " + "capability; the server-side validator only checks tools/tool_choice." + ), + ), + ), + "sampling:create:model-preferences": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#model-preferences", + behavior=( + "The model preferences supplied by the server (hints and the cost, speed, and intelligence " + "priorities) reach the client callback intact." + ), + ), + "sampling:create:system-prompt": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", + behavior="The system prompt supplied by the server reaches the client callback intact.", + ), + "sampling:create:tools": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#tools-in-sampling", + behavior=( + "A sampling request carrying tools and toolChoice reaches the client, and a tool_use response " + "with a toolUse stop reason returns to the requesting handler." + ), + deferred=( + "Not implemented in the SDK: Client does not expose ClientSession's sampling_capabilities " + "parameter, so a client can never declare sampling.tools and the server-side validator " + "rejects every tool-enabled request before it is sent." + ), + ), + "sampling:create-message:audio-content": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#audio-content", + behavior="Sampling messages can carry audio content: base64 data with a mimeType.", + ), + "sampling:create-message:image-content": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#image-content", + behavior="Sampling messages can carry image content: base64 data with a mimeType.", + ), + "sampling:create-message:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#capabilities", + behavior=( + "A sampling request to a client that did not declare the sampling capability fails with an " + "error rather than hanging or being silently dropped; the spec names no error code for this case." + ), + ), + "sampling:error:user-rejected": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#error-handling", + behavior=( + "A sampling request the user rejects is answered with a JSON-RPC error (the spec's code for " + "this case is -1, 'User rejected sampling request'), surfaced to the requesting handler as an MCPError." + ), + ), + "sampling:message:content-cardinality": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling", + behavior="A sampling message's content may be a single block or an array of blocks.", + ), + "sampling:result:no-tools-single-content": Requirement( + source="sdk", + behavior=( + "When the request carries no tools, a sampling callback result whose content is an array is " + "rejected by the client." + ), + divergence=Divergence( + note=( + "The client does not validate the callback result against the request shape; an array-content " + "result for a tool-free request is accepted client-side and surfaces as a raw " + "pydantic.ValidationError from the server's response parsing (send_request) instead." + ), + ), + ), + "sampling:result:with-tools-array-content": Requirement( + source="sdk", + behavior=( + "When the request includes tools, the client accepts a callback result whose content is an " + "array including tool_use blocks." + ), + deferred=( + "Not implemented in the SDK: requires declaring sampling.tools, which the high-level client " + "cannot do (see sampling:create:tools)." + ), + ), + "sampling:tool-result:no-mixed-content": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#tool-result-messages", + behavior=( + "A user sampling message that carries tool_result content contains only tool_result blocks; " + "mixing tool_result with text, image, or audio content is rejected as invalid." + ), + ), + "sampling:tool-use:result-balance": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#tool-use-and-result-balance", + behavior=( + "In a sampling/createMessage request, every assistant tool_use block in messages MUST be " + "matched by a tool_result with the same toolUseId in the immediately-following user message; " + "an unmatched tool_use is rejected with -32602 Invalid params." + ), + divergence=Divergence( + note=( + "The client does not validate inbound tool_use/tool_result balance; the SDK enforces " + "the rule server-side instead, before the request leaves the server (see " + "sampling:tool-use:server-preflight)." + ), + ), + deferred=( + "Not implemented on the client receive path: validation runs only on the server send path " + "(pinned by sampling:tool-use:server-preflight)." + ), + ), + "sampling:tool-use:server-preflight": Requirement( + source="sdk", + behavior=( + "The server validates tool_use/tool_result balance before sending a sampling/createMessage " + "request; an unmatched tool_use raises ValueError and the request never reaches the wire." + ), + ), + "sampling:tools:server-gated-by-capability": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#tools-in-sampling", + behavior=( + "A tool-enabled sampling request to a client that did not declare sampling.tools is rejected " + "by the server before anything reaches the wire (the SDK surfaces this as an Invalid params error)." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Elicitation (server → client) + # ═══════════════════════════════════════════════════════════════════════════ + "elicitation:capability:empty-is-form": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#capabilities", + behavior="A client advertising an empty elicitation capability accepts form-mode elicitation requests.", + deferred=( + "Not implemented in the SDK: a Client with an elicitation callback always declares explicit " + "form and url sub-capabilities, so an empty elicitation capability cannot be produced through " + "the public API." + ), + ), + "elicitation:capability:mode-mismatch": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#error-handling", + behavior=( + "The client answers elicitation requests for a mode it did not advertise with JSON-RPC error " + "-32602 (Invalid params)." + ), + deferred=( + "Not implemented in the SDK: a client cannot be configured form-only or url-only, so the " + "per-mode mismatch error cannot arise (see elicitation:url:not-supported)." + ), + ), + "elicitation:capability:server-respects-mode": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#capabilities", + behavior=( + "The server refuses to send an elicitation request with a mode the connected client did not " + "declare in its capabilities." + ), + divergence=Divergence( + note=( + "The server does not check the client's declared elicitation modes before sending " + "elicitation/create; the spec's MUST NOT is not enforced." + ), + ), + ), + "elicitation:form:action:accept": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior=( + "A form-mode elicitation answered with action 'accept' returns the user's content to the " + "requesting handler." + ), + ), + "elicitation:form:action:cancel": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior="A form-mode elicitation answered with action 'cancel' returns no content to the handler.", + ), + "elicitation:form:action:decline": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior="A form-mode elicitation answered with action 'decline' returns no content to the handler.", + ), + "elicitation:form:basic": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#form-mode-elicitation-requests", + behavior=( + "A form-mode elicitation delivers the message and requested schema to the client callback " + "exactly as the server sent them." + ), + ), + "elicitation:form:defaults": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#requested-schema", + behavior=( + "Optional default values declared in a form-mode requested schema are pre-populated into the " + "form presented to the user." + ), + deferred=( + "Not implemented in the SDK: there is no form-rendering layer that could pre-populate " + "defaults; client callbacks receive the requested schema as-is." + ), + ), + "elicitation:form:mode-omitted-default": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#elicitation-requests", + behavior="An elicitation request with no mode field is treated as form mode by the client.", + ), + "elicitation:form:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#error-handling", + behavior=( + "An elicitation request to a client that did not declare the elicitation capability is " + "answered with -32602 Invalid params." + ), + divergence=Divergence( + note="The client's default callback answers with -32600 Invalid request instead of -32602.", + ), + ), + "elicitation:form:schema:enum-variants": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#requested-schema", + behavior=( + "Requested-schema enum fields (including titled and multi-select variants) reach the client " + "callback as sent." + ), + ), + "elicitation:form:schema:primitives": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#requested-schema", + behavior="Requested-schema fields may be string (with format), number or integer, or boolean.", + ), + "elicitation:form:schema:restricted-subset": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#requested-schema", + behavior=( + "Form-mode requested schemas are flat objects with primitive-typed properties only; nested " + "structures and arrays of objects are not used." + ), + divergence=Divergence( + note=( + "ServerSession.elicit_form forwards an arbitrary dict[str, Any] schema unchanged; no shape " + "validation at the low-level session layer (the high-level Context.elicit / " + "elicit_with_validation helper enforces primitive-only fields before generating the schema)." + ), + ), + ), + "elicitation:form:response-validation": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#form-mode-security", + behavior=( + "Accepted form-mode content is validated against the requested schema: the client validates " + "the response before sending and the server validates the content it receives." + ), + divergence=Divergence( + note=( + "The client never validates outbound content; ServerSession.elicit_form returns received " + "content unvalidated (the high-level Context.elicit / elicit_with_validation helper " + "validates server-side, but the low-level session API does not)." + ), + ), + ), + "elicitation:url:action:accept-no-content": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior=( + "A URL-mode elicitation delivers the message, URL, and elicitationId to the client; an accept " + "response carries no content (accept means the user agreed to visit the URL, not that the " + "interaction completed)." + ), + ), + "elicitation:url:basic": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#url-mode-elicitation-requests", + behavior=( + "A url-mode elicitation delivers the elicitation id and URL to the client callback exactly as " + "the server sent them." + ), + ), + "elicitation:url:cancel": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior="A URL-mode elicitation answered with cancel returns the action with no content.", + ), + "elicitation:url:complete-notification": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#completion-notifications-for-url-mode-elicitation", + behavior=( + "An elicitation/complete notification sent by the server after an out-of-band elicitation " + "finishes reaches the client carrying the elicitationId." + ), + ), + "elicitation:url:complete-unknown-ignored": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#completion-notifications-for-url-mode-elicitation", + behavior=( + "The client ignores an elicitation/complete notification referencing an unknown or " + "already-completed elicitationId without error." + ), + ), + "elicitation:url:decline": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior="A URL-mode elicitation answered with decline returns the action with no content.", + ), + "elicitation:url:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#error-handling", + behavior=( + "A URL-mode elicitation to a client that declared only form-mode support is rejected with an " + "Invalid params error." + ), + deferred=( + "Not implemented in the SDK: a Client with an elicitation callback always declares both the " + "form and url sub-capabilities, so a form-only client cannot be constructed." + ), + ), + "elicitation:url:required-error": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#url-elicitation-required-error", + behavior=( + "A handler that cannot proceed without a URL elicitation rejects the request with error " + "-32042, carrying the pending elicitations in the error data." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Roots (server → client) + # ═══════════════════════════════════════════════════════════════════════════ + "roots:list-changed": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#root-list-changes", + behavior="A roots/list_changed notification sent by the client is delivered to the server's handler.", + ), + "roots:list-changed:client-emits": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#root-list-changes", + behavior=( + "A client that declared roots.listChanged sends notifications/roots/list_changed when its set " + "of roots changes." + ), + deferred=( + "Not implemented in the SDK: the client does not own the root set (it calls back to the host " + "via list_roots_callback), so there is no mutation it could observe to auto-emit on; the SDK " + "provides send_roots_list_changed() for the host to call when its roots change, and that " + "emission path is covered by roots:list-changed." + ), + ), + "roots:list:basic": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#listing-roots", + behavior=( + "A roots/list request from a server handler is answered by the client's roots callback, and " + "the returned roots (uri, name) reach the handler." + ), + ), + "roots:list:client-error": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#error-handling", + behavior="A roots callback that answers with an error surfaces to the requesting handler as an MCPError.", + ), + "roots:list:empty": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#listing-roots", + behavior="An empty roots list is a valid response and reaches the handler as such.", + ), + "roots:list:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#error-handling", + behavior=( + "A roots/list request to a client that did not declare the roots capability is answered with " + "-32601 Method not found." + ), + divergence=Divergence( + note="The client's default callback answers with -32600 Invalid request instead of -32601.", + ), + ), + "roots:uri:file-scheme": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#root", + behavior="Every root returned by the client identifies itself with a file:// URI.", + deferred=( + "Schema-level validation: the FileUrl type on Root.uri rejects any non-file:// scheme at " + "construction and at parse, so a non-conforming root cannot reach the wire from either side; " + "type-level coverage belongs in tests/test_types.py rather than this interaction suite." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # list_changed & dynamic registration + # ═══════════════════════════════════════════════════════════════════════════ + "client:list-changed:auto-refresh": Requirement( + source="sdk", + behavior=( + "A client configured to react to list_changed notifications automatically re-fetches the " + "corresponding list and delivers the fresh result to its callback." + ), + deferred=( + "Not implemented in the SDK: the client has no list-changed auto-refresh mechanism; " + "notifications are only delivered to the message handler." + ), + ), + "client:list-changed:capability-gated": Requirement( + source="sdk", + behavior=( + "The client does not activate list-changed handling for a kind the server did not advertise " + "with listChanged true." + ), + deferred="Not implemented in the SDK: no client-side list-changed handling exists to gate.", + ), + "client:list-changed:signal-only": Requirement( + source="sdk", + behavior="A client configured for signal-only list-changed handling is notified without auto-refreshing.", + deferred="Not implemented in the SDK: no client-side list-changed handling exists.", + ), + "mcpserver:list-changed:debounce": Requirement( + source="sdk", + behavior=( + "Bursts of registration changes on MCPServer are debounced into one list_changed notification per kind." + ), + deferred=( + "Not implemented in the SDK: MCPServer does not send list_changed notifications on " + "registration changes at all (see mcpserver:register:post-connect), so there is nothing to " + "debounce." + ), + ), + "mcpserver:register:post-connect": Requirement( + source="sdk", + behavior=( + "A tool, resource, or prompt registered or removed after the client connected appears in (or " + "disappears from) the corresponding list results, and the change is announced with a " + "list_changed notification." + ), + divergence=Divergence( + note=( + "MCPServer never sends list_changed notifications on registration changes, so a connected " + "client cannot learn that the set changed without polling." + ), + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Pagination + # ═══════════════════════════════════════════════════════════════════════════ + "pagination:exhaustion": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#response-format", + behavior=( + "Following nextCursor until it is absent yields every page exactly once; a result without " + "nextCursor ends the sequence." + ), + ), + "pagination:invalid-cursor": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#error-handling", + behavior="A list request with an invalid cursor returns JSON-RPC error -32602 (Invalid params).", + ), + "pagination:client:cursor-handling": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#implementation-guidelines", + behavior=( + "The client treats cursors as opaque tokens — it does not parse, modify, or persist them — " + "and does not assume a fixed page size." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Tasks (experimental) + # ═══════════════════════════════════════════════════════════════════════════ + "tasks:auth:context-isolation": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-isolation-and-access-control", + behavior=( + "When an authorization context is available, task operations are scoped to the context that " + "created the task: other contexts cannot get it, retrieve its result, cancel it, or see it in " + "tasks/list." + ), + transports=("streamable-http",), + deferred=_TASKS_DEFERRAL, + ), + "tasks:bidirectional": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#definitions", + behavior="Task APIs are bidirectional: the server may create, get, list, and cancel tasks on the client.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:cancel:no-handler-abort": Requirement( + source="sdk", + behavior=( + "tasks/cancel marks the task cancelled without aborting the originating request handler " + "(the spec says receivers SHOULD attempt to stop execution)." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:cancel:remains-cancelled": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-cancellation", + behavior=( + "After tasks/cancel, the task remains cancelled even if the underlying handler subsequently " + "completes or fails." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:cancel:terminal-rejected": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-cancellation", + behavior="tasks/cancel on a task already in a terminal state returns Invalid params (-32602).", + deferred=_TASKS_DEFERRAL, + ), + "tasks:cancel:working": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-cancellation", + behavior="tasks/cancel on a working task transitions it to cancelled and returns the updated task.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:create:ttl-honored": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#ttl-and-resource-management", + behavior=( + "tasks/get responses include the actual ttl applied by the receiver (or null for unlimited); " + "the create-task result carries the same value." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:create:via-tool-call": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#creating-tasks", + behavior="A task-augmented tools/call returns a create-task result instead of the tool result.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:get": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#getting-tasks", + behavior="tasks/get returns the task's current status, ttl, timestamps, and status message.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:lifecycle:initial-working": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-status-lifecycle", + behavior="A newly created task has status 'working'.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:lifecycle:input-required": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#input-required-status", + behavior=( + "While a task awaits a side-channel client response its status is input_required; once the " + "response arrives the task leaves input_required (typically returning to working)." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:list:invalid-cursor": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#protocol-errors", + behavior="tasks/list with an invalid cursor returns Invalid params (-32602).", + deferred=_TASKS_DEFERRAL, + ), + "tasks:list:pagination": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#listing-tasks", + behavior="tasks/list returns created tasks and supports cursor pagination.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:no-capability:ignore-task-param": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-support-and-handling", + behavior=( + "A receiver that did not declare task capability for a request type processes the request " + "normally and returns the ordinary result, ignoring the task augmentation." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:progress:after-create": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-progress-notifications", + behavior=( + "After the create-task result, progress notifications keyed to the original progress token " + "continue to reach the caller until the task is terminal." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:request-cancel:no-task-cancel": Requirement( + source="sdk", + behavior="A cancellation notification for the originating request does not auto-cancel the created task.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:result:failed": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-execution-errors", + behavior="tasks/result for a failed task returns the failure result (isError true).", + deferred=_TASKS_DEFERRAL, + ), + "tasks:result:related-task-meta": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#related-task-metadata", + behavior="The tasks/result response carries related-task _meta naming the requested task.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:result:terminal": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#result-retrieval", + behavior="tasks/result for a completed task returns the stored result of the original request type.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:drain-fifo": Requirement( + source="sdk", + behavior="tasks/result drains queued related-task messages in FIFO order before returning the final result.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:drop-on-cancel": Requirement( + source="sdk", + behavior="When a task is cancelled before tasks/result, queued related-task messages are dropped.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:elicitation": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#input-required-status", + behavior=( + "An elicitation issued mid-task is delivered through the tasks/result side-channel, and the " + "client's response routes back to the handler." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:queue": Requirement( + source="sdk", + behavior=( + "Server-to-client requests with related-task metadata sent while no tasks/result is open are queued." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:sampling": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#input-required-status", + behavior=( + "A sampling request issued mid-task is delivered through the tasks/result side-channel, and " + "the client's response routes back to the task." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:stream": Requirement( + source="sdk", + behavior=( + "Calling tasks/result while the task is working streams related-task messages as they are " + "produced, then returns the result." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:status-notification": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-status-notification", + behavior="Task status notifications deliver status updates carrying the full task fields.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:tool-level:forbidden-with-task-32601": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#tool-level-negotiation", + behavior=( + "A task-augmented tools/call on a tool that does not support tasks returns Method not found (-32601)." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:tool-level:required-no-task-32601": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#tool-level-negotiation", + behavior=("A plain tools/call on a tool that requires task augmentation returns Method not found (-32601)."), + deferred=_TASKS_DEFERRAL, + ), + "tasks:unknown-id": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#protocol-errors", + behavior="tasks/get, tasks/result, and tasks/cancel for an unknown task id return Invalid params (-32602).", + deferred=_TASKS_DEFERRAL, + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Transports (in-suite coverage) + # ═══════════════════════════════════════════════════════════════════════════ + "transport:streamable-http:stateful": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "The interaction round trip (initialize, tool calls, tool errors) works through the " + "streamable HTTP framing in its default stateful SSE-response mode." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:json-response": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior="The interaction round trip works when the server answers with plain JSON instead of SSE.", + transports=("streamable-http",), + ), + "transport:streamable-http:stateless": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "The interaction round trip works in stateless mode, where every request is served by a " + "fresh transport with no session id." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:notifications": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "Notifications emitted during a request are delivered on that request's SSE stream and reach " + "the client's callbacks, in order, before the response." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:stateless-restrictions": Requirement( + source="sdk", + behavior=( + "A handler that attempts a server-initiated request in stateless mode fails with an error " + "result, because there is no session to call back through." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:unrelated-messages": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "A server-to-client message that is not related to an in-flight request is routed to the " + "standalone GET stream and delivered to the client listening on it, not to any request's " + "own stream." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:server-to-client": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "A server-initiated request nested inside an in-flight call round-trips over stateful streamable HTTP." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:resumability": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior="A client that reconnects with Last-Event-ID receives the events it missed.", + transports=("streamable-http",), + ), + "transport:streamable-http:origin-validation": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#security-warning", + behavior="Requests with an invalid Origin header are rejected with 403 before reaching the session.", + transports=("streamable-http",), + ), + "transport:sse": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#backwards-compatibility", + behavior=( + "A client connected over the legacy HTTP+SSE transport completes the handshake and round-trips " + "requests, with server messages delivered on the SSE stream." + ), + transports=("sse",), + ), + "transport:sse:endpoint-event": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#backwards-compatibility", + behavior="Opening the SSE stream delivers an `endpoint` event naming the message-POST URL as the first event.", + transports=("sse",), + ), + "transport:sse:post:session-routing": Requirement( + source="sdk", + behavior=( + "The endpoint URL carries a fresh session identifier; the server registers the session before " + "the endpoint event is sent and releases it when the stream disconnects, and a POST that names " + "no session id, a malformed session id, or an unknown session id is rejected (400/400/404)." + ), + transports=("sse",), + ), + "transport:stdio": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#stdio", + behavior=( + "A Client connected to a real SDK Server over stdio initializes, calls a tool with arguments, " + "and receives notifications and results over the child process's stdin/stdout." + ), + transports=("stdio",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Hosting: session lifecycle + # ═══════════════════════════════════════════════════════════════════════════ + "hosting:session:cors-expose": Requirement( + source="sdk", + behavior="CORS configuration exposes the Mcp-Session-Id header so browser clients can read it.", + transports=("streamable-http",), + deferred="Not implemented in the SDK: CORS configuration is left to the hosting ASGI application.", + ), + "hosting:session:create": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior=( + "An initialize POST without a session id creates a session and returns Mcp-Session-Id in the " + "response headers." + ), + transports=("streamable-http",), + ), + "hosting:session:delete": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="DELETE with a valid Mcp-Session-Id terminates the session.", + transports=("streamable-http",), + ), + "hosting:session:id-charset": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="Generated Mcp-Session-Id values contain only visible ASCII characters.", + transports=("streamable-http",), + ), + "hosting:session:isolation": Requirement( + source="sdk", + behavior="Each session gets its own server instance; closing one session does not affect others.", + transports=("streamable-http",), + ), + "hosting:session:missing-id": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="A non-initialize POST without Mcp-Session-Id in stateful mode returns 400.", + transports=("streamable-http",), + ), + "hosting:session:post-termination-404": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior=( + "After a session is terminated, any further request carrying that session ID is answered with " + "404 Not Found." + ), + transports=("streamable-http",), + ), + "hosting:session:reinitialize": Requirement( + source="sdk", + behavior="A second initialize on an already-initialized session transport is rejected.", + transports=("streamable-http",), + divergence=Divergence( + note=( + "The transport forwards a second initialize carrying the existing session ID to the running " + "server, which answers it as a fresh handshake; nothing rejects re-initialization." + ), + ), + ), + "hosting:session:reuse": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="A POST carrying a valid Mcp-Session-Id routes to that session's transport with state preserved.", + transports=("streamable-http",), + ), + "hosting:session:unknown-id": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="A POST, GET, or DELETE with an unknown Mcp-Session-Id returns 404.", + transports=("streamable-http",), + ), + "hosting:stateless:concurrent-clients": Requirement( + source="sdk", + behavior="Multiple independent clients can connect to a stateless server concurrently.", + transports=("streamable-http",), + ), + "hosting:stateless:no-reuse": Requirement( + source="sdk", + behavior="A stateless per-request transport cannot be reused for a second request.", + transports=("streamable-http",), + ), + "hosting:stateless:no-session-id": Requirement( + source="sdk", + behavior="In stateless mode no Mcp-Session-Id is emitted and no session validation is performed.", + transports=("streamable-http",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Hosting: auth + # ═══════════════════════════════════════════════════════════════════════════ + "hosting:auth:as-router": Requirement( + source="sdk", + behavior=( + "The authorization-server routes expose the authorize, token, and registration endpoints " + "(and revocation when supported)." + ), + transports=("streamable-http",), + ), + "hosting:auth:aud-validation": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#access-token-usage", + behavior="The resource server validates that the token audience matches its resource identifier.", + transports=("streamable-http",), + divergence=Divergence( + note=( + "BearerAuthBackend never inspects AccessToken.resource; a token issued for a different " + "resource is accepted. Spec MUST." + ), + ), + ), + "hosting:auth:authinfo-propagates": Requirement( + source="sdk", + behavior="A valid token's auth info is exposed to request handlers.", + transports=("streamable-http",), + ), + "hosting:auth:expired-401": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#token-handling", + behavior="An expired token returns 401 invalid_token.", + transports=("streamable-http",), + divergence=Divergence( + note="The challenge carries no `scope` parameter; see the note on hosting:auth:missing-401.", + ), + ), + "hosting:auth:invalid-401": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#token-handling", + behavior="A malformed bearer token or token-verification failure returns 401 with WWW-Authenticate.", + transports=("streamable-http",), + divergence=Divergence( + note="The challenge carries no `scope` parameter; see the note on hosting:auth:missing-401.", + ), + ), + "hosting:auth:metadata-endpoints": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-location", + behavior=( + "The MCP server publishes protected-resource metadata at its well-known endpoint, and the " + "authorization server (which the SDK can also host) publishes authorization-server metadata " + "at its own." + ), + transports=("streamable-http",), + ), + "hosting:auth:missing-401": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#protected-resource-metadata-discovery-requirements", + behavior=( + "A request without an Authorization header is rejected with 401; the WWW-Authenticate header " + "carries resource_metadata (one of the spec's two permitted discovery mechanisms)." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The SDK never emits a `scope` parameter in any WWW-Authenticate challenge — neither the " + "discovery-time 401 (#protected-resource-metadata-discovery-requirements SHOULD) nor the " + "runtime 403 (#runtime-insufficient-scope-errors SHOULD); and for the no-credentials case " + 'it emits error="invalid_token", which RFC 6750 Section 3.1 says SHOULD NOT appear when no ' + "authentication information was presented." + ), + ), + ), + "hosting:auth:prm:authorization-servers-field": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-location", + behavior=( + "The protected-resource metadata document includes an authorization_servers array with at least one entry." + ), + transports=("streamable-http",), + ), + "hosting:auth:query-token-ignored": Requirement( + source="sdk", + behavior=( + "An access token presented in the URI query string is not accepted; the request is treated as " + "unauthenticated." + ), + transports=("streamable-http",), + ), + "hosting:auth:scope-403": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#runtime-insufficient-scope-errors", + behavior=( + "A token lacking a required scope returns 403 with WWW-Authenticate carrying " + "insufficient_scope, the required scope, and resource_metadata." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + 'The SDK emits error="insufficient_scope" and error_description but never the `scope` ' + "parameter the spec SHOULD include; the SDK client reads `scope` from this header to drive " + "step-up (utils.py extract_scope_from_www_auth) — a resource-server/client asymmetry." + ), + ), + ), + "hosting:auth:as:authorize-requires-pkce": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-code-protection", + behavior=( + "The bundled authorization endpoint rejects an authorize request that omits " + "`code_challenge` with `invalid_request`." + ), + transports=("streamable-http",), + ), + "hosting:auth:as:verifier-mismatch": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-code-protection", + behavior=( + "The bundled token endpoint rejects an authorization-code exchange whose `code_verifier` " + "does not hash to the stored `code_challenge` with `invalid_grant`." + ), + transports=("streamable-http",), + ), + "hosting:auth:as:code-single-use": Requirement( + source="sdk", + behavior=( + "An authorization code can be exchanged exactly once; a second exchange of the same code " + "is rejected with `invalid_grant`. Enforced by the provider deleting the code on first use; " + "the handler relies on `load_authorization_code` returning None." + ), + transports=("streamable-http",), + ), + "hosting:auth:as:redirect-uri-binding": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#open-redirection", + behavior=( + "The bundled token endpoint rejects an authorization-code exchange whose `redirect_uri` " + "differs from the one used at authorize; the bundled authorize endpoint rejects a " + "`redirect_uri` not in the client's registered list without redirecting to it." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "RFC 6749 §5.2 assigns redirect_uri mismatch at the token endpoint to invalid_grant; " + "the SDK's TokenHandler returns invalid_request (src/mcp/server/auth/handlers/token.py:157). " + "The rejection itself is the security-relevant property and is correct." + ), + ), + ), + "hosting:auth:as:redirect-uri-scheme": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#communication-security", + behavior=( + "The bundled registration endpoint accepts only redirect URIs that use HTTPS or target a loopback host." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "Not enforced: the registration handler models redirect_uris as AnyUrl with no scheme or " + "host check, so http://evil.example/callback is accepted and registered. The spec's " + "localhost-or-HTTPS rule is left to the provider implementation." + ), + ), + ), + "hosting:auth:as:token-cache-headers": Requirement( + source="sdk", + behavior=("Every token-endpoint response carries `Cache-Control: no-store` and `Pragma: no-cache`."), + transports=("streamable-http",), + ), + "hosting:auth:as:register-error-response": Requirement( + source="sdk", + behavior=( + "The bundled registration endpoint answers invalid client metadata with HTTP 400 and an " + "RFC 7591 error body." + ), + transports=("streamable-http",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Hosting: resumability + # ═══════════════════════════════════════════════════════════════════════════ + "hosting:resume:bad-event-id": Requirement( + source="sdk", + behavior="A Last-Event-ID that cannot be mapped to a stream is rejected.", + transports=("streamable-http",), + divergence=Divergence( + note=( + "The replay path returns an empty SSE stream rather than rejecting an unknown " + "Last-Event-ID; the client cannot tell an unknown ID apart from a stream with no missed " + "events." + ), + ), + ), + "hosting:resume:buffered-replay": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior="Notifications emitted while no client is connected are replayed in order on reconnect.", + transports=("streamable-http",), + ), + "hosting:resume:close-stream": Requirement( + source="sdk", + behavior="Handlers can close an SSE stream cleanly when an event store is configured.", + transports=("streamable-http",), + ), + "hosting:resume:event-ids": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior="With an event store configured, every SSE event carries an id field.", + transports=("streamable-http",), + ), + "hosting:resume:priming": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A server-initiated SSE stream begins with a priming event carrying an event ID and an empty " + "data field; a server that closes the connection before terminating the stream sends an SSE " + "retry field first." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The retry hint is attached to the priming event itself rather than sent as a separate " + "event before the connection closes, and a priming event is only sent when an event store " + "is configured and the negotiated protocol version is at least 2025-11-25." + ), + ), + ), + "hosting:resume:replay": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior="GET with Last-Event-ID replays stored events for that stream after the given id.", + transports=("streamable-http",), + ), + "hosting:resume:stream-scoped": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior="Replay via Last-Event-ID returns only messages from the stream that event id belongs to.", + transports=("streamable-http",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Hosting: HTTP semantics + # ═══════════════════════════════════════════════════════════════════════════ + "hosting:http:accept-406": Requirement( + source="sdk", + behavior="A request whose Accept header does not allow the response representation returns 406.", + transports=("streamable-http",), + ), + "hosting:http:batch": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A POST body is a single JSON-RPC message; batched arrays are rejected for protocol revisions " + "that forbid them." + ), + transports=("streamable-http",), + ), + "hosting:http:content-type-415": Requirement( + source="sdk", + behavior="A POST with a Content-Type other than application/json returns 415.", + transports=("streamable-http",), + divergence=Divergence( + note=( + "The transport-security middleware rejects a non-JSON Content-Type with 400 'Invalid " + "Content-Type header' before the request reaches the transport, so the transport's own 415 " + "path is unreachable through any public entry point." + ), + ), + ), + "hosting:http:disconnect-not-cancel": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A client connection drop during an in-flight request does not cancel the server-side " + "handler; the request continues and its result remains retrievable." + ), + transports=("streamable-http",), + ), + "hosting:http:dns-rebinding": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#security-warning", + behavior=( + "The Origin header is validated on every incoming connection; a request with an invalid " + "Origin is rejected with 403 Forbidden." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The spec's Origin validation is an unconditional MUST; the SDK enables it only when the " + "host is a localhost address or explicit TransportSecuritySettings are passed (with no " + "settings, no Origin validation runs), and additionally validates the Host header " + "(returning 421 on mismatch), which the spec does not require." + ), + ), + ), + "hosting:http:json-response-mode": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior="With JSON response mode enabled, POST returns application/json instead of SSE.", + transports=("streamable-http",), + ), + "hosting:http:method-405": Requirement( + source="sdk", + behavior="An unsupported HTTP method on the MCP endpoint returns 405.", + transports=("streamable-http",), + ), + "hosting:http:no-broadcast": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#multiple-connections", + behavior=( + "When multiple SSE streams are open for a session, each server-originated message is sent on " + "exactly one stream, never duplicated." + ), + transports=("streamable-http",), + ), + "hosting:http:notifications-202": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior="A POST containing only notifications or responses returns 202 with no body.", + transports=("streamable-http",), + ), + "hosting:http:onerror": Requirement( + source="sdk", + behavior="Transport-level rejections are reported through an error callback on the server transport.", + transports=("streamable-http",), + deferred="Not implemented in the SDK: the server transport has no error callback; rejections are logged.", + ), + "hosting:http:parse-error-400": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A POST body that is not valid JSON or not a valid JSON-RPC message is rejected with HTTP 400; " + "the body may carry a JSON-RPC error response (the SDK sends a Parse error body)." + ), + transports=("streamable-http",), + ), + "hosting:http:protocol-version-400": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#protocol-version-header", + behavior="An invalid or unsupported MCP-Protocol-Version header returns 400 Bad Request.", + transports=("streamable-http",), + ), + "hosting:http:protocol-version-default": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#protocol-version-header", + behavior=( + "When no MCP-Protocol-Version header is received and the version cannot be determined another " + "way, the server assumes protocol version 2025-03-26." + ), + transports=("streamable-http",), + ), + "hosting:http:response-same-connection": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A response is delivered on the SSE stream opened by the POST that carried its request (or " + "that stream's resumed continuation), not on an unrelated stream." + ), + transports=("streamable-http",), + ), + "hosting:http:second-sse-rejected": Requirement( + source="sdk", + behavior="A second concurrent standalone GET SSE stream on the same session is rejected.", + transports=("streamable-http",), + ), + "hosting:http:sse-close-after-response": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior="The server terminates a POST-initiated SSE stream after writing the JSON-RPC response.", + transports=("streamable-http",), + ), + "hosting:http:standalone-sse": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#listening-for-messages-from-the-server", + behavior="GET opens a standalone SSE stream that receives server-initiated messages.", + transports=("streamable-http",), + ), + "hosting:http:standalone-sse-no-response": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#listening-for-messages-from-the-server", + behavior=( + "The standalone GET SSE stream carries server requests and notifications but never a JSON-RPC " + "response, except when resuming a prior request stream." + ), + transports=("streamable-http",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Client transport: streamable HTTP + # ═══════════════════════════════════════════════════════════════════════════ + "client-transport:http:404-surfaces": Requirement( + source="sdk", + behavior="A 404 (session expired) on a request surfaces as an error to the caller.", + transports=("streamable-http",), + ), + "client-transport:http:session-404-reinitialize": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior=( + "A 404 in response to a request carrying a session ID makes the client start a new session " + "with a fresh InitializeRequest and no session ID attached." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The client surfaces the 404 as an error to the caller instead of re-initializing a new " + "session; the spec's MUST is not satisfied." + ), + ), + deferred=( + "Not implemented in the SDK: the client surfaces a Session terminated error instead of " + "re-initializing (the surfaced error is pinned by client-transport:http:404-surfaces)." + ), + ), + "client-transport:http:accept-header-get": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#listening-for-messages-from-the-server", + behavior="The client GET to the MCP endpoint includes an Accept header listing text/event-stream.", + transports=("streamable-http",), + ), + "client-transport:http:accept-header-post": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "Every client POST to the MCP endpoint includes an Accept header listing both application/json " + "and text/event-stream." + ), + transports=("streamable-http",), + ), + "client-transport:http:concurrent-streams": Requirement( + source="sdk", + behavior="Multiple concurrent POST-initiated SSE streams each deliver their response to the right caller.", + transports=("streamable-http",), + ), + "client-transport:http:custom-client": Requirement( + source="sdk", + behavior=( + "A caller-supplied HTTP client (and its event hooks and headers) is used for all MCP traffic, " + "including auth flows." + ), + transports=("streamable-http",), + ), + "client-transport:http:custom-headers": Requirement( + source="sdk", + behavior="Caller-supplied headers are sent on every POST, GET, and DELETE to the MCP endpoint.", + transports=("streamable-http",), + ), + "client-transport:http:json-response-parsed": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior="A Content-Type application/json response is parsed as a single JSON-RPC message.", + transports=("streamable-http",), + ), + "client-transport:http:no-reconnect-after-close": Requirement( + source="sdk", + behavior="After the transport is closed, no further reconnection attempts are scheduled.", + transports=("streamable-http",), + ), + "client-transport:http:no-reconnect-after-response": Requirement( + source="sdk", + behavior="A POST-initiated stream that already delivered its response is not reconnected when it closes.", + transports=("streamable-http",), + ), + "client-transport:http:protocol-version-header": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#protocol-version-header", + behavior=( + "After initialization, the client sends the negotiated MCP-Protocol-Version header on every " + "subsequent HTTP request." + ), + transports=("streamable-http",), + ), + "client-transport:http:protocol-version-stored": Requirement( + source="sdk", + behavior=( + "The client transport stores the negotiated protocol version and sends it on every subsequent request." + ), + transports=("streamable-http",), + ), + "client-transport:http:reconnect-get": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior=( + "A standalone GET SSE stream that errors is reconnected with the Last-Event-ID of the last received event." + ), + transports=("streamable-http",), + deferred=( + "The server's standalone GET stream emits no priming event or retry hint, so the client's " + "reconnection path always sleeps the hard-coded 1 s default; a deterministic in-process test " + "would require accepting that real-time wait. The POST-stream reconnection path is covered " + "by client-transport:http:reconnect-post-priming." + ), + ), + "client-transport:http:reconnect-post-priming": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A POST-initiated SSE stream that errors before delivering its response is reconnected only " + "if a priming event (an event carrying an ID) was received on it." + ), + transports=("streamable-http",), + ), + "client-transport:http:reconnect-retry-value": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior="Reconnection delay honours the server-provided SSE retry value when one was sent.", + transports=("streamable-http",), + ), + "client-transport:http:resume-stream-api": Requirement( + source="sdk", + behavior=( + "The client can capture a resumption token, reconnect with the same session id, and receive " + "the notifications it missed." + ), + transports=("streamable-http",), + ), + "client-transport:http:session-stored": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior=( + "The Mcp-Session-Id returned by initialize is stored by the client transport and sent on " + "every subsequent request." + ), + transports=("streamable-http",), + ), + "client-transport:http:sse-405-tolerated": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#listening-for-messages-from-the-server", + behavior="Opening the standalone GET SSE stream tolerates a 405 response without failing the connection.", + transports=("streamable-http",), + ), + "client-transport:http:terminate-405-ok": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="Session termination succeeds without error if the server answers 405 (termination unsupported).", + transports=("streamable-http",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Client auth + # ═══════════════════════════════════════════════════════════════════════════ + "client-auth:401-after-auth-throws": Requirement( + source="sdk", + behavior=( + "If the server still returns 401 after a successful authorization, the client fails instead of looping." + ), + transports=("streamable-http",), + ), + "client-auth:401-triggers-flow": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#protected-resource-metadata-discovery-requirements", + behavior="A 401 on a request triggers the OAuth authorization flow once.", + transports=("streamable-http",), + ), + "client-auth:403-scope-upgrade": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#step-up-authorization-flow", + behavior=( + "A 403 with WWW-Authenticate triggers a scope-upgrade authorization attempt; repeated 403s do not loop." + ), + transports=("streamable-http",), + ), + "client-auth:as-metadata-discovery:priority-order": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-metadata-discovery", + behavior=( + "The client discovers authorization-server metadata by trying, in order, the OAuth " + "path-inserted, OIDC path-inserted, and OIDC path-appended well-known URLs (with the " + "root-path forms when the issuer URL has no path)." + ), + transports=("streamable-http",), + ), + "client-auth:as-metadata-discovery:issuer-validation": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-metadata-discovery", + behavior=( + "The client rejects authorization-server metadata whose issuer does not match the URL the " + "metadata was retrieved from (RFC 8414 section 3.3)." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The SDK parses authorization-server metadata without comparing issuer to the discovery " + "URL; a mismatched issuer is accepted and the flow proceeds." + ), + ), + ), + "client-auth:authorize:error-surfaces": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-flow-steps", + behavior=( + "An OAuth error redirect from the authorize endpoint aborts the flow before any token " + "request is issued, surfacing as an error to the caller." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The callback contract has no error form, so the client surfaces 'No authorization code " + "received' rather than the redirect's `error`/`error_description` values." + ), + ), + ), + "client-auth:authorize:offline-access-consent": Requirement( + source="sdk", + behavior=( + "When the authorization server's metadata advertises offline_access in scopes_supported and " + "the client uses the refresh_token grant, offline_access is appended to the requested scope " + "and prompt=consent is added to the authorize request." + ), + transports=("streamable-http",), + ), + "client-auth:bearer-header:every-request": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#token-requirements", + behavior=( + "Once authorized, the client sends the bearer token in the Authorization header on every HTTP " + "request to the MCP server, never in the query string." + ), + transports=("streamable-http",), + ), + "client-auth:cimd": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#client-id-metadata-documents", + behavior="The client can use a client-ID metadata document URL as its OAuth client_id instead of registration.", + transports=("streamable-http",), + ), + "client-auth:client-credentials": Requirement( + source="sdk", + behavior=( + "A client-credentials provider obtains a token without user interaction and the resulting " + "bearer token authorizes subsequent requests." + ), + transports=("streamable-http",), + ), + "client-auth:dcr:registration-error-surfaces": Requirement( + source="sdk", + behavior=( + "A 400 from the registration endpoint surfaces to the caller as an OAuthRegistrationError " + "carrying the status and the server's RFC 7591 error body." + ), + transports=("streamable-http",), + ), + "client-auth:dcr": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#dynamic-client-registration", + behavior=( + "The client performs dynamic client registration against the authorization server when no " + "client_id is preconfigured." + ), + transports=("streamable-http",), + ), + "client-auth:invalid-client-clears-all": Requirement( + source="sdk", + behavior=( + "An invalid-client or unauthorized-client error during authorization invalidates all stored credentials." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The token-response handlers do not parse the error body; an invalid_client or " + "unauthorized_client response leaves stored client_info untouched. The TypeScript SDK " + "clears it." + ), + ), + deferred=( + "Not implemented in the SDK: no token-response path inspects the error code to decide " + "whether to clear client_info." + ), + ), + "client-auth:invalid-grant-clears-tokens": Requirement( + source="sdk", + behavior="An invalid-grant error during authorization invalidates only the stored tokens.", + transports=("streamable-http",), + ), + "client-auth:pkce:refuse-if-unsupported": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-code-protection", + behavior=( + "The client refuses to proceed when the authorization server's metadata does not include " + "code_challenge_methods_supported, since PKCE support cannot be verified." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The client never inspects code_challenge_methods_supported and proceeds with PKCE S256 " + "regardless; the spec MUST is not enforced." + ), + ), + ), + "client-auth:pkce:s256": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-code-protection", + behavior=( + "The authorization request includes a PKCE S256 code challenge and the token request includes " + "the matching verifier." + ), + transports=("streamable-http",), + ), + "client-auth:pre-registration": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#preregistration", + behavior=( + "A client with statically preconfigured credentials skips dynamic registration and uses them directly." + ), + transports=("streamable-http",), + ), + "client-auth:private-key-jwt": Requirement( + source="sdk", + behavior="The client can authenticate the client-credentials grant with a signed JWT assertion.", + transports=("streamable-http",), + ), + "client-auth:prm-discovery:fallback-order": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#protected-resource-metadata-discovery-requirements", + behavior=( + "The client uses resource_metadata from WWW-Authenticate when present, then falls back to the " + "well-known protected-resource locations in the documented order." + ), + transports=("streamable-http",), + ), + "client-auth:prm-discovery:no-prm-fallback": Requirement( + source="sdk", + behavior=( + "When every protected-resource metadata probe fails, the client falls back to discovering " + "authorization-server metadata directly at the MCP server's origin (the legacy 2025-03-26 path) " + "rather than aborting." + ), + transports=("streamable-http",), + ), + "client-auth:prm-resource-mismatch": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-location", + behavior=( + "The client refuses to proceed when the protected-resource metadata's resource field does not " + "match the server URL it is connecting to." + ), + transports=("streamable-http",), + ), + "client-auth:refresh:transparent": Requirement( + source="sdk", + behavior=( + "An access token the client considers expired is transparently refreshed before the next " + "request, using the stored refresh token; the refresh request includes the resource indicator " + "and the new token is persisted." + ), + transports=("streamable-http",), + ), + "client-auth:resource-parameter": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#resource-parameter-implementation", + behavior=( + "The client includes the canonical server URI as the resource parameter in both the " + "authorization request and the token request." + ), + transports=("streamable-http",), + ), + "client-auth:scope-selection:priority": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#scope-selection-strategy", + behavior=( + "Client selects requested scope from the WWW-Authenticate scope param if present; otherwise " + "uses scopes_supported from the PRM document; otherwise omits scope." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The SDK inserts an extra fallback step between PRM and omit: if the authorization " + "server metadata advertises scopes_supported, that list is used (client/auth/utils.py). " + "This is beyond the spec's two-step chain." + ), + ), + ), + "client-auth:state:verify": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#open-redirection", + behavior=( + "A state parameter is included in the authorization URL, and authorization results with a " + "missing or mismatched state are discarded." + ), + transports=("streamable-http",), + ), + "client-auth:token-endpoint-auth-method": Requirement( + source="sdk", + behavior="The client authenticates to the token endpoint using the auth method established at registration.", + transports=("streamable-http",), + ), + "client-auth:token-provenance": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#token-handling", + behavior=( + "The client sends the MCP server only tokens issued by that server's authorization server, " + "never tokens obtained elsewhere." + ), + transports=("streamable-http",), + deferred=( + "Untestable negative through the public API: there is no path to inject a token obtained " + "elsewhere into the auth provider's state, so the absence cannot be observed end to end." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # stdio transport + # ═══════════════════════════════════════════════════════════════════════════ + "transport:stdio:clean-shutdown": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#shutdown", + behavior="Closing the client transport closes the child process's stdin and the server exits cleanly.", + transports=("stdio",), + ), + "transport:stdio:stream-purity": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#stdio", + behavior=( + "Nothing that is not a valid MCP message is written to the server's stdout, and nothing that " + "is not a valid MCP message is written to its stdin." + ), + transports=("stdio",), + divergence=Divergence( + note=( + "stdio_server's own writes satisfy this, but it does not redirect or guard sys.stdout: " + "handler code that calls print() writes directly to the protocol stream and corrupts the " + "framing. The spec MUST is satisfied only as long as application code behaves." + ), + ), + ), + "transport:stdio:no-embedded-newlines": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#stdio", + behavior="Serialized JSON-RPC messages on stdio contain no embedded newlines; one message per line.", + transports=("stdio",), + ), + "transport:stdio:shutdown-escalation": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#stdio", + behavior=( + "If the server process does not exit after stdin is closed, the client transport terminates " + "it (and kills it if still alive) after a grace period." + ), + transports=("stdio",), + deferred=( + "A server that ignores stdin close takes the full PROCESS_TERMINATION_TIMEOUT (2.0 s) grace " + "period plus up to a further 2.0 s for SIGTERM/SIGKILL escalation; testing that path is " + "real-time-bound (the constant is module-level with no public override) and so is deliberately " + "excluded from this suite. Covered by tests/client/test_stdio.py." + ), + ), + "transport:stdio:stderr-passthrough": Requirement( + source="sdk", + behavior="Server stderr is available to the client and is not consumed by the transport.", + transports=("stdio",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Composite end-to-end flows + # ═══════════════════════════════════════════════════════════════════════════ + "flow:compat:dual-transport-server": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#backwards-compatibility", + behavior=( + "A single server instance can serve streamable HTTP and the legacy SSE transport " + "concurrently; clients on either transport can call the same tools." + ), + transports=("streamable-http", "sse"), + ), + "flow:compat:streamable-then-sse-fallback": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#backwards-compatibility", + behavior=( + "When a streamable HTTP initialize fails with 400, 404, or 405, falling back to the legacy " + "SSE client transport against the same server connects successfully." + ), + transports=("streamable-http", "sse"), + divergence=Divergence( + note=( + "The SDK provides no automatic streamable-HTTP-to-SSE client fallback; the spec's " + "client-side SHOULD is left to the application to compose from streamable_http_client " + "and sse_client. Both halves are independently proven by the matrix." + ), + ), + deferred=( + "A demonstration test would only re-prove what the matrix already covers (an SSE-only " + "server is reachable via sse_client; an unmounted route returns 404), with the application " + "doing the fallback in between rather than the SDK." + ), + ), + "flow:elicitation:multi-step-form": Requirement( + source="sdk", + behavior=( + "A single tool handler issues sequential elicitations; an accept on one step feeds the next, " + "and a decline or cancel at any step short-circuits to a final result." + ), + ), + "flow:elicitation:url-at-session-init": Requirement( + source="sdk", + behavior=( + "The server can issue a URL-mode elicitation over the standalone GET stream immediately after " + "session initialization, before any client request." + ), + transports=("streamable-http",), + deferred=( + "Not implemented in the SDK: no public per-session post-initialization hook exists on either " + "server flavour (Server.lifespan runs at server startup, not per session; ServerSession " + "handles the initialized notification internally with no callback). Driving 'before any " + "client request' deterministically would also require knowing the standalone GET stream is " + "established, which has no synchronization signal." + ), + ), + "flow:elicitation:url-required-then-retry": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#url-elicitation-required-error", + behavior=( + "A tool call rejected with the URL-elicitation-required error can be retried successfully " + "after the client completes the URL flow and the server announces completion." + ), + ), + "flow:multi-client:stateful-isolation": Requirement( + source="sdk", + behavior=( + "Independent clients connected to one stateful server each receive a distinct session and " + "only the notifications produced by their own requests." + ), + transports=("streamable-http",), + ), + "flow:oauth:authorization-code-roundtrip": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-flow-steps", + behavior=( + "Connecting to a protected server walks the authorization-code flow end to end: the first " + "attempt requires authorization, the code is exchanged, and a subsequent connection succeeds." + ), + transports=("streamable-http",), + ), + "flow:resume:tool-call-resumption-token": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior=( + "A tool call interrupted mid-stream is transparently resumed by the client transport using " + "the last-seen event id, delivering only the remaining notifications and the final result." + ), + transports=("streamable-http",), + ), + "flow:session:terminate-then-reconnect": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior=("After terminating a session, a fresh connection obtains a new session id and operations succeed."), + transports=("streamable-http",), + ), + "flow:tool-result:resource-link-follow": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#resource-links", + behavior=( + "A resource_link returned by a tool call can be followed with resources/read on the linked " + "URI to retrieve the referenced contents." + ), + ), +} + + +def requirement(requirement_id: str) -> Callable[[_TestFn], _TestFn]: + """Mark a test as exercising a requirement from :data:`REQUIREMENTS`. + + Applies the `requirement` pytest marker and records the coverage link checked by + `test_coverage.py`. Unknown IDs fail at import time so a typo surfaces as a collection + error on the offending test, not as a missing-coverage report later. + """ + if requirement_id not in REQUIREMENTS: + raise KeyError(f"Unknown requirement id {requirement_id!r}: add it to REQUIREMENTS in {__name__}") + + def apply(test_fn: _TestFn) -> _TestFn: + covered_by(requirement_id).append(f"{test_fn.__module__}.{test_fn.__qualname__}") + return pytest.mark.requirement(requirement_id)(test_fn) + + return apply + + +_COVERAGE: dict[str, list[str]] = {} + + +def covered_by(requirement_id: str) -> list[str]: + """Return the (mutable) list of test names recorded as exercising `requirement_id`.""" + return _COVERAGE.setdefault(requirement_id, []) diff --git a/tests/interaction/auth/__init__.py b/tests/interaction/auth/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/interaction/auth/_harness.py b/tests/interaction/auth/_harness.py new file mode 100644 index 0000000000..d013364f33 --- /dev/null +++ b/tests/interaction/auth/_harness.py @@ -0,0 +1,465 @@ +"""In-process harness for the auth interaction tests. + +Co-hosts the SDK's authorization-server routes, protected-resource metadata route, and the +bearer-gated MCP endpoint on one Starlette app via `Server.streamable_http_app(auth=..., +token_verifier=..., auth_server_provider=...)`, drives that app through the streaming bridge +on a single `httpx.AsyncClient` carrying `auth=OAuthClientProvider(...)`, and completes the +authorize redirect headlessly by GETing the URL through the same bridge and parsing the code +from the 302 `Location`. The whole authorization-code flow runs in one event loop with no +sockets, no threads, and no real time. +""" + +import json +from collections.abc import AsyncIterator, Callable, Mapping, Sequence +from contextlib import AsyncExitStack, asynccontextmanager +from dataclasses import dataclass, field +from typing import Any +from urllib.parse import parse_qs, parse_qsl, urlsplit + +import httpx +from pydantic import AnyHttpUrl, AnyUrl, BaseModel +from starlette.types import ASGIApp, Receive, Scope, Send + +from mcp.client.auth import OAuthClientProvider +from mcp.client.client import Client +from mcp.client.streamable_http import streamable_http_client +from mcp.server import Server +from mcp.server.auth.provider import AccessToken, ProviderTokenVerifier +from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions, RevocationOptions +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken +from tests.interaction._connect import BASE_URL, NO_DNS_REBINDING_PROTECTION +from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider +from tests.interaction.transports._bridge import StreamingASGITransport + +REDIRECT_URI = f"{BASE_URL}/oauth/callback" + +AppShim = Callable[[ASGIApp], ASGIApp] + + +@dataclass +class RecordedRequest: + """A snapshot of an `httpx.Request` at the moment it was sent. + + The auth flow re-yields the same `httpx.Request` object after mutating its headers in + place for the retry, so tests that need to assert on the first attempt's headers must + capture a copy rather than a live reference. `record_requests` produces these. + """ + + method: str + url: httpx.URL + headers: dict[str, str] + content: bytes + + @property + def path(self) -> str: + return self.url.path + + +def record_requests() -> tuple[list[RecordedRequest], Callable[[httpx.Request], None]]: + """Build an `on_request` callback that snapshots each request, and the list it appends to.""" + recorded: list[RecordedRequest] = [] + + def on_request(request: httpx.Request) -> None: + recorded.append( + RecordedRequest( + method=request.method, + url=request.url, + headers=dict(request.headers), + content=bytes(request.content), + ) + ) + + return recorded, on_request + + +def metadata_body(model: BaseModel, **extra: object) -> bytes: + """Serialize a metadata model to a JSON body for `shimmed_app(serve=...)`. + + `extra` keys are merged into the serialized object so a test can inject fields the model + does not declare (e.g. an unknown extension field, to prove the client's parser tolerates + unrecognized members per RFC 8414/9728 §3.2). The model itself would silently drop such + fields at construction, so they have to be added after serialization. + """ + document = model.model_dump(by_alias=True, mode="json", exclude_none=True) + document.update(extra) + return json.dumps(document).encode() + + +class StaticTokenVerifier: + """A `TokenVerifier` backed by a fixed token→`AccessToken` mapping. + + Any token string not in the mapping verifies to `None`, which the bearer middleware treats + as an unrecognized token. Tests seed the mapping with the exact token shapes (valid, expired, + wrong scope, wrong audience) they need so the resource-server gate's behaviour is asserted in + isolation from the authorization-server provider. + """ + + def __init__(self, tokens: Mapping[str, AccessToken]) -> None: + self._tokens = dict(tokens) + + async def verify_token(self, token: str) -> AccessToken | None: + return self._tokens.get(token) + + +class InMemoryTokenStorage: + """A `TokenStorage` that holds tokens and client info as instance attributes. + + Tests pre-seed `client_info` (via the constructor or by assignment) to drive the + pre-registered path, and read both attributes after the flow to assert what the SDK + persisted. + """ + + def __init__(self, *, client_info: OAuthClientInformationFull | None = None) -> None: + self.tokens: OAuthToken | None = None + self.client_info: OAuthClientInformationFull | None = client_info + + async def get_tokens(self) -> OAuthToken | None: + return self.tokens + + async def set_tokens(self, tokens: OAuthToken) -> None: + self.tokens = tokens + + async def get_client_info(self) -> OAuthClientInformationFull | None: + return self.client_info + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + self.client_info = client_info + + +class HeadlessOAuth: + """Completes the authorize step in-process by following the redirect through the bridge. + + `redirect_handler` GETs the authorize URL on the bound client (with `auth=None` so the + request does not re-enter the locked auth flow), parses `code` and `state` from the 302 + `Location`, and stashes them; `callback_handler` returns the stashed pair. Tests inspect + `authorize_url` to assert what the SDK put on the authorize request. + + `state_override`: when set, `callback_handler` returns this value as the state instead of + the one parsed from the redirect, so tests can drive the state-mismatch path. + """ + + def __init__(self, *, state_override: str | None = None) -> None: + self.authorize_url: str | None = None + self.authorize_urls: list[str] = [] + self.error: str | None = None + self._state_override = state_override + self._http: httpx.AsyncClient | None = None + self._code: str = "" + self._state: str | None = None + + def bind(self, http_client: httpx.AsyncClient) -> None: + self._http = http_client + + async def redirect_handler(self, authorization_url: str) -> None: + assert self._http is not None + self.authorize_url = authorization_url + self.authorize_urls.append(authorization_url) + # auth=None is load-bearing: without it the GET re-enters OAuthClientProvider.async_auth_flow + # through its context lock and the flow deadlocks. + response = await self._http.get(authorization_url, follow_redirects=False, auth=None) + assert response.status_code == 302, f"authorize endpoint returned {response.status_code}: {response.text}" + params = parse_qs(urlsplit(response.headers["location"]).query) + self._code = params.get("code", [""])[0] + self._state = params.get("state", [None])[0] + self.error = params.get("error", [None])[0] + + async def callback_handler(self) -> tuple[str, str | None]: + return self._code, self._state_override if self._state_override is not None else self._state + + +def auth_settings( + *, required_scopes: Sequence[str] = ("mcp",), valid_scopes: Sequence[str] | None = None +) -> AuthSettings: + """Build `AuthSettings` for the co-hosted authorization + resource server. + + The issuer and resource URLs use the suite's loopback origin, which `validate_issuer_url` + accepts in lieu of HTTPS. Dynamic client registration is enabled. `valid_scopes` defaults + to `required_scopes` so a client requesting exactly those passes registration scope + validation; tests pass a wider set when they need the protected-resource metadata's + `scopes_supported` (which mirrors `required_scopes`) to differ from what the client may + register or when AS metadata should advertise additional scopes such as `offline_access`. + """ + required = list(required_scopes) + valid = list(valid_scopes) if valid_scopes is not None else required + return AuthSettings( + issuer_url=AnyHttpUrl(BASE_URL), + resource_server_url=AnyHttpUrl(f"{BASE_URL}/mcp"), + required_scopes=required, + client_registration_options=ClientRegistrationOptions( + enabled=True, valid_scopes=valid, default_scopes=required + ), + revocation_options=RevocationOptions(enabled=False), + ) + + +def oauth_client_metadata() -> OAuthClientMetadata: + """Build the client's registration metadata. + + `scope` is left unset so the SDK's scope-selection strategy chooses one from the server's + metadata before registration. + """ + return OAuthClientMetadata( + client_name="interaction-suite", + redirect_uris=[AnyUrl(REDIRECT_URI)], + grant_types=["authorization_code", "refresh_token"], + ) + + +def shimmed_app( + app: ASGIApp, + *, + not_found: frozenset[str] = frozenset(), + serve: Mapping[str, bytes | tuple[int, bytes]] | None = None, +) -> ASGIApp: + """Wrap an ASGI app so specific paths return canned responses before reaching the real app. + + Paths in `serve` return the given body as `application/json` (status 200, or the supplied + status when the value is a `(status, body)` pair); paths in `not_found` return 404; + everything else reaches the wrapped app unchanged. Used by the discovery tests to make a + well-known endpoint 404 or return alternate metadata while keeping the real authorization + and MCP endpoints behind it. + """ + overrides: dict[str, tuple[int, bytes]] = { + path: value if isinstance(value, tuple) else (200, value) for path, value in (serve or {}).items() + } + + async def wrapped(scope: Scope, receive: Receive, send: Send) -> None: + path = scope["path"] + if path in overrides: + status, body = overrides[path] + await send( + { + "type": "http.response.start", + "status": status, + "headers": [ + (b"content-type", b"application/json"), + (b"content-length", str(len(body)).encode()), + ], + } + ) + await send({"type": "http.response.body", "body": body}) + return + if path in not_found: + await send({"type": "http.response.start", "status": 404, "headers": []}) + await send({"type": "http.response.body", "body": b""}) + return + await app(scope, receive, send) + + return wrapped + + +def shim( + *, not_found: frozenset[str] = frozenset(), serve: Mapping[str, bytes | tuple[int, bytes]] | None = None +) -> AppShim: + """Build an `app_shim` for `connect_with_oauth` that applies `shimmed_app` with these overrides.""" + return lambda app: shimmed_app(app, not_found=not_found, serve=serve) + + +@dataclass +class _FirstChallenge: + """ASGI shim that answers the first request to a path with 401 + a given WWW-Authenticate. + + Subsequent requests pass through to the wrapped app. Used to make the initial 401 carry + parameters (such as `scope=`) that the SDK's own bearer middleware cannot be configured + to emit, so client behaviour driven by those parameters is reachable end to end. Reserve + this pattern for behaviour the real server cannot be made to produce. + """ + + app: ASGIApp + path: str + www_authenticate: str + _seen: set[str] = field(default_factory=set[str]) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http" and scope["path"] == self.path and self.path not in self._seen: + self._seen.add(self.path) + await send( + { + "type": "http.response.start", + "status": 401, + "headers": [(b"www-authenticate", self.www_authenticate.encode())], + } + ) + await send({"type": "http.response.body", "body": b""}) + return + await self.app(scope, receive, send) + + +def first_challenge_shim(www_authenticate: str, *, path: str = "/mcp") -> Callable[[ASGIApp], ASGIApp]: + """Build an `app_shim` that 401s the first request to `path` with the given header value.""" + return lambda app: _FirstChallenge(app, path, www_authenticate) + + +def step_up_shim(www_authenticate: str, *, on_nth_authenticated_post: int = 2) -> AppShim: + """Build an `app_shim` that 403s the Nth authenticated POST to `/mcp` with the given challenge. + + Subsequent requests pass through. Used to drive the client's `insufficient_scope` step-up + handling: the SDK's bearer middleware never emits `scope=` in its 403 challenge (see the + divergence on `hosting:auth:scope-403`), so the test supplies the 403 itself. Reserve this + pattern for behaviour the real server cannot be made to produce. + + The default `on_nth_authenticated_post=2` targets the `notifications/initialized` POST: the + first authenticated POST is the auth flow's retry of the original initialize request (yielded + after the 401 branch, where the generator ends without inspecting the response), so a 403 + there would not reach the step-up handler. + """ + seen = 0 + fired = False + + def factory(app: ASGIApp) -> ASGIApp: + async def wrapped(scope: Scope, receive: Receive, send: Send) -> None: + nonlocal seen, fired + if ( + not fired + and scope["type"] == "http" + and scope["path"] == "/mcp" + and scope["method"] == "POST" + and any(name == b"authorization" for name, _ in scope["headers"]) + ): + seen += 1 + if seen < on_nth_authenticated_post: + await app(scope, receive, send) + return + fired = True + await send( + { + "type": "http.response.start", + "status": 403, + "headers": [(b"www-authenticate", www_authenticate.encode())], + } + ) + await send({"type": "http.response.body", "body": b""}) + return + await app(scope, receive, send) + + return wrapped + + return factory + + +def m2m_token_shim(provider: InMemoryAuthorizationServerProvider, *, scopes: list[str]) -> AppShim: + """Build an `app_shim` that handles `grant_type=client_credentials` at `/token`. + + The SDK server's `TokenHandler` only routes `authorization_code` and `refresh_token`, so a + `client_credentials` request would fail discriminator validation. This shim mints a token via + `provider.mint_access_token` so the M2M client providers can complete e2e against the real + bearer middleware. The shim is harness; the SDK-under-test is the client provider, whose + outbound `/token` body the test asserts. The shim does not authenticate the client (no + credential check) because the test asserts the credentials on the recorded request, not on + the server's acceptance. + """ + + def factory(app: ASGIApp) -> ASGIApp: + async def wrapped(scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http" and scope["path"] == "/token" and scope["method"] == "POST": + # The streaming bridge buffers the request body and delivers it in a single + # http.request event, so one receive is sufficient. + message = await receive() + assert not message.get("more_body", False) + form = dict(parse_qsl(message.get("body", b"").decode())) + assert form.get("grant_type") == "client_credentials", ( + f"m2m_token_shim only handles client_credentials; got {form.get('grant_type')!r}" + ) + access = provider.mint_access_token(client_id="m2m", scopes=scopes, resource=form.get("resource")) + token = OAuthToken(access_token=access, token_type="Bearer", expires_in=3600, scope=" ".join(scopes)) + response_body = token.model_dump_json(exclude_none=True).encode() + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [ + (b"content-type", b"application/json"), + (b"content-length", str(len(response_body)).encode()), + (b"cache-control", b"no-store"), + ], + } + ) + await send({"type": "http.response.body", "body": response_body}) + return + await app(scope, receive, send) + + return wrapped + + return factory + + +@asynccontextmanager +async def connect_with_oauth( + server: Server, + *, + provider: InMemoryAuthorizationServerProvider, + settings: AuthSettings | None = None, + storage: InMemoryTokenStorage | None = None, + client_metadata: OAuthClientMetadata | None = None, + client_metadata_url: str | None = None, + headless: HeadlessOAuth | None = None, + auth: httpx.Auth | None = None, + verify_tokens: bool = True, + app_shim: Callable[[ASGIApp], ASGIApp] | None = None, + on_request: Callable[[httpx.Request], None] | None = None, +) -> AsyncIterator[tuple[Client, HeadlessOAuth]]: + """Connect a `Client` to a server's bearer-gated streamable-HTTP app, completing OAuth in process. + + Yields the connected `Client` and the `HeadlessOAuth` whose `authorize_url` records what the + SDK put on the authorize request. `on_request` records every HTTP request the underlying + `httpx.AsyncClient` issues, including those yielded from inside the auth flow. + + `headless`: supply a pre-configured `HeadlessOAuth` to override the callback behaviour + (state mismatch, error redirects). `verify_tokens=False` mounts the MCP endpoint without + the bearer middleware so a flow driven by a shimmed 401 completes regardless of the granted + scopes. `app_shim` wraps the built Starlette app before it reaches the bridge transport, + for tests that need to intercept or rewrite specific server responses. + + `auth`: supply a pre-built `httpx.Auth` (such as `ClientCredentialsOAuthProvider`) to use + instead of constructing the default `OAuthClientProvider`; in that case `storage`, + `client_metadata`, `client_metadata_url`, and `headless` are unused (the yielded + `HeadlessOAuth` is never invoked and its `authorize_url` stays None). + """ + settings = settings if settings is not None else auth_settings() + storage = storage if storage is not None else InMemoryTokenStorage() + client_metadata = client_metadata if client_metadata is not None else oauth_client_metadata() + headless = headless if headless is not None else HeadlessOAuth() + + oauth = ( + auth + if auth is not None + else OAuthClientProvider( + server_url=f"{BASE_URL}/mcp", + client_metadata=client_metadata, + storage=storage, + redirect_handler=headless.redirect_handler, + callback_handler=headless.callback_handler, + client_metadata_url=client_metadata_url, + ) + ) + + app: ASGIApp = server.streamable_http_app( + auth=settings, + token_verifier=ProviderTokenVerifier(provider) if verify_tokens else None, + auth_server_provider=provider, + transport_security=NO_DNS_REBINDING_PROTECTION, + ) + if app_shim is not None: + app = app_shim(app) + + event_hooks: dict[str, list[Callable[..., Any]]] | None = None + if on_request is not None: + record = on_request + + async def hook(request: httpx.Request) -> None: + record(request) + + event_hooks = {"request": [hook]} + + async with AsyncExitStack() as stack: + await stack.enter_async_context(server.session_manager.run()) + http_client = await stack.enter_async_context( + httpx.AsyncClient( + transport=StreamingASGITransport(app), base_url=BASE_URL, auth=oauth, event_hooks=event_hooks + ) + ) + headless.bind(http_client) + client = await stack.enter_async_context( + Client(streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client)) + ) + yield client, headless diff --git a/tests/interaction/auth/_provider.py b/tests/interaction/auth/_provider.py new file mode 100644 index 0000000000..5c88995a30 --- /dev/null +++ b/tests/interaction/auth/_provider.py @@ -0,0 +1,186 @@ +"""An in-memory implementation of the SDK's OAuth authorization-server provider protocol. + +The provider holds clients, authorization codes, refresh tokens and access tokens in plain +instance dicts so tests can inspect them; tokens are minted from `secrets.token_hex` so the +values are unique without being predictable. The behaviour mirrors what the SDK's authorization +handlers expect: `authorize` immediately mints a code and returns the redirect, `exchange_*` +issue and rotate tokens, and `load_*` are simple lookups. Only the parts the auth interaction +suite drives are implemented; methods the suite does not exercise raise `NotImplementedError`. +""" + +import secrets +import time + +from mcp.server.auth.provider import ( + AccessToken, + AuthorizationCode, + AuthorizationParams, + OAuthAuthorizationServerProvider, + RefreshToken, + TokenError, + construct_redirect_uri, +) +from mcp.shared.auth import OAuthClientInformationFull, OAuthToken + +_TOKEN_LIFETIME_SECONDS = 3600 + + +class InMemoryAuthorizationServerProvider( + OAuthAuthorizationServerProvider[AuthorizationCode, RefreshToken, AccessToken] +): + """An OAuth authorization-server provider backed by in-memory dicts. + + Holds registered clients, issued codes, refresh tokens and access tokens as instance state + so tests can both drive the SDK's authorization handlers and inspect what was issued. + + Knobs: + `default_scopes`: scopes granted when an authorize request supplies none. + `deny_authorize`: every authorize request returns an `error=access_denied` redirect. + `issue_expired_first`: the first issued token's `expires_in` is in the past so the + client immediately considers it expired and refreshes; the server-side + `AccessToken.expires_at` stays in the future so the bearer middleware accepts it + on the retry that completes the connect. + `fail_next_refresh`: the next refresh-token exchange raises `invalid_grant` once. + `reject_all_tokens`: `load_access_token` returns None for every token, so the bearer + middleware 401s every authenticated request. + """ + + def __init__( + self, + *, + default_scopes: list[str] | None = None, + deny_authorize: bool = False, + issue_expired_first: bool = False, + fail_next_refresh: bool = False, + reject_all_tokens: bool = False, + ) -> None: + self._default_scopes = list(default_scopes) if default_scopes is not None else ["mcp"] + self._issuer = "http://127.0.0.1:8000" + self._deny_authorize = deny_authorize + self._issue_expired_first = issue_expired_first + self._fail_next_refresh = fail_next_refresh + self._reject_all_tokens = reject_all_tokens + self._tokens_issued = 0 + self.clients: dict[str, OAuthClientInformationFull] = {} + self.codes: dict[str, AuthorizationCode] = {} + self.refresh_tokens: dict[str, RefreshToken] = {} + self.access_tokens: dict[str, AccessToken] = {} + + def _next_expires_in(self) -> int: + self._tokens_issued += 1 + if self._issue_expired_first and self._tokens_issued == 1: + return -_TOKEN_LIFETIME_SECONDS + return _TOKEN_LIFETIME_SECONDS + + def mint_access_token(self, *, client_id: str, scopes: list[str], resource: str | None = None) -> str: + """Mint and store an access token, returning its value. + + Used by the auth-code and refresh exchanges and by the M2M `/token` shim. The + server-side `expires_at` is always in the future regardless of `issue_expired_first`, + which only affects what the client is told. + """ + access = f"access_{secrets.token_hex(16)}" + self.access_tokens[access] = AccessToken( + token=access, + client_id=client_id, + scopes=scopes, + expires_at=int(time.time()) + _TOKEN_LIFETIME_SECONDS, + resource=resource, + ) + return access + + async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: + return self.clients.get(client_id) + + async def register_client(self, client_info: OAuthClientInformationFull) -> None: + assert client_info.client_id is not None + self.clients[client_info.client_id] = client_info + + async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: + """Mint an authorization code immediately and return the redirect carrying it. + + A real provider would interpose user consent here; the test provider grants + unconditionally so the headless redirect handler can complete the flow in-process. + When `deny_authorize` is set, returns an `error=access_denied` redirect instead. + """ + assert client.client_id is not None + if self._deny_authorize: + return construct_redirect_uri( + str(params.redirect_uri), error="access_denied", error_description="user denied", state=params.state + ) + code = AuthorizationCode( + code=f"code_{secrets.token_hex(16)}", + client_id=client.client_id, + scopes=params.scopes or self._default_scopes, + expires_at=time.time() + 300, + code_challenge=params.code_challenge, + redirect_uri=params.redirect_uri, + redirect_uri_provided_explicitly=params.redirect_uri_provided_explicitly, + resource=params.resource, + ) + self.codes[code.code] = code + # `iss` is RFC 9207's authorization-response issuer identifier — an extra parameter many + # real authorization servers send. Including it on every success redirect proves the + # client tolerates unrecognized callback parameters (RFC 6749 §4.1.2 MUST) by virtue of + # every flow test passing unchanged. + return construct_redirect_uri(str(params.redirect_uri), code=code.code, state=params.state, iss=self._issuer) + + async def load_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: str + ) -> AuthorizationCode | None: + return self.codes.get(authorization_code) + + async def exchange_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode + ) -> OAuthToken: + """Mint an access token and a refresh token for a valid authorization code, then consume the code.""" + assert client.client_id is not None + access = self.mint_access_token( + client_id=client.client_id, scopes=authorization_code.scopes, resource=authorization_code.resource + ) + refresh = f"refresh_{secrets.token_hex(16)}" + self.refresh_tokens[refresh] = RefreshToken( + token=refresh, + client_id=client.client_id, + scopes=authorization_code.scopes, + ) + del self.codes[authorization_code.code] + return OAuthToken( + access_token=access, + token_type="Bearer", + expires_in=self._next_expires_in(), + scope=" ".join(authorization_code.scopes), + refresh_token=refresh, + ) + + async def load_access_token(self, token: str) -> AccessToken | None: + if self._reject_all_tokens: + return None + return self.access_tokens.get(token) + + async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None: + return self.refresh_tokens.get(refresh_token) + + async def exchange_refresh_token( + self, client: OAuthClientInformationFull, refresh_token: RefreshToken, scopes: list[str] + ) -> OAuthToken: + """Mint a new access token and rotate the refresh token, consuming the old one.""" + assert client.client_id is not None + if self._fail_next_refresh: + self._fail_next_refresh = False + raise TokenError(error="invalid_grant", error_description="refresh denied by harness") + access = self.mint_access_token(client_id=client.client_id, scopes=scopes) + new_refresh = f"refresh_{secrets.token_hex(16)}" + self.refresh_tokens[new_refresh] = RefreshToken(token=new_refresh, client_id=client.client_id, scopes=scopes) + del self.refresh_tokens[refresh_token.token] + return OAuthToken( + access_token=access, + token_type="Bearer", + expires_in=self._next_expires_in(), + scope=" ".join(scopes), + refresh_token=new_refresh, + ) + + async def revoke_token(self, token: AccessToken | RefreshToken) -> None: + """Not exercised by this suite; revocation is out of scope for the interaction tests.""" + raise NotImplementedError diff --git a/tests/interaction/auth/test_as_handlers.py b/tests/interaction/auth/test_as_handlers.py new file mode 100644 index 0000000000..5cb4e92d86 --- /dev/null +++ b/tests/interaction/auth/test_as_handlers.py @@ -0,0 +1,300 @@ +"""Error-plane behaviour of the SDK's bundled OAuth authorization-server handlers. + +The end-to-end OAuth tests prove the handlers' happy paths; these tests drive the same +mounted authorization server directly with raw httpx so the assertions are the HTTP +semantics (status, redirect target, error body, headers) the OAuth RFCs mandate. Almost +every behaviour here is enforced by the SDK's own handlers; where the pinned output +deviates from the RFC, the manifest entry carries the divergence. +""" + +import base64 +import hashlib +import secrets +from collections.abc import AsyncIterator +from urllib.parse import parse_qs, urlsplit + +import httpx +import pytest +from inline_snapshot import snapshot + +from mcp.server import Server +from mcp.server.auth.provider import ProviderTokenVerifier +from mcp.shared.auth import OAuthClientInformationFull +from tests.interaction._connect import mounted_app +from tests.interaction._requirements import requirement +from tests.interaction.auth._harness import REDIRECT_URI, auth_settings, oauth_client_metadata +from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider + +pytestmark = pytest.mark.anyio + + +@pytest.fixture +async def as_app() -> AsyncIterator[tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider]]: + """Co-host the SDK's authorization-server routes and yield a raw httpx client against them.""" + provider = InMemoryAuthorizationServerProvider() + settings = auth_settings() + async with mounted_app( + Server("guarded"), + auth=settings, + token_verifier=ProviderTokenVerifier(provider), + auth_server_provider=provider, + ) as (http, _): + yield http, provider + + +def _pkce_pair() -> tuple[str, str]: + """Generate a (code_verifier, code_challenge) pair the same way the SDK client does.""" + verifier = secrets.token_urlsafe(48)[:64] + challenge = base64.urlsafe_b64encode(hashlib.sha256(verifier.encode()).digest()).decode().rstrip("=") + return verifier, challenge + + +async def _register_client(http: httpx.AsyncClient) -> OAuthClientInformationFull: + """Dynamically register a client and return its full credentials.""" + response = await http.post("/register", content=oauth_client_metadata().model_dump_json()) + assert response.status_code == 201 + return OAuthClientInformationFull.model_validate_json(response.content) + + +async def _mint_code(http: httpx.AsyncClient) -> tuple[OAuthClientInformationFull, str, str]: + """Register a client, complete a valid authorize step, and return (client_info, code, verifier).""" + client_info = await _register_client(http) + assert client_info.client_id is not None + verifier, challenge = _pkce_pair() + response = await http.get( + "/authorize", + params={ + "response_type": "code", + "client_id": client_info.client_id, + "redirect_uri": REDIRECT_URI, + "code_challenge": challenge, + "code_challenge_method": "S256", + "state": "s", + }, + follow_redirects=False, + ) + assert response.status_code == 302 + redirect = urlsplit(response.headers["location"]) + assert f"{redirect.scheme}://{redirect.netloc}{redirect.path}" == REDIRECT_URI + code = parse_qs(redirect.query)["code"][0] + return client_info, code, verifier + + +def _token_form(client_info: OAuthClientInformationFull, **overrides: str) -> dict[str, str]: + """Build the form body for an authorization-code token request, with the defaults a real client would send.""" + assert client_info.client_id is not None + assert client_info.client_secret is not None + form = { + "grant_type": "authorization_code", + "client_id": client_info.client_id, + "client_secret": client_info.client_secret, + "redirect_uri": REDIRECT_URI, + } + form.update(overrides) + return form + + +@requirement("hosting:auth:as:authorize-requires-pkce") +async def test_authorize_without_a_code_challenge_is_rejected_with_invalid_request( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """An authorize request omitting `code_challenge` is redirected back with `error=invalid_request`. + + PKCE is mandatory: the bundled authorize handler models `code_challenge` as a required field, so + a code without a stored challenge can never be issued. That makes the PKCE-downgrade attack (a + token request carrying a verifier for a code minted without a challenge) structurally impossible + through these handlers, so no separate downgrade-guard test is needed. + """ + http, _ = as_app + client_info = await _register_client(http) + assert client_info.client_id is not None + + response = await http.get( + "/authorize", + params={ + "response_type": "code", + "client_id": client_info.client_id, + "redirect_uri": REDIRECT_URI, + "state": "abc", + }, + follow_redirects=False, + ) + + assert response.status_code == 302 + redirect = urlsplit(response.headers["location"]) + assert f"{redirect.scheme}://{redirect.netloc}{redirect.path}" == REDIRECT_URI + params = parse_qs(redirect.query) + assert params["error"] == ["invalid_request"] + assert params["state"] == ["abc"] + assert "code_challenge" in params["error_description"][0] + + +@requirement("hosting:auth:as:verifier-mismatch") +async def test_a_mismatched_code_verifier_is_rejected_with_invalid_grant( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """A token exchange whose `code_verifier` does not hash to the stored challenge is rejected.""" + http, _ = as_app + client_info, code, _ = await _mint_code(http) + + response = await http.post("/token", data=_token_form(client_info, code=code, code_verifier="0" * 64)) + + assert response.status_code == 400 + assert response.json() == snapshot({"error": "invalid_grant", "error_description": "incorrect code_verifier"}) + + +@requirement("hosting:auth:as:code-single-use") +async def test_reusing_an_authorization_code_is_rejected_with_invalid_grant( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """An authorization code can be exchanged exactly once; a second exchange is `invalid_grant`. + + The handler does not track used codes itself: it returns `invalid_grant` whenever the provider's + `load_authorization_code` returns None, and the in-memory provider deletes the code on first + exchange. The test proves the combination enforces single-use; a provider that did not consume + codes would not get this guarantee from the handler. + """ + http, _ = as_app + client_info, code, verifier = await _mint_code(http) + form = _token_form(client_info, code=code, code_verifier=verifier) + + first = await http.post("/token", data=form) + assert first.status_code == 200 + assert first.json()["token_type"] == "Bearer" + + second = await http.post("/token", data=form) + assert second.status_code == 400 + assert second.json() == snapshot( + {"error": "invalid_grant", "error_description": "authorization code does not exist"} + ) + + +@requirement("hosting:auth:as:redirect-uri-binding") +async def test_a_redirect_uri_differing_from_authorize_is_rejected_at_the_token_endpoint( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """A token exchange whose `redirect_uri` differs from the one used at authorize is rejected. + + This is the security-critical half of redirect-URI binding: a code intercepted via redirect + substitution cannot be redeemed because the attacker cannot reproduce the original authorize + redirect URI at the token endpoint. RFC 6749 §5.2 specifies `invalid_grant` for this case; + the SDK returns `invalid_request` (see the divergence on the requirement). The rejection + itself is the security property and is correct. + """ + http, _ = as_app + client_info, code, verifier = await _mint_code(http) + + response = await http.post( + "/token", + data=_token_form(client_info, code=code, code_verifier=verifier, redirect_uri=f"{REDIRECT_URI}/different"), + ) + + assert response.status_code == 400 + assert response.json() == snapshot( + { + "error": "invalid_request", + "error_description": "redirect_uri did not match the one used when creating auth code", + } + ) + + +@requirement("hosting:auth:as:token-cache-headers") +async def test_token_responses_carry_cache_control_no_store( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """Every token-endpoint response (success and error) carries `Cache-Control: no-store`.""" + http, _ = as_app + client_info, code, verifier = await _mint_code(http) + form = _token_form(client_info, code=code, code_verifier=verifier) + + success = await http.post("/token", data=form) + assert success.status_code == 200 + assert success.headers["cache-control"] == "no-store" + assert success.headers["pragma"] == "no-cache" + + failure = await http.post("/token", data=form) + assert failure.status_code == 400 + assert failure.headers["cache-control"] == "no-store" + assert failure.headers["pragma"] == "no-cache" + + +@requirement("hosting:auth:as:register-error-response") +async def test_registration_with_invalid_metadata_is_rejected_with_400( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """Invalid client metadata at the registration endpoint returns 400 with an RFC 7591 error body.""" + http, _ = as_app + + malformed = await http.post("/register", json={"redirect_uris": ["not-a-url"]}) + assert malformed.status_code == 400 + assert malformed.json()["error"] == "invalid_client_metadata" + + body = oauth_client_metadata().model_dump(mode="json", exclude_none=True) + + no_auth_code = await http.post("/register", json=body | {"grant_types": ["refresh_token"]}) + assert no_auth_code.status_code == 400 + assert no_auth_code.json() == snapshot( + {"error": "invalid_client_metadata", "error_description": "grant_types must include 'authorization_code'"} + ) + + bad_scope = await http.post("/register", json=body | {"scope": "forbidden"}) + assert bad_scope.status_code == 400 + body = bad_scope.json() + assert body["error"] == "invalid_client_metadata" + # The description embeds a set difference whose ordering is not stable, so assert the prefix. + assert body["error_description"].startswith("Requested scopes are not valid: ") + + +@requirement("hosting:auth:as:redirect-uri-binding") +async def test_authorize_with_an_unregistered_redirect_uri_is_rejected_directly( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """An authorize request naming an unregistered `redirect_uri` returns 400 without redirecting to it. + + The security property is that the authorization server never redirects to an unvalidated URI: + the response is a direct JSON error to the user agent, not a 302 to the attacker's host. + """ + http, _ = as_app + client_info = await _register_client(http) + assert client_info.client_id is not None + _, challenge = _pkce_pair() + + response = await http.get( + "/authorize", + params={ + "response_type": "code", + "client_id": client_info.client_id, + "redirect_uri": "http://127.0.0.1:8000/evil", + "code_challenge": challenge, + "code_challenge_method": "S256", + }, + follow_redirects=False, + ) + + assert response.status_code == 400 + assert "location" not in response.headers + body = response.json() + assert body["error"] == "invalid_request" + assert "not registered" in body["error_description"] + + +@requirement("hosting:auth:as:redirect-uri-scheme") +async def test_a_non_loopback_http_redirect_uri_is_accepted_at_registration( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """A registration carrying a non-HTTPS, non-loopback redirect URI is accepted. + + The spec requires every redirect URI to be either HTTPS or a loopback host; the bundled + registration handler does not enforce this and registers `http://evil.example/callback` + successfully. See the divergence on the requirement. + """ + http, provider = as_app + body = oauth_client_metadata().model_dump(mode="json", exclude_none=True) + body["redirect_uris"] = ["http://evil.example/callback"] + + response = await http.post("/register", json=body) + + assert response.status_code == 201 + info = OAuthClientInformationFull.model_validate_json(response.content) + assert [str(u) for u in (info.redirect_uris or [])] == ["http://evil.example/callback"] + assert info.client_id in provider.clients diff --git a/tests/interaction/auth/test_authorize_token.py b/tests/interaction/auth/test_authorize_token.py new file mode 100644 index 0000000000..cb8524c097 --- /dev/null +++ b/tests/interaction/auth/test_authorize_token.py @@ -0,0 +1,399 @@ +"""Authorization-request, token-request, and PKCE wire-level invariants of the SDK's OAuth client. + +Every test connects a real `Client` end to end via `connect_with_oauth`; the assertions are on +the parsed authorize URL and the recorded `/token` form body, because those wire shapes are what +the spec mandates and `Client` cannot observe them. The recording uses `record_requests`, which +snapshots each request at send time so the auth flow's in-place header mutation on retry never +affects what was captured for the first attempt. + +Tests #1/#2/#4/#5 share one `recorded_oauth_flow` fixture (one connect, several disjoint +assertions on its recording); the others connect fresh because each needs a different harness +configuration. +""" + +import base64 +import hashlib +import json +import re +from collections.abc import AsyncIterator +from dataclasses import dataclass +from urllib.parse import parse_qsl, quote, urlsplit + +import anyio +import pytest +from inline_snapshot import snapshot +from pydantic import AnyHttpUrl, AnyUrl + +from mcp import types +from mcp.client.auth import OAuthFlowError +from mcp.server import Server, ServerRequestContext +from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata +from mcp.types import ListToolsResult, Tool +from tests.interaction._connect import BASE_URL +from tests.interaction._requirements import requirement +from tests.interaction.auth._harness import ( + REDIRECT_URI, + HeadlessOAuth, + InMemoryTokenStorage, + RecordedRequest, + auth_settings, + connect_with_oauth, + first_challenge_shim, + record_requests, + shimmed_app, +) +from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider + +pytestmark = pytest.mark.anyio + +PRM_PATH = "/.well-known/oauth-protected-resource/mcp" +ASM_PATH = "/.well-known/oauth-authorization-server" + + +async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="echo", input_schema={"type": "object"})]) + + +def authorize_params(authorize_url: str) -> dict[str, str]: + """Parse the authorize URL's query string into a flat dict (one value per key).""" + return dict(parse_qsl(urlsplit(authorize_url).query)) + + +def form_body(request: RecordedRequest) -> dict[str, str]: + """Parse an `application/x-www-form-urlencoded` request body into a flat dict.""" + return dict(parse_qsl(request.content.decode())) + + +def find(recorded: list[RecordedRequest], method: str, path: str) -> list[RecordedRequest]: + """Filter recorded requests by method and exact path.""" + return [r for r in recorded if r.method == method and r.path == path] + + +@dataclass +class RecordedFlow: + """One completed OAuth connect: every recorded request, plus the parsed authorize URL params.""" + + requests: list[RecordedRequest] + authorize_url: str + + @property + def authorize(self) -> dict[str, str]: + return authorize_params(self.authorize_url) + + @property + def token_request(self) -> RecordedRequest: + token_posts = find(self.requests, "POST", "/token") + assert len(token_posts) == 1 + return token_posts[0] + + +@pytest.fixture +async def recorded_oauth_flow() -> AsyncIterator[RecordedFlow]: + """Run one full OAuth connect with default configuration and yield its recorded wire traffic. + + `valid_scopes` includes `offline_access` so the AS metadata advertises it and the SDK's + SEP-2207 auto-append (and the resulting `prompt=consent`) is exercised; `required_scopes` + stays at `["mcp"]` so the issued token still passes the bearer middleware. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + settings = auth_settings(required_scopes=["mcp"], valid_scopes=["mcp", "offline_access"]) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, settings=settings, on_request=on_request) as ( + client, + headless, + ): + await client.list_tools() + + assert headless.authorize_url is not None + yield RecordedFlow(requests=recorded, authorize_url=headless.authorize_url) + + +@requirement("client-auth:pkce:s256") +@requirement("client-auth:resource-parameter") +@requirement("client-auth:authorize:offline-access-consent") +async def test_the_authorize_url_carries_s256_pkce_and_the_resource_indicator( + recorded_oauth_flow: RecordedFlow, +) -> None: + """Every spec-mandated parameter appears on the authorize URL with the right value. + + The full key set is snapshotted so a parameter added or dropped fails the test. The + `code_challenge` length bound is the RFC 7636 §4.2 grammar; an S256 challenge is in + practice always 43 characters, so the upper bound is never approached. + """ + params = recorded_oauth_flow.authorize + + assert sorted(params) == snapshot( + [ + "client_id", + "code_challenge", + "code_challenge_method", + "prompt", + "redirect_uri", + "resource", + "response_type", + "scope", + "state", + ] + ) + assert params["response_type"] == "code" + assert params["code_challenge_method"] == "S256" + assert 43 <= len(params["code_challenge"]) <= 128 + # The exact resource value depends on canonical-URI normalisation (a spec ambiguity); pin + # the stable prefix so the test does not lock in a trailing-slash decision. + assert params["resource"].startswith(BASE_URL) + assert params["state"] != "" + + assert params["scope"].split(" ") == snapshot(["mcp", "offline_access"]) + assert params["prompt"] == "consent" + + +@requirement("client-auth:pkce:s256") +async def test_the_code_verifier_on_the_token_request_hashes_to_the_code_challenge( + recorded_oauth_flow: RecordedFlow, +) -> None: + """The PKCE verifier sent on /token is the S256 pre-image of the challenge sent on /authorize. + + The verifier is also checked against RFC 7636 §4.1's length and `unreserved` charset. + """ + challenge = recorded_oauth_flow.authorize["code_challenge"] + verifier = form_body(recorded_oauth_flow.token_request)["code_verifier"] + + assert re.fullmatch(r"[A-Za-z0-9._~-]{43,128}", verifier) + assert base64.urlsafe_b64encode(hashlib.sha256(verifier.encode()).digest()).decode().rstrip("=") == challenge + + +@requirement("client-auth:state:verify") +async def test_a_mismatched_state_on_the_callback_aborts_the_flow() -> None: + """A callback whose state does not match the value sent on /authorize raises and stops the flow. + + The auth flow runs inside the streamable-HTTP client's task group, so the `OAuthFlowError` + reaches the test wrapped in nested single-element exception groups; `pytest.RaisesGroup` + asserts the leaf type and the SDK-authored message prefix (the full message embeds two + random tokens). + """ + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + headless = HeadlessOAuth(state_override="wrong-state") + + with anyio.fail_after(5): + with pytest.RaisesGroup( + pytest.RaisesExc(OAuthFlowError, match="^State parameter mismatch:"), flatten_subgroups=True + ): + # Entering the connect raises during the OAuth handshake (inside `Client.__aenter__`), + # so an `async with` body would be unreachable; entering explicitly avoids dead code. + await connect_with_oauth(server, provider=provider, headless=headless).__aenter__() + + +@requirement("client-auth:resource-parameter") +async def test_the_authorization_code_token_request_carries_grant_type_code_redirect_and_resource( + recorded_oauth_flow: RecordedFlow, +) -> None: + """The /token form body has exactly the auth-code grant fields, with redirect_uri and resource matching /authorize. + + `client_secret` is present because the SDK's dynamic-registration handler issues a secret + and the client defaults to `client_secret_post`. + """ + token_req = recorded_oauth_flow.token_request + body = form_body(token_req) + + assert sorted(body) == snapshot( + ["client_id", "client_secret", "code", "code_verifier", "grant_type", "redirect_uri", "resource"] + ) + assert body["grant_type"] == "authorization_code" + assert body["code"] != "" + assert body["redirect_uri"] == recorded_oauth_flow.authorize["redirect_uri"] + assert body["resource"] == recorded_oauth_flow.authorize["resource"] + assert token_req.headers["content-type"] == "application/x-www-form-urlencoded" + + +@requirement("client-auth:bearer-header:every-request") +async def test_every_mcp_request_after_auth_carries_the_bearer_header_and_never_a_query_token( + recorded_oauth_flow: RecordedFlow, +) -> None: + """Every MCP request after the flow has `Authorization: Bearer ...` and never `?access_token=`. + + The first /mcp POST is the unauthenticated trigger and is asserted to carry no Authorization + header; that assertion is only meaningful because the recording snapshots requests at send + time (the SDK mutates the same request object in place for the retry). + """ + mcp_posts = find(recorded_oauth_flow.requests, "POST", "/mcp") + assert len(mcp_posts) >= 3 + + assert "authorization" not in mcp_posts[0].headers + for r in mcp_posts[1:]: + assert r.headers["authorization"].startswith("Bearer ") + assert r.headers["authorization"] != "Bearer " + assert "access_token" not in dict(r.url.params) + + +@requirement("client-auth:token-endpoint-auth-method") +async def test_a_client_with_a_secret_authenticates_the_token_request_with_http_basic() -> None: + """A `client_secret_basic` client sends URL-encoded credentials in HTTP Basic, not the body. + + Credentials are URL-encoded before base64 per RFC 6749 §2.3.1; the secret contains `/` so + the encoding is observable. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + client_info = OAuthClientInformationFull( + client_id="cid", + client_secret="s/cret", + token_endpoint_auth_method="client_secret_basic", + redirect_uris=[AnyUrl(REDIRECT_URI)], + grant_types=["authorization_code", "refresh_token"], + scope="mcp", + ) + await provider.register_client(client_info) + storage = InMemoryTokenStorage(client_info=client_info) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, storage=storage, on_request=on_request) as (client, _): + await client.list_tools() + + assert find(recorded, "POST", "/register") == [] + [token_req] = find(recorded, "POST", "/token") + + decoded = base64.b64decode(token_req.headers["authorization"].removeprefix("Basic ")).decode() + assert decoded == f"{quote('cid', safe='')}:{quote('s/cret', safe='')}" + assert "client_secret" not in form_body(token_req) + + +@requirement("client-auth:token-endpoint-auth-method") +async def test_the_registered_auth_method_is_used_regardless_of_as_metadata_advertised_methods() -> None: + """The token-endpoint auth method comes from the registered client info, not from AS metadata. + + The shim serves AS metadata advertising only `client_secret_basic`; the client dynamically + registers and the SDK's registration handler issues `client_secret_post`. The client uses + `client_secret_post` (secret in the body, no Basic header) because the SDK reads the + registered `token_endpoint_auth_method`, not `token_endpoint_auth_methods_supported`. Other + SDKs (TypeScript, Go) do consult the AS metadata; this test pins where the python SDK's + selection point lives. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + override = OAuthMetadata( + issuer=AnyHttpUrl(f"{BASE_URL}/"), + authorization_endpoint=AnyHttpUrl(f"{BASE_URL}/authorize"), + token_endpoint=AnyHttpUrl(f"{BASE_URL}/token"), + registration_endpoint=AnyHttpUrl(f"{BASE_URL}/register"), + scopes_supported=["mcp"], + grant_types_supported=["authorization_code", "refresh_token"], + code_challenge_methods_supported=["S256"], + token_endpoint_auth_methods_supported=["client_secret_basic"], + ) + serve = {ASM_PATH: override.model_dump_json(exclude_none=True).encode()} + + with anyio.fail_after(5): + async with connect_with_oauth( + server, provider=provider, app_shim=lambda app: shimmed_app(app, serve=serve), on_request=on_request + ) as (client, _): + await client.list_tools() + + [register] = find(recorded, "POST", "/register") + assert json.loads(register.content).get("token_endpoint_auth_method") is None + + [token_req] = find(recorded, "POST", "/token") + body = form_body(token_req) + assert "client_secret" in body + assert body["client_secret"] != "" + assert "authorization" not in token_req.headers + + +@requirement("client-auth:scope-selection:priority") +async def test_scope_is_selected_from_the_www_authenticate_challenge_over_prm_metadata() -> None: + """When the 401 challenge carries `scope=`, that value is requested instead of the PRM scopes. + + The SDK's bearer middleware never emits `scope=` in WWW-Authenticate (see the divergence + on `hosting:auth:scope-403`), so the test supplies the first 401 itself via + `first_challenge_shim` and disables token verification so the post-auth retry succeeds + regardless of the granted scope. PRM advertises `["from-prm"]` (it mirrors + `required_scopes`); the challenge says `from-header`; the authorize URL must carry + `from-header`. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider(default_scopes=["from-header"]) + server = Server("guarded", on_list_tools=list_tools) + settings = auth_settings(required_scopes=["from-prm"], valid_scopes=["from-header", "from-prm"]) + challenge = f'Bearer scope="from-header", resource_metadata="{BASE_URL}{PRM_PATH}"' + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + settings=settings, + verify_tokens=False, + app_shim=first_challenge_shim(challenge), + on_request=on_request, + ) as (client, headless): + await client.list_tools() + + assert headless.authorize_url is not None + assert authorize_params(headless.authorize_url)["scope"] == "from-header" + + [register] = find(recorded, "POST", "/register") + assert json.loads(register.content)["scope"] == "from-header" + + +@requirement("client-auth:pkce:refuse-if-unsupported") +async def test_pkce_is_still_sent_when_as_metadata_omits_code_challenge_methods_supported() -> None: + """AS metadata without `code_challenge_methods_supported` does not stop the client sending PKCE. + + The spec says the client MUST refuse to proceed in this case; the SDK proceeds and the flow + completes. See the divergence on the requirement. + """ + override = OAuthMetadata( + issuer=AnyHttpUrl(f"{BASE_URL}/"), + authorization_endpoint=AnyHttpUrl(f"{BASE_URL}/authorize"), + token_endpoint=AnyHttpUrl(f"{BASE_URL}/token"), + registration_endpoint=AnyHttpUrl(f"{BASE_URL}/register"), + scopes_supported=["mcp"], + grant_types_supported=["authorization_code", "refresh_token"], + ) + assert override.code_challenge_methods_supported is None + serve = {ASM_PATH: override.model_dump_json(exclude_none=True).encode()} + + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + with anyio.fail_after(5): + async with connect_with_oauth( + server, provider=provider, app_shim=lambda app: shimmed_app(app, serve=serve) + ) as (client, headless): + result = await client.list_tools() + + assert headless.authorize_url is not None + params = authorize_params(headless.authorize_url) + assert params["code_challenge_method"] == "S256" + assert params["code_challenge"] != "" + assert result.tools[0].name == "echo" + + +@requirement("client-auth:authorize:error-surfaces") +async def test_an_authorize_error_on_the_callback_aborts_the_flow_before_the_token_request() -> None: + """An `error=` redirect from /authorize aborts the flow with no /token request issued. + + The SDK's callback contract is `() -> (code, state)` with no error form, so the failure is + observed as an empty code reaching the SDK and `OAuthFlowError("No authorization code + received")` being raised. The actual `error` value from the redirect is not surfaced to the + caller; that gap is noted in the manifest. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider(deny_authorize=True) + server = Server("guarded", on_list_tools=list_tools) + headless = HeadlessOAuth() + + with anyio.fail_after(5): + with pytest.RaisesGroup( + pytest.RaisesExc(OAuthFlowError, match="^No authorization code received$"), flatten_subgroups=True + ): + await connect_with_oauth(server, provider=provider, headless=headless, on_request=on_request).__aenter__() + + assert headless.error == "access_denied" + assert find(recorded, "POST", "/token") == [] diff --git a/tests/interaction/auth/test_bearer.py b/tests/interaction/auth/test_bearer.py new file mode 100644 index 0000000000..341a8e0db9 --- /dev/null +++ b/tests/interaction/auth/test_bearer.py @@ -0,0 +1,189 @@ +"""Resource-server bearer-token gate: status codes and `WWW-Authenticate` for each token shape. + +These tests mount only the resource-server side of the auth wiring (a `StaticTokenVerifier` +seeded with hand-built tokens, no authorization-server provider) and speak raw HTTP, since +every assertion is about HTTP semantics the SDK `Client` cannot observe: the 401/403 status, +the `WWW-Authenticate` header structure, and that a wrong-audience token reaches the MCP +endpoint behind the gate. The flow side of the same 401 is `test_flow.py`'s flagship test. +""" + +import time +from collections.abc import AsyncIterator + +import httpx +import pytest +from inline_snapshot import snapshot + +from mcp.server import Server +from mcp.server.auth.provider import AccessToken +from mcp.types import JSONRPCResponse +from tests.interaction._connect import base_headers, initialize_body, mounted_app +from tests.interaction._requirements import requirement +from tests.interaction.auth._harness import StaticTokenVerifier, auth_settings + +pytestmark = pytest.mark.anyio + +REQUIRED_SCOPE = "mcp:read" +RESOURCE_METADATA_URL = "http://127.0.0.1:8000/.well-known/oauth-protected-resource/mcp" + +_FUTURE = int(time.time()) + 3600 +_PAST = int(time.time()) - 3600 + +TOKENS = { + "tok-valid": AccessToken(token="tok-valid", client_id="c", scopes=[REQUIRED_SCOPE], expires_at=_FUTURE), + "tok-expired": AccessToken(token="tok-expired", client_id="c", scopes=[REQUIRED_SCOPE], expires_at=_PAST), + "tok-noscope": AccessToken(token="tok-noscope", client_id="c", scopes=["other:thing"], expires_at=_FUTURE), + "tok-wrong-aud": AccessToken( + token="tok-wrong-aud", + client_id="c", + scopes=[REQUIRED_SCOPE], + expires_at=_FUTURE, + resource="https://other.example/mcp", + ), +} + + +@pytest.fixture +async def protected() -> AsyncIterator[httpx.AsyncClient]: + """A bearer-gated streamable-HTTP app (resource server only) on the in-process bridge.""" + server = Server("rs") + settings = auth_settings(required_scopes=[REQUIRED_SCOPE]) + async with mounted_app(server, auth=settings, token_verifier=StaticTokenVerifier(TOKENS)) as (http, _): + yield http + + +async def post_mcp( + http: httpx.AsyncClient, *, bearer: str | None = None, query: dict[str, str] | None = None +) -> httpx.Response: + """POST an initialize body to `/mcp`, optionally with a bearer token and/or a query string.""" + headers = base_headers() + if bearer is not None: + headers["authorization"] = f"Bearer {bearer}" + return await http.post("/mcp", headers=headers, params=query, json=initialize_body()) + + +def parse_www_authenticate(value: str) -> dict[str, str]: + """Parse a `Bearer k="v", k="v"` challenge into a dict. + + The SDK emits each parameter exactly once, comma-space separated, with double-quoted + values that contain no quotes themselves; this helper relies on that and would fail + visibly if the format changed. + """ + scheme, _, params = value.partition(" ") + assert scheme == "Bearer" + return {key: quoted.strip('"') for key, _, quoted in (pair.partition("=") for pair in params.split(", "))} + + +@requirement("hosting:auth:missing-401") +async def test_a_request_with_no_authorization_header_is_challenged_with_resource_metadata( + protected: httpx.AsyncClient, +) -> None: + """No `Authorization` header → 401 with a `WWW-Authenticate` carrying `resource_metadata`. + + The snapshot pins current behaviour: the SDK collapses the no-header, unknown-token, and + expired-token cases into one challenge (`error="invalid_token"`, no `scope` parameter). The + spec says the discovery-time challenge SHOULD include `scope` and RFC 6750 says the + no-credentials case SHOULD NOT carry an error code; both gaps are recorded as the divergence + on this requirement. Asserting the dict equals an exact key set also pins that no parameter + appears twice. + """ + response = await post_mcp(protected) + + assert response.status_code == 401 + assert response.headers["www-authenticate"] == snapshot( + 'Bearer error="invalid_token", error_description="Authentication required", ' + 'resource_metadata="http://127.0.0.1:8000/.well-known/oauth-protected-resource/mcp"' + ) + assert parse_www_authenticate(response.headers["www-authenticate"]) == { + "error": "invalid_token", + "error_description": "Authentication required", + "resource_metadata": RESOURCE_METADATA_URL, + } + assert response.json() == snapshot({"error": "invalid_token", "error_description": "Authentication required"}) + + +@requirement("hosting:auth:invalid-401") +async def test_an_unrecognized_bearer_token_is_answered_401_invalid_token(protected: httpx.AsyncClient) -> None: + """A token the verifier does not recognize is answered 401 `invalid_token`. + + The challenge is identical to the no-header case (the backend returns `None` for both); the + missing `scope` parameter is the recorded divergence on this requirement. + """ + response = await post_mcp(protected, bearer="tok-unknown") + + assert response.status_code == 401 + assert parse_www_authenticate(response.headers["www-authenticate"]) == { + "error": "invalid_token", + "error_description": "Authentication required", + "resource_metadata": RESOURCE_METADATA_URL, + } + + +@requirement("hosting:auth:expired-401") +async def test_an_expired_token_is_answered_401(protected: httpx.AsyncClient) -> None: + """A token whose `expires_at` is in the past is answered 401 `invalid_token`. + + The expiry check is the bearer backend's, against the wall clock; the test seeds a concrete + past timestamp so no time mocking is involved. The missing `scope` parameter is the recorded + divergence on this requirement. + """ + response = await post_mcp(protected, bearer="tok-expired") + + assert response.status_code == 401 + assert parse_www_authenticate(response.headers["www-authenticate"])["error"] == "invalid_token" + + +@requirement("hosting:auth:scope-403") +async def test_a_token_missing_a_required_scope_is_answered_403_insufficient_scope_without_a_scope_param( + protected: httpx.AsyncClient, +) -> None: + """A token lacking the required scope is answered 403 `insufficient_scope`, with no `scope` parameter. + + The spec's runtime-insufficient-scope guidance says the challenge SHOULD include `scope` + naming the required scope; the SDK never emits it, recorded as the divergence on this + requirement. The SDK client reads `scope` from this header to drive step-up, so the gap is + a resource-server/client asymmetry. + """ + response = await post_mcp(protected, bearer="tok-noscope") + + assert response.status_code == 403 + parsed = parse_www_authenticate(response.headers["www-authenticate"]) + assert parsed == { + "error": "insufficient_scope", + "error_description": f"Required scope: {REQUIRED_SCOPE}", + "resource_metadata": RESOURCE_METADATA_URL, + } + assert "scope" not in parsed + + +@requirement("hosting:auth:aud-validation") +async def test_a_token_with_a_mismatched_audience_is_accepted(protected: httpx.AsyncClient) -> None: + """A token whose `resource` does not match the server's resource identifier is accepted. + + The spec mandates the resource server validate the token's audience; the bearer backend + never inspects `AccessToken.resource`, so the request passes the gate and the MCP endpoint + serves it. This pins current behaviour with the divergence recorded on the requirement. + """ + response = await post_mcp(protected, bearer="tok-wrong-aud") + + assert response.status_code == 200 + assert response.headers["content-type"].startswith("text/event-stream") + # The body is finite SSE: a result event followed by stream close. Pull the JSON-RPC response + # out of the buffered text to prove the MCP endpoint actually answered the initialize request. + [data] = [line.removeprefix("data: ") for line in response.text.splitlines() if line.startswith("data: ")] + assert "protocolVersion" in JSONRPCResponse.model_validate_json(data).result + + +@requirement("hosting:auth:query-token-ignored") +async def test_an_access_token_in_the_query_string_is_not_accepted(protected: httpx.AsyncClient) -> None: + """A valid token presented in the URI query string is treated as no authentication. + + The bearer backend reads only the `Authorization` header, so `?access_token=...` is never + consulted; the request is treated as unauthenticated and answered 401. This satisfies, by + absence, the security best-practice that resource servers must not accept query-string + tokens. + """ + response = await post_mcp(protected, query={"access_token": "tok-valid"}) + + assert response.status_code == 401 + assert parse_www_authenticate(response.headers["www-authenticate"])["error"] == "invalid_token" diff --git a/tests/interaction/auth/test_discovery.py b/tests/interaction/auth/test_discovery.py new file mode 100644 index 0000000000..68c33c8a2d --- /dev/null +++ b/tests/interaction/auth/test_discovery.py @@ -0,0 +1,333 @@ +"""Protected-resource and authorization-server metadata discovery, end to end. + +Every client-side test connects a real `Client` via `connect_with_oauth` and asserts on the +recorded request paths the discovery probes produced; the discovery URL ordering is a wire +detail `Client` cannot observe directly but the recording can. Tests that need a metadata +endpoint to 404 or return alternate content wrap the SDK's app in `shimmed_app` while leaving +the real authorize and token endpoints behind it, so the rest of the flow runs unaltered. + +The two server-side tests (#5, #6) drive raw httpx against `mounted_app` because their +assertions are the metadata response bodies and headers, which `Client` does not surface. +""" + +import json + +import anyio +import pytest +from inline_snapshot import snapshot +from pydantic import AnyHttpUrl + +from mcp import types +from mcp.client.auth import OAuthFlowError, OAuthRegistrationError +from mcp.server import Server, ServerRequestContext +from mcp.shared.auth import OAuthMetadata, ProtectedResourceMetadata +from mcp.types import ListToolsResult, Tool +from tests.interaction._connect import BASE_URL, mounted_app +from tests.interaction._requirements import requirement +from tests.interaction.auth._harness import ( + RecordedRequest, + auth_settings, + connect_with_oauth, + metadata_body, + record_requests, + shim, +) +from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider + +pytestmark = pytest.mark.anyio + +PRM_PATH_SUFFIXED = "/.well-known/oauth-protected-resource/mcp" +PRM_ROOT = "/.well-known/oauth-protected-resource" +ASM_ROOT = "/.well-known/oauth-authorization-server" +OIDC_ROOT = "/.well-known/openid-configuration" + + +async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="probe", input_schema={"type": "object"})]) + + +def discovery_gets(recorded: list[RecordedRequest]) -> list[str]: + """Return the well-known GET paths in recorded order, ignoring everything else.""" + return [r.path for r in recorded if r.method == "GET" and "/.well-known/" in r.path] + + +def real_asm() -> OAuthMetadata: + """Build an authorization-server metadata document pointing at the real co-hosted endpoints.""" + return OAuthMetadata( + issuer=AnyHttpUrl(BASE_URL), + authorization_endpoint=AnyHttpUrl(f"{BASE_URL}/authorize"), + token_endpoint=AnyHttpUrl(f"{BASE_URL}/token"), + registration_endpoint=AnyHttpUrl(f"{BASE_URL}/register"), + scopes_supported=["mcp"], + grant_types_supported=["authorization_code", "refresh_token"], + code_challenge_methods_supported=["S256"], + ) + + +@requirement("client-auth:prm-discovery:fallback-order") +async def test_prm_discovery_uses_the_resource_metadata_url_from_www_authenticate() -> None: + """The first protected-resource probe is the URL the 401's `WWW-Authenticate` header supplied. + + With co-hosted defaults the header carries the path-suffixed well-known URL; the client + fetches that one first and, because it succeeds, never falls back. The single-probe + sequence proves priority 1. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, on_request=on_request) as (client, _): + await client.list_tools() + + assert discovery_gets(recorded) == snapshot([PRM_PATH_SUFFIXED, ASM_ROOT]) + assert (recorded[0].method, recorded[0].path) == ("POST", "/mcp") + assert (recorded[1].method, recorded[1].path) == ("GET", PRM_PATH_SUFFIXED) + + +@requirement("client-auth:prm-discovery:fallback-order") +async def test_prm_discovery_falls_back_from_path_well_known_to_root_on_404() -> None: + """When the path-suffixed PRM well-known 404s, the client falls back to the root well-known. + + The exact GET count is not asserted: the WWW-Authenticate URL equals the path well-known + here, so the SDK probes it twice (once as priority 1, once as priority 2) before reaching + root. Asserting "path before root, root reached, then the flow proceeds" pins the spec + invariant; the duplicate probe is an implementation detail. The served PRM body carries an + unrecognized field to prove the client's parser ignores unknown members (RFC 9728 §3.2). + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + prm = ProtectedResourceMetadata( + resource=AnyHttpUrl(f"{BASE_URL}/mcp"), authorization_servers=[AnyHttpUrl(BASE_URL)] + ) + app_shim = shim( + not_found=frozenset({PRM_PATH_SUFFIXED}), + serve={PRM_ROOT: metadata_body(prm, x_unknown_extension="ignored")}, + ) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, app_shim=app_shim, on_request=on_request) as ( + client, + _, + ): + await client.list_tools() + + well_known = discovery_gets(recorded) + assert PRM_PATH_SUFFIXED in well_known + assert PRM_ROOT in well_known + assert well_known.index(PRM_PATH_SUFFIXED) < well_known.index(PRM_ROOT) + assert any(r.path == "/authorize" for r in recorded) + + +@requirement("client-auth:prm-discovery:no-prm-fallback") +async def test_when_every_prm_probe_fails_the_client_discovers_as_metadata_at_the_server_origin() -> None: + """When every protected-resource metadata probe 404s, the client falls back to the legacy path. + + The legacy 2025-03-26 behaviour: with no PRM document available, treat the MCP server's + origin as the authorization server and fetch its `/.well-known/oauth-authorization-server` + directly. The real co-hosted ASM endpoint is at exactly that location, so the flow completes. + The recorded sequence shows both PRM well-known paths probed (and failed) before ASM_ROOT. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + app_shim = shim(not_found=frozenset({PRM_PATH_SUFFIXED, PRM_ROOT})) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, app_shim=app_shim, on_request=on_request) as ( + client, + _, + ): + result = await client.list_tools() + + well_known = discovery_gets(recorded) + assert PRM_PATH_SUFFIXED in well_known + assert PRM_ROOT in well_known + assert well_known[-1] == ASM_ROOT + assert all(well_known.index(prm) < well_known.index(ASM_ROOT) for prm in (PRM_PATH_SUFFIXED, PRM_ROOT)) + assert result.tools[0].name == "probe" + + +@requirement("client-auth:dcr:registration-error-surfaces") +async def test_a_400_from_the_registration_endpoint_surfaces_as_a_registration_error() -> None: + """A 400 from `/register` surfaces as `OAuthRegistrationError` carrying the server's body. + + The shim makes `/register` return RFC 7591's `invalid_client_metadata`; the SDK reads the + body and raises with the status and text in the message, before any authorize or token + request is made. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + error_body = json.dumps({"error": "invalid_client_metadata", "error_description": "no"}).encode() + app_shim = shim(serve={"/register": (400, error_body)}) + + with anyio.fail_after(5): + with pytest.RaisesGroup( + pytest.RaisesExc(OAuthRegistrationError, match=r"^Registration failed: 400 .*invalid_client_metadata"), + flatten_subgroups=True, + ): + await connect_with_oauth(server, provider=provider, app_shim=app_shim, on_request=on_request).__aenter__() + + assert [r.path for r in recorded if r.path in ("/authorize", "/token")] == [] + + +@requirement("client-auth:prm-resource-mismatch") +async def test_prm_with_a_mismatched_resource_aborts_the_flow_before_authorize() -> None: + """A PRM document whose `resource` does not cover the server URL aborts the flow. + + The shim serves PRM at the URL the WWW-Authenticate header supplies, but with a `resource` + on a different path; `check_resource_allowed` rejects it and `OAuthFlowError` is raised + before any authorize or token request is made. The error reaches the test wrapped in nested + single-element exception groups by the streamable-HTTP client's task group. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + prm = ProtectedResourceMetadata( + resource=AnyHttpUrl(f"{BASE_URL}/other"), authorization_servers=[AnyHttpUrl(BASE_URL)] + ) + app_shim = shim(serve={PRM_PATH_SUFFIXED: metadata_body(prm)}) + + with anyio.fail_after(5): + with pytest.RaisesGroup( + pytest.RaisesExc(OAuthFlowError, match="^Protected resource .* does not match expected"), + flatten_subgroups=True, + ): + await connect_with_oauth(server, provider=provider, app_shim=app_shim, on_request=on_request).__aenter__() + + assert [r.path for r in recorded if r.path in ("/authorize", "/token")] == [] + + +@requirement("client-auth:as-metadata-discovery:priority-order") +@pytest.mark.parametrize( + ("authorization_server", "not_found", "serve_at", "expected_order"), + [ + pytest.param( + f"{BASE_URL}/", + frozenset({ASM_ROOT}), + OIDC_ROOT, + [ASM_ROOT, OIDC_ROOT], + id="root-issuer", + ), + pytest.param( + f"{BASE_URL}/tenant", + frozenset({f"{ASM_ROOT}/tenant", f"{OIDC_ROOT}/tenant"}), + "/tenant/.well-known/openid-configuration", + [f"{ASM_ROOT}/tenant", f"{OIDC_ROOT}/tenant", "/tenant/.well-known/openid-configuration"], + id="path-issuer", + ), + ], +) +async def test_as_metadata_discovery_falls_back_through_the_spec_endpoint_order( + authorization_server: str, not_found: frozenset[str], serve_at: str, expected_order: list[str] +) -> None: + """Authorization-server metadata is fetched at the spec's endpoints in the spec's order. + + The shim 404s every endpoint before the last so the recording proves each probe and its + position. For an issuer URL with no path the order is OAuth root then OIDC root; for an + issuer URL with a path component it is OAuth path-inserted, OIDC path-inserted, then OIDC + path-appended (the spec's three-endpoint MUST). The path-issuer case is driven by serving + a PRM whose `authorization_servers` carries the path; the SDK's own AS routes stay at root + (the served body points at the real `/authorize` and `/token`). The served bodies carry an + unrecognized field to prove the client's parser ignores unknown members (RFC 8414 §3.2). + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + prm = ProtectedResourceMetadata( + resource=AnyHttpUrl(f"{BASE_URL}/mcp"), authorization_servers=[AnyHttpUrl(authorization_server)] + ) + app_shim = shim( + not_found=not_found, + serve={ + PRM_PATH_SUFFIXED: metadata_body(prm), + serve_at: metadata_body(real_asm(), x_unknown_extension="ignored"), + }, + ) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, app_shim=app_shim, on_request=on_request) as ( + client, + _, + ): + await client.list_tools() + + assert discovery_gets(recorded) == [PRM_PATH_SUFFIXED, *expected_order] + + +@requirement("hosting:auth:metadata-endpoints") +@requirement("hosting:auth:prm:authorization-servers-field") +async def test_the_prm_endpoint_serves_the_resource_url_and_at_least_one_authorization_server() -> None: + """The protected-resource metadata document the SDK serves identifies the resource and an authorization server. + + Also asserts the response is `application/json` (RFC 9728 §3.2) and that fields the SDK has + no value for are absent rather than null (`PydanticJSONResponse` serializes with + `exclude_none=True`, satisfying RFC 9728 §3.2's omit-zero-value rule). + """ + server = Server("bare") + provider = InMemoryAuthorizationServerProvider() + + async with mounted_app(server, auth=auth_settings(), auth_server_provider=provider) as (http, _): + response = await http.get(PRM_PATH_SUFFIXED) + + assert response.status_code == 200 + assert response.headers["content-type"].startswith("application/json") + + document = json.loads(response.content) + assert "resource_documentation" not in document + assert "scopes_supported" in document + + metadata = ProtectedResourceMetadata.model_validate(document) + assert str(metadata.resource).rstrip("/") == f"{BASE_URL}/mcp" + assert len(metadata.authorization_servers) >= 1 + assert metadata.bearer_methods_supported == ["header"] + + +@requirement("hosting:auth:as-router") +async def test_as_metadata_advertises_authorize_token_registration_and_s256() -> None: + """The authorization-server metadata document the SDK serves names the required endpoints and S256.""" + server = Server("bare") + provider = InMemoryAuthorizationServerProvider() + + async with mounted_app(server, auth=auth_settings(), auth_server_provider=provider) as (http, _): + response = await http.get(ASM_ROOT) + + assert response.status_code == 200 + assert response.headers["content-type"].startswith("application/json") + + metadata = OAuthMetadata.model_validate_json(response.content) + assert str(metadata.issuer).rstrip("/") == BASE_URL + assert str(metadata.authorization_endpoint) == f"{BASE_URL}/authorize" + assert str(metadata.token_endpoint) == f"{BASE_URL}/token" + assert str(metadata.registration_endpoint) == f"{BASE_URL}/register" + assert metadata.response_types_supported == ["code"] + assert metadata.code_challenge_methods_supported is not None + assert "S256" in metadata.code_challenge_methods_supported + + +@requirement("client-auth:as-metadata-discovery:issuer-validation") +async def test_as_metadata_with_a_mismatched_issuer_is_accepted_and_the_flow_proceeds() -> None: + """Authorization-server metadata whose `issuer` does not match the discovery URL is accepted. + + RFC 8414 §3.3 requires the client to reject the document; the SDK parses and uses it + without comparing `issuer` to the URL it was fetched from. See the divergence on the + requirement. The served body carries an unrecognized field as a fold-in proof of + unknown-field tolerance. + """ + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + metadata = real_asm() + metadata.issuer = AnyHttpUrl(f"{BASE_URL}/wrong-issuer") + app_shim = shim(serve={ASM_ROOT: metadata_body(metadata, x_unknown_extension="ignored")}) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, app_shim=app_shim) as (client, _): + result = await client.list_tools() + + assert result.tools[0].name == "probe" diff --git a/tests/interaction/auth/test_flow.py b/tests/interaction/auth/test_flow.py new file mode 100644 index 0000000000..968fc5f980 --- /dev/null +++ b/tests/interaction/auth/test_flow.py @@ -0,0 +1,239 @@ +"""End-to-end OAuth authorization-code flow against the SDK's own server, fully in process. + +Auth is HTTP-only so these tests are not transport-parametrized; each connects via +`connect_with_oauth`, which co-hosts the SDK's authorization server, protected-resource +metadata, and bearer-gated MCP endpoint on one bridge-backed Starlette app and drives the +whole flow through one `httpx.AsyncClient` carrying the SDK's `OAuthClientProvider`. The +authorize redirect completes headlessly through the same bridge, so every request the flow +makes is observable via `on_request`. +""" + +import json +from collections import Counter +from urllib.parse import parse_qs, urlsplit + +import anyio +import httpx +import pytest +from inline_snapshot import snapshot +from pydantic import AnyUrl + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.server.auth.middleware.auth_context import get_access_token +from mcp.shared.auth import OAuthClientInformationFull +from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool +from tests.interaction._connect import BASE_URL +from tests.interaction._requirements import requirement +from tests.interaction.auth._harness import ( + REDIRECT_URI, + InMemoryTokenStorage, + auth_settings, + connect_with_oauth, + oauth_client_metadata, + shimmed_app, +) +from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider +from tests.interaction.transports._bridge import StreamingASGITransport + +pytestmark = pytest.mark.anyio + + +async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="whoami", input_schema={"type": "object"})]) + + +@requirement("flow:oauth:authorization-code-roundtrip") +@requirement("client-auth:401-triggers-flow") +@requirement("hosting:auth:missing-401") +async def test_an_unauthenticated_request_is_challenged_then_the_full_oauth_flow_connects() -> None: + """Connecting to a bearer-gated server walks the full authorization-code flow and succeeds. + + Three requirements are proven by one connect: the flow runs end to end (authorization-code + roundtrip), it was triggered by a 401 on the first MCP request (401-triggers-flow), and + that 401 carried `resource_metadata` in `WWW-Authenticate` for discovery (missing-401). + The flagship test pins the recorded request sequence so the discovery → registration → + authorize → token → retry order is asserted explicitly. + + Steps the SDK is expected to perform: + 1. POST /mcp without a token → 401 with `WWW-Authenticate: Bearer resource_metadata=...`. + 2. GET the protected-resource metadata. + 3. GET the authorization-server metadata. + 4. POST /register (dynamic client registration). + 5. GET /authorize → 302 with code+state (completed by the headless redirect). + 6. POST /token (authorization-code exchange). + 7. Retry POST /mcp with `Authorization: Bearer ` → succeeds. + """ + requests: list[httpx.Request] = [] + provider = InMemoryAuthorizationServerProvider() + storage = InMemoryTokenStorage() + server = Server("guarded", on_list_tools=list_tools) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, storage=storage, on_request=requests.append) as ( + client, + headless, + ): + result = await client.list_tools() + + assert result == snapshot(ListToolsResult(tools=[Tool(name="whoami", input_schema={"type": "object"})])) + assert headless.authorize_url is not None + + paths = [(r.method, r.url.path) for r in requests] + assert Counter(paths) == snapshot( + Counter( + { + ("POST", "/mcp"): 4, + ("GET", "/.well-known/oauth-protected-resource/mcp"): 1, + ("GET", "/.well-known/oauth-authorization-server"): 1, + ("POST", "/register"): 1, + ("GET", "/authorize"): 1, + ("POST", "/token"): 1, + ("GET", "/mcp"): 1, + ("DELETE", "/mcp"): 1, + } + ) + ) + + assert (requests[0].method, requests[0].url.path) == ("POST", "/mcp") + # The recorded Request objects are live references: the auth flow mutates the original + # request's headers in place when it adds the bearer token for the retry, so the first + # entry's headers cannot be used to assert "no Authorization on the first attempt". The + # path multiset above proving discovery happened is the evidence the first attempt was 401. + + # The first PRM discovery GET carries the protocol-version header (an SDK behaviour, not a + # spec requirement on discovery requests). + prm_get = next(r for r in requests if r.url.path == "/.well-known/oauth-protected-resource/mcp") + assert prm_get.headers.get("mcp-protocol-version") == snapshot("2025-11-25") + + authorize = parse_qs(urlsplit(headless.authorize_url).query) + assert authorize["response_type"] == ["code"] + assert authorize["code_challenge_method"] == ["S256"] + assert authorize["client_id"][0] in provider.clients + + assert storage.tokens is not None + bearer = f"Bearer {storage.tokens.access_token}" + authed_mcp = [r for r in requests if r.url.path == "/mcp" and r.headers.get("authorization") == bearer] + assert len(authed_mcp) > 0 + assert storage.tokens.access_token in provider.access_tokens + + +@requirement("hosting:auth:authinfo-propagates") +async def test_the_access_token_reaches_the_tool_handler_via_get_access_token() -> None: + """A tool handler reads the request's access token through `get_access_token()`.""" + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "whoami" + token = get_access_token() + assert token is not None + return CallToolResult(content=[TextContent(text=" ".join(token.scopes))]) + + server = Server("guarded", on_list_tools=list_tools, on_call_tool=call_tool) + provider = InMemoryAuthorizationServerProvider() + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider) as (client, _): + result = await client.call_tool("whoami", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="mcp")])) + + +@requirement("client-auth:pre-registration") +async def test_a_preregistered_client_skips_registration() -> None: + """A client whose storage already holds client info uses it instead of registering. + + The provider holds the same registration server-side so the authorize and token steps + accept it; the recorded requests prove no `/register` call was made. + """ + requests: list[httpx.Request] = [] + provider = InMemoryAuthorizationServerProvider() + storage = InMemoryTokenStorage() + server = Server("guarded", on_list_tools=list_tools) + + client_info = OAuthClientInformationFull( + client_id="preregistered", + client_secret="s3cret", + token_endpoint_auth_method="client_secret_post", + redirect_uris=[AnyUrl(REDIRECT_URI)], + grant_types=["authorization_code", "refresh_token"], + scope="mcp", + ) + await provider.register_client(client_info) + storage.client_info = client_info + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, storage=storage, on_request=requests.append) as ( + client, + _, + ): + await client.list_tools() + + assert [r.url.path for r in requests].count("/register") == 0 + assert list(provider.clients) == ["preregistered"] + + +@requirement("client-auth:dcr") +async def test_the_dcr_request_carries_the_client_metadata() -> None: + """Dynamic registration sends the client's metadata and persists what the server issued. + + The body of the recorded `/register` POST carries the metadata the test supplied (with the + scope filled in from server discovery), and the server's issued client_id and secret are + persisted to storage and held by the provider. + """ + requests: list[httpx.Request] = [] + provider = InMemoryAuthorizationServerProvider() + storage = InMemoryTokenStorage() + server = Server("guarded", on_list_tools=list_tools) + + client_metadata = oauth_client_metadata() + client_metadata.software_id = "interaction-test-suite" + + with anyio.fail_after(5): + async with connect_with_oauth( + server, provider=provider, storage=storage, client_metadata=client_metadata, on_request=requests.append + ) as (client, _): + await client.list_tools() + + register = next(r for r in requests if r.url.path == "/register") + assert register.headers["content-type"] == "application/json" + body = json.loads(register.content) + assert body == snapshot( + { + "redirect_uris": ["http://127.0.0.1:8000/oauth/callback"], + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "scope": "mcp", + "client_name": "interaction-suite", + "software_id": "interaction-test-suite", + } + ) + + assert storage.client_info is not None + assert storage.client_info.client_id is not None + assert storage.client_info.client_secret is not None + assert list(provider.clients) == [storage.client_info.client_id] + + +async def test_shimmed_app_serves_overrides_404s_and_otherwise_forwards_to_the_wrapped_app() -> None: + """Harness self-test: `shimmed_app` serves canned bodies, 404s, and forwards everything else. + + Wraps a real auth-hosting Starlette app so the forward path is exercised against the SDK's + own routing; provided here so the discovery tests can rely on the shim without each adding + their own contract test. + """ + server = Server("bare") + provider = InMemoryAuthorizationServerProvider() + real_app = server.streamable_http_app(auth=auth_settings(), auth_server_provider=provider) + app = shimmed_app(real_app, not_found=frozenset({"/missing"}), serve={"/override": b'{"shimmed": true}'}) + async with server.session_manager.run(): + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http: + served = await http.get("/override") + assert served.status_code == 200 + assert served.headers["content-type"] == "application/json" + assert served.json() == {"shimmed": True} + + assert (await http.get("/missing")).status_code == 404 + + forwarded = await http.get("/.well-known/oauth-authorization-server") + assert forwarded.status_code == 200 + assert forwarded.json()["issuer"] == "http://127.0.0.1:8000/" diff --git a/tests/interaction/auth/test_lifecycle.py b/tests/interaction/auth/test_lifecycle.py new file mode 100644 index 0000000000..aa552ae8a6 --- /dev/null +++ b/tests/interaction/auth/test_lifecycle.py @@ -0,0 +1,445 @@ +"""Token lifecycle, step-up, and registration-variant flows of the SDK's OAuth client. + +Every test connects end to end via `connect_with_oauth`; the assertions are recording-first +(the recorded request sequence is asserted before, or independently of, the call result), so a +surprise in the refresh or step-up paths produces a readable diff of what fired rather than an +opaque failure. The provider knobs that drive each scenario are documented per test. +""" + +import base64 +from collections import Counter +from urllib.parse import parse_qsl, urlsplit + +import anyio +import pytest +from inline_snapshot import snapshot +from pydantic import AnyHttpUrl, AnyUrl + +from mcp import MCPError, types +from mcp.client.auth.extensions.client_credentials import ClientCredentialsOAuthProvider, PrivateKeyJWTOAuthProvider +from mcp.server import Server, ServerRequestContext +from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata +from mcp.types import INTERNAL_ERROR, ListToolsResult, Tool +from tests.interaction._connect import BASE_URL +from tests.interaction._requirements import requirement +from tests.interaction.auth._harness import ( + REDIRECT_URI, + InMemoryTokenStorage, + RecordedRequest, + auth_settings, + connect_with_oauth, + m2m_token_shim, + metadata_body, + record_requests, + shim, + step_up_shim, +) +from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider + +pytestmark = pytest.mark.anyio + +PRM_PATH = "/.well-known/oauth-protected-resource/mcp" +ASM_PATH = "/.well-known/oauth-authorization-server" +CIMD_URL = "https://client.example/.well-known/mcp-client" + + +async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="echo", input_schema={"type": "object"})]) + + +def form_body(request: RecordedRequest) -> dict[str, str]: + """Parse an `application/x-www-form-urlencoded` request body into a flat dict.""" + return dict(parse_qsl(request.content.decode())) + + +def authorize_params(authorize_url: str) -> dict[str, str]: + """Parse the authorize URL's query string into a flat dict.""" + return dict(parse_qsl(urlsplit(authorize_url).query)) + + +def find(recorded: list[RecordedRequest], method: str, path: str) -> list[RecordedRequest]: + return [r for r in recorded if r.method == method and r.path == path] + + +def path_counts(recorded: list[RecordedRequest]) -> Counter[tuple[str, str]]: + return Counter((r.method, r.path) for r in recorded) + + +def cimd_supported_metadata() -> bytes: + """AS metadata advertising `client_id_metadata_document_supported: true` (the SDK server never sets it).""" + metadata = OAuthMetadata( + issuer=AnyHttpUrl(f"{BASE_URL}/"), + authorization_endpoint=AnyHttpUrl(f"{BASE_URL}/authorize"), + token_endpoint=AnyHttpUrl(f"{BASE_URL}/token"), + registration_endpoint=AnyHttpUrl(f"{BASE_URL}/register"), + scopes_supported=["mcp"], + response_types_supported=["code"], + grant_types_supported=["authorization_code", "refresh_token"], + code_challenge_methods_supported=["S256"], + client_id_metadata_document_supported=True, + ) + return metadata_body(metadata) + + +def seeded_client(provider: InMemoryAuthorizationServerProvider, **kwargs: object) -> OAuthClientInformationFull: + """Register a client with the provider and return its info, for pre-registration and CIMD scenarios.""" + base: dict[str, object] = { + "client_id": "preregistered", + "token_endpoint_auth_method": "none", + "redirect_uris": [AnyUrl(REDIRECT_URI)], + "grant_types": ["authorization_code", "refresh_token"], + "scope": "mcp", + } + base.update(kwargs) + info = OAuthClientInformationFull.model_validate(base) + assert info.client_id is not None + provider.clients[info.client_id] = info + return info + + +@requirement("client-auth:refresh:transparent") +async def test_an_expired_access_token_is_transparently_refreshed_before_the_next_request() -> None: + """An access token the client considers expired is refreshed and the new bearer is used. + + The provider tells the client `expires_in=-3600` for the first token while keeping the + server-side `expires_at` in the future, so the connect's retry succeeds and the next + request finds the token expired and refreshes. The recorded requests prove exactly one + `grant_type=refresh_token` exchange carrying the resource indicator, and the bearer used + after the refresh is the second access token, which is the one persisted to storage. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider(issue_expired_first=True) + storage = InMemoryTokenStorage() + server = Server("guarded", on_list_tools=list_tools) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, storage=storage, on_request=on_request) as (client, _): + result = await client.list_tools() + + assert result.tools[0].name == "echo" + + token_posts = find(recorded, "POST", "/token") + bodies = [form_body(r) for r in token_posts] + assert [b["grant_type"] for b in bodies] == snapshot(["authorization_code", "refresh_token"]) + + refresh_body = bodies[1] + assert sorted(refresh_body) == snapshot(["client_id", "client_secret", "grant_type", "refresh_token", "resource"]) + assert refresh_body["refresh_token"].startswith("refresh_") + assert refresh_body["resource"].startswith(BASE_URL) + + bearers = {r.headers["authorization"] for r in recorded if r.path == "/mcp" and "authorization" in r.headers} + assert len(bearers) == 2 + assert storage.tokens is not None + assert f"Bearer {storage.tokens.access_token}" in bearers + assert storage.tokens.expires_in == 3600 + + +@requirement("client-auth:403-scope-upgrade") +async def test_a_403_insufficient_scope_triggers_one_reauthorize_with_the_challenged_scope() -> None: + """A 403 `insufficient_scope` challenge is answered by one re-authorize with the challenge's scope. + + The shim 403s the second authenticated `/mcp` POST (the `notifications/initialized` request, + which reaches the auth flow's step-up handler; the first authenticated POST is the post-401 + retry, after which the generator ends without inspecting the response). The challenge names a + wider scope; step-up reuses cached metadata and the existing client registration, + re-authorizes with the new scope, and the connect completes. The client is pre-registered + with both scopes so the server's authorize handler accepts the wider second request. One + re-authorize, one retry; the spec's SHOULD-retry-limit ("a few") is not enforced. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + storage = InMemoryTokenStorage(client_info=seeded_client(provider, scope="mcp write")) + server = Server("guarded", on_list_tools=list_tools) + settings = auth_settings(required_scopes=["mcp"], valid_scopes=["mcp", "write"]) + challenge = 'Bearer error="insufficient_scope", scope="mcp write"' + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + storage=storage, + settings=settings, + app_shim=step_up_shim(challenge), + on_request=on_request, + ) as (client, headless): + result = await client.list_tools() + + assert result.tools[0].name == "echo" + + assert len(headless.authorize_urls) == 2 + assert authorize_params(headless.authorize_urls[0])["scope"] == "mcp" + assert authorize_params(headless.authorize_urls[1])["scope"] == "mcp write" + + counts = path_counts(recorded) + assert counts[("GET", PRM_PATH)] == 1 + assert counts[("GET", ASM_PATH)] == 1 + assert counts[("POST", "/register")] == 0 + assert counts[("GET", "/authorize")] == 2 + assert counts[("POST", "/token")] == 2 + + +@requirement("client-auth:401-after-auth-throws") +async def test_a_second_401_after_a_completed_oauth_flow_surfaces_without_looping() -> None: + """A 401 on the post-auth retry surfaces as an error rather than re-entering discovery. + + The provider rejects every token at verification, so the full flow runs once and the retry + is 401'd. The auth-flow generator ends after that retry, so the 401 propagates and the + transport converts it to an INTERNAL_ERROR result, raising during connect. Discovery, + registration, authorize, and token each ran exactly once: no loop. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider(reject_all_tokens=True) + server = Server("guarded", on_list_tools=list_tools) + + def is_internal_error(error: MCPError) -> bool: + return error.error.code == INTERNAL_ERROR + + with anyio.fail_after(5): + with pytest.RaisesGroup(pytest.RaisesExc(MCPError, check=is_internal_error), flatten_subgroups=True): + # Entering the connect raises during the OAuth handshake (inside `Client.__aenter__`), + # so an `async with` body would be unreachable; entering explicitly avoids dead code. + await connect_with_oauth(server, provider=provider, on_request=on_request).__aenter__() + + counts = path_counts(recorded) + assert counts[("GET", PRM_PATH)] == 1 + assert counts[("GET", ASM_PATH)] == 1 + assert counts[("POST", "/register")] == 1 + assert counts[("GET", "/authorize")] == 1 + assert counts[("POST", "/token")] == 1 + assert counts[("POST", "/mcp")] == 2 + + +@requirement("client-auth:cimd") +async def test_cimd_is_selected_when_the_as_advertises_support_and_a_metadata_url_is_supplied() -> None: + """A client-ID metadata-document URL is used as `client_id` instead of registering. + + AS metadata is shimmed to advertise `client_id_metadata_document_supported: true`; the + provider is pre-seeded so the server's authorize and token handlers accept the URL as a + client_id (the SDK server has no CIMD-aware client lookup of its own). The recorded + requests prove no `/register` call, the authorize URL's `client_id` is the CIMD URL, the + token request uses `token_endpoint_auth_method=none`, and storage persists the URL as + `client_id`. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + seeded_client(provider, client_id=CIMD_URL) + storage = InMemoryTokenStorage() + server = Server("guarded", on_list_tools=list_tools) + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + storage=storage, + client_metadata_url=CIMD_URL, + app_shim=shim(serve={ASM_PATH: cimd_supported_metadata()}), + on_request=on_request, + ) as (client, headless): + await client.list_tools() + + assert find(recorded, "POST", "/register") == [] + assert headless.authorize_url is not None + assert authorize_params(headless.authorize_url)["client_id"] == CIMD_URL + + [token_req] = find(recorded, "POST", "/token") + body = form_body(token_req) + assert body["client_id"] == CIMD_URL + assert "client_secret" not in body + assert "authorization" not in token_req.headers + + assert storage.client_info is not None + assert storage.client_info.client_id == CIMD_URL + assert storage.client_info.token_endpoint_auth_method == "none" + + +@requirement("client-auth:invalid-grant-clears-tokens") +async def test_a_failed_refresh_clears_stored_tokens_and_restarts_the_full_flow() -> None: + """A non-200 refresh response clears the in-memory tokens and the flow re-runs from discovery. + + The first token is reported expired so the next request refreshes; the provider denies the + refresh once with `invalid_grant`, the auth flow clears its tokens, the unauthenticated + request 401s, and discovery, authorize, and token run again. The original registration is + preserved (`client_info` is not cleared). The SDK clears tokens on any non-200 refresh + response, not specifically `error=invalid_grant`; `source="sdk"` so this is a precision + note rather than a divergence. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider(issue_expired_first=True, fail_next_refresh=True) + storage = InMemoryTokenStorage() + server = Server("guarded", on_list_tools=list_tools) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, storage=storage, on_request=on_request) as (client, _): + result = await client.list_tools() + + assert result.tools[0].name == "echo" + + token_posts = find(recorded, "POST", "/token") + assert [form_body(r)["grant_type"] for r in token_posts] == snapshot( + ["authorization_code", "refresh_token", "authorization_code"] + ) + + counts = path_counts(recorded) + assert counts[("POST", "/register")] == 1 + assert counts[("GET", "/authorize")] == 2 + assert counts[("GET", PRM_PATH)] == 2 + assert counts[("GET", ASM_PATH)] == 2 + + assert storage.client_info is not None + assert storage.tokens is not None + assert storage.tokens.access_token in provider.access_tokens + + +@requirement("client-auth:client-credentials") +async def test_client_credentials_provider_obtains_a_token_without_an_authorize_step() -> None: + """The client-credentials provider connects with no authorize step and a `client_credentials` grant. + + The SDK server's `TokenHandler` does not route `client_credentials`, so the harness shim + handles it (the shim is harness; the SDK-under-test is the client provider). The recorded + `/token` body proves the grant type, scope, resource indicator, and HTTP-Basic client + authentication; no `/authorize` or `/register` request was made. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + auth = ClientCredentialsOAuthProvider( + server_url=f"{BASE_URL}/mcp", + storage=InMemoryTokenStorage(), + client_id="m2m-client", + client_secret="m2m-secret", + scopes="mcp", + ) + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + auth=auth, + app_shim=m2m_token_shim(provider, scopes=["mcp"]), + on_request=on_request, + ) as (client, headless): + result = await client.list_tools() + + assert result.tools[0].name == "echo" + assert headless.authorize_url is None + assert find(recorded, "GET", "/authorize") == [] + assert find(recorded, "POST", "/register") == [] + + [token_req] = find(recorded, "POST", "/token") + body = form_body(token_req) + assert body == snapshot( + {"grant_type": "client_credentials", "resource": "http://127.0.0.1:8000/mcp", "scope": "mcp"} + ) + decoded = base64.b64decode(token_req.headers["authorization"].removeprefix("Basic ")).decode() + assert decoded == "m2m-client:m2m-secret" + + +@requirement("client-auth:private-key-jwt") +async def test_private_key_jwt_provider_authenticates_the_token_request_with_an_assertion() -> None: + """The private-key-JWT provider sends a `client_assertion` on the token request, with the issuer as audience. + + The assertion provider is a closure that records the audience it was called with and returns + a fixed opaque value (the JWT contents are not the SDK's concern here); the test asserts the + `client_assertion`/`client_assertion_type` form fields and that the audience matches the AS + metadata's issuer. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + audiences: list[str] = [] + + async def assertion_provider(audience: str) -> str: + audiences.append(audience) + return "header.payload.sig" + + auth = PrivateKeyJWTOAuthProvider( + server_url=f"{BASE_URL}/mcp", + storage=InMemoryTokenStorage(), + client_id="m2m-jwt-client", + assertion_provider=assertion_provider, + scopes="mcp", + ) + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + auth=auth, + app_shim=m2m_token_shim(provider, scopes=["mcp"]), + on_request=on_request, + ) as (client, _): + result = await client.list_tools() + + assert result.tools[0].name == "echo" + assert audiences == [f"{BASE_URL}/"] + + [token_req] = find(recorded, "POST", "/token") + body = form_body(token_req) + assert body == snapshot( + { + "grant_type": "client_credentials", + "client_assertion": "header.payload.sig", + "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", + "resource": "http://127.0.0.1:8000/mcp", + "scope": "mcp", + } + ) + assert "client_secret" not in body + assert "authorization" not in token_req.headers + + +@pytest.mark.parametrize( + ("case", "preseed_storage", "advertise_cimd"), + [("cimd_unsupported_falls_through_to_dcr", False, False), ("preregistered_beats_cimd", True, True)], + ids=["cimd_unsupported_falls_through_to_dcr", "preregistered_beats_cimd"], +) +@requirement("client-auth:cimd") +async def test_registration_priority_prefers_preregistered_then_cimd_then_dcr( + case: str, preseed_storage: bool, advertise_cimd: bool +) -> None: + """The client picks pre-registration over CIMD over DCR, falling through when each is unavailable. + + Two priority edges are exercised: with a CIMD URL configured but no AS support, DCR runs and + the registered `client_id` is used; with a CIMD URL configured and AS support but a + pre-registered client in storage, the stored `client_id` is used and neither CIMD nor DCR + runs. (The positive CIMD case and pre-registration over DCR are covered by their own tests.) + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + storage = InMemoryTokenStorage() + + expected_client_id: str + if preseed_storage: + info = seeded_client(provider) + storage.client_info = info + assert info.client_id is not None + expected_client_id = info.client_id + else: + expected_client_id = "" + + app_shim = shim(serve={ASM_PATH: cimd_supported_metadata()}) if advertise_cimd else None + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + storage=storage, + client_metadata_url=CIMD_URL, + app_shim=app_shim, + on_request=on_request, + ) as (client, headless): + await client.list_tools() + + assert headless.authorize_url is not None + chosen_client_id = authorize_params(headless.authorize_url)["client_id"] + assert chosen_client_id != CIMD_URL + + if case == "cimd_unsupported_falls_through_to_dcr": + assert len(find(recorded, "POST", "/register")) == 1 + assert chosen_client_id in provider.clients + else: + assert find(recorded, "POST", "/register") == [] + assert chosen_client_id == expected_client_id diff --git a/tests/interaction/conftest.py b/tests/interaction/conftest.py new file mode 100644 index 0000000000..c2ace45077 --- /dev/null +++ b/tests/interaction/conftest.py @@ -0,0 +1,23 @@ +"""Shared fixtures for the interaction suite.""" + +import pytest + +from tests.interaction._connect import Connect, connect_in_memory, connect_over_sse, connect_over_streamable_http + +_FACTORIES: dict[str, Connect] = { + "in-memory": connect_in_memory, + "streamable-http": connect_over_streamable_http, + "sse": connect_over_sse, +} + + +@pytest.fixture(params=sorted(_FACTORIES)) +def connect(request: pytest.FixtureRequest) -> Connect: + """The transport-parametrized connection factory: a test using it runs once per transport. + + Tests that are tied to one transport (the wire-recording tests, the bare-ClientSession tests, + the transport-specific tests under transports/) do not use this fixture and connect directly. + """ + transport_name = request.param + assert isinstance(transport_name, str) + return _FACTORIES[transport_name] diff --git a/tests/interaction/lowlevel/__init__.py b/tests/interaction/lowlevel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/interaction/lowlevel/test_cancellation.py b/tests/interaction/lowlevel/test_cancellation.py new file mode 100644 index 0000000000..6f1454e58a --- /dev/null +++ b/tests/interaction/lowlevel/test_cancellation.py @@ -0,0 +1,234 @@ +"""Cancellation interactions against the low-level Server, driven through the public Client API. + +There is no client-side cancellation API: cancelling means sending a CancelledNotification +carrying the request id, which only the server-side handler can observe (`ctx.request_id`), so +these tests capture the id from inside the blocked handler before cancelling. The handler blocks +on an Event rather than a sleep, and every wait is bounded by `anyio.fail_after`. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client import ClientSession +from mcp.server import Server, ServerRequestContext +from mcp.shared.memory import MessageStream, create_client_server_memory_streams +from mcp.shared.message import SessionMessage +from mcp.types import ( + CallToolResult, + EmptyResult, + ErrorData, + Implementation, + InitializeResult, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + PingRequest, + ServerCapabilities, + TextContent, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("protocol:cancel:in-flight") +@requirement("protocol:cancel:handler-abort-propagates") +async def test_cancellation_stops_in_flight_handler(connect: Connect) -> None: + """Cancelling an in-flight request interrupts its handler and fails the pending call. + + The server answers the cancelled request with an error response (the spec says it should + not respond at all; see the divergence note on the requirement), so the caller's pending + request raises rather than hanging. + """ + started = anyio.Event() + handler_cancelled = anyio.Event() + request_ids: list[types.RequestId] = [] + errors: list[ErrorData] = [] + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "block" + assert ctx.request_id is not None + request_ids.append(ctx.request_id) + started.set() + try: + await anyio.Event().wait() # blocks until cancelled; nothing ever sets this event + except anyio.get_cancelled_exc_class(): + handler_cancelled.set() + raise + raise NotImplementedError # unreachable: the wait above never completes normally + + server = Server("blocker", on_call_tool=call_tool) + + async with connect(server) as client: + with anyio.fail_after(5): + async with anyio.create_task_group() as task_group: + + async def call_and_capture_error() -> None: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("block", {}) + errors.append(exc_info.value.error) + + task_group.start_soon(call_and_capture_error) + await started.wait() + await client.session.send_notification( + types.CancelledNotification( + params=types.CancelledNotificationParams(request_id=request_ids[0], reason="user aborted") + ) + ) + + await handler_cancelled.wait() + + assert errors == snapshot([ErrorData(code=0, message="Request cancelled")]) + + +@requirement("protocol:cancel:server-survives") +async def test_session_serves_requests_after_cancellation(connect: Connect) -> None: + """A request cancelled mid-flight does not poison the session: the next request succeeds.""" + started = anyio.Event() + request_ids: list[types.RequestId] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool(name="block", input_schema={"type": "object"}), + types.Tool(name="echo", input_schema={"type": "object"}), + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + if params.name == "echo": + return CallToolResult(content=[TextContent(text="still alive")]) + assert ctx.request_id is not None + request_ids.append(ctx.request_id) + started.set() + await anyio.Event().wait() # blocks until cancelled + raise NotImplementedError # unreachable + + server = Server("blocker", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + with anyio.fail_after(5): + async with anyio.create_task_group() as task_group: + + async def call_and_swallow_cancellation_error() -> None: + with pytest.raises(MCPError): + await client.call_tool("block", {}) + + task_group.start_soon(call_and_swallow_cancellation_error) + await started.wait() + await client.session.send_notification( + types.CancelledNotification(params=types.CancelledNotificationParams(request_id=request_ids[0])) + ) + + result = await client.call_tool("echo", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="still alive")])) + + +@requirement("protocol:cancel:unknown-id-ignored") +async def test_cancellation_for_unknown_request_is_ignored(connect: Connect) -> None: + """A cancellation referencing a request id that is not in flight is ignored without error.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="echo", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "echo" + return CallToolResult(content=[TextContent(text="unbothered")]) + + server = Server("calm", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.session.send_notification( + types.CancelledNotification(params=types.CancelledNotificationParams(request_id=9999)) + ) + result = await client.call_tool("echo", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="unbothered")])) + + +@requirement("protocol:cancel:late-response-ignored") +async def test_a_response_for_an_unknown_request_id_surfaces_to_the_message_handler() -> None: + """A response whose id matches no in-flight request is surfaced to the message handler as a RuntimeError. + + The spec says a sender SHOULD ignore a response that arrives after it issued a cancellation; + that is the same client-side code path as any response with an unknown id, and that form is + deterministic to test without depending on the cancellation API the SDK does not yet provide. + See the divergence note on the requirement. + + A real Server cannot be made to answer with a fabricated id, so the test plays the server's + side of the wire by hand. Reserve this pattern for behaviour no real server can produce. The + other tests in this file run over the transport matrix; this one is in-memory only because the + scripted-peer mechanism is the in-memory stream pair, not because the behaviour is + transport-specific. + """ + + async def scripted_server(streams: MessageStream) -> None: + server_read, server_write = streams + + def respond(request_id: types.RequestId, result: types.Result) -> SessionMessage: + return SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request_id, + # Serialized exactly as a real server serializes results onto the wire. + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + + init = await server_read.receive() + assert isinstance(init, SessionMessage) + assert isinstance(init.message, JSONRPCRequest) + assert init.message.method == "initialize" + await server_write.send( + respond( + init.message.id, + InitializeResult( + protocol_version="2025-11-25", + capabilities=ServerCapabilities(), + server_info=Implementation(name="scripted", version="0.0.1"), + ), + ) + ) + + initialized = await server_read.receive() + assert isinstance(initialized, SessionMessage) + assert isinstance(initialized.message, JSONRPCNotification) + assert initialized.message.method == "notifications/initialized" + + ping = await server_read.receive() + assert isinstance(ping, SessionMessage) + assert isinstance(ping.message, JSONRPCRequest) + assert ping.message.method == "ping" + # First answer with a fabricated id that matches nothing in flight, then the real id. + await server_write.send(respond(9999, EmptyResult())) + await server_write.send(respond(ping.message.id, EmptyResult())) + + incoming: list[IncomingMessage] = [] + + async def message_handler(message: IncomingMessage) -> None: + incoming.append(message) + + async with ( + create_client_server_memory_streams() as ((client_read, client_write), server_streams), + anyio.create_task_group() as task_group, + ClientSession(client_read, client_write, message_handler=message_handler) as session, + ): + task_group.start_soon(scripted_server, server_streams) + with anyio.fail_after(5): + await session.initialize() + pong = await session.send_request(PingRequest(), EmptyResult) + + assert pong == snapshot(EmptyResult()) + assert len(incoming) == 1 + assert isinstance(incoming[0], RuntimeError) + # The full message embeds the response object's repr; only the prefix is stable. + assert str(incoming[0]).startswith("Received response with an unknown request ID:") diff --git a/tests/interaction/lowlevel/test_completion.py b/tests/interaction/lowlevel/test_completion.py new file mode 100644 index 0000000000..6a35404df3 --- /dev/null +++ b/tests/interaction/lowlevel/test_completion.py @@ -0,0 +1,131 @@ +"""Completion interactions against the low-level Server, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + INVALID_PARAMS, + METHOD_NOT_FOUND, + CompleteResult, + Completion, + ErrorData, + PromptReference, + ResourceTemplateReference, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("completion:prompt-arg") +@requirement("completion:result-shape") +async def test_complete_prompt_argument(connect: Connect) -> None: + """Completing a prompt argument delivers the ref, argument name, and current value to the handler. + + The returned values are filtered by the argument's value, proving the value reached the handler. + """ + + async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult: + assert isinstance(params.ref, PromptReference) + assert params.ref.name == "code_review" + assert params.argument.name == "language" + candidates = ["python", "pytorch", "ruby"] + matches = [candidate for candidate in candidates if candidate.startswith(params.argument.value)] + return CompleteResult(completion=Completion(values=matches, total=len(matches), has_more=False)) + + server = Server("completer", on_completion=completion) + + async with connect(server) as client: + result = await client.complete( + PromptReference(name="code_review"), argument={"name": "language", "value": "py"} + ) + + assert result == snapshot( + CompleteResult(completion=Completion(values=["python", "pytorch"], total=2, has_more=False)) + ) + + +@requirement("completion:resource-template-arg") +async def test_complete_resource_template_variable(connect: Connect) -> None: + """Completing a URI template variable delivers the template URI and variable name to the handler.""" + + async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult: + assert isinstance(params.ref, ResourceTemplateReference) + assert params.ref.uri == "github://repos/{owner}/{repo}" + assert params.argument.name == "owner" + return CompleteResult(completion=Completion(values=[f"{params.argument.value}contextprotocol"])) + + server = Server("completer", on_completion=completion) + + async with connect(server) as client: + result = await client.complete( + ResourceTemplateReference(uri="github://repos/{owner}/{repo}"), + argument={"name": "owner", "value": "model"}, + ) + + assert result == snapshot(CompleteResult(completion=Completion(values=["modelcontextprotocol"]))) + + +@requirement("completion:context-arguments") +async def test_complete_receives_context_arguments(connect: Connect) -> None: + """Previously-resolved arguments passed as completion context reach the handler. + + The returned value is derived from the context, proving it arrived. + """ + + async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult: + assert params.argument.name == "repo" + assert params.context is not None + assert params.context.arguments is not None + return CompleteResult(completion=Completion(values=[f"{params.context.arguments['owner']}/python-sdk"])) + + server = Server("completer", on_completion=completion) + + async with connect(server) as client: + result = await client.complete( + ResourceTemplateReference(uri="github://repos/{owner}/{repo}"), + argument={"name": "repo", "value": ""}, + context_arguments={"owner": "modelcontextprotocol"}, + ) + + assert result == snapshot(CompleteResult(completion=Completion(values=["modelcontextprotocol/python-sdk"]))) + + +@requirement("completion:error:invalid-ref") +async def test_completion_against_an_unknown_ref_is_rejected_with_invalid_params(connect: Connect) -> None: + """completion/complete with a ref naming an unknown prompt is answered with -32602 Invalid params. + + The lowlevel server does not validate refs itself (it has no prompt/template registry to check + against); rejecting an unknown ref is the handler's job, and this test pins the spec-recommended + way to do it. + """ + + async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult: + assert isinstance(params.ref, PromptReference) + raise MCPError(code=INVALID_PARAMS, message=f"Unknown prompt: {params.ref.name!r}") + + server = Server("completer", on_completion=completion) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.complete(PromptReference(name="ghost"), argument={"name": "x", "value": ""}) + + assert exc_info.value.error.code == INVALID_PARAMS + + +@requirement("completion:complete:not-supported") +@requirement("protocol:error:method-not-found") +async def test_complete_without_handler_is_method_not_found(connect: Connect) -> None: + """A server with no completion handler advertises no completions capability and rejects the request.""" + server = Server("incomplete") + + async with connect(server) as client: + assert client.initialize_result.capabilities.completions is None + + with pytest.raises(MCPError) as exc_info: + await client.complete(PromptReference(name="anything"), argument={"name": "topic", "value": ""}) + + assert exc_info.value.error == snapshot(ErrorData(code=METHOD_NOT_FOUND, message="Method not found")) diff --git a/tests/interaction/lowlevel/test_elicitation.py b/tests/interaction/lowlevel/test_elicitation.py new file mode 100644 index 0000000000..b8edf601d0 --- /dev/null +++ b/tests/interaction/lowlevel/test_elicitation.py @@ -0,0 +1,662 @@ +"""Form- and URL-mode elicitation against the low-level Server, driven through the public Client API. + +The final test plays the server's side of the wire by hand to issue an elicitation request with no +mode field, because the typed server API (`elicit_form`/`elicit_url`) always serializes one. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, UrlElicitationRequiredError, types +from mcp.client import ClientRequestContext, ClientSession +from mcp.server import Server, ServerRequestContext +from mcp.shared.memory import MessageStream, create_client_server_memory_streams +from mcp.shared.message import SessionMessage +from mcp.types import ( + CallToolResult, + ElicitCompleteNotification, + ElicitCompleteNotificationParams, + ElicitRequestedSchema, + ElicitRequestFormParams, + ElicitRequestURLParams, + ElicitResult, + ErrorData, + Implementation, + InitializeResult, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + ServerCapabilities, + TextContent, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + +REQUESTED_SCHEMA: dict[str, object] = { + "type": "object", + "properties": { + "username": {"type": "string"}, + "newsletter": {"type": "boolean"}, + }, + "required": ["username"], +} + + +@requirement("elicitation:form:action:accept") +@requirement("elicitation:form:basic") +@requirement("tools:call:elicitation-roundtrip") +async def test_elicit_form_accepted_content_returns_to_handler(connect: Connect) -> None: + """An accepted form elicitation returns the user's content to the requesting handler. + + The tool reports the action as text and the received content as structured content, proving + the client's answer made it back into the tool's own result. + """ + received: list[types.ElicitRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="signup", description="Register the user.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "signup" + answer = await ctx.session.elicit_form("Choose a username.", REQUESTED_SCHEMA) + return CallToolResult(content=[TextContent(text=answer.action)], structured_content=answer.content) + + server = Server("registrar", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="accept", content={"username": "ada", "newsletter": True}) + + async with connect(server, elicitation_callback=answer_form) as client: + result = await client.call_tool("signup", {}) + + assert received == snapshot( + [ + ElicitRequestFormParams( + _meta={}, + message="Choose a username.", + requested_schema={ + "type": "object", + "properties": { + "username": {"type": "string"}, + "newsletter": {"type": "boolean"}, + }, + "required": ["username"], + }, + ) + ] + ) + assert result == snapshot( + CallToolResult( + content=[TextContent(text="accept")], + structured_content={"username": "ada", "newsletter": True}, + ) + ) + + +@requirement("elicitation:form:action:decline") +async def test_elicit_form_decline_returns_no_content(connect: Connect) -> None: + """A declined form elicitation returns the decline action to the handler with no content.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="confirm", description="Ask for confirmation.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "confirm" + answer = await ctx.session.elicit_form("Proceed?", {"type": "object", "properties": {}}) + return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + + server = Server("confirmer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="decline") + + async with connect(server, elicitation_callback=answer_form) as client: + result = await client.call_tool("confirm", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="decline content=None")])) + + +@requirement("elicitation:form:action:cancel") +async def test_elicit_form_cancel_returns_no_content(connect: Connect) -> None: + """A cancelled form elicitation returns the cancel action to the handler with no content.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="confirm", description="Ask for confirmation.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "confirm" + answer = await ctx.session.elicit_form("Proceed?", {"type": "object", "properties": {}}) + return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + + server = Server("confirmer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="cancel") + + async with connect(server, elicitation_callback=answer_form) as client: + result = await client.call_tool("confirm", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="cancel content=None")])) + + +@requirement("elicitation:form:not-supported") +@requirement("elicitation:capability:server-respects-mode") +async def test_elicit_form_without_callback_is_error(connect: Connect) -> None: + """Eliciting from a client that configured no elicitation callback fails with an error. + + The client's default callback answers with an Invalid request error, which the server-side + elicit call raises as an MCPError; the tool reports the code and message it caught. The spec + requires -32602 for an undeclared mode (see the divergence note on the requirement). The + request reaching the client also shows the server does not check the client's declared + elicitation capability before sending (see the divergence on `server-respects-mode`). + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="ask", description="Ask the user.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask" + try: + await ctx.session.elicit_form("Anyone there?", {"type": "object", "properties": {}}) + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # elicit_form cannot succeed without a client callback + + server = Server("asker", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("ask", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="-32600: Elicitation not supported")])) + + +@requirement("elicitation:url:action:accept-no-content") +@requirement("elicitation:url:basic") +async def test_elicit_url_delivers_url_and_returns_accept_without_content(connect: Connect) -> None: + """A URL elicitation delivers the message, URL, and elicitation id to the client; accepting it + returns the action with no content. + + Accept means the user agreed to visit the URL, not that the out-of-band interaction finished, + so there is never form content to return. + """ + received: list[types.ElicitRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="authorize", description="Link an account.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "authorize" + answer = await ctx.session.elicit_url( + "Authorize access to your calendar.", "https://example.com/oauth/authorize", "auth-001" + ) + return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + + server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_url(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="accept") + + async with connect(server, elicitation_callback=answer_url) as client: + result = await client.call_tool("authorize", {}) + + assert received == snapshot( + [ + ElicitRequestURLParams( + _meta={}, + message="Authorize access to your calendar.", + url="https://example.com/oauth/authorize", + elicitation_id="auth-001", + ) + ] + ) + assert result == snapshot(CallToolResult(content=[TextContent(text="accept content=None")])) + + +@requirement("elicitation:url:decline") +async def test_elicit_url_decline_returns_no_content(connect: Connect) -> None: + """A declined URL elicitation returns the decline action to the handler with no content.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="authorize", description="Link an account.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "authorize" + answer = await ctx.session.elicit_url( + "Authorize access to your calendar.", "https://example.com/oauth/authorize", "auth-001" + ) + return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + + server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_url(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="decline") + + async with connect(server, elicitation_callback=answer_url) as client: + result = await client.call_tool("authorize", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="decline content=None")])) + + +@requirement("elicitation:url:cancel") +async def test_elicit_url_cancel_returns_no_content(connect: Connect) -> None: + """A cancelled URL elicitation returns the cancel action to the handler with no content.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="authorize", description="Link an account.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "authorize" + answer = await ctx.session.elicit_url( + "Authorize access to your calendar.", "https://example.com/oauth/authorize", "auth-001" + ) + return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + + server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_url(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="cancel") + + async with connect(server, elicitation_callback=answer_url) as client: + result = await client.call_tool("authorize", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="cancel content=None")])) + + +@requirement("elicitation:url:complete-notification") +async def test_elicitation_complete_notification_carries_the_elicited_id_back_to_the_client(connect: Connect) -> None: + """After a URL elicitation finishes, the server announces it with a notification carrying the same id. + + The lifecycle under test: the tool elicits a URL interaction with an elicitationId, the user + agrees to visit the URL, the out-of-band interaction finishes, and the server emits + elicitation/complete so the client can correlate the completion with the elicitation it + accepted earlier. The completion notification carries ``related_request_id`` so over + streamable HTTP it rides the tool call's own stream and reaches the client before the call + returns; the same ordering already holds on in-memory and SSE transports. + """ + elicitation_id = "auth-001" + elicited_ids: list[str] = [] + received: list[IncomingMessage] = [] + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="link_account", description="Link an account.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "link_account" + answer = await ctx.session.elicit_url( + "Authorize access to your files.", "https://example.com/oauth/authorize", elicitation_id + ) + assert answer.action == "accept" + await ctx.session.send_elicit_complete(elicitation_id, related_request_id=ctx.request_id) + return CallToolResult(content=[TextContent(text="linked")]) + + server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_url(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + assert isinstance(params, ElicitRequestURLParams) + elicited_ids.append(params.elicitation_id) + return ElicitResult(action="accept") + + async with connect(server, message_handler=collect, elicitation_callback=answer_url) as client: + await client.call_tool("link_account", {}) + + # The completion notification refers to the same elicitation the client accepted. + assert elicited_ids == [elicitation_id] + assert received == snapshot( + [ElicitCompleteNotification(params=ElicitCompleteNotificationParams(elicitation_id="auth-001"))] + ) + + +@requirement("elicitation:url:required-error") +async def test_url_elicitation_required_error_carries_pending_elicitations(connect: Connect) -> None: + """A request that cannot proceed until a URL interaction completes is rejected with error -32042. + + This is the non-interactive alternative to elicit_url: instead of asking and waiting, the + handler rejects the whole request and lists the required URL elicitations in the error data. + The client is expected to present those URLs, wait for the matching elicitation/complete + notifications, and retry the original request. + """ + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "read_files" + raise UrlElicitationRequiredError( + [ + ElicitRequestURLParams( + message="Authorization required for your files.", + url="https://example.com/oauth/authorize", + elicitation_id="auth-001", + ) + ] + ) + + server = Server("authorizer", on_call_tool=call_tool) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("read_files", {}) + + assert exc_info.value.error == snapshot( + ErrorData( + code=-32042, + message="URL elicitation required", + data={ + "elicitations": [ + { + "mode": "url", + "message": "Authorization required for your files.", + "url": "https://example.com/oauth/authorize", + "elicitationId": "auth-001", + } + ] + }, + ) + ) + + +@requirement("elicitation:form:schema:primitives") +@requirement("elicitation:form:schema:enum-variants") +async def test_elicit_form_schema_with_every_primitive_and_enum_type_reaches_the_callback_as_sent( + connect: Connect, +) -> None: + """A requested schema covering every spec-listed property kind is delivered to the callback unchanged. + + One schema with one property per kind: a formatted string, an integer with bounds, a number, + a boolean, a plain enum, a oneOf-const titled enum, and a multi-select array-of-enum. The + callback observing the same schema as the handler sent proves both the primitive coverage and + the enum-variant coverage in one snapshot. + """ + schema: ElicitRequestedSchema = { + "type": "object", + "properties": { + "email": {"type": "string", "format": "email", "title": "Email", "description": "Contact address."}, + "age": {"type": "integer", "minimum": 0, "maximum": 150}, + "score": {"type": "number"}, + "subscribe": {"type": "boolean", "default": False}, + "tier": {"type": "string", "enum": ["free", "pro", "team"]}, + "region": { + "oneOf": [ + {"const": "eu", "title": "Europe"}, + {"const": "na", "title": "North America"}, + ], + }, + "channels": {"type": "array", "items": {"type": "string", "enum": ["email", "sms", "push"]}}, + }, + "required": ["email"], + } + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="onboard", description="Onboard the user.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "onboard" + answer = await ctx.session.elicit_form("Tell us about yourself.", schema) + return CallToolResult(content=[TextContent(text=answer.action)]) + + server = Server("onboarder", on_list_tools=list_tools, on_call_tool=call_tool) + + received: list[types.ElicitRequestParams] = [] + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="accept", content={"email": "ada@example.com"}) + + async with connect(server, elicitation_callback=answer_form) as client: + await client.call_tool("onboard", {}) + + assert len(received) == 1 + assert isinstance(received[0], ElicitRequestFormParams) + assert received[0].requested_schema == schema + + +@requirement("elicitation:form:schema:restricted-subset") +async def test_elicit_form_with_a_nested_schema_is_forwarded_unchanged(connect: Connect) -> None: + """A requested schema with nested-object and array-of-object properties passes through unchanged. + + The spec restricts form-mode requested schemas to flat objects with primitive-typed properties; + this test pins that the SDK does not enforce that restriction on either side (see the + divergence on the requirement). + """ + schema: ElicitRequestedSchema = { + "type": "object", + "properties": { + "address": { + "type": "object", + "properties": {"street": {"type": "string"}, "city": {"type": "string"}}, + }, + "contacts": { + "type": "array", + "items": {"type": "object", "properties": {"name": {"type": "string"}}}, + }, + }, + } + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="profile", description="Collect a profile.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "profile" + answer = await ctx.session.elicit_form("Profile details.", schema) + return CallToolResult(content=[TextContent(text=answer.action)]) + + server = Server("profiler", on_list_tools=list_tools, on_call_tool=call_tool) + + received: list[types.ElicitRequestParams] = [] + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="decline") + + async with connect(server, elicitation_callback=answer_form) as client: + await client.call_tool("profile", {}) + + assert len(received) == 1 + assert isinstance(received[0], ElicitRequestFormParams) + assert received[0].requested_schema == schema + + +@requirement("elicitation:form:response-validation") +async def test_accepted_elicitation_content_that_violates_the_schema_reaches_the_handler_unchanged( + connect: Connect, +) -> None: + """Accepted form content that contradicts the requested schema is delivered to the handler unchanged. + + The schema requires a string `name`; the callback answers with a wrong-type value and an extra + field. Nothing on either side validates the response against the schema (see the divergence on + the requirement), so the handler observes exactly what the callback sent. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="signup", description="Register the user.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "signup" + answer = await ctx.session.elicit_form( + "Choose a name.", + {"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]}, + ) + return CallToolResult(content=[TextContent(text=answer.action)], structured_content=answer.content) + + server = Server("registrar", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="accept", content={"name": 42, "extra": "field"}) + + async with connect(server, elicitation_callback=answer_form) as client: + result = await client.call_tool("signup", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="accept")], structured_content={"name": 42, "extra": "field"}) + ) + + +@requirement("elicitation:url:complete-unknown-ignored") +async def test_elicitation_complete_for_an_unknown_id_is_received_without_error(connect: Connect) -> None: + """An elicitation/complete for an id the client never elicited is delivered and does not fail anything. + + No URL elicitation precedes the notification; the client neither tracks elicitation ids nor + rejects unknown ones, so the call completes normally and the message handler observes the + notification as-is. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="noop", description="Send a stray complete.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "noop" + await ctx.session.send_elicit_complete("never-elicited", related_request_id=ctx.request_id) + return CallToolResult(content=[TextContent(text="ok")]) + + server = Server("notifier", on_list_tools=list_tools, on_call_tool=call_tool) + + received: list[IncomingMessage] = [] + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async with connect(server, message_handler=collect) as client: + result = await client.call_tool("noop", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="ok")])) + assert received == snapshot( + [ElicitCompleteNotification(params=ElicitCompleteNotificationParams(elicitation_id="never-elicited"))] + ) + + +@requirement("elicitation:form:mode-omitted-default") +async def test_a_mode_less_elicitation_request_is_treated_as_form_mode() -> None: + """An elicitation/create request with no mode field reaches the client callback as form-mode. + + The typed server API always serializes a mode (`elicit_form` writes 'form', `elicit_url` writes + 'url'), so this test plays the server's side of the wire by hand to send a request body without + one. Reserve this pattern for behaviour the typed server API cannot produce. + """ + received: list[types.ElicitRequestParams] = [] + answered = anyio.Event() + server_received: list[JSONRPCMessage] = [] + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="accept", content={}) + + async def scripted_server(streams: MessageStream) -> None: + server_read, server_write = streams + initialize = await server_read.receive() + assert isinstance(initialize, SessionMessage) + request = initialize.message + assert isinstance(request, JSONRPCRequest) + assert request.method == "initialize" + result = InitializeResult( + protocol_version="2025-11-25", + capabilities=ServerCapabilities(), + server_info=Implementation(name="legacy", version="0.0.1"), + ) + await server_write.send( + SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + initialized = await server_read.receive() + assert isinstance(initialized, SessionMessage) + assert isinstance(initialized.message, JSONRPCNotification) + assert initialized.message.method == "notifications/initialized" + # No mode key: a server speaking a pre-mode revision of the spec sends only message + schema. + await server_write.send( + SessionMessage( + JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="elicitation/create", + params={"message": "Legacy ask.", "requestedSchema": {"type": "object", "properties": {}}}, + ) + ) + ) + response = await server_read.receive() + assert isinstance(response, SessionMessage) + server_received.append(response.message) + answered.set() + + async with ( + create_client_server_memory_streams() as ((client_read, client_write), server_streams), + anyio.create_task_group() as tg, + ClientSession(client_read, client_write, elicitation_callback=answer_form) as session, + ): + tg.start_soon(scripted_server, server_streams) + with anyio.fail_after(5): + await session.initialize() + await answered.wait() + + assert received == snapshot( + [ + ElicitRequestFormParams( + _meta=None, + message="Legacy ask.", + requested_schema={"type": "object", "properties": {}}, + ) + ] + ) + assert isinstance(received[0], ElicitRequestFormParams) + assert received[0].mode == "form" + assert len(server_received) == 1 + assert isinstance(server_received[0], JSONRPCResponse) + assert server_received[0].id == 2 diff --git a/tests/interaction/lowlevel/test_flows.py b/tests/interaction/lowlevel/test_flows.py new file mode 100644 index 0000000000..8d96582341 --- /dev/null +++ b/tests/interaction/lowlevel/test_flows.py @@ -0,0 +1,203 @@ +"""Composed multi-feature flows against the low-level Server, driven through the public Client API. + +Each test reads as the scenario it proves: the steps run top to bottom in the order a real client +would perform them, composing two or more feature areas (a tool call followed by a resource read; +a chain of elicitations inside one tool call; the full URL-elicitation-required retry loop). The +individual features are pinned by their own tests; these prove they compose. +""" + +from collections.abc import Awaitable, Callable + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, UrlElicitationRequiredError, types +from mcp.client import ClientRequestContext +from mcp.server import Server, ServerRequestContext +from mcp.server.session import ServerSession +from mcp.types import ( + URL_ELICITATION_REQUIRED, + CallToolResult, + ElicitCompleteNotification, + ElicitRequestFormParams, + ElicitRequestURLParams, + ElicitResult, + EmptyResult, + ListToolsResult, + ReadResourceResult, + ResourceLink, + TextContent, + TextResourceContents, + Tool, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + +ListToolsHandler = Callable[ + [ServerRequestContext, types.PaginatedRequestParams | None], Awaitable[types.ListToolsResult] +] + + +def _list_tools(*names: str) -> ListToolsHandler: + """A list_tools handler advertising the named tools, so call_tool's implicit list succeeds.""" + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name=name, input_schema={"type": "object"}) for name in names]) + + return list_tools + + +@requirement("flow:tool-result:resource-link-follow") +async def test_a_resource_link_returned_by_a_tool_can_be_followed_with_read(connect: Connect) -> None: + """A tool returns a resource_link; reading that link's URI returns the referenced contents. + + Steps: (1) call the tool, (2) extract the link from its content, (3) read_resource on the + link's URI, (4) the read result carries the linked contents. + """ + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "generate" + return CallToolResult(content=[ResourceLink(uri="file:///report.txt", name="report")]) + + async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: + assert str(params.uri) == "file:///report.txt" + return ReadResourceResult(contents=[TextResourceContents(uri="file:///report.txt", text="generated")]) + + server = Server( + "linker", on_list_tools=_list_tools("generate"), on_call_tool=call_tool, on_read_resource=read_resource + ) + + async with connect(server) as client: + called = await client.call_tool("generate", {}) + link = called.content[0] + assert isinstance(link, ResourceLink) + read = await client.read_resource(link.uri) + + assert called == snapshot(CallToolResult(content=[ResourceLink(name="report", uri="file:///report.txt")])) + assert read == snapshot( + ReadResourceResult(contents=[TextResourceContents(uri="file:///report.txt", text="generated")]) + ) + + +@requirement("flow:elicitation:multi-step-form") +async def test_a_tool_handler_chains_form_elicitations_feeding_each_answer_forward(connect: Connect) -> None: + """Sequential form elicitations inside one tool call: each accepted answer feeds the next step. + + Steps: (1) call the tool, (2) the handler issues a step-one form elicitation that the client + accepts with content, (3) the handler issues a step-two elicitation whose message references + the step-one answer, (4) the client accepts step two, (5) the tool result summarises both + answers. The callback is invoked exactly twice with the expected messages and schemas. The + short-circuit on decline is the application's choice (proven separately by the per-action + elicitation tests); what this flow pins is that the chain itself works end to end. + """ + received: list[ElicitRequestFormParams] = [] + answers: list[dict[str, str | int | float | bool | list[str] | None]] = [{"name": "ada"}, {"age": 37}] + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "onboard" + first = await ctx.session.elicit_form( + "Step 1: choose a username.", {"type": "object", "properties": {"name": {"type": "string"}}} + ) + assert first.action == "accept" and first.content is not None + second = await ctx.session.elicit_form( + f"Step 2: confirm age for {first.content['name']}.", + {"type": "object", "properties": {"age": {"type": "integer"}}}, + ) + assert second.action == "accept" and second.content is not None + return CallToolResult(content=[TextContent(text=f"{first.content['name']} is {second.content['age']}")]) + + server = Server("onboarder", on_list_tools=_list_tools("onboard"), on_call_tool=call_tool) + + async def answer(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + assert isinstance(params, ElicitRequestFormParams) + received.append(params) + return ElicitResult(action="accept", content=answers[len(received) - 1]) + + async with connect(server, elicitation_callback=answer) as client: + result = await client.call_tool("onboard", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="ada is 37")])) + assert [(p.message, p.requested_schema) for p in received] == snapshot( + [ + ("Step 1: choose a username.", {"type": "object", "properties": {"name": {"type": "string"}}}), + ("Step 2: confirm age for ada.", {"type": "object", "properties": {"age": {"type": "integer"}}}), + ] + ) + + +@requirement("flow:elicitation:url-required-then-retry") +async def test_a_tool_rejected_with_url_elicitation_required_succeeds_on_retry_after_completion( + connect: Connect, +) -> None: + """The full URL-elicitation-required retry loop: -32042, completion announced, retry succeeds. + + Steps: (1) the first call is rejected with -32042 carrying the required URL elicitation in + its error data, (2) the client extracts the elicitation id from the error, (3) the server + announces completion via the elicitation/complete notification (driven via the captured + session, the same way a real out-of-band callback would reach a held session reference), + (4) the client observes the matching completion notification and retries, (5) the retry + succeeds. The handler distinguishes the two calls by a closure flag the test flips between + them; the test waits on the completion notification with an event so the retry only happens + after the announcement has arrived. + """ + elicitation_id = "auth-001" + authorised: list[bool] = [False] + captured: list[ServerSession] = [] + completed = anyio.Event() + notifications: list[ElicitCompleteNotification] = [] + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "read_files" + captured.append(ctx.session) + if not authorised[0]: + # The log line gives the message handler a non-completion notification, so the test's + # filtering branch is exercised in both directions and the wait remains specific. + await ctx.session.send_log_message(level="warning", data="authorisation required", logger="gate") + raise UrlElicitationRequiredError( + [ + ElicitRequestURLParams( + message="Authorize file access.", + url="https://example.com/oauth/authorize", + elicitation_id=elicitation_id, + ) + ] + ) + return CallToolResult(content=[TextContent(text="contents")]) + + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + """Registered so the logging capability is advertised; the client never sets a level.""" + raise NotImplementedError + + server = Server( + "gatekeeper", + on_list_tools=_list_tools("read_files"), + on_call_tool=call_tool, + on_set_logging_level=set_logging_level, + ) + + async def collect(message: IncomingMessage) -> None: + if isinstance(message, ElicitCompleteNotification): + notifications.append(message) + completed.set() + + async with connect(server, message_handler=collect) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("read_files", {}) + assert exc_info.value.error.code == URL_ELICITATION_REQUIRED + required = UrlElicitationRequiredError.from_error(exc_info.value.error) + assert [e.elicitation_id for e in required.elicitations] == [elicitation_id] + + # The out-of-band interaction completes; the server announces it on the same session. + await captured[0].send_elicit_complete(elicitation_id) + with anyio.fail_after(5): + await completed.wait() + assert notifications[0].params.elicitation_id == elicitation_id + + authorised[0] = True + result = await client.call_tool("read_files", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="contents")])) diff --git a/tests/interaction/lowlevel/test_initialize.py b/tests/interaction/lowlevel/test_initialize.py new file mode 100644 index 0000000000..91adbf5611 --- /dev/null +++ b/tests/interaction/lowlevel/test_initialize.py @@ -0,0 +1,384 @@ +"""Initialization handshake against the low-level Server, driven through the public Client API. + +The later tests drive a bare ClientSession over an InMemoryTransport instead: Client always +performs the full handshake with the latest protocol version, so skipping initialization or +requesting a different version can only be expressed one level down. The final test goes one step +further and plays the server's side of the wire by hand, because no real Server can be made to +answer initialize with an unsupported protocol version. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client import ClientRequestContext, ClientSession +from mcp.client._memory import InMemoryTransport +from mcp.server import Server, ServerRequestContext +from mcp.shared.memory import MessageStream, create_client_server_memory_streams +from mcp.shared.message import SessionMessage +from mcp.types import ( + INVALID_PARAMS, + CallToolResult, + ClientCapabilities, + CompletionsCapability, + EmptyResult, + ErrorData, + Icon, + Implementation, + InitializeRequest, + InitializeRequestParams, + InitializeResult, + JSONRPCRequest, + JSONRPCResponse, + ListToolsRequest, + ListToolsResult, + LoggingCapability, + PromptsCapability, + ResourcesCapability, + ServerCapabilities, + TextContent, + ToolsCapability, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("lifecycle:initialize:basic") +@requirement("lifecycle:initialize:server-info") +async def test_initialize_returns_server_info(connect: Connect) -> None: + """Every identity field the server declares is returned to the client in server_info.""" + server = Server( + "greeter", + version="1.2.3", + title="Greeter", + description="Greets people.", + website_url="https://example.com/greeter", + icons=[Icon(src="https://example.com/icon.png", mime_type="image/png", sizes=["48x48"])], + ) + + async with connect(server) as client: + server_info = client.initialize_result.server_info + + assert server_info == snapshot( + Implementation( + name="greeter", + title="Greeter", + description="Greets people.", + version="1.2.3", + website_url="https://example.com/greeter", + icons=[Icon(src="https://example.com/icon.png", mime_type="image/png", sizes=["48x48"])], + ) + ) + + +@requirement("lifecycle:initialize:instructions") +async def test_initialize_returns_instructions(connect: Connect) -> None: + """Instructions are returned when the server declares them and omitted when it does not.""" + async with connect(Server("guided", instructions="Call the add tool.")) as client: + assert client.initialize_result.instructions == snapshot("Call the add tool.") + + async with connect(Server("unguided")) as client: + assert client.initialize_result.instructions is None + + +@requirement("lifecycle:initialize:capabilities:from-handlers") +@requirement("tools:capability:declared") +@requirement("resources:capability:declared") +@requirement("prompts:capability:declared") +@requirement("completion:capability:declared") +async def test_initialize_capabilities_reflect_registered_handlers(connect: Connect) -> None: + """Each feature area with a registered handler is advertised as a capability. + + The in-memory transport connects with default initialization options, so the + list_changed flags are always False regardless of the server's notification behaviour. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + """Registered only so the tools capability is advertised; never called.""" + raise NotImplementedError + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListResourcesResult: + """Registered only so the resources capability is advertised; never called.""" + raise NotImplementedError + + async def subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeRequestParams) -> types.EmptyResult: + """Registered only so the subscribe sub-capability is advertised; never called.""" + raise NotImplementedError + + async def list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListPromptsResult: + """Registered only so the prompts capability is advertised; never called.""" + raise NotImplementedError + + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> types.EmptyResult: + """Registered only so the logging capability is advertised; never called.""" + raise NotImplementedError + + async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> types.CompleteResult: + """Registered only so the completions capability is advertised; never called.""" + raise NotImplementedError + + server = Server( + "full", + on_list_tools=list_tools, + on_list_resources=list_resources, + on_subscribe_resource=subscribe_resource, + on_list_prompts=list_prompts, + on_set_logging_level=set_logging_level, + on_completion=completion, + ) + + async with connect(server) as client: + capabilities = client.initialize_result.capabilities + + assert capabilities == snapshot( + ServerCapabilities( + experimental={}, + logging=LoggingCapability(), + prompts=PromptsCapability(list_changed=False), + resources=ResourcesCapability(subscribe=True, list_changed=False), + tools=ToolsCapability(list_changed=False), + completions=CompletionsCapability(), + ) + ) + + +@requirement("lifecycle:initialize:capabilities:minimal") +async def test_initialize_minimal_server_advertises_no_capabilities(connect: Connect) -> None: + """A server with no feature handlers advertises no feature capabilities.""" + async with connect(Server("bare")) as client: + capabilities = client.initialize_result.capabilities + + assert capabilities == snapshot(ServerCapabilities(experimental={})) + + +@requirement("lifecycle:initialize:client-info") +async def test_initialize_server_sees_client_info(connect: Connect) -> None: + """The client identity supplied to Client is visible to server handlers after initialization.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="whoami", description="Report the caller.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "whoami" + assert ctx.session.client_params is not None + client_info = ctx.session.client_params.client_info + return CallToolResult(content=[TextContent(text=f"{client_info.name} {client_info.version}")]) + + server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + async with connect(server, client_info=Implementation(name="acme-agent", version="9.9.9")) as client: + result = await client.call_tool("whoami", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="acme-agent 9.9.9")])) + + +@requirement("lifecycle:initialize:client-capabilities") +async def test_initialize_server_sees_client_capabilities(connect: Connect) -> None: + """The client capabilities visible to the server reflect which callbacks the client configured.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="abilities", description="Report capabilities.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "abilities" + assert ctx.session.client_params is not None + capabilities = ctx.session.client_params.capabilities + declared = [ + name + for name, value in ( + ("sampling", capabilities.sampling), + ("elicitation", capabilities.elicitation), + ) + if value is not None + ] + if capabilities.roots is not None: + declared.append(f"roots(list_changed={capabilities.roots.list_changed})") + return CallToolResult(content=[TextContent(text=",".join(declared) or "none")]) + + async def list_roots(context: ClientRequestContext) -> types.ListRootsResult: + """Registered only so the client declares the roots capability; never called.""" + raise NotImplementedError + + server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("abilities", {}) + assert result == snapshot(CallToolResult(content=[TextContent(text="none")])) + + async with connect(server, list_roots_callback=list_roots) as client: + result = await client.call_tool("abilities", {}) + assert result == snapshot(CallToolResult(content=[TextContent(text="roots(list_changed=True)")])) + + +@requirement("lifecycle:requests-before-initialized") +async def test_request_before_initialization_is_rejected() -> None: + """A feature request sent before the handshake completes is rejected; ping is exempt. + + Client always initializes on entry, so this drives a bare ClientSession that never sends + initialize. The server's stated reason for the rejection never reaches the client: the error + is reported as a generic invalid-params failure. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + """Registered so the request is routed to a real handler; never reached.""" + raise NotImplementedError + + server = Server("strict", on_list_tools=list_tools) + + async with ( + InMemoryTransport(server) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + with anyio.fail_after(5): + with pytest.raises(MCPError) as exc_info: + await session.send_request(ListToolsRequest(), ListToolsResult) + + # Ping is explicitly permitted before initialization completes. + pong = await session.send_ping() + + assert exc_info.value.error == snapshot( + ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="") + ) + assert pong == snapshot(EmptyResult()) + + +@requirement("lifecycle:version:match") +@requirement("lifecycle:version:server-fallback-latest") +async def test_initialize_negotiates_protocol_version() -> None: + """The server echoes a supported requested version and answers an unsupported one with its latest. + + Client always requests the latest version, so each half hand-builds an InitializeRequest on a + bare ClientSession to control the requested version. + """ + server = Server("negotiator") + + def initialize_request(protocol_version: str) -> InitializeRequest: + return InitializeRequest( + params=InitializeRequestParams( + protocol_version=protocol_version, + capabilities=ClientCapabilities(), + client_info=Implementation(name="time-traveller", version="0.0.1"), + ) + ) + + async with ( + InMemoryTransport(server) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + with anyio.fail_after(5): + result = await session.send_request(initialize_request("2025-03-26"), InitializeResult) + assert result.protocol_version == snapshot("2025-03-26") + + async with ( + InMemoryTransport(server) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + with anyio.fail_after(5): + result = await session.send_request(initialize_request("1999-01-01"), InitializeResult) + assert result.protocol_version == snapshot("2025-11-25") + + +@requirement("lifecycle:version:reject-unsupported") +async def test_unsupported_server_protocol_version_fails_initialization() -> None: + """An initialize response carrying a protocol version the client does not support fails initialization. + + A real Server only ever answers with a version it supports, so this test alone plays the + server's side of the wire by hand: it reads the initialize request off the raw stream and + answers it with a hand-built result. Reserve this pattern for behaviour no real server can + be made to produce. + """ + + async def scripted_server(streams: MessageStream) -> None: + server_read, server_write = streams + message = await server_read.receive() + assert isinstance(message, SessionMessage) + request = message.message + assert isinstance(request, JSONRPCRequest) + assert request.method == "initialize" + result = InitializeResult( + protocol_version="1991-08-06", + capabilities=ServerCapabilities(), + server_info=Implementation(name="relic", version="0.0.1"), + ) + await server_write.send( + SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request.id, + # Serialized exactly as a real server serializes results onto the wire. + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + + async with ( + create_client_server_memory_streams() as ((client_read, client_write), server_streams), + anyio.create_task_group() as tg, + ClientSession(client_read, client_write) as session, + ): + tg.start_soon(scripted_server, server_streams) + with anyio.fail_after(5): + with pytest.raises(RuntimeError) as exc_info: + await session.initialize() + + assert str(exc_info.value) == snapshot("Unsupported protocol version from the server: 1991-08-06") + + +@requirement("lifecycle:version:downgrade") +async def test_an_older_supported_protocol_version_from_the_server_is_accepted() -> None: + """An initialize response carrying an older supported protocol version completes the handshake at that version. + + A real Server answers with the version the client requested (or its own latest), so this test + plays the server's side of the wire by hand to return a fixed older version regardless of what + was requested. Reserve this pattern for behaviour no real server can be made to produce. + """ + + async def scripted_server(streams: MessageStream) -> None: + server_read, server_write = streams + message = await server_read.receive() + assert isinstance(message, SessionMessage) + request = message.message + assert isinstance(request, JSONRPCRequest) + assert request.method == "initialize" + result = InitializeResult( + protocol_version="2025-06-18", + capabilities=ServerCapabilities(), + server_info=Implementation(name="conservative", version="0.0.1"), + ) + await server_write.send( + SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request.id, + # Serialized exactly as a real server serializes results onto the wire. + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + + async with ( + create_client_server_memory_streams() as ((client_read, client_write), server_streams), + anyio.create_task_group() as tg, + ClientSession(client_read, client_write) as session, + ): + tg.start_soon(scripted_server, server_streams) + with anyio.fail_after(5): + initialize_result = await session.initialize() + + assert initialize_result.protocol_version == snapshot("2025-06-18") diff --git a/tests/interaction/lowlevel/test_list_changed.py b/tests/interaction/lowlevel/test_list_changed.py new file mode 100644 index 0000000000..a2f85eeacf --- /dev/null +++ b/tests/interaction/lowlevel/test_list_changed.py @@ -0,0 +1,136 @@ +"""List-changed notifications from the low-level Server, driven through the public Client API. + +``send_*_list_changed`` does not take a ``related_request_id``, so over streamable HTTP the +notification routes to the standalone GET stream and is not guaranteed to arrive before the tool +result on its POST stream. Tests therefore wait on an event the collector sets, the same pattern +as ``transports/test_streamable_http.py::test_unrelated_server_messages_arrive_on_the_standalone_stream``. +The collector still records every message it receives, so the snapshot also proves nothing else +was delivered. + +The servers register the parent capability (resources/prompts) so that part of the spec's +precondition holds, but the ``listChanged`` sub-capability stays ``False``: ``NotificationOptions`` +is not threaded through any of the suite's connection paths. The tests therefore rely on the +recorded ``lifecycle:capability:server-not-advertised`` divergence and will need updating +alongside the fix that introduces capability gating. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + CallToolResult, + PromptListChangedNotification, + ResourceListChangedNotification, + TextContent, + ToolListChangedNotification, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("tools:list-changed") +async def test_tool_list_changed_notification(connect: Connect) -> None: + """A tools/list_changed notification sent during a tool call reaches the client's message handler.""" + received: list[IncomingMessage] = [] + seen = anyio.Event() + + async def collect(message: IncomingMessage) -> None: + received.append(message) + seen.set() + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="install", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "install" + await ctx.session.send_tool_list_changed() + return CallToolResult(content=[TextContent(text="installed")]) + + server = Server("registry", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server, message_handler=collect) as client: + await client.call_tool("install", {}) + with anyio.fail_after(5): + await seen.wait() + + assert received == snapshot([ToolListChangedNotification()]) + + +@requirement("resources:list-changed") +async def test_resource_list_changed_notification(connect: Connect) -> None: + """A resources/list_changed notification sent during a tool call reaches the client's message handler.""" + received: list[IncomingMessage] = [] + seen = anyio.Event() + + async def collect(message: IncomingMessage) -> None: + received.append(message) + seen.set() + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="mount", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "mount" + await ctx.session.send_resource_list_changed() + return CallToolResult(content=[TextContent(text="mounted")]) + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListResourcesResult: + """Registered so the resources capability is advertised; the client never lists resources.""" + raise NotImplementedError + + server = Server("registry", on_list_tools=list_tools, on_call_tool=call_tool, on_list_resources=list_resources) + + async with connect(server, message_handler=collect) as client: + await client.call_tool("mount", {}) + with anyio.fail_after(5): + await seen.wait() + + assert received == snapshot([ResourceListChangedNotification()]) + + +@requirement("prompts:list-changed") +async def test_prompt_list_changed_notification(connect: Connect) -> None: + """A prompts/list_changed notification sent during a tool call reaches the client's message handler.""" + received: list[IncomingMessage] = [] + seen = anyio.Event() + + async def collect(message: IncomingMessage) -> None: + received.append(message) + seen.set() + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="learn", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "learn" + await ctx.session.send_prompt_list_changed() + return CallToolResult(content=[TextContent(text="learned")]) + + async def list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListPromptsResult: + """Registered so the prompts capability is advertised; the client never lists prompts.""" + raise NotImplementedError + + server = Server("registry", on_list_tools=list_tools, on_call_tool=call_tool, on_list_prompts=list_prompts) + + async with connect(server, message_handler=collect) as client: + await client.call_tool("learn", {}) + with anyio.fail_after(5): + await seen.wait() + + assert received == snapshot([PromptListChangedNotification()]) diff --git a/tests/interaction/lowlevel/test_logging.py b/tests/interaction/lowlevel/test_logging.py new file mode 100644 index 0000000000..fba632ef4d --- /dev/null +++ b/tests/interaction/lowlevel/test_logging.py @@ -0,0 +1,127 @@ +"""Logging interactions against the low-level Server, driven through the public Client API. + +Notification ordering: the in-memory transport delivers every server-to-client message on one +ordered stream, and the client's receive loop dispatches each incoming message to completion +before reading the next one. Over streamable HTTP that ordered single-stream guarantee holds +only for messages that carry a ``related_request_id`` (they ride the originating request's POST +stream); without it the message routes to the standalone GET stream and may arrive after the +response. These tests pass ``related_request_id`` so they can collect into a plain list and +assert after the request completes on every transport leg -- no events, no waiting. +""" + +import pytest +from inline_snapshot import snapshot + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.types import CallToolResult, EmptyResult, LoggingMessageNotificationParams, TextContent +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + +ALL_LEVELS: tuple[types.LoggingLevel, ...] = ( + "debug", + "info", + "notice", + "warning", + "error", + "critical", + "alert", + "emergency", +) + + +@requirement("logging:set-level") +async def test_set_logging_level_reaches_handler(connect: Connect) -> None: + """The level requested by the client is delivered to the server's handler verbatim.""" + + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + assert params.level == "warning" + return EmptyResult() + + server = Server("logger", on_set_logging_level=set_logging_level) + + async with connect(server) as client: + result = await client.set_logging_level("warning") + + assert result == snapshot(EmptyResult()) + + +@requirement("logging:message:fields") +@requirement("tools:call:logging-mid-execution") +async def test_log_messages_reach_logging_callback_in_order(connect: Connect) -> None: + """Log messages sent during a tool call arrive at the logging callback, in order, before the call returns. + + The two messages pin the full notification shape: severity, optional logger name, and both + string and structured data payloads. + """ + received: list[LoggingMessageNotificationParams] = [] + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="chatty", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "chatty" + await ctx.session.send_log_message( + level="info", data="starting up", logger="app.lifecycle", related_request_id=ctx.request_id + ) + await ctx.session.send_log_message( + level="error", data={"code": 502, "retryable": True}, related_request_id=ctx.request_id + ) + return CallToolResult(content=[TextContent(text="done")]) + + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + """Registered so the logging capability is advertised; the client never sets a level.""" + raise NotImplementedError + + server = Server("logger", on_list_tools=list_tools, on_call_tool=call_tool, on_set_logging_level=set_logging_level) + + async with connect(server, logging_callback=collect) as client: + result = await client.call_tool("chatty", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="done")])) + assert received == snapshot( + [ + LoggingMessageNotificationParams(level="info", logger="app.lifecycle", data="starting up"), + LoggingMessageNotificationParams(level="error", data={"code": 502, "retryable": True}), + ] + ) + + +@requirement("logging:message:all-levels") +async def test_log_messages_at_every_severity_level(connect: Connect) -> None: + """Each of the eight RFC 5424 severity levels is deliverable as a log message notification.""" + received: list[LoggingMessageNotificationParams] = [] + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="siren", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "siren" + for level in ALL_LEVELS: + await ctx.session.send_log_message( + level=level, data=f"a {level} message", related_request_id=ctx.request_id + ) + return CallToolResult(content=[TextContent(text="logged")]) + + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + """Registered so the logging capability is advertised; the client never sets a level.""" + raise NotImplementedError + + server = Server("logger", on_list_tools=list_tools, on_call_tool=call_tool, on_set_logging_level=set_logging_level) + + async with connect(server, logging_callback=collect) as client: + await client.call_tool("siren", {}) + + assert [params.level for params in received] == list(ALL_LEVELS) diff --git a/tests/interaction/lowlevel/test_meta.py b/tests/interaction/lowlevel/test_meta.py new file mode 100644 index 0000000000..a9e4f994d8 --- /dev/null +++ b/tests/interaction/lowlevel/test_meta.py @@ -0,0 +1,63 @@ +"""Request and result _meta round trips against the low-level Server, through the public Client API. + +Meta is opaque pass-through data, so these tests assert identity against the value that was sent +rather than snapshotting a literal: the expected value and the sent value are the same variable, +which also proves the SDK injected nothing alongside it. +""" + +import pytest + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.types import CallToolResult, RequestParamsMeta, TextContent +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("meta:request-to-handler") +async def test_request_meta_reaches_handler(connect: Connect) -> None: + """The _meta object the client attaches to a request arrives at the tool handler unchanged.""" + request_meta: RequestParamsMeta = {"example.com/trace": "abc-123"} + observed_metas: list[dict[str, object]] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="traced", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "traced" + assert ctx.meta is not None + observed_metas.append(dict(ctx.meta)) + return CallToolResult(content=[TextContent(text="traced")]) + + server = Server("observability", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.call_tool("traced", {}, meta=request_meta) + + assert observed_metas == [dict(request_meta)] + + +@requirement("meta:result-to-client") +async def test_result_meta_reaches_client(connect: Connect) -> None: + """The _meta object a handler attaches to its result is delivered to the client unchanged.""" + result_meta = {"example.com/cost": 3} + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="metered", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "metered" + return CallToolResult(content=[TextContent(text="done")], _meta=result_meta) + + server = Server("observability", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("metered", {}) + + assert result == CallToolResult(content=[TextContent(text="done")], _meta=result_meta) diff --git a/tests/interaction/lowlevel/test_pagination.py b/tests/interaction/lowlevel/test_pagination.py new file mode 100644 index 0000000000..77db90401e --- /dev/null +++ b/tests/interaction/lowlevel/test_pagination.py @@ -0,0 +1,242 @@ +"""Cursor pagination of the list operations against the low-level Server. + +The cursor is an opaque string chosen by the server: the suite only asserts that whatever the +handler returns as next_cursor comes back verbatim on the client's next call, not any particular +pagination scheme. +""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + INVALID_PARAMS, + ListPromptsResult, + ListResourcesResult, + ListResourceTemplatesResult, + ListToolsResult, + Prompt, + Resource, + ResourceTemplate, + Tool, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("tools:list:pagination") +async def test_next_cursor_round_trips_through_the_client(connect: Connect) -> None: + """The next_cursor a list handler returns reaches the client, and the cursor the client sends + back on the following call reaches the handler verbatim. + """ + cursor = "page-2" + seen_cursors: list[str | None] = [] + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + assert params is not None # the client always sends params, even without a cursor + seen_cursors.append(params.cursor) + if params.cursor is None: + return ListToolsResult( + tools=[Tool(name="alpha", input_schema={"type": "object"})], + next_cursor=cursor, + ) + return ListToolsResult(tools=[Tool(name="beta", input_schema={"type": "object"})]) + + server = Server("paginated", on_list_tools=list_tools) + + async with connect(server) as client: + first_page = await client.list_tools() + second_page = await client.list_tools(cursor=first_page.next_cursor) + + assert first_page.next_cursor == cursor + assert seen_cursors == [None, cursor] + assert [tool.name for tool in first_page.tools] == ["alpha"] + assert second_page == snapshot(ListToolsResult(tools=[Tool(name="beta", input_schema={"type": "object"})])) + + +@requirement("pagination:exhaustion") +@requirement("tools:list:pagination") +async def test_paginating_until_next_cursor_is_absent_yields_every_page(connect: Connect) -> None: + """Following next_cursor until it is absent visits every page exactly once, in order.""" + pages: dict[str | None, tuple[str, str | None]] = { + None: ("alpha", "page-2"), + "page-2": ("beta", "page-3"), + "page-3": ("gamma", None), + } + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + assert params is not None + tool_name, next_cursor = pages[params.cursor] + return ListToolsResult(tools=[Tool(name=tool_name, input_schema={"type": "object"})], next_cursor=next_cursor) + + server = Server("paginated", on_list_tools=list_tools) + + collected: list[str] = [] + cursor: str | None = None + requests_made = 0 + async with connect(server) as client: + while True: + result = await client.list_tools(cursor=cursor) + requests_made += 1 + assert requests_made <= len(pages), "the server kept returning next_cursor past the last page" + collected.extend(tool.name for tool in result.tools) + if result.next_cursor is None: + break + cursor = result.next_cursor + + assert collected == snapshot(["alpha", "beta", "gamma"]) + assert requests_made == len(pages) + + +@requirement("pagination:client:cursor-handling") +async def test_the_client_follows_opaque_cursors_through_pages_of_varying_sizes(connect: Connect) -> None: + """The client passes a server-issued cursor back byte-for-byte and follows pages of varying sizes. + + The cursors are deliberately base64-looking strings (with padding and URL-unsafe characters) to + show the client treats them as opaque tokens; the page sizes [3, 1, 2] show the loop relies only + on next_cursor, not on a fixed page size. + """ + cursor_to_page_2 = "YWxwaGE+YnJhdm8/Y2hhcmxpZQ==" + cursor_to_page_3 = "ZGVsdGE=" + pages: dict[str | None, tuple[list[str], str | None]] = { + None: (["alpha", "beta", "gamma"], cursor_to_page_2), + cursor_to_page_2: (["delta"], cursor_to_page_3), + cursor_to_page_3: (["epsilon", "zeta"], None), + } + received_cursors: list[str | None] = [] + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + assert params is not None + received_cursors.append(params.cursor) + names, next_cursor = pages[params.cursor] + return ListToolsResult( + tools=[Tool(name=name, input_schema={"type": "object"}) for name in names], next_cursor=next_cursor + ) + + server = Server("paginated", on_list_tools=list_tools) + + page_sizes: list[int] = [] + cursor: str | None = None + async with connect(server) as client: + while True: + result = await client.list_tools(cursor=cursor) + page_sizes.append(len(result.tools)) + if result.next_cursor is None: + break + cursor = result.next_cursor + + # Identity, not a snapshot: what arrived at the handler is exactly what the handler issued. + assert received_cursors == [None, cursor_to_page_2, cursor_to_page_3] + assert page_sizes == [3, 1, 2] + + +@requirement("pagination:invalid-cursor") +async def test_an_unrecognized_pagination_cursor_is_rejected_with_invalid_params(connect: Connect) -> None: + """A list request with a cursor the server did not issue is answered with -32602 Invalid params. + + The lowlevel server does not validate cursors itself (they are opaque to it); rejecting an + unrecognized cursor is the handler's job, and this test pins the spec-recommended way to do it. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + assert params is not None + assert params.cursor == "never-issued" + raise MCPError(code=INVALID_PARAMS, message=f"Unknown cursor: {params.cursor!r}") + + server = Server("paginated", on_list_tools=list_tools) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.list_tools(cursor="never-issued") + + assert exc_info.value.error.code == INVALID_PARAMS + + +@requirement("resources:list:pagination") +async def test_resources_list_supports_cursor_pagination(connect: Connect) -> None: + """resources/list round-trips the cursor like every other list operation.""" + cursor = "page-2" + seen_cursors: list[str | None] = [] + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: + assert params is not None + seen_cursors.append(params.cursor) + if params.cursor is None: + return ListResourcesResult(resources=[Resource(uri="memo://1", name="first")], next_cursor=cursor) + return ListResourcesResult(resources=[Resource(uri="memo://2", name="second")]) + + server = Server("paginated", on_list_resources=list_resources) + + async with connect(server) as client: + first_page = await client.list_resources() + second_page = await client.list_resources(cursor=first_page.next_cursor) + + assert first_page.next_cursor == cursor + assert seen_cursors == [None, cursor] + assert [resource.name for resource in first_page.resources] == ["first"] + assert [resource.name for resource in second_page.resources] == ["second"] + assert second_page.next_cursor is None + + +@requirement("resources:templates:pagination") +async def test_resource_templates_list_supports_cursor_pagination(connect: Connect) -> None: + """resources/templates/list round-trips the cursor like every other list operation.""" + cursor = "page-2" + seen_cursors: list[str | None] = [] + + async def list_resource_templates( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourceTemplatesResult: + assert params is not None + seen_cursors.append(params.cursor) + if params.cursor is None: + return ListResourceTemplatesResult( + resource_templates=[ResourceTemplate(name="first", uri_template="users://{id}")], + next_cursor=cursor, + ) + return ListResourceTemplatesResult( + resource_templates=[ResourceTemplate(name="second", uri_template="teams://{id}")] + ) + + server = Server("paginated", on_list_resource_templates=list_resource_templates) + + async with connect(server) as client: + first_page = await client.list_resource_templates() + second_page = await client.list_resource_templates(cursor=first_page.next_cursor) + + assert first_page.next_cursor == cursor + assert seen_cursors == [None, cursor] + assert [template.name for template in first_page.resource_templates] == ["first"] + assert [template.name for template in second_page.resource_templates] == ["second"] + assert second_page.next_cursor is None + + +@requirement("prompts:list:pagination") +async def test_prompts_list_supports_cursor_pagination(connect: Connect) -> None: + """prompts/list round-trips the cursor like every other list operation.""" + cursor = "page-2" + seen_cursors: list[str | None] = [] + + async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListPromptsResult: + assert params is not None + seen_cursors.append(params.cursor) + if params.cursor is None: + return ListPromptsResult(prompts=[Prompt(name="first")], next_cursor=cursor) + return ListPromptsResult(prompts=[Prompt(name="second")]) + + server = Server("paginated", on_list_prompts=list_prompts) + + async with connect(server) as client: + first_page = await client.list_prompts() + second_page = await client.list_prompts(cursor=first_page.next_cursor) + + assert first_page.next_cursor == cursor + assert seen_cursors == [None, cursor] + assert [prompt.name for prompt in first_page.prompts] == ["first"] + assert [prompt.name for prompt in second_page.prompts] == ["second"] + assert second_page.next_cursor is None diff --git a/tests/interaction/lowlevel/test_ping.py b/tests/interaction/lowlevel/test_ping.py new file mode 100644 index 0000000000..797e20dc35 --- /dev/null +++ b/tests/interaction/lowlevel/test_ping.py @@ -0,0 +1,53 @@ +"""Ping interactions against the low-level Server, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.types import CallToolResult, EmptyResult, TextContent +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("lifecycle:ping") +@requirement("ping:client-to-server") +async def test_client_ping_returns_empty_result(connect: Connect) -> None: + """A client ping is answered with an empty result, even by a server with no handlers.""" + server = Server("silent") + + async with connect(server) as client: + result = await client.send_ping() + + assert result == snapshot(EmptyResult()) + + +@requirement("lifecycle:ping") +@requirement("ping:server-to-client") +async def test_server_ping_returns_empty_result(connect: Connect) -> None: + """A server-initiated ping sent while a request is in flight is answered by the client. + + The tool returns the type of the ping response, proving the round trip completed inside + the handler before the tool result was produced. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="ping_back", description="Ping the client.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ping_back" + pong = await ctx.session.send_ping() + return CallToolResult(content=[TextContent(text=type(pong).__name__)]) + + server = Server("pinger", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("ping_back", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="EmptyResult")])) diff --git a/tests/interaction/lowlevel/test_progress.py b/tests/interaction/lowlevel/test_progress.py new file mode 100644 index 0000000000..6350c33a33 --- /dev/null +++ b/tests/interaction/lowlevel/test_progress.py @@ -0,0 +1,301 @@ +"""Progress interactions against the low-level Server, driven through the public Client API. + +Server-to-client progress emitted during a request follows the same ordering guarantee as +logging notifications (see test_logging.py) -- on the in-memory transport unconditionally, and +over streamable HTTP only when sent with ``related_request_id`` so the notification rides the +originating request's POST stream rather than the standalone GET stream. These tests pass +``related_request_id`` so no synchronisation is needed. The client-to-server direction is a +standalone notification with no response to await, so that test waits on an event set by the +server's handler. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.server.session import ServerSession +from mcp.shared.session import ProgressFnT +from mcp.types import CallToolResult, ProgressNotification, ProgressNotificationParams, ProgressToken, TextContent +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("protocol:progress:callback") +@requirement("tools:call:progress") +async def test_progress_during_tool_call_reaches_callback_in_order(connect: Connect) -> None: + """Progress notifications emitted by a tool handler reach the caller's progress callback in order.""" + received: list[tuple[float, float | None, str | None]] = [] + + async def collect(progress: float, total: float | None, message: str | None) -> None: + received.append((progress, total, message)) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="download", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "download" + assert ctx.meta is not None + token = ctx.meta.get("progress_token") + assert token is not None + await ctx.session.send_progress_notification( + token, 1.0, total=3.0, message="first chunk", related_request_id=str(ctx.request_id) + ) + await ctx.session.send_progress_notification( + token, 2.0, total=3.0, message="second chunk", related_request_id=str(ctx.request_id) + ) + await ctx.session.send_progress_notification( + token, 3.0, total=3.0, message="done", related_request_id=str(ctx.request_id) + ) + return CallToolResult(content=[TextContent(text="downloaded")]) + + server = Server("downloader", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("download", {}, progress_callback=collect) + + assert result == snapshot(CallToolResult(content=[TextContent(text="downloaded")])) + assert received == snapshot([(1.0, 3.0, "first chunk"), (2.0, 3.0, "second chunk"), (3.0, 3.0, "done")]) + + +@requirement("protocol:progress:token-injected") +async def test_progress_token_visible_to_handler(connect: Connect) -> None: + """Supplying a progress callback attaches a progress token that the handler can read from the request meta.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="inspect", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "inspect" + assert ctx.meta is not None + return CallToolResult(content=[TextContent(text=str(ctx.meta.get("progress_token")))]) + + server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + + async def ignore(progress: float, total: float | None, message: str | None) -> None: + """A progress callback that is never invoked; the tool only inspects the token.""" + raise NotImplementedError + + async with connect(server) as client: + result = await client.call_tool("inspect", {}, progress_callback=ignore) + + # The token is the request id of the tools/call request itself (initialize is request 0). + assert result == snapshot(CallToolResult(content=[TextContent(text="1")])) + + +@requirement("protocol:progress:no-token") +async def test_no_progress_callback_means_no_token(connect: Connect) -> None: + """Without a progress callback the request carries no progress token. + + The low-level API has no way to report request-scoped progress without a token, so a handler + that sees no token has nothing to send progress against. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="inspect", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "inspect" + assert ctx.meta is not None + return CallToolResult(content=[TextContent(text=str(ctx.meta.get("progress_token")))]) + + server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("inspect", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="None")])) + + +@requirement("protocol:progress:client-to-server") +async def test_client_progress_notification_reaches_server_handler(connect: Connect) -> None: + """A progress notification sent by the client is delivered to the server's progress handler.""" + received: list[ProgressNotificationParams] = [] + delivered = anyio.Event() + + async def on_progress(ctx: ServerRequestContext, params: ProgressNotificationParams) -> None: + received.append(params) + delivered.set() + + server = Server("observer", on_progress=on_progress) + + async with connect(server) as client: + await client.send_progress_notification("upload-1", 0.5, total=1.0, message="halfway") + with anyio.fail_after(5): + await delivered.wait() + + assert received == snapshot( + [ProgressNotificationParams(progress_token="upload-1", progress=0.5, total=1.0, message="halfway")] + ) + + +@requirement("protocol:progress:token-unique") +async def test_concurrent_requests_carry_distinct_progress_tokens(connect: Connect) -> None: + """Two concurrent requests carry distinct progress tokens, and each callback sees only its own progress. + + Without the barrier the first call could run to completion before the second starts, so only one + token would be live at a time and the demultiplexing would never be exercised. The handlers each + block until both have started and then hand control back and forth so the four progress + notifications are emitted in strict a, b, a, b order on the wire. The two handlers send different + progress values so a stream swap (token A delivered to callback B and vice versa) would fail: each + callback receiving exactly its own values proves notifications are routed by token, not by arrival + order or by chance. + """ + progress_values = {"a": (1.0, 2.0), "b": (10.0, 20.0)} + tokens: dict[str, ProgressToken] = {} + entered = {"a": anyio.Event(), "b": anyio.Event()} + # turns[n] is set to release the nth emission; each emission releases the next. + turns = [anyio.Event() for _ in range(4)] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="report", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "report" + assert params.arguments is not None + assert ctx.meta is not None + token = ctx.meta.get("progress_token") + assert token is not None + label = params.arguments["label"] + tokens[label] = token + entered[label].set() + # The two handlers interleave by waiting on alternating turns: a takes 0 and 2, b takes 1 and 3. + first, second = (0, 2) if label == "a" else (1, 3) + await turns[first].wait() + await ctx.session.send_progress_notification( + token, progress_values[label][0], related_request_id=str(ctx.request_id) + ) + turns[first + 1].set() + await turns[second].wait() + await ctx.session.send_progress_notification( + token, progress_values[label][1], related_request_id=str(ctx.request_id) + ) + if second + 1 < len(turns): + turns[second + 1].set() + return CallToolResult(content=[TextContent(text="done")]) + + server = Server("reporter", on_list_tools=list_tools, on_call_tool=call_tool) + + received_a: list[float] = [] + received_b: list[float] = [] + + async def collect_a(progress: float, total: float | None, message: str | None) -> None: + received_a.append(progress) + + async def collect_b(progress: float, total: float | None, message: str | None) -> None: + received_b.append(progress) + + async with connect(server) as client: + + async def call(label: str, collect: ProgressFnT) -> None: + await client.call_tool("report", {"label": label}, progress_callback=collect) + + with anyio.fail_after(5): + async with anyio.create_task_group() as task_group: # pragma: no branch + task_group.start_soon(call, "a", collect_a) + task_group.start_soon(call, "b", collect_b) + await entered["a"].wait() + await entered["b"].wait() + turns[0].set() + + assert tokens["a"] != tokens["b"] + assert received_a == [1.0, 2.0] + assert received_b == [10.0, 20.0] + + +@requirement("protocol:progress:stops-after-completion") +@requirement("protocol:progress:late-dropped-by-client") +async def test_progress_sent_after_the_response_is_not_delivered_to_the_callback(connect: Connect) -> None: + """A progress notification sent after the response is emitted, and the client drops it from the callback. + + This single body proves both halves: the server's `send_progress_notification` happily sends for + a token whose request has already completed (the spec MUST that progress stops is not enforced; + see the divergence on `stops-after-completion`), and the client, having removed the callback when + the call returned, does not deliver the late notification to it. The message handler observes the + late notification arriving so the test knows when to assert without polling. + """ + captured: list[tuple[ServerSession, ProgressToken]] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="report", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "report" + assert ctx.meta is not None + token = ctx.meta.get("progress_token") + assert token is not None + captured.append((ctx.session, token)) + await ctx.session.send_progress_notification(token, 0.5, related_request_id=str(ctx.request_id)) + return CallToolResult(content=[TextContent(text="done")]) + + server = Server("reporter", on_list_tools=list_tools, on_call_tool=call_tool) + + received: list[float] = [] + late_progress_arrived = anyio.Event() + + async def collect(progress: float, total: float | None, message: str | None) -> None: + received.append(progress) + + async def message_handler(message: IncomingMessage) -> None: + if isinstance(message, ProgressNotification) and message.params.progress == 1.0: + late_progress_arrived.set() + + async with connect(server, message_handler=message_handler) as client: + with anyio.fail_after(5): + await client.call_tool("report", {}, progress_callback=collect) + assert received == [0.5] + + server_session, token = captured[0] + await server_session.send_progress_notification(token, 1.0) + await late_progress_arrived.wait() + + assert received == [0.5] + + +@requirement("protocol:progress:monotonic") +async def test_non_increasing_progress_values_are_forwarded_unchanged(connect: Connect) -> None: + """A handler that emits non-increasing progress values has them forwarded to the callback unchanged. + + The spec says progress MUST increase with each notification; the SDK does not enforce that on + either side. See the divergence note on the requirement. + """ + received: list[float] = [] + + async def collect(progress: float, total: float | None, message: str | None) -> None: + received.append(progress) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="zigzag", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "zigzag" + assert ctx.meta is not None + token = ctx.meta.get("progress_token") + assert token is not None + await ctx.session.send_progress_notification(token, 0.5, related_request_id=str(ctx.request_id)) + await ctx.session.send_progress_notification(token, 0.3, related_request_id=str(ctx.request_id)) + await ctx.session.send_progress_notification(token, 0.9, related_request_id=str(ctx.request_id)) + return CallToolResult(content=[TextContent(text="done")]) + + server = Server("zigzagger", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.call_tool("zigzag", {}, progress_callback=collect) + + assert received == snapshot([0.5, 0.3, 0.9]) diff --git a/tests/interaction/lowlevel/test_prompts.py b/tests/interaction/lowlevel/test_prompts.py new file mode 100644 index 0000000000..868b82692c --- /dev/null +++ b/tests/interaction/lowlevel/test_prompts.py @@ -0,0 +1,209 @@ +"""Prompt interactions against the low-level Server, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + INVALID_PARAMS, + AudioContent, + EmbeddedResource, + ErrorData, + GetPromptResult, + Icon, + ImageContent, + ListPromptsResult, + Prompt, + PromptArgument, + PromptMessage, + TextContent, + TextResourceContents, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("prompts:list:basic") +async def test_list_prompts_returns_registered_prompts(connect: Connect) -> None: + """The prompts returned by the handler reach the client with their argument declarations intact.""" + + async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListPromptsResult: + return ListPromptsResult( + prompts=[ + Prompt( + name="code_review", + description="Review a piece of code.", + arguments=[ + PromptArgument(name="code", description="The code to review.", required=True), + PromptArgument(name="style_guide", description="Optional style guide to apply."), + ], + icons=[Icon(src="https://example.com/review.png", mime_type="image/png", sizes=["48x48"])], + ), + Prompt(name="daily_standup"), + ] + ) + + server = Server("prompter", on_list_prompts=list_prompts) + + async with connect(server) as client: + result = await client.list_prompts() + + assert result == snapshot( + ListPromptsResult( + prompts=[ + Prompt( + name="code_review", + description="Review a piece of code.", + arguments=[ + PromptArgument(name="code", description="The code to review.", required=True), + PromptArgument(name="style_guide", description="Optional style guide to apply."), + ], + icons=[Icon(src="https://example.com/review.png", mime_type="image/png", sizes=["48x48"])], + ), + Prompt(name="daily_standup"), + ] + ) + ) + + +@requirement("prompts:get:with-args") +async def test_get_prompt_substitutes_arguments(connect: Connect) -> None: + """Arguments supplied by the client reach the prompt handler; the templated message comes back.""" + + async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: + assert params.name == "greet" + assert params.arguments is not None + return GetPromptResult( + description="A personalised greeting.", + messages=[PromptMessage(role="user", content=TextContent(text=f"Hello, {params.arguments['name']}!"))], + ) + + server = Server("prompter", on_get_prompt=get_prompt) + + async with connect(server) as client: + result = await client.get_prompt("greet", {"name": "Ada"}) + + assert result == snapshot( + GetPromptResult( + description="A personalised greeting.", + messages=[PromptMessage(role="user", content=TextContent(text="Hello, Ada!"))], + ) + ) + + +@requirement("prompts:get:multi-message") +async def test_get_prompt_multiple_messages_preserve_roles_and_order(connect: Connect) -> None: + """A prompt returning a user/assistant conversation reaches the client with roles and order intact.""" + + async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: + assert params.name == "geography_quiz" + return GetPromptResult( + messages=[ + PromptMessage(role="user", content=TextContent(text="What is the capital of France?")), + PromptMessage(role="assistant", content=TextContent(text="The capital of France is Paris.")), + PromptMessage(role="user", content=TextContent(text="And of Italy?")), + ] + ) + + server = Server("prompter", on_get_prompt=get_prompt) + + async with connect(server) as client: + result = await client.get_prompt("geography_quiz") + + assert result == snapshot( + GetPromptResult( + messages=[ + PromptMessage(role="user", content=TextContent(text="What is the capital of France?")), + PromptMessage(role="assistant", content=TextContent(text="The capital of France is Paris.")), + PromptMessage(role="user", content=TextContent(text="And of Italy?")), + ] + ) + ) + + +@requirement("prompts:get:no-args") +async def test_get_prompt_without_arguments_returns_the_messages(connect: Connect) -> None: + """A prompt fetched with no arguments delivers None as the handler's arguments and returns its messages.""" + + async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: + assert params.name == "static" + assert params.arguments is None + return GetPromptResult(messages=[PromptMessage(role="user", content=TextContent(text="Say hello."))]) + + server = Server("prompter", on_get_prompt=get_prompt) + + async with connect(server) as client: + result = await client.get_prompt("static") + + assert result == snapshot( + GetPromptResult(messages=[PromptMessage(role="user", content=TextContent(text="Say hello."))]) + ) + + +@requirement("prompts:get:content:image") +@requirement("prompts:get:content:audio") +@requirement("prompts:get:content:embedded-resource") +async def test_get_prompt_with_non_text_content_round_trips(connect: Connect) -> None: + """Prompt messages can carry image, audio, and embedded-resource content; all reach the client. + + A single full-result snapshot proves all three content types round-trip: each block in the result + is one of the three behaviours under test. Tiny fixed base64 payloads ("aW1n" is b"img", "YXVk" + is b"aud") so the snapshot pins the exact bytes. + """ + + async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: + assert params.name == "media" + return GetPromptResult( + messages=[ + PromptMessage(role="user", content=ImageContent(data="aW1n", mime_type="image/png")), + PromptMessage(role="assistant", content=AudioContent(data="YXVk", mime_type="audio/wav")), + PromptMessage( + role="user", + content=EmbeddedResource( + resource=TextResourceContents(uri="resource://notes/1", mime_type="text/plain", text="attached") + ), + ), + ] + ) + + server = Server("prompter", on_get_prompt=get_prompt) + + async with connect(server) as client: + result = await client.get_prompt("media", {}) + + assert result == snapshot( + GetPromptResult( + messages=[ + PromptMessage(role="user", content=ImageContent(data="aW1n", mime_type="image/png")), + PromptMessage(role="assistant", content=AudioContent(data="YXVk", mime_type="audio/wav")), + PromptMessage( + role="user", + content=EmbeddedResource( + resource=TextResourceContents(uri="resource://notes/1", mime_type="text/plain", text="attached") + ), + ), + ] + ) + ) + + +@requirement("prompts:get:unknown-name") +async def test_get_prompt_unknown_name_is_protocol_error(connect: Connect) -> None: + """A handler that rejects an unrecognised prompt name with MCPError produces a JSON-RPC error. + + The error's code and message chosen by the handler reach the client verbatim. + """ + + async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: + raise MCPError(code=INVALID_PARAMS, message=f"Unknown prompt: {params.name}") + + server = Server("prompter", on_get_prompt=get_prompt) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.get_prompt("nope") + + assert exc_info.value.error == snapshot(ErrorData(code=INVALID_PARAMS, message="Unknown prompt: nope")) diff --git a/tests/interaction/lowlevel/test_resources.py b/tests/interaction/lowlevel/test_resources.py new file mode 100644 index 0000000000..4e369d3645 --- /dev/null +++ b/tests/interaction/lowlevel/test_resources.py @@ -0,0 +1,309 @@ +"""Resource interactions against the low-level Server, driven through the public Client API.""" + +import base64 + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + METHOD_NOT_FOUND, + Annotations, + BlobResourceContents, + CallToolResult, + EmptyResult, + ErrorData, + Icon, + ListResourcesResult, + ListResourceTemplatesResult, + ReadResourceResult, + Resource, + ResourceTemplate, + ResourceUpdatedNotification, + ResourceUpdatedNotificationParams, + TextContent, + TextResourceContents, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("resources:list:basic") +@requirement("resources:annotations") +async def test_list_resources_returns_registered_resources(connect: Connect) -> None: + """Listed resources reach the client with their URIs, names, and optional descriptive fields intact. + + The fully-populated entry includes annotations, so the snapshot also proves they round-trip. + The SDK's Annotations model omits the schema's lastModified field (see the divergence on + resources:annotations); the input is built via model_validate with lastModified set so the + snapshot pins the drop and will fail once the SDK adds the field. + """ + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult( + resources=[ + Resource(uri="memo://minimal", name="minimal"), + Resource( + uri="file:///project/README.md", + name="readme", + title="Project README", + description="The project's front page.", + mime_type="text/markdown", + size=1024, + annotations=Annotations.model_validate( + {"audience": ["user", "assistant"], "priority": 0.8, "lastModified": "2025-01-01T00:00:00Z"} + ), + icons=[Icon(src="https://example.com/readme.png", mime_type="image/png", sizes=["48x48"])], + ), + ] + ) + + server = Server("library", on_list_resources=list_resources) + + async with connect(server) as client: + result = await client.list_resources() + + assert result == snapshot( + ListResourcesResult( + resources=[ + Resource(uri="memo://minimal", name="minimal"), + Resource( + uri="file:///project/README.md", + name="readme", + title="Project README", + description="The project's front page.", + mime_type="text/markdown", + size=1024, + annotations=Annotations(audience=["user", "assistant"], priority=0.8), + icons=[Icon(src="https://example.com/readme.png", mime_type="image/png", sizes=["48x48"])], + ), + ] + ) + ) + + +@requirement("resources:read:text") +async def test_read_resource_text(connect: Connect) -> None: + """Reading a text resource returns its contents with the URI, MIME type, and text supplied by the handler.""" + + async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: + return ReadResourceResult( + contents=[TextResourceContents(uri=params.uri, mime_type="text/plain", text="Hello, world!")] + ) + + server = Server("library", on_read_resource=read_resource) + + async with connect(server) as client: + result = await client.read_resource("file:///greeting.txt") + + assert result == snapshot( + ReadResourceResult( + contents=[TextResourceContents(uri="file:///greeting.txt", mime_type="text/plain", text="Hello, world!")] + ) + ) + + +@requirement("resources:read:blob") +async def test_read_resource_binary(connect: Connect) -> None: + """Reading a binary resource returns its contents base64-encoded in the blob field.""" + + async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: + return ReadResourceResult( + contents=[ + BlobResourceContents( + uri=params.uri, + mime_type="image/png", + blob=base64.b64encode(b"\x89PNG").decode(), + ) + ] + ) + + server = Server("library", on_read_resource=read_resource) + + async with connect(server) as client: + result = await client.read_resource("file:///pixel.png") + + assert result == snapshot( + ReadResourceResult( + contents=[BlobResourceContents(uri="file:///pixel.png", mime_type="image/png", blob="iVBORw==")] + ) + ) + + +@requirement("resources:read:unknown-uri") +async def test_read_resource_unknown_uri_is_protocol_error(connect: Connect) -> None: + """A handler that rejects an unrecognised URI with MCPError produces a JSON-RPC error. + + The spec reserves -32002 for resource-not-found; the code is the handler's choice and reaches + the client verbatim. + """ + + async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: + raise MCPError(code=-32002, message=f"Resource not found: {params.uri}") + + server = Server("library", on_read_resource=read_resource) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.read_resource("file:///missing.txt") + + assert exc_info.value.error == snapshot(ErrorData(code=-32002, message="Resource not found: file:///missing.txt")) + + +@requirement("resources:templates:list") +async def test_list_resource_templates_returns_registered_templates(connect: Connect) -> None: + """Listed resource templates reach the client with their URI templates and descriptive fields intact.""" + + async def list_resource_templates( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourceTemplatesResult: + return ListResourceTemplatesResult( + resource_templates=[ + ResourceTemplate(uri_template="users://{user_id}", name="user"), + ResourceTemplate( + uri_template="logs://{service}/{date}", + name="service_logs", + title="Service logs", + description="One day of logs for one service.", + mime_type="text/plain", + icons=[Icon(src="https://example.com/logs.png", mime_type="image/png", sizes=["48x48"])], + ), + ] + ) + + server = Server("library", on_list_resource_templates=list_resource_templates) + + async with connect(server) as client: + result = await client.list_resource_templates() + + assert result == snapshot( + ListResourceTemplatesResult( + resource_templates=[ + ResourceTemplate(uri_template="users://{user_id}", name="user"), + ResourceTemplate( + uri_template="logs://{service}/{date}", + name="service_logs", + title="Service logs", + description="One day of logs for one service.", + mime_type="text/plain", + icons=[Icon(src="https://example.com/logs.png", mime_type="image/png", sizes=["48x48"])], + ), + ] + ) + ) + + +@requirement("resources:subscribe") +async def test_subscribe_resource_delivers_uri_to_handler(connect: Connect) -> None: + """Subscribing to a resource delivers the URI to the server's subscribe handler and returns an empty result.""" + + async def subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeRequestParams) -> EmptyResult: + assert params.uri == "file:///watched.txt" + return EmptyResult() + + server = Server("library", on_subscribe_resource=subscribe_resource) + + async with connect(server) as client: + result = await client.subscribe_resource("file:///watched.txt") + + assert result == snapshot(EmptyResult()) + + +@requirement("resources:subscribe:capability-required") +async def test_subscribe_without_a_subscribe_handler_is_method_not_found(connect: Connect) -> None: + """Subscribing to a server that registered no subscribe handler is rejected with METHOD_NOT_FOUND. + + The rejection comes from no handler being registered, not from any capability check; see the + divergence on lifecycle:capability:server-not-advertised. + """ + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: + """Registered only so the resources capability is advertised; never called.""" + raise NotImplementedError + + server = Server("library", on_list_resources=list_resources) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.subscribe_resource("file:///watched.txt") + + assert exc_info.value.error == snapshot(ErrorData(code=METHOD_NOT_FOUND, message="Method not found")) + + +@requirement("resources:unsubscribe") +async def test_unsubscribe_resource_delivers_uri_to_handler(connect: Connect) -> None: + """Unsubscribing from a resource delivers the URI to the server's unsubscribe handler.""" + + async def unsubscribe_resource(ctx: ServerRequestContext, params: types.UnsubscribeRequestParams) -> EmptyResult: + assert params.uri == "file:///watched.txt" + return EmptyResult() + + server = Server("library", on_unsubscribe_resource=unsubscribe_resource) + + async with connect(server) as client: + result = await client.unsubscribe_resource("file:///watched.txt") + + assert result == snapshot(EmptyResult()) + + +@requirement("resources:updated-notification") +async def test_resource_updated_notification_reaches_client(connect: Connect) -> None: + """A resources/updated notification sent during a tool call reaches the client with the resource URI. + + ``send_resource_updated`` does not take a ``related_request_id``, so over streamable HTTP the + notification routes to the standalone GET stream and is not guaranteed to arrive before the + tool result; the test waits on an event the collector sets. The collector records every + message the handler receives, so the assertion also proves nothing else was delivered. + """ + received: list[IncomingMessage] = [] + seen = anyio.Event() + + async def collect(message: IncomingMessage) -> None: + received.append(message) + seen.set() + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="touch", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "touch" + await ctx.session.send_resource_updated("file:///watched.txt") + return CallToolResult(content=[TextContent(text="touched")]) + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: + """Registered so the resources capability is advertised; the client never lists resources.""" + raise NotImplementedError + + async def subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeRequestParams) -> EmptyResult: + """Registered so the resources subscribe sub-capability is advertised; the client never subscribes.""" + raise NotImplementedError + + server = Server( + "library", + on_list_tools=list_tools, + on_call_tool=call_tool, + on_list_resources=list_resources, + on_subscribe_resource=subscribe_resource, + ) + + async with connect(server, message_handler=collect) as client: + await client.call_tool("touch", {}) + with anyio.fail_after(5): + await seen.wait() + + assert received == snapshot( + [ResourceUpdatedNotification(params=ResourceUpdatedNotificationParams(uri="file:///watched.txt"))] + ) diff --git a/tests/interaction/lowlevel/test_roots.py b/tests/interaction/lowlevel/test_roots.py new file mode 100644 index 0000000000..8149e0befb --- /dev/null +++ b/tests/interaction/lowlevel/test_roots.py @@ -0,0 +1,166 @@ +"""Roots interactions against the low-level Server, driven through the public Client API.""" + +import anyio +import pytest +from inline_snapshot import snapshot +from pydantic import FileUrl + +from mcp import MCPError, types +from mcp.client import ClientRequestContext +from mcp.server import Server, ServerRequestContext +from mcp.types import INTERNAL_ERROR, CallToolResult, ErrorData, ListRootsResult, Root, TextContent +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("roots:list:basic") +async def test_list_roots_round_trip(connect: Connect) -> None: + """A roots/list request from a tool handler is answered by the client's roots callback. + + The tool reports the URIs and names it received, proving the client's roots reached the server. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="show_roots", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "show_roots" + result = await ctx.session.list_roots() + lines = [f"{root.uri} name={root.name}" for root in result.roots] + return CallToolResult(content=[TextContent(text="\n".join(lines))]) + + server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) + + async def list_roots(context: ClientRequestContext) -> ListRootsResult: + return ListRootsResult( + roots=[ + Root(uri=FileUrl("file:///home/alice/project"), name="project"), + Root(uri=FileUrl("file:///home/alice/scratch")), + ] + ) + + async with connect(server, list_roots_callback=list_roots) as client: + result = await client.call_tool("show_roots", {}) + + assert result == snapshot( + CallToolResult( + content=[TextContent(text="file:///home/alice/project name=project\nfile:///home/alice/scratch name=None")] + ) + ) + + +@requirement("roots:list:empty") +async def test_list_roots_empty(connect: Connect) -> None: + """A client with no roots to offer answers roots/list with an empty list, not an error.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="count_roots", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "count_roots" + result = await ctx.session.list_roots() + return CallToolResult(content=[TextContent(text=str(len(result.roots)))]) + + server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) + + async def list_roots(context: ClientRequestContext) -> ListRootsResult: + return ListRootsResult(roots=[]) + + async with connect(server, list_roots_callback=list_roots) as client: + result = await client.call_tool("count_roots", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="0")])) + + +@requirement("roots:list:not-supported") +async def test_list_roots_without_callback_is_error(connect: Connect) -> None: + """A roots/list request to a client with no roots callback fails with an error the handler can observe. + + The client's default callback answers with INVALID_REQUEST rather than leaving the server + hanging; the spec names -32601 for this case (see the divergence note on the requirement). + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="show_roots", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "show_roots" + try: + await ctx.session.list_roots() + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # list_roots cannot succeed without a client callback + + server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("show_roots", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="-32600: List roots not supported")])) + + +@requirement("roots:list:client-error") +async def test_list_roots_callback_error_surfaces_to_the_handler(connect: Connect) -> None: + """A roots callback that answers with an error fails the roots/list request with that exact error. + + The callback's code and message reach the requesting handler verbatim as an MCPError. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="show_roots", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "show_roots" + try: + await ctx.session.list_roots() + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # the callback always answers with an error + + server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) + + async def list_roots(context: ClientRequestContext) -> ErrorData: + return ErrorData(code=INTERNAL_ERROR, message="roots provider crashed") + + async with connect(server, list_roots_callback=list_roots) as client: + result = await client.call_tool("show_roots", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="-32603: roots provider crashed")])) + + +@requirement("roots:list-changed") +async def test_roots_list_changed_reaches_server_handler(connect: Connect) -> None: + """A roots/list_changed notification from the client is delivered to the server's handler. + + Unlike a request, a notification has no response to await: the handler sets an event and the + test waits on it, which is the only synchronisation point proving delivery. + """ + delivered = anyio.Event() + received: list[types.NotificationParams | None] = [] + + async def roots_list_changed(ctx: ServerRequestContext, params: types.NotificationParams | None) -> None: + received.append(params) + delivered.set() + + server = Server("rooted", on_roots_list_changed=roots_list_changed) + + async def list_roots(context: ClientRequestContext) -> ListRootsResult: + """Registered so the client declares the roots capability; the server never asks for roots.""" + raise NotImplementedError + + async with connect(server, list_roots_callback=list_roots) as client: + await client.send_roots_list_changed() + with anyio.fail_after(5): + await delivered.wait() + + assert received == snapshot([None]) diff --git a/tests/interaction/lowlevel/test_sampling.py b/tests/interaction/lowlevel/test_sampling.py new file mode 100644 index 0000000000..260e564192 --- /dev/null +++ b/tests/interaction/lowlevel/test_sampling.py @@ -0,0 +1,687 @@ +"""Sampling interactions against the low-level Server, driven through the public Client API. + +Each test nests a sampling/createMessage request inside a tool call: the tool handler calls +ctx.session.create_message(), the client's sampling callback answers it, and the handler +round-trips what it received back to the test through its tool result. +""" + +import pydantic +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client import ClientRequestContext +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + AudioContent, + CallToolResult, + CreateMessageRequestParams, + CreateMessageResult, + CreateMessageResultWithTools, + ErrorData, + ImageContent, + ModelHint, + ModelPreferences, + SamplingCapability, + SamplingMessage, + TextContent, + ToolResultContent, + ToolUseContent, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("sampling:create:basic") +@requirement("tools:call:sampling-roundtrip") +async def test_create_message_round_trip(connect: Connect) -> None: + """A handler's sampling request is answered by the client callback, and the callback's result + (role, content, model, stop reason) is returned to the handler. + """ + received: list[CreateMessageRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Say hello."))], + max_tokens=100, + ) + assert isinstance(result.content, TextContent) + return CallToolResult(content=[TextContent(text=f"{result.model}/{result.stop_reason}: {result.content.text}")]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + received.append(params) + return CreateMessageResult( + role="assistant", + content=TextContent(text="Hello to you too."), + model="mock-llm-1", + stop_reason="endTurn", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="mock-llm-1/endTurn: Hello to you too.")])) + assert received == snapshot( + [ + CreateMessageRequestParams( + _meta={}, + messages=[SamplingMessage(role="user", content=TextContent(text="Say hello."))], + max_tokens=100, + ) + ] + ) + + +@requirement("sampling:create:include-context") +@requirement("sampling:create:model-preferences") +@requirement("sampling:create:system-prompt") +@requirement("sampling:context:server-gated-by-capability") +async def test_create_message_params_reach_callback(connect: Connect) -> None: + """Every sampling parameter the handler supplies arrives at the client callback unchanged. + + The client has not declared the sampling.context capability (Client cannot declare it), yet + include_context="thisServer" reaches the callback regardless: the spec's SHOULD NOT is not + enforced. See the divergence note on `sampling:context:server-gated-by-capability`. + """ + received: list[CreateMessageRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Pick a model."))], + max_tokens=50, + system_prompt="You are terse.", + include_context="thisServer", + temperature=0.7, + stop_sequences=["\n\n", "END"], + model_preferences=ModelPreferences( + hints=[ModelHint(name="claude"), ModelHint(name="gpt")], + cost_priority=0.2, + speed_priority=0.3, + intelligence_priority=0.9, + ), + ) + assert isinstance(result.content, TextContent) + return CallToolResult(content=[TextContent(text=result.content.text)]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + received.append(params) + return CreateMessageResult(role="assistant", content=TextContent(text="ok"), model="mock-llm-1") + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="ok")])) + assert received == snapshot( + [ + CreateMessageRequestParams( + _meta={}, + messages=[SamplingMessage(role="user", content=TextContent(text="Pick a model."))], + model_preferences=ModelPreferences( + hints=[ModelHint(name="claude"), ModelHint(name="gpt")], + cost_priority=0.2, + speed_priority=0.3, + intelligence_priority=0.9, + ), + system_prompt="You are terse.", + include_context="thisServer", + temperature=0.7, + max_tokens=50, + stop_sequences=["\n\n", "END"], + ) + ] + ) + + +@requirement("sampling:create-message:image-content") +async def test_create_message_request_with_image_content_reaches_callback(connect: Connect) -> None: + """A sampling request message carrying image content arrives at the client callback intact. + + This is the server-to-client direction: the server includes an image in the conversation it + asks the client to sample from. + """ + received: list[CreateMessageRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="describe_image", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "describe_image" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=ImageContent(data="aW1n", mime_type="image/png"))], + max_tokens=100, + ) + assert isinstance(result.content, TextContent) + return CallToolResult(content=[TextContent(text=result.content.text)]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + received.append(params) + image = params.messages[0].content + assert isinstance(image, ImageContent) + return CreateMessageResult( + role="assistant", + content=TextContent(text=f"described {image.mime_type} ({image.data})"), + model="mock-vision-1", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("describe_image", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="described image/png (aW1n)")])) + assert received == snapshot( + [ + CreateMessageRequestParams( + _meta={}, + messages=[SamplingMessage(role="user", content=ImageContent(data="aW1n", mime_type="image/png"))], + max_tokens=100, + ) + ] + ) + + +@requirement("sampling:create-message:image-content") +async def test_create_message_result_with_image_content_returns_to_handler(connect: Connect) -> None: + """A sampling result whose content is an image is returned to the requesting handler intact. + + This is the client-to-server direction: the model's response is an image rather than text. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="draw", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "draw" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Draw a cat."))], + max_tokens=100, + ) + image = result.content + assert isinstance(image, ImageContent) + return CallToolResult(content=[TextContent(text=f"{result.model}: {image.mime_type} {image.data}")]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + return CreateMessageResult( + role="assistant", + content=ImageContent(data="Y2F0", mime_type="image/png"), + model="mock-vision-1", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("draw", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="mock-vision-1: image/png Y2F0")])) + + +@requirement("sampling:error:user-rejected") +async def test_create_message_callback_error(connect: Connect) -> None: + """A sampling callback that answers with an error surfaces to the requesting handler as an MCPError. + + The error here is the spec's own example for a user rejecting a sampling request (code -1); + the callback's code and message reach the handler verbatim, whatever they are. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + try: + await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Say hello."))], + max_tokens=100, + ) + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # the callback always answers with an error + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback(context: ClientRequestContext, params: CreateMessageRequestParams) -> ErrorData: + return ErrorData(code=-1, message="User rejected sampling request") + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="-1: User rejected sampling request")])) + + +@requirement("sampling:create-message:not-supported") +async def test_create_message_without_callback_is_error(connect: Connect) -> None: + """A sampling request to a client with no sampling callback fails with the SDK's default error.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + try: + await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Say hello."))], + max_tokens=100, + ) + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # create_message cannot succeed without a client callback + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="-32600: Sampling not supported")])) + + +@requirement("sampling:tools:server-gated-by-capability") +async def test_create_message_with_tools_is_rejected_for_unsupporting_client(connect: Connect) -> None: + """A tool-enabled sampling request to a client that has not declared sampling.tools never leaves the server. + + The client supports plain sampling but cannot declare the tools sub-capability (Client does not + expose it), so the server-side validator rejects the request before anything reaches the wire. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + try: + await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="What is the weather?"))], + max_tokens=100, + tools=[types.Tool(name="get_weather", input_schema={"type": "object"})], + ) + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # the validator rejects every tool-enabled request + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + """Declares the plain sampling capability; never invoked because the request is rejected first.""" + raise NotImplementedError + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="-32602: Client does not support sampling tools capability")]) + ) + + +@requirement("sampling:tool-result:no-mixed-content") +async def test_create_message_with_mixed_tool_result_content_is_rejected(connect: Connect) -> None: + """A sampling request whose user message mixes tool_result with other content never leaves the server. + + The message-structure validation runs inside create_message before the request is sent, even + when no tools are passed, so the client callback is never invoked and the handler observes the + ValueError directly. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="summarise_tools", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "summarise_tools" + try: + await ctx.session.create_message( + messages=[ + SamplingMessage( + role="user", + content=[ + ToolResultContent(tool_use_id="call-1", content=[TextContent(text="42")]), + TextContent(text="Also, a comment alongside the result."), + ], + ) + ], + max_tokens=100, + ) + except ValueError as exc: + return CallToolResult(content=[TextContent(text=f"{type(exc).__name__}: {exc}")]) + raise NotImplementedError # the validator rejects the malformed messages before sending + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + """Declares the sampling capability; never invoked because the request is rejected first.""" + raise NotImplementedError + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("summarise_tools", {}) + + assert result == snapshot( + CallToolResult( + content=[ + TextContent(text="ValueError: The last message must contain only tool_result content if any is present") + ] + ) + ) + + +@requirement("sampling:capability:declare") +async def test_a_client_with_a_sampling_callback_declares_the_sampling_capability(connect: Connect) -> None: + """A client connecting with a sampling callback advertises the sampling capability to the server. + + Client cannot declare any sub-capabilities (it does not expose ClientSession's + sampling_capabilities parameter), so the snapshot pins an empty SamplingCapability. + """ + captured: list[SamplingCapability | None] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="capabilities", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "capabilities" + assert ctx.session.client_params is not None + captured.append(ctx.session.client_params.capabilities.sampling) + return CallToolResult(content=[TextContent(text="ok")]) + + server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + """Registered only so the sampling capability is advertised; never called.""" + raise NotImplementedError + + async with connect(server, sampling_callback=sampling_callback) as client: + await client.call_tool("capabilities", {}) + + assert captured == snapshot([SamplingCapability()]) + + +@requirement("sampling:create-message:audio-content") +async def test_create_message_request_with_audio_content_reaches_callback(connect: Connect) -> None: + """A sampling request message carrying audio content arrives at the client callback intact. + + This is the server-to-client direction: the server includes audio in the conversation it asks + the client to sample from. + """ + received: list[CreateMessageRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="transcribe", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "transcribe" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=AudioContent(data="c25k", mime_type="audio/wav"))], + max_tokens=100, + ) + assert isinstance(result.content, TextContent) + return CallToolResult(content=[TextContent(text=result.content.text)]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + received.append(params) + audio = params.messages[0].content + assert isinstance(audio, AudioContent) + return CreateMessageResult( + role="assistant", + content=TextContent(text=f"transcribed {audio.mime_type} ({audio.data})"), + model="mock-audio-1", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("transcribe", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="transcribed audio/wav (c25k)")])) + assert received == snapshot( + [ + CreateMessageRequestParams( + _meta={}, + messages=[SamplingMessage(role="user", content=AudioContent(data="c25k", mime_type="audio/wav"))], + max_tokens=100, + ) + ] + ) + + +@requirement("sampling:create-message:audio-content") +async def test_create_message_result_with_audio_content_returns_to_handler(connect: Connect) -> None: + """A sampling result whose content is audio is returned to the requesting handler intact. + + This is the client-to-server direction: the model's response is audio rather than text. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="speak", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "speak" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Say hello, aloud."))], + max_tokens=100, + ) + audio = result.content + assert isinstance(audio, AudioContent) + return CallToolResult(content=[TextContent(text=f"{result.model}: {audio.mime_type} {audio.data}")]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + return CreateMessageResult( + role="assistant", + content=AudioContent(data="aGVsbG8=", mime_type="audio/wav"), + model="mock-audio-1", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("speak", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="mock-audio-1: audio/wav aGVsbG8=")])) + + +@requirement("sampling:message:content-cardinality") +async def test_create_message_with_list_valued_message_content_reaches_callback(connect: Connect) -> None: + """A sampling message whose content is a list of blocks arrives at the client callback as a list.""" + received: list[CreateMessageRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="caption", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "caption" + result = await ctx.session.create_message( + messages=[ + SamplingMessage( + role="user", + content=[ + TextContent(text="Caption this image."), + ImageContent(data="aW1n", mime_type="image/png"), + ], + ) + ], + max_tokens=100, + ) + assert isinstance(result.content, TextContent) + return CallToolResult(content=[TextContent(text=result.content.text)]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + received.append(params) + content = params.messages[0].content + assert isinstance(content, list) + return CreateMessageResult( + role="assistant", content=TextContent(text=f"{len(content)} blocks"), model="mock-llm-1" + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("caption", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="2 blocks")])) + assert received == snapshot( + [ + CreateMessageRequestParams( + _meta={}, + messages=[ + SamplingMessage( + role="user", + content=[ + TextContent(text="Caption this image."), + ImageContent(data="aW1n", mime_type="image/png"), + ], + ) + ], + max_tokens=100, + ) + ] + ) + + +@requirement("sampling:tool-use:server-preflight") +async def test_create_message_with_mismatched_tool_use_and_result_ids_is_rejected(connect: Connect) -> None: + """A sampling request whose tool_result ids do not match the preceding tool_use ids never leaves the server. + + The message-structure validation runs inside create_message before the request is sent, so the + client callback is never invoked and the handler observes the ValueError directly. The spec's + client-side -32602 check is tracked separately at sampling:tool-use:result-balance. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="continue_tools", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "continue_tools" + try: + await ctx.session.create_message( + messages=[ + SamplingMessage( + role="assistant", + content=[ToolUseContent(id="call-1", name="weather", input={})], + ), + SamplingMessage( + role="user", + content=[ToolResultContent(tool_use_id="call-WRONG", content=[TextContent(text="42")])], + ), + ], + max_tokens=100, + ) + except ValueError as exc: + return CallToolResult(content=[TextContent(text=f"{type(exc).__name__}: {exc}")]) + raise NotImplementedError # the validator rejects the malformed messages before sending + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + """Declares the sampling capability; never invoked because the request is rejected first.""" + raise NotImplementedError + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("continue_tools", {}) + + assert result == snapshot( + CallToolResult( + content=[ + TextContent( + text="ValueError: ids of tool_result blocks and tool_use blocks from previous message do not match" + ) + ] + ) + ) + + +@requirement("sampling:result:no-tools-single-content") +async def test_array_content_result_for_a_tool_free_request_surfaces_as_a_validation_error(connect: Connect) -> None: + """An array-content sampling result for a tool-free request is accepted by the client and fails server-side. + + Only the exception type is asserted: the message is pydantic's, which changes across releases. + See the divergence note on the requirement: the intended behaviour is that the client rejects + the result; instead the client accepts it and the server's response parsing raises. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + try: + await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Two thoughts, please."))], + max_tokens=100, + ) + except pydantic.ValidationError as exc: + return CallToolResult(content=[TextContent(text=type(exc).__name__)]) + raise NotImplementedError # the array-content result fails server-side parsing every time + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResultWithTools: + return CreateMessageResultWithTools( + role="assistant", + content=[TextContent(text="First thought."), TextContent(text="Second thought.")], + model="mock-llm-1", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="ValidationError")])) diff --git a/tests/interaction/lowlevel/test_timeouts.py b/tests/interaction/lowlevel/test_timeouts.py new file mode 100644 index 0000000000..a9c83d641d --- /dev/null +++ b/tests/interaction/lowlevel/test_timeouts.py @@ -0,0 +1,114 @@ +"""Request timeouts against the low-level Server, driven through the public Client API. + +The handler blocks on an event that is never set, so the awaited response can never arrive and +any positive timeout fires deterministically on the next event-loop pass. The timeout is therefore +set to an effectively-zero duration: the tests add no wall-clock time to the suite. (Zero itself +cannot be used: a falsy read_timeout_seconds is silently treated as "no timeout".) +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client.client import Client +from mcp.server import Server, ServerRequestContext +from mcp.types import REQUEST_TIMEOUT, CallToolResult, ErrorData, TextContent +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("protocol:timeout:basic") +@requirement("protocol:timeout:sends-cancellation") +async def test_request_timeout_fails_the_pending_call() -> None: + """A request whose response does not arrive within its read timeout fails with a timeout error. + + No cancellation is sent to the server (see the divergence note on the requirement): the handler + starts and is still running after the caller has already given up. The test waits for the + handler to have started only after the timeout has fired, so the timeout itself races nothing. + """ + handler_started = anyio.Event() + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "block" + handler_started.set() + await anyio.Event().wait() # blocks until the session is torn down + raise NotImplementedError # unreachable + + server = Server("blocker", on_call_tool=call_tool) + + async with Client(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("block", {}, read_timeout_seconds=0.000001) + + # The request was already on the wire: the handler still runs even though the caller gave up. + with anyio.fail_after(5): + await handler_started.wait() + + assert exc_info.value.error == snapshot( + ErrorData( + code=REQUEST_TIMEOUT, + message="Timed out while waiting for response to CallToolRequest. Waited 1e-06 seconds.", + ) + ) + + +@requirement("protocol:timeout:session-survives") +async def test_session_serves_requests_after_timeout() -> None: + """A timed-out request does not poison the session: the next request succeeds.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool(name="block", input_schema={"type": "object"}), + types.Tool(name="echo", input_schema={"type": "object"}), + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + if params.name == "echo": + return CallToolResult(content=[TextContent(text="still alive")]) + await anyio.Event().wait() # blocks until the session is torn down + raise NotImplementedError # unreachable + + server = Server("blocker", on_list_tools=list_tools, on_call_tool=call_tool) + + async with Client(server) as client: + with pytest.raises(MCPError): + await client.call_tool("block", {}, read_timeout_seconds=0.000001) + + result = await client.call_tool("echo", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="still alive")])) + + +@requirement("protocol:timeout:session-default") +async def test_session_level_timeout_applies_to_every_request() -> None: + """A read timeout configured on the client applies to requests that do not set their own.""" + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "block" + await anyio.Event().wait() # blocks until the session is torn down + raise NotImplementedError # unreachable + + server = Server("blocker", on_call_tool=call_tool) + + # The one real wall-clock wait in the suite, and it cannot be made effectively zero like the + # per-request timeouts: a session-level timeout also governs the initialize handshake, so the + # value must be long enough for the in-process handshake to complete before the blocked tool + # call waits it out in full. 50ms buys a ~50x safety margin over the handshake's actual + # latency; lowering it only erodes the margin against CI scheduler jitter without saving + # anything perceptible. + async with Client(server, read_timeout_seconds=0.05) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("block", {}) + + assert exc_info.value.error == snapshot( + ErrorData( + code=REQUEST_TIMEOUT, + message="Timed out while waiting for response to CallToolRequest. Waited 0.05 seconds.", + ) + ) diff --git a/tests/interaction/lowlevel/test_tools.py b/tests/interaction/lowlevel/test_tools.py new file mode 100644 index 0000000000..e8053fbaa7 --- /dev/null +++ b/tests/interaction/lowlevel/test_tools.py @@ -0,0 +1,512 @@ +"""Tool interactions against the low-level Server, driven through the public Client API.""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + INVALID_PARAMS, + AudioContent, + CallToolResult, + EmbeddedResource, + ErrorData, + Icon, + ImageContent, + ListToolsResult, + ResourceLink, + TextContent, + TextResourceContents, + Tool, + ToolAnnotations, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("tools:call:content:text") +async def test_call_tool_returns_text_content(connect: Connect) -> None: + """Arguments reach the tool handler; its content comes back as the call result.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="add", description="Add two integers.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "add" + assert params.arguments is not None + return CallToolResult(content=[TextContent(text=str(params.arguments["a"] + params.arguments["b"]))]) + + server = Server("adder", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("add", {"a": 2, "b": 3}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="5")])) + + +@requirement("tools:call:is-error") +async def test_call_tool_execution_error_is_returned_as_result(connect: Connect) -> None: + """A tool reporting its own failure with is_error=True reaches the client as a result, not an exception. + + Tool execution errors are part of the result so the caller (typically a model) can see + them; only protocol-level failures become JSON-RPC errors. + """ + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "flux" + return CallToolResult(content=[TextContent(text="the flux capacitor is offline")], is_error=True) + + server = Server("errors", on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("flux", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="the flux capacitor is offline")], is_error=True) + ) + + +@requirement("tools:call:unknown-name") +async def test_call_tool_unknown_tool_is_protocol_error(connect: Connect) -> None: + """A handler that rejects an unrecognised tool name with MCPError produces a JSON-RPC error. + + The error's code, message, and data chosen by the handler reach the client verbatim. + """ + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + raise MCPError(code=INVALID_PARAMS, message=f"Unknown tool: {params.name}", data={"requested": params.name}) + + server = Server("errors", on_call_tool=call_tool) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("nope", {}) + + assert exc_info.value.error == snapshot( + ErrorData(code=INVALID_PARAMS, message="Unknown tool: nope", data={"requested": "nope"}) + ) + + +@requirement("protocol:error:internal-error") +async def test_call_tool_uncaught_exception_becomes_error_response(connect: Connect) -> None: + """An uncaught exception in the tool handler surfaces to the client as a JSON-RPC error. + + The low-level server reports it with code 0 and the exception text as the message; see the + divergence note on the requirement. + """ + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "explode" + raise ValueError("boom") + + server = Server("errors", on_call_tool=call_tool) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("explode", {}) + + assert exc_info.value.error == snapshot(ErrorData(code=0, message="boom")) + + +@requirement("tools:list:basic") +async def test_list_tools_returns_registered_tools(connect: Connect) -> None: + """The tools advertised by the server's list handler arrive at the client unchanged.""" + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="add", + description="Add two integers.", + input_schema={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + ), + Tool(name="reset", description="Reset the calculator.", input_schema={"type": "object"}), + ] + ) + + server = Server("calculator", on_list_tools=list_tools) + + async with connect(server) as client: + result = await client.list_tools() + + assert result == snapshot( + ListToolsResult( + tools=[ + Tool( + name="add", + description="Add two integers.", + input_schema={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + ), + Tool(name="reset", description="Reset the calculator.", input_schema={"type": "object"}), + ] + ) + ) + + +@requirement("tools:input-schema:json-schema-2020-12") +@requirement("tools:input-schema:preserve-additional-properties") +@requirement("tools:input-schema:preserve-defs") +@requirement("tools:input-schema:preserve-schema-dialect") +async def test_tools_list_preserves_arbitrary_input_schema_keywords(connect: Connect) -> None: + """A rich JSON Schema 2020-12 inputSchema reaches the client unchanged and the tool is callable. + + The single identity assertion below proves all four pass-through behaviours at once: the same + dict literal that was registered is the dict that arrives, so $schema, $defs, the nested object + property, and additionalProperties are each preserved by virtue of the whole schema being + preserved. The follow-up call proves the rich-schema tool is callable end to end. + """ + schema = { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "$defs": {"positive": {"type": "integer", "exclusiveMinimum": 0}}, + "properties": { + "count": {"$ref": "#/$defs/positive"}, + "options": { + "type": "object", + "properties": {"verbose": {"type": "boolean"}}, + "additionalProperties": False, + }, + }, + "required": ["count"], + "additionalProperties": False, + } + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="typed", input_schema=schema)]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "typed" + assert params.arguments == {"count": 3, "options": {"verbose": True}} + return CallToolResult(content=[TextContent(text="ok")]) + + server = Server("typed", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + listed = await client.list_tools() + called = await client.call_tool("typed", {"count": 3, "options": {"verbose": True}}) + + assert listed.tools[0].input_schema == schema + assert called == snapshot(CallToolResult(content=[TextContent(text="ok")])) + + +@requirement("tools:list:metadata") +async def test_list_tools_optional_fields_round_trip(connect: Connect) -> None: + """Every optional Tool field the server supplies reaches the client unchanged.""" + + tool = Tool( + name="annotated", + title="Annotated tool", + description="A tool carrying every optional field.", + input_schema={"type": "object"}, + output_schema={"type": "object", "properties": {"answer": {"type": "integer"}}}, + icons=[Icon(src="https://example.com/icon.png", mime_type="image/png", sizes=["48x48"])], + annotations=ToolAnnotations(title="Display title", read_only_hint=True, idempotent_hint=True), + _meta={"example.com/source": "interaction-suite"}, + ) + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[tool]) + + server = Server("annotated", on_list_tools=list_tools) + + async with connect(server) as client: + result = await client.list_tools() + + assert result == snapshot( + ListToolsResult( + tools=[ + Tool( + name="annotated", + title="Annotated tool", + description="A tool carrying every optional field.", + input_schema={"type": "object"}, + output_schema={"type": "object", "properties": {"answer": {"type": "integer"}}}, + icons=[Icon(src="https://example.com/icon.png", mime_type="image/png", sizes=["48x48"])], + annotations=ToolAnnotations(title="Display title", read_only_hint=True, idempotent_hint=True), + _meta={"example.com/source": "interaction-suite"}, + ) + ] + ) + ) + + +@requirement("tools:call:content:mixed") +@requirement("tools:call:content:image") +@requirement("tools:call:content:audio") +@requirement("tools:call:content:resource-link") +@requirement("tools:call:content:embedded-resource") +async def test_call_tool_multiple_content_block_types(connect: Connect) -> None: + """A tool result can mix every content block type; all of them arrive in order. + + The payloads are tiny fixed base64 strings ("aW1n" is b"img", "YXVk" is b"aud") so the + snapshot pins the exact bytes the client receives. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="render", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "render" + return CallToolResult( + content=[ + TextContent(text="all five content block types"), + ImageContent(data="aW1n", mime_type="image/png"), + AudioContent(data="YXVk", mime_type="audio/wav"), + ResourceLink(name="report", uri="resource://reports/1", description="The full report"), + EmbeddedResource( + resource=TextResourceContents(uri="resource://reports/1", mime_type="text/plain", text="contents") + ), + ] + ) + + server = Server("renderer", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("render", {}) + + assert result == snapshot( + CallToolResult( + content=[ + TextContent(text="all five content block types"), + ImageContent(data="aW1n", mime_type="image/png"), + AudioContent(data="YXVk", mime_type="audio/wav"), + ResourceLink(name="report", uri="resource://reports/1", description="The full report"), + EmbeddedResource( + resource=TextResourceContents(uri="resource://reports/1", mime_type="text/plain", text="contents") + ), + ] + ) + ) + + +@requirement("tools:call:structured-content") +async def test_call_tool_structured_content(connect: Connect) -> None: + """A tool result carrying structured content alongside content delivers both to the client.""" + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="sum", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "sum" + return CallToolResult(content=[TextContent(text="the sum is 5")], structured_content={"sum": 5}) + + server = Server("calculator", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("sum", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="the sum is 5")], structured_content={"sum": 5})) + + +@requirement("tools:call:concurrent") +async def test_concurrent_tool_calls_complete_independently(connect: Connect) -> None: + """Two tool calls in flight at once run concurrently and each caller gets its own answer. + + Both handlers are held on a shared event after signalling that they have started, and the test + only releases them once both signals have arrived -- a server that processed requests + sequentially would never start the second handler and the test would time out instead. + """ + started: list[str] = [] + started_events = {"first": anyio.Event(), "second": anyio.Event()} + release = anyio.Event() + results: dict[str, CallToolResult] = {} + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="echo", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "echo" + assert params.arguments is not None + tag = params.arguments["tag"] + assert isinstance(tag, str) + started.append(tag) + started_events[tag].set() + await release.wait() + return CallToolResult(content=[TextContent(text=tag)]) + + server = Server("echoer", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + with anyio.fail_after(5): + async with anyio.create_task_group() as task_group: # pragma: no branch + + async def call_and_record(tag: str) -> None: + results[tag] = await client.call_tool("echo", {"tag": tag}) + + task_group.start_soon(call_and_record, "first") + task_group.start_soon(call_and_record, "second") + + # Both handlers are running at the same time before either is allowed to finish. + await started_events["first"].wait() + await started_events["second"].wait() + release.set() + + assert sorted(started) == ["first", "second"] + assert results == snapshot( + { + "first": CallToolResult(content=[TextContent(text="first")]), + "second": CallToolResult(content=[TextContent(text="second")]), + } + ) + + +@requirement("client:output-schema:validate") +async def test_call_tool_structured_content_violating_output_schema_is_rejected_by_the_client(connect: Connect) -> None: + """A result whose structured content does not conform to the tool's declared output schema never + reaches the caller: the client validates it against the schema cached from tools/list and raises. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="forecast", + input_schema={"type": "object"}, + output_schema={ + "type": "object", + "properties": {"temperature": {"type": "number"}}, + "required": ["temperature"], + }, + ) + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "forecast" + return CallToolResult(content=[TextContent(text="warm")], structured_content={"temperature": "warm"}) + + server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.list_tools() + with pytest.raises(RuntimeError) as exc_info: + await client.call_tool("forecast", {}) + + # The message embeds the jsonschema validation error, so only the SDK-authored prefix is pinned. + assert str(exc_info.value).startswith("Invalid structured content returned by tool forecast") + + +@requirement("client:output-schema:skip-on-error") +async def test_is_error_result_bypasses_client_output_schema_validation(connect: Connect) -> None: + """A tool result with isError true is returned as-is even when its structured content violates the schema. + + The schema is cached up front so the client could validate, proving the bypass is specifically the + isError flag and not an empty cache. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="forecast", + input_schema={"type": "object"}, + output_schema={ + "type": "object", + "properties": {"temperature": {"type": "number"}}, + "required": ["temperature"], + }, + ) + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "forecast" + return CallToolResult( + content=[TextContent(text="boom")], structured_content={"temperature": "warm"}, is_error=True + ) + + server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.list_tools() + result = await client.call_tool("forecast", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="boom")], structured_content={"temperature": "warm"}, is_error=True) + ) + + +@requirement("client:output-schema:missing-structured") +async def test_declared_output_schema_with_no_structured_content_is_rejected_by_the_client(connect: Connect) -> None: + """A tool that declared an output schema but returned no structuredContent fails the client-side check. + + The error is the SDK's own message, so the full text is snapshotted. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="forecast", + input_schema={"type": "object"}, + output_schema={"type": "object", "properties": {"temperature": {"type": "number"}}}, + ) + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "forecast" + return CallToolResult(content=[TextContent(text="warm")]) + + server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.list_tools() + with pytest.raises(RuntimeError) as exc_info: + await client.call_tool("forecast", {}) + + assert str(exc_info.value) == snapshot("Tool forecast has an output schema but did not return structured content") + + +@requirement("client:output-schema:auto-list") +async def test_call_tool_populates_the_output_schema_cache_via_an_implicit_tools_list(connect: Connect) -> None: + """Calling a tool whose schema is not cached issues exactly one implicit tools/list to populate it. + + The first call_tool of an uncached tool triggers a tools/list the caller never asked for; the + second call hits the cache and does not. This is the SDK's chosen cache strategy and the cause of + the surprising behaviour where a server with only on_call_tool sees a successful call answered + with METHOD_NOT_FOUND from a request the caller never made; see the divergence on the requirement. + """ + list_calls: list[str] = [] + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + list_calls.append("called") + return ListToolsResult( + tools=[ + Tool( + name="forecast", + input_schema={"type": "object"}, + output_schema={"type": "object", "properties": {"temperature": {"type": "number"}}}, + ) + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "forecast" + return CallToolResult(content=[TextContent(text="21 C")], structured_content={"temperature": 21}) + + server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + first = await client.call_tool("forecast", {}) + assert list_calls == ["called"] + second = await client.call_tool("forecast", {}) + + assert list_calls == ["called"] + assert first == snapshot(CallToolResult(content=[TextContent(text="21 C")], structured_content={"temperature": 21})) + assert second == first diff --git a/tests/interaction/lowlevel/test_wire.py b/tests/interaction/lowlevel/test_wire.py new file mode 100644 index 0000000000..0f9c58aa7a --- /dev/null +++ b/tests/interaction/lowlevel/test_wire.py @@ -0,0 +1,309 @@ +"""Wire-level invariants observed at the client's transport boundary. + +These behaviours are invisible to API callers -- they are properties of the raw JSON-RPC frames. +The tests wrap the in-memory transport in a RecordingTransport, which tees every message crossing +the transport seam into a list without touching the session, so the assertions hold for whatever +the session implementation sends rather than for what its API returns. + +The later tests drive the wire by hand instead: one closes the server-to-client stream while a +request is in flight to pin the connection-closed teardown, and the last two send deliberately +malformed JSON-RPC requests that the typed client API cannot produce. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client import ClientRequestContext, ClientSession +from mcp.client._memory import InMemoryTransport +from mcp.client.client import Client +from mcp.server import Server, ServerRequestContext +from mcp.shared.memory import create_client_server_memory_streams +from mcp.shared.message import SessionMessage +from mcp.types import ( + CONNECTION_CLOSED, + INVALID_PARAMS, + CallToolRequest, + CallToolRequestParams, + CallToolResult, + EmptyResult, + ErrorData, + JSONRPCError, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + ListRootsResult, + TextContent, +) +from tests.interaction._helpers import RecordingTransport, _RecordingReadStream +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +def _echo_server() -> Server: + """A server with one echo tool, used by every test in this module.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="echo", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "echo" + return CallToolResult(content=[TextContent(text="ok")]) + + return Server("wire", on_list_tools=list_tools, on_call_tool=call_tool) + + +@requirement("protocol:request-id:unique") +async def test_request_ids_are_unique_and_never_null() -> None: + """Every request the client sends carries a distinct, non-null id. + + The id sequence is pinned: sequential integers from zero, in send order. + """ + recording = RecordingTransport(InMemoryTransport(_echo_server())) + + async with Client(recording) as client: + await client.list_tools() + await client.call_tool("echo", {}) + await client.call_tool("echo", {}) + await client.send_ping() + + sent = [message.message for message in recording.sent] + request_ids = [message.id for message in sent if isinstance(message, JSONRPCRequest)] + assert all(request_id is not None for request_id in request_ids) + assert len(request_ids) == len(set(request_ids)) + # initialize, tools/list, tools/call, tools/call, ping -- the client does not issue a + # schema-cache refresh here because the explicit tools/list already populated the cache. + assert request_ids == snapshot([0, 1, 2, 3, 4]) + + +@requirement("protocol:notifications:no-response") +async def test_notifications_are_never_answered() -> None: + """A notification produces no response: everything the server sends back answers a request. + + The client sends two notifications (initialized and roots/list_changed) and several requests; + the messages received from the server must be exactly one response per request, each carrying + the id of the request it answers, and nothing else. + """ + + async def list_roots(context: ClientRequestContext) -> ListRootsResult: + """Registered so the client declares the roots capability; the server never asks for roots.""" + raise NotImplementedError + + recording = RecordingTransport(InMemoryTransport(_echo_server())) + + async with Client(recording, list_roots_callback=list_roots) as client: + await client.send_roots_list_changed() + await client.send_ping() + + sent = [message.message for message in recording.sent] + sent_request_ids = [message.id for message in sent if isinstance(message, JSONRPCRequest)] + sent_notifications = [message for message in sent if isinstance(message, JSONRPCNotification)] + received = [message.message for message in recording.received if isinstance(message, SessionMessage)] + received_responses = [message for message in received if isinstance(message, JSONRPCResponse)] + + assert len(sent_notifications) == 2 # notifications/initialized and notifications/roots/list_changed + assert len(received_responses) == len(received) # nothing the server sent was anything but a response + assert [message.id for message in received_responses] == sent_request_ids + + +async def test_recording_read_stream_ends_iteration_when_the_sender_closes() -> None: + """The recording wrapper preserves the end-of-stream behaviour of the stream it wraps. + + This exercises the helper itself rather than an interaction-model behaviour: a transport whose + far end closes must end the client's receive loop cleanly, and the wrapper must not swallow or + mistranslate that. + """ + send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](1) + log: list[SessionMessage | Exception] = [] + async with send_stream, _RecordingReadStream(receive_stream, log) as wrapped: + await send_stream.aclose() + items = [item async for item in wrapped] + assert items == [] + assert log == [] + + +@requirement("lifecycle:initialized-notification") +async def test_exactly_one_initialized_notification_is_sent_after_the_handshake() -> None: + """The client sends initialized exactly once, between the initialize response and its first request. + + The full method sequence the client puts on the wire is pinned in send order. + """ + recording = RecordingTransport(InMemoryTransport(_echo_server())) + + async with Client(recording) as client: + await client.list_tools() + + sent_methods = [ + message.message.method + for message in recording.sent + if isinstance(message.message, JSONRPCRequest | JSONRPCNotification) + ] + assert sent_methods.count("notifications/initialized") == 1 + assert sent_methods == snapshot(["initialize", "notifications/initialized", "tools/list"]) + + +@requirement("protocol:error:connection-closed") +async def test_closing_the_transport_fails_in_flight_requests_with_connection_closed() -> None: + """When the server-to-client stream closes, every in-flight client request fails with CONNECTION_CLOSED. + + Driven over a bare ClientSession against a real Server so the test holds the transport stream + pair directly: once the request is in flight (the server handler signals it has started) the + test closes the server's write stream, which ends the client's receive loop and triggers the + teardown that fails the pending request. + """ + handler_started = anyio.Event() + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "block" + handler_started.set() + await anyio.Event().wait() # blocks until cancelled; nothing ever sets this event + raise NotImplementedError # unreachable: the wait above never completes normally + + server = Server("blocker", on_call_tool=call_tool) + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + errors: list[ErrorData] = [] + + async with anyio.create_task_group() as server_task_group: + server_task_group.start_soon(server.run, server_read, server_write, server.create_initialization_options()) + + async with ClientSession(client_read, client_write) as session: + with anyio.fail_after(5): + await session.initialize() + + async def call_and_capture_error() -> None: + with pytest.raises(MCPError) as exc_info: + await session.send_request( + CallToolRequest(params=CallToolRequestParams(name="block")), CallToolResult + ) + errors.append(exc_info.value.error) + + async with anyio.create_task_group() as task_group: # pragma: no branch + task_group.start_soon(call_and_capture_error) + await handler_started.wait() + await server_write.aclose() + + server_task_group.cancel_scope.cancel() + + assert errors == snapshot([ErrorData(code=CONNECTION_CLOSED, message="Connection closed")]) + + +@requirement("protocol:error:invalid-params") +async def test_malformed_request_params_are_answered_with_invalid_params() -> None: + """A request whose params fail validation is answered with -32602 Invalid params. + + The typed client API cannot construct a request with the wrong parameter types, so the test + plays the client's side of the wire by hand against a real Server: it completes the + initialization handshake at the JSON-RPC layer and then sends a tools/call whose `name` is an + integer. Reserve this pattern for behaviour the typed API cannot produce. + """ + server = Server("strict") + errors: list[ErrorData] = [] + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async with anyio.create_task_group() as server_task_group: + server_task_group.start_soon(server.run, server_read, server_write, server.create_initialization_options()) + + with anyio.fail_after(5): + await client_write.send( + SessionMessage( + JSONRPCRequest( + jsonrpc="2.0", + id=0, + method="initialize", + params={ + "protocolVersion": "2025-11-25", + "capabilities": {}, + "clientInfo": {"name": "raw", "version": "0.0.1"}, + }, + ) + ) + ) + init_response = await client_read.receive() + assert isinstance(init_response, SessionMessage) + assert isinstance(init_response.message, JSONRPCResponse) + await client_write.send( + SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")) + ) + + await client_write.send( + SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=1, method="tools/call", params={"name": 42})) + ) + error_response = await client_read.receive() + assert isinstance(error_response, SessionMessage) + assert isinstance(error_response.message, JSONRPCError) + errors.append(error_response.message.error) + + server_task_group.cancel_scope.cancel() + + assert errors == snapshot([ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="")]) + + +@requirement("logging:set-level:invalid-level") +async def test_set_level_with_an_unrecognized_value_is_answered_with_invalid_params() -> None: + """logging/setLevel with a value outside the spec's level enum is answered with -32602 Invalid params. + + The typed client API cannot construct a setLevel request with an unrecognized level (pyright and + the client-side model both reject it), so the test plays the client's side of the wire by hand + against a real Server. Reserve this pattern for behaviour the typed API cannot produce. + """ + + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + """Registered so the logging capability is advertised; never called -- params validation fails first.""" + raise NotImplementedError + + server = Server("logger", on_set_logging_level=set_logging_level) + errors: list[ErrorData] = [] + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async with anyio.create_task_group() as server_task_group: + server_task_group.start_soon(server.run, server_read, server_write, server.create_initialization_options()) + + with anyio.fail_after(5): + await client_write.send( + SessionMessage( + JSONRPCRequest( + jsonrpc="2.0", + id=0, + method="initialize", + params={ + "protocolVersion": "2025-11-25", + "capabilities": {}, + "clientInfo": {"name": "raw", "version": "0.0.1"}, + }, + ) + ) + ) + init_response = await client_read.receive() + assert isinstance(init_response, SessionMessage) + assert isinstance(init_response.message, JSONRPCResponse) + await client_write.send( + SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")) + ) + + await client_write.send( + SessionMessage( + JSONRPCRequest(jsonrpc="2.0", id=1, method="logging/setLevel", params={"level": "loud"}) + ) + ) + error_response = await client_read.receive() + assert isinstance(error_response, SessionMessage) + assert isinstance(error_response.message, JSONRPCError) + errors.append(error_response.message.error) + + server_task_group.cancel_scope.cancel() + + assert len(errors) == 1 + assert errors[0].code == INVALID_PARAMS diff --git a/tests/interaction/mcpserver/__init__.py b/tests/interaction/mcpserver/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/interaction/mcpserver/test_completion.py b/tests/interaction/mcpserver/test_completion.py new file mode 100644 index 0000000000..7761066e94 --- /dev/null +++ b/tests/interaction/mcpserver/test_completion.py @@ -0,0 +1,38 @@ +"""Completion behaviour against MCPServer, driven through the public Client API.""" + +import pytest + +from mcp.server.mcpserver import MCPServer +from mcp.types import ( + Completion, + CompletionArgument, + CompletionContext, + CompletionsCapability, + PromptReference, + ResourceTemplateReference, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("mcpserver:completion:capability-auto") +async def test_completion_capability_is_advertised_only_when_a_handler_is_registered(connect: Connect) -> None: + """An MCPServer with a registered completion handler advertises the completions capability; one without does not.""" + with_handler = MCPServer("completer") + + @with_handler.completion() + async def complete( + ref: PromptReference | ResourceTemplateReference, + argument: CompletionArgument, + context: CompletionContext | None, + ) -> Completion | None: + """Registered only so the completions capability is advertised; never called.""" + raise NotImplementedError + + async with connect(with_handler) as client: + assert client.initialize_result.capabilities.completions == CompletionsCapability() + + async with connect(MCPServer("plain")) as client: + assert client.initialize_result.capabilities.completions is None diff --git a/tests/interaction/mcpserver/test_context.py b/tests/interaction/mcpserver/test_context.py new file mode 100644 index 0000000000..26556fea7a --- /dev/null +++ b/tests/interaction/mcpserver/test_context.py @@ -0,0 +1,271 @@ +"""The Context convenience methods MCPServer injects into tool functions, observed from the client.""" + +import pytest +from inline_snapshot import snapshot +from pydantic import BaseModel + +from mcp import MCPError +from mcp.client import ClientRequestContext +from mcp.server.elicitation import AcceptedElicitation +from mcp.server.mcpserver import Context, MCPServer +from mcp.types import ( + METHOD_NOT_FOUND, + CallToolResult, + ElicitRequestFormParams, + ElicitRequestParams, + ElicitResult, + ErrorData, + Implementation, + LoggingMessageNotification, + LoggingMessageNotificationParams, + TextContent, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("mcpserver:context:logging") +@requirement("logging:capability:declared") +async def test_context_logging_helpers_send_log_notifications(connect: Connect) -> None: + """Each Context logging helper sends a log message notification at the matching severity. + + All four notifications reach the client's logging callback before the tool call returns; none + of them carry a logger name unless one is passed explicitly. The server emits these without + advertising the logging capability (see the divergence note on logging:capability). + """ + received: list[LoggingMessageNotificationParams] = [] + mcp = MCPServer("chatty") + + @mcp.tool() + async def narrate(ctx: Context) -> str: + await ctx.debug("d") + await ctx.info("i") + await ctx.warning("w") + await ctx.error("e") + return "done" + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params) + + async with connect(mcp, logging_callback=collect) as client: + result = await client.call_tool("narrate", {}) + advertised_logging = client.initialize_result.capabilities.logging + + assert result == snapshot(CallToolResult(content=[TextContent(text="done")], structured_content={"result": "done"})) + assert received == snapshot( + [ + LoggingMessageNotificationParams(level="debug", data="d"), + LoggingMessageNotificationParams(level="info", data="i"), + LoggingMessageNotificationParams(level="warning", data="w"), + LoggingMessageNotificationParams(level="error", data="e"), + ] + ) + # The spec requires servers that emit log notifications to declare the logging capability. + assert advertised_logging is None + + +@requirement("mcpserver:context:progress") +async def test_context_report_progress_sends_progress_notifications(connect: Connect) -> None: + """Context.report_progress sends progress notifications correlated to the calling request. + + The caller's progress callback receives each report, in order, before the tool call returns. + """ + received: list[tuple[float, float | None, str | None]] = [] + mcp = MCPServer("worker") + + @mcp.tool() + async def crunch(ctx: Context) -> str: + await ctx.report_progress(1, 3) + await ctx.report_progress(2, 3, "halfway there") + return "crunched" + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + received.append((progress, total, message)) + + async with connect(mcp) as client: + result = await client.call_tool("crunch", {}, progress_callback=on_progress) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="crunched")], structured_content={"result": "crunched"}) + ) + assert received == snapshot([(1.0, 3.0, None), (2.0, 3.0, "halfway there")]) + + +@requirement("mcpserver:tool:extra") +async def test_context_exposes_request_id_and_client_info_to_a_tool(connect: Connect) -> None: + """A tool can read the per-request id and the connecting client's identity through Context. + + The request id is non-empty (its concrete value depends on transport-level sequencing, so the + test asserts the value the tool saw is the one returned, rather than pinning the literal); the + client info reflects what the caller passed to `Client`. + """ + mcp = MCPServer("introspector") + + @mcp.tool() + async def whoami(ctx: Context) -> str: + client_params = ctx.session.client_params + assert client_params is not None + return f"request {ctx.request_id} from {client_params.client_info.name} {client_params.client_info.version}" + + async with connect(mcp, client_info=Implementation(name="acme-agent", version="9.9.9")) as client: + result = await client.call_tool("whoami", {}) + + assert isinstance(result.content[0], TextContent) + text = result.content[0].text + assert text.startswith("request ") + assert text.endswith(" from acme-agent 9.9.9") + request_id = text.removeprefix("request ").removesuffix(" from acme-agent 9.9.9") + assert request_id + + +@requirement("protocol:progress:no-token") +async def test_report_progress_without_a_progress_token_sends_nothing(connect: Connect) -> None: + """When the caller supplied no progress callback, Context.report_progress is a silent no-op. + + The tool also emits one log message as a sentinel: the message handler receives only that, + proving the notification pipeline works and no progress notification was sent for the + token-less request. + """ + received: list[IncomingMessage] = [] + mcp = MCPServer("quiet") + + @mcp.tool() + async def mill(ctx: Context) -> str: + await ctx.report_progress(1, 3) + await ctx.info("milling done") + return "milled" + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async with connect(mcp, message_handler=collect) as client: + result = await client.call_tool("mill", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="milled")], structured_content={"result": "milled"}) + ) + assert received == snapshot( + [LoggingMessageNotification(params=LoggingMessageNotificationParams(level="info", data="milling done"))] + ) + + +@requirement("mcpserver:context:elicit") +@requirement("tools:call:elicitation-roundtrip") +async def test_context_elicit_returns_typed_result(connect: Connect) -> None: + """Context.elicit sends a form elicitation built from a pydantic schema and returns a typed result. + + The client sees the JSON schema generated from the model; the accepted content is validated + back into the model and handed to the tool as result.data. + """ + received: list[ElicitRequestParams] = [] + mcp = MCPServer("travel") + + class TravelPreferences(BaseModel): + destination: str + window_seat: bool + + @mcp.tool() + async def book_flight(ctx: Context) -> str: + answer = await ctx.elicit("Where to?", TravelPreferences) + assert isinstance(answer, AcceptedElicitation) + return f"{answer.action}: {answer.data.destination} window={answer.data.window_seat}" + + async def answer_form(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="accept", content={"destination": "Lisbon", "window_seat": True}) + + async with connect(mcp, elicitation_callback=answer_form) as client: + result = await client.call_tool("book_flight", {}) + + assert received == snapshot( + [ + ElicitRequestFormParams( + _meta={}, + message="Where to?", + requested_schema={ + "properties": { + "destination": {"title": "Destination", "type": "string"}, + "window_seat": {"title": "Window Seat", "type": "boolean"}, + }, + "required": ["destination", "window_seat"], + "title": "TravelPreferences", + "type": "object", + }, + ) + ] + ) + assert result == snapshot( + CallToolResult( + content=[TextContent(text="accept: Lisbon window=True")], + structured_content={"result": "accept: Lisbon window=True"}, + ) + ) + + +@requirement("mcpserver:context:read-resource") +async def test_context_read_resource_reads_registered_resource(connect: Connect) -> None: + """Context.read_resource lets a tool read a resource registered on the same server. + + The tool reports the MIME type and content it read, proving the resource function ran and its + return value came back through the context. + """ + mcp = MCPServer("library") + + @mcp.resource("config://app") + def app_config() -> str: + """The application configuration.""" + return "theme = dark" + + @mcp.tool() + async def show_config(ctx: Context) -> str: + contents = list(await ctx.read_resource("config://app")) + return "\n".join(f"{item.mime_type}: {item.content!r}" for item in contents) + + async with connect(mcp) as client: + result = await client.call_tool("show_config", {}) + + assert result == snapshot( + CallToolResult( + content=[TextContent(text="text/plain: 'theme = dark'")], + structured_content={"result": "text/plain: 'theme = dark'"}, + ) + ) + + +@requirement("logging:message:filtered") +async def test_set_logging_level_is_rejected_and_messages_are_never_filtered(connect: Connect) -> None: + """MCPServer does not support logging/setLevel, so log messages are never filtered by severity. + + The request is rejected with METHOD_NOT_FOUND because MCPServer registers no handler for it, + and every message a tool emits is delivered regardless of level. The spec says the server + should only send messages at or above the configured level; with no way to configure one, + everything is sent. + """ + received: list[LoggingMessageNotificationParams] = [] + mcp = MCPServer("unfilterable") + + @mcp.tool() + async def chatter(ctx: Context) -> str: + await ctx.debug("noise") + await ctx.error("signal") + return "done" + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params) + + async with connect(mcp, logging_callback=collect) as client: + with pytest.raises(MCPError) as exc_info: + await client.set_logging_level("error") + + await client.call_tool("chatter", {}) + + assert exc_info.value.error == snapshot(ErrorData(code=METHOD_NOT_FOUND, message="Method not found")) + assert received == snapshot( + [ + LoggingMessageNotificationParams(level="debug", data="noise"), + LoggingMessageNotificationParams(level="error", data="signal"), + ] + ) diff --git a/tests/interaction/mcpserver/test_prompts.py b/tests/interaction/mcpserver/test_prompts.py new file mode 100644 index 0000000000..2095f086d4 --- /dev/null +++ b/tests/interaction/mcpserver/test_prompts.py @@ -0,0 +1,195 @@ +"""Prompt interactions against MCPServer, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError +from mcp.server.mcpserver import MCPServer +from mcp.types import ( + ErrorData, + GetPromptResult, + ListPromptsResult, + Prompt, + PromptArgument, + PromptMessage, + TextContent, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("mcpserver:prompt:decorated") +async def test_list_prompts_derives_arguments_from_signature(connect: Connect) -> None: + """A decorated prompt is listed with arguments derived from the function signature. + + Parameters without a default are required; the description comes from the docstring. + """ + mcp = MCPServer("prompter") + + @mcp.prompt() + def code_review(code: str, style_guide: str = "pep8") -> str: + """Review a piece of code.""" + raise NotImplementedError # registered for listing only; never rendered + + async with connect(mcp) as client: + result = await client.list_prompts() + + assert result == snapshot( + ListPromptsResult( + prompts=[ + Prompt( + name="code_review", + description="Review a piece of code.", + arguments=[ + PromptArgument(name="code", required=True), + PromptArgument(name="style_guide", required=False), + ], + ) + ] + ) + ) + + +@requirement("mcpserver:prompt:decorated") +async def test_get_prompt_renders_function_return(connect: Connect) -> None: + """The decorated function's string return value is rendered as a single user message.""" + mcp = MCPServer("prompter") + + @mcp.prompt() + def greet(name: str) -> str: + """A personalised greeting.""" + return f"Say hello to {name}." + + async with connect(mcp) as client: + result = await client.get_prompt("greet", {"name": "Ada"}) + + assert result == snapshot( + GetPromptResult( + description="A personalised greeting.", + messages=[PromptMessage(role="user", content=TextContent(text="Say hello to Ada."))], + ) + ) + + +@requirement("mcpserver:prompt:unknown-name") +async def test_get_unknown_prompt_is_error(connect: Connect) -> None: + """Getting a prompt name that was never registered fails with a JSON-RPC error. + + The spec reserves -32602 for this case; the SDK reports code 0 (see the divergence note on + the requirement). + """ + mcp = MCPServer("prompter") + + @mcp.prompt() + def greet(name: str) -> str: + """A registered prompt; the test requests a different name.""" + raise NotImplementedError + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.get_prompt("nope") + + assert exc_info.value.error == snapshot(ErrorData(code=0, message="Unknown prompt: nope")) + + +@requirement("prompts:get:missing-required-args") +async def test_get_prompt_with_a_missing_required_argument_is_an_error(connect: Connect) -> None: + """Getting a prompt without one of its required arguments fails with a JSON-RPC error. + + The missing argument is detected before the prompt function is called, but the spec's -32602 + Invalid params is reported as error code 0 with the bare exception text (see the divergence + note on the requirement). + """ + mcp = MCPServer("prompter") + + @mcp.prompt() + def greet(name: str) -> str: + """A registered prompt; validation rejects the call before the function runs.""" + raise NotImplementedError + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.get_prompt("greet") + + assert exc_info.value.error == snapshot(ErrorData(code=0, message="Missing required arguments: {'name'}")) + + +@requirement("mcpserver:prompt:args-validation") +async def test_get_prompt_with_a_wrong_type_argument_is_rejected_before_the_function_runs(connect: Connect) -> None: + """An argument that fails the function signature's type validation is rejected before the function runs. + + The decorated function is wrapped in pydantic's validate_call, so a value that cannot be + coerced to the parameter's annotation fails before the body executes. The function body + raises NotImplementedError to prove it never ran. The error is wrapped in the SDK's stable + rendering-error prefix; the body of the message is raw pydantic output and is not asserted. + """ + mcp = MCPServer("prompter") + + @mcp.prompt() + def repeat(phrase: str, count: int) -> str: + """A registered prompt; type validation rejects the call before the function runs.""" + raise NotImplementedError + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.get_prompt("repeat", {"phrase": "hi", "count": "many"}) + + assert exc_info.value.error.code == 0 + assert exc_info.value.error.message.startswith("Error rendering prompt repeat: 1 validation error") + + +@requirement("mcpserver:prompt:optional-args") +async def test_get_prompt_with_an_optional_argument_omitted_uses_the_default(connect: Connect) -> None: + """A prompt rendered without one of its optional arguments uses that parameter's default value.""" + mcp = MCPServer("prompter") + + @mcp.prompt() + def review(code: str, style: str = "pep8") -> str: + """Review a snippet of code against a style guide.""" + return f"Review {code} per {style}." + + async with connect(mcp) as client: + result = await client.get_prompt("review", {"code": "x = 1"}) + + assert result == snapshot( + GetPromptResult( + description="Review a snippet of code against a style guide.", + messages=[PromptMessage(role="user", content=TextContent(text="Review x = 1 per pep8."))], + ) + ) + + +@requirement("mcpserver:prompt:duplicate-name") +async def test_registering_a_duplicate_prompt_name_warns_and_keeps_the_first(connect: Connect) -> None: + """Registering a second prompt with an already-used name keeps the first registration. + + The intended behaviour is rejection at registration time; MCPServer instead logs a warning + and discards the second registration (see the divergence note on the requirement). The + second function is registered via the decorator with an explicit name so the test does not + redefine the same function name in this scope. + """ + mcp = MCPServer("prompter") + + @mcp.prompt() + def greet() -> str: + """The first registration; this is the one that wins.""" + return "first" + + @mcp.prompt(name="greet") + def greet_second() -> str: + """Registered with a duplicate name; the registration is discarded so this never runs.""" + raise NotImplementedError + + async with connect(mcp) as client: + listed = await client.list_prompts() + result = await client.get_prompt("greet") + + assert [prompt.name for prompt in listed.prompts] == ["greet"] + assert result == snapshot( + GetPromptResult( + description="The first registration; this is the one that wins.", + messages=[PromptMessage(role="user", content=TextContent(text="first"))], + ) + ) diff --git a/tests/interaction/mcpserver/test_resources.py b/tests/interaction/mcpserver/test_resources.py new file mode 100644 index 0000000000..57b0fdc86d --- /dev/null +++ b/tests/interaction/mcpserver/test_resources.py @@ -0,0 +1,183 @@ +"""Resource interactions against MCPServer, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError +from mcp.server.mcpserver import MCPServer +from mcp.types import ( + ErrorData, + ListResourcesResult, + ListResourceTemplatesResult, + ReadResourceResult, + Resource, + ResourceTemplate, + TextResourceContents, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("mcpserver:resource:static") +async def test_read_static_resource(connect: Connect) -> None: + """A function registered for a fixed URI is served at that URI with its return value as text.""" + mcp = MCPServer("library") + + @mcp.resource("config://app") + def app_config() -> str: + """The application configuration.""" + return "theme = dark" + + async with connect(mcp) as client: + result = await client.read_resource("config://app") + + assert result == snapshot( + ReadResourceResult( + contents=[TextResourceContents(uri="config://app", mime_type="text/plain", text="theme = dark")] + ) + ) + + +@requirement("mcpserver:resource:static") +async def test_list_static_and_templated_resources(connect: Connect) -> None: + """Statically-registered resources appear in resources/list; templated ones only in templates/list. + + The name and description are derived from the function name and docstring; the MIME type + defaults to text/plain. + """ + mcp = MCPServer("library") + + @mcp.resource("config://app") + def app_config() -> str: + """The application configuration.""" + raise NotImplementedError # registered for listing only; never read + + @mcp.resource("users://{user_id}/profile") + def user_profile(user_id: str) -> str: + """A user's profile.""" + raise NotImplementedError # registered for listing only; never read + + async with connect(mcp) as client: + resources = await client.list_resources() + templates = await client.list_resource_templates() + + assert resources == snapshot( + ListResourcesResult( + resources=[ + Resource( + name="app_config", + uri="config://app", + description="The application configuration.", + mime_type="text/plain", + ) + ] + ) + ) + assert templates == snapshot( + ListResourceTemplatesResult( + resource_templates=[ + ResourceTemplate( + name="user_profile", + uri_template="users://{user_id}/profile", + description="A user's profile.", + mime_type="text/plain", + ) + ] + ) + ) + + +@requirement("mcpserver:resource:template") +@requirement("resources:read:template-vars") +async def test_read_templated_resource(connect: Connect) -> None: + """Reading a URI that matches a registered template invokes the function with the extracted parameters.""" + mcp = MCPServer("library") + + @mcp.resource("users://{user_id}/profile") + def user_profile(user_id: str) -> str: + """A user's profile.""" + return f"profile for {user_id}" + + async with connect(mcp) as client: + result = await client.read_resource("users://42/profile") + + assert result == snapshot( + ReadResourceResult( + contents=[TextResourceContents(uri="users://42/profile", mime_type="text/plain", text="profile for 42")] + ) + ) + + +@requirement("mcpserver:resource:unknown-uri") +async def test_read_unknown_uri_is_error(connect: Connect) -> None: + """Reading a URI that matches no registered resource fails with a JSON-RPC error. + + The spec reserves -32002 for resource-not-found; see the divergence note on the requirement. + """ + mcp = MCPServer("library") + + @mcp.resource("config://app") + def app_config() -> str: + """A registered resource; the test reads a different URI.""" + raise NotImplementedError + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.read_resource("config://missing") + + assert exc_info.value.error == snapshot(ErrorData(code=0, message="Unknown resource: config://missing")) + + +@requirement("mcpserver:resource:read-throws-surfaced") +async def test_resource_function_that_raises_is_surfaced_as_a_jsonrpc_error(connect: Connect) -> None: + """An exception raised by a resource function reaches the caller as a JSON-RPC error. + + MCPServer wraps the failure in a generic error that names only the URI, so the original + exception text is not leaked to the client. The wrapped exception becomes error code 0 the + same way every other unhandled server-side exception does. + """ + mcp = MCPServer("library") + + @mcp.resource("res://boom") + def boom() -> str: + raise RuntimeError("nope") + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.read_resource("res://boom") + + assert exc_info.value.error == snapshot(ErrorData(code=0, message="Error reading resource res://boom")) + + +@requirement("mcpserver:resource:duplicate-name") +async def test_registering_a_duplicate_resource_uri_warns_and_keeps_the_first(connect: Connect) -> None: + """Registering a second static resource at an already-used URI keeps the first registration. + + The intended behaviour is rejection at registration time; MCPServer instead logs a warning + and discards the second registration (see the divergence note on the requirement). The two + registrations use different function names so the test does not redefine a name in this scope; + the resource decorator keys on the URI, not the function name. + """ + mcp = MCPServer("library") + + @mcp.resource("config://app") + def config_first() -> str: + """The first registration; this is the one that wins.""" + return "first" + + @mcp.resource("config://app") + def config_second() -> str: + """Registered at a duplicate URI; the registration is discarded so this never runs.""" + raise NotImplementedError + + async with connect(mcp) as client: + listed = await client.list_resources() + result = await client.read_resource("config://app") + + assert [resource.uri for resource in listed.resources] == ["config://app"] + assert listed.resources[0].name == "config_first" + assert result == snapshot( + ReadResourceResult(contents=[TextResourceContents(uri="config://app", mime_type="text/plain", text="first")]) + ) diff --git a/tests/interaction/mcpserver/test_tools.py b/tests/interaction/mcpserver/test_tools.py new file mode 100644 index 0000000000..05135c1286 --- /dev/null +++ b/tests/interaction/mcpserver/test_tools.py @@ -0,0 +1,432 @@ +"""Tool interactions against MCPServer, driven through the public Client API.""" + +import logging +from typing import Annotated, Literal + +import pytest +from inline_snapshot import snapshot +from pydantic import BaseModel, Field + +from mcp import MCPError +from mcp.server.mcpserver import Context, MCPServer +from mcp.server.mcpserver.exceptions import ToolError +from mcp.shared.exceptions import UrlElicitationRequiredError +from mcp.types import ( + URL_ELICITATION_REQUIRED, + CallToolResult, + ElicitRequestURLParams, + ErrorData, + LoggingMessageNotification, + LoggingMessageNotificationParams, + TextContent, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("tools:call:content:text") +async def test_call_tool_returns_text_content(connect: Connect) -> None: + """Arguments reach the tool function; its return value comes back as text content. + + MCPServer also derives an output schema from the return annotation and attaches the + matching structuredContent to the result. + """ + mcp = MCPServer("adder") + + @mcp.tool() + def add(a: int, b: int) -> str: + return str(a + b) + + async with connect(mcp) as client: + result = await client.call_tool("add", {"a": 2, "b": 3}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="5")], structured_content={"result": "5"})) + + +@requirement("mcpserver:tool:schema-variants") +async def test_complex_parameter_types_are_validated_and_coerced_before_the_tool_runs(connect: Connect) -> None: + """Literal, nested-model, and constrained parameters are validated and coerced from the wire arguments. + + The string "3" is coerced to `int` and the `point` dict to a `Point` instance before the function + body sees them, proving the generated input schema and validation pipeline cover non-trivial types. + """ + mcp = MCPServer("typed") + + class Point(BaseModel): + x: int + y: int + + @mcp.tool() + def place(mode: Literal["fast", "slow"], point: Point, count: Annotated[int, Field(ge=1, le=10)]) -> str: + assert isinstance(point, Point) + return f"{mode} at ({point.x}, {point.y}) x{count}" + + async with connect(mcp) as client: + result = await client.call_tool("place", {"mode": "fast", "point": {"x": "3", "y": 4}, "count": 5}) + + assert result == snapshot( + CallToolResult( + content=[TextContent(text="fast at (3, 4) x5")], structured_content={"result": "fast at (3, 4) x5"} + ) + ) + + +@requirement("mcpserver:tool:handler-throws") +@requirement("mcpserver:output-schema:skip-on-error") +async def test_call_tool_function_exception_becomes_error_result(connect: Connect) -> None: + """An exception raised by a tool function is returned as an is_error result, not a JSON-RPC error. + + The function's `-> str` annotation gives the tool a derived output schema, but the error + result is built before any schema validation runs, so no validation failure is layered on + top of the original exception. + """ + mcp = MCPServer("errors") + + @mcp.tool() + def explode() -> str: + raise ValueError("boom") + + async with connect(mcp) as client: + result = await client.call_tool("explode", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="Error executing tool explode: boom")], is_error=True) + ) + + +@requirement("mcpserver:tool:handler-throws") +async def test_call_tool_tool_error_becomes_error_result(connect: Connect) -> None: + """A ToolError raised by a tool function is returned as an is_error result, not a JSON-RPC error.""" + mcp = MCPServer("errors") + + @mcp.tool() + def flux() -> str: + raise ToolError("flux capacitor offline") + + async with connect(mcp) as client: + result = await client.call_tool("flux", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="Error executing tool flux: flux capacitor offline")], is_error=True) + ) + + +@requirement("mcpserver:tool:unknown-name") +async def test_call_tool_unknown_name_returns_error_result(connect: Connect) -> None: + """Calling a tool name that was never registered is reported as an is_error result. + + The spec classifies unknown tools as a protocol error; see the divergence note on the + requirement. + """ + mcp = MCPServer("errors") + + @mcp.tool() + def add() -> None: + """A registered tool; the test calls a different name.""" + + async with connect(mcp) as client: + result = await client.call_tool("nope", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="Unknown tool: nope")], is_error=True)) + + +@requirement("mcpserver:tool:output-schema:model") +@requirement("tools:call:structured-content:text-mirror") +async def test_call_tool_model_return_becomes_structured_content(connect: Connect) -> None: + """A tool returning a pydantic model advertises the model's schema as the tool's output schema + and returns the model's fields as structured content alongside a serialised text block. + """ + mcp = MCPServer("weather") + + class Weather(BaseModel): + temperature: float + conditions: str + + @mcp.tool() + def get_weather() -> Weather: + return Weather(temperature=22.5, conditions="sunny") + + async with connect(mcp) as client: + listed = await client.list_tools() + result = await client.call_tool("get_weather", {}) + + assert listed.tools[0].output_schema == snapshot( + { + "properties": { + "temperature": {"title": "Temperature", "type": "number"}, + "conditions": {"title": "Conditions", "type": "string"}, + }, + "required": ["temperature", "conditions"], + "title": "Weather", + "type": "object", + } + ) + assert result == snapshot( + CallToolResult( + content=[ + TextContent( + text="""\ +{ + "temperature": 22.5, + "conditions": "sunny" +}\ +""" + ) + ], + structured_content={"temperature": 22.5, "conditions": "sunny"}, + ) + ) + + +@requirement("mcpserver:tool:output-schema:wrapped") +async def test_call_tool_list_return_is_wrapped_in_result_key(connect: Connect) -> None: + """A tool returning a list wraps the value under a "result" key in both the generated output + schema and the structured content. + """ + mcp = MCPServer("primes") + + @mcp.tool() + def primes() -> list[int]: + return [2, 3, 5] + + async with connect(mcp) as client: + listed = await client.list_tools() + result = await client.call_tool("primes", {}) + + assert listed.tools[0].output_schema == snapshot( + { + "properties": {"result": {"items": {"type": "integer"}, "title": "Result", "type": "array"}}, + "required": ["result"], + "title": "primesOutput", + "type": "object", + } + ) + assert result == snapshot( + CallToolResult( + content=[TextContent(text="2"), TextContent(text="3"), TextContent(text="5")], + structured_content={"result": [2, 3, 5]}, + ) + ) + + +@requirement("mcpserver:tool:input-validation") +async def test_call_tool_invalid_arguments_become_error_result(connect: Connect) -> None: + """Arguments that fail validation against the tool's signature are reported as an is_error + result describing the failure, not as a protocol error. + """ + mcp = MCPServer("adder") + + @mcp.tool() + def add(a: int, b: int) -> str: + """Validation rejects the arguments before the function is ever called.""" + raise NotImplementedError + + async with connect(mcp) as client: + result = await client.call_tool("add", {"b": 3}) + + # The description is raw pydantic output -- it embeds a pydantic-version-specific + # errors.pydantic.dev URL and the internal `addArguments` model name -- so only the stable + # prefix is asserted; a full snapshot would break on every pydantic upgrade. + assert result.is_error is True + assert isinstance(result.content[0], TextContent) + assert result.content[0].text.startswith("Error executing tool add: 1 validation error") + + +@requirement("mcpserver:output-schema:server-validate") +@requirement("mcpserver:output-schema:missing-structured") +async def test_tool_with_output_schema_returning_mismatched_structured_content_is_an_error_result( + connect: Connect, +) -> None: + """Structured content that fails the tool's own output schema is rejected on the server side. + + A tool annotated `Annotated[CallToolResult, Model]` returns a hand-built CallToolResult while + declaring `Model` as its output schema; MCPServer validates the supplied structured_content + against that schema before returning. The two cases -- a content shape that does not match, + and no structured content at all -- both fail that validation and are reported as is_error + results carrying the (raw pydantic) validation error wrapped in the SDK's stable prefix. + """ + mcp = MCPServer("forecaster") + + class Weather(BaseModel): + temperature: float + conditions: str + + @mcp.tool() + def mismatched() -> Annotated[CallToolResult, Weather]: + return CallToolResult(content=[TextContent(text="oops")], structured_content={"nope": True}) + + @mcp.tool() + def missing() -> Annotated[CallToolResult, Weather]: + return CallToolResult(content=[TextContent(text="oops")]) + + async with connect(mcp) as client: + mismatched_result = await client.call_tool("mismatched", {}) + missing_result = await client.call_tool("missing", {}) + + # The body of each message is raw pydantic ValidationError output (model name, field paths, + # an errors.pydantic.dev URL) and changes across pydantic versions, so only the SDK's stable + # prefix is asserted. + assert mismatched_result.is_error is True + assert isinstance(mismatched_result.content[0], TextContent) + assert mismatched_result.content[0].text.startswith("Error executing tool mismatched: 2 validation errors") + + assert missing_result.is_error is True + assert isinstance(missing_result.content[0], TextContent) + assert missing_result.content[0].text.startswith("Error executing tool missing: 1 validation error") + + +@requirement("mcpserver:tool:duplicate-name") +async def test_registering_a_duplicate_tool_name_warns_and_keeps_the_first(connect: Connect) -> None: + """Registering a second tool with an already-used name keeps the first registration. + + The intended behaviour is rejection at registration time; MCPServer instead logs a warning + and discards the second registration (see the divergence note on the requirement). The + second function is registered via add_tool with an explicit name so the test does not + redefine the same function name in this scope. + """ + mcp = MCPServer("duplicates") + + @mcp.tool() + def echo() -> str: + return "first" + + def echo_second() -> str: + """Passed to add_tool with a duplicate name; the registration is discarded so this never runs.""" + raise NotImplementedError + + mcp.add_tool(echo_second, name="echo") + + async with connect(mcp) as client: + listed = await client.list_tools() + result = await client.call_tool("echo", {}) + + assert [tool.name for tool in listed.tools] == ["echo"] + assert result == snapshot( + CallToolResult(content=[TextContent(text="first")], structured_content={"result": "first"}) + ) + + +@requirement("mcpserver:tool:naming-validation") +async def test_registering_a_tool_with_a_spec_invalid_name_warns_but_does_not_reject( + connect: Connect, caplog: pytest.LogCaptureFixture +) -> None: + """A tool name that violates the SEP-986 rules logs a warning at registration but is still registered. + + The intended behaviour is rejection at registration time; MCPServer instead logs the + naming-rule violation and proceeds (see the divergence note on the requirement). The warning + spans several SDK-authored log records, so only the stable prefix and inclusion of the + offending name are asserted. + """ + mcp = MCPServer("naming") + + with caplog.at_level(logging.WARNING, logger="mcp.shared.tool_name_validation"): + + @mcp.tool(name="bad name!") + def bad() -> str: + return "ok" + + assert any( + rec.levelno == logging.WARNING + and rec.message.startswith("Tool name validation warning") + and "bad name!" in rec.message + for rec in caplog.records + ) + + async with connect(mcp) as client: + listed = await client.list_tools() + result = await client.call_tool("bad name!", {}) + + assert [tool.name for tool in listed.tools] == ["bad name!"] + assert result == snapshot(CallToolResult(content=[TextContent(text="ok")], structured_content={"result": "ok"})) + + +@requirement("mcpserver:tool:url-elicitation-error") +async def test_decorated_tool_raising_url_elicitation_required_surfaces_as_error_32042(connect: Connect) -> None: + """A decorated tool raising the URL-elicitation-required error reaches the client as error -32042. + + MCPServer wraps every other tool exception as an is_error result; this error is special-cased + so it propagates as the JSON-RPC error the client needs in order to present the listed URL + interactions and retry the call. + """ + mcp = MCPServer("authorizer") + + @mcp.tool() + def read_files() -> str: + raise UrlElicitationRequiredError( + [ + ElicitRequestURLParams( + message="Authorization required for your files.", + url="https://example.com/oauth/authorize", + elicitation_id="auth-001", + ) + ] + ) + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("read_files", {}) + + assert exc_info.value.error.code == URL_ELICITATION_REQUIRED + assert exc_info.value.error == snapshot( + ErrorData( + code=-32042, + message="URL elicitation required", + data={ + "elicitations": [ + { + "mode": "url", + "message": "Authorization required for your files.", + "url": "https://example.com/oauth/authorize", + "elicitationId": "auth-001", + } + ] + }, + ) + ) + + +@requirement("mcpserver:register:post-connect") +async def test_adding_and_removing_tools_does_not_notify_connected_clients(connect: Connect) -> None: + """Mutating the tool set on a running server changes tools/list but sends no notification. + + add_tool and remove_tool only update the registry: a connected client that listed the tools + before the mutation has no way to learn it should list them again. The spec provides + notifications/tools/list_changed for exactly this; MCPServer never sends it. The tool emits + one log message as a sentinel so the test proves notifications do reach the collector -- the + log message arrives, a list_changed does not. + """ + received: list[IncomingMessage] = [] + mcp = MCPServer("mutable") + + def extra() -> str: + """A tool registered at runtime; never called.""" + raise NotImplementedError + + @mcp.tool() + def doomed() -> str: + """A tool removed at runtime; never called.""" + raise NotImplementedError + + @mcp.tool() + async def grow(ctx: Context) -> str: + mcp.add_tool(extra, name="extra") + mcp.remove_tool("doomed") + await ctx.info("tool set changed") + return "mutated" + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async with connect(mcp, message_handler=collect) as client: + before = await client.list_tools() + await client.call_tool("grow", {}) + after = await client.list_tools() + + assert [tool.name for tool in before.tools] == ["doomed", "grow"] + assert [tool.name for tool in after.tools] == ["grow", "extra"] + assert received == snapshot( + [LoggingMessageNotification(params=LoggingMessageNotificationParams(level="info", data="tool set changed"))] + ) diff --git a/tests/interaction/test_coverage.py b/tests/interaction/test_coverage.py new file mode 100644 index 0000000000..7821c1eed5 --- /dev/null +++ b/tests/interaction/test_coverage.py @@ -0,0 +1,105 @@ +"""Enforces the contract between the requirements manifest and the test suite. + +The contract runs in both directions: every non-deferred entry in :data:`REQUIREMENTS` must be +exercised by at least one test, and every test in the suite must carry at least one +`@requirement(...)` mark referencing a manifest entry. Deferral reasons that point at coverage +elsewhere in the repo must point at paths that exist. Test modules are imported directly +(rather than relying on pytest collection) so the check holds even when only this file is run. +""" + +import importlib +import re +from pathlib import Path +from types import ModuleType + +import pytest + +from tests.interaction._requirements import REQUIREMENTS, Requirement, covered_by, requirement + +_SUITE_ROOT = Path(__file__).parent +_REPO_ROOT = _SUITE_ROOT.parent.parent + +# Repo paths cited inside deferral reasons ("Covered by tests/... "). +_CITED_PATH = re.compile(r"(?:tests|src)/[\w./-]*\w") + +# Tests that exercise the suite's own helpers rather than an interaction-model behaviour. +# Anything listed here is exempt from the every-test-has-a-requirement check. +_HARNESS_SELF_TESTS = { + "tests.interaction.lowlevel.test_wire.test_recording_read_stream_ends_iteration_when_the_sender_closes", + "tests.interaction.transports.test_bridge.test_response_chunks_arrive_as_the_application_sends_them", + "tests.interaction.transports.test_bridge.test_closing_the_response_delivers_a_disconnect_to_the_application", + "tests.interaction.transports.test_bridge.test_an_application_failure_before_the_response_starts_fails_the_request", + "tests.interaction.transports.test_bridge.test_disabling_cancel_on_close_lets_the_application_finish_after_disconnect", + "tests.interaction.auth.test_flow.test_shimmed_app_serves_overrides_404s_and_otherwise_forwards_to_the_wrapped_app", +} + + +def _import_all_test_modules() -> list[ModuleType]: + """Import every other test module in the suite so their `@requirement` decorators register.""" + modules: list[ModuleType] = [] + for path in sorted(_SUITE_ROOT.rglob("test_*.py")): + relative = path.relative_to(_SUITE_ROOT).with_suffix("") + name = f"{__package__}.{'.'.join(relative.parts)}" + if name != __name__: + modules.append(importlib.import_module(name)) + return modules + + +def test_every_requirement_is_exercised() -> None: + """Each non-deferred requirement is covered by at least one test (deferred ones by none).""" + _import_all_test_modules() + + uncovered = [ + requirement_id + for requirement_id, spec in sorted(REQUIREMENTS.items()) + if spec.deferred is None and not covered_by(requirement_id) + ] + assert not uncovered, f"Requirements with no test and no deferred reason: {uncovered}" + + stale_deferrals = [ + requirement_id + for requirement_id, spec in sorted(REQUIREMENTS.items()) + if spec.deferred is not None and covered_by(requirement_id) + ] + assert not stale_deferrals, f"Deferred requirements that now have tests (remove deferred): {stale_deferrals}" + + +def test_every_test_exercises_a_requirement() -> None: + """Each test in the suite carries at least one `@requirement` mark (harness self-tests excepted).""" + all_tests = { + f"{module.__name__}.{name}" + for module in _import_all_test_modules() + for name in vars(module) + if name.startswith("test_") + } + linked_tests = {test_name for requirement_id in REQUIREMENTS for test_name in covered_by(requirement_id)} + + unlinked = sorted(all_tests - linked_tests - _HARNESS_SELF_TESTS) + assert not unlinked, f"Tests with no @requirement mark: {unlinked}" + + stale_exemptions = sorted(_HARNESS_SELF_TESTS - all_tests) + assert not stale_exemptions, f"Harness self-test exemptions that no longer exist: {stale_exemptions}" + + +def test_deferral_reasons_cite_existing_paths() -> None: + """Every repo path named in a deferral reason exists, so coverage pointers cannot rot.""" + missing = sorted( + f"{requirement_id}: {cited}" + for requirement_id, spec in REQUIREMENTS.items() + if spec.deferred is not None + for cited in _CITED_PATH.findall(spec.deferred) + if not (_REPO_ROOT / cited).exists() + ) + assert not missing, f"Deferral reasons citing paths that do not exist: {missing}" + + +def test_unknown_requirement_id_is_rejected() -> None: + """Marking a test with an ID that is not in the manifest fails at decoration time.""" + with pytest.raises(KeyError, match="Unknown requirement id 'tools:call:does-not-exist'"): + requirement("tools:call:does-not-exist") + + +def test_invalid_requirement_source_is_rejected() -> None: + """A requirement whose source is not a spec URL, 'sdk', or an issue reference fails at construction.""" + with pytest.raises(ValueError, match="source must be a specification URL"): + Requirement(source="https://example.com/not-the-spec", behavior="Never constructed.") diff --git a/tests/interaction/transports/__init__.py b/tests/interaction/transports/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/interaction/transports/_bridge.py b/tests/interaction/transports/_bridge.py new file mode 100644 index 0000000000..f78c6d14b5 --- /dev/null +++ b/tests/interaction/transports/_bridge.py @@ -0,0 +1,169 @@ +"""An in-process, full-duplex HTTP transport for driving ASGI applications from httpx. + +`httpx.ASGITransport` runs the application to completion and only then hands the buffered +response to the caller, so a server that streams its response — the streamable HTTP transport's +SSE responses — can never converse with the client mid-request: a server-initiated request +nested inside a still-open call deadlocks. `StreamingASGITransport` removes that limitation by +running the application as a background task and forwarding every `http.response.body` chunk to +the client the moment it is sent. Everything happens on the one event loop: no sockets, no +threads, no sleeps, no extra dependencies. + +The behavioural contract, pinned by `test_bridge.py`: + +- The request body is buffered before the application is invoked (MCP requests are small JSON + documents); the response streams chunk by chunk. +- Closing the response — or the whole client — delivers `http.disconnect` to the application, + exactly as a real server sees when its peer goes away. +- An exception the application raises before sending `http.response.start` fails the originating + request with that same exception. After the response has started, a failure is visible to the + client only through the response itself (status code, truncated body) — the same signal a real + server over a real socket would give. + +The transport owns an anyio task group for the application tasks; it is opened and closed by +`httpx.AsyncClient`'s own context manager, so use the client as a context manager (the suite +always does). Closing the transport cancels every running application task by default; set +`cancel_on_close=False` to wait for the application's own disconnect handling instead. +""" + +import math +from collections.abc import AsyncIterator +from types import TracebackType + +import anyio +import anyio.abc +import httpx +from anyio.streams.memory import MemoryObjectReceiveStream +from starlette.types import ASGIApp, Message, Scope + + +class _StreamingResponseBody(httpx.AsyncByteStream): + """A response body that yields chunks as the application produces them. + + Closing it tells the application the client has gone away (`http.disconnect`), mirroring a + peer that drops the connection mid-response. + """ + + def __init__(self, chunks: MemoryObjectReceiveStream[bytes], client_disconnected: anyio.Event) -> None: + self._chunks = chunks + self._client_disconnected = client_disconnected + + async def __aiter__(self) -> AsyncIterator[bytes]: + async for chunk in self._chunks: + yield chunk + + async def aclose(self) -> None: + self._client_disconnected.set() + await self._chunks.aclose() + + +class StreamingASGITransport(httpx.AsyncBaseTransport): + """Drive an ASGI application in-process, streaming each response as it is produced. + + With `cancel_on_close` (the default), closing the transport cancels every application task + still running so harness teardown can never hang. Setting it to False makes the transport wait + for the application's own disconnect handling to complete instead, which is the path the legacy + SSE server transport relies on for resource cleanup. + """ + + _task_group: anyio.abc.TaskGroup + + def __init__(self, app: ASGIApp, *, cancel_on_close: bool = True) -> None: + self._app = app + self._cancel_on_close = cancel_on_close + + async def __aenter__(self) -> "StreamingASGITransport": + self._task_group = anyio.create_task_group() + await self._task_group.__aenter__() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, + ) -> None: + # httpx closes every streamed response before closing the transport, so by now each + # application task has been delivered `http.disconnect`. Either cancel immediately, or wait + # for the application's own disconnect handling to unwind. + if self._cancel_on_close: + self._task_group.cancel_scope.cancel() + await self._task_group.__aexit__(exc_type, exc_value, traceback) + + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + assert isinstance(request.stream, httpx.AsyncByteStream) + request_body = b"".join([chunk async for chunk in request.stream]) + + scope: Scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": request.method, + "scheme": request.url.scheme, + "path": request.url.path, + "raw_path": request.url.raw_path.split(b"?", maxsplit=1)[0], + "query_string": request.url.query, + "root_path": "", + "headers": [(name.lower(), value) for name, value in request.headers.raw], + "server": (request.url.host, request.url.port), + "client": ("127.0.0.1", 1234), + } + + request_delivered = False + client_disconnected = anyio.Event() + response_started = anyio.Event() + response_status = 0 + response_headers: list[tuple[bytes, bytes]] = [] + application_error: Exception | None = None + chunk_writer, chunk_reader = anyio.create_memory_object_stream[bytes](math.inf) + + async def receive_request() -> Message: + nonlocal request_delivered + if not request_delivered: + request_delivered = True + return {"type": "http.request", "body": request_body, "more_body": False} + await client_disconnected.wait() + return {"type": "http.disconnect"} + + async def send_response(message: Message) -> None: + nonlocal response_status, response_headers + if message["type"] == "http.response.start": + response_status = message["status"] + response_headers = list(message.get("headers", [])) + response_started.set() + return + assert message["type"] == "http.response.body" + body: bytes = message.get("body", b"") + if body: + await chunk_writer.send(body) + if not message.get("more_body", False): + await chunk_writer.aclose() + + async def run_application() -> None: + nonlocal application_error + try: + await self._app(scope, receive_request, send_response) + except Exception as exc: # The bridge is the application's outermost boundary: a crash + # must fail the originating request (or show up in the already-started response), + # never tear down the task group shared with every other in-flight request. + application_error = exc + finally: + response_started.set() + await chunk_writer.aclose() + + self._task_group.start_soon(run_application) + try: + await response_started.wait() + if application_error is not None: + raise application_error + except BaseException: + # No response will be built, so close the reader the response body would have owned + # and tell the application its peer has gone away. + client_disconnected.set() + await chunk_reader.aclose() + raise + return httpx.Response( + status_code=response_status, + headers=response_headers, + stream=_StreamingResponseBody(chunk_reader, client_disconnected), + request=request, + ) diff --git a/tests/interaction/transports/_event_store.py b/tests/interaction/transports/_event_store.py new file mode 100644 index 0000000000..84d1a2646a --- /dev/null +++ b/tests/interaction/transports/_event_store.py @@ -0,0 +1,55 @@ +"""A predictable event store for resumability tests. + +The SDK's `EventStore` interface lets a streamable-HTTP server stamp every SSE event with an ID +and replay missed events when a client reconnects with `Last-Event-ID`. This implementation +issues sequential integer IDs starting at "1" so tests can assert exact IDs (the example store +uses uuid4, which cannot be snapshotted) and is small enough that every line is exercised by the +resumability tests themselves. +""" + +import anyio + +from mcp.server.streamable_http import EventCallback, EventId, EventMessage, EventStore, StreamId +from mcp.types import JSONRPCMessage + + +class SequencedEventStore(EventStore): + """Stores every event in order and replays the same-stream tail after a given ID.""" + + def __init__(self) -> None: + self._events: list[tuple[StreamId, JSONRPCMessage | None]] = [] + self._milestones: dict[int, anyio.Event] = {} + + async def store_event(self, stream_id: StreamId, message: JSONRPCMessage | None) -> EventId: + self._events.append((stream_id, message)) + count = len(self._events) + milestone = self._milestones.pop(count, None) + if milestone is not None: + milestone.set() + return str(count) + + async def wait_until_stored(self, count: int) -> None: + """Block until at least `count` events have been stored. + + Tests use this to wait for the server's message router (which runs in another task) to + finish storing a known set of events before issuing a replay, so the replay's content is + deterministic rather than depending on task scheduling order. + """ + if len(self._events) >= count: + return + milestone = self._milestones.setdefault(count, anyio.Event()) + await milestone.wait() + + async def replay_events_after(self, last_event_id: EventId, send_callback: EventCallback) -> StreamId | None: + try: + cursor = int(last_event_id) + except ValueError: + return None + if not 0 < cursor <= len(self._events): + return None + stream_id, _ = self._events[cursor - 1] + for index in range(cursor, len(self._events)): + event_stream_id, message = self._events[index] + if event_stream_id == stream_id and message is not None: + await send_callback(EventMessage(message, str(index + 1))) + return stream_id diff --git a/tests/interaction/transports/_stdio_server.py b/tests/interaction/transports/_stdio_server.py new file mode 100644 index 0000000000..5977cc3e99 --- /dev/null +++ b/tests/interaction/transports/_stdio_server.py @@ -0,0 +1,63 @@ +"""A real low-level Server over the stdio transport, for the suite's one subprocess test. + +Runnable as `python -m tests.interaction.transports._stdio_server` from the repo root; the test +launches it that way via `stdio_client`. Kept separate from the test module so the server lives in +its own importable file (subprocess coverage applies) while the test file follows the suite's +test-only-functions convention. +""" + +import sys + +import anyio + +from mcp.server import Server, ServerRequestContext +from mcp.server.stdio import stdio_server +from mcp.types import ( + CallToolRequestParams, + CallToolResult, + EmptyResult, + ListToolsResult, + PaginatedRequestParams, + SetLevelRequestParams, + TextContent, + Tool, +) + + +async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="echo", + input_schema={"type": "object", "properties": {"text": {"type": "string"}}, "required": ["text"]}, + ) + ] + ) + + +async def call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + assert params.name == "echo" + assert params.arguments is not None + text = params.arguments["text"] + await ctx.session.send_log_message(level="info", data=f"echoing {text}", logger="echo") + return CallToolResult(content=[TextContent(text=text)]) + + +async def set_logging_level(ctx: ServerRequestContext, params: SetLevelRequestParams) -> EmptyResult: + """Registered so the logging capability is advertised; the client never sets a level.""" + raise NotImplementedError + + +server = Server("stdio-echo", on_list_tools=list_tools, on_call_tool=call_tool, on_set_logging_level=set_logging_level) + + +async def main() -> None: + async with stdio_server() as (read_stream, write_stream): + await server.run(read_stream, write_stream, server.create_initialization_options()) + # Reached only when the run loop exits because stdin closed; if the process were terminated + # the test's stderr capture would not see this line. + print("stdio-echo: clean exit", file=sys.stderr, flush=True) + + +if __name__ == "__main__": + anyio.run(main) diff --git a/tests/interaction/transports/test_bridge.py b/tests/interaction/transports/test_bridge.py new file mode 100644 index 0000000000..7420b9d902 --- /dev/null +++ b/tests/interaction/transports/test_bridge.py @@ -0,0 +1,94 @@ +"""Contract tests for the suite's streaming ASGI bridge. + +These pin what `StreamingASGITransport` itself guarantees — chunk-by-chunk delivery, disconnect +propagation, and failure handling — against minimal hand-written ASGI applications, so the MCP +transport tests built on top of it never have to wonder what the harness provides. They are +harness self-tests, not interaction-model tests, and are exempted from the requirement-coverage +contract in `test_coverage.py`. +""" + +import anyio +import httpx +import pytest +from starlette.types import Message, Receive, Scope, Send + +from tests.interaction.transports._bridge import StreamingASGITransport + +pytestmark = pytest.mark.anyio + + +async def test_response_chunks_arrive_as_the_application_sends_them() -> None: + """Each body chunk is delivered as sent, empty chunks are skipped, and the stream ends with the application.""" + + async def chunked_app(scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + assert (await receive())["type"] == "http.request" + await send({"type": "http.response.start", "status": 200, "headers": [(b"content-type", b"text/plain")]}) + await send({"type": "http.response.body", "body": b"first", "more_body": True}) + await send({"type": "http.response.body", "body": b"", "more_body": True}) + await send({"type": "http.response.body", "body": b"second", "more_body": False}) + + async with ( + httpx.AsyncClient(transport=StreamingASGITransport(chunked_app), base_url="http://bridge") as http, + http.stream("GET", "/chunks") as response, + ): + with anyio.fail_after(5): + chunks = [chunk async for chunk in response.aiter_raw()] + + assert response.status_code == 200 + assert response.headers["content-type"] == "text/plain" + assert chunks == [b"first", b"second"] + + +async def test_closing_the_response_delivers_a_disconnect_to_the_application() -> None: + """A client that closes the response early is seen by the application as an http.disconnect.""" + seen_after_request: list[Message] = [] + disconnect_seen = anyio.Event() + + async def waiting_app(scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + assert (await receive())["type"] == "http.request" + await send({"type": "http.response.start", "status": 200, "headers": []}) + seen_after_request.append(await receive()) + disconnect_seen.set() + + async with httpx.AsyncClient(transport=StreamingASGITransport(waiting_app), base_url="http://bridge") as http: + async with http.stream("GET", "/wait") as response: + assert response.status_code == 200 + # Leaving the stream block closes the response while the application is still mid-response. + with anyio.fail_after(5): + await disconnect_seen.wait() + + assert seen_after_request == [{"type": "http.disconnect"}] + + +async def test_an_application_failure_before_the_response_starts_fails_the_request() -> None: + """An exception raised before http.response.start reaches the caller as that same exception.""" + + async def broken_app(scope: Scope, receive: Receive, send: Send) -> None: + raise RuntimeError("the demo application is broken") + + async with httpx.AsyncClient(transport=StreamingASGITransport(broken_app), base_url="http://bridge") as http: + with pytest.raises(RuntimeError, match="the demo application is broken"): + await http.get("/broken") + + +async def test_disabling_cancel_on_close_lets_the_application_finish_after_disconnect() -> None: + """With cancel_on_close=False, an application that runs cleanup after seeing http.disconnect + completes that cleanup before the transport finishes closing.""" + cleanup_ran = anyio.Event() + + async def lingering_app(scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + await receive() + await send({"type": "http.response.start", "status": 200, "headers": []}) + assert (await receive())["type"] == "http.disconnect" + cleanup_ran.set() + + transport = StreamingASGITransport(lingering_app, cancel_on_close=False) + with anyio.fail_after(5): + async with httpx.AsyncClient(transport=transport, base_url="http://bridge") as http: + async with http.stream("GET", "/linger") as response: + assert response.status_code == 200 + assert not cleanup_ran.is_set() + assert cleanup_ran.is_set() diff --git a/tests/interaction/transports/test_client_transport_http.py b/tests/interaction/transports/test_client_transport_http.py new file mode 100644 index 0000000000..65ed03f1e4 --- /dev/null +++ b/tests/interaction/transports/test_client_transport_http.py @@ -0,0 +1,247 @@ +"""Behaviour of the streamable-HTTP client transport itself, observed at the wire. + +These tests connect a real `Client` to a real server over the in-process bridge, recording every +HTTP request the SDK client issues, so the assertions are about what the transport sends (headers, +methods, ordering) rather than what the protocol layer on top of it returns. The recording is the +wire-level instrument; the SDK client never exposes these details. +""" + +from collections.abc import AsyncIterator + +import anyio +import httpx +import pytest +from inline_snapshot import snapshot +from starlette.types import Receive, Scope, Send + +from mcp import MCPError, types +from mcp.client.client import Client +from mcp.client.streamable_http import streamable_http_client +from mcp.server import Server, ServerRequestContext +from mcp.types import INVALID_REQUEST, CallToolResult, ErrorData, ListToolsResult, TextContent, Tool +from tests.interaction._connect import BASE_URL, NO_DNS_REBINDING_PROTECTION, client_via_http, mounted_app +from tests.interaction._requirements import requirement +from tests.interaction.transports._bridge import StreamingASGITransport +from tests.interaction.transports._event_store import SequencedEventStore + +pytestmark = pytest.mark.anyio + + +def _tooled_server() -> Server: + """A low-level server with one echo tool, used by every test in this file.""" + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="echo", description="Echo text.", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "echo" + assert params.arguments is not None + return CallToolResult(content=[TextContent(text=str(params.arguments["text"]))]) + + return Server("echoer", on_list_tools=list_tools, on_call_tool=call_tool) + + +@pytest.fixture +async def recorded() -> AsyncIterator[list[httpx.Request]]: + """Connect a `Client` over a recording HTTP client, list tools, exit, and yield every request sent. + + The HTTP client carries one caller-supplied header (`x-trace`) so its propagation can be + asserted; the recording captures the closing DELETE because it is read after the `Client` has + fully exited. + """ + requests: list[httpx.Request] = [] + + async def record(request: httpx.Request) -> None: + requests.append(request) + + async with mounted_app(_tooled_server(), on_request=record, headers={"x-trace": "abc"}) as (http, _): + async with client_via_http(http) as client: + result = await client.list_tools() + assert [tool.name for tool in result.tools] == ["echo"] + + yield requests + + +def _after_initialize(recorded: list[httpx.Request]) -> list[httpx.Request]: + """Every recorded request after the initialize POST (which carries no session yet).""" + assert recorded[0].method == "POST" + assert "mcp-session-id" not in recorded[0].headers + return recorded[1:] + + +@requirement("client-transport:http:custom-client") +@requirement("client-transport:http:custom-headers") +async def test_the_client_uses_the_supplied_http_client_and_propagates_its_headers( + recorded: list[httpx.Request], +) -> None: + """A caller-supplied `httpx.AsyncClient` is used for every request and carries its own headers. + + The recording itself proves the supplied client is the one in use; the propagated header + proves the SDK transport does not replace the caller's client configuration. + """ + # Exact ordering past the first request is not guaranteed (the standalone GET stream is + # scheduled concurrently with later POSTs), so methods are asserted as a multiset. + assert sorted(request.method for request in recorded) == snapshot(["DELETE", "GET", "POST", "POST", "POST"]) + assert all(request.headers["x-trace"] == "abc" for request in recorded) + + +@requirement("client-transport:http:session-stored") +async def test_every_request_after_initialize_carries_the_issued_session_id(recorded: list[httpx.Request]) -> None: + """The session id from the initialize response is sent on every subsequent request.""" + session_ids = {request.headers["mcp-session-id"] for request in _after_initialize(recorded)} + assert len(session_ids) == 1 + (session_id,) = session_ids + assert session_id + + +@requirement("client-transport:http:protocol-version-stored") +@requirement("client-transport:http:protocol-version-header") +async def test_every_request_after_initialize_carries_the_negotiated_protocol_version( + recorded: list[httpx.Request], +) -> None: + """The negotiated protocol version is sent on every subsequent request (and not on initialize).""" + assert "mcp-protocol-version" not in recorded[0].headers + versions = {request.headers["mcp-protocol-version"] for request in _after_initialize(recorded)} + assert versions == snapshot({"2025-11-25"}) + + +@requirement("client-transport:http:accept-header-post") +@requirement("client-transport:http:accept-header-get") +async def test_accept_headers_cover_the_response_representations_the_transport_handles( + recorded: list[httpx.Request], +) -> None: + """POSTs accept both JSON and SSE; the standalone GET stream accepts SSE.""" + for request in recorded: + if request.method == "POST": + assert "application/json" in request.headers["accept"] + assert "text/event-stream" in request.headers["accept"] + if request.method == "GET": + assert "text/event-stream" in request.headers["accept"] + + +@requirement("client-transport:http:no-reconnect-after-close") +async def test_closing_the_client_sends_delete_and_does_not_reconnect(recorded: list[httpx.Request]) -> None: + """Client teardown sends DELETE and issues no further requests (no resumption GET).""" + assert recorded[-1].method == "DELETE" + assert all("last-event-id" not in request.headers for request in recorded) + + +@requirement("client-transport:http:concurrent-streams") +async def test_concurrent_tool_calls_each_open_a_post_stream_and_receive_their_own_response() -> None: + """Three tool calls issued at once each open their own POST stream and get the right answer.""" + requests: list[httpx.Request] = [] + results: dict[int, CallToolResult] = {} + + async def record(request: httpx.Request) -> None: + requests.append(request) + + async with mounted_app(_tooled_server(), on_request=record) as (http, _), client_via_http(http) as client: + + async def call(n: int) -> None: + results[n] = await client.call_tool("echo", {"text": str(n)}) + + with anyio.fail_after(5): # pragma: no branch + async with anyio.create_task_group() as tg: # pragma: no branch + for n in (1, 2, 3): + tg.start_soon(call, n) + + assert results == snapshot( + { + 1: CallToolResult(content=[TextContent(text="1")]), + 2: CallToolResult(content=[TextContent(text="2")]), + 3: CallToolResult(content=[TextContent(text="3")]), + } + ) + tools_call_posts = [r for r in requests if r.method == "POST" and b'"tools/call"' in r.content] + assert len(tools_call_posts) == 3 + + +@requirement("client-transport:http:sse-405-tolerated") +@requirement("client-transport:http:terminate-405-ok") +async def test_client_tolerates_405_on_get_and_delete() -> None: + """A 405 on the standalone GET stream or the closing DELETE does not fail the connection. + + The GET-stream task swallows the failure and schedules a reconnect that the closing cancel + interrupts before it ever sleeps the full default delay; the DELETE 405 is logged and ignored. + Neither surfaces to the caller. + """ + server = _tooled_server() + real_app = server.streamable_http_app(transport_security=NO_DNS_REBINDING_PROTECTION) + + async def filter_methods(scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http" and scope["method"] in ("GET", "DELETE"): + await send({"type": "http.response.start", "status": 405, "headers": []}) + await send({"type": "http.response.body", "body": b""}) + return + await real_app(scope, receive, send) + + async with ( + server.session_manager.run(), + httpx.AsyncClient(transport=StreamingASGITransport(filter_methods), base_url=BASE_URL) as http_client, + ): + transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) + with anyio.fail_after(5): # pragma: no branch + async with Client(transport) as client: # pragma: no branch + result = await client.list_tools() + + assert [tool.name for tool in result.tools] == ["echo"] + + +@requirement("client-transport:http:no-reconnect-after-response") +async def test_a_completed_post_stream_is_not_reconnected() -> None: + """A POST stream that delivered its response closes without a resumption GET. + + With an event store the server stamps every SSE event with an ID, so the client transport has a + Last-Event-ID it could resume from -- the test proves it does not, because the response arrived + and the stream completed normally. + """ + requests: list[httpx.Request] = [] + + async def record(request: httpx.Request) -> None: + requests.append(request) + + server = _tooled_server() + async with ( + mounted_app(server, event_store=SequencedEventStore(), retry_interval=0, on_request=record) as (http, _), + client_via_http(http) as client, + ): + with anyio.fail_after(5): + result = await client.list_tools() + + assert [tool.name for tool in result.tools] == ["echo"] + resumption_gets = [r for r in requests if r.method == "GET" and "last-event-id" in r.headers] + assert resumption_gets == [] + + +@requirement("client-transport:http:404-surfaces") +async def test_a_404_mid_session_surfaces_as_a_session_terminated_error() -> None: + """A 404 in response to a request after initialization is reported to the caller as an MCP error. + + The spec says the client MUST start a new session in this situation; the SDK instead surfaces a + `Session terminated` error to the caller. The spec's MUST is tracked at + client-transport:http:session-404-reinitialize; this test pins the SDK's current behaviour. + """ + server = _tooled_server() + real_app = server.streamable_http_app(transport_security=NO_DNS_REBINDING_PROTECTION) + initialize_seen = anyio.Event() + + async def first_post_then_404(scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http" and scope["method"] == "POST" and initialize_seen.is_set(): + await send({"type": "http.response.start", "status": 404, "headers": []}) + await send({"type": "http.response.body", "body": b""}) + return + if scope["type"] == "http" and scope["method"] == "POST": + initialize_seen.set() + await real_app(scope, receive, send) + + async with ( + server.session_manager.run(), + httpx.AsyncClient(transport=StreamingASGITransport(first_post_then_404), base_url=BASE_URL) as http_client, + ): + transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) + with anyio.fail_after(5): # pragma: no branch + async with Client(transport) as client: # pragma: no branch + with pytest.raises(MCPError) as exc_info: # pragma: no branch + await client.list_tools() + + assert exc_info.value.error == snapshot(ErrorData(code=INVALID_REQUEST, message="Session terminated")) diff --git a/tests/interaction/transports/test_flows.py b/tests/interaction/transports/test_flows.py new file mode 100644 index 0000000000..c428fe2d68 --- /dev/null +++ b/tests/interaction/transports/test_flows.py @@ -0,0 +1,129 @@ +"""Transport-level composed flows: multi-client isolation, reconnection, and dual-transport hosting. + +These scenarios are about how the transport layer holds together across more than one connection +or more than one transport, so they connect real `Client`s against one mounted server rather than +running over the matrix. +""" + +import anyio +import httpx +import pytest +from inline_snapshot import snapshot + +from mcp.client.session import LoggingFnT +from mcp.server.mcpserver import Context, MCPServer +from mcp.types import CallToolResult, LoggingMessageNotificationParams, TextContent +from tests.interaction._connect import client_via_http, connect_over_sse, mounted_app +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("flow:multi-client:stateful-isolation") +async def test_concurrent_clients_on_one_stateful_server_receive_only_their_own_notifications() -> None: + """Two clients on one stateful manager each receive only the notifications their own request produced. + + Complements `test_terminating_one_session_leaves_others_working` (which proves session + independence under termination) with the notification-isolation dimension: a notification + emitted by one session's handler does not leak to another session's client. + """ + mcp = MCPServer("multi") + + @mcp.tool() + async def announce(label: str, ctx: Context) -> str: + """Emit one info-level log carrying the caller's label, then return it.""" + await ctx.info(label) + return label + + received_a: list[object] = [] + received_b: list[object] = [] + + async def collect_a(params: LoggingMessageNotificationParams) -> None: + received_a.append(params.data) + + async def collect_b(params: LoggingMessageNotificationParams) -> None: + received_b.append(params.data) + + async with mounted_app(mcp) as (http, _): + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + + async def call(label: str, collect: LoggingFnT) -> None: + async with client_via_http(http, logging_callback=collect) as client: + await client.call_tool("announce", {"label": label}) + + tg.start_soon(call, "a", collect_a) + tg.start_soon(call, "b", collect_b) + + assert received_a == ["a"] + assert received_b == ["b"] + + +@requirement("flow:session:terminate-then-reconnect") +async def test_a_fresh_connection_after_termination_obtains_a_new_session_and_operates() -> None: + """After a client terminates, a fresh connection to the same manager gets a distinct session. + + Steps: (1) connect a client and call list_tools, (2) the client exits (its DELETE fires), + (3) connect a second client to the same mounted app, (4) the second client's call_tool + succeeds and the recorded session ids show two distinct sessions were issued. + """ + mcp = MCPServer("reconnectable") + + @mcp.tool() + def echo(text: str) -> str: + """Return the input unchanged.""" + return text + + session_ids: list[str] = [] + + async def record(request: httpx.Request) -> None: + session_id = request.headers.get("mcp-session-id") + if session_id is not None: + session_ids.append(session_id) + + async with mounted_app(mcp, on_request=record) as (http, _): + async with client_via_http(http) as first: + first_result = await first.list_tools() + async with client_via_http(http) as second: + second_result = await second.call_tool("echo", {"text": "again"}) + + assert {tool.name for tool in first_result.tools} == {"echo"} + assert second_result == snapshot( + CallToolResult(content=[TextContent(text="again")], structured_content={"result": "again"}) + ) + distinct = set(session_ids) + assert len(distinct) == 2, f"expected two distinct session ids across the two connections, saw {distinct}" + + +@requirement("flow:compat:dual-transport-server") +async def test_one_server_serves_streamable_http_and_sse_clients_concurrently() -> None: + """One MCPServer instance serves a streamable-HTTP client and a legacy-SSE client at the same time. + + The two transports have independent connection management (the streamable-HTTP session manager + versus a per-connection SSE handler), but both dispatch into the same server's request + handlers. The test connects one client over each transport against the same instance and + proves both reach the same tool. Uses MCPServer because the low-level Server has no SSE + convenience; the entry is about hosting composition, not the low-level API. + """ + mcp = MCPServer("dual") + + @mcp.tool() + def echo(text: str) -> str: + """Return the input unchanged.""" + return text + + async with ( + mounted_app(mcp) as (http, _), + connect_over_sse(mcp) as sse_client, + client_via_http(http) as shttp_client, + ): + with anyio.fail_after(5): + shttp_result = await shttp_client.call_tool("echo", {"text": "via http"}) + sse_result = await sse_client.call_tool("echo", {"text": "via sse"}) + + assert shttp_result == snapshot( + CallToolResult(content=[TextContent(text="via http")], structured_content={"result": "via http"}) + ) + assert sse_result == snapshot( + CallToolResult(content=[TextContent(text="via sse")], structured_content={"result": "via sse"}) + ) diff --git a/tests/interaction/transports/test_hosting_http.py b/tests/interaction/transports/test_hosting_http.py new file mode 100644 index 0000000000..85e64ded42 --- /dev/null +++ b/tests/interaction/transports/test_hosting_http.py @@ -0,0 +1,344 @@ +"""Streamable HTTP semantics: status codes, header validation, message routing, and security. + +These tests speak HTTP directly to the server's mounted ASGI app via the in-process bridge, +asserting the wire contract -- which status code answers which condition, which stream a message +travels on -- that the SDK client never exposes. Transport-agnostic behaviour is covered by the +`connect`-fixture matrix. +""" + +import anyio +import pytest +from anyio.lowlevel import checkpoint +from httpx_sse import ServerSentEvent, aconnect_sse +from inline_snapshot import snapshot + +from mcp.server import Server, ServerRequestContext +from mcp.server.transport_security import TransportSecuritySettings +from mcp.types import ( + INVALID_PARAMS, + PARSE_ERROR, + CallToolRequestParams, + CallToolResult, + EmptyResult, + JSONRPCError, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + ListResourcesResult, + ListToolsResult, + PaginatedRequestParams, + SetLevelRequestParams, + SubscribeRequestParams, + TextContent, +) +from tests.interaction._connect import ( + base_headers, + initialize_body, + initialize_via_http, + mounted_app, + parse_sse_messages, +) +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +def _server() -> Server: + """A low-level server with one tool that emits a related and an unrelated notification.""" + + async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + """Registered only so the tools capability is advertised; never called.""" + raise NotImplementedError + + async def call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + assert params.name == "narrate" + await ctx.session.send_log_message(level="info", data="related", logger=None, related_request_id=ctx.request_id) + await ctx.session.send_resource_updated("file:///watched.txt") + return CallToolResult(content=[TextContent(text="done")]) + + async def set_logging_level(ctx: ServerRequestContext, params: SetLevelRequestParams) -> EmptyResult: + """Registered so the logging capability is advertised; the client never sets a level.""" + raise NotImplementedError + + async def list_resources(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListResourcesResult: + """Registered so the resources capability is advertised; the client never lists resources.""" + raise NotImplementedError + + async def subscribe_resource(ctx: ServerRequestContext, params: SubscribeRequestParams) -> EmptyResult: + """Registered so the resources subscribe sub-capability is advertised; the client never subscribes.""" + raise NotImplementedError + + return Server( + "hosted", + on_list_tools=list_tools, + on_call_tool=call_tool, + on_set_logging_level=set_logging_level, + on_list_resources=list_resources, + on_subscribe_resource=subscribe_resource, + ) + + +@requirement("hosting:http:method-405") +async def test_unsupported_http_methods_return_405() -> None: + """PUT and PATCH on the MCP endpoint return 405 with an Allow header naming the supported methods.""" + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + put = await http.put("/mcp", json={}, headers=base_headers(session_id=session_id)) + patch = await http.patch("/mcp", json={}, headers=base_headers(session_id=session_id)) + + assert (put.status_code, put.headers.get("allow")) == snapshot((405, "GET, POST, DELETE")) + assert (patch.status_code, patch.headers.get("allow")) == snapshot((405, "GET, POST, DELETE")) + + +@requirement("hosting:http:accept-406") +async def test_missing_accept_media_types_return_406() -> None: + """A POST whose Accept header lacks both required types, or a GET lacking text/event-stream, returns 406.""" + async with mounted_app(_server()) as (http, _): + post = await http.post( + "/mcp", json=initialize_body(), headers={"accept": "text/plain", "mcp-protocol-version": "2025-11-25"} + ) + session_id = await initialize_via_http(http) + get = await http.get( + "/mcp", + headers={"accept": "application/json", "mcp-protocol-version": "2025-11-25", "mcp-session-id": session_id}, + ) + + assert (post.status_code, post.json()["error"]["message"]) == snapshot( + (406, "Not Acceptable: Client must accept both application/json and text/event-stream") + ) + assert (get.status_code, get.json()["error"]["message"]) == snapshot( + (406, "Not Acceptable: Client must accept text/event-stream") + ) + + +@requirement("hosting:http:content-type-415") +async def test_non_json_content_type_is_rejected() -> None: + """A POST with a non-JSON Content-Type is rejected before reaching the transport. + + See the divergence on the requirement: the security middleware rejects with 400, so the + transport's own 415 path is unreachable through any public entry point. + """ + async with mounted_app(_server()) as (http, _): + response = await http.post( + "/mcp", content=b"", headers=base_headers() | {"content-type": "text/plain"} + ) + + assert (response.status_code, response.text) == snapshot((400, "Invalid Content-Type header")) + + +@requirement("hosting:http:parse-error-400") +@requirement("hosting:http:batch") +async def test_malformed_and_batched_bodies_return_400() -> None: + """A non-JSON body returns 400 Parse error; a JSON array of requests returns 400 Invalid params.""" + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + not_json = await http.post( + "/mcp", + content=b"this is not json", + headers=base_headers(session_id=session_id) | {"content-type": "application/json"}, + ) + batched = await http.post( + "/mcp", + json=[ + {"jsonrpc": "2.0", "id": 1, "method": "tools/list"}, + {"jsonrpc": "2.0", "id": 2, "method": "tools/list"}, + ], + headers=base_headers(session_id=session_id), + ) + + assert not_json.status_code == 400 + assert JSONRPCError.model_validate_json(not_json.text).error.code == PARSE_ERROR + assert batched.status_code == 400 + assert JSONRPCError.model_validate_json(batched.text).error.code == INVALID_PARAMS + + +@requirement("hosting:http:protocol-version-400") +@requirement("hosting:http:protocol-version-default") +async def test_protocol_version_header_is_validated() -> None: + """An unsupported MCP-Protocol-Version header returns 400; an absent header is accepted as the default.""" + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + + bad = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 2, "method": "tools/list"}, + headers=base_headers(session_id=session_id) | {"mcp-protocol-version": "1991-01-01"}, + ) + # Only Accept and the session ID -- no MCP-Protocol-Version header at all. + defaulted = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "method": "notifications/progress", "params": {"progressToken": 0, "progress": 1}}, + headers={"accept": "application/json, text/event-stream", "mcp-session-id": session_id}, + ) + + assert bad.status_code == 400 + assert JSONRPCError.model_validate_json(bad.text).error.message.startswith( + "Bad Request: Unsupported protocol version: 1991-01-01." + ) + # 202 proves the request was accepted under the assumed default version (2025-03-26). + assert defaulted.status_code == 202 + + +@requirement("hosting:http:json-response-mode") +async def test_json_response_mode_answers_with_application_json_not_sse() -> None: + """With JSON response mode enabled, request POSTs are answered with a single application/json body. + + Asserted at the wire level because the SDK client parses either representation, so a + Client-driven round trip cannot distinguish a JSON response from an SSE one. + """ + async with mounted_app(_server(), json_response=True) as (http, _): + initialized = await http.post("/mcp", json=initialize_body(), headers=base_headers()) + session_id = initialized.headers["mcp-session-id"] + ping = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 2, "method": "ping"}, + headers=base_headers(session_id=session_id), + ) + + assert initialized.status_code == 200 + assert initialized.headers["content-type"].split(";", 1)[0] == "application/json" + assert JSONRPCResponse.model_validate(initialized.json()).id == 1 + assert ping.status_code == 200 + assert ping.headers["content-type"].split(";", 1)[0] == "application/json" + assert JSONRPCResponse.model_validate(ping.json()).id == 2 + + +@requirement("hosting:http:notifications-202") +async def test_notification_post_returns_202_with_no_body() -> None: + """A POST containing only a notification (no request ID) returns 202 Accepted with no body.""" + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + response = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "method": "notifications/progress", "params": {"progressToken": 0, "progress": 1}}, + headers=base_headers(session_id=session_id), + ) + + assert (response.status_code, response.content) == snapshot((202, b"")) + + +@requirement("hosting:http:second-sse-rejected") +async def test_a_second_standalone_get_stream_on_the_same_session_returns_409() -> None: + """Opening a second standalone GET SSE stream while one is already established returns 409 Conflict.""" + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + + async with aconnect_sse(http, "GET", "/mcp", headers=base_headers(session_id=session_id)) as first: + assert first.response.status_code == 200 + # The standalone-stream writer registers its key as its first action, then parks + # awaiting messages; one yield to the loop lets that registration complete before the + # second GET is dispatched. + await checkpoint() + second = await http.get("/mcp", headers=base_headers(session_id=session_id)) + + assert (second.status_code, second.json()["error"]["message"]) == snapshot( + (409, "Conflict: Only one SSE stream is allowed per session") + ) + + +@requirement("hosting:http:standalone-sse") +@requirement("hosting:http:standalone-sse-no-response") +@requirement("hosting:http:response-same-connection") +@requirement("hosting:http:sse-close-after-response") +@requirement("hosting:http:no-broadcast") +async def test_messages_are_routed_to_exactly_one_stream() -> None: + """Each server message travels on exactly one SSE stream and is never broadcast. + + A streamable-HTTP session has two kinds of server-to-client SSE stream: one short-lived stream + per POST request, carrying that request's response and any notifications related to it, and one + long-lived standalone stream (opened by GET) for notifications not tied to any request. The + spec's routing rule is that the POST stream delivers the response (and its related + notifications) and then closes, the standalone stream carries only unrelated notifications and + never a JSON-RPC response, and no message appears on both. The test opens both streams, calls a + tool whose handler emits one related and one unrelated notification, and asserts each message's + routing. + """ + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + post_events: list[ServerSentEvent] = [] + get_events: list[ServerSentEvent] = [] + + async def read_standalone_stream() -> None: + async with aconnect_sse(http, "GET", "/mcp", headers=base_headers(session_id=session_id)) as get: + assert get.response.status_code == 200 + standalone_ready.set() + async for event in get.aiter_sse(): + get_events.append(event) + seen_on_standalone.set() + + standalone_ready = anyio.Event() + seen_on_standalone = anyio.Event() + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + tg.start_soon(read_standalone_stream) + await standalone_ready.wait() + + params = CallToolRequestParams(name="narrate", arguments={}) + body = JSONRPCRequest(jsonrpc="2.0", id=5, method="tools/call", params=params.model_dump()) + async with aconnect_sse( + http, + "POST", + "/mcp", + json=body.model_dump(by_alias=True, exclude_none=True), + headers=base_headers(session_id=session_id), + ) as post: + assert post.response.status_code == 200 + # The POST stream iterator ends when the server closes the stream after the response. + post_events = [event async for event in post.aiter_sse()] + + await seen_on_standalone.wait() + tg.cancel_scope.cancel() + + post_messages = parse_sse_messages(post_events) + get_messages = parse_sse_messages(get_events) + + # POST stream: the related log notification, then the response, then the iterator ends (close). + assert [type(m).__name__ for m in post_messages] == snapshot(["JSONRPCNotification", "JSONRPCResponse"]) + assert isinstance(post_messages[0], JSONRPCNotification) + assert (post_messages[0].method, post_messages[0].params) == snapshot( + ("notifications/message", {"level": "info", "data": "related"}) + ) + assert isinstance(post_messages[1], JSONRPCResponse) + assert post_messages[1].id == 5 + + # Standalone stream: only the unrelated resource-updated notification, never a response. + assert [type(m).__name__ for m in get_messages] == snapshot(["JSONRPCNotification"]) + assert isinstance(get_messages[0], JSONRPCNotification) + assert get_messages[0].method == snapshot("notifications/resources/updated") + + +@requirement("hosting:http:dns-rebinding") +@requirement("transport:streamable-http:origin-validation") +async def test_origin_validation_rejects_disallowed_origins_when_enabled() -> None: + """A disallowed Origin returns 403 (and Host 421) with protection enabled; disabled lets both through. + + See the divergence on hosting:http:dns-rebinding: the spec's Origin validation is an + unconditional MUST, but the SDK enables it only when the host is localhost (or settings are + passed explicitly) and additionally checks the Host header (returning 421), which the spec + does not require. + """ + # transport_security=None triggers the localhost auto-enable behaviour. + async with mounted_app(Server("guarded"), transport_security=None) as (http, _): + bad_origin = await http.post( + "/mcp", json=initialize_body(), headers=base_headers() | {"origin": "http://evil.example"} + ) + bad_host = await http.post("/mcp", json=initialize_body(), headers=base_headers() | {"host": "evil.example"}) + async with aconnect_sse( + http, "POST", "/mcp", json=initialize_body(), headers=base_headers() | {"origin": "http://127.0.0.1:8000"} + ) as ok: + assert ok.response.status_code == 200 + assert [event async for event in ok.aiter_sse()] + + assert (bad_origin.status_code, bad_origin.text) == snapshot((403, "Invalid Origin header")) + assert (bad_host.status_code, bad_host.text) == snapshot((421, "Invalid Host header")) + + async with mounted_app( + Server("unguarded"), transport_security=TransportSecuritySettings(enable_dns_rebinding_protection=False) + ) as (http, _): + async with aconnect_sse( + http, "POST", "/mcp", json=initialize_body(), headers=base_headers() | {"origin": "http://evil.example"} + ) as unguarded: + status = unguarded.response.status_code + assert [event async for event in unguarded.aiter_sse()] + + assert status == 200 diff --git a/tests/interaction/transports/test_hosting_resume.py b/tests/interaction/transports/test_hosting_resume.py new file mode 100644 index 0000000000..c7945d56c3 --- /dev/null +++ b/tests/interaction/transports/test_hosting_resume.py @@ -0,0 +1,372 @@ +"""Resumability over the streamable HTTP transport, exercised entirely in process. + +These tests configure the server with an event store, so every SSE event is stamped with an ID +and a client that loses its connection can resume by sending `Last-Event-ID`. The wire-level +tests (`mounted_app` + raw httpx) assert exactly what travels on the wire; the end-to-end test +drives the SDK client through a server-initiated stream close and proves the call still +completes. The bridge's `aclose()` delivers `http.disconnect` to the running application, so +closing a streaming response mid-read is a deterministic in-process disconnect -- no sockets, +no real time. Every server here uses `retry_interval=0` so reconnection waits are no-ops. +""" + +import json + +import anyio +import httpx +import pytest +from httpx_sse import EventSource, ServerSentEvent +from inline_snapshot import snapshot + +from mcp.client.session import ClientSession +from mcp.client.streamable_http import streamable_http_client +from mcp.server.mcpserver import Context, MCPServer +from mcp.shared.message import ClientMessageMetadata +from mcp.types import ( + LATEST_PROTOCOL_VERSION, + CallToolRequest, + CallToolRequestParams, + CallToolResult, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + LoggingMessageNotificationParams, + TextContent, + jsonrpc_message_adapter, +) +from tests.interaction._connect import ( + BASE_URL, + base_headers, + connect_over_streamable_http, + initialize_via_http, + mounted_app, + parse_sse_messages, +) +from tests.interaction._requirements import requirement +from tests.interaction.transports._event_store import SequencedEventStore + +pytestmark = pytest.mark.anyio + + +def _counting_server() -> MCPServer: + """A server with one tool that emits related notifications and one unrelated notification.""" + mcp = MCPServer("resumable") + + @mcp.tool() + async def count(ctx: Context, n: int) -> str: + """Emit n log notifications related to this call, plus one unrelated resource update.""" + for i in range(1, n + 1): + await ctx.info(f"tick {i}") + await ctx.session.send_resource_updated("file:///elsewhere.txt") + return f"counted to {n}" + + return mcp + + +def _tools_call(request_id: int, name: str, arguments: dict[str, object]) -> str: + """A serialized tools/call JSON-RPC request body.""" + return JSONRPCRequest( + jsonrpc="2.0", id=request_id, method="tools/call", params={"name": name, "arguments": arguments} + ).model_dump_json(by_alias=True, exclude_none=True) + + +async def _read_events(response: httpx.Response, count: int) -> list[ServerSentEvent]: + """Read exactly `count` SSE events from a streaming response without closing it.""" + source = EventSource(response).aiter_sse() + return [await anext(source) for _ in range(count)] + + +@requirement("hosting:resume:event-ids") +@requirement("hosting:resume:priming") +async def test_a_post_sse_stream_begins_with_a_priming_event_and_stamps_every_event() -> None: + """A request's SSE stream opens with a priming event (id, empty data, retry) then stamps each message.""" + async with mounted_app(_counting_server(), event_store=SequencedEventStore(), retry_interval=0) as (http, _): + session_id = await initialize_via_http(http) + with anyio.fail_after(5): + async with http.stream( # pragma: no branch + "POST", "/mcp", content=_tools_call(1, "count", {"n": 2}), headers=base_headers(session_id=session_id) + ) as response: + assert response.status_code == 200 + events = await _read_events(response, 4) + + priming, first, second, result = events + # The priming event is the only event a client could have seen before any work happened, so it + # is the resumption anchor: it carries an ID and empty data. The SDK attaches the retry hint + # to this event (see the divergence on hosting:resume:priming). + assert (priming.id, priming.data, priming.retry) == snapshot(("3", "", 0)) + assert priming.event == snapshot("message") + # Every subsequent event carries an event-store ID; the related notifications and the response + # all ride this stream and close it after the response. + assert [event.id for event in (first, second, result)] == snapshot(["4", "5", "7"]) + assert [json.loads(event.data)["method"] for event in (first, second)] == snapshot( + ["notifications/message", "notifications/message"] + ) + assert jsonrpc_message_adapter.validate_json(result.data) == snapshot( + JSONRPCResponse( + jsonrpc="2.0", + id=1, + result={ + "content": [{"type": "text", "text": "counted to 2"}], + "structuredContent": {"result": "counted to 2"}, + "isError": False, + }, + ) + ) + + +@requirement("hosting:resume:replay") +@requirement("hosting:resume:stream-scoped") +@requirement("hosting:resume:buffered-replay") +async def test_get_with_last_event_id_replays_only_that_streams_missed_events() -> None: + """Reconnecting with Last-Event-ID returns the missed events from that one stream, in order. + + The handler also emits an unrelated notification (which the server stores under the + standalone-stream key); replay must not return it, proving replay is scoped to the stream + the given event ID belongs to. + + Steps: (1) initialize; (2) POST a tool call and read events until the first notification is + captured; (3) close the response mid-stream -- the bridge delivers `http.disconnect`, the + handler keeps running; (4) release the handler so it emits the remaining messages, which the + server buffers in the event store; (5) wait on the event store for the handler's response to + be stored, so the replay's content is independent of task scheduling; (6) GET with + `Last-Event-ID` and assert the replay is exactly the missed events from this request's stream. + """ + release = anyio.Event() + store = SequencedEventStore() + + mcp = MCPServer("resumable") + + @mcp.tool() + async def count(ctx: Context) -> str: + """Emit one related notification, wait for the test, then emit two more plus an unrelated one.""" + await ctx.info("tick 1") + await release.wait() + await ctx.info("tick 2") + await ctx.info("tick 3") + await ctx.session.send_resource_updated("file:///elsewhere.txt") + return "counted" + + async with mounted_app(mcp, event_store=store, retry_interval=0) as (http, _): + session_id = await initialize_via_http(http) + with anyio.fail_after(5): + async with http.stream( + "POST", "/mcp", content=_tools_call(1, "count", {}), headers=base_headers(session_id=session_id) + ) as response: + # Read the priming event and the first notification, then drop the connection. + priming, first = await _read_events(response, 2) + assert (priming.id, first.id) == snapshot(("3", "4")) + last_seen = first.id + release.set() + # The handler keeps running after the disconnect; its remaining messages are stored. + # The first wait returns immediately (the priming and first tick are already stored); + # the second blocks until the response itself is stored so the replay content is fixed. + await store.wait_until_stored(4) + await store.wait_until_stored(8) + replay_headers = base_headers(session_id=session_id) | {"last-event-id": last_seen} + async with http.stream("GET", "/mcp", headers=replay_headers) as replay: # pragma: no branch + assert replay.status_code == 200 + missed = await _read_events(replay, 3) + + decoded = parse_sse_messages(missed) + # Exactly the two remaining related notifications and the response, with their original IDs. + assert [event.id for event in missed] == snapshot(["5", "6", "8"]) + assert [type(message).__name__ for message in decoded] == snapshot( + ["JSONRPCNotification", "JSONRPCNotification", "JSONRPCResponse"] + ) + assert isinstance(decoded[2], JSONRPCResponse) + assert decoded[2].id == 1 + # The unrelated resource-updated notification was stored under the standalone-stream key, not + # this request's stream, so it must not appear in the replay. + assert all( + not (isinstance(message, JSONRPCNotification) and message.method == "notifications/resources/updated") + for message in decoded + ) + + +@requirement("hosting:resume:bad-event-id") +async def test_an_unknown_last_event_id_yields_an_empty_replay_stream() -> None: + """A Last-Event-ID the event store cannot map produces an empty SSE stream rather than an error. + + See the divergence on hosting:resume:bad-event-id: this pins current behaviour. + """ + async with mounted_app(_counting_server(), event_store=SequencedEventStore(), retry_interval=0) as (http, _): + session_id = await initialize_via_http(http) + with anyio.fail_after(5): + for unknown in ("no-such-event", "0"): + headers = base_headers(session_id=session_id) | {"last-event-id": unknown} + async with http.stream("GET", "/mcp", headers=headers) as replay: + assert replay.status_code == 200 + assert replay.headers["content-type"].startswith("text/event-stream") + events = [event async for event in EventSource(replay).aiter_sse()] + assert events == [] + + +@requirement("hosting:http:disconnect-not-cancel") +async def test_dropping_the_connection_mid_request_does_not_cancel_the_handler() -> None: + """Closing the request's SSE connection while the handler is running leaves the handler running. + + The handler signals when it has started and when it has finished; the test drops the + connection in between and then releases the handler. If the disconnect cancelled the handler, + `finished` would never be set and the test would time out. + """ + started = anyio.Event() + release = anyio.Event() + finished = anyio.Event() + + mcp = MCPServer("resumable") + + @mcp.tool() + async def hold(ctx: Context) -> str: + """Signal start, wait for the test, signal completion.""" + started.set() + await release.wait() + await ctx.info("released") + finished.set() + return "held" + + async with mounted_app(mcp, event_store=SequencedEventStore(), retry_interval=0) as (http, _): + session_id = await initialize_via_http(http) + with anyio.fail_after(5): + async with http.stream( + "POST", "/mcp", content=_tools_call(1, "hold", {}), headers=base_headers(session_id=session_id) + ) as response: + await _read_events(response, 1) + await started.wait() + assert not finished.is_set() + release.set() + await finished.wait() + + +# This test intentionally carries every automatic-reconnection requirement: the +# close-then-resume scenario is indivisible, so splitting it would mean five near-identical bodies. +@requirement("hosting:resume:close-stream") +@requirement("transport:streamable-http:resumability") +@requirement("client-transport:http:reconnect-post-priming") +@requirement("client-transport:http:reconnect-retry-value") +@requirement("flow:resume:tool-call-resumption-token") +async def test_a_call_whose_stream_the_server_closes_is_resumed_by_the_client() -> None: + """A server-closed request stream is reconnected by the client and the call completes. + + The handler emits one notification, closes its own SSE stream, then (once released) emits + another and returns. The client observed the priming event (so it has a Last-Event-ID and a + retry hint of 0ms), sees the stream end, reconnects via GET with Last-Event-ID, and receives + the post-close notification and the result over the replay stream. The shared events make the + test deterministic: the handler only proceeds once the test knows the first notification has + arrived (and so the client's reconnection has begun). + """ + received: list[object] = [] + before_seen = anyio.Event() + gate = anyio.Event() + done = anyio.Event() + + mcp = MCPServer("resumable") + + @mcp.tool() + async def interrupt(ctx: Context) -> str: + """Emit, close this call's SSE stream, then emit again after the test releases the gate.""" + await ctx.info("before close") + await ctx.close_sse_stream() + await gate.wait() + await ctx.info("after close") + done.set() + return "resumed" + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params.data) + if params.data == "before close": + before_seen.set() + + result: list[CallToolResult] = [] + async with connect_over_streamable_http( + mcp, event_store=SequencedEventStore(), retry_interval=0, logging_callback=collect + ) as client: + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + + async def call() -> None: + result.append(await client.call_tool("interrupt", {})) + + tg.start_soon(call) + await before_seen.wait() + gate.set() + await done.wait() + + assert result == snapshot( + [CallToolResult(content=[TextContent(text="resumed")], structured_content={"result": "resumed"})] + ) + assert received == snapshot(["before close", "after close"]) + + +@requirement("client-transport:http:resume-stream-api") +async def test_a_captured_resumption_token_replays_missed_messages_on_a_new_connection() -> None: + """A resumption token captured via on_resumption_token_update on one connection lets a fresh + connection retrieve the messages it missed by passing resumption_token to send_request. + + This is the explicit ClientMessageMetadata API, distinct from the automatic reconnection the + previous test covers: the transport dispatches a resumption_token request as a GET with + Last-Event-ID instead of POSTing the body, and remaps the replayed response onto the new + request's id. Client.call_tool does not expose ClientMessageMetadata, so the test drives a + bare ClientSession via session.send_request -- the sanctioned drop-down for behaviour Client + cannot express. The second connection carries the original session id but does not initialize + (the server-side session already is), modelling a caller that resumes after a process restart. + """ + captured: list[str] = [] + received: list[object] = [] + first_seen = anyio.Event() + token_seen = anyio.Event() + release = anyio.Event() + store = SequencedEventStore() + + mcp = MCPServer("resumable") + + @mcp.tool() + async def hold(ctx: Context) -> str: + """Emit one notification, wait for the test, emit another, return.""" + await ctx.info("first") + await release.wait() + await ctx.info("second") + return "done" + + async def on_token(token: str) -> None: + captured.append(token) + if len(captured) >= 2: + token_seen.set() + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params.data) + first_seen.set() + + call = CallToolRequest(params=CallToolRequestParams(name="hold", arguments={})) + capture = ClientMessageMetadata(on_resumption_token_update=on_token) + + async with mounted_app(mcp, event_store=store, retry_interval=0) as (http, manager): + with anyio.fail_after(5): # pragma: no branch + async with ( # pragma: no branch + streamable_http_client(f"{BASE_URL}/mcp", http_client=http, terminate_on_close=False) as (r1, w1), + ClientSession(r1, w1, logging_callback=collect) as first, + anyio.create_task_group() as tg, + ): + await first.initialize() + tg.start_soon(first.send_request, call, CallToolResult, None, capture) + await first_seen.wait() + await token_seen.wait() + assert captured == snapshot(["3", "4"]) + assert received == snapshot(["first"]) + # The session id is only observable via the manager (the client transport does not expose it). + (session_id,) = manager._server_instances + http.headers["mcp-session-id"] = session_id + http.headers["mcp-protocol-version"] = LATEST_PROTOCOL_VERSION + tg.cancel_scope.cancel() + + with anyio.fail_after(5): # pragma: no branch + release.set() # pragma: lax no cover — python/cpython#106749: 3.11 drops this line event + # init priming + init response + call priming + "first" + "second" + result = 6 stored events. + await store.wait_until_stored(6) + async with ( # pragma: no branch + streamable_http_client(f"{BASE_URL}/mcp", http_client=http) as (r2, w2), + ClientSession(r2, w2, logging_callback=collect) as second, + ): + result = await second.send_request( + call, CallToolResult, metadata=ClientMessageMetadata(resumption_token=captured[-1]) + ) + assert result == snapshot(CallToolResult(content=[TextContent(text="done")], structured_content={"result": "done"})) + assert received == snapshot(["first", "second"]) diff --git a/tests/interaction/transports/test_hosting_session.py b/tests/interaction/transports/test_hosting_session.py new file mode 100644 index 0000000000..a926c3e8a2 --- /dev/null +++ b/tests/interaction/transports/test_hosting_session.py @@ -0,0 +1,202 @@ +"""Streamable HTTP session lifecycle: creation, routing, termination, and stateless mode. + +A test here speaks raw HTTP only when its assertion is the wire contract -- which header is +issued, which status code answers which condition -- that the SDK `Client` cannot observe. +Everything else is `Client`-driven against the same mounted session manager. Transport-agnostic +behaviour is covered by the `connect`-fixture matrix. +""" + +import re + +import anyio +import httpx +import pytest +from inline_snapshot import snapshot + +from mcp.server import Server, ServerRequestContext +from mcp.types import JSONRPCResponse, ListToolsResult, PaginatedRequestParams, Tool +from tests.interaction._connect import ( + base_headers, + client_via_http, + initialize_body, + initialize_via_http, + mounted_app, + post_jsonrpc, +) +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +def _server() -> Server: + """A minimal low-level server with one tool, so subsequent-request routing can be observed.""" + + async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="noop", description="Does nothing.", input_schema={"type": "object"})]) + + return Server("hosted", on_list_tools=list_tools) + + +@requirement("hosting:session:create") +@requirement("hosting:session:id-charset") +async def test_initialize_issues_a_visible_ascii_session_id() -> None: + """An initialize POST without a session ID creates a session and returns a visible-ASCII Mcp-Session-Id.""" + async with mounted_app(_server()) as (http, _): + response, messages = await post_jsonrpc(http, initialize_body()) + + assert response.status_code == 200 + session_id = response.headers.get("mcp-session-id") + assert session_id is not None + # The spec requires the session ID to consist only of visible ASCII (0x21-0x7E). + assert re.fullmatch(r"[\x21-\x7E]+", session_id) + assert isinstance(messages[0], JSONRPCResponse) + assert messages[0].id == 1 + + +@requirement("hosting:session:reuse") +async def test_subsequent_requests_with_the_session_id_route_to_the_same_session() -> None: + """Requests carrying the issued Mcp-Session-Id reuse that session's transport rather than creating another.""" + async with mounted_app(_server()) as (http, manager): + async with client_via_http(http) as client: + await client.list_tools() + await client.list_tools() + # The session count is the only signal that distinguishes routing-to-existing from + # silently creating a second session: both produce a successful result. + assert len(manager._server_instances) == 1 + + +@requirement("hosting:session:unknown-id") +async def test_requests_with_an_unknown_session_id_return_404() -> None: + """POST, GET, and DELETE each carrying an unknown Mcp-Session-Id are answered 404 by the manager.""" + async with mounted_app(_server()) as (http, _): + post = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 1, "method": "tools/list"}, + headers=base_headers(session_id="not-a-session"), + ) + get = await http.get("/mcp", headers=base_headers(session_id="not-a-session")) + delete = await http.delete("/mcp", headers=base_headers(session_id="not-a-session")) + + assert (post.status_code, post.json()) == snapshot( + (404, {"jsonrpc": "2.0", "id": None, "error": {"code": -32600, "message": "Session not found"}}) + ) + assert (get.status_code, delete.status_code) == (404, 404) + + +@requirement("hosting:session:missing-id") +async def test_non_initialize_post_without_a_session_id_returns_400() -> None: + """A non-initialize POST that omits Mcp-Session-Id in stateful mode is rejected with 400.""" + async with mounted_app(_server()) as (http, _): + await initialize_via_http(http) + response = await http.post( + "/mcp", json={"jsonrpc": "2.0", "id": 2, "method": "tools/list"}, headers=base_headers() + ) + + assert (response.status_code, response.json()) == snapshot( + (400, {"jsonrpc": "2.0", "id": None, "error": {"code": -32600, "message": "Bad Request: Missing session ID"}}) + ) + + +@requirement("hosting:session:delete") +@requirement("hosting:session:post-termination-404") +async def test_delete_terminates_the_session_and_subsequent_requests_return_404() -> None: + """DELETE with a valid Mcp-Session-Id terminates the session; further requests on that ID return 404.""" + async with mounted_app(_server()) as (http, manager): + session_id = await initialize_via_http(http) + + delete = await http.delete("/mcp", headers=base_headers(session_id=session_id)) + assert delete.status_code == 200 + + # The manager keeps the terminated transport registered, so the next request reaches the + # transport's own _terminated check rather than the manager's unknown-session path. + assert session_id in manager._server_instances + post = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 2, "method": "tools/list"}, + headers=base_headers(session_id=session_id), + ) + assert (post.status_code, post.json()) == snapshot( + ( + 404, + { + "jsonrpc": "2.0", + "id": None, + "error": {"code": -32600, "message": "Not Found: Session has been terminated"}, + }, + ) + ) + + +@requirement("hosting:session:isolation") +async def test_terminating_one_session_leaves_others_working() -> None: + """Terminating one session on a manager does not disturb a concurrent session on the same manager.""" + async with mounted_app(_server()) as (http, manager): + async with client_via_http(http) as survivor: + async with client_via_http(http) as terminated: + await terminated.list_tools() + assert len(manager._server_instances) == 2 + # `terminated` has exited (its DELETE has been sent); `survivor` still answers. + result = await survivor.list_tools() + + assert result.tools[0].name == "noop" + + +@requirement("hosting:session:reinitialize") +async def test_second_initialize_on_an_existing_session_is_accepted() -> None: + """A second initialize POST carrying an existing session ID is processed rather than rejected. + + See the divergence on the requirement: the entry expects a rejection, but the SDK forwards the + second initialize to the running server, which answers it as a fresh handshake. + """ + async with mounted_app(_server()) as (http, manager): + session_id = await initialize_via_http(http) + response, messages = await post_jsonrpc(http, initialize_body(request_id=2), session_id=session_id) + assert len(manager._server_instances) == 1 + + assert response.status_code == snapshot(200) + assert isinstance(messages[0], JSONRPCResponse) + assert messages[0].id == 2 + + +@requirement("hosting:stateless:no-session-id") +@requirement("hosting:stateless:no-reuse") +async def test_stateless_mode_never_issues_a_session_id() -> None: + """A stateless server issues no Mcp-Session-Id and creates no persistent transport. + + The recording proves no request the SDK client sent carried an Mcp-Session-Id (the server + cannot have issued one, or the client would echo it); the empty instance map proves the + manager kept no transport between requests. + """ + requests: list[httpx.Request] = [] + + async def record(request: httpx.Request) -> None: + requests.append(request) + + async with mounted_app(_server(), stateless_http=True, on_request=record) as (http, manager): + async with client_via_http(http) as client: + result = await client.list_tools() + assert manager._server_instances == {} + + assert result.tools[0].name == "noop" + assert all("mcp-session-id" not in request.headers for request in requests) + assert "DELETE" not in {request.method for request in requests} + + +@requirement("hosting:stateless:concurrent-clients") +async def test_stateless_mode_serves_concurrent_clients_independently() -> None: + """Two clients connected concurrently to the same stateless app each complete a round trip.""" + results: dict[str, ListToolsResult] = {} + + async with mounted_app(_server(), stateless_http=True) as (http, _): + + async def list_via(label: str) -> None: + async with client_via_http(http) as client: + results[label] = await client.list_tools() + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + tg.start_soon(list_via, "a") + tg.start_soon(list_via, "b") + + assert results["a"].tools[0].name == "noop" + assert results["b"].tools[0].name == "noop" diff --git a/tests/interaction/transports/test_sse.py b/tests/interaction/transports/test_sse.py new file mode 100644 index 0000000000..9c7353dda5 --- /dev/null +++ b/tests/interaction/transports/test_sse.py @@ -0,0 +1,90 @@ +"""Behaviour specific to the legacy HTTP+SSE transport, exercised entirely in process. + +Transport-agnostic behaviour is covered by the `connect`-fixture matrix, which runs the rest of +the suite over this transport as well; this file pins only what is observable on the SSE wiring +itself: the GET-then-POST connection lifecycle, the endpoint event, and how the message endpoint +rejects requests it cannot route to a session. Every test drives the server's real Starlette app +through the suite's streaming ASGI bridge. +""" + +from uuid import UUID, uuid4 + +import anyio +import httpx +import pytest +from inline_snapshot import snapshot + +from mcp.client.client import Client +from mcp.client.sse import sse_client +from mcp.server import Server +from mcp.types import EmptyResult +from tests.interaction._connect import BASE_URL, build_sse_app +from tests.interaction._requirements import requirement +from tests.interaction.transports._bridge import StreamingASGITransport + +pytestmark = pytest.mark.anyio + + +@requirement("transport:sse") +@requirement("transport:sse:endpoint-event") +async def test_endpoint_event_names_the_message_endpoint_with_a_fresh_session_id() -> None: + """Connecting opens a GET stream whose first event names the POST endpoint and a fresh + session id; messages POSTed there are answered on that stream, and disconnecting releases the + server's session entry.""" + app, sse = build_sse_app(Server("legacy")) + captured_session_id: list[str] = [] + + def httpx_client_factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + return httpx.AsyncClient( + transport=StreamingASGITransport(app, cancel_on_close=False), + base_url=BASE_URL, + headers=headers, + timeout=timeout, + auth=auth, + ) + + transport = sse_client( + f"{BASE_URL}/sse", httpx_client_factory=httpx_client_factory, on_session_created=captured_session_id.append + ) + with anyio.fail_after(5): + async with Client(transport) as client: + assert len(captured_session_id) == 1 + assert UUID(hex=captured_session_id[0]) in sse._read_stream_writers + assert await client.send_ping() == snapshot(EmptyResult()) + + assert sse._read_stream_writers == {} + + +@requirement("transport:sse:post:session-routing") +async def test_post_without_a_session_id_is_rejected() -> None: + """A POST to the message endpoint with no session_id query parameter is answered 400.""" + app, _ = build_sse_app(Server("legacy")) + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http: + response = await http.post("/messages/", json={"jsonrpc": "2.0", "method": "ping", "id": 1}) + assert (response.status_code, response.text) == snapshot((400, "session_id is required")) + + +@requirement("transport:sse:post:session-routing") +async def test_post_with_a_malformed_session_id_is_rejected() -> None: + """A POST whose session_id query parameter is not a UUID is answered 400.""" + app, _ = build_sse_app(Server("legacy")) + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http: + response = await http.post( + "/messages/", params={"session_id": "not-a-uuid"}, json={"jsonrpc": "2.0", "method": "ping", "id": 1} + ) + assert (response.status_code, response.text) == snapshot((400, "Invalid session ID")) + + +@requirement("transport:sse:post:session-routing") +async def test_post_for_an_unknown_session_is_rejected() -> None: + """A POST naming a well-formed session_id that no SSE stream owns is answered 404.""" + app, _ = build_sse_app(Server("legacy")) + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http: + response = await http.post( + "/messages/", params={"session_id": uuid4().hex}, json={"jsonrpc": "2.0", "method": "ping", "id": 1} + ) + assert (response.status_code, response.text) == snapshot((404, "Could not find session")) diff --git a/tests/interaction/transports/test_stdio.py b/tests/interaction/transports/test_stdio.py new file mode 100644 index 0000000000..27cc65de42 --- /dev/null +++ b/tests/interaction/transports/test_stdio.py @@ -0,0 +1,143 @@ +"""The stdio transport: one subprocess end-to-end test and one in-process framing test. + +Everything else in the suite runs in a single process; the subprocess test exists to prove the same +client↔server round trip works over the stdio transport's real boundary (a child process whose +stdin/stdout carry one newline-delimited JSON-RPC message per line). The server lives in +`_stdio_server.py` and is launched via `python -m` so subprocess coverage measurement applies. + +The framing test drives `stdio_server` in-process by passing it injected text streams instead of the +real stdin/stdout, so the raw lines the transport writes can be asserted directly without a process +boundary. + +stdio is deliberately not a leg of the `connect`-fixture matrix: spawning a subprocess per test +would be slow, and the matrix already proves transport-agnosticism over three in-process +transports. Process-lifecycle edge cases (escalation to terminate/kill, parse errors) are covered by +`tests/client/test_stdio.py` and stay deferred here. +""" + +import io +import json +import os +import sys +import tempfile +from pathlib import Path + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp.client.client import Client +from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.server.stdio import stdio_server +from mcp.shared.message import SessionMessage +from mcp.types import ( + CallToolResult, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + LoggingMessageNotificationParams, + TextContent, +) +from mcp.types.jsonrpc import jsonrpc_message_adapter +from tests.interaction._connect import initialize_body +from tests.interaction._requirements import requirement +from tests.interaction.transports import _stdio_server + +pytestmark = pytest.mark.anyio + +_REPO_ROOT = Path(__file__).parents[3] + + +@requirement("transport:stdio") +@requirement("transport:stdio:clean-shutdown") +@requirement("transport:stdio:stderr-passthrough") +async def test_tool_call_and_notification_round_trip_over_a_stdio_subprocess() -> None: + """A Client connected over stdio initializes, calls a tool with arguments, receives the + server's log notification before the call returns, and the server exits when the transport + closes its stdin.""" + received: list[LoggingMessageNotificationParams] = [] + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params) + + with tempfile.TemporaryFile(mode="w+") as errlog: + transport = stdio_client( + StdioServerParameters( + command=sys.executable, + args=["-m", _stdio_server.__name__], + cwd=str(_REPO_ROOT), + # stdio_client deliberately filters the inherited environment to a safe minimum, + # which drops the variables coverage.py's subprocess support uses; pass them through + # so the server module is measured. Empty when not running under coverage. + env={key: value for key, value in os.environ.items() if key.startswith("COVERAGE_")}, + ), + errlog=errlog, + ) + + with anyio.fail_after(10): + async with Client(transport, logging_callback=collect) as client: + assert client.initialize_result.server_info.name == "stdio-echo" + result = await client.call_tool("echo", {"text": "across\nprocesses"}) + + errlog.seek(0) + captured_stderr = errlog.read() + + assert result == snapshot(CallToolResult(content=[TextContent(text="across\nprocesses")])) + # stdio carries one ordered server→client stream, so the same notification-before-response + # guarantee holds here as for the in-memory transport. + assert received == snapshot( + [LoggingMessageNotificationParams(level="info", logger="echo", data="echoing across\nprocesses")] + ) + # The server writes this line only after its run loop returns, which happens when stdin closes: + # seeing it proves the process exited on its own rather than via the transport's terminate + # escalation, without a timing-based assertion. The capture itself proves stderr passthrough: + # the transport routes the child's stderr to the caller's `errlog` without consuming it. + assert captured_stderr == snapshot("stdio-echo: clean exit\n") + + +@requirement("transport:stdio:stream-purity") +@requirement("transport:stdio:no-embedded-newlines") +async def test_stdio_server_writes_one_jsonrpc_message_per_line() -> None: + """Everything `stdio_server` writes is a valid JSON-RPC message on its own line, and nothing else. + + The transport's stdin/stdout parameters are public, so the test injects in-process text streams + instead of the real process handles and drives the read/write streams directly: a JSON-RPC line on + stdin is parsed and delivered, and every message sent on the write stream appears as exactly one + newline-terminated line whose payload newlines are JSON-escaped. This proves the transport's own + framing; it does not guard `sys.stdout` against handler code that prints to it directly (see the + divergence on `transport:stdio:stream-purity`). + """ + captured = io.StringIO() + sent_line = json.dumps(initialize_body(request_id=1)) + "\n" + + with anyio.fail_after(5): + async with ( + stdio_server(stdin=anyio.wrap_file(io.StringIO(sent_line)), stdout=anyio.wrap_file(captured)) as ( + read_stream, + write_stream, + ), + read_stream, + write_stream, + ): + received = await read_stream.receive() + assert isinstance(received, SessionMessage) + assert isinstance(received.message, JSONRPCRequest) + assert received.message.method == "initialize" + + response = JSONRPCResponse(jsonrpc="2.0", id=1, result={"text": "line\nbreak"}) + notification = JSONRPCNotification( + jsonrpc="2.0", method="notifications/message", params={"level": "info", "data": "two\nlines"} + ) + await write_stream.send(SessionMessage(response)) + await write_stream.send(SessionMessage(notification)) + + output = captured.getvalue() + assert output.endswith("\n") + lines = output.removesuffix("\n").split("\n") + assert len(lines) == 2 + messages = [jsonrpc_message_adapter.validate_json(line) for line in lines] + assert [type(message).__name__ for message in messages] == snapshot(["JSONRPCResponse", "JSONRPCNotification"]) + # The newline inside the payload is JSON-escaped on the wire, not a literal newline that would + # break the one-message-per-line framing. + assert r"line\nbreak" in lines[0] + assert r"two\nlines" in lines[1] diff --git a/tests/interaction/transports/test_streamable_http.py b/tests/interaction/transports/test_streamable_http.py new file mode 100644 index 0000000000..d38e2a0bb3 --- /dev/null +++ b/tests/interaction/transports/test_streamable_http.py @@ -0,0 +1,168 @@ +"""Behaviour specific to the streamable HTTP transport, exercised entirely in process. + +Transport-agnostic behaviour is covered by the `connect`-fixture matrix, which runs the rest of +the suite over this transport as well; this file only pins what cannot be observed in memory: the +server's stateless and JSON-response modes, the standalone GET stream, and the full-duplex +server-initiated exchange on a still-open call. Every test drives the server's real Starlette app +through the suite's streaming ASGI bridge — no sockets, threads, or subprocesses. +""" + +import anyio +import pytest +from inline_snapshot import snapshot +from pydantic import BaseModel + +from mcp.client import ClientRequestContext +from mcp.server.elicitation import AcceptedElicitation +from mcp.server.mcpserver import Context, MCPServer +from mcp.types import ( + CallToolResult, + ElicitRequestParams, + ElicitResult, + LoggingMessageNotification, + LoggingMessageNotificationParams, + ResourceUpdatedNotification, + ResourceUpdatedNotificationParams, + TextContent, +) +from tests.interaction._connect import connect_over_streamable_http +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +def _smoke_server() -> MCPServer: + """A server exercising each message shape the transport-specific tests need.""" + mcp = MCPServer("smoke", instructions="Talk to the smoke server.") + + @mcp.tool() + def echo(text: str) -> str: + """Echo the text back.""" + return text + + class Confirmation(BaseModel): + confirmed: bool + + @mcp.tool() + async def ask(ctx: Context) -> str: + """Elicit a confirmation from the client and report the outcome.""" + answer = await ctx.elicit("Proceed?", Confirmation) + # In stateless mode the elicit raises before this point: there is no session to call back through. + assert isinstance(answer, AcceptedElicitation) + return f"confirmed={answer.data.confirmed}" + + @mcp.tool() + async def announce(ctx: Context) -> str: + """Send one notification related to this request and one that is not.""" + await ctx.info("about to announce") + await ctx.session.send_resource_updated("file:///watched.txt") + return "announced" + + return mcp + + +@requirement("transport:streamable-http:json-response") +@requirement("client-transport:http:json-response-parsed") +async def test_tool_call_over_streamable_http_with_json_responses() -> None: + """The round trip works when the server answers with a single JSON body instead of an SSE stream.""" + async with connect_over_streamable_http(_smoke_server(), json_response=True) as client: + assert client.initialize_result.server_info.name == "smoke" + result = await client.call_tool("echo", {"text": "as json"}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="as json")], structured_content={"result": "as json"}) + ) + + +@requirement("transport:streamable-http:stateless") +async def test_tool_calls_over_stateless_streamable_http() -> None: + """Consecutive requests each succeed against a stateless server with no session to share.""" + async with connect_over_streamable_http(_smoke_server(), stateless_http=True) as client: + first = await client.call_tool("echo", {"text": "first"}) + second = await client.call_tool("echo", {"text": "second"}) + + assert first == snapshot( + CallToolResult(content=[TextContent(text="first")], structured_content={"result": "first"}) + ) + assert second == snapshot( + CallToolResult(content=[TextContent(text="second")], structured_content={"result": "second"}) + ) + + +@requirement("transport:streamable-http:stateless-restrictions") +async def test_stateless_streamable_http_rejects_server_initiated_requests() -> None: + """A handler that tries to call back to the client in stateless mode fails: there is no session.""" + async with connect_over_streamable_http(_smoke_server(), stateless_http=True) as client: + result = await client.call_tool("ask", {}) + + assert result.is_error is True + assert isinstance(result.content[0], TextContent) + # The exact message is the StatelessModeNotSupported exception text wrapped by the tool-error + # path; pin the stable prefix rather than the full exception prose. + assert result.content[0].text.startswith("Error executing tool ask:") + + +@requirement("transport:streamable-http:notifications") +@requirement("transport:streamable-http:unrelated-messages") +@requirement("hosting:http:standalone-sse") +async def test_unrelated_server_messages_arrive_on_the_standalone_stream() -> None: + """A server message with no related request reaches the client through the standalone GET stream. + + The log notification is related to the tool call and travels on that call's own SSE stream; + the resource-updated notification is not related to any request, so the only way it can reach + the client is the standalone stream the client opens after initialization. Delivery order + across the two streams is not guaranteed, so the unrelated message is awaited rather than + assumed to beat the tool result. + """ + received: list[IncomingMessage] = [] + resource_update_seen = anyio.Event() + + async def collect(message: IncomingMessage) -> None: + received.append(message) + if isinstance(message, ResourceUpdatedNotification): + resource_update_seen.set() + + async with connect_over_streamable_http(_smoke_server(), message_handler=collect) as client: + result = await client.call_tool("announce", {}) + with anyio.fail_after(5): + await resource_update_seen.wait() + + assert result == snapshot( + CallToolResult(content=[TextContent(text="announced")], structured_content={"result": "announced"}) + ) + # The related log notification rides the call's stream; the unrelated resource-updated + # notification rides the standalone stream. Both arrive, nothing else does. + assert [message for message in received if isinstance(message, LoggingMessageNotification)] == snapshot( + [LoggingMessageNotification(params=LoggingMessageNotificationParams(level="info", data="about to announce"))] + ) + assert [message for message in received if isinstance(message, ResourceUpdatedNotification)] == snapshot( + [ResourceUpdatedNotification(params=ResourceUpdatedNotificationParams(uri="file:///watched.txt"))] + ) + assert len(received) == 2 + + +@requirement("transport:streamable-http:stateful") +@requirement("transport:streamable-http:server-to-client") +async def test_server_initiated_elicitation_round_trips_during_a_tool_call() -> None: + """An elicitation issued mid-call reaches the client and its answer reaches the handler over stateful HTTP. + + The elicitation request travels on the still-open SSE response of the tool call that triggered + it, and the client's answer arrives as a separate POST -- the full-duplex exchange the + streamable HTTP transport exists to provide. + """ + asked: list[ElicitRequestParams] = [] + + async def answer(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + asked.append(params) + return ElicitResult(action="accept", content={"confirmed": True}) + + async with connect_over_streamable_http(_smoke_server(), elicitation_callback=answer) as client: + # Bounded because a harness regression here historically meant deadlock, not failure. + with anyio.fail_after(5): + result = await client.call_tool("ask", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="confirmed=True")], structured_content={"result": "confirmed=True"}) + ) + assert [params.message for params in asked] == snapshot(["Proceed?"]) diff --git a/uv.lock b/uv.lock index b396898b66..5b72e97fce 100644 --- a/uv.lock +++ b/uv.lock @@ -939,7 +939,7 @@ dev = [ { name = "mcp", extras = ["cli", "ws"], editable = "." }, { name = "pillow", specifier = ">=12.0" }, { name = "pyright", specifier = ">=1.1.400" }, - { name = "pytest", specifier = ">=8.3.4" }, + { name = "pytest", specifier = ">=8.4.0" }, { name = "pytest-examples", specifier = ">=0.0.14" }, { name = "pytest-flakefinder", specifier = ">=1.1.0" }, { name = "pytest-pretty", specifier = ">=1.2.0" }, From 616476f6927a5c64213ea97bbd36a7466f410775 Mon Sep 17 00:00:00 2001 From: Max <224885523+maxisbey@users.noreply.github.com> Date: Fri, 29 May 2026 17:48:40 +0100 Subject: [PATCH 69/84] Bind transport sessions to the authenticated principal (#2718) --- src/mcp/server/auth/middleware/bearer_auth.py | 26 +- src/mcp/server/sse.py | 66 ++-- src/mcp/server/streamable_http_manager.py | 40 ++- src/mcp/server/transport_security.py | 14 +- tests/server/test_sse_security.py | 288 +++++++++++++++++- tests/server/test_streamable_http_manager.py | 167 +++++++++- tests/server/test_transport_security.py | 88 ++++++ 7 files changed, 647 insertions(+), 42 deletions(-) create mode 100644 tests/server/test_transport_security.py diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index 2eafdc793e..ba66e94226 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -1,6 +1,6 @@ import json import time -from typing import Any +from typing import Any, TypedDict from pydantic import AnyHttpUrl from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser @@ -19,6 +19,30 @@ def __init__(self, auth_info: AccessToken): self.scopes = auth_info.scopes +class AuthorizationContext(TypedDict): + client_id: str + issuer: str | None + subject: str | None + + +def authorization_context(user: AuthenticatedUser) -> AuthorizationContext: + """Identify the principal `user` represents, for transports to compare + against the principal that created a session. Components the token + verifier does not supply are `None`, so the comparison degrades to the + remaining components. + + See `examples/servers/simple-auth/mcp_simple_auth/token_verifier.py` for + a verifier that populates `subject` and `claims` from an introspection + response.""" + token = user.access_token + issuer = (token.claims or {}).get("iss") + return AuthorizationContext( + client_id=token.client_id, + issuer=str(issuer) if issuer is not None else None, + subject=token.subject, + ) + + class BearerAuthBackend(AuthenticationBackend): """Authentication backend that validates Bearer tokens using a TokenVerifier.""" diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index be8e979c9d..05e948332b 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -50,6 +50,7 @@ async def handle_sse(request): from starlette.types import Receive, Scope, Send from mcp import types +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, AuthorizationContext, authorization_context from mcp.server.transport_security import ( TransportSecurityMiddleware, TransportSecuritySettings, @@ -73,6 +74,9 @@ class SseServerTransport: _endpoint: str _read_stream_writers: dict[UUID, ContextSendStream[SessionMessage | Exception]] + # Identity of the credential that created each session; requests for a + # session must present the same credential. + _session_owners: dict[UUID, AuthorizationContext] _security: TransportSecurityMiddleware def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | None = None) -> None: @@ -112,19 +116,20 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | self._endpoint = endpoint self._read_stream_writers = {} + self._session_owners = {} self._security = TransportSecurityMiddleware(security_settings) logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}") @asynccontextmanager async def connect_sse(self, scope: Scope, receive: Receive, send: Send): - if scope["type"] != "http": # pragma: no cover + if scope["type"] != "http": logger.error("connect_sse received non-HTTP request") raise ValueError("connect_sse can only handle HTTP requests") # Validate request headers for DNS rebinding protection request = Request(scope, receive) error_response = await self._security.validate_request(request, is_post=False) - if error_response: # pragma: no cover + if error_response: await error_response(scope, receive, send) raise ValueError("Request validation failed") @@ -134,6 +139,9 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): write_stream, write_stream_reader = create_context_streams[SessionMessage](0) session_id = uuid4() + user = scope.get("user") + if isinstance(user, AuthenticatedUser): + self._session_owners[session_id] = authorization_context(user) self._read_stream_writers[session_id] = read_stream_writer logger.debug(f"Created new session with ID: {session_id}") @@ -169,27 +177,30 @@ async def sse_writer(): } ) - async with anyio.create_task_group() as tg: - - async def response_wrapper(scope: Scope, receive: Receive, send: Send): - """The EventSourceResponse returning signals a client close / disconnect. - In this case we close our side of the streams to signal the client that - the connection has been closed. - """ - await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)( - scope, receive, send - ) - await sse_stream_reader.aclose() - await read_stream_writer.aclose() - await write_stream_reader.aclose() - self._read_stream_writers.pop(session_id, None) - logging.debug(f"Client session disconnected {session_id}") + try: + async with anyio.create_task_group() as tg: + + async def response_wrapper(scope: Scope, receive: Receive, send: Send): + """The EventSourceResponse returning signals a client close / disconnect. + In this case we close our side of the streams to signal the client that + the connection has been closed. + """ + await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)( + scope, receive, send + ) + await read_stream_writer.aclose() + await write_stream_reader.aclose() + await sse_stream_reader.aclose() + logging.debug(f"Client session disconnected {session_id}") - logger.debug("Starting SSE response task") - tg.start_soon(response_wrapper, scope, receive, send) + logger.debug("Starting SSE response task") + tg.start_soon(response_wrapper, scope, receive, send) - logger.debug("Yielding read and write streams") - yield (read_stream, write_stream) + logger.debug("Yielding read and write streams") + yield (read_stream, write_stream) + finally: + self._read_stream_writers.pop(session_id, None) + self._session_owners.pop(session_id, None) async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: logger.debug("Handling POST message") @@ -197,7 +208,7 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) # Validate request headers for DNS rebinding protection error_response = await self._security.validate_request(request, is_post=True) - if error_response: # pragma: no cover + if error_response: return await error_response(scope, receive, send) session_id_param = request.query_params.get("session_id") @@ -220,13 +231,22 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) response = Response("Could not find session", status_code=404) return await response(scope, receive, send) + user = scope.get("user") + requestor = authorization_context(user) if isinstance(user, AuthenticatedUser) else None + if requestor != self._session_owners.get(session_id): + # A session can only be used with the credential that created it. + # Respond exactly as if the session did not exist. + logger.warning("Rejecting message for session %s: credential does not match", session_id) + response = Response("Could not find session", status_code=404) + return await response(scope, receive, send) + body = await request.body() logger.debug(f"Received JSON: {body}") try: message = types.jsonrpc_message_adapter.validate_json(body, by_name=False) logger.debug(f"Validated client message: {message}") - except ValidationError as err: # pragma: no cover + except ValidationError as err: logger.exception("Failed to parse message") response = Response("Could not parse message", status_code=400) await response(scope, receive, send) diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 39d434505c..81350a8f24 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -5,7 +5,6 @@ import contextlib import logging from collections.abc import AsyncIterator -from http import HTTPStatus from typing import TYPE_CHECKING, Any from uuid import uuid4 @@ -15,6 +14,7 @@ from starlette.responses import Response from starlette.types import Receive, Scope, Send +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, AuthorizationContext, authorization_context from mcp.server.streamable_http import ( MCP_SESSION_ID_HEADER, EventStore, @@ -89,6 +89,9 @@ def __init__( # Session tracking (only used if not stateless) self._session_creation_lock = anyio.Lock() self._server_instances: dict[str, StreamableHTTPServerTransport] = {} + # Identity of the credential that created each session; requests for a + # session must present the same credential. + self._session_owners: dict[str, AuthorizationContext] = {} # The task group will be set during lifespan self._task_group = None @@ -135,6 +138,7 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]: self._task_group = None # Clear any remaining server instances self._server_instances.clear() + self._session_owners.clear() async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: """Process ASGI request with proper session handling and transport setup. @@ -192,9 +196,29 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S request = Request(scope, receive) request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) + user = scope.get("user") + requestor = authorization_context(user) if isinstance(user, AuthenticatedUser) else None + # Existing session case if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: transport = self._server_instances[request_mcp_session_id] + if requestor != self._session_owners.get(request_mcp_session_id): + # A session can only be used with the credential that created + # it. Respond exactly as if the session did not exist. + logger.warning( + "Rejecting request for session %s: credential does not match the one that created the session", + request_mcp_session_id[:64], + ) + body = JSONRPCError( + jsonrpc="2.0", id=None, error=ErrorData(code=INVALID_REQUEST, message="Session not found") + ) + response = Response( + body.model_dump_json(by_alias=True, exclude_unset=True), + status_code=404, + media_type="application/json", + ) + await response(scope, receive, send) + return logger.debug("Session already exists, handling request directly") # Push back idle deadline on activity if transport.idle_scope is not None and self.session_idle_timeout is not None: @@ -216,6 +240,8 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S ) assert http_transport.mcp_session_id is not None + if requestor is not None: + self._session_owners[http_transport.mcp_session_id] = requestor self._server_instances[http_transport.mcp_session_id] = http_transport logger.info(f"Created new transport with session ID: {new_session_id}") @@ -246,6 +272,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE assert http_transport.mcp_session_id is not None logger.info(f"Session {http_transport.mcp_session_id} idle timeout") self._server_instances.pop(http_transport.mcp_session_id, None) + self._session_owners.pop(http_transport.mcp_session_id, None) await http_transport.terminate() except Exception: logger.exception(f"Session {http_transport.mcp_session_id} crashed") @@ -260,6 +287,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE f"{http_transport.mcp_session_id} from active instances." ) del self._server_instances[http_transport.mcp_session_id] + self._session_owners.pop(http_transport.mcp_session_id, None) # Assert task group is not None for type checking assert self._task_group is not None @@ -273,15 +301,11 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE # TODO: Align error code once spec clarifies # See: https://github.com/modelcontextprotocol/python-sdk/issues/1821 logger.info(f"Rejected request with unknown or expired session ID: {request_mcp_session_id[:64]}") - error_response = JSONRPCError( - jsonrpc="2.0", - id=None, - error=ErrorData(code=INVALID_REQUEST, message="Session not found"), + body = JSONRPCError( + jsonrpc="2.0", id=None, error=ErrorData(code=INVALID_REQUEST, message="Session not found") ) response = Response( - content=error_response.model_dump_json(by_alias=True, exclude_unset=True), - status_code=HTTPStatus.NOT_FOUND, - media_type="application/json", + body.model_dump_json(by_alias=True, exclude_unset=True), status_code=404, media_type="application/json" ) await response(scope, receive, send) diff --git a/src/mcp/server/transport_security.py b/src/mcp/server/transport_security.py index 707d4b61dd..d9e9f965b3 100644 --- a/src/mcp/server/transport_security.py +++ b/src/mcp/server/transport_security.py @@ -42,17 +42,17 @@ def __init__(self, settings: TransportSecuritySettings | None = None): def _validate_host(self, host: str | None) -> bool: """Validate the Host header against allowed values.""" - if not host: # pragma: no cover + if not host: logger.warning("Missing Host header in request") return False # Check exact match first - if host in self.settings.allowed_hosts: # pragma: no cover + if host in self.settings.allowed_hosts: return True # Check wildcard port patterns for allowed in self.settings.allowed_hosts: - if allowed.endswith(":*"): # pragma: no branch + if allowed.endswith(":*"): # Extract base host from pattern base_host = allowed[:-2] # Check if the actual host starts with base host and has a port @@ -65,16 +65,16 @@ def _validate_host(self, host: str | None) -> bool: def _validate_origin(self, origin: str | None) -> bool: """Validate the Origin header against allowed values.""" # Origin can be absent for same-origin requests - if not origin: # pragma: no cover + if not origin: return True # Check exact match first - if origin in self.settings.allowed_origins: # pragma: no cover + if origin in self.settings.allowed_origins: return True # Check wildcard port patterns for allowed in self.settings.allowed_origins: - if allowed.endswith(":*"): # pragma: no branch + if allowed.endswith(":*"): # Extract base origin from pattern base_origin = allowed[:-2] # Check if the actual origin starts with base origin and has a port @@ -94,7 +94,7 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res Returns None if validation passes, or an error Response if validation fails. """ # Always validate Content-Type for POST requests - if is_post: # pragma: no branch + if is_post: content_type = request.headers.get("content-type") if not self._validate_content_type(content_type): return Response("Invalid Content-Type header", status_code=400) diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index 010eaf6a25..e95dc51b31 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -1,27 +1,44 @@ -"""Tests for SSE server DNS rebinding protection.""" +"""Tests for SSE server request validation.""" import logging import multiprocessing +import re import socket +import anyio import httpx import pytest +import sse_starlette.sse import uvicorn from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response from starlette.routing import Mount, Route +from starlette.types import Message, Receive, Scope, Send from mcp.server import Server +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser +from mcp.server.auth.provider import AccessToken from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings -from mcp.types import Tool +from mcp.shared._stream_protocols import WriteStream +from mcp.shared.message import SessionMessage +from mcp.types import JSONRPCRequest, JSONRPCResponse, Tool from tests.test_helpers import wait_for_server logger = logging.getLogger(__name__) SERVER_NAME = "test_sse_security_server" +@pytest.fixture(autouse=True) +def reset_sse_starlette_exit_event() -> None: + """sse-starlette<2 caches a module-level anyio.Event on AppStatus; reset it + between tests so it is not bound to a previous test's event loop.""" + app_status = getattr(sse_starlette.sse, "AppStatus", None) + if app_status is not None and hasattr(app_status, "should_exit_event"): # pragma: lax no cover + app_status.should_exit_event = None + + @pytest.fixture def server_port() -> int: with socket.socket() as s: @@ -291,3 +308,270 @@ async def test_sse_security_post_valid_content_type(server_port: int): finally: process.terminate() process.join() + + +def _authenticated_user(client_id: str, subject: str | None = None, issuer: str | None = None) -> AuthenticatedUser: + """Build the scope["user"] value that AuthenticationMiddleware would set for this principal.""" + claims = {"iss": issuer} if issuer is not None else None + return AuthenticatedUser(AccessToken(token="token", client_id=client_id, scopes=[], subject=subject, claims=claims)) + + +def _sse_scope( + method: str, path: str, user: AuthenticatedUser | None, *, query_string: bytes = b"", body: bytes = b"" +) -> tuple[Scope, Receive, Send, list[Message]]: + """Build an ASGI scope/receive/send triple for a request to the SSE transport.""" + scope: Scope = { + "type": "http", + "method": method, + "path": path, + "root_path": "", + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + } + if user is not None: + scope["user"] = user + sent: list[Message] = [] + + async def receive() -> Message: + return {"type": "http.request", "body": body, "more_body": False} + + async def send(message: Message) -> None: + sent.append(message) + + return scope, receive, send, sent + + +def _response_status(sent: list[Message]) -> int: + response_start = next(msg for msg in sent if msg["type"] == "http.response.start") + return response_start["status"] + + +async def _post_message(transport: SseServerTransport, session_id: str, user: AuthenticatedUser | None) -> int: + """POST a message to an SSE session as `user` and return the response status.""" + body = b'{"jsonrpc": "2.0", "id": 1, "method": "ping", "params": null}' + scope, receive, send, sent = _sse_scope( + "POST", "/messages/", user, query_string=f"session_id={session_id}".encode(), body=body + ) + await transport.handle_post_message(scope, receive, send) + return _response_status(sent) + + +_Principal = tuple[str] | tuple[str, str] | tuple[str, str, str] + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("creator", "sender", "expected"), + [ + pytest.param(("client-a",), ("client-b",), 404, id="different-client"), + pytest.param(("client-a",), None, 404, id="unauthenticated-sender"), + pytest.param(("client-a", "alice"), ("client-a", "bob"), 404, id="same-client-different-subject"), + pytest.param(("client-a", "alice"), ("client-a",), 404, id="same-client-no-subject"), + pytest.param( + ("client-a", "alice", "https://i1"), ("client-a", "alice", "https://i2"), 404, id="different-issuer" + ), + pytest.param(None, ("client-a",), 404, id="unauthenticated-creator"), + pytest.param(("client-a",), ("client-a",), 202, id="same-client"), + pytest.param(("client-a", "alice"), ("client-a", "alice"), 202, id="same-client-and-subject"), + pytest.param(None, None, 202, id="both-unauthenticated"), + ], +) +async def test_sse_post_requires_the_credential_that_created_the_session( + creator: _Principal | None, + sender: _Principal | None, + expected: int, +): + """The session endpoint URL issued to one authenticated principal must not + accept messages from a request authenticated as a different one.""" + transport = SseServerTransport("/messages/") + session_id_received = anyio.Event() + session_ids: list[str] = [] + client_disconnected = anyio.Event() + + async def get_send(message: Message) -> None: + # The first body chunk is the SSE event announcing the session URI to POST messages to. + if message["type"] == "http.response.body" and not session_ids: + match = re.search(rb"session_id=([0-9a-f]{32})", message.get("body", b"")) + assert match is not None, f"expected the endpoint event first, got {message!r}" + session_ids.append(match.group(1).decode()) + session_id_received.set() + + async def get_receive() -> Message: + # The SSE client stays connected until the test signals otherwise. + await client_disconnected.wait() + return {"type": "http.disconnect"} + + creator_user = _authenticated_user(*creator) if creator is not None else None + sender_user = _authenticated_user(*sender) if sender is not None else None + + async def hold_sse_connection() -> None: + """Establish the SSE session as `creator` and keep it open, as a server would.""" + scope, _, _, _ = _sse_scope("GET", "/sse", creator_user) + with anyio.fail_after(5): + async with transport.connect_sse(scope, get_receive, get_send) as (read_stream, write_stream): + async with read_stream, write_stream: # pragma: no branch + async for _ in read_stream: + pass + + async with anyio.create_task_group() as tg: + tg.start_soon(hold_sse_connection) + with anyio.fail_after(5): + await session_id_received.wait() + + assert await _post_message(transport, session_ids[0], sender_user) == expected + + client_disconnected.set() + + # Once the connection is gone the session is no longer routable. + assert await _post_message(transport, session_ids[0], creator_user) == 404 + + +@pytest.mark.anyio +async def test_sse_connect_rejects_a_non_http_scope(): + """connect_sse refuses ASGI scopes that are not HTTP requests.""" + transport = SseServerTransport("/messages/") + with pytest.raises(ValueError): + async with transport.connect_sse({"type": "websocket"}, _no_receive, _no_send): + raise NotImplementedError + + +@pytest.mark.anyio +async def test_sse_connect_rejects_a_disallowed_host(): + """connect_sse rejects requests whose Host header fails the configured security check.""" + settings = TransportSecuritySettings(allowed_hosts=["allowed.example.com"]) + transport = SseServerTransport("/messages/", security_settings=settings) + scope, receive, send, sent = _sse_scope("GET", "/sse", None) + scope["headers"] = [(b"host", b"disallowed.example.com")] + + with pytest.raises(ValueError): + async with transport.connect_sse(scope, receive, send): + raise NotImplementedError + assert _response_status(sent) == 421 + + +@pytest.mark.anyio +async def test_sse_post_without_a_session_id_returns_400(): + """POSTs to the messages endpoint must include a session_id query parameter.""" + transport = SseServerTransport("/messages/") + scope, receive, send, sent = _sse_scope("POST", "/messages/", None) + + await transport.handle_post_message(scope, receive, send) + assert _response_status(sent) == 400 + + +@pytest.mark.anyio +async def test_sse_post_with_a_malformed_session_id_returns_400(): + """A session_id that is not 32 hex characters is rejected before any session lookup.""" + transport = SseServerTransport("/messages/") + scope, receive, send, sent = _sse_scope("POST", "/messages/", None, query_string=b"session_id=not-hex") + + await transport.handle_post_message(scope, receive, send) + assert _response_status(sent) == 400 + + +@pytest.mark.anyio +async def test_sse_post_with_a_disallowed_host_is_rejected_before_session_lookup(): + """The transport security check on POST runs before any session-ID handling.""" + settings = TransportSecuritySettings(allowed_hosts=["allowed.example.com"]) + transport = SseServerTransport("/messages/", security_settings=settings) + scope, receive, send, sent = _sse_scope("POST", "/messages/", None) + scope["headers"] = [(b"host", b"disallowed.example.com"), (b"content-type", b"application/json")] + + await transport.handle_post_message(scope, receive, send) + assert _response_status(sent) == 421 + + +@pytest.mark.anyio +async def test_sse_round_trip_delivers_posted_messages_and_streams_responses(): + """A POSTed JSON-RPC message reaches the server's read stream, and a message + written to the server's write stream is sent to the client as an SSE event.""" + transport = SseServerTransport("/messages/") + session = _SseSession(transport) + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + tg.start_soon(session.hold) + await session.ready.wait() + + # POST a parse-failing body: client gets 400, server's read stream receives the error. + scope, receive, send, sent = _sse_scope( + "POST", "/messages/", None, query_string=f"session_id={session.session_id}".encode(), body=b"not json" + ) + await transport.handle_post_message(scope, receive, send) + assert _response_status(sent) == 400 + assert isinstance(await session.next_read_item(), Exception) + + # POST a valid message: client gets 202, server's read stream receives it. + assert await _post_message(transport, session.session_id, None) == 202 + received = await session.next_read_item() + assert isinstance(received, SessionMessage) + assert isinstance(received.message, JSONRPCRequest) + assert received.message.method == "ping" + + # Server writes a response: it appears as an SSE `message` event on the GET stream. + outgoing = JSONRPCResponse(jsonrpc="2.0", id=1, result={}) + await session.write_stream.send(SessionMessage(outgoing)) + chunk = await session.next_body_chunk() + assert b"event: message" in chunk + assert outgoing.model_dump_json(by_alias=True, exclude_unset=True).encode() in chunk + + session.disconnect() + + +class _SseSession: + """Drive an in-process SSE GET connection and surface what the server reads and the client receives. + + `hold` runs the connection in a background task and consumes the server-side read stream + into a buffer so that `handle_post_message` (which writes to that stream with a zero-capacity + channel) never blocks the test body. + """ + + def __init__(self, transport: SseServerTransport) -> None: + self.transport = transport + self.ready = anyio.Event() + self._disconnected = anyio.Event() + self._body_send, self._body_recv = anyio.create_memory_object_stream[bytes](16) + self._read_send, self._read_recv = anyio.create_memory_object_stream[SessionMessage | Exception](16) + self.session_id = "" + self.write_stream: WriteStream[SessionMessage] + + async def hold(self) -> None: + scope, _, _, _ = _sse_scope("GET", "/sse", None) + async with self.transport.connect_sse(scope, self._receive, self._send) as (read, write): + self.write_stream = write + async with read, write, self._body_send, self._body_recv, self._read_send, self._read_recv: + async for item in read: + await self._read_send.send(item) + + def disconnect(self) -> None: + self._disconnected.set() + + async def next_read_item(self) -> SessionMessage | Exception: + return await self._read_recv.receive() + + async def next_body_chunk(self) -> bytes: + return await self._body_recv.receive() + + async def _receive(self) -> Message: + await self._disconnected.wait() + return {"type": "http.disconnect"} + + async def _send(self, message: Message) -> None: + if message["type"] != "http.response.body": + return + body: bytes = message.get("body", b"") + if not self.session_id: + match = re.search(rb"session_id=([0-9a-f]{32})", body) + assert match is not None, f"expected the endpoint event first, got {message!r}" + self.session_id = match.group(1).decode() + self.ready.set() + else: + await self._body_send.send(body) + + +async def _no_receive() -> Message: + raise NotImplementedError + + +async def _no_send(message: Message) -> None: + raise NotImplementedError diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 47cfbf14a4..ba75547964 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -8,11 +8,13 @@ import anyio import httpx import pytest -from starlette.types import Message +from starlette.types import Message, Scope from mcp import Client from mcp.client.streamable_http import streamable_http_client from mcp.server import Server, ServerRequestContext, streamable_http_manager +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser +from mcp.server.auth.provider import AccessToken from mcp.server.streamable_http import MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.types import INVALID_REQUEST, ListToolsResult, PaginatedRequestParams @@ -413,3 +415,166 @@ def test_session_idle_timeout_rejects_non_positive(): def test_session_idle_timeout_rejects_stateless(): with pytest.raises(RuntimeError, match="not supported in stateless"): StreamableHTTPSessionManager(app=Server("test"), session_idle_timeout=30, stateless=True) + + +def _user(client_id: str, subject: str | None = None, issuer: str | None = None) -> AuthenticatedUser: + """Build the scope["user"] value that AuthenticationMiddleware would set for this principal.""" + claims = {"iss": issuer} if issuer is not None else None + return AuthenticatedUser(AccessToken(token="token", client_id=client_id, scopes=[], subject=subject, claims=claims)) + + +def _request_scope( + *, session_id: str | None = None, user: AuthenticatedUser | None = None, method: str = "POST" +) -> Scope: + """Build an ASGI scope for a request to the MCP endpoint.""" + headers = [ + (b"content-type", b"application/json"), + (b"accept", b"application/json, text/event-stream"), + ] + if session_id is not None: + headers.append((b"mcp-session-id", session_id.encode())) + scope: Scope = { + "type": "http", + "method": method, + "path": "/mcp", + "headers": headers, + } + if user is not None: + scope["user"] = user + return scope + + +async def _open_session(manager: StreamableHTTPSessionManager, user: AuthenticatedUser | None) -> str: + """Create a new session as `user` and return its session ID.""" + sent_messages: list[Message] = [] + + async def mock_send(message: Message) -> None: + sent_messages.append(message) + + async def mock_receive() -> Message: + return {"type": "http.request", "body": b"", "more_body": False} + + await manager.handle_request(_request_scope(user=user), mock_receive, mock_send) + + response_start = next(msg for msg in sent_messages if msg["type"] == "http.response.start") + headers = dict(response_start.get("headers", [])) + return headers[MCP_SESSION_ID_HEADER.encode()].decode() + + +async def _request_session( + manager: StreamableHTTPSessionManager, session_id: str, user: AuthenticatedUser | None, method: str = "POST" +) -> int: + """Send a request for an existing session as `user` and return the response status.""" + sent_messages: list[Message] = [] + + async def mock_send(message: Message) -> None: + sent_messages.append(message) + + async def mock_receive() -> Message: + return {"type": "http.request", "body": b"", "more_body": False} + + await manager.handle_request( + _request_scope(session_id=session_id, user=user, method=method), mock_receive, mock_send + ) + + response_start = next(msg for msg in sent_messages if msg["type"] == "http.response.start") + return response_start["status"] + + +@pytest.fixture +async def manager_with_live_session(): + """A running manager around a real `Server`. Sessions remain registered until + `manager.run()` exits because `Server.run` blocks waiting for an initialize message.""" + manager = StreamableHTTPSessionManager(app=Server("test-session-credentials")) + async with manager.run(): + yield manager + + +@pytest.mark.anyio +async def test_session_accepts_requests_from_the_credential_that_created_it( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """Requests presenting the same credential as the one that created the session are served.""" + manager = manager_with_live_session + session_id = await _open_session(manager, _user("client-a")) + + status = await _request_session(manager, session_id, _user("client-a")) + + # The request passes the manager's credential check and reaches the + # session's transport, instead of being answered with 404 by the manager. + assert status != 404 + + +@pytest.mark.anyio +@pytest.mark.parametrize("method", ["POST", "GET", "DELETE"]) +async def test_session_rejects_requests_from_a_different_credential( + manager_with_live_session: StreamableHTTPSessionManager, method: str +) -> None: + """A session created by one credential cannot be used with another credential, whatever the method.""" + manager = manager_with_live_session + session_id = await _open_session(manager, _user("client-a")) + + assert await _request_session(manager, session_id, _user("client-b"), method) == 404 + # The session is still registered and still serves its creator. + assert await _request_session(manager, session_id, _user("client-a")) != 404 + + +@pytest.mark.anyio +async def test_session_rejects_requests_from_a_different_subject_of_the_same_client( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """Two end-users that share an OAuth client cannot use each other's sessions.""" + manager = manager_with_live_session + session_id = await _open_session(manager, _user("client-a", subject="alice")) + + assert await _request_session(manager, session_id, _user("client-a", subject="bob")) == 404 + assert await _request_session(manager, session_id, _user("client-a", subject=None)) == 404 + assert await _request_session(manager, session_id, _user("client-a", subject="alice")) != 404 + + +@pytest.mark.anyio +async def test_session_rejects_requests_with_the_same_subject_from_a_different_issuer( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """A subject is unique only per issuer, so a colliding subject from a different issuer is not the same principal.""" + manager = manager_with_live_session + creator = _user("client-a", subject="alice", issuer="https://issuer.one") + session_id = await _open_session(manager, creator) + + other_issuer = _user("client-a", subject="alice", issuer="https://issuer.two") + assert await _request_session(manager, session_id, other_issuer) == 404 + assert await _request_session(manager, session_id, _user("client-a", subject="alice")) == 404 + assert await _request_session(manager, session_id, creator) != 404 + + +@pytest.mark.anyio +async def test_session_rejects_unauthenticated_requests_for_an_authenticated_session( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """A session created with a credential cannot be used without one.""" + manager = manager_with_live_session + session_id = await _open_session(manager, _user("client-a")) + + assert await _request_session(manager, session_id, None) == 404 + + +@pytest.mark.anyio +async def test_session_rejects_authenticated_requests_for_an_anonymous_session( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """A session created without a credential cannot be used with one.""" + manager = manager_with_live_session + session_id = await _open_session(manager, None) + + assert await _request_session(manager, session_id, _user("client-a")) == 404 + + +@pytest.mark.anyio +async def test_anonymous_session_accepts_anonymous_requests( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """Servers without authentication keep working: no credential on either side.""" + manager = manager_with_live_session + session_id = await _open_session(manager, None) + + assert await _request_session(manager, session_id, None) != 404 diff --git a/tests/server/test_transport_security.py b/tests/server/test_transport_security.py new file mode 100644 index 0000000000..be28980b53 --- /dev/null +++ b/tests/server/test_transport_security.py @@ -0,0 +1,88 @@ +"""Tests for the transport-security request validation middleware.""" + +import pytest +from starlette.requests import Request + +from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings + + +def _request(host: str | None, origin: str | None, content_type: str | None = "application/json") -> Request: + headers: list[tuple[bytes, bytes]] = [] + if content_type is not None: + headers.append((b"content-type", content_type.encode())) + if host is not None: + headers.append((b"host", host.encode())) + if origin is not None: + headers.append((b"origin", origin.encode())) + return Request({"type": "http", "method": "GET", "headers": headers}) + + +SETTINGS = TransportSecuritySettings( + enable_dns_rebinding_protection=True, + allowed_hosts=["good.example", "wild.example:*"], + allowed_origins=["http://good.example", "http://wild.example:*"], +) + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("host", "origin", "expected"), + [ + pytest.param(None, None, 421, id="missing-host"), + pytest.param("evil.example", None, 421, id="host-no-match"), + pytest.param("evil.example:9000", None, 421, id="host-wildcard-base-mismatch"), + pytest.param("good.example", None, None, id="host-exact-no-origin"), + pytest.param("wild.example:9000", None, None, id="host-wildcard-match"), + pytest.param("good.example", "http://evil.example", 403, id="origin-no-match"), + pytest.param("good.example", "http://evil.example:9000", 403, id="origin-wildcard-base-mismatch"), + pytest.param("good.example", "http://good.example", None, id="origin-exact"), + pytest.param("good.example", "http://wild.example:9000", None, id="origin-wildcard-match"), + ], +) +async def test_validate_request_checks_host_then_origin( + host: str | None, origin: str | None, expected: int | None +) -> None: + """Host is checked first, then Origin; exact and wildcard-port allowlist entries are honoured.""" + middleware = TransportSecurityMiddleware(SETTINGS) + response = await middleware.validate_request(_request(host, origin)) + assert (None if response is None else response.status_code) == expected + + +@pytest.mark.anyio +async def test_validate_request_skips_host_and_origin_when_protection_is_disabled() -> None: + """With DNS-rebinding protection off, any Host/Origin is accepted.""" + middleware = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False)) + assert await middleware.validate_request(_request("evil.example", "http://evil.example")) is None + + +@pytest.mark.anyio +async def test_validate_request_defaults_to_protection_disabled() -> None: + """Constructing the middleware without settings leaves DNS-rebinding protection off.""" + middleware = TransportSecurityMiddleware() + assert await middleware.validate_request(_request("evil.example", "http://evil.example")) is None + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("content_type", "expected"), + [ + pytest.param("application/json", None, id="json"), + pytest.param("application/json; charset=utf-8", None, id="json-with-charset"), + pytest.param("APPLICATION/JSON", None, id="case-insensitive"), + pytest.param("text/plain", 400, id="wrong-type"), + pytest.param(None, 400, id="missing"), + ], +) +async def test_validate_request_checks_content_type_on_post(content_type: str | None, expected: int | None) -> None: + """POST requests must carry an application/json Content-Type, regardless of DNS-rebinding settings.""" + middleware = TransportSecurityMiddleware() + response = await middleware.validate_request(_request("any", None, content_type=content_type), is_post=True) + assert (None if response is None else response.status_code) == expected + + +@pytest.mark.anyio +async def test_validate_request_ignores_content_type_on_get() -> None: + """Content-Type is only enforced for POST requests.""" + middleware = TransportSecurityMiddleware(SETTINGS) + response = await middleware.validate_request(_request("good.example", None, content_type="text/plain")) + assert response is None From 2a3d065417088eafef86d0e1b47eca3facc16130 Mon Sep 17 00:00:00 2001 From: devteamaegis Date: Tue, 2 Jun 2026 11:53:51 -0400 Subject: [PATCH 70/84] fix: rename `.gitattribute` to `.gitattributes` so git actually reads it (#2656) Co-authored-by: devteamaegis --- .gitattribute => .gitattributes | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename .gitattribute => .gitattributes (100%) diff --git a/.gitattribute b/.gitattributes similarity index 100% rename from .gitattribute rename to .gitattributes From 453cafb9a9633ca747e193c7e1694c10545d91d2 Mon Sep 17 00:00:00 2001 From: Siddhiraj Katkar <91388781+siddhirajkatkar@users.noreply.github.com> Date: Tue, 2 Jun 2026 21:31:50 +0530 Subject: [PATCH 71/84] fix: add 'invalid_target' to AuthorizationErrorCode (RFC 8707) (#2642) Co-authored-by: Marcelo Trylesinski --- src/mcp/server/auth/provider.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index 4ce1137575..bb47c19566 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -68,6 +68,7 @@ class RegistrationError(Exception): "invalid_scope", "server_error", "temporarily_unavailable", + "invalid_target", ] From a5b2ebb660e59d4dc52315c2827cde5d5e83e92f Mon Sep 17 00:00:00 2001 From: "S;Co" Date: Tue, 2 Jun 2026 17:02:36 +0100 Subject: [PATCH 72/84] Clarify CLI subprocess environment comment (#2672) Co-authored-by: scosemicolon <277933778+scosemicolon@users.noreply.github.com> --- src/mcp/cli/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/cli/cli.py b/src/mcp/cli/cli.py index 62334a4a2c..44e2bae48e 100644 --- a/src/mcp/cli/cli.py +++ b/src/mcp/cli/cli.py @@ -277,7 +277,7 @@ def dev( [npx_cmd, "@modelcontextprotocol/inspector"] + uv_cmd, check=True, shell=shell, - env=dict(os.environ.items()), # Convert to list of tuples for env update + env=dict(os.environ.items()), # Copy the environment for subprocess launch ) sys.exit(process.returncode) except subprocess.CalledProcessError as e: From 8cc187fac06a85a0c80a84bdf47136a0accbb924 Mon Sep 17 00:00:00 2001 From: Max <224885523+maxisbey@users.noreply.github.com> Date: Tue, 2 Jun 2026 17:27:05 +0100 Subject: [PATCH 73/84] Remove Tasks (SEP-1686) from the SDK (#2714) --- README.v2.md | 1 - docs/experimental/index.md | 42 - docs/experimental/tasks-client.md | 361 -------- docs/experimental/tasks-server.md | 577 ------------ docs/experimental/tasks.md | 188 ---- docs/migration.md | 38 +- examples/clients/simple-task-client/README.md | 43 - .../mcp_simple_task_client/__init__.py | 0 .../mcp_simple_task_client/__main__.py | 5 - .../mcp_simple_task_client/main.py | 56 -- .../clients/simple-task-client/pyproject.toml | 43 - .../simple-task-interactive-client/README.md | 87 -- .../__init__.py | 0 .../__main__.py | 5 - .../main.py | 137 --- .../pyproject.toml | 43 - .../servers/simple-task-interactive/README.md | 74 -- .../mcp_simple_task_interactive/__init__.py | 0 .../mcp_simple_task_interactive/__main__.py | 5 - .../mcp_simple_task_interactive/server.py | 138 --- .../simple-task-interactive/pyproject.toml | 43 - examples/servers/simple-task/README.md | 37 - .../simple-task/mcp_simple_task/__init__.py | 0 .../simple-task/mcp_simple_task/__main__.py | 5 - .../simple-task/mcp_simple_task/server.py | 70 -- examples/servers/simple-task/pyproject.toml | 43 - mkdocs.yml | 6 - src/mcp/client/experimental/__init__.py | 8 - src/mcp/client/experimental/task_handlers.py | 293 ------ src/mcp/client/experimental/tasks.py | 208 ----- src/mcp/client/session.py | 53 +- src/mcp/server/context.py | 2 - src/mcp/server/experimental/__init__.py | 10 - .../server/experimental/request_context.py | 217 ----- .../server/experimental/session_features.py | 209 ----- src/mcp/server/experimental/task_context.py | 587 ------------ .../experimental/task_result_handler.py | 218 ----- src/mcp/server/experimental/task_support.py | 116 --- src/mcp/server/lowlevel/experimental.py | 210 ----- src/mcp/server/lowlevel/server.py | 52 +- src/mcp/server/session.py | 195 ---- src/mcp/server/validation.py | 3 +- src/mcp/shared/experimental/__init__.py | 6 - src/mcp/shared/experimental/tasks/__init__.py | 11 - .../shared/experimental/tasks/capabilities.py | 96 -- src/mcp/shared/experimental/tasks/context.py | 95 -- src/mcp/shared/experimental/tasks/helpers.py | 166 ---- .../tasks/in_memory_task_store.py | 217 ----- .../experimental/tasks/message_queue.py | 230 ----- src/mcp/shared/experimental/tasks/polling.py | 43 - src/mcp/shared/experimental/tasks/resolver.py | 58 -- src/mcp/shared/experimental/tasks/store.py | 144 --- src/mcp/shared/response_router.py | 61 -- src/mcp/shared/session.py | 38 +- src/mcp/types/__init__.py | 82 -- src/mcp/types/_types.py | 304 +----- tests/experimental/__init__.py | 0 tests/experimental/tasks/__init__.py | 1 - tests/experimental/tasks/client/__init__.py | 0 .../tasks/client/test_capabilities.py | 312 ------- .../tasks/client/test_handlers.py | 874 ------------------ .../tasks/client/test_poll_task.py | 121 --- tests/experimental/tasks/client/test_tasks.py | 309 ------- tests/experimental/tasks/server/__init__.py | 0 .../experimental/tasks/server/test_context.py | 183 ---- .../tasks/server/test_integration.py | 247 ----- .../tasks/server/test_run_task_flow.py | 367 -------- .../experimental/tasks/server/test_server.py | 797 ---------------- .../tasks/server/test_server_task_context.py | 709 -------------- tests/experimental/tasks/server/test_store.py | 406 -------- .../tasks/server/test_task_result_handler.py | 354 ------- tests/experimental/tasks/test_capabilities.py | 283 ------ .../tasks/test_elicitation_scenarios.py | 695 -------------- .../experimental/tasks/test_message_queue.py | 330 ------- .../tasks/test_request_context.py | 166 ---- .../tasks/test_spec_compliance.py | 717 -------------- tests/interaction/_requirements.py | 5 +- tests/issues/test_176_progress_token.py | 2 - tests/server/mcpserver/test_server.py | 2 - tests/server/test_session.py | 43 + uv.lock | 124 --- 81 files changed, 63 insertions(+), 12963 deletions(-) delete mode 100644 docs/experimental/index.md delete mode 100644 docs/experimental/tasks-client.md delete mode 100644 docs/experimental/tasks-server.md delete mode 100644 docs/experimental/tasks.md delete mode 100644 examples/clients/simple-task-client/README.md delete mode 100644 examples/clients/simple-task-client/mcp_simple_task_client/__init__.py delete mode 100644 examples/clients/simple-task-client/mcp_simple_task_client/__main__.py delete mode 100644 examples/clients/simple-task-client/mcp_simple_task_client/main.py delete mode 100644 examples/clients/simple-task-client/pyproject.toml delete mode 100644 examples/clients/simple-task-interactive-client/README.md delete mode 100644 examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/__init__.py delete mode 100644 examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/__main__.py delete mode 100644 examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py delete mode 100644 examples/clients/simple-task-interactive-client/pyproject.toml delete mode 100644 examples/servers/simple-task-interactive/README.md delete mode 100644 examples/servers/simple-task-interactive/mcp_simple_task_interactive/__init__.py delete mode 100644 examples/servers/simple-task-interactive/mcp_simple_task_interactive/__main__.py delete mode 100644 examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py delete mode 100644 examples/servers/simple-task-interactive/pyproject.toml delete mode 100644 examples/servers/simple-task/README.md delete mode 100644 examples/servers/simple-task/mcp_simple_task/__init__.py delete mode 100644 examples/servers/simple-task/mcp_simple_task/__main__.py delete mode 100644 examples/servers/simple-task/mcp_simple_task/server.py delete mode 100644 examples/servers/simple-task/pyproject.toml delete mode 100644 src/mcp/client/experimental/__init__.py delete mode 100644 src/mcp/client/experimental/task_handlers.py delete mode 100644 src/mcp/client/experimental/tasks.py delete mode 100644 src/mcp/server/experimental/__init__.py delete mode 100644 src/mcp/server/experimental/request_context.py delete mode 100644 src/mcp/server/experimental/session_features.py delete mode 100644 src/mcp/server/experimental/task_context.py delete mode 100644 src/mcp/server/experimental/task_result_handler.py delete mode 100644 src/mcp/server/experimental/task_support.py delete mode 100644 src/mcp/server/lowlevel/experimental.py delete mode 100644 src/mcp/shared/experimental/__init__.py delete mode 100644 src/mcp/shared/experimental/tasks/__init__.py delete mode 100644 src/mcp/shared/experimental/tasks/capabilities.py delete mode 100644 src/mcp/shared/experimental/tasks/context.py delete mode 100644 src/mcp/shared/experimental/tasks/helpers.py delete mode 100644 src/mcp/shared/experimental/tasks/in_memory_task_store.py delete mode 100644 src/mcp/shared/experimental/tasks/message_queue.py delete mode 100644 src/mcp/shared/experimental/tasks/polling.py delete mode 100644 src/mcp/shared/experimental/tasks/resolver.py delete mode 100644 src/mcp/shared/experimental/tasks/store.py delete mode 100644 src/mcp/shared/response_router.py delete mode 100644 tests/experimental/__init__.py delete mode 100644 tests/experimental/tasks/__init__.py delete mode 100644 tests/experimental/tasks/client/__init__.py delete mode 100644 tests/experimental/tasks/client/test_capabilities.py delete mode 100644 tests/experimental/tasks/client/test_handlers.py delete mode 100644 tests/experimental/tasks/client/test_poll_task.py delete mode 100644 tests/experimental/tasks/client/test_tasks.py delete mode 100644 tests/experimental/tasks/server/__init__.py delete mode 100644 tests/experimental/tasks/server/test_context.py delete mode 100644 tests/experimental/tasks/server/test_integration.py delete mode 100644 tests/experimental/tasks/server/test_run_task_flow.py delete mode 100644 tests/experimental/tasks/server/test_server.py delete mode 100644 tests/experimental/tasks/server/test_server_task_context.py delete mode 100644 tests/experimental/tasks/server/test_store.py delete mode 100644 tests/experimental/tasks/server/test_task_result_handler.py delete mode 100644 tests/experimental/tasks/test_capabilities.py delete mode 100644 tests/experimental/tasks/test_elicitation_scenarios.py delete mode 100644 tests/experimental/tasks/test_message_queue.py delete mode 100644 tests/experimental/tasks/test_request_context.py delete mode 100644 tests/experimental/tasks/test_spec_compliance.py diff --git a/README.v2.md b/README.v2.md index d0851c04e5..a888f21bda 100644 --- a/README.v2.md +++ b/README.v2.md @@ -2489,7 +2489,6 @@ MCP servers declare capabilities during initialization: ## Documentation - [API Reference](https://modelcontextprotocol.github.io/python-sdk/api/) -- [Experimental Features (Tasks)](https://modelcontextprotocol.github.io/python-sdk/experimental/tasks/) - [Model Context Protocol documentation](https://modelcontextprotocol.io) - [Model Context Protocol specification](https://modelcontextprotocol.io/specification/latest) - [Officially supported servers](https://github.com/modelcontextprotocol/servers) diff --git a/docs/experimental/index.md b/docs/experimental/index.md deleted file mode 100644 index c97fe2a3d6..0000000000 --- a/docs/experimental/index.md +++ /dev/null @@ -1,42 +0,0 @@ -# Experimental Features - -!!! warning "Experimental APIs" - - The features in this section are experimental and may change without notice. - They track the evolving MCP specification and are not yet stable. - -This section documents experimental features in the MCP Python SDK. These features -implement draft specifications that are still being refined. - -## Available Experimental Features - -### [Tasks](tasks.md) - -Tasks enable asynchronous execution of MCP operations. Instead of waiting for a -long-running operation to complete, the server returns a task reference immediately. -Clients can then poll for status updates and retrieve results when ready. - -Tasks are useful for: - -- **Long-running computations** that would otherwise block -- **Batch operations** that process many items -- **Interactive workflows** that require user input (elicitation) or LLM assistance (sampling) - -## Using Experimental APIs - -Experimental features are accessed via the `.experimental` property: - -```python -# Server-side: enable task support (auto-registers default handlers) -server = Server(name="my-server") -server.experimental.enable_tasks() - -# Client-side -result = await session.experimental.call_tool_as_task("tool_name", {"arg": "value"}) -``` - -## Providing Feedback - -Since these features are experimental, feedback is especially valuable. If you encounter -issues or have suggestions, please open an issue on the -[python-sdk repository](https://github.com/modelcontextprotocol/python-sdk/issues). diff --git a/docs/experimental/tasks-client.md b/docs/experimental/tasks-client.md deleted file mode 100644 index 0374ed86b5..0000000000 --- a/docs/experimental/tasks-client.md +++ /dev/null @@ -1,361 +0,0 @@ -# Client Task Usage - -!!! warning "Experimental" - - Tasks are an experimental feature. The API may change without notice. - -This guide covers calling task-augmented tools from clients, handling the `input_required` status, and advanced patterns like receiving task requests from servers. - -## Quick Start - -Call a tool as a task and poll for the result: - -```python -from mcp.client.session import ClientSession -from mcp.types import CallToolResult - -async with ClientSession(read, write) as session: - await session.initialize() - - # Call tool as task - result = await session.experimental.call_tool_as_task( - "process_data", - {"input": "hello"}, - ttl=60000, - ) - task_id = result.task.taskId - - # Poll until complete - async for status in session.experimental.poll_task(task_id): - print(f"Status: {status.status} - {status.statusMessage or ''}") - - # Get result - final = await session.experimental.get_task_result(task_id, CallToolResult) - print(f"Result: {final.content[0].text}") -``` - -## Calling Tools as Tasks - -Use `call_tool_as_task()` to invoke a tool with task augmentation: - -```python -result = await session.experimental.call_tool_as_task( - "my_tool", # Tool name - {"arg": "value"}, # Arguments - ttl=60000, # Time-to-live in milliseconds - meta={"key": "val"}, # Optional metadata -) - -task_id = result.task.taskId -print(f"Task: {task_id}, Status: {result.task.status}") -``` - -The response is a `CreateTaskResult` containing: - -- `task.taskId` - Unique identifier for polling -- `task.status` - Initial status (usually `"working"`) -- `task.pollInterval` - Suggested polling interval (milliseconds) -- `task.ttl` - Time-to-live for results -- `task.createdAt` - Creation timestamp - -## Polling with poll_task - -The `poll_task()` async iterator polls until the task reaches a terminal state: - -```python -async for status in session.experimental.poll_task(task_id): - print(f"Status: {status.status}") - if status.statusMessage: - print(f"Progress: {status.statusMessage}") -``` - -It automatically: - -- Respects the server's suggested `pollInterval` -- Stops when status is `completed`, `failed`, or `cancelled` -- Yields each status for progress display - -### Handling input_required - -When a task needs user input (elicitation), it transitions to `input_required`. You must call `get_task_result()` to receive and respond to the elicitation: - -```python -async for status in session.experimental.poll_task(task_id): - print(f"Status: {status.status}") - - if status.status == "input_required": - # This delivers the elicitation and waits for completion - final = await session.experimental.get_task_result(task_id, CallToolResult) - break -``` - -The elicitation callback (set during session creation) handles the actual user interaction. - -## Elicitation Callbacks - -To handle elicitation requests from the server, provide a callback when creating the session: - -```python -from mcp.types import ElicitRequestParams, ElicitResult - -async def handle_elicitation(context, params: ElicitRequestParams) -> ElicitResult: - # Display the message to the user - print(f"Server asks: {params.message}") - - # Collect user input (this is a simplified example) - response = input("Your response (y/n): ") - confirmed = response.lower() == "y" - - return ElicitResult( - action="accept", - content={"confirm": confirmed}, - ) - -async with ClientSession( - read, - write, - elicitation_callback=handle_elicitation, -) as session: - await session.initialize() - # ... call tasks that may require elicitation -``` - -## Sampling Callbacks - -Similarly, handle sampling requests with a callback: - -```python -from mcp.types import CreateMessageRequestParams, CreateMessageResult, TextContent - -async def handle_sampling(context, params: CreateMessageRequestParams) -> CreateMessageResult: - # In a real implementation, call your LLM here - prompt = params.messages[-1].content.text if params.messages else "" - - # Return a mock response - return CreateMessageResult( - role="assistant", - content=TextContent(type="text", text=f"Response to: {prompt}"), - model="my-model", - ) - -async with ClientSession( - read, - write, - sampling_callback=handle_sampling, -) as session: - # ... -``` - -## Retrieving Results - -Once a task completes, retrieve the result: - -```python -if status.status == "completed": - result = await session.experimental.get_task_result(task_id, CallToolResult) - for content in result.content: - if hasattr(content, "text"): - print(content.text) - -elif status.status == "failed": - print(f"Task failed: {status.statusMessage}") - -elif status.status == "cancelled": - print("Task was cancelled") -``` - -The result type matches the original request: - -- `tools/call` → `CallToolResult` -- `sampling/createMessage` → `CreateMessageResult` -- `elicitation/create` → `ElicitResult` - -## Cancellation - -Cancel a running task: - -```python -cancel_result = await session.experimental.cancel_task(task_id) -print(f"Cancelled, status: {cancel_result.status}") -``` - -Note: Cancellation is cooperative—the server must check for and handle cancellation. - -## Listing Tasks - -View all tasks on the server: - -```python -result = await session.experimental.list_tasks() -for task in result.tasks: - print(f"{task.taskId}: {task.status}") - -# Handle pagination -while result.nextCursor: - result = await session.experimental.list_tasks(cursor=result.nextCursor) - for task in result.tasks: - print(f"{task.taskId}: {task.status}") -``` - -## Advanced: Client as Task Receiver - -Servers can send task-augmented requests to clients. This is useful when the server needs the client to perform async work (like complex sampling or user interaction). - -### Declaring Client Capabilities - -Register task handlers to declare what task-augmented requests your client accepts: - -```python -from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers -from mcp.types import ( - CreateTaskResult, GetTaskResult, GetTaskPayloadResult, - TaskMetadata, ElicitRequestParams, -) -from mcp.shared.experimental.tasks import InMemoryTaskStore - -# Client-side task store -client_store = InMemoryTaskStore() - -async def handle_augmented_elicitation(context, params: ElicitRequestParams, task_metadata: TaskMetadata): - """Handle task-augmented elicitation from server.""" - # Create a task for this elicitation - task = await client_store.create_task(task_metadata) - - # Start async work (e.g., show UI, wait for user) - async def complete_elicitation(): - # ... do async work ... - result = ElicitResult(action="accept", content={"confirm": True}) - await client_store.store_result(task.taskId, result) - await client_store.update_task(task.taskId, status="completed") - - context.session._task_group.start_soon(complete_elicitation) - - # Return task reference immediately - return CreateTaskResult(task=task) - -async def handle_get_task(context, params): - """Handle tasks/get from server.""" - task = await client_store.get_task(params.taskId) - return GetTaskResult( - taskId=task.taskId, - status=task.status, - statusMessage=task.statusMessage, - createdAt=task.createdAt, - lastUpdatedAt=task.lastUpdatedAt, - ttl=task.ttl, - pollInterval=100, - ) - -async def handle_get_task_result(context, params): - """Handle tasks/result from server.""" - result = await client_store.get_result(params.taskId) - return GetTaskPayloadResult.model_validate(result.model_dump()) - -task_handlers = ExperimentalTaskHandlers( - augmented_elicitation=handle_augmented_elicitation, - get_task=handle_get_task, - get_task_result=handle_get_task_result, -) - -async with ClientSession( - read, - write, - experimental_task_handlers=task_handlers, -) as session: - # Client now accepts task-augmented elicitation from server - await session.initialize() -``` - -This enables flows where: - -1. Client calls a task-augmented tool -2. Server's tool work calls `task.elicit_as_task()` -3. Client receives task-augmented elicitation -4. Client creates its own task, does async work -5. Server polls client's task -6. Eventually both tasks complete - -## Complete Example - -A client that handles all task scenarios: - -```python -import anyio -from mcp.client.session import ClientSession -from mcp.client.stdio import stdio_client -from mcp.types import CallToolResult, ElicitRequestParams, ElicitResult - - -async def elicitation_callback(context, params: ElicitRequestParams) -> ElicitResult: - print(f"\n[Elicitation] {params.message}") - response = input("Confirm? (y/n): ") - return ElicitResult(action="accept", content={"confirm": response.lower() == "y"}) - - -async def main(): - async with stdio_client(command="python", args=["server.py"]) as (read, write): - async with ClientSession( - read, - write, - elicitation_callback=elicitation_callback, - ) as session: - await session.initialize() - - # List available tools - tools = await session.list_tools() - print("Tools:", [t.name for t in tools.tools]) - - # Call a task-augmented tool - print("\nCalling task tool...") - result = await session.experimental.call_tool_as_task( - "confirm_action", - {"action": "delete files"}, - ) - task_id = result.task.taskId - print(f"Task created: {task_id}") - - # Poll and handle input_required - async for status in session.experimental.poll_task(task_id): - print(f"Status: {status.status}") - - if status.status == "input_required": - final = await session.experimental.get_task_result(task_id, CallToolResult) - print(f"Result: {final.content[0].text}") - break - - if status.status == "completed": - final = await session.experimental.get_task_result(task_id, CallToolResult) - print(f"Result: {final.content[0].text}") - - -if __name__ == "__main__": - anyio.run(main) -``` - -## Error Handling - -Handle task errors gracefully: - -```python -from mcp.shared.exceptions import MCPError - -try: - result = await session.experimental.call_tool_as_task("my_tool", args) - task_id = result.task.taskId - - async for status in session.experimental.poll_task(task_id): - if status.status == "failed": - raise RuntimeError(f"Task failed: {status.statusMessage}") - - final = await session.experimental.get_task_result(task_id, CallToolResult) - -except MCPError as e: - print(f"MCP error: {e.message}") -except Exception as e: - print(f"Error: {e}") -``` - -## Next Steps - -- [Server Implementation](tasks-server.md) - Build task-supporting servers -- [Tasks Overview](tasks.md) - Review lifecycle and concepts diff --git a/docs/experimental/tasks-server.md b/docs/experimental/tasks-server.md deleted file mode 100644 index b350ee3bb6..0000000000 --- a/docs/experimental/tasks-server.md +++ /dev/null @@ -1,577 +0,0 @@ -# Server Task Implementation - -!!! warning "Experimental" - - Tasks are an experimental feature. The API may change without notice. - -This guide covers implementing task support in MCP servers, from basic setup to advanced patterns like elicitation and sampling within tasks. - -## Quick Start - -The simplest way to add task support: - -```python -from mcp.server import Server -from mcp.server.experimental.task_context import ServerTaskContext -from mcp.types import CallToolResult, CreateTaskResult, TextContent, Tool, ToolExecution, TASK_REQUIRED - -server = Server("my-server") -server.experimental.enable_tasks() # Registers all task handlers automatically - -@server.list_tools() -async def list_tools(): - return [ - Tool( - name="process_data", - description="Process data asynchronously", - inputSchema={"type": "object", "properties": {"input": {"type": "string"}}}, - execution=ToolExecution(taskSupport=TASK_REQUIRED), - ) - ] - -@server.call_tool() -async def handle_tool(name: str, arguments: dict) -> CallToolResult | CreateTaskResult: - if name == "process_data": - return await handle_process_data(arguments) - return CallToolResult(content=[TextContent(type="text", text=f"Unknown: {name}")], isError=True) - -async def handle_process_data(arguments: dict) -> CreateTaskResult: - ctx = server.request_context - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - async def work(task: ServerTaskContext) -> CallToolResult: - await task.update_status("Processing...") - result = arguments.get("input", "").upper() - return CallToolResult(content=[TextContent(type="text", text=result)]) - - return await ctx.experimental.run_task(work) -``` - -That's it. `enable_tasks()` automatically: - -- Creates an in-memory task store -- Registers handlers for `tasks/get`, `tasks/result`, `tasks/list`, `tasks/cancel` -- Updates server capabilities - -## Tool Declaration - -Tools declare task support via the `execution.taskSupport` field: - -```python -from mcp.types import Tool, ToolExecution, TASK_REQUIRED, TASK_OPTIONAL, TASK_FORBIDDEN - -Tool( - name="my_tool", - inputSchema={"type": "object"}, - execution=ToolExecution(taskSupport=TASK_REQUIRED), # or TASK_OPTIONAL, TASK_FORBIDDEN -) -``` - -| Value | Meaning | -|-------|---------| -| `TASK_REQUIRED` | Tool **must** be called as a task | -| `TASK_OPTIONAL` | Tool supports both sync and task execution | -| `TASK_FORBIDDEN` | Tool **cannot** be called as a task (default) | - -Validate the request matches your tool's requirements: - -```python -@server.call_tool() -async def handle_tool(name: str, arguments: dict): - ctx = server.request_context - - if name == "required_task_tool": - ctx.experimental.validate_task_mode(TASK_REQUIRED) # Raises if not task mode - return await handle_as_task(arguments) - - elif name == "optional_task_tool": - if ctx.experimental.is_task: - return await handle_as_task(arguments) - else: - return handle_sync(arguments) -``` - -## The run_task Pattern - -`run_task()` is the recommended way to execute task work: - -```python -async def handle_my_tool(arguments: dict) -> CreateTaskResult: - ctx = server.request_context - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - async def work(task: ServerTaskContext) -> CallToolResult: - # Your work here - return CallToolResult(content=[TextContent(type="text", text="Done")]) - - return await ctx.experimental.run_task(work) -``` - -**What `run_task()` does:** - -1. Creates a task in the store -2. Spawns your work function in the background -3. Returns `CreateTaskResult` immediately -4. Auto-completes the task when your function returns -5. Auto-fails the task if your function raises - -**The `ServerTaskContext` provides:** - -- `task.task_id` - The task identifier -- `task.update_status(message)` - Update progress -- `task.complete(result)` - Explicitly complete (usually automatic) -- `task.fail(error)` - Explicitly fail -- `task.is_cancelled` - Check if cancellation requested - -## Status Updates - -Keep clients informed of progress: - -```python -async def work(task: ServerTaskContext) -> CallToolResult: - await task.update_status("Starting...") - - for i, item in enumerate(items): - await task.update_status(f"Processing {i+1}/{len(items)}") - await process_item(item) - - await task.update_status("Finalizing...") - return CallToolResult(content=[TextContent(type="text", text="Complete")]) -``` - -Status messages appear in `tasks/get` responses, letting clients show progress to users. - -## Elicitation Within Tasks - -Tasks can request user input via elicitation. This transitions the task to `input_required` status. - -### Form Elicitation - -Collect structured data from the user: - -```python -async def work(task: ServerTaskContext) -> CallToolResult: - await task.update_status("Waiting for confirmation...") - - result = await task.elicit( - message="Delete these files?", - requestedSchema={ - "type": "object", - "properties": { - "confirm": {"type": "boolean"}, - "reason": {"type": "string"}, - }, - "required": ["confirm"], - }, - ) - - if result.action == "accept" and result.content.get("confirm"): - # User confirmed - return CallToolResult(content=[TextContent(type="text", text="Files deleted")]) - else: - # User declined or cancelled - return CallToolResult(content=[TextContent(type="text", text="Cancelled")]) -``` - -### URL Elicitation - -Direct users to external URLs for OAuth, payments, or other out-of-band flows: - -```python -async def work(task: ServerTaskContext) -> CallToolResult: - await task.update_status("Waiting for OAuth...") - - result = await task.elicit_url( - message="Please authorize with GitHub", - url="https://github.com/login/oauth/authorize?client_id=...", - elicitation_id="oauth-github-123", - ) - - if result.action == "accept": - # User completed OAuth flow - return CallToolResult(content=[TextContent(type="text", text="Connected to GitHub")]) - else: - return CallToolResult(content=[TextContent(type="text", text="OAuth cancelled")]) -``` - -## Sampling Within Tasks - -Tasks can request LLM completions from the client: - -```python -from mcp.types import SamplingMessage, TextContent - -async def work(task: ServerTaskContext) -> CallToolResult: - await task.update_status("Generating response...") - - result = await task.create_message( - messages=[ - SamplingMessage( - role="user", - content=TextContent(type="text", text="Write a haiku about coding"), - ) - ], - max_tokens=100, - ) - - haiku = result.content.text if isinstance(result.content, TextContent) else "Error" - return CallToolResult(content=[TextContent(type="text", text=haiku)]) -``` - -Sampling supports additional parameters: - -```python -result = await task.create_message( - messages=[...], - max_tokens=500, - system_prompt="You are a helpful assistant", - temperature=0.7, - stop_sequences=["\n\n"], - model_preferences=ModelPreferences(hints=[ModelHint(name="claude-3")]), -) -``` - -## Cancellation Support - -Check for cancellation in long-running work: - -```python -async def work(task: ServerTaskContext) -> CallToolResult: - for i in range(1000): - if task.is_cancelled: - # Clean up and exit - return CallToolResult(content=[TextContent(type="text", text="Cancelled")]) - - await task.update_status(f"Step {i}/1000") - await process_step(i) - - return CallToolResult(content=[TextContent(type="text", text="Complete")]) -``` - -The SDK's default cancel handler updates the task status. Your work function should check `is_cancelled` periodically. - -## Custom Task Store - -For production, implement `TaskStore` with persistent storage: - -```python -from mcp.shared.experimental.tasks.store import TaskStore -from mcp.types import Task, TaskMetadata, Result - -class RedisTaskStore(TaskStore): - def __init__(self, redis_client): - self.redis = redis_client - - async def create_task(self, metadata: TaskMetadata, task_id: str | None = None) -> Task: - # Create and persist task - ... - - async def get_task(self, task_id: str) -> Task | None: - # Retrieve task from Redis - ... - - async def update_task(self, task_id: str, status: str | None = None, ...) -> Task: - # Update and persist - ... - - async def store_result(self, task_id: str, result: Result) -> None: - # Store result in Redis - ... - - async def get_result(self, task_id: str) -> Result | None: - # Retrieve result - ... - - # ... implement remaining methods -``` - -Use your custom store: - -```python -store = RedisTaskStore(redis_client) -server.experimental.enable_tasks(store=store) -``` - -## Complete Example - -A server with multiple task-supporting tools: - -```python -from mcp.server import Server -from mcp.server.experimental.task_context import ServerTaskContext -from mcp.types import ( - CallToolResult, CreateTaskResult, TextContent, Tool, ToolExecution, - SamplingMessage, TASK_REQUIRED, -) - -server = Server("task-demo") -server.experimental.enable_tasks() - - -@server.list_tools() -async def list_tools(): - return [ - Tool( - name="confirm_action", - description="Requires user confirmation", - inputSchema={"type": "object", "properties": {"action": {"type": "string"}}}, - execution=ToolExecution(taskSupport=TASK_REQUIRED), - ), - Tool( - name="generate_text", - description="Generate text via LLM", - inputSchema={"type": "object", "properties": {"prompt": {"type": "string"}}}, - execution=ToolExecution(taskSupport=TASK_REQUIRED), - ), - ] - - -async def handle_confirm_action(arguments: dict) -> CreateTaskResult: - ctx = server.request_context - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - action = arguments.get("action", "unknown action") - - async def work(task: ServerTaskContext) -> CallToolResult: - result = await task.elicit( - message=f"Confirm: {action}?", - requestedSchema={ - "type": "object", - "properties": {"confirm": {"type": "boolean"}}, - "required": ["confirm"], - }, - ) - - if result.action == "accept" and result.content.get("confirm"): - return CallToolResult(content=[TextContent(type="text", text=f"Executed: {action}")]) - return CallToolResult(content=[TextContent(type="text", text="Cancelled")]) - - return await ctx.experimental.run_task(work) - - -async def handle_generate_text(arguments: dict) -> CreateTaskResult: - ctx = server.request_context - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - prompt = arguments.get("prompt", "Hello") - - async def work(task: ServerTaskContext) -> CallToolResult: - await task.update_status("Generating...") - - result = await task.create_message( - messages=[SamplingMessage(role="user", content=TextContent(type="text", text=prompt))], - max_tokens=200, - ) - - text = result.content.text if isinstance(result.content, TextContent) else "Error" - return CallToolResult(content=[TextContent(type="text", text=text)]) - - return await ctx.experimental.run_task(work) - - -@server.call_tool() -async def handle_tool(name: str, arguments: dict) -> CallToolResult | CreateTaskResult: - if name == "confirm_action": - return await handle_confirm_action(arguments) - elif name == "generate_text": - return await handle_generate_text(arguments) - return CallToolResult(content=[TextContent(type="text", text=f"Unknown: {name}")], isError=True) -``` - -## Error Handling in Tasks - -Tasks handle errors automatically, but you can also fail explicitly: - -```python -async def work(task: ServerTaskContext) -> CallToolResult: - try: - result = await risky_operation() - return CallToolResult(content=[TextContent(type="text", text=result)]) - except PermissionError: - await task.fail("Access denied - insufficient permissions") - raise - except TimeoutError: - await task.fail("Operation timed out after 30 seconds") - raise -``` - -When `run_task()` catches an exception, it automatically: - -1. Marks the task as `failed` -2. Sets `statusMessage` to the exception message -3. Propagates the exception (which is caught by the task group) - -For custom error messages, call `task.fail()` before raising. - -## HTTP Transport Example - -For web applications, use the Streamable HTTP transport: - -```python -import uvicorn - -from mcp.server import Server -from mcp.server.experimental.task_context import ServerTaskContext -from mcp.types import ( - CallToolResult, CreateTaskResult, TextContent, Tool, ToolExecution, TASK_REQUIRED, -) - - -server = Server("http-task-server") -server.experimental.enable_tasks() - - -@server.list_tools() -async def list_tools(): - return [ - Tool( - name="long_operation", - description="A long-running operation", - inputSchema={"type": "object", "properties": {"duration": {"type": "number"}}}, - execution=ToolExecution(taskSupport=TASK_REQUIRED), - ) - ] - - -async def handle_long_operation(arguments: dict) -> CreateTaskResult: - ctx = server.request_context - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - duration = arguments.get("duration", 5) - - async def work(task: ServerTaskContext) -> CallToolResult: - import anyio - for i in range(int(duration)): - await task.update_status(f"Step {i+1}/{int(duration)}") - await anyio.sleep(1) - return CallToolResult(content=[TextContent(type="text", text=f"Completed after {duration}s")]) - - return await ctx.experimental.run_task(work) - - -@server.call_tool() -async def handle_tool(name: str, arguments: dict) -> CallToolResult | CreateTaskResult: - if name == "long_operation": - return await handle_long_operation(arguments) - return CallToolResult(content=[TextContent(type="text", text=f"Unknown: {name}")], isError=True) - - -if __name__ == "__main__": - uvicorn.run(server.streamable_http_app(), host="127.0.0.1", port=8000) -``` - -## Testing Task Servers - -Test task functionality with the SDK's testing utilities: - -```python -import pytest -import anyio -from mcp.client.session import ClientSession -from mcp.types import CallToolResult - - -@pytest.mark.anyio -async def test_task_tool(): - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream(10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream(10) - - async def run_server(): - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options(), - ) - - async def run_client(): - async with ClientSession(server_to_client_receive, client_to_server_send) as session: - await session.initialize() - - # Call the tool as a task - result = await session.experimental.call_tool_as_task("my_tool", {"arg": "value"}) - task_id = result.task.taskId - assert result.task.status == "working" - - # Poll until complete - async for status in session.experimental.poll_task(task_id): - if status.status in ("completed", "failed"): - break - - # Get result - final = await session.experimental.get_task_result(task_id, CallToolResult) - assert len(final.content) > 0 - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) -``` - -## Best Practices - -### Keep Work Functions Focused - -```python -# Good: focused work function -async def work(task: ServerTaskContext) -> CallToolResult: - await task.update_status("Validating...") - validate_input(arguments) - - await task.update_status("Processing...") - result = await process_data(arguments) - - return CallToolResult(content=[TextContent(type="text", text=result)]) -``` - -### Check Cancellation in Loops - -```python -async def work(task: ServerTaskContext) -> CallToolResult: - results = [] - for item in large_dataset: - if task.is_cancelled: - return CallToolResult(content=[TextContent(type="text", text="Cancelled")]) - - results.append(await process(item)) - - return CallToolResult(content=[TextContent(type="text", text=str(results))]) -``` - -### Use Meaningful Status Messages - -```python -async def work(task: ServerTaskContext) -> CallToolResult: - await task.update_status("Connecting to database...") - db = await connect() - - await task.update_status("Fetching records (0/1000)...") - for i, record in enumerate(records): - if i % 100 == 0: - await task.update_status(f"Processing records ({i}/1000)...") - await process(record) - - await task.update_status("Finalizing results...") - return CallToolResult(content=[TextContent(type="text", text="Done")]) -``` - -### Handle Elicitation Responses - -```python -async def work(task: ServerTaskContext) -> CallToolResult: - result = await task.elicit(message="Continue?", requestedSchema={...}) - - match result.action: - case "accept": - # User accepted, process content - return await process_accepted(result.content) - case "decline": - # User explicitly declined - return CallToolResult(content=[TextContent(type="text", text="User declined")]) - case "cancel": - # User cancelled the elicitation - return CallToolResult(content=[TextContent(type="text", text="Cancelled")]) -``` - -## Next Steps - -- [Client Usage](tasks-client.md) - Learn how clients interact with task servers -- [Tasks Overview](tasks.md) - Review lifecycle and concepts diff --git a/docs/experimental/tasks.md b/docs/experimental/tasks.md deleted file mode 100644 index 2d4d06a025..0000000000 --- a/docs/experimental/tasks.md +++ /dev/null @@ -1,188 +0,0 @@ -# Tasks - -!!! warning "Experimental" - - Tasks are an experimental feature tracking the draft MCP specification. - The API may change without notice. - -Tasks enable asynchronous request handling in MCP. Instead of blocking until an operation completes, the receiver creates a task, returns immediately, and the requestor polls for the result. - -## When to Use Tasks - -Tasks are designed for operations that: - -- Take significant time (seconds to minutes) -- Need progress updates during execution -- Require user input mid-execution (elicitation, sampling) -- Should run without blocking the requestor - -Common use cases: - -- Long-running data processing -- Multi-step workflows with user confirmation -- LLM-powered operations requiring sampling -- OAuth flows requiring user browser interaction - -## Task Lifecycle - -```text - ┌─────────────┐ - │ working │ - └──────┬──────┘ - │ - ┌────────────┼────────────┐ - │ │ │ - ▼ ▼ ▼ - ┌────────────┐ ┌───────────┐ ┌───────────┐ - │ completed │ │ failed │ │ cancelled │ - └────────────┘ └───────────┘ └───────────┘ - ▲ - │ - ┌────────┴────────┐ - │ input_required │◄──────┐ - └────────┬────────┘ │ - │ │ - └────────────────┘ -``` - -| Status | Description | -|--------|-------------| -| `working` | Task is being processed | -| `input_required` | Receiver needs input from requestor (elicitation/sampling) | -| `completed` | Task finished successfully | -| `failed` | Task encountered an error | -| `cancelled` | Task was cancelled by requestor | - -Terminal states (`completed`, `failed`, `cancelled`) are final—tasks cannot transition out of them. - -## Bidirectional Flow - -Tasks work in both directions: - -**Client → Server** (most common): - -```text -Client Server - │ │ - │── tools/call (task) ──────────────>│ Creates task - │<── CreateTaskResult ───────────────│ - │ │ - │── tasks/get ──────────────────────>│ - │<── status: working ────────────────│ - │ │ ... work continues ... - │── tasks/get ──────────────────────>│ - │<── status: completed ──────────────│ - │ │ - │── tasks/result ───────────────────>│ - │<── CallToolResult ─────────────────│ -``` - -**Server → Client** (for elicitation/sampling): - -```text -Server Client - │ │ - │── elicitation/create (task) ──────>│ Creates task - │<── CreateTaskResult ───────────────│ - │ │ - │── tasks/get ──────────────────────>│ - │<── status: working ────────────────│ - │ │ ... user interaction ... - │── tasks/get ──────────────────────>│ - │<── status: completed ──────────────│ - │ │ - │── tasks/result ───────────────────>│ - │<── ElicitResult ───────────────────│ -``` - -## Key Concepts - -### Task Metadata - -When augmenting a request with task execution, include `TaskMetadata`: - -```python -from mcp.types import TaskMetadata - -task = TaskMetadata(ttl=60000) # TTL in milliseconds -``` - -The `ttl` (time-to-live) specifies how long the task and result are retained after completion. - -### Task Store - -Servers persist task state in a `TaskStore`. The SDK provides `InMemoryTaskStore` for development: - -```python -from mcp.shared.experimental.tasks import InMemoryTaskStore - -store = InMemoryTaskStore() -``` - -For production, implement `TaskStore` with a database or distributed cache. - -### Capabilities - -Both servers and clients declare task support through capabilities: - -**Server capabilities:** - -- `tasks.requests.tools.call` - Server accepts task-augmented tool calls - -**Client capabilities:** - -- `tasks.requests.sampling.createMessage` - Client accepts task-augmented sampling -- `tasks.requests.elicitation.create` - Client accepts task-augmented elicitation - -The SDK manages these automatically when you enable task support. - -## Quick Example - -**Server** (simplified API): - -```python -from mcp.server import Server -from mcp.server.experimental.task_context import ServerTaskContext -from mcp.types import CallToolResult, TextContent, TASK_REQUIRED - -server = Server("my-server") -server.experimental.enable_tasks() # One-line setup - -@server.call_tool() -async def handle_tool(name: str, arguments: dict): - ctx = server.request_context - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - async def work(task: ServerTaskContext): - await task.update_status("Processing...") - # ... do work ... - return CallToolResult(content=[TextContent(type="text", text="Done!")]) - - return await ctx.experimental.run_task(work) -``` - -**Client:** - -```python -from mcp.client.session import ClientSession -from mcp.types import CallToolResult - -async with ClientSession(read, write) as session: - await session.initialize() - - # Call tool as task - result = await session.experimental.call_tool_as_task("my_tool", {"arg": "value"}) - task_id = result.task.taskId - - # Poll until done - async for status in session.experimental.poll_task(task_id): - print(f"Status: {status.status}") - - # Get result - final = await session.experimental.get_task_result(task_id, CallToolResult) -``` - -## Next Steps - -- [Server Implementation](tasks-server.md) - Build task-supporting servers -- [Client Usage](tasks-client.md) - Call and poll tasks from clients diff --git a/docs/migration.md b/docs/migration.md index 8b70885e8d..9850f74cd4 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -595,7 +595,7 @@ The `RequestContext` class has been split to separate shared fields from server- **`RequestContext` changes:** - Type parameters reduced from `RequestContext[SessionT, LifespanContextT, RequestT]` to `RequestContext[SessionT]` -- Server-specific fields (`lifespan_context`, `experimental`, `request`, `close_sse_stream`, `close_standalone_sse_stream`) moved to new `ServerRequestContext` class in `mcp.server.context` +- Server-specific fields (`lifespan_context`, `request`, `close_sse_stream`, `close_standalone_sse_stream`) moved to new `ServerRequestContext` class in `mcp.server.context` **Before (v1):** @@ -861,7 +861,7 @@ server = Server("my-server", on_list_tools=handle_list_tools, on_call_tool=handl **Key differences:** -- Handlers receive `(ctx, params)` instead of the full request object or unpacked arguments. `ctx` is a `ServerRequestContext` with `session`, `lifespan_context`, and `experimental` fields (plus `request_id`, `meta`, etc. for request handlers). `params` is the typed request params object. +- Handlers receive `(ctx, params)` instead of the full request object or unpacked arguments. `ctx` is a `ServerRequestContext` with `session` and `lifespan_context` fields (plus `request_id`, `meta`, etc. for request handlers). `params` is the typed request params object. - Handlers return the full result type (e.g. `ListToolsResult`) rather than unwrapped values (e.g. `list[Tool]`). - The automatic `jsonschema` input/output validation that the old `call_tool()` decorator performed has been removed. There is no built-in replacement — if you relied on schema validation in the lowlevel server, you will need to validate inputs yourself in your handler. @@ -872,7 +872,7 @@ All handlers receive `ctx: ServerRequestContext` as the first argument. The seco | v1 decorator | v2 constructor kwarg | `params` type | return type | |---|---|---|---| | `@server.list_tools()` | `on_list_tools` | `PaginatedRequestParams \| None` | `ListToolsResult` | -| `@server.call_tool()` | `on_call_tool` | `CallToolRequestParams` | `CallToolResult \| CreateTaskResult` | +| `@server.call_tool()` | `on_call_tool` | `CallToolRequestParams` | `CallToolResult` | | `@server.list_resources()` | `on_list_resources` | `PaginatedRequestParams \| None` | `ListResourcesResult` | | `@server.list_resource_templates()` | `on_list_resource_templates` | `PaginatedRequestParams \| None` | `ListResourceTemplatesResult` | | `@server.read_resource()` | `on_read_resource` | `ReadResourceRequestParams` | `ReadResourceResult` | @@ -1039,37 +1039,11 @@ from mcp.server import ServerRequestContext # but None in notification handlers ``` -### Experimental: task handler decorators removed +### Experimental Tasks support removed -The experimental decorator methods on `ExperimentalHandlers` (`@server.experimental.list_tasks()`, `@server.experimental.get_task()`, etc.) have been removed. +Tasks (SEP-1686) have been removed from the MCP specification and are no longer part of this SDK. The `mcp.client.experimental`, `mcp.server.experimental`, `mcp.shared.experimental`, and `mcp.server.lowlevel.experimental` modules have been removed, along with all `Task*` types, the `tasks` capability fields, `Tool.execution`, and the `experimental` properties on `ClientSession`, `ServerSession`, `Server`, and `ServerRequestContext`. -Default task handlers are still registered automatically via `server.experimental.enable_tasks()`. Custom handlers can be passed as `on_*` kwargs to override specific defaults. - -**Before (v1):** - -```python -server = Server("my-server") -server.experimental.enable_tasks() - -@server.experimental.get_task() -async def custom_get_task(request: GetTaskRequest) -> GetTaskResult: - ... -``` - -**After (v2):** - -```python -from mcp.server import Server, ServerRequestContext -from mcp.types import GetTaskRequestParams, GetTaskResult - - -async def custom_get_task(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: - ... - - -server = Server("my-server") -server.experimental.enable_tasks(on_get_task=custom_get_task) -``` +Tasks are expected to return as a separate MCP extension in a future release. ## Deprecations diff --git a/examples/clients/simple-task-client/README.md b/examples/clients/simple-task-client/README.md deleted file mode 100644 index 103be0f1fb..0000000000 --- a/examples/clients/simple-task-client/README.md +++ /dev/null @@ -1,43 +0,0 @@ -# Simple Task Client - -A minimal MCP client demonstrating polling for task results over streamable HTTP. - -## Running - -First, start the simple-task server in another terminal: - -```bash -cd examples/servers/simple-task -uv run mcp-simple-task -``` - -Then run the client: - -```bash -cd examples/clients/simple-task-client -uv run mcp-simple-task-client -``` - -Use `--url` to connect to a different server. - -## What it does - -1. Connects to the server via streamable HTTP -2. Calls the `long_running_task` tool as a task -3. Polls the task status until completion -4. Retrieves and prints the result - -## Expected output - -```text -Available tools: ['long_running_task'] - -Calling tool as a task... -Task created: - Status: working - Starting work... - Status: working - Processing step 1... - Status: working - Processing step 2... - Status: completed - - -Result: Task completed! -``` diff --git a/examples/clients/simple-task-client/mcp_simple_task_client/__init__.py b/examples/clients/simple-task-client/mcp_simple_task_client/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/examples/clients/simple-task-client/mcp_simple_task_client/__main__.py b/examples/clients/simple-task-client/mcp_simple_task_client/__main__.py deleted file mode 100644 index 2fc2cda8d9..0000000000 --- a/examples/clients/simple-task-client/mcp_simple_task_client/__main__.py +++ /dev/null @@ -1,5 +0,0 @@ -import sys - -from .main import main - -sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/clients/simple-task-client/mcp_simple_task_client/main.py b/examples/clients/simple-task-client/mcp_simple_task_client/main.py deleted file mode 100644 index f9e555c8e6..0000000000 --- a/examples/clients/simple-task-client/mcp_simple_task_client/main.py +++ /dev/null @@ -1,56 +0,0 @@ -"""Simple task client demonstrating MCP tasks polling over streamable HTTP.""" - -import asyncio - -import click -from mcp import ClientSession -from mcp.client.streamable_http import streamable_http_client -from mcp.types import CallToolResult, TextContent - - -async def run(url: str) -> None: - async with streamable_http_client(url) as (read, write): - async with ClientSession(read, write) as session: - await session.initialize() - - # List tools - tools = await session.list_tools() - print(f"Available tools: {[t.name for t in tools.tools]}") - - # Call the tool as a task - print("\nCalling tool as a task...") - - result = await session.experimental.call_tool_as_task( - "long_running_task", - arguments={}, - ttl=60000, - ) - task_id = result.task.task_id - print(f"Task created: {task_id}") - - status = None - # Poll until done (respects server's pollInterval hint) - async for status in session.experimental.poll_task(task_id): - print(f" Status: {status.status} - {status.status_message or ''}") - - # Check final status - if status and status.status != "completed": - print(f"Task ended with status: {status.status}") - return - - # Get the result - task_result = await session.experimental.get_task_result(task_id, CallToolResult) - content = task_result.content[0] - if isinstance(content, TextContent): - print(f"\nResult: {content.text}") - - -@click.command() -@click.option("--url", default="http://localhost:8000/mcp", help="Server URL") -def main(url: str) -> int: - asyncio.run(run(url)) - return 0 - - -if __name__ == "__main__": - main() diff --git a/examples/clients/simple-task-client/pyproject.toml b/examples/clients/simple-task-client/pyproject.toml deleted file mode 100644 index c7abf51159..0000000000 --- a/examples/clients/simple-task-client/pyproject.toml +++ /dev/null @@ -1,43 +0,0 @@ -[project] -name = "mcp-simple-task-client" -version = "0.1.0" -description = "A simple MCP client demonstrating task polling" -readme = "README.md" -requires-python = ">=3.10" -authors = [{ name = "Model Context Protocol a Series of LF Projects, LLC." }] -keywords = ["mcp", "llm", "tasks", "client"] -license = { text = "MIT" } -classifiers = [ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "License :: OSI Approved :: MIT License", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.10", -] -dependencies = ["click>=8.0", "mcp"] - -[project.scripts] -mcp-simple-task-client = "mcp_simple_task_client.main:main" - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.hatch.build.targets.wheel] -packages = ["mcp_simple_task_client"] - -[tool.pyright] -include = ["mcp_simple_task_client"] -venvPath = "." -venv = ".venv" - -[tool.ruff.lint] -select = ["E", "F", "I"] -ignore = [] - -[tool.ruff] -line-length = 120 -target-version = "py310" - -[dependency-groups] -dev = ["pyright>=1.1.378", "ruff>=0.6.9"] diff --git a/examples/clients/simple-task-interactive-client/README.md b/examples/clients/simple-task-interactive-client/README.md deleted file mode 100644 index 3397d3b5d7..0000000000 --- a/examples/clients/simple-task-interactive-client/README.md +++ /dev/null @@ -1,87 +0,0 @@ -# Simple Interactive Task Client - -A minimal MCP client demonstrating responses to interactive tasks (elicitation and sampling). - -## Running - -First, start the interactive task server in another terminal: - -```bash -cd examples/servers/simple-task-interactive -uv run mcp-simple-task-interactive -``` - -Then run the client: - -```bash -cd examples/clients/simple-task-interactive-client -uv run mcp-simple-task-interactive-client -``` - -Use `--url` to connect to a different server. - -## What it does - -1. Connects to the server via streamable HTTP -2. Calls `confirm_delete` - server asks for confirmation, client responds via terminal -3. Calls `write_haiku` - server requests LLM completion, client returns a hardcoded haiku - -## Key concepts - -### Elicitation callback - -```python -async def elicitation_callback(context, params) -> ElicitResult: - # Handle user input request from server - return ElicitResult(action="accept", content={"confirm": True}) -``` - -### Sampling callback - -```python -async def sampling_callback(context, params) -> CreateMessageResult: - # Handle LLM completion request from server - return CreateMessageResult(model="...", role="assistant", content=...) -``` - -### Using call_tool_as_task - -```python -# Call a tool as a task (returns immediately with task reference) -result = await session.experimental.call_tool_as_task("tool_name", {"arg": "value"}) -task_id = result.task.task_id - -# Get result - this delivers elicitation/sampling requests and blocks until complete -final = await session.experimental.get_task_result(task_id, CallToolResult) -``` - -**Important**: The `get_task_result()` call is what triggers the delivery of elicitation -and sampling requests to your callbacks. It blocks until the task completes and returns -the final result. - -## Expected output - -```text -Available tools: ['confirm_delete', 'write_haiku'] - ---- Demo 1: Elicitation --- -Calling confirm_delete tool... -Task created: - -[Elicitation] Server asks: Are you sure you want to delete 'important.txt'? -Your response (y/n): y -[Elicitation] Responding with: confirm=True -Result: Deleted 'important.txt' - ---- Demo 2: Sampling --- -Calling write_haiku tool... -Task created: - -[Sampling] Server requests LLM completion for: Write a haiku about autumn leaves -[Sampling] Responding with haiku -Result: -Haiku: -Cherry blossoms fall -Softly on the quiet pond -Spring whispers goodbye -``` diff --git a/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/__init__.py b/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/__main__.py b/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/__main__.py deleted file mode 100644 index 2fc2cda8d9..0000000000 --- a/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/__main__.py +++ /dev/null @@ -1,5 +0,0 @@ -import sys - -from .main import main - -sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py b/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py deleted file mode 100644 index ff5f499280..0000000000 --- a/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py +++ /dev/null @@ -1,137 +0,0 @@ -"""Simple interactive task client demonstrating elicitation and sampling responses. - -This example demonstrates the spec-compliant polling pattern: -1. Poll tasks/get watching for status changes -2. On input_required, call tasks/result to receive elicitation/sampling requests -3. Continue until terminal status, then retrieve final result -""" - -import asyncio - -import click -from mcp import ClientSession -from mcp.client.context import ClientRequestContext -from mcp.client.streamable_http import streamable_http_client -from mcp.types import ( - CallToolResult, - CreateMessageRequestParams, - CreateMessageResult, - ElicitRequestParams, - ElicitResult, - TextContent, -) - - -async def elicitation_callback( - context: ClientRequestContext, - params: ElicitRequestParams, -) -> ElicitResult: - """Handle elicitation requests from the server.""" - print(f"\n[Elicitation] Server asks: {params.message}") - - # Simple terminal prompt - response = input("Your response (y/n): ").strip().lower() - confirmed = response in ("y", "yes", "true", "1") - - print(f"[Elicitation] Responding with: confirm={confirmed}") - return ElicitResult(action="accept", content={"confirm": confirmed}) - - -async def sampling_callback( - context: ClientRequestContext, - params: CreateMessageRequestParams, -) -> CreateMessageResult: - """Handle sampling requests from the server.""" - # Get the prompt from the first message - prompt = "unknown" - if params.messages: - content = params.messages[0].content - if isinstance(content, TextContent): - prompt = content.text - - print(f"\n[Sampling] Server requests LLM completion for: {prompt}") - - # Return a hardcoded haiku (in real use, call your LLM here) - haiku = """Cherry blossoms fall -Softly on the quiet pond -Spring whispers goodbye""" - - print("[Sampling] Responding with haiku") - return CreateMessageResult( - model="mock-haiku-model", - role="assistant", - content=TextContent(type="text", text=haiku), - ) - - -def get_text(result: CallToolResult) -> str: - """Extract text from a CallToolResult.""" - if result.content and isinstance(result.content[0], TextContent): - return result.content[0].text - return "(no text)" - - -async def run(url: str) -> None: - async with streamable_http_client(url) as (read, write): - async with ClientSession( - read, - write, - elicitation_callback=elicitation_callback, - sampling_callback=sampling_callback, - ) as session: - await session.initialize() - - # List tools - tools = await session.list_tools() - print(f"Available tools: {[t.name for t in tools.tools]}") - - # Demo 1: Elicitation (confirm_delete) - print("\n--- Demo 1: Elicitation ---") - print("Calling confirm_delete tool...") - - elicit_task = await session.experimental.call_tool_as_task("confirm_delete", {"filename": "important.txt"}) - elicit_task_id = elicit_task.task.task_id - print(f"Task created: {elicit_task_id}") - - # Poll until terminal, calling tasks/result on input_required - async for status in session.experimental.poll_task(elicit_task_id): - print(f"[Poll] Status: {status.status}") - if status.status == "input_required": - # Server needs input - tasks/result delivers the elicitation request - elicit_result = await session.experimental.get_task_result(elicit_task_id, CallToolResult) - break - else: - # poll_task exited due to terminal status - elicit_result = await session.experimental.get_task_result(elicit_task_id, CallToolResult) - - print(f"Result: {get_text(elicit_result)}") - - # Demo 2: Sampling (write_haiku) - print("\n--- Demo 2: Sampling ---") - print("Calling write_haiku tool...") - - sampling_task = await session.experimental.call_tool_as_task("write_haiku", {"topic": "autumn leaves"}) - sampling_task_id = sampling_task.task.task_id - print(f"Task created: {sampling_task_id}") - - # Poll until terminal, calling tasks/result on input_required - async for status in session.experimental.poll_task(sampling_task_id): - print(f"[Poll] Status: {status.status}") - if status.status == "input_required": - sampling_result = await session.experimental.get_task_result(sampling_task_id, CallToolResult) - break - else: - sampling_result = await session.experimental.get_task_result(sampling_task_id, CallToolResult) - - print(f"Result:\n{get_text(sampling_result)}") - - -@click.command() -@click.option("--url", default="http://localhost:8000/mcp", help="Server URL") -def main(url: str) -> int: - asyncio.run(run(url)) - return 0 - - -if __name__ == "__main__": - main() diff --git a/examples/clients/simple-task-interactive-client/pyproject.toml b/examples/clients/simple-task-interactive-client/pyproject.toml deleted file mode 100644 index 47191573f2..0000000000 --- a/examples/clients/simple-task-interactive-client/pyproject.toml +++ /dev/null @@ -1,43 +0,0 @@ -[project] -name = "mcp-simple-task-interactive-client" -version = "0.1.0" -description = "A simple MCP client demonstrating interactive task responses" -readme = "README.md" -requires-python = ">=3.10" -authors = [{ name = "Model Context Protocol a Series of LF Projects, LLC." }] -keywords = ["mcp", "llm", "tasks", "client", "elicitation", "sampling"] -license = { text = "MIT" } -classifiers = [ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "License :: OSI Approved :: MIT License", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.10", -] -dependencies = ["click>=8.0", "mcp"] - -[project.scripts] -mcp-simple-task-interactive-client = "mcp_simple_task_interactive_client.main:main" - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.hatch.build.targets.wheel] -packages = ["mcp_simple_task_interactive_client"] - -[tool.pyright] -include = ["mcp_simple_task_interactive_client"] -venvPath = "." -venv = ".venv" - -[tool.ruff.lint] -select = ["E", "F", "I"] -ignore = [] - -[tool.ruff] -line-length = 120 -target-version = "py310" - -[dependency-groups] -dev = ["pyright>=1.1.378", "ruff>=0.6.9"] diff --git a/examples/servers/simple-task-interactive/README.md b/examples/servers/simple-task-interactive/README.md deleted file mode 100644 index b8f384cb48..0000000000 --- a/examples/servers/simple-task-interactive/README.md +++ /dev/null @@ -1,74 +0,0 @@ -# Simple Interactive Task Server - -A minimal MCP server demonstrating interactive tasks with elicitation and sampling. - -## Running - -```bash -cd examples/servers/simple-task-interactive -uv run mcp-simple-task-interactive -``` - -The server starts on `http://localhost:8000/mcp` by default. Use `--port` to change. - -## What it does - -This server exposes two tools: - -### `confirm_delete` (demonstrates elicitation) - -Asks the user for confirmation before "deleting" a file. - -- Uses `task.elicit()` to request user input -- Shows the elicitation flow: task -> input_required -> response -> complete - -### `write_haiku` (demonstrates sampling) - -Asks the LLM to write a haiku about a topic. - -- Uses `task.create_message()` to request LLM completion -- Shows the sampling flow: task -> input_required -> response -> complete - -## Usage with the client - -In one terminal, start the server: - -```bash -cd examples/servers/simple-task-interactive -uv run mcp-simple-task-interactive -``` - -In another terminal, run the interactive client: - -```bash -cd examples/clients/simple-task-interactive-client -uv run mcp-simple-task-interactive-client -``` - -## Expected server output - -When a client connects and calls the tools, you'll see: - -```text -Starting server on http://localhost:8000/mcp - -[Server] confirm_delete called for 'important.txt' -[Server] Task created: -[Server] Sending elicitation request to client... -[Server] Received elicitation response: action=accept, content={'confirm': True} -[Server] Completing task with result: Deleted 'important.txt' - -[Server] write_haiku called for topic 'autumn leaves' -[Server] Task created: -[Server] Sending sampling request to client... -[Server] Received sampling response: Cherry blossoms fall -Softly on the quiet pon... -[Server] Completing task with haiku -``` - -## Key concepts - -1. **ServerTaskContext**: Provides `elicit()` and `create_message()` for user interaction -2. **run_task()**: Spawns background work, auto-completes/fails, returns immediately -3. **TaskResultHandler**: Delivers queued messages and routes responses -4. **Response routing**: Responses are routed back to waiting resolvers diff --git a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/__init__.py b/examples/servers/simple-task-interactive/mcp_simple_task_interactive/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/__main__.py b/examples/servers/simple-task-interactive/mcp_simple_task_interactive/__main__.py deleted file mode 100644 index e7ef16530b..0000000000 --- a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/__main__.py +++ /dev/null @@ -1,5 +0,0 @@ -import sys - -from .server import main - -sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py b/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py deleted file mode 100644 index bc06e12088..0000000000 --- a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py +++ /dev/null @@ -1,138 +0,0 @@ -"""Simple interactive task server demonstrating elicitation and sampling. - -This example shows the simplified task API where: -- server.experimental.enable_tasks() sets up all infrastructure -- ctx.experimental.run_task() handles task lifecycle automatically -- ServerTaskContext.elicit() and ServerTaskContext.create_message() queue requests properly -""" - -from typing import Any - -import click -import uvicorn -from mcp import types -from mcp.server import Server, ServerRequestContext -from mcp.server.experimental.task_context import ServerTaskContext - - -async def handle_list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None -) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[ - types.Tool( - name="confirm_delete", - description="Asks for confirmation before deleting (demonstrates elicitation)", - input_schema={ - "type": "object", - "properties": {"filename": {"type": "string"}}, - }, - execution=types.ToolExecution(task_support=types.TASK_REQUIRED), - ), - types.Tool( - name="write_haiku", - description="Asks LLM to write a haiku (demonstrates sampling)", - input_schema={"type": "object", "properties": {"topic": {"type": "string"}}}, - execution=types.ToolExecution(task_support=types.TASK_REQUIRED), - ), - ] - ) - - -async def handle_confirm_delete(ctx: ServerRequestContext, arguments: dict[str, Any]) -> types.CreateTaskResult: - """Handle the confirm_delete tool - demonstrates elicitation.""" - ctx.experimental.validate_task_mode(types.TASK_REQUIRED) - - filename = arguments.get("filename", "unknown.txt") - print(f"\n[Server] confirm_delete called for '{filename}'") - - async def work(task: ServerTaskContext) -> types.CallToolResult: - print(f"[Server] Task {task.task_id} starting elicitation...") - - result = await task.elicit( - message=f"Are you sure you want to delete '{filename}'?", - requested_schema={ - "type": "object", - "properties": {"confirm": {"type": "boolean"}}, - "required": ["confirm"], - }, - ) - - print(f"[Server] Received elicitation response: action={result.action}, content={result.content}") - - if result.action == "accept" and result.content: - confirmed = result.content.get("confirm", False) - text = f"Deleted '{filename}'" if confirmed else "Deletion cancelled" - else: - text = "Deletion cancelled" - - print(f"[Server] Completing task with result: {text}") - return types.CallToolResult(content=[types.TextContent(type="text", text=text)]) - - return await ctx.experimental.run_task(work) - - -async def handle_write_haiku(ctx: ServerRequestContext, arguments: dict[str, Any]) -> types.CreateTaskResult: - """Handle the write_haiku tool - demonstrates sampling.""" - ctx.experimental.validate_task_mode(types.TASK_REQUIRED) - - topic = arguments.get("topic", "nature") - print(f"\n[Server] write_haiku called for topic '{topic}'") - - async def work(task: ServerTaskContext) -> types.CallToolResult: - print(f"[Server] Task {task.task_id} starting sampling...") - - result = await task.create_message( - messages=[ - types.SamplingMessage( - role="user", - content=types.TextContent(type="text", text=f"Write a haiku about {topic}"), - ) - ], - max_tokens=50, - ) - - haiku = "No response" - if isinstance(result.content, types.TextContent): - haiku = result.content.text - - print(f"[Server] Received sampling response: {haiku[:50]}...") - return types.CallToolResult(content=[types.TextContent(type="text", text=f"Haiku:\n{haiku}")]) - - return await ctx.experimental.run_task(work) - - -async def handle_call_tool( - ctx: ServerRequestContext, params: types.CallToolRequestParams -) -> types.CallToolResult | types.CreateTaskResult: - """Dispatch tool calls to their handlers.""" - arguments = params.arguments or {} - - if params.name == "confirm_delete": - return await handle_confirm_delete(ctx, arguments) - elif params.name == "write_haiku": - return await handle_write_haiku(ctx, arguments) - - return types.CallToolResult( - content=[types.TextContent(type="text", text=f"Unknown tool: {params.name}")], - is_error=True, - ) - - -server = Server( - "simple-task-interactive", - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, -) - -# Enable task support - this auto-registers all handlers -server.experimental.enable_tasks() - - -@click.command() -@click.option("--port", default=8000, help="Port to listen on") -def main(port: int) -> int: - starlette_app = server.streamable_http_app() - print(f"Starting server on http://localhost:{port}/mcp") - uvicorn.run(starlette_app, host="127.0.0.1", port=port) - return 0 diff --git a/examples/servers/simple-task-interactive/pyproject.toml b/examples/servers/simple-task-interactive/pyproject.toml deleted file mode 100644 index 4ec9770763..0000000000 --- a/examples/servers/simple-task-interactive/pyproject.toml +++ /dev/null @@ -1,43 +0,0 @@ -[project] -name = "mcp-simple-task-interactive" -version = "0.1.0" -description = "A simple MCP server demonstrating interactive tasks (elicitation & sampling)" -readme = "README.md" -requires-python = ">=3.10" -authors = [{ name = "Model Context Protocol a Series of LF Projects, LLC." }] -keywords = ["mcp", "llm", "tasks", "elicitation", "sampling"] -license = { text = "MIT" } -classifiers = [ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "License :: OSI Approved :: MIT License", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.10", -] -dependencies = ["anyio>=4.5", "click>=8.0", "mcp", "starlette", "uvicorn"] - -[project.scripts] -mcp-simple-task-interactive = "mcp_simple_task_interactive.server:main" - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.hatch.build.targets.wheel] -packages = ["mcp_simple_task_interactive"] - -[tool.pyright] -include = ["mcp_simple_task_interactive"] -venvPath = "." -venv = ".venv" - -[tool.ruff.lint] -select = ["E", "F", "I"] -ignore = [] - -[tool.ruff] -line-length = 120 -target-version = "py310" - -[dependency-groups] -dev = ["pyright>=1.1.378", "ruff>=0.6.9"] diff --git a/examples/servers/simple-task/README.md b/examples/servers/simple-task/README.md deleted file mode 100644 index 6914e0414f..0000000000 --- a/examples/servers/simple-task/README.md +++ /dev/null @@ -1,37 +0,0 @@ -# Simple Task Server - -A minimal MCP server demonstrating the experimental tasks feature over streamable HTTP. - -## Running - -```bash -cd examples/servers/simple-task -uv run mcp-simple-task -``` - -The server starts on `http://localhost:8000/mcp` by default. Use `--port` to change. - -## What it does - -This server exposes a single tool `long_running_task` that: - -1. Must be called as a task (with `task` metadata in the request) -2. Takes ~3 seconds to complete -3. Sends status updates during execution -4. Returns a result when complete - -## Usage with the client - -In one terminal, start the server: - -```bash -cd examples/servers/simple-task -uv run mcp-simple-task -``` - -In another terminal, run the client: - -```bash -cd examples/clients/simple-task-client -uv run mcp-simple-task-client -``` diff --git a/examples/servers/simple-task/mcp_simple_task/__init__.py b/examples/servers/simple-task/mcp_simple_task/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/examples/servers/simple-task/mcp_simple_task/__main__.py b/examples/servers/simple-task/mcp_simple_task/__main__.py deleted file mode 100644 index e7ef16530b..0000000000 --- a/examples/servers/simple-task/mcp_simple_task/__main__.py +++ /dev/null @@ -1,5 +0,0 @@ -import sys - -from .server import main - -sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/servers/simple-task/mcp_simple_task/server.py b/examples/servers/simple-task/mcp_simple_task/server.py deleted file mode 100644 index 7583cd8f0e..0000000000 --- a/examples/servers/simple-task/mcp_simple_task/server.py +++ /dev/null @@ -1,70 +0,0 @@ -"""Simple task server demonstrating MCP tasks over streamable HTTP.""" - -import anyio -import click -import uvicorn -from mcp import types -from mcp.server import Server, ServerRequestContext -from mcp.server.experimental.task_context import ServerTaskContext - - -async def handle_list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None -) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[ - types.Tool( - name="long_running_task", - description="A task that takes a few seconds to complete with status updates", - input_schema={"type": "object", "properties": {}}, - execution=types.ToolExecution(task_support=types.TASK_REQUIRED), - ) - ] - ) - - -async def handle_call_tool( - ctx: ServerRequestContext, params: types.CallToolRequestParams -) -> types.CallToolResult | types.CreateTaskResult: - """Dispatch tool calls to their handlers.""" - if params.name == "long_running_task": - ctx.experimental.validate_task_mode(types.TASK_REQUIRED) - - async def work(task: ServerTaskContext) -> types.CallToolResult: - await task.update_status("Starting work...") - await anyio.sleep(1) - - await task.update_status("Processing step 1...") - await anyio.sleep(1) - - await task.update_status("Processing step 2...") - await anyio.sleep(1) - - return types.CallToolResult(content=[types.TextContent(type="text", text="Task completed!")]) - - return await ctx.experimental.run_task(work) - - return types.CallToolResult( - content=[types.TextContent(type="text", text=f"Unknown tool: {params.name}")], - is_error=True, - ) - - -server = Server( - "simple-task-server", - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, -) - -# One-line setup: auto-registers get_task, get_task_result, list_tasks, cancel_task -server.experimental.enable_tasks() - - -@click.command() -@click.option("--port", default=8000, help="Port to listen on") -def main(port: int) -> int: - starlette_app = server.streamable_http_app() - - print(f"Starting server on http://localhost:{port}/mcp") - uvicorn.run(starlette_app, host="127.0.0.1", port=port) - return 0 diff --git a/examples/servers/simple-task/pyproject.toml b/examples/servers/simple-task/pyproject.toml deleted file mode 100644 index 921a1c34fc..0000000000 --- a/examples/servers/simple-task/pyproject.toml +++ /dev/null @@ -1,43 +0,0 @@ -[project] -name = "mcp-simple-task" -version = "0.1.0" -description = "A simple MCP server demonstrating tasks" -readme = "README.md" -requires-python = ">=3.10" -authors = [{ name = "Model Context Protocol a Series of LF Projects, LLC." }] -keywords = ["mcp", "llm", "tasks"] -license = { text = "MIT" } -classifiers = [ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "License :: OSI Approved :: MIT License", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.10", -] -dependencies = ["anyio>=4.5", "click>=8.0", "mcp", "starlette", "uvicorn"] - -[project.scripts] -mcp-simple-task = "mcp_simple_task.server:main" - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.hatch.build.targets.wheel] -packages = ["mcp_simple_task"] - -[tool.pyright] -include = ["mcp_simple_task"] -venvPath = "." -venv = ".venv" - -[tool.ruff.lint] -select = ["E", "F", "I"] -ignore = [] - -[tool.ruff] -line-length = 120 -target-version = "py310" - -[dependency-groups] -dev = ["pyright>=1.1.378", "ruff>=0.6.9"] diff --git a/mkdocs.yml b/mkdocs.yml index e48c64242d..cb89faf0f0 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -19,12 +19,6 @@ nav: - Low-Level Server: low-level-server.md - Authorization: authorization.md - Testing: testing.md - - Experimental: - - Overview: experimental/index.md - - Tasks: - - Introduction: experimental/tasks.md - - Server Implementation: experimental/tasks-server.md - - Client Usage: experimental/tasks-client.md - API Reference: api/ theme: diff --git a/src/mcp/client/experimental/__init__.py b/src/mcp/client/experimental/__init__.py deleted file mode 100644 index 8d74cb3044..0000000000 --- a/src/mcp/client/experimental/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""Experimental client features. - -WARNING: These APIs are experimental and may change without notice. -""" - -from mcp.client.experimental.tasks import ExperimentalClientFeatures - -__all__ = ["ExperimentalClientFeatures"] diff --git a/src/mcp/client/experimental/task_handlers.py b/src/mcp/client/experimental/task_handlers.py deleted file mode 100644 index 0ab513236a..0000000000 --- a/src/mcp/client/experimental/task_handlers.py +++ /dev/null @@ -1,293 +0,0 @@ -"""Experimental task handler protocols for server -> client requests. - -This module provides Protocol types and default handlers for when servers -send task-related requests to clients (the reverse of normal client -> server flow). - -WARNING: These APIs are experimental and may change without notice. - -Use cases: -- Server sends task-augmented sampling/elicitation request to client -- Client creates a local task, spawns background work, returns CreateTaskResult -- Server polls client's task status via tasks/get, tasks/result, etc. -""" - -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Protocol - -from pydantic import TypeAdapter - -from mcp import types -from mcp.shared._context import RequestContext -from mcp.shared.session import RequestResponder - -if TYPE_CHECKING: - from mcp.client.session import ClientSession - - -class GetTaskHandlerFnT(Protocol): - """Handler for tasks/get requests from server. - - WARNING: This is experimental and may change without notice. - """ - - async def __call__( - self, - context: RequestContext[ClientSession], - params: types.GetTaskRequestParams, - ) -> types.GetTaskResult | types.ErrorData: ... # pragma: no branch - - -class GetTaskResultHandlerFnT(Protocol): - """Handler for tasks/result requests from server. - - WARNING: This is experimental and may change without notice. - """ - - async def __call__( - self, - context: RequestContext[ClientSession], - params: types.GetTaskPayloadRequestParams, - ) -> types.GetTaskPayloadResult | types.ErrorData: ... # pragma: no branch - - -class ListTasksHandlerFnT(Protocol): - """Handler for tasks/list requests from server. - - WARNING: This is experimental and may change without notice. - """ - - async def __call__( - self, - context: RequestContext[ClientSession], - params: types.PaginatedRequestParams | None, - ) -> types.ListTasksResult | types.ErrorData: ... # pragma: no branch - - -class CancelTaskHandlerFnT(Protocol): - """Handler for tasks/cancel requests from server. - - WARNING: This is experimental and may change without notice. - """ - - async def __call__( - self, - context: RequestContext[ClientSession], - params: types.CancelTaskRequestParams, - ) -> types.CancelTaskResult | types.ErrorData: ... # pragma: no branch - - -class TaskAugmentedSamplingFnT(Protocol): - """Handler for task-augmented sampling/createMessage requests from server. - - When server sends a CreateMessageRequest with task field, this callback - is invoked. The callback should create a task, spawn background work, - and return CreateTaskResult immediately. - - WARNING: This is experimental and may change without notice. - """ - - async def __call__( - self, - context: RequestContext[ClientSession], - params: types.CreateMessageRequestParams, - task_metadata: types.TaskMetadata, - ) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch - - -class TaskAugmentedElicitationFnT(Protocol): - """Handler for task-augmented elicitation/create requests from server. - - When server sends an ElicitRequest with task field, this callback - is invoked. The callback should create a task, spawn background work, - and return CreateTaskResult immediately. - - WARNING: This is experimental and may change without notice. - """ - - async def __call__( - self, - context: RequestContext[ClientSession], - params: types.ElicitRequestParams, - task_metadata: types.TaskMetadata, - ) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch - - -async def default_get_task_handler( - context: RequestContext[ClientSession], - params: types.GetTaskRequestParams, -) -> types.GetTaskResult | types.ErrorData: - return types.ErrorData( - code=types.METHOD_NOT_FOUND, - message="tasks/get not supported", - ) - - -async def default_get_task_result_handler( - context: RequestContext[ClientSession], - params: types.GetTaskPayloadRequestParams, -) -> types.GetTaskPayloadResult | types.ErrorData: - return types.ErrorData( - code=types.METHOD_NOT_FOUND, - message="tasks/result not supported", - ) - - -async def default_list_tasks_handler( - context: RequestContext[ClientSession], - params: types.PaginatedRequestParams | None, -) -> types.ListTasksResult | types.ErrorData: - return types.ErrorData( - code=types.METHOD_NOT_FOUND, - message="tasks/list not supported", - ) - - -async def default_cancel_task_handler( - context: RequestContext[ClientSession], - params: types.CancelTaskRequestParams, -) -> types.CancelTaskResult | types.ErrorData: - return types.ErrorData( - code=types.METHOD_NOT_FOUND, - message="tasks/cancel not supported", - ) - - -async def default_task_augmented_sampling( - context: RequestContext[ClientSession], - params: types.CreateMessageRequestParams, - task_metadata: types.TaskMetadata, -) -> types.CreateTaskResult | types.ErrorData: - return types.ErrorData( - code=types.INVALID_REQUEST, - message="Task-augmented sampling not supported", - ) - - -async def default_task_augmented_elicitation( - context: RequestContext[ClientSession], - params: types.ElicitRequestParams, - task_metadata: types.TaskMetadata, -) -> types.CreateTaskResult | types.ErrorData: - return types.ErrorData( - code=types.INVALID_REQUEST, - message="Task-augmented elicitation not supported", - ) - - -@dataclass -class ExperimentalTaskHandlers: - """Container for experimental task handlers. - - Groups all task-related handlers that handle server -> client requests. - This includes both pure task requests (get, list, cancel, result) and - task-augmented request handlers (sampling, elicitation with task field). - - WARNING: These APIs are experimental and may change without notice. - - Example: - ```python - handlers = ExperimentalTaskHandlers( - get_task=my_get_task_handler, - list_tasks=my_list_tasks_handler, - ) - session = ClientSession(..., experimental_task_handlers=handlers) - ``` - """ - - # Pure task request handlers - get_task: GetTaskHandlerFnT = field(default=default_get_task_handler) - get_task_result: GetTaskResultHandlerFnT = field(default=default_get_task_result_handler) - list_tasks: ListTasksHandlerFnT = field(default=default_list_tasks_handler) - cancel_task: CancelTaskHandlerFnT = field(default=default_cancel_task_handler) - - # Task-augmented request handlers - augmented_sampling: TaskAugmentedSamplingFnT = field(default=default_task_augmented_sampling) - augmented_elicitation: TaskAugmentedElicitationFnT = field(default=default_task_augmented_elicitation) - - def build_capability(self) -> types.ClientTasksCapability | None: - """Build ClientTasksCapability from the configured handlers. - - Returns a capability object that reflects which handlers are configured - (i.e., not using the default "not supported" handlers). - - Returns: - ClientTasksCapability if any handlers are provided, None otherwise - """ - has_list = self.list_tasks is not default_list_tasks_handler - has_cancel = self.cancel_task is not default_cancel_task_handler - has_sampling = self.augmented_sampling is not default_task_augmented_sampling - has_elicitation = self.augmented_elicitation is not default_task_augmented_elicitation - - # If no handlers are provided, return None - if not any([has_list, has_cancel, has_sampling, has_elicitation]): - return None - - # Build requests capability if any request handlers are provided - requests_capability: types.ClientTasksRequestsCapability | None = None - if has_sampling or has_elicitation: - requests_capability = types.ClientTasksRequestsCapability( - sampling=types.TasksSamplingCapability(create_message=types.TasksCreateMessageCapability()) - if has_sampling - else None, - elicitation=types.TasksElicitationCapability(create=types.TasksCreateElicitationCapability()) - if has_elicitation - else None, - ) - - return types.ClientTasksCapability( - list=types.TasksListCapability() if has_list else None, - cancel=types.TasksCancelCapability() if has_cancel else None, - requests=requests_capability, - ) - - @staticmethod - def handles_request(request: types.ServerRequest) -> bool: - """Check if this handler handles the given request type.""" - return isinstance( - request, - types.GetTaskRequest | types.GetTaskPayloadRequest | types.ListTasksRequest | types.CancelTaskRequest, - ) - - async def handle_request( - self, - ctx: RequestContext[ClientSession], - responder: RequestResponder[types.ServerRequest, types.ClientResult], - ) -> None: - """Handle a task-related request from the server. - - Call handles_request() first to check if this handler can handle the request. - """ - client_response_type: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter( - types.ClientResult | types.ErrorData - ) - - match responder.request: - case types.GetTaskRequest(params=params): - response = await self.get_task(ctx, params) - client_response = client_response_type.validate_python(response) - await responder.respond(client_response) - - case types.GetTaskPayloadRequest(params=params): - response = await self.get_task_result(ctx, params) - client_response = client_response_type.validate_python(response) - await responder.respond(client_response) - - case types.ListTasksRequest(params=params): - response = await self.list_tasks(ctx, params) - client_response = client_response_type.validate_python(response) - await responder.respond(client_response) - - case types.CancelTaskRequest(params=params): - response = await self.cancel_task(ctx, params) - client_response = client_response_type.validate_python(response) - await responder.respond(client_response) - - case _: # pragma: no cover - raise ValueError(f"Unhandled request type: {type(responder.request)}") - - -# Backwards compatibility aliases -default_task_augmented_sampling_callback = default_task_augmented_sampling -default_task_augmented_elicitation_callback = default_task_augmented_elicitation diff --git a/src/mcp/client/experimental/tasks.py b/src/mcp/client/experimental/tasks.py deleted file mode 100644 index a566df766b..0000000000 --- a/src/mcp/client/experimental/tasks.py +++ /dev/null @@ -1,208 +0,0 @@ -"""Experimental client-side task support. - -This module provides client methods for interacting with MCP tasks. - -WARNING: These APIs are experimental and may change without notice. - -Example: - ```python - # Call a tool as a task - result = await session.experimental.call_tool_as_task("tool_name", {"arg": "value"}) - task_id = result.task.task_id - - # Get task status - status = await session.experimental.get_task(task_id) - - # Get task result when complete - if status.status == "completed": - result = await session.experimental.get_task_result(task_id, CallToolResult) - - # List all tasks - tasks = await session.experimental.list_tasks() - - # Cancel a task - await session.experimental.cancel_task(task_id) - ``` -""" - -from collections.abc import AsyncIterator -from typing import TYPE_CHECKING, Any, TypeVar - -from mcp import types -from mcp.shared.experimental.tasks.polling import poll_until_terminal -from mcp.types._types import RequestParamsMeta - -if TYPE_CHECKING: - from mcp.client.session import ClientSession - -ResultT = TypeVar("ResultT", bound=types.Result) - - -class ExperimentalClientFeatures: - """Experimental client features for tasks and other experimental APIs. - - WARNING: These APIs are experimental and may change without notice. - - Access via session.experimental: - status = await session.experimental.get_task(task_id) - """ - - def __init__(self, session: "ClientSession") -> None: - self._session = session - - async def call_tool_as_task( - self, - name: str, - arguments: dict[str, Any] | None = None, - *, - ttl: int = 60000, - meta: RequestParamsMeta | None = None, - ) -> types.CreateTaskResult: - """Call a tool as a task, returning a CreateTaskResult for polling. - - This is a convenience method for calling tools that support task execution. - The server will return a task reference instead of the immediate result, - which can then be polled via `get_task()` and retrieved via `get_task_result()`. - - Args: - name: The tool name - arguments: Tool arguments - ttl: Task time-to-live in milliseconds (default: 60000 = 1 minute) - meta: Optional metadata to include in the request - - Returns: - CreateTaskResult containing the task reference - - Example: - ```python - # Create task - result = await session.experimental.call_tool_as_task( - "long_running_tool", {"input": "data"} - ) - task_id = result.task.task_id - - # Poll for completion - while True: - status = await session.experimental.get_task(task_id) - if status.status == "completed": - break - await anyio.sleep(0.5) - - # Get result - final = await session.experimental.get_task_result(task_id, CallToolResult) - ``` - """ - return await self._session.send_request( - types.CallToolRequest( - params=types.CallToolRequestParams( - name=name, - arguments=arguments, - task=types.TaskMetadata(ttl=ttl), - _meta=meta, - ), - ), - types.CreateTaskResult, - ) - - async def get_task(self, task_id: str) -> types.GetTaskResult: - """Get the current status of a task. - - Args: - task_id: The task identifier - - Returns: - GetTaskResult containing the task status and metadata - """ - return await self._session.send_request( - types.GetTaskRequest(params=types.GetTaskRequestParams(task_id=task_id)), - types.GetTaskResult, - ) - - async def get_task_result( - self, - task_id: str, - result_type: type[ResultT], - ) -> ResultT: - """Get the result of a completed task. - - The result type depends on the original request type: - - tools/call tasks return CallToolResult - - Other request types return their corresponding result type - - Args: - task_id: The task identifier - result_type: The expected result type (e.g., CallToolResult) - - Returns: - The task result, validated against result_type - """ - return await self._session.send_request( - types.GetTaskPayloadRequest( - params=types.GetTaskPayloadRequestParams(task_id=task_id), - ), - result_type, - ) - - async def list_tasks( - self, - cursor: str | None = None, - ) -> types.ListTasksResult: - """List all tasks. - - Args: - cursor: Optional pagination cursor - - Returns: - ListTasksResult containing tasks and optional next cursor - """ - params = types.PaginatedRequestParams(cursor=cursor) if cursor else None - return await self._session.send_request( - types.ListTasksRequest(params=params), - types.ListTasksResult, - ) - - async def cancel_task(self, task_id: str) -> types.CancelTaskResult: - """Cancel a running task. - - Args: - task_id: The task identifier - - Returns: - CancelTaskResult with the updated task state - """ - return await self._session.send_request( - types.CancelTaskRequest( - params=types.CancelTaskRequestParams(task_id=task_id), - ), - types.CancelTaskResult, - ) - - async def poll_task(self, task_id: str) -> AsyncIterator[types.GetTaskResult]: - """Poll a task until it reaches a terminal status. - - Yields GetTaskResult for each poll, allowing the caller to react to - status changes (e.g., handle input_required). Exits when the task reaches - a terminal status (completed, failed, cancelled). - - Respects the pollInterval hint from the server. - - Args: - task_id: The task identifier - - Yields: - GetTaskResult for each poll - - Example: - ```python - async for status in session.experimental.poll_task(task_id): - print(f"Status: {status.status}") - if status.status == "input_required": - # Handle elicitation request via tasks/result - pass - - # Task is now terminal, get the result - result = await session.experimental.get_task_result(task_id, CallToolResult) - ``` - """ - async for status in poll_until_terminal(self.get_task, task_id): - yield status diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 86113874be..08f532eca5 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -8,8 +8,6 @@ from mcp import types from mcp.client._transport import ReadStream, WriteStream -from mcp.client.experimental import ExperimentalClientFeatures -from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers from mcp.shared._context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder @@ -120,7 +118,6 @@ def __init__( client_info: types.Implementation | None = None, *, sampling_capabilities: types.SamplingCapability | None = None, - experimental_task_handlers: ExperimentalTaskHandlers | None = None, ) -> None: super().__init__(read_stream, write_stream, read_timeout_seconds=read_timeout_seconds) self._client_info = client_info or DEFAULT_CLIENT_INFO @@ -132,10 +129,6 @@ def __init__( self._message_handler = message_handler or _default_message_handler self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} self._initialize_result: types.InitializeResult | None = None - self._experimental_features: ExperimentalClientFeatures | None = None - - # Experimental: Task handlers (use defaults if not provided) - self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers() @property def _receive_request_adapter(self) -> TypeAdapter[types.ServerRequest]: @@ -174,7 +167,6 @@ async def initialize(self) -> types.InitializeResult: elicitation=elicitation, experimental=None, roots=roots, - tasks=self._task_handlers.build_capability(), ), client_info=self._client_info, ), @@ -199,23 +191,6 @@ def initialize_result(self) -> types.InitializeResult | None: """ return self._initialize_result - @property - def experimental(self) -> ExperimentalClientFeatures: - """Experimental APIs for tasks and other features. - - !!! warning - These APIs are experimental and may change without notice. - - Example: - ```python - status = await session.experimental.get_task(task_id) - result = await session.experimental.get_task_result(task_id, CallToolResult) - ``` - """ - if self._experimental_features is None: - self._experimental_features = ExperimentalClientFeatures(self) - return self._experimental_features - async def send_ping(self, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult: """Send a ping request.""" return await self.send_request(types.PingRequest(params=types.RequestParams(_meta=meta)), types.EmptyResult) @@ -413,31 +388,16 @@ async def send_roots_list_changed(self) -> None: async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None: ctx = RequestContext[ClientSession](request_id=responder.request_id, meta=responder.request_meta, session=self) - # Delegate to experimental task handler if applicable - if self._task_handlers.handles_request(responder.request): - with responder: - await self._task_handlers.handle_request(ctx, responder) - return None - - # Core request handling match responder.request: case types.CreateMessageRequest(params=params): with responder: - # Check if this is a task-augmented request - if params.task is not None: - response = await self._task_handlers.augmented_sampling(ctx, params, params.task) - else: - response = await self._sampling_callback(ctx, params) + response = await self._sampling_callback(ctx, params) client_response = ClientResponse.validate_python(response) await responder.respond(client_response) case types.ElicitRequest(params=params): with responder: - # Check if this is a task-augmented request - if params.task is not None: - response = await self._task_handlers.augmented_elicitation(ctx, params, params.task) - else: - response = await self._elicitation_callback(ctx, params) + response = await self._elicitation_callback(ctx, params) client_response = ClientResponse.validate_python(response) await responder.respond(client_response) @@ -447,14 +407,9 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques client_response = ClientResponse.validate_python(response) await responder.respond(client_response) - case types.PingRequest(): + case types.PingRequest(): # pragma: no branch with responder: - return await responder.respond(types.EmptyResult()) - - case _: # pragma: no cover - pass # Task requests handled above by _task_handlers - - return None + await responder.respond(types.EmptyResult()) async def _handle_incoming( self, diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index d8e11d78b2..bc54c5d2eb 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -5,7 +5,6 @@ from typing_extensions import TypeVar -from mcp.server.experimental.request_context import Experimental from mcp.server.session import ServerSession from mcp.shared._context import RequestContext from mcp.shared.message import CloseSSEStreamCallback @@ -17,7 +16,6 @@ @dataclass(kw_only=True) class ServerRequestContext(RequestContext[ServerSession], Generic[LifespanContextT, RequestT]): lifespan_context: LifespanContextT - experimental: Experimental request: RequestT | None = None close_sse_stream: CloseSSEStreamCallback | None = None close_standalone_sse_stream: CloseSSEStreamCallback | None = None diff --git a/src/mcp/server/experimental/__init__.py b/src/mcp/server/experimental/__init__.py deleted file mode 100644 index fd1db623f2..0000000000 --- a/src/mcp/server/experimental/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Server-side experimental features. - -WARNING: These APIs are experimental and may change without notice. - -Import directly from submodules: -- mcp.server.experimental.task_context.ServerTaskContext -- mcp.server.experimental.task_support.TaskSupport -- mcp.server.experimental.task_result_handler.TaskResultHandler -- mcp.server.experimental.request_context.Experimental -""" diff --git a/src/mcp/server/experimental/request_context.py b/src/mcp/server/experimental/request_context.py deleted file mode 100644 index 3eba65822a..0000000000 --- a/src/mcp/server/experimental/request_context.py +++ /dev/null @@ -1,217 +0,0 @@ -"""Experimental request context features. - -This module provides the Experimental class which gives access to experimental -features within a request context, such as task-augmented request handling. - -WARNING: These APIs are experimental and may change without notice. -""" - -from collections.abc import Awaitable, Callable -from dataclasses import dataclass, field -from typing import Any - -from mcp.server.experimental.task_context import ServerTaskContext -from mcp.server.experimental.task_support import TaskSupport -from mcp.server.session import ServerSession -from mcp.shared.exceptions import MCPError -from mcp.shared.experimental.tasks.helpers import MODEL_IMMEDIATE_RESPONSE_KEY, is_terminal -from mcp.types import ( - METHOD_NOT_FOUND, - TASK_FORBIDDEN, - TASK_REQUIRED, - ClientCapabilities, - CreateTaskResult, - ErrorData, - Result, - TaskExecutionMode, - TaskMetadata, - Tool, -) - - -@dataclass -class Experimental: - """Experimental features context for task-augmented requests. - - Provides helpers for validating task execution compatibility and - running tasks with automatic lifecycle management. - - WARNING: This API is experimental and may change without notice. - """ - - task_metadata: TaskMetadata | None = None - _client_capabilities: ClientCapabilities | None = field(default=None, repr=False) - _session: ServerSession | None = field(default=None, repr=False) - _task_support: TaskSupport | None = field(default=None, repr=False) - - @property - def is_task(self) -> bool: - """Check if this request is task-augmented.""" - return self.task_metadata is not None - - @property - def client_supports_tasks(self) -> bool: - """Check if the client declared task support.""" - if self._client_capabilities is None: - return False - return self._client_capabilities.tasks is not None - - def validate_task_mode( - self, tool_task_mode: TaskExecutionMode | None, *, raise_error: bool = True - ) -> ErrorData | None: - """Validate that the request is compatible with the tool's task execution mode. - - Per MCP spec: - - "required": Clients MUST invoke as a task. Server returns -32601 if not. - - "forbidden" (or None): Clients MUST NOT invoke as a task. Server returns -32601 if they do. - - "optional": Either is acceptable. - - Args: - tool_task_mode: The tool's execution.taskSupport value - ("forbidden", "optional", "required", or None) - raise_error: If True, raises MCPError on validation failure. If False, returns ErrorData. - - Returns: - None if valid, ErrorData if invalid and raise_error=False - - Raises: - MCPError: If invalid and raise_error=True - """ - - mode = tool_task_mode or TASK_FORBIDDEN - - error: ErrorData | None = None - - if mode == TASK_REQUIRED and not self.is_task: - error = ErrorData(code=METHOD_NOT_FOUND, message="This tool requires task-augmented invocation") - elif mode == TASK_FORBIDDEN and self.is_task: - error = ErrorData(code=METHOD_NOT_FOUND, message="This tool does not support task-augmented invocation") - - if error is not None and raise_error: - raise MCPError.from_error_data(error) - - return error - - def validate_for_tool(self, tool: Tool, *, raise_error: bool = True) -> ErrorData | None: - """Validate that the request is compatible with the given tool. - - Convenience wrapper around validate_task_mode that extracts the mode from a Tool. - - Args: - tool: The Tool definition - raise_error: If True, raises MCPError on validation failure. - - Returns: - None if valid, ErrorData if invalid and raise_error=False - """ - mode = tool.execution.task_support if tool.execution else None - return self.validate_task_mode(mode, raise_error=raise_error) - - def can_use_tool(self, tool_task_mode: TaskExecutionMode | None) -> bool: - """Check if this client can use a tool with the given task mode. - - Useful for filtering tool lists or providing warnings. - Returns False if the tool's task mode is "required" but the client doesn't support tasks. - - Args: - tool_task_mode: The tool's execution.taskSupport value - - Returns: - True if the client can use this tool, False otherwise - """ - mode = tool_task_mode or TASK_FORBIDDEN - if mode == TASK_REQUIRED and not self.client_supports_tasks: - return False - return True - - async def run_task( - self, - work: Callable[[ServerTaskContext], Awaitable[Result]], - *, - task_id: str | None = None, - model_immediate_response: str | None = None, - ) -> CreateTaskResult: - """Create a task, spawn background work, and return CreateTaskResult immediately. - - This is the recommended way to handle task-augmented tool calls. It: - 1. Creates a task in the store - 2. Spawns the work function in a background task - 3. Returns CreateTaskResult immediately - - The work function receives a ServerTaskContext with: - - elicit() for sending elicitation requests - - create_message() for sampling requests - - update_status() for progress updates - - complete()/fail() for finishing the task - - When work() returns a Result, the task is auto-completed with that result. - If work() raises an exception, the task is auto-failed. - - Args: - work: Async function that does the actual work - task_id: Optional task ID (generated if not provided) - model_immediate_response: Optional string to include in _meta as - io.modelcontextprotocol/model-immediate-response - - Returns: - CreateTaskResult to return to the client - - Raises: - RuntimeError: If task support is not enabled or task_metadata is missing - - Example: - ```python - async def handle_tool(ctx: RequestContext, params: CallToolRequestParams) -> CallToolResult: - async def work(task: ServerTaskContext) -> CallToolResult: - result = await task.elicit( - message="Are you sure?", - requested_schema={"type": "object", ...} - ) - confirmed = result.content.get("confirm", False) - return CallToolResult(content=[TextContent(text="Done" if confirmed else "Cancelled")]) - - return await ctx.experimental.run_task(work) - ``` - - WARNING: This API is experimental and may change without notice. - """ - if self._task_support is None: - raise RuntimeError("Task support not enabled. Call server.experimental.enable_tasks() first.") - if self._session is None: - raise RuntimeError("Session not available.") - if self.task_metadata is None: - raise RuntimeError( - "Request is not task-augmented (no task field in params). " - "The client must send a task-augmented request." - ) - - support = self._task_support - # Access task_group via TaskSupport - raises if not in run() context - task_group = support.task_group - - task = await support.store.create_task(self.task_metadata, task_id) - - task_ctx = ServerTaskContext( - task=task, - store=support.store, - session=self._session, - queue=support.queue, - handler=support.handler, - ) - - async def execute() -> None: - try: - result = await work(task_ctx) - if not is_terminal(task_ctx.task.status): - await task_ctx.complete(result) - except Exception as e: - if not is_terminal(task_ctx.task.status): - await task_ctx.fail(str(e)) - - task_group.start_soon(execute) - - meta: dict[str, Any] | None = None - if model_immediate_response is not None: - meta = {MODEL_IMMEDIATE_RESPONSE_KEY: model_immediate_response} - - return CreateTaskResult(task=task, **{"_meta": meta} if meta else {}) diff --git a/src/mcp/server/experimental/session_features.py b/src/mcp/server/experimental/session_features.py deleted file mode 100644 index 2f9d1b0320..0000000000 --- a/src/mcp/server/experimental/session_features.py +++ /dev/null @@ -1,209 +0,0 @@ -"""Experimental server session features for server→client task operations. - -This module provides the server-side equivalent of ExperimentalClientFeatures, -allowing the server to send task-augmented requests to the client and poll for results. - -WARNING: These APIs are experimental and may change without notice. -""" - -from collections.abc import AsyncIterator -from typing import TYPE_CHECKING, Any, TypeVar - -from mcp import types -from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages -from mcp.shared.experimental.tasks.capabilities import ( - require_task_augmented_elicitation, - require_task_augmented_sampling, -) -from mcp.shared.experimental.tasks.polling import poll_until_terminal - -if TYPE_CHECKING: - from mcp.server.session import ServerSession - -ResultT = TypeVar("ResultT", bound=types.Result) - - -class ExperimentalServerSessionFeatures: - """Experimental server session features for server→client task operations. - - This provides the server-side equivalent of ExperimentalClientFeatures, - allowing the server to send task-augmented requests to the client and - poll for results. - - WARNING: These APIs are experimental and may change without notice. - - Access via session.experimental: - result = await session.experimental.elicit_as_task(...) - """ - - def __init__(self, session: "ServerSession") -> None: - self._session = session - - async def get_task(self, task_id: str) -> types.GetTaskResult: - """Send tasks/get to the client to get task status. - - Args: - task_id: The task identifier - - Returns: - GetTaskResult containing the task status - """ - return await self._session.send_request( - types.GetTaskRequest(params=types.GetTaskRequestParams(task_id=task_id)), - types.GetTaskResult, - ) - - async def get_task_result( - self, - task_id: str, - result_type: type[ResultT], - ) -> ResultT: - """Send tasks/result to the client to retrieve the final result. - - Args: - task_id: The task identifier - result_type: The expected result type - - Returns: - The task result, validated against result_type - """ - return await self._session.send_request( - types.GetTaskPayloadRequest(params=types.GetTaskPayloadRequestParams(task_id=task_id)), - result_type, - ) - - async def poll_task(self, task_id: str) -> AsyncIterator[types.GetTaskResult]: - """Poll a client task until it reaches terminal status. - - Yields GetTaskResult for each poll, allowing the caller to react to - status changes. Exits when task reaches a terminal status. - - Respects the pollInterval hint from the client. - - Args: - task_id: The task identifier - - Yields: - GetTaskResult for each poll - """ - async for status in poll_until_terminal(self.get_task, task_id): - yield status - - async def elicit_as_task( - self, - message: str, - requested_schema: types.ElicitRequestedSchema, - *, - ttl: int = 60000, - ) -> types.ElicitResult: - """Send a task-augmented elicitation to the client and poll until complete. - - The client will create a local task, process the elicitation asynchronously, - and return the result when ready. This method handles the full flow: - 1. Send elicitation with task field - 2. Receive CreateTaskResult from client - 3. Poll client's task until terminal - 4. Retrieve and return the final ElicitResult - - Args: - message: The message to present to the user - requested_schema: Schema defining the expected response - ttl: Task time-to-live in milliseconds - - Returns: - The client's elicitation response - - Raises: - MCPError: If client doesn't support task-augmented elicitation - """ - client_caps = self._session.client_params.capabilities if self._session.client_params else None - require_task_augmented_elicitation(client_caps) - - create_result = await self._session.send_request( - types.ElicitRequest( - params=types.ElicitRequestFormParams( - message=message, - requested_schema=requested_schema, - task=types.TaskMetadata(ttl=ttl), - ) - ), - types.CreateTaskResult, - ) - - task_id = create_result.task.task_id - - async for _ in self.poll_task(task_id): - pass - - return await self.get_task_result(task_id, types.ElicitResult) - - async def create_message_as_task( - self, - messages: list[types.SamplingMessage], - *, - max_tokens: int, - ttl: int = 60000, - system_prompt: str | None = None, - include_context: types.IncludeContext | None = None, - temperature: float | None = None, - stop_sequences: list[str] | None = None, - metadata: dict[str, Any] | None = None, - model_preferences: types.ModelPreferences | None = None, - tools: list[types.Tool] | None = None, - tool_choice: types.ToolChoice | None = None, - ) -> types.CreateMessageResult: - """Send a task-augmented sampling request and poll until complete. - - The client will create a local task, process the sampling request - asynchronously, and return the result when ready. - - Args: - messages: The conversation messages for sampling - max_tokens: Maximum tokens in the response - ttl: Task time-to-live in milliseconds - system_prompt: Optional system prompt - include_context: Context inclusion strategy - temperature: Sampling temperature - stop_sequences: Stop sequences - metadata: Additional metadata - model_preferences: Model selection preferences - tools: Optional list of tools the LLM can use during sampling - tool_choice: Optional control over tool usage behavior - - Returns: - The sampling result from the client - - Raises: - MCPError: If client doesn't support task-augmented sampling or tools - ValueError: If tool_use or tool_result message structure is invalid - """ - client_caps = self._session.client_params.capabilities if self._session.client_params else None - require_task_augmented_sampling(client_caps) - validate_sampling_tools(client_caps, tools, tool_choice) - validate_tool_use_result_messages(messages) - - create_result = await self._session.send_request( - types.CreateMessageRequest( - params=types.CreateMessageRequestParams( - messages=messages, - max_tokens=max_tokens, - system_prompt=system_prompt, - include_context=include_context, - temperature=temperature, - stop_sequences=stop_sequences, - metadata=metadata, - model_preferences=model_preferences, - tools=tools, - tool_choice=tool_choice, - task=types.TaskMetadata(ttl=ttl), - ) - ), - types.CreateTaskResult, - ) - - task_id = create_result.task.task_id - - async for _ in self.poll_task(task_id): - pass - - return await self.get_task_result(task_id, types.CreateMessageResult) diff --git a/src/mcp/server/experimental/task_context.py b/src/mcp/server/experimental/task_context.py deleted file mode 100644 index 1fc45badfd..0000000000 --- a/src/mcp/server/experimental/task_context.py +++ /dev/null @@ -1,587 +0,0 @@ -"""ServerTaskContext - Server-integrated task context with elicitation and sampling. - -This wraps the pure TaskContext and adds server-specific functionality: -- Elicitation (task.elicit()) -- Sampling (task.create_message()) -- Status notifications -""" - -from typing import Any - -import anyio - -from mcp.server.experimental.task_result_handler import TaskResultHandler -from mcp.server.session import ServerSession -from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages -from mcp.shared.exceptions import MCPError -from mcp.shared.experimental.tasks.capabilities import ( - require_task_augmented_elicitation, - require_task_augmented_sampling, -) -from mcp.shared.experimental.tasks.context import TaskContext -from mcp.shared.experimental.tasks.message_queue import QueuedMessage, TaskMessageQueue -from mcp.shared.experimental.tasks.resolver import Resolver -from mcp.shared.experimental.tasks.store import TaskStore -from mcp.types import ( - INVALID_REQUEST, - TASK_STATUS_INPUT_REQUIRED, - TASK_STATUS_WORKING, - ClientCapabilities, - CreateMessageResult, - CreateTaskResult, - ElicitationCapability, - ElicitRequestedSchema, - ElicitResult, - IncludeContext, - ModelPreferences, - RequestId, - Result, - SamplingCapability, - SamplingMessage, - Task, - TaskMetadata, - TaskStatusNotification, - TaskStatusNotificationParams, - Tool, - ToolChoice, -) - - -class ServerTaskContext: - """Server-integrated task context with elicitation and sampling. - - This wraps a pure TaskContext and adds server-specific functionality: - - elicit() for sending elicitation requests to the client - - create_message() for sampling requests - - Status notifications via the session - - Example: - ```python - async def my_task_work(task: ServerTaskContext) -> CallToolResult: - await task.update_status("Starting...") - - result = await task.elicit( - message="Continue?", - requested_schema={"type": "object", "properties": {"ok": {"type": "boolean"}}} - ) - - if result.content.get("ok"): - return CallToolResult(content=[TextContent(text="Done!")]) - else: - return CallToolResult(content=[TextContent(text="Cancelled")]) - ``` - """ - - def __init__( - self, - *, - task: Task, - store: TaskStore, - session: ServerSession, - queue: TaskMessageQueue, - handler: TaskResultHandler | None = None, - ): - """Create a ServerTaskContext. - - Args: - task: The Task object - store: The task store - session: The server session - queue: The message queue for elicitation/sampling - handler: The result handler for response routing (required for elicit/create_message) - """ - self._ctx = TaskContext(task=task, store=store) - self._session = session - self._queue = queue - self._handler = handler - self._store = store - - # Delegate pure properties to inner context - - @property - def task_id(self) -> str: - """The task identifier.""" - return self._ctx.task_id - - @property - def task(self) -> Task: - """The current task state.""" - return self._ctx.task - - @property - def is_cancelled(self) -> bool: - """Whether cancellation has been requested.""" - return self._ctx.is_cancelled - - def request_cancellation(self) -> None: - """Request cancellation of this task.""" - self._ctx.request_cancellation() - - # Enhanced methods with notifications - - async def update_status(self, message: str, *, notify: bool = True) -> None: - """Update the task's status message. - - Args: - message: The new status message - notify: Whether to send a notification to the client - """ - await self._ctx.update_status(message) - if notify: - await self._send_notification() - - async def complete(self, result: Result, *, notify: bool = True) -> None: - """Mark the task as completed with the given result. - - Args: - result: The task result - notify: Whether to send a notification to the client - """ - await self._ctx.complete(result) - if notify: - await self._send_notification() - - async def fail(self, error: str, *, notify: bool = True) -> None: - """Mark the task as failed with an error message. - - Args: - error: The error message - notify: Whether to send a notification to the client - """ - await self._ctx.fail(error) - if notify: - await self._send_notification() - - async def _send_notification(self) -> None: - """Send a task status notification to the client.""" - task = self._ctx.task - await self._session.send_notification( - TaskStatusNotification( - params=TaskStatusNotificationParams( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=task.poll_interval, - ) - ) - ) - - # Server-specific methods: elicitation and sampling - - def _check_elicitation_capability(self) -> None: - """Check if the client supports elicitation.""" - if not self._session.check_client_capability(ClientCapabilities(elicitation=ElicitationCapability())): - raise MCPError(code=INVALID_REQUEST, message="Client does not support elicitation capability") - - def _check_sampling_capability(self) -> None: - """Check if the client supports sampling.""" - if not self._session.check_client_capability(ClientCapabilities(sampling=SamplingCapability())): - raise MCPError(code=INVALID_REQUEST, message="Client does not support sampling capability") - - async def elicit( - self, - message: str, - requested_schema: ElicitRequestedSchema, - ) -> ElicitResult: - """Send an elicitation request via the task message queue. - - This method: - 1. Checks client capability - 2. Updates task status to "input_required" - 3. Queues the elicitation request - 4. Waits for the response (delivered via tasks/result round-trip) - 5. Updates task status back to "working" - 6. Returns the result - - Args: - message: The message to present to the user - requested_schema: Schema defining the expected response structure - - Returns: - The client's response - - Raises: - MCPError: If client doesn't support elicitation capability - """ - self._check_elicitation_capability() - - if self._handler is None: - raise RuntimeError("handler is required for elicit(). Pass handler= to ServerTaskContext.") - - # Update status to input_required - await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED) - - # Build the request using session's helper - request = self._session._build_elicit_form_request( # pyright: ignore[reportPrivateUsage] - message=message, - requested_schema=requested_schema, - related_task_id=self.task_id, - ) - request_id: RequestId = request.id - - resolver: Resolver[dict[str, Any]] = Resolver() - self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage] - - queued = QueuedMessage( - type="request", - message=request, - resolver=resolver, - original_request_id=request_id, - ) - await self._queue.enqueue(self.task_id, queued) - - try: - # Wait for response (routed back via TaskResultHandler) - response_data = await resolver.wait() - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) - return ElicitResult.model_validate(response_data) - except anyio.get_cancelled_exc_class(): - # This path is tested in test_elicit_restores_status_on_cancellation - # which verifies status is restored to "working" after cancellation. - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) - raise - - async def elicit_url( - self, - message: str, - url: str, - elicitation_id: str, - ) -> ElicitResult: - """Send a URL mode elicitation request via the task message queue. - - This directs the user to an external URL for out-of-band interactions - like OAuth flows, credential collection, or payment processing. - - This method: - 1. Checks client capability - 2. Updates task status to "input_required" - 3. Queues the elicitation request - 4. Waits for the response (delivered via tasks/result round-trip) - 5. Updates task status back to "working" - 6. Returns the result - - Args: - message: Human-readable explanation of why the interaction is needed - url: The URL the user should navigate to - elicitation_id: Unique identifier for tracking this elicitation - - Returns: - The client's response indicating acceptance, decline, or cancellation - - Raises: - MCPError: If client doesn't support elicitation capability - RuntimeError: If handler is not configured - """ - self._check_elicitation_capability() - - if self._handler is None: - raise RuntimeError("handler is required for elicit_url(). Pass handler= to ServerTaskContext.") - - # Update status to input_required - await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED) - - # Build the request using session's helper - request = self._session._build_elicit_url_request( # pyright: ignore[reportPrivateUsage] - message=message, - url=url, - elicitation_id=elicitation_id, - related_task_id=self.task_id, - ) - request_id: RequestId = request.id - - resolver: Resolver[dict[str, Any]] = Resolver() - self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage] - - queued = QueuedMessage( - type="request", - message=request, - resolver=resolver, - original_request_id=request_id, - ) - await self._queue.enqueue(self.task_id, queued) - - try: - # Wait for response (routed back via TaskResultHandler) - response_data = await resolver.wait() - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) - return ElicitResult.model_validate(response_data) - except anyio.get_cancelled_exc_class(): # pragma: no cover - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) - raise - - async def create_message( - self, - messages: list[SamplingMessage], - *, - max_tokens: int, - system_prompt: str | None = None, - include_context: IncludeContext | None = None, - temperature: float | None = None, - stop_sequences: list[str] | None = None, - metadata: dict[str, Any] | None = None, - model_preferences: ModelPreferences | None = None, - tools: list[Tool] | None = None, - tool_choice: ToolChoice | None = None, - ) -> CreateMessageResult: - """Send a sampling request via the task message queue. - - This method: - 1. Checks client capability - 2. Updates task status to "input_required" - 3. Queues the sampling request - 4. Waits for the response (delivered via tasks/result round-trip) - 5. Updates task status back to "working" - 6. Returns the result - - Args: - messages: The conversation messages for sampling - max_tokens: Maximum tokens in the response - system_prompt: Optional system prompt - include_context: Context inclusion strategy - temperature: Sampling temperature - stop_sequences: Stop sequences - metadata: Additional metadata - model_preferences: Model selection preferences - tools: Optional list of tools the LLM can use during sampling - tool_choice: Optional control over tool usage behavior - - Returns: - The sampling result from the client - - Raises: - MCPError: If client doesn't support sampling capability or tools - ValueError: If tool_use or tool_result message structure is invalid - """ - self._check_sampling_capability() - client_caps = self._session.client_params.capabilities if self._session.client_params else None - validate_sampling_tools(client_caps, tools, tool_choice) - validate_tool_use_result_messages(messages) - - if self._handler is None: - raise RuntimeError("handler is required for create_message(). Pass handler= to ServerTaskContext.") - - # Update status to input_required - await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED) - - # Build the request using session's helper - request = self._session._build_create_message_request( # pyright: ignore[reportPrivateUsage] - messages=messages, - max_tokens=max_tokens, - system_prompt=system_prompt, - include_context=include_context, - temperature=temperature, - stop_sequences=stop_sequences, - metadata=metadata, - model_preferences=model_preferences, - tools=tools, - tool_choice=tool_choice, - related_task_id=self.task_id, - ) - request_id: RequestId = request.id - - resolver: Resolver[dict[str, Any]] = Resolver() - self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage] - - queued = QueuedMessage( - type="request", - message=request, - resolver=resolver, - original_request_id=request_id, - ) - await self._queue.enqueue(self.task_id, queued) - - try: - # Wait for response (routed back via TaskResultHandler) - response_data = await resolver.wait() - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) - return CreateMessageResult.model_validate(response_data) - except anyio.get_cancelled_exc_class(): - # This path is tested in test_create_message_restores_status_on_cancellation - # which verifies status is restored to "working" after cancellation. - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) - raise - - async def elicit_as_task( - self, - message: str, - requested_schema: ElicitRequestedSchema, - *, - ttl: int = 60000, - ) -> ElicitResult: - """Send a task-augmented elicitation via the queue, then poll client. - - This is for use inside a task-augmented tool call when you want the client - to handle the elicitation as its own task. The elicitation request is queued - and delivered when the client calls tasks/result. After the client responds - with CreateTaskResult, we poll the client's task until complete. - - Args: - message: The message to present to the user - requested_schema: Schema defining the expected response structure - ttl: Task time-to-live in milliseconds for the client's task - - Returns: - The client's elicitation response - - Raises: - MCPError: If client doesn't support task-augmented elicitation - RuntimeError: If handler is not configured - """ - client_caps = self._session.client_params.capabilities if self._session.client_params else None - require_task_augmented_elicitation(client_caps) - - if self._handler is None: - raise RuntimeError("handler is required for elicit_as_task()") - - # Update status to input_required - await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED) - - request = self._session._build_elicit_form_request( # pyright: ignore[reportPrivateUsage] - message=message, - requested_schema=requested_schema, - related_task_id=self.task_id, - task=TaskMetadata(ttl=ttl), - ) - request_id: RequestId = request.id - - resolver: Resolver[dict[str, Any]] = Resolver() - self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage] - - queued = QueuedMessage( - type="request", - message=request, - resolver=resolver, - original_request_id=request_id, - ) - await self._queue.enqueue(self.task_id, queued) - - try: - # Wait for initial response (CreateTaskResult from client) - response_data = await resolver.wait() - create_result = CreateTaskResult.model_validate(response_data) - client_task_id = create_result.task.task_id - - # Poll the client's task using session.experimental - async for _ in self._session.experimental.poll_task(client_task_id): - pass - - # Get final result from client - result = await self._session.experimental.get_task_result( - client_task_id, - ElicitResult, - ) - - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) - return result - - except anyio.get_cancelled_exc_class(): # pragma: no cover - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) - raise - - async def create_message_as_task( - self, - messages: list[SamplingMessage], - *, - max_tokens: int, - ttl: int = 60000, - system_prompt: str | None = None, - include_context: IncludeContext | None = None, - temperature: float | None = None, - stop_sequences: list[str] | None = None, - metadata: dict[str, Any] | None = None, - model_preferences: ModelPreferences | None = None, - tools: list[Tool] | None = None, - tool_choice: ToolChoice | None = None, - ) -> CreateMessageResult: - """Send a task-augmented sampling request via the queue, then poll client. - - This is for use inside a task-augmented tool call when you want the client - to handle the sampling as its own task. The request is queued and delivered - when the client calls tasks/result. After the client responds with - CreateTaskResult, we poll the client's task until complete. - - Args: - messages: The conversation messages for sampling - max_tokens: Maximum tokens in the response - ttl: Task time-to-live in milliseconds for the client's task - system_prompt: Optional system prompt - include_context: Context inclusion strategy - temperature: Sampling temperature - stop_sequences: Stop sequences - metadata: Additional metadata - model_preferences: Model selection preferences - tools: Optional list of tools the LLM can use during sampling - tool_choice: Optional control over tool usage behavior - - Returns: - The sampling result from the client - - Raises: - MCPError: If client doesn't support task-augmented sampling or tools - ValueError: If tool_use or tool_result message structure is invalid - RuntimeError: If handler is not configured - """ - client_caps = self._session.client_params.capabilities if self._session.client_params else None - require_task_augmented_sampling(client_caps) - validate_sampling_tools(client_caps, tools, tool_choice) - validate_tool_use_result_messages(messages) - - if self._handler is None: - raise RuntimeError("handler is required for create_message_as_task()") - - # Update status to input_required - await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED) - - # Build request WITH task field for task-augmented sampling - request = self._session._build_create_message_request( # pyright: ignore[reportPrivateUsage] - messages=messages, - max_tokens=max_tokens, - system_prompt=system_prompt, - include_context=include_context, - temperature=temperature, - stop_sequences=stop_sequences, - metadata=metadata, - model_preferences=model_preferences, - tools=tools, - tool_choice=tool_choice, - related_task_id=self.task_id, - task=TaskMetadata(ttl=ttl), - ) - request_id: RequestId = request.id - - resolver: Resolver[dict[str, Any]] = Resolver() - self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage] - - queued = QueuedMessage( - type="request", - message=request, - resolver=resolver, - original_request_id=request_id, - ) - await self._queue.enqueue(self.task_id, queued) - - try: - # Wait for initial response (CreateTaskResult from client) - response_data = await resolver.wait() - create_result = CreateTaskResult.model_validate(response_data) - client_task_id = create_result.task.task_id - - # Poll the client's task using session.experimental - async for _ in self._session.experimental.poll_task(client_task_id): - pass - - # Get final result from client - result = await self._session.experimental.get_task_result( - client_task_id, - CreateMessageResult, - ) - - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) - return result - - except anyio.get_cancelled_exc_class(): # pragma: no cover - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) - raise diff --git a/src/mcp/server/experimental/task_result_handler.py b/src/mcp/server/experimental/task_result_handler.py deleted file mode 100644 index b2268bc1c8..0000000000 --- a/src/mcp/server/experimental/task_result_handler.py +++ /dev/null @@ -1,218 +0,0 @@ -"""TaskResultHandler - Integrated handler for tasks/result endpoint. - -This implements the dequeue-send-wait pattern from the MCP Tasks spec: -1. Dequeue all pending messages for the task -2. Send them to the client via transport with relatedRequestId routing -3. Wait if task is not in terminal state -4. Return final result when task completes - -This is the core of the task message queue pattern. -""" - -import logging -from typing import Any - -import anyio - -from mcp.server.session import ServerSession -from mcp.shared.exceptions import MCPError -from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY, is_terminal -from mcp.shared.experimental.tasks.message_queue import TaskMessageQueue -from mcp.shared.experimental.tasks.resolver import Resolver -from mcp.shared.experimental.tasks.store import TaskStore -from mcp.shared.message import ServerMessageMetadata, SessionMessage -from mcp.types import ( - INVALID_PARAMS, - ErrorData, - GetTaskPayloadRequest, - GetTaskPayloadResult, - RelatedTaskMetadata, - RequestId, -) - -logger = logging.getLogger(__name__) - - -class TaskResultHandler: - """Handler for tasks/result that implements the message queue pattern. - - This handler: - 1. Dequeues pending messages (elicitations, notifications) for the task - 2. Sends them to the client via the response stream - 3. Waits for responses and resolves them back to callers - 4. Blocks until task reaches terminal state - 5. Returns the final result - - Usage: - async def handle_task_result( - ctx: ServerRequestContext, params: GetTaskPayloadRequestParams - ) -> GetTaskPayloadResult: - ... - - server.experimental.enable_tasks( - on_task_result=handle_task_result, - ) - """ - - def __init__( - self, - store: TaskStore, - queue: TaskMessageQueue, - ): - self._store = store - self._queue = queue - # Map from internal request ID to resolver for routing responses - self._pending_requests: dict[RequestId, Resolver[dict[str, Any]]] = {} - - async def send_message( - self, - session: ServerSession, - message: SessionMessage, - ) -> None: - """Send a message via the session. - - This is a helper for delivering queued task messages. - """ - await session.send_message(message) - - async def handle( - self, - request: GetTaskPayloadRequest, - session: ServerSession, - request_id: RequestId, - ) -> GetTaskPayloadResult: - """Handle a tasks/result request. - - This implements the dequeue-send-wait loop: - 1. Dequeue all pending messages - 2. Send each via transport with relatedRequestId = this request's ID - 3. If task not terminal, wait for status change - 4. Loop until task is terminal - 5. Return final result - - Args: - request: The GetTaskPayloadRequest - session: The server session for sending messages - request_id: The request ID for relatedRequestId routing - - Returns: - GetTaskPayloadResult with the task's final payload - """ - task_id = request.params.task_id - - while True: - task = await self._store.get_task(task_id) - if task is None: - raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {task_id}") - - await self._deliver_queued_messages(task_id, session, request_id) - - # If task is terminal, return result - if is_terminal(task.status): - result = await self._store.get_result(task_id) - # GetTaskPayloadResult is a Result with extra="allow" - # The stored result contains the actual payload data - # Per spec: tasks/result MUST include _meta with related-task metadata - related_task = RelatedTaskMetadata(task_id=task_id) - related_task_meta: dict[str, Any] = {RELATED_TASK_METADATA_KEY: related_task.model_dump(by_alias=True)} - if result is not None: - result_data = result.model_dump(by_alias=True) - existing_meta: dict[str, Any] = result_data.get("_meta") or {} - result_data["_meta"] = {**existing_meta, **related_task_meta} - return GetTaskPayloadResult.model_validate(result_data) - return GetTaskPayloadResult.model_validate({"_meta": related_task_meta}) - - # Wait for task update (status change or new messages) - await self._wait_for_task_update(task_id) - - async def _deliver_queued_messages( - self, - task_id: str, - session: ServerSession, - request_id: RequestId, - ) -> None: - """Dequeue and send all pending messages for a task. - - Each message is sent via the session's write stream with - relatedRequestId set so responses route back to this stream. - """ - while True: - message = await self._queue.dequeue(task_id) - if message is None: - break - - # If this is a request (not notification), wait for response - if message.type == "request" and message.resolver is not None: - # Store the resolver so we can route the response back - original_id = message.original_request_id - if original_id is not None: - self._pending_requests[original_id] = message.resolver - - logger.debug("Delivering queued message for task %s: %s", task_id, message.type) - - # Send the message with relatedRequestId for routing - session_message = SessionMessage( - message=message.message, - metadata=ServerMessageMetadata(related_request_id=request_id), - ) - await self.send_message(session, session_message) - - async def _wait_for_task_update(self, task_id: str) -> None: - """Wait for task to be updated (status change or new message). - - Races between store update and queue message - first one wins. - """ - async with anyio.create_task_group() as tg: - - async def wait_for_store() -> None: - try: - await self._store.wait_for_update(task_id) - except Exception: - pass - finally: - tg.cancel_scope.cancel() - - async def wait_for_queue() -> None: - try: - await self._queue.wait_for_message(task_id) - except Exception: - pass - finally: - tg.cancel_scope.cancel() - - tg.start_soon(wait_for_store) - tg.start_soon(wait_for_queue) - - def route_response(self, request_id: RequestId, response: dict[str, Any]) -> bool: - """Route a response back to the waiting resolver. - - This is called when a response arrives for a queued request. - - Args: - request_id: The request ID from the response - response: The response data - - Returns: - True if response was routed, False if no pending request - """ - resolver = self._pending_requests.pop(request_id, None) - if resolver is not None and not resolver.done(): - resolver.set_result(response) - return True - return False - - def route_error(self, request_id: RequestId, error: ErrorData) -> bool: - """Route an error back to the waiting resolver. - - Args: - request_id: The request ID from the error response - error: The error data - - Returns: - True if error was routed, False if no pending request - """ - resolver = self._pending_requests.pop(request_id, None) - if resolver is not None and not resolver.done(): - resolver.set_exception(MCPError.from_error_data(error)) - return True - return False diff --git a/src/mcp/server/experimental/task_support.py b/src/mcp/server/experimental/task_support.py deleted file mode 100644 index b542195048..0000000000 --- a/src/mcp/server/experimental/task_support.py +++ /dev/null @@ -1,116 +0,0 @@ -"""TaskSupport - Configuration for experimental task support. - -This module provides the TaskSupport class which encapsulates all the -infrastructure needed for task-augmented requests: store, queue, and handler. -""" - -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager -from dataclasses import dataclass, field - -import anyio -from anyio.abc import TaskGroup - -from mcp.server.experimental.task_result_handler import TaskResultHandler -from mcp.server.session import ServerSession -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue, TaskMessageQueue -from mcp.shared.experimental.tasks.store import TaskStore - - -@dataclass -class TaskSupport: - """Configuration for experimental task support. - - Encapsulates the task store, message queue, result handler, and task group - for spawning background work. - - When enabled on a server, this automatically: - - Configures response routing for each session - - Provides default handlers for task operations - - Manages a task group for background task execution - - Example: - Simple in-memory setup: - - ```python - server.experimental.enable_tasks() - ``` - - Custom store/queue for distributed systems: - - ```python - server.experimental.enable_tasks( - store=RedisTaskStore(redis_url), - queue=RedisTaskMessageQueue(redis_url), - ) - ``` - """ - - store: TaskStore - queue: TaskMessageQueue - handler: TaskResultHandler = field(init=False) - _task_group: TaskGroup | None = field(init=False, default=None) - - def __post_init__(self) -> None: - """Create the result handler from store and queue.""" - self.handler = TaskResultHandler(self.store, self.queue) - - @property - def task_group(self) -> TaskGroup: - """Get the task group for spawning background work. - - Raises: - RuntimeError: If not within a run() context - """ - if self._task_group is None: - raise RuntimeError("TaskSupport not running. Ensure Server.run() is active.") - return self._task_group - - @asynccontextmanager - async def run(self) -> AsyncIterator[None]: - """Run the task support lifecycle. - - This creates a task group for spawning background task work. - Called automatically by Server.run(). - - Usage: - async with task_support.run(): - # Task group is now available - ... - """ - async with anyio.create_task_group() as tg: - self._task_group = tg - try: - yield - finally: - self._task_group = None - - def configure_session(self, session: ServerSession) -> None: - """Configure a session for task support. - - This registers the result handler as a response router so that - responses to queued requests (elicitation, sampling) are routed - back to the waiting resolvers. - - Called automatically by Server.run() for each new session. - - Args: - session: The session to configure - """ - session.add_response_router(self.handler) - - @classmethod - def in_memory(cls) -> "TaskSupport": - """Create in-memory task support. - - Suitable for development, testing, and single-process servers. - For distributed systems, provide custom store and queue implementations. - - Returns: - TaskSupport configured with in-memory store and queue - """ - return cls( - store=InMemoryTaskStore(), - queue=InMemoryTaskMessageQueue(), - ) diff --git a/src/mcp/server/lowlevel/experimental.py b/src/mcp/server/lowlevel/experimental.py deleted file mode 100644 index 5a907b6407..0000000000 --- a/src/mcp/server/lowlevel/experimental.py +++ /dev/null @@ -1,210 +0,0 @@ -"""Experimental handlers for the low-level MCP server. - -WARNING: These APIs are experimental and may change without notice. -""" - -from __future__ import annotations - -import logging -from collections.abc import Awaitable, Callable -from typing import Any, Generic - -from typing_extensions import TypeVar - -from mcp.server.context import ServerRequestContext -from mcp.server.experimental.task_support import TaskSupport -from mcp.shared.exceptions import MCPError -from mcp.shared.experimental.tasks.helpers import cancel_task -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue, TaskMessageQueue -from mcp.shared.experimental.tasks.store import TaskStore -from mcp.types import ( - INVALID_PARAMS, - CancelTaskRequestParams, - CancelTaskResult, - GetTaskPayloadRequest, - GetTaskPayloadRequestParams, - GetTaskPayloadResult, - GetTaskRequestParams, - GetTaskResult, - ListTasksResult, - PaginatedRequestParams, - ServerCapabilities, - ServerTasksCapability, - ServerTasksRequestsCapability, - TasksCallCapability, - TasksCancelCapability, - TasksListCapability, - TasksToolsCapability, -) - -logger = logging.getLogger(__name__) - -LifespanResultT = TypeVar("LifespanResultT", default=Any) - - -class ExperimentalHandlers(Generic[LifespanResultT]): - """Experimental request/notification handlers. - - WARNING: These APIs are experimental and may change without notice. - """ - - def __init__( - self, - add_request_handler: Callable[ - [str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]]], None - ], - has_handler: Callable[[str], bool], - ) -> None: - self._add_request_handler = add_request_handler - self._has_handler = has_handler - self._task_support: TaskSupport | None = None - - @property - def task_support(self) -> TaskSupport | None: - """Get the task support configuration, if enabled.""" - return self._task_support - - def update_capabilities(self, capabilities: ServerCapabilities) -> None: - # Only add tasks capability if handlers are registered - if not any(self._has_handler(method) for method in ["tasks/get", "tasks/list", "tasks/cancel", "tasks/result"]): - return - - capabilities.tasks = ServerTasksCapability() - if self._has_handler("tasks/list"): - capabilities.tasks.list = TasksListCapability() - if self._has_handler("tasks/cancel"): - capabilities.tasks.cancel = TasksCancelCapability() - - capabilities.tasks.requests = ServerTasksRequestsCapability( - tools=TasksToolsCapability(call=TasksCallCapability()) - ) # assuming always supported for now - - def enable_tasks( - self, - store: TaskStore | None = None, - queue: TaskMessageQueue | None = None, - *, - on_get_task: Callable[[ServerRequestContext[LifespanResultT], GetTaskRequestParams], Awaitable[GetTaskResult]] - | None = None, - on_task_result: Callable[ - [ServerRequestContext[LifespanResultT], GetTaskPayloadRequestParams], Awaitable[GetTaskPayloadResult] - ] - | None = None, - on_list_tasks: Callable[ - [ServerRequestContext[LifespanResultT], PaginatedRequestParams | None], Awaitable[ListTasksResult] - ] - | None = None, - on_cancel_task: Callable[ - [ServerRequestContext[LifespanResultT], CancelTaskRequestParams], Awaitable[CancelTaskResult] - ] - | None = None, - ) -> TaskSupport: - """Enable experimental task support. - - This sets up the task infrastructure and registers handlers for - tasks/get, tasks/result, tasks/list, and tasks/cancel. Custom handlers - can be provided via the on_* kwargs; any not provided will use defaults. - - Args: - store: Custom TaskStore implementation (defaults to InMemoryTaskStore) - queue: Custom TaskMessageQueue implementation (defaults to InMemoryTaskMessageQueue) - on_get_task: Custom handler for tasks/get - on_task_result: Custom handler for tasks/result - on_list_tasks: Custom handler for tasks/list - on_cancel_task: Custom handler for tasks/cancel - - Returns: - The TaskSupport configuration object - - Example: - Simple in-memory setup: - - ```python - server.experimental.enable_tasks() - ``` - - Custom store/queue for distributed systems: - - ```python - server.experimental.enable_tasks( - store=RedisTaskStore(redis_url), - queue=RedisTaskMessageQueue(redis_url), - ) - ``` - - WARNING: This API is experimental and may change without notice. - """ - if store is None: - store = InMemoryTaskStore() - if queue is None: - queue = InMemoryTaskMessageQueue() - - self._task_support = TaskSupport(store=store, queue=queue) - task_support = self._task_support - - # Register user-provided handlers - if on_get_task is not None: - self._add_request_handler("tasks/get", on_get_task) - if on_task_result is not None: - self._add_request_handler("tasks/result", on_task_result) - if on_list_tasks is not None: - self._add_request_handler("tasks/list", on_list_tasks) - if on_cancel_task is not None: - self._add_request_handler("tasks/cancel", on_cancel_task) - - # Fill in defaults for any not provided - if not self._has_handler("tasks/get"): - - async def _default_get_task( - ctx: ServerRequestContext[LifespanResultT], params: GetTaskRequestParams - ) -> GetTaskResult: - task = await task_support.store.get_task(params.task_id) - if task is None: - raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {params.task_id}") - return GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=task.poll_interval, - ) - - self._add_request_handler("tasks/get", _default_get_task) - - if not self._has_handler("tasks/result"): - - async def _default_get_task_result( - ctx: ServerRequestContext[LifespanResultT], params: GetTaskPayloadRequestParams - ) -> GetTaskPayloadResult: - assert ctx.request_id is not None - req = GetTaskPayloadRequest(params=params) - result = await task_support.handler.handle(req, ctx.session, ctx.request_id) - return result - - self._add_request_handler("tasks/result", _default_get_task_result) - - if not self._has_handler("tasks/list"): - - async def _default_list_tasks( - ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None - ) -> ListTasksResult: - cursor = params.cursor if params else None - tasks, next_cursor = await task_support.store.list_tasks(cursor) - return ListTasksResult(tasks=tasks, next_cursor=next_cursor) - - self._add_request_handler("tasks/list", _default_list_tasks) - - if not self._has_handler("tasks/cancel"): - - async def _default_cancel_task( - ctx: ServerRequestContext[LifespanResultT], params: CancelTaskRequestParams - ) -> CancelTaskResult: - result = await cancel_task(task_support.store, params.task_id) - return result - - self._add_request_handler("tasks/cancel", _default_cancel_task) - - return task_support diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 5e4e2e6f5b..37127c5621 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -59,8 +59,6 @@ async def main(): from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes from mcp.server.auth.settings import AuthSettings from mcp.server.context import ServerRequestContext -from mcp.server.experimental.request_context import Experimental -from mcp.server.lowlevel.experimental import ExperimentalHandlers from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.server.streamable_http import EventStore @@ -121,7 +119,7 @@ def __init__( | None = None, on_call_tool: Callable[ [ServerRequestContext[LifespanResultT], types.CallToolRequestParams], - Awaitable[types.CallToolResult | types.CreateTaskResult], + Awaitable[types.CallToolResult], ] | None = None, on_list_resources: Callable[ @@ -197,7 +195,6 @@ def __init__( self._notification_handlers: dict[ str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[None]] ] = {} - self._experimental_handlers: ExperimentalHandlers[LifespanResultT] | None = None self._session_manager: StreamableHTTPSessionManager | None = None logger.debug("Initializing server %r", name) @@ -242,10 +239,6 @@ def _add_request_handler( """Add a request handler, silently replacing any existing handler for the same method.""" self._request_handlers[method] = handler - def _has_handler(self, method: str) -> bool: - """Check if a handler is registered for the given method.""" - return method in self._request_handlers or method in self._notification_handlers - # TODO: Rethink capabilities API. Currently capabilities are derived from registered # handlers but require NotificationOptions to be passed externally for list_changed # flags, and experimental_capabilities as a separate dict. Consider deriving capabilities @@ -323,25 +316,8 @@ def get_capabilities( experimental=experimental_capabilities, completions=completions_capability, ) - if self._experimental_handlers: - self._experimental_handlers.update_capabilities(capabilities) return capabilities - @property - def experimental(self) -> ExperimentalHandlers[LifespanResultT]: - """Experimental APIs for tasks and other features. - - WARNING: These APIs are experimental and may change without notice. - """ - - # We create this inline so we only add these capabilities _if_ they're actually used - if self._experimental_handlers is None: - self._experimental_handlers = ExperimentalHandlers( - add_request_handler=self._add_request_handler, - has_handler=self._has_handler, - ) - return self._experimental_handlers - @property def session_manager(self) -> StreamableHTTPSessionManager: """Get the StreamableHTTP session manager. @@ -383,12 +359,6 @@ async def run( ) ) - # Configure task support for this session if enabled - task_support = self._experimental_handlers.task_support if self._experimental_handlers else None - if task_support is not None: - task_support.configure_session(session) - await stack.enter_async_context(task_support.run()) - async with anyio.create_task_group() as tg: try: async for message in session.incoming_messages: @@ -476,23 +446,11 @@ async def _handle_request( close_sse_stream_cb = message.message_metadata.close_sse_stream close_standalone_sse_stream_cb = message.message_metadata.close_standalone_sse_stream - client_capabilities = session.client_params.capabilities if session.client_params else None - task_support = self._experimental_handlers.task_support if self._experimental_handlers else None - # Get task metadata from request params if present - task_metadata = None - if hasattr(req, "params") and req.params is not None: # pragma: no branch - task_metadata = getattr(req.params, "task", None) ctx = ServerRequestContext( request_id=message.request_id, meta=message.request_meta, session=session, lifespan_context=lifespan_context, - experimental=Experimental( - task_metadata=task_metadata, - _client_capabilities=client_capabilities, - _session=session, - _task_support=task_support, - ), request=request_data, close_sse_stream=close_sse_stream_cb, close_standalone_sse_stream=close_standalone_sse_stream_cb, @@ -543,17 +501,9 @@ async def _handle_notification( logger.debug("Dispatching notification of type %s", type(notify).__name__) try: - client_capabilities = session.client_params.capabilities if session.client_params else None - task_support = self._experimental_handlers.task_support if self._experimental_handlers else None ctx = ServerRequestContext( session=session, lifespan_context=lifespan_context, - experimental=Experimental( - task_metadata=None, - _client_capabilities=client_capabilities, - _session=session, - _task_support=task_support, - ), ) await handler(ctx, notify.params) except Exception: # pragma: no cover diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index fc2f97a9cb..3fc7bbf0d3 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -37,13 +37,10 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult: from pydantic import AnyUrl, TypeAdapter from mcp import types -from mcp.server.experimental.session_features import ExperimentalServerSessionFeatures from mcp.server.models import InitializationOptions from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.exceptions import StatelessModeNotSupported -from mcp.shared.experimental.tasks.capabilities import check_tasks_capability -from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import ( BaseSession, @@ -76,7 +73,6 @@ class ServerSession( ): _initialized: InitializationState = InitializationState.NotInitialized _client_params: types.InitializeRequestParams | None = None - _experimental_features: ExperimentalServerSessionFeatures | None = None def __init__( self, @@ -109,16 +105,6 @@ def _receive_notification_adapter(self) -> TypeAdapter[types.ClientNotification] def client_params(self) -> types.InitializeRequestParams | None: return self._client_params - @property - def experimental(self) -> ExperimentalServerSessionFeatures: - """Experimental APIs for server→client task operations. - - WARNING: These APIs are experimental and may change without notice. - """ - if self._experimental_features is None: - self._experimental_features = ExperimentalServerSessionFeatures(self) - return self._experimental_features - def check_client_capability(self, capability: types.ClientCapabilities) -> bool: """Check if the client supports a specific capability.""" if self._client_params is None: # pragma: lax no cover @@ -150,12 +136,6 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool: if exp_key not in client_caps.experimental or client_caps.experimental[exp_key] != exp_value: return False - if capability.tasks is not None: # pragma: lax no cover - if client_caps.tasks is None: - return False - if not check_tasks_capability(capability.tasks, client_caps.tasks): - return False - return True async def _receive_loop(self) -> None: @@ -509,181 +489,6 @@ async def send_elicit_complete( related_request_id, ) - def _build_elicit_form_request( - self, - message: str, - requested_schema: types.ElicitRequestedSchema, - related_task_id: str | None = None, - task: types.TaskMetadata | None = None, - ) -> types.JSONRPCRequest: - """Build a form mode elicitation request without sending it. - - Args: - message: The message to present to the user - requested_schema: Schema defining the expected response structure - related_task_id: If provided, adds io.modelcontextprotocol/related-task metadata - task: If provided, makes this a task-augmented request - - Returns: - A JSONRPCRequest ready to be sent or queued - """ - params = types.ElicitRequestFormParams( - message=message, - requested_schema=requested_schema, - task=task, - ) - params_data = params.model_dump(by_alias=True, mode="json", exclude_none=True) - - # Add related-task metadata if associated with a parent task - if related_task_id is not None: - # Defensive: model_dump() never includes _meta, but guard against future changes - if "_meta" not in params_data: # pragma: no branch - params_data["_meta"] = {} - params_data["_meta"][RELATED_TASK_METADATA_KEY] = types.RelatedTaskMetadata( - task_id=related_task_id - ).model_dump(by_alias=True) - - request_id = f"task-{related_task_id}-{id(params)}" if related_task_id else self._request_id - if related_task_id is None: - self._request_id += 1 - - return types.JSONRPCRequest( - jsonrpc="2.0", - id=request_id, - method="elicitation/create", - params=params_data, - ) - - def _build_elicit_url_request( - self, - message: str, - url: str, - elicitation_id: str, - related_task_id: str | None = None, - ) -> types.JSONRPCRequest: - """Build a URL mode elicitation request without sending it. - - Args: - message: Human-readable explanation of why the interaction is needed - url: The URL the user should navigate to - elicitation_id: Unique identifier for tracking this elicitation - related_task_id: If provided, adds io.modelcontextprotocol/related-task metadata - - Returns: - A JSONRPCRequest ready to be sent or queued - """ - params = types.ElicitRequestURLParams( - message=message, - url=url, - elicitation_id=elicitation_id, - ) - params_data = params.model_dump(by_alias=True, mode="json", exclude_none=True) - - # Add related-task metadata if associated with a parent task - if related_task_id is not None: - # Defensive: model_dump() never includes _meta, but guard against future changes - if "_meta" not in params_data: # pragma: no branch - params_data["_meta"] = {} - params_data["_meta"][RELATED_TASK_METADATA_KEY] = types.RelatedTaskMetadata( - task_id=related_task_id - ).model_dump(by_alias=True) - - request_id = f"task-{related_task_id}-{id(params)}" if related_task_id else self._request_id - if related_task_id is None: - self._request_id += 1 - - return types.JSONRPCRequest( - jsonrpc="2.0", - id=request_id, - method="elicitation/create", - params=params_data, - ) - - def _build_create_message_request( - self, - messages: list[types.SamplingMessage], - *, - max_tokens: int, - system_prompt: str | None = None, - include_context: types.IncludeContext | None = None, - temperature: float | None = None, - stop_sequences: list[str] | None = None, - metadata: dict[str, Any] | None = None, - model_preferences: types.ModelPreferences | None = None, - tools: list[types.Tool] | None = None, - tool_choice: types.ToolChoice | None = None, - related_task_id: str | None = None, - task: types.TaskMetadata | None = None, - ) -> types.JSONRPCRequest: - """Build a sampling/createMessage request without sending it. - - Args: - messages: The conversation messages to send - max_tokens: Maximum number of tokens to generate - system_prompt: Optional system prompt - include_context: Optional context inclusion setting - temperature: Optional sampling temperature - stop_sequences: Optional stop sequences - metadata: Optional metadata to pass through to the LLM provider - model_preferences: Optional model selection preferences - tools: Optional list of tools the LLM can use during sampling - tool_choice: Optional control over tool usage behavior - related_task_id: If provided, adds io.modelcontextprotocol/related-task metadata - task: If provided, makes this a task-augmented request - - Returns: - A JSONRPCRequest ready to be sent or queued - """ - params = types.CreateMessageRequestParams( - messages=messages, - system_prompt=system_prompt, - include_context=include_context, - temperature=temperature, - max_tokens=max_tokens, - stop_sequences=stop_sequences, - metadata=metadata, - model_preferences=model_preferences, - tools=tools, - tool_choice=tool_choice, - task=task, - ) - params_data = params.model_dump(by_alias=True, mode="json", exclude_none=True) - - # Add related-task metadata if associated with a parent task - if related_task_id is not None: - # Defensive: model_dump() never includes _meta, but guard against future changes - if "_meta" not in params_data: # pragma: no branch - params_data["_meta"] = {} - params_data["_meta"][RELATED_TASK_METADATA_KEY] = types.RelatedTaskMetadata( - task_id=related_task_id - ).model_dump(by_alias=True) - - request_id = f"task-{related_task_id}-{id(params)}" if related_task_id else self._request_id - if related_task_id is None: - self._request_id += 1 - - return types.JSONRPCRequest( - jsonrpc="2.0", - id=request_id, - method="sampling/createMessage", - params=params_data, - ) - - async def send_message(self, message: SessionMessage) -> None: - """Send a raw session message. - - This is primarily used by TaskResultHandler to deliver queued messages - (elicitation/sampling requests) to the client during task execution. - - WARNING: This is a low-level experimental method that may change without - notice. Prefer using higher-level methods like send_notification() or - send_request() for normal operations. - - Args: - message: The session message to send - """ - await self._write_stream.send(message) - async def _handle_incoming(self, req: ServerRequestResponder) -> None: await self._incoming_message_stream_writer.send(req) diff --git a/src/mcp/server/validation.py b/src/mcp/server/validation.py index 5708628074..08f5754f1e 100644 --- a/src/mcp/server/validation.py +++ b/src/mcp/server/validation.py @@ -1,7 +1,6 @@ """Shared validation functions for server requests. -This module provides validation logic for sampling and elicitation requests -that is shared across normal and task-augmented code paths. +This module provides validation logic for sampling and elicitation requests. """ from mcp.shared.exceptions import MCPError diff --git a/src/mcp/shared/experimental/__init__.py b/src/mcp/shared/experimental/__init__.py deleted file mode 100644 index fa6940acc6..0000000000 --- a/src/mcp/shared/experimental/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Pure experimental MCP features (no server dependencies). - -WARNING: These APIs are experimental and may change without notice. - -For server-integrated experimental features, use mcp.server.experimental. -""" diff --git a/src/mcp/shared/experimental/tasks/__init__.py b/src/mcp/shared/experimental/tasks/__init__.py deleted file mode 100644 index 52793e408b..0000000000 --- a/src/mcp/shared/experimental/tasks/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Pure task state management for MCP. - -WARNING: These APIs are experimental and may change without notice. - -Import directly from submodules: -- mcp.shared.experimental.tasks.store.TaskStore -- mcp.shared.experimental.tasks.context.TaskContext -- mcp.shared.experimental.tasks.in_memory_task_store.InMemoryTaskStore -- mcp.shared.experimental.tasks.message_queue.TaskMessageQueue -- mcp.shared.experimental.tasks.helpers.is_terminal -""" diff --git a/src/mcp/shared/experimental/tasks/capabilities.py b/src/mcp/shared/experimental/tasks/capabilities.py deleted file mode 100644 index 51fe64ecc3..0000000000 --- a/src/mcp/shared/experimental/tasks/capabilities.py +++ /dev/null @@ -1,96 +0,0 @@ -"""Tasks capability checking utilities. - -This module provides functions for checking and requiring task-related -capabilities. All tasks capability logic is centralized here to keep -the main session code clean. - -WARNING: These APIs are experimental and may change without notice. -""" - -from mcp.shared.exceptions import MCPError -from mcp.types import INVALID_REQUEST, ClientCapabilities, ClientTasksCapability - - -def check_tasks_capability( - required: ClientTasksCapability, - client: ClientTasksCapability, -) -> bool: - """Check if client's tasks capability matches the required capability. - - Args: - required: The capability being checked for - client: The client's declared capabilities - - Returns: - True if client has the required capability, False otherwise - """ - if required.requests is None: - return True - if client.requests is None: - return False - - # Check elicitation.create - if required.requests.elicitation is not None: - if client.requests.elicitation is None: - return False - if required.requests.elicitation.create is not None: - if client.requests.elicitation.create is None: - return False - - # Check sampling.createMessage - if required.requests.sampling is not None: - if client.requests.sampling is None: - return False - if required.requests.sampling.create_message is not None: - if client.requests.sampling.create_message is None: - return False - - return True - - -def has_task_augmented_elicitation(caps: ClientCapabilities) -> bool: - """Check if capabilities include task-augmented elicitation support.""" - if caps.tasks is None: - return False - if caps.tasks.requests is None: - return False - if caps.tasks.requests.elicitation is None: - return False - return caps.tasks.requests.elicitation.create is not None - - -def has_task_augmented_sampling(caps: ClientCapabilities) -> bool: - """Check if capabilities include task-augmented sampling support.""" - if caps.tasks is None: - return False - if caps.tasks.requests is None: - return False - if caps.tasks.requests.sampling is None: - return False - return caps.tasks.requests.sampling.create_message is not None - - -def require_task_augmented_elicitation(client_caps: ClientCapabilities | None) -> None: - """Raise MCPError if client doesn't support task-augmented elicitation. - - Args: - client_caps: The client's declared capabilities, or None if not initialized - - Raises: - MCPError: If client doesn't support task-augmented elicitation - """ - if client_caps is None or not has_task_augmented_elicitation(client_caps): - raise MCPError(code=INVALID_REQUEST, message="Client does not support task-augmented elicitation") - - -def require_task_augmented_sampling(client_caps: ClientCapabilities | None) -> None: - """Raise MCPError if client doesn't support task-augmented sampling. - - Args: - client_caps: The client's declared capabilities, or None if not initialized - - Raises: - MCPError: If client doesn't support task-augmented sampling - """ - if client_caps is None or not has_task_augmented_sampling(client_caps): - raise MCPError(code=INVALID_REQUEST, message="Client does not support task-augmented sampling") diff --git a/src/mcp/shared/experimental/tasks/context.py b/src/mcp/shared/experimental/tasks/context.py deleted file mode 100644 index ed0d2b91b6..0000000000 --- a/src/mcp/shared/experimental/tasks/context.py +++ /dev/null @@ -1,95 +0,0 @@ -"""TaskContext - Pure task state management. - -This module provides TaskContext, which manages task state without any -server/session dependencies. It can be used standalone for distributed -workers or wrapped by ServerTaskContext for full server integration. -""" - -from mcp.shared.experimental.tasks.store import TaskStore -from mcp.types import TASK_STATUS_COMPLETED, TASK_STATUS_FAILED, Result, Task - - -class TaskContext: - """Pure task state management - no session dependencies. - - This class handles: - - Task state (status, result) - - Cancellation tracking - - Store interactions - - For server-integrated features (elicit, create_message, notifications), - use ServerTaskContext from mcp.server.experimental. - - Example (distributed worker): - async def worker_job(task_id: str): - store = RedisTaskStore(redis_url) - task = await store.get_task(task_id) - ctx = TaskContext(task=task, store=store) - - await ctx.update_status("Working...") - result = await do_work() - await ctx.complete(result) - """ - - def __init__(self, task: Task, store: TaskStore): - self._task = task - self._store = store - self._cancelled = False - - @property - def task_id(self) -> str: - """The task identifier.""" - return self._task.task_id - - @property - def task(self) -> Task: - """The current task state.""" - return self._task - - @property - def is_cancelled(self) -> bool: - """Whether cancellation has been requested.""" - return self._cancelled - - def request_cancellation(self) -> None: - """Request cancellation of this task. - - This sets is_cancelled=True. Task work should check this - periodically and exit gracefully if set. - """ - self._cancelled = True - - async def update_status(self, message: str) -> None: - """Update the task's status message. - - Args: - message: The new status message - """ - self._task = await self._store.update_task( - self.task_id, - status_message=message, - ) - - async def complete(self, result: Result) -> None: - """Mark the task as completed with the given result. - - Args: - result: The task result - """ - await self._store.store_result(self.task_id, result) - self._task = await self._store.update_task( - self.task_id, - status=TASK_STATUS_COMPLETED, - ) - - async def fail(self, error: str) -> None: - """Mark the task as failed with an error message. - - Args: - error: The error message - """ - self._task = await self._store.update_task( - self.task_id, - status=TASK_STATUS_FAILED, - status_message=error, - ) diff --git a/src/mcp/shared/experimental/tasks/helpers.py b/src/mcp/shared/experimental/tasks/helpers.py deleted file mode 100644 index 3f91cd0d06..0000000000 --- a/src/mcp/shared/experimental/tasks/helpers.py +++ /dev/null @@ -1,166 +0,0 @@ -"""Helper functions for pure task management. - -These helpers work with pure TaskContext and don't require server dependencies. -For server-integrated task helpers, use mcp.server.experimental. -""" - -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager -from datetime import datetime, timezone -from uuid import uuid4 - -from mcp.shared.exceptions import MCPError -from mcp.shared.experimental.tasks.context import TaskContext -from mcp.shared.experimental.tasks.store import TaskStore -from mcp.types import ( - INVALID_PARAMS, - TASK_STATUS_CANCELLED, - TASK_STATUS_COMPLETED, - TASK_STATUS_FAILED, - TASK_STATUS_WORKING, - CancelTaskResult, - Task, - TaskMetadata, - TaskStatus, -) - -# Metadata key for model-immediate-response (per MCP spec) -# Servers MAY include this in CreateTaskResult._meta to provide an immediate -# response string while the task executes in the background. -MODEL_IMMEDIATE_RESPONSE_KEY = "io.modelcontextprotocol/model-immediate-response" - -# Metadata key for associating requests with a task (per MCP spec) -RELATED_TASK_METADATA_KEY = "io.modelcontextprotocol/related-task" - - -def is_terminal(status: TaskStatus) -> bool: - """Check if a task status represents a terminal state. - - Terminal states are those where the task has finished and will not change. - - Args: - status: The task status to check - - Returns: - True if the status is terminal (completed, failed, or cancelled) - """ - return status in (TASK_STATUS_COMPLETED, TASK_STATUS_FAILED, TASK_STATUS_CANCELLED) - - -async def cancel_task( - store: TaskStore, - task_id: str, -) -> CancelTaskResult: - """Cancel a task with spec-compliant validation. - - Per spec: "Receivers MUST reject cancellation of terminal status tasks - with -32602 (Invalid params)" - - This helper validates that the task exists and is not in a terminal state - before setting it to "cancelled". - - Args: - store: The task store - task_id: The task identifier to cancel - - Returns: - CancelTaskResult with the cancelled task state - - Raises: - MCPError: With INVALID_PARAMS (-32602) if: - - Task does not exist - - Task is already in a terminal state (completed, failed, cancelled) - - Example: - ```python - async def handle_cancel(ctx, params: CancelTaskRequestParams) -> CancelTaskResult: - return await cancel_task(store, params.task_id) - ``` - """ - task = await store.get_task(task_id) - if task is None: - raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {task_id}") - - if is_terminal(task.status): - raise MCPError(code=INVALID_PARAMS, message=f"Cannot cancel task in terminal state '{task.status}'") - - # Update task to cancelled status - cancelled_task = await store.update_task(task_id, status=TASK_STATUS_CANCELLED) - return CancelTaskResult(**cancelled_task.model_dump()) - - -def generate_task_id() -> str: - """Generate a unique task ID.""" - return str(uuid4()) - - -def create_task_state( - metadata: TaskMetadata, - task_id: str | None = None, -) -> Task: - """Create a Task object with initial state. - - This is a helper for TaskStore implementations. - - Args: - metadata: Task metadata - task_id: Optional task ID (generated if not provided) - - Returns: - A new Task in "working" status - """ - now = datetime.now(timezone.utc) - return Task( - task_id=task_id or generate_task_id(), - status=TASK_STATUS_WORKING, - created_at=now, - last_updated_at=now, - ttl=metadata.ttl, - poll_interval=500, # Default 500ms poll interval - ) - - -@asynccontextmanager -async def task_execution( - task_id: str, - store: TaskStore, -) -> AsyncIterator[TaskContext]: - """Context manager for safe task execution (pure, no server dependencies). - - Loads a task from the store and provides a TaskContext for the work. - If an unhandled exception occurs, the task is automatically marked as failed - and the exception is suppressed (since the failure is captured in task state). - - This is useful for distributed workers that don't have a server session. - - Args: - task_id: The task identifier to execute - store: The task store (must be accessible by the worker) - - Yields: - TaskContext for updating status and completing/failing the task - - Raises: - ValueError: If the task is not found in the store - - Example (distributed worker): - async def worker_process(task_id: str): - store = RedisTaskStore(redis_url) - async with task_execution(task_id, store) as ctx: - await ctx.update_status("Working...") - result = await do_work() - await ctx.complete(result) - """ - task = await store.get_task(task_id) - if task is None: - raise ValueError(f"Task {task_id} not found") - - ctx = TaskContext(task, store) - try: - yield ctx - except Exception as e: - # Auto-fail the task if an exception occurs and task isn't already terminal - # Exception is suppressed since failure is captured in task state - if not is_terminal(ctx.task.status): - await ctx.fail(str(e)) - # Don't re-raise - the failure is recorded in task state diff --git a/src/mcp/shared/experimental/tasks/in_memory_task_store.py b/src/mcp/shared/experimental/tasks/in_memory_task_store.py deleted file mode 100644 index 42f4fb7035..0000000000 --- a/src/mcp/shared/experimental/tasks/in_memory_task_store.py +++ /dev/null @@ -1,217 +0,0 @@ -"""In-memory implementation of TaskStore for demonstration purposes. - -This implementation stores all tasks in memory and provides automatic cleanup -based on the TTL duration specified in the task metadata using lazy expiration. - -Note: This is not suitable for production use as all data is lost on restart. -For production, consider implementing TaskStore with a database or distributed cache. -""" - -from dataclasses import dataclass, field -from datetime import datetime, timedelta, timezone - -import anyio - -from mcp.shared.experimental.tasks.helpers import create_task_state, is_terminal -from mcp.shared.experimental.tasks.store import TaskStore -from mcp.types import Result, Task, TaskMetadata, TaskStatus - - -@dataclass -class StoredTask: - """Internal storage representation of a task.""" - - task: Task - result: Result | None = None - # Time when this task should be removed (None = never) - expires_at: datetime | None = field(default=None) - - -class InMemoryTaskStore(TaskStore): - """A simple in-memory implementation of TaskStore. - - Features: - - Automatic TTL-based cleanup (lazy expiration) - - Thread-safe for single-process async use - - Pagination support for list_tasks - - Limitations: - - All data lost on restart - - Not suitable for distributed systems - - No persistence - - For production, implement TaskStore with Redis, PostgreSQL, etc. - """ - - def __init__(self, page_size: int = 10) -> None: - self._tasks: dict[str, StoredTask] = {} - self._page_size = page_size - self._update_events: dict[str, anyio.Event] = {} - - def _calculate_expiry(self, ttl_ms: int | None) -> datetime | None: - """Calculate expiry time from TTL in milliseconds.""" - if ttl_ms is None: - return None - return datetime.now(timezone.utc) + timedelta(milliseconds=ttl_ms) - - def _is_expired(self, stored: StoredTask) -> bool: - """Check if a task has expired.""" - if stored.expires_at is None: - return False - return datetime.now(timezone.utc) >= stored.expires_at - - def _cleanup_expired(self) -> None: - """Remove all expired tasks. Called lazily during access operations.""" - expired_ids = [task_id for task_id, stored in self._tasks.items() if self._is_expired(stored)] - for task_id in expired_ids: - del self._tasks[task_id] - - async def create_task( - self, - metadata: TaskMetadata, - task_id: str | None = None, - ) -> Task: - """Create a new task with the given metadata.""" - # Cleanup expired tasks on access - self._cleanup_expired() - - task = create_task_state(metadata, task_id) - - if task.task_id in self._tasks: - raise ValueError(f"Task with ID {task.task_id} already exists") - - stored = StoredTask( - task=task, - expires_at=self._calculate_expiry(metadata.ttl), - ) - self._tasks[task.task_id] = stored - - # Return a copy to prevent external modification - return Task(**task.model_dump()) - - async def get_task(self, task_id: str) -> Task | None: - """Get a task by ID.""" - # Cleanup expired tasks on access - self._cleanup_expired() - - stored = self._tasks.get(task_id) - if stored is None: - return None - - # Return a copy to prevent external modification - return Task(**stored.task.model_dump()) - - async def update_task( - self, - task_id: str, - status: TaskStatus | None = None, - status_message: str | None = None, - ) -> Task: - """Update a task's status and/or message.""" - stored = self._tasks.get(task_id) - if stored is None: - raise ValueError(f"Task with ID {task_id} not found") - - # Per spec: Terminal states MUST NOT transition to any other status - if status is not None and status != stored.task.status and is_terminal(stored.task.status): - raise ValueError(f"Cannot transition from terminal status '{stored.task.status}'") - - status_changed = False - if status is not None and stored.task.status != status: - stored.task.status = status - status_changed = True - - if status_message is not None: - stored.task.status_message = status_message - - # Update last_updated_at on any change - stored.task.last_updated_at = datetime.now(timezone.utc) - - # If task is now terminal and has TTL, reset expiry timer - if status is not None and is_terminal(status) and stored.task.ttl is not None: - stored.expires_at = self._calculate_expiry(stored.task.ttl) - - # Notify waiters if status changed - if status_changed: - await self.notify_update(task_id) - - return Task(**stored.task.model_dump()) - - async def store_result(self, task_id: str, result: Result) -> None: - """Store the result for a task.""" - stored = self._tasks.get(task_id) - if stored is None: - raise ValueError(f"Task with ID {task_id} not found") - - stored.result = result - - async def get_result(self, task_id: str) -> Result | None: - """Get the stored result for a task.""" - stored = self._tasks.get(task_id) - if stored is None: - return None - - return stored.result - - async def list_tasks( - self, - cursor: str | None = None, - ) -> tuple[list[Task], str | None]: - """List tasks with pagination.""" - # Cleanup expired tasks on access - self._cleanup_expired() - - all_task_ids = list(self._tasks.keys()) - - start_index = 0 - if cursor is not None: - try: - cursor_index = all_task_ids.index(cursor) - start_index = cursor_index + 1 - except ValueError: - raise ValueError(f"Invalid cursor: {cursor}") - - page_task_ids = all_task_ids[start_index : start_index + self._page_size] - tasks = [Task(**self._tasks[tid].task.model_dump()) for tid in page_task_ids] - - # Determine next cursor - next_cursor = None - if start_index + self._page_size < len(all_task_ids) and page_task_ids: - next_cursor = page_task_ids[-1] - - return tasks, next_cursor - - async def delete_task(self, task_id: str) -> bool: - """Delete a task.""" - if task_id not in self._tasks: - return False - - del self._tasks[task_id] - return True - - async def wait_for_update(self, task_id: str) -> None: - """Wait until the task status changes.""" - if task_id not in self._tasks: - raise ValueError(f"Task with ID {task_id} not found") - - # Create a fresh event for waiting (anyio.Event can't be cleared) - self._update_events[task_id] = anyio.Event() - event = self._update_events[task_id] - await event.wait() - - async def notify_update(self, task_id: str) -> None: - """Signal that a task has been updated.""" - if task_id in self._update_events: - self._update_events[task_id].set() - - # --- Testing/debugging helpers --- - - def cleanup(self) -> None: - """Cleanup all tasks (useful for testing or graceful shutdown).""" - self._tasks.clear() - self._update_events.clear() - - def get_all_tasks(self) -> list[Task]: - """Get all tasks (useful for debugging). Returns copies to prevent modification.""" - self._cleanup_expired() - return [Task(**stored.task.model_dump()) for stored in self._tasks.values()] diff --git a/src/mcp/shared/experimental/tasks/message_queue.py b/src/mcp/shared/experimental/tasks/message_queue.py deleted file mode 100644 index e17c4a8650..0000000000 --- a/src/mcp/shared/experimental/tasks/message_queue.py +++ /dev/null @@ -1,230 +0,0 @@ -"""TaskMessageQueue - FIFO queue for task-related messages. - -This implements the core message queue pattern from the MCP Tasks spec. -When a handler needs to send a request (like elicitation) during a task-augmented -request, the message is enqueued instead of sent directly. Messages are delivered -to the client only through the `tasks/result` endpoint. - -This pattern enables: -1. Decoupling request handling from message delivery -2. Proper bidirectional communication via the tasks/result stream -3. Automatic status management (working <-> input_required) -""" - -from abc import ABC, abstractmethod -from collections import deque -from dataclasses import dataclass, field -from datetime import datetime, timezone -from typing import Any, Literal - -import anyio - -from mcp.shared.experimental.tasks.resolver import Resolver -from mcp.types import JSONRPCNotification, JSONRPCRequest, RequestId - - -@dataclass -class QueuedMessage: - """A message queued for delivery via tasks/result. - - Messages are stored with their type and a resolver for requests - that expect responses. - """ - - type: Literal["request", "notification"] - """Whether this is a request (expects response) or notification (one-way).""" - - message: JSONRPCRequest | JSONRPCNotification - """The JSON-RPC message to send.""" - - timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) - """When the message was enqueued.""" - - resolver: Resolver[dict[str, Any]] | None = None - """Resolver to set when response arrives (only for requests).""" - - original_request_id: RequestId | None = None - """The original request ID used internally, for routing responses back.""" - - -class TaskMessageQueue(ABC): - """Abstract interface for task message queuing. - - This is a FIFO queue that stores messages to be delivered via `tasks/result`. - When a task-augmented handler calls elicit() or sends a notification, the - message is enqueued here instead of being sent directly to the client. - - The `tasks/result` handler then dequeues and sends these messages through - the transport, with `relatedRequestId` set to the tasks/result request ID - so responses are routed correctly. - - Implementations can use in-memory storage, Redis, etc. - """ - - @abstractmethod - async def enqueue(self, task_id: str, message: QueuedMessage) -> None: - """Add a message to the queue for a task. - - Args: - task_id: The task identifier - message: The message to enqueue - """ - - @abstractmethod - async def dequeue(self, task_id: str) -> QueuedMessage | None: - """Remove and return the next message from the queue. - - Args: - task_id: The task identifier - - Returns: - The next message, or None if queue is empty - """ - - @abstractmethod - async def peek(self, task_id: str) -> QueuedMessage | None: - """Return the next message without removing it. - - Args: - task_id: The task identifier - - Returns: - The next message, or None if queue is empty - """ - - @abstractmethod - async def is_empty(self, task_id: str) -> bool: - """Check if the queue is empty for a task. - - Args: - task_id: The task identifier - - Returns: - True if no messages are queued - """ - - @abstractmethod - async def clear(self, task_id: str) -> list[QueuedMessage]: - """Remove and return all messages from the queue. - - This is useful for cleanup when a task is cancelled or completed. - - Args: - task_id: The task identifier - - Returns: - All queued messages (may be empty) - """ - - @abstractmethod - async def wait_for_message(self, task_id: str) -> None: - """Wait until a message is available in the queue. - - This blocks until either: - 1. A message is enqueued for this task - 2. The wait is cancelled - - Args: - task_id: The task identifier - """ - - @abstractmethod - async def notify_message_available(self, task_id: str) -> None: - """Signal that a message is available for a task. - - This wakes up any coroutines waiting in wait_for_message(). - - Args: - task_id: The task identifier - """ - - -class InMemoryTaskMessageQueue(TaskMessageQueue): - """In-memory implementation of TaskMessageQueue. - - This is suitable for single-process servers. For distributed systems, - implement TaskMessageQueue with Redis, RabbitMQ, etc. - - Features: - - FIFO ordering per task - - Async wait for message availability - - Thread-safe for single-process async use - """ - - def __init__(self) -> None: - self._queues: dict[str, deque[QueuedMessage]] = {} - self._events: dict[str, anyio.Event] = {} - - def _get_queue(self, task_id: str) -> deque[QueuedMessage]: - """Get or create the queue for a task.""" - if task_id not in self._queues: - self._queues[task_id] = deque() - return self._queues[task_id] - - async def enqueue(self, task_id: str, message: QueuedMessage) -> None: - """Add a message to the queue.""" - queue = self._get_queue(task_id) - queue.append(message) - # Signal that a message is available - await self.notify_message_available(task_id) - - async def dequeue(self, task_id: str) -> QueuedMessage | None: - """Remove and return the next message.""" - queue = self._get_queue(task_id) - if not queue: - return None - return queue.popleft() - - async def peek(self, task_id: str) -> QueuedMessage | None: - """Return the next message without removing it.""" - queue = self._get_queue(task_id) - if not queue: - return None - return queue[0] - - async def is_empty(self, task_id: str) -> bool: - """Check if the queue is empty.""" - queue = self._get_queue(task_id) - return len(queue) == 0 - - async def clear(self, task_id: str) -> list[QueuedMessage]: - """Remove and return all messages.""" - queue = self._get_queue(task_id) - messages = list(queue) - queue.clear() - return messages - - async def wait_for_message(self, task_id: str) -> None: - """Wait until a message is available.""" - # Check if there are already messages - if not await self.is_empty(task_id): - return - - # Create a fresh event for waiting (anyio.Event can't be cleared) - self._events[task_id] = anyio.Event() - event = self._events[task_id] - - # Double-check after creating event (avoid race condition) - if not await self.is_empty(task_id): - return - - # Wait for a new message - await event.wait() - - async def notify_message_available(self, task_id: str) -> None: - """Signal that a message is available.""" - if task_id in self._events: - self._events[task_id].set() - - def cleanup(self, task_id: str | None = None) -> None: - """Clean up queues and events. - - Args: - task_id: If provided, clean up only this task. Otherwise clean up all. - """ - if task_id is not None: - self._queues.pop(task_id, None) - self._events.pop(task_id, None) - else: - self._queues.clear() - self._events.clear() diff --git a/src/mcp/shared/experimental/tasks/polling.py b/src/mcp/shared/experimental/tasks/polling.py deleted file mode 100644 index e4e13b6640..0000000000 --- a/src/mcp/shared/experimental/tasks/polling.py +++ /dev/null @@ -1,43 +0,0 @@ -"""Shared polling utilities for task operations. - -This module provides generic polling logic that works for both client→server -and server→client task polling. - -WARNING: These APIs are experimental and may change without notice. -""" - -from collections.abc import AsyncIterator, Awaitable, Callable - -import anyio - -from mcp.shared.experimental.tasks.helpers import is_terminal -from mcp.types import GetTaskResult - - -async def poll_until_terminal( - get_task: Callable[[str], Awaitable[GetTaskResult]], - task_id: str, - default_interval_ms: int = 500, -) -> AsyncIterator[GetTaskResult]: - """Poll a task until it reaches terminal status. - - This is a generic utility that works for both client→server and server→client - polling. The caller provides the get_task function appropriate for their direction. - - Args: - get_task: Async function that takes task_id and returns GetTaskResult - task_id: The task to poll - default_interval_ms: Fallback poll interval if server doesn't specify - - Yields: - GetTaskResult for each poll - """ - while True: - status = await get_task(task_id) - yield status - - if is_terminal(status.status): - break - - interval_ms = status.poll_interval if status.poll_interval is not None else default_interval_ms - await anyio.sleep(interval_ms / 1000) diff --git a/src/mcp/shared/experimental/tasks/resolver.py b/src/mcp/shared/experimental/tasks/resolver.py deleted file mode 100644 index 1d233a9309..0000000000 --- a/src/mcp/shared/experimental/tasks/resolver.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Resolver - An anyio-compatible future-like object for async result passing. - -This provides a simple way to pass a result (or exception) from one coroutine -to another without depending on asyncio.Future. -""" - -from typing import Generic, TypeVar, cast - -import anyio - -T = TypeVar("T") - - -class Resolver(Generic[T]): - """A simple resolver for passing results between coroutines. - - Unlike asyncio.Future, this works with any anyio-compatible async backend. - - Usage: - resolver: Resolver[str] = Resolver() - - # In one coroutine: - resolver.set_result("hello") - - # In another coroutine: - result = await resolver.wait() # returns "hello" - """ - - def __init__(self) -> None: - self._event = anyio.Event() - self._value: T | None = None - self._exception: BaseException | None = None - - def set_result(self, value: T) -> None: - """Set the result value and wake up waiters.""" - if self._event.is_set(): - raise RuntimeError("Resolver already completed") - self._value = value - self._event.set() - - def set_exception(self, exc: BaseException) -> None: - """Set an exception and wake up waiters.""" - if self._event.is_set(): - raise RuntimeError("Resolver already completed") - self._exception = exc - self._event.set() - - async def wait(self) -> T: - """Wait for the result and return it, or raise the exception.""" - await self._event.wait() - if self._exception is not None: - raise self._exception - # If we reach here, set_result() was called, so _value is set - return cast(T, self._value) - - def done(self) -> bool: - """Return True if the resolver has been completed.""" - return self._event.is_set() diff --git a/src/mcp/shared/experimental/tasks/store.py b/src/mcp/shared/experimental/tasks/store.py deleted file mode 100644 index 7de97d40ca..0000000000 --- a/src/mcp/shared/experimental/tasks/store.py +++ /dev/null @@ -1,144 +0,0 @@ -"""TaskStore - Abstract interface for task state storage.""" - -from abc import ABC, abstractmethod - -from mcp.types import Result, Task, TaskMetadata, TaskStatus - - -class TaskStore(ABC): - """Abstract interface for task state storage. - - This is a pure storage interface - it doesn't manage execution. - Implementations can use in-memory storage, databases, Redis, etc. - - All methods are async to support various backends. - """ - - @abstractmethod - async def create_task( - self, - metadata: TaskMetadata, - task_id: str | None = None, - ) -> Task: - """Create a new task. - - Args: - metadata: Task metadata (ttl, etc.) - task_id: Optional task ID. If None, implementation should generate one. - - Returns: - The created Task with status="working" - - Raises: - ValueError: If task_id already exists - """ - - @abstractmethod - async def get_task(self, task_id: str) -> Task | None: - """Get a task by ID. - - Args: - task_id: The task identifier - - Returns: - The Task, or None if not found - """ - - @abstractmethod - async def update_task( - self, - task_id: str, - status: TaskStatus | None = None, - status_message: str | None = None, - ) -> Task: - """Update a task's status and/or message. - - Args: - task_id: The task identifier - status: New status (if changing) - status_message: New status message (if changing) - - Returns: - The updated Task - - Raises: - ValueError: If task not found - ValueError: If attempting to transition from a terminal status - (completed, failed, cancelled). Per spec, terminal states - MUST NOT transition to any other status. - """ - - @abstractmethod - async def store_result(self, task_id: str, result: Result) -> None: - """Store the result for a task. - - Args: - task_id: The task identifier - result: The result to store - - Raises: - ValueError: If task not found - """ - - @abstractmethod - async def get_result(self, task_id: str) -> Result | None: - """Get the stored result for a task. - - Args: - task_id: The task identifier - - Returns: - The stored Result, or None if not available - """ - - @abstractmethod - async def list_tasks( - self, - cursor: str | None = None, - ) -> tuple[list[Task], str | None]: - """List tasks with pagination. - - Args: - cursor: Optional cursor for pagination - - Returns: - Tuple of (tasks, next_cursor). next_cursor is None if no more pages. - """ - - @abstractmethod - async def delete_task(self, task_id: str) -> bool: - """Delete a task. - - Args: - task_id: The task identifier - - Returns: - True if deleted, False if not found - """ - - @abstractmethod - async def wait_for_update(self, task_id: str) -> None: - """Wait until the task status changes. - - This blocks until either: - 1. The task status changes - 2. The wait is cancelled - - Used by tasks/result to wait for task completion or status changes. - - Args: - task_id: The task identifier - - Raises: - ValueError: If task not found - """ - - @abstractmethod - async def notify_update(self, task_id: str) -> None: - """Signal that a task has been updated. - - This wakes up any coroutines waiting in wait_for_update(). - - Args: - task_id: The task identifier - """ diff --git a/src/mcp/shared/response_router.py b/src/mcp/shared/response_router.py deleted file mode 100644 index fe24b016f1..0000000000 --- a/src/mcp/shared/response_router.py +++ /dev/null @@ -1,61 +0,0 @@ -"""ResponseRouter - Protocol for pluggable response routing. - -This module defines a protocol for routing JSON-RPC responses to alternative -handlers before falling back to the default response stream mechanism. - -The primary use case is task-augmented requests: when a TaskSession enqueues -a request (like elicitation), the response needs to be routed back to the -waiting resolver instead of the normal response stream. - -Design: -- Protocol-based for testability and flexibility -- Returns bool to indicate if response was handled -- Supports both success responses and errors -""" - -from typing import Any, Protocol - -from mcp.types import ErrorData, RequestId - - -class ResponseRouter(Protocol): - """Protocol for routing responses to alternative handlers. - - Implementations check if they have a pending request for the given ID - and deliver the response/error to the appropriate handler. - - Example: - ```python - class TaskResultHandler(ResponseRouter): - def route_response(self, request_id, response): - resolver = self._pending_requests.pop(request_id, None) - if resolver: - resolver.set_result(response) - return True - return False - ``` - """ - - def route_response(self, request_id: RequestId, response: dict[str, Any]) -> bool: - """Try to route a response to a pending request handler. - - Args: - request_id: The JSON-RPC request ID from the response - response: The response result data - - Returns: - True if the response was handled, False otherwise - """ - ... # pragma: no cover - - def route_error(self, request_id: RequestId, error: ErrorData) -> bool: - """Try to route an error to a pending request handler. - - Args: - request_id: The JSON-RPC request ID from the error response - error: The error data - - Returns: - True if the error was handled, False otherwise - """ - ... # pragma: no cover diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 9c72a23844..ea5d8833bd 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -17,7 +17,6 @@ from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.exceptions import MCPError from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage -from mcp.shared.response_router import ResponseRouter from mcp.types import ( CONNECTION_CLOSED, INVALID_PARAMS, @@ -183,7 +182,6 @@ class BaseSession( _request_id: int _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] _progress_callbacks: dict[RequestId, ProgressFnT] - _response_routers: list[ResponseRouter] def __init__( self, @@ -199,24 +197,8 @@ def __init__( self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} self._progress_callbacks = {} - self._response_routers = [] self._exit_stack = AsyncExitStack() - def add_response_router(self, router: ResponseRouter) -> None: - """Register a response router to handle responses for non-standard requests. - - Response routers are checked in order before falling back to the default - response stream mechanism. This is used by TaskResultHandler to route - responses for queued task requests back to their resolvers. - - !!! warning - This is an experimental API that may change without notice. - - Args: - router: A ResponseRouter implementation - """ - self._response_routers.append(router) - async def __aenter__(self) -> Self: self._task_group = anyio.create_task_group() await self._task_group.__aenter__() @@ -477,11 +459,7 @@ def _normalize_request_id(self, response_id: RequestId) -> RequestId: return response_id async def _handle_response(self, message: SessionMessage) -> None: - """Handle an incoming response or error message. - - Checks response routers first (e.g., for task-related responses), - then falls back to the normal response stream mechanism. - """ + """Handle an incoming response or error message.""" # This check is always true at runtime: the caller (_receive_loop) only invokes # this method in the else branch after checking for JSONRPCRequest and # JSONRPCNotification. However, the type checker can't infer this from the @@ -498,20 +476,6 @@ async def _handle_response(self, message: SessionMessage) -> None: # Normalize response ID to handle type mismatches (e.g., "0" vs 0) response_id = self._normalize_request_id(message.message.id) - # First, check response routers (e.g., TaskResultHandler) - if isinstance(message.message, JSONRPCError): - # Route error to routers - for router in self._response_routers: - if router.route_error(response_id, message.message.error): - return # Handled - else: - # Route success response to routers - response_data: dict[str, Any] = message.message.result or {} - for router in self._response_routers: - if router.route_response(response_id, response_data): - return # Handled - - # Fall back to normal response streams stream = self._response_streams.pop(response_id, None) if stream: await stream.send(message.message) diff --git a/src/mcp/types/__init__.py b/src/mcp/types/__init__.py index b442303937..b2d537fb70 100644 --- a/src/mcp/types/__init__.py +++ b/src/mcp/types/__init__.py @@ -8,14 +8,6 @@ from mcp.types._types import ( DEFAULT_NEGOTIATED_VERSION, LATEST_PROTOCOL_VERSION, - TASK_FORBIDDEN, - TASK_OPTIONAL, - TASK_REQUIRED, - TASK_STATUS_CANCELLED, - TASK_STATUS_COMPLETED, - TASK_STATUS_FAILED, - TASK_STATUS_INPUT_REQUIRED, - TASK_STATUS_WORKING, Annotations, AudioContent, BaseMetadata, @@ -25,15 +17,10 @@ CallToolResult, CancelledNotification, CancelledNotificationParams, - CancelTaskRequest, - CancelTaskRequestParams, - CancelTaskResult, ClientCapabilities, ClientNotification, ClientRequest, ClientResult, - ClientTasksCapability, - ClientTasksRequestsCapability, CompleteRequest, CompleteRequestParams, CompleteResult, @@ -46,7 +33,6 @@ CreateMessageRequestParams, CreateMessageResult, CreateMessageResultWithTools, - CreateTaskResult, ElicitationCapability, ElicitationRequiredErrorData, ElicitCompleteNotification, @@ -63,12 +49,6 @@ GetPromptRequest, GetPromptRequestParams, GetPromptResult, - GetTaskPayloadRequest, - GetTaskPayloadRequestParams, - GetTaskPayloadResult, - GetTaskRequest, - GetTaskRequestParams, - GetTaskResult, Icon, IconTheme, ImageContent, @@ -86,8 +66,6 @@ ListResourceTemplatesResult, ListRootsRequest, ListRootsResult, - ListTasksRequest, - ListTasksResult, ListToolsRequest, ListToolsResult, LoggingCapability, @@ -114,7 +92,6 @@ ReadResourceRequest, ReadResourceRequestParams, ReadResourceResult, - RelatedTaskMetadata, Request, RequestParams, RequestParamsMeta, @@ -142,33 +119,16 @@ ServerNotification, ServerRequest, ServerResult, - ServerTasksCapability, - ServerTasksRequestsCapability, SetLevelRequest, SetLevelRequestParams, StopReason, SubscribeRequest, SubscribeRequestParams, - Task, - TaskExecutionMode, - TaskMetadata, - TasksCallCapability, - TasksCancelCapability, - TasksCreateElicitationCapability, - TasksCreateMessageCapability, - TasksElicitationCapability, - TasksListCapability, - TasksSamplingCapability, - TaskStatus, - TaskStatusNotification, - TaskStatusNotificationParams, - TasksToolsCapability, TextContent, TextResourceContents, Tool, ToolAnnotations, ToolChoice, - ToolExecution, ToolListChangedNotification, ToolResultContent, ToolsCapability, @@ -208,16 +168,6 @@ # Protocol version constants "LATEST_PROTOCOL_VERSION", "DEFAULT_NEGOTIATED_VERSION", - # Task execution mode constants - "TASK_FORBIDDEN", - "TASK_OPTIONAL", - "TASK_REQUIRED", - # Task status constants - "TASK_STATUS_CANCELLED", - "TASK_STATUS_COMPLETED", - "TASK_STATUS_FAILED", - "TASK_STATUS_INPUT_REQUIRED", - "TASK_STATUS_WORKING", # Type aliases and variables "ContentBlock", "ElicitRequestedSchema", @@ -229,8 +179,6 @@ "SamplingContent", "SamplingMessageContentBlock", "StopReason", - "TaskExecutionMode", - "TaskStatus", # Base classes "BaseMetadata", "Request", @@ -245,8 +193,6 @@ "EmptyResult", # Capabilities "ClientCapabilities", - "ClientTasksCapability", - "ClientTasksRequestsCapability", "CompletionsCapability", "ElicitationCapability", "FormElicitationCapability", @@ -258,16 +204,6 @@ "SamplingContextCapability", "SamplingToolsCapability", "ServerCapabilities", - "ServerTasksCapability", - "ServerTasksRequestsCapability", - "TasksCancelCapability", - "TasksCallCapability", - "TasksCreateElicitationCapability", - "TasksCreateMessageCapability", - "TasksElicitationCapability", - "TasksListCapability", - "TasksSamplingCapability", - "TasksToolsCapability", "ToolsCapability", "UrlElicitationCapability", # Content types @@ -300,18 +236,12 @@ "ResourceTemplateReference", "Root", "SamplingMessage", - "Task", - "TaskMetadata", - "RelatedTaskMetadata", "Tool", "ToolAnnotations", "ToolChoice", - "ToolExecution", # Requests "CallToolRequest", "CallToolRequestParams", - "CancelTaskRequest", - "CancelTaskRequestParams", "CompleteRequest", "CompleteRequestParams", "CreateMessageRequest", @@ -321,17 +251,12 @@ "ElicitRequestURLParams", "GetPromptRequest", "GetPromptRequestParams", - "GetTaskPayloadRequest", - "GetTaskPayloadRequestParams", - "GetTaskRequest", - "GetTaskRequestParams", "InitializeRequest", "InitializeRequestParams", "ListPromptsRequest", "ListResourcesRequest", "ListResourceTemplatesRequest", "ListRootsRequest", - "ListTasksRequest", "ListToolsRequest", "PingRequest", "ReadResourceRequest", @@ -344,22 +269,17 @@ "UnsubscribeRequestParams", # Results "CallToolResult", - "CancelTaskResult", "CompleteResult", "CreateMessageResult", "CreateMessageResultWithTools", - "CreateTaskResult", "ElicitResult", "ElicitationRequiredErrorData", "GetPromptResult", - "GetTaskPayloadResult", - "GetTaskResult", "InitializeResult", "ListPromptsResult", "ListResourcesResult", "ListResourceTemplatesResult", "ListRootsResult", - "ListTasksResult", "ListToolsResult", "ReadResourceResult", # Notifications @@ -377,8 +297,6 @@ "ResourceUpdatedNotification", "ResourceUpdatedNotificationParams", "RootsListChangedNotification", - "TaskStatusNotification", - "TaskStatusNotificationParams", "ToolListChangedNotification", # Union types for request/response routing "ClientNotification", diff --git a/src/mcp/types/_types.py b/src/mcp/types/_types.py index 9005d253af..34800ba12e 100644 --- a/src/mcp/types/_types.py +++ b/src/mcp/types/_types.py @@ -1,7 +1,6 @@ from __future__ import annotations -from datetime import datetime -from typing import Annotated, Any, Final, Generic, Literal, TypeAlias, TypeVar +from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar from pydantic import BaseModel, ConfigDict, Field, FileUrl, TypeAdapter from pydantic.alias_generators import to_camel @@ -30,11 +29,6 @@ IconTheme = Literal["light", "dark"] -TaskExecutionMode = Literal["forbidden", "optional", "required"] -TASK_FORBIDDEN: Final[Literal["forbidden"]] = "forbidden" -TASK_OPTIONAL: Final[Literal["optional"]] = "optional" -TASK_REQUIRED: Final[Literal["required"]] = "required" - class MCPModel(BaseModel): """Base class for all MCP protocol types.""" @@ -55,27 +49,7 @@ class RequestParamsMeta(TypedDict, extra_items=Any): """ -class TaskMetadata(MCPModel): - """Metadata for augmenting a request with task execution. - - Include this in the `task` field of the request parameters. - """ - - ttl: Annotated[int, Field(strict=True)] | None = None - """Requested duration in milliseconds to retain task from creation.""" - - class RequestParams(MCPModel): - task: TaskMetadata | None = None - """ - If specified, the caller is requesting task-augmented execution for this request. - The request will return a CreateTaskResult immediately, and the actual result can be - retrieved later via tasks/result. - - Task augmentation is subject to capability negotiation - receivers MUST declare support - for task augmentation of specific request types in their capabilities. - """ - meta: RequestParamsMeta | None = Field(alias="_meta", default=None) @@ -258,55 +232,6 @@ class SamplingCapability(MCPModel): """ -class TasksListCapability(MCPModel): - """Capability for tasks listing operations.""" - - -class TasksCancelCapability(MCPModel): - """Capability for tasks cancel operations.""" - - -class TasksCreateMessageCapability(MCPModel): - """Capability for tasks create messages.""" - - -class TasksSamplingCapability(MCPModel): - """Capability for tasks sampling operations.""" - - create_message: TasksCreateMessageCapability | None = None - - -class TasksCreateElicitationCapability(MCPModel): - """Capability for tasks create elicitation operations.""" - - -class TasksElicitationCapability(MCPModel): - """Capability for tasks elicitation operations.""" - - create: TasksCreateElicitationCapability | None = None - - -class ClientTasksRequestsCapability(MCPModel): - """Capability for tasks requests operations.""" - - sampling: TasksSamplingCapability | None = None - - elicitation: TasksElicitationCapability | None = None - - -class ClientTasksCapability(MCPModel): - """Capability for client tasks operations.""" - - list: TasksListCapability | None = None - """Whether this client supports tasks/list.""" - - cancel: TasksCancelCapability | None = None - """Whether this client supports tasks/cancel.""" - - requests: ClientTasksRequestsCapability | None = None - """Specifies which request types can be augmented with tasks.""" - - class ClientCapabilities(MCPModel): """Capabilities a client may support.""" @@ -321,8 +246,6 @@ class ClientCapabilities(MCPModel): """Present if the client supports elicitation from the user.""" roots: RootsCapability | None = None """Present if the client supports listing roots.""" - tasks: ClientTasksCapability | None = None - """Present if the client supports task-augmented requests.""" class PromptsCapability(MCPModel): @@ -356,30 +279,6 @@ class CompletionsCapability(MCPModel): """Capability for completions operations.""" -class TasksCallCapability(MCPModel): - """Capability for tasks call operations.""" - - -class TasksToolsCapability(MCPModel): - """Capability for tasks tools operations.""" - - call: TasksCallCapability | None = None - - -class ServerTasksRequestsCapability(MCPModel): - """Capability for tasks requests operations.""" - - tools: TasksToolsCapability | None = None - - -class ServerTasksCapability(MCPModel): - """Capability for server tasks operations.""" - - list: TasksListCapability | None = None - cancel: TasksCancelCapability | None = None - requests: ServerTasksRequestsCapability | None = None - - class ServerCapabilities(MCPModel): """Capabilities that a server may support.""" @@ -401,146 +300,6 @@ class ServerCapabilities(MCPModel): completions: CompletionsCapability | None = None """Present if the server offers autocompletion suggestions for prompts and resources.""" - tasks: ServerTasksCapability | None = None - """Present if the server supports task-augmented requests.""" - - -TaskStatus = Literal["working", "input_required", "completed", "failed", "cancelled"] - -# Task status constants -TASK_STATUS_WORKING: Final[Literal["working"]] = "working" -TASK_STATUS_INPUT_REQUIRED: Final[Literal["input_required"]] = "input_required" -TASK_STATUS_COMPLETED: Final[Literal["completed"]] = "completed" -TASK_STATUS_FAILED: Final[Literal["failed"]] = "failed" -TASK_STATUS_CANCELLED: Final[Literal["cancelled"]] = "cancelled" - - -class RelatedTaskMetadata(MCPModel): - """Metadata for associating messages with a task. - - Include this in the `_meta` field under the key `io.modelcontextprotocol/related-task`. - """ - - task_id: str - """The task identifier this message is associated with.""" - - -class Task(MCPModel): - """Data associated with a task.""" - - task_id: str - """The task identifier.""" - - status: TaskStatus - """Current task state.""" - - status_message: str | None = None - """Optional human-readable message describing the current task state. - - This can provide context for any status, including: - - Reasons for "cancelled" status - - Summaries for "completed" status - - Diagnostic information for "failed" status (e.g., error details, what went wrong) - """ - - created_at: datetime # Pydantic will enforce ISO 8601 and re-serialize as a string later - """ISO 8601 timestamp when the task was created.""" - - last_updated_at: datetime - """ISO 8601 timestamp when the task was last updated.""" - - ttl: Annotated[int, Field(strict=True)] | None - """Actual retention duration from creation in milliseconds, null for unlimited.""" - - poll_interval: Annotated[int, Field(strict=True)] | None = None - """Suggested polling interval in milliseconds.""" - - -class CreateTaskResult(Result): - """A response to a task-augmented request.""" - - task: Task - - -class GetTaskRequestParams(RequestParams): - task_id: str - """The task identifier to query.""" - - -class GetTaskRequest(Request[GetTaskRequestParams, Literal["tasks/get"]]): - """A request to retrieve the state of a task.""" - - method: Literal["tasks/get"] = "tasks/get" - - params: GetTaskRequestParams - - -class GetTaskResult(Result, Task): - """The response to a tasks/get request.""" - - -class GetTaskPayloadRequestParams(RequestParams): - task_id: str - """The task identifier to retrieve results for.""" - - -class GetTaskPayloadRequest(Request[GetTaskPayloadRequestParams, Literal["tasks/result"]]): - """A request to retrieve the result of a completed task.""" - - method: Literal["tasks/result"] = "tasks/result" - params: GetTaskPayloadRequestParams - - -class GetTaskPayloadResult(Result): - """The response to a tasks/result request. - - The structure matches the result type of the original request. - For example, a tools/call task would return the CallToolResult structure. - """ - - model_config = ConfigDict(extra="allow", alias_generator=to_camel, populate_by_name=True) - - -class CancelTaskRequestParams(RequestParams): - task_id: str - """The task identifier to cancel.""" - - -class CancelTaskRequest(Request[CancelTaskRequestParams, Literal["tasks/cancel"]]): - """A request to cancel a task.""" - - method: Literal["tasks/cancel"] = "tasks/cancel" - params: CancelTaskRequestParams - - -class CancelTaskResult(Result, Task): - """The response to a tasks/cancel request.""" - - -class ListTasksRequest(PaginatedRequest[Literal["tasks/list"]]): - """A request to retrieve a list of tasks.""" - - method: Literal["tasks/list"] = "tasks/list" - - -class ListTasksResult(PaginatedResult): - """The response to a tasks/list request.""" - - tasks: list[Task] - - -class TaskStatusNotificationParams(NotificationParams, Task): - """Parameters for a `notifications/tasks/status` notification.""" - - -class TaskStatusNotification(Notification[TaskStatusNotificationParams, Literal["notifications/tasks/status"]]): - """An optional notification from the receiver to the requestor, informing them that a task's status has changed. - Receivers are not required to send these notifications. - """ - - method: Literal["notifications/tasks/status"] = "notifications/tasks/status" - params: TaskStatusNotificationParams - class InitializeRequestParams(RequestParams): """Parameters for the initialize request.""" @@ -1133,23 +892,6 @@ class ToolAnnotations(MCPModel): """ -class ToolExecution(MCPModel): - """Execution-related properties for a tool.""" - - task_support: TaskExecutionMode | None = None - """ - Indicates whether this tool supports task-augmented execution. - This allows clients to handle long-running operations through polling - the task system. - - - "forbidden": Tool does not support task-augmented execution (default when absent) - - "optional": Tool may support task-augmented execution - - "required": Tool requires task-augmented execution - - Default: "forbidden" - """ - - class Tool(BaseMetadata): """Definition for a tool the client can call.""" @@ -1172,8 +914,6 @@ class Tool(BaseMetadata): for notes on _meta usage. """ - execution: ToolExecution | None = None - class ListToolsResult(PaginatedResult): """The server's response to a tools/list request from the client.""" @@ -1554,8 +1294,6 @@ class CancelledNotificationParams(NotificationParams): The ID of the request to cancel. This MUST correspond to the ID of a request previously issued in the same direction. - This MUST be provided for cancelling non-task requests. - This MUST NOT be used for cancelling tasks (use the `tasks/cancel` request instead). """ reason: str | None = None """An optional string describing the reason for the cancellation.""" @@ -1607,20 +1345,12 @@ class ElicitCompleteNotification( | UnsubscribeRequest | CallToolRequest | ListToolsRequest - | GetTaskRequest - | GetTaskPayloadRequest - | ListTasksRequest - | CancelTaskRequest ) client_request_adapter = TypeAdapter[ClientRequest](ClientRequest) ClientNotification = ( - CancelledNotification - | ProgressNotification - | InitializedNotification - | RootsListChangedNotification - | TaskStatusNotification + CancelledNotification | ProgressNotification | InitializedNotification | RootsListChangedNotification ) client_notification_adapter = TypeAdapter[ClientNotification](ClientNotification) @@ -1716,31 +1446,11 @@ class ElicitationRequiredErrorData(MCPModel): """List of URL mode elicitations that must be completed.""" -ClientResult = ( - EmptyResult - | CreateMessageResult - | CreateMessageResultWithTools - | ListRootsResult - | ElicitResult - | GetTaskResult - | GetTaskPayloadResult - | ListTasksResult - | CancelTaskResult - | CreateTaskResult -) +ClientResult = EmptyResult | CreateMessageResult | CreateMessageResultWithTools | ListRootsResult | ElicitResult client_result_adapter = TypeAdapter[ClientResult](ClientResult) -ServerRequest = ( - PingRequest - | CreateMessageRequest - | ListRootsRequest - | ElicitRequest - | GetTaskRequest - | GetTaskPayloadRequest - | ListTasksRequest - | CancelTaskRequest -) +ServerRequest = PingRequest | CreateMessageRequest | ListRootsRequest | ElicitRequest server_request_adapter = TypeAdapter[ServerRequest](ServerRequest) @@ -1753,7 +1463,6 @@ class ElicitationRequiredErrorData(MCPModel): | ToolListChangedNotification | PromptListChangedNotification | ElicitCompleteNotification - | TaskStatusNotification ) server_notification_adapter = TypeAdapter[ServerNotification](ServerNotification) @@ -1769,10 +1478,5 @@ class ElicitationRequiredErrorData(MCPModel): | ReadResourceResult | CallToolResult | ListToolsResult - | GetTaskResult - | GetTaskPayloadResult - | ListTasksResult - | CancelTaskResult - | CreateTaskResult ) server_result_adapter = TypeAdapter[ServerResult](ServerResult) diff --git a/tests/experimental/__init__.py b/tests/experimental/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/experimental/tasks/__init__.py b/tests/experimental/tasks/__init__.py deleted file mode 100644 index 6e8649d283..0000000000 --- a/tests/experimental/tasks/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Tests for MCP task support.""" diff --git a/tests/experimental/tasks/client/__init__.py b/tests/experimental/tasks/client/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/experimental/tasks/client/test_capabilities.py b/tests/experimental/tasks/client/test_capabilities.py deleted file mode 100644 index 1ea2199e8c..0000000000 --- a/tests/experimental/tasks/client/test_capabilities.py +++ /dev/null @@ -1,312 +0,0 @@ -"""Tests for client task capabilities declaration during initialization.""" - -import anyio -import pytest - -from mcp import ClientCapabilities, types -from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers -from mcp.client.session import ClientSession -from mcp.shared._context import RequestContext -from mcp.shared.message import SessionMessage -from mcp.types import ( - LATEST_PROTOCOL_VERSION, - Implementation, - InitializeRequest, - InitializeResult, - JSONRPCRequest, - JSONRPCResponse, - ServerCapabilities, - client_request_adapter, -) - - -@pytest.mark.anyio -async def test_client_capabilities_without_tasks(): - """Test that tasks capability is None when not provided.""" - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - - received_capabilities = None - - async def mock_server(): - nonlocal received_capabilities - - session_message = await client_to_server_receive.receive() - jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request, JSONRPCRequest) - request = client_request_adapter.validate_python( - jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) - ) - assert isinstance(request, InitializeRequest) - received_capabilities = request.params.capabilities - - result = InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) - - async with server_to_client_send: - await server_to_client_send.send( - SessionMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) - ) - ) - await client_to_server_receive.receive() - - async with ( - ClientSession( - server_to_client_receive, - client_to_server_send, - ) as session, - anyio.create_task_group() as tg, - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - ): - tg.start_soon(mock_server) - await session.initialize() - - # Assert that tasks capability is None when not provided - assert received_capabilities is not None - assert received_capabilities.tasks is None - - -@pytest.mark.anyio -async def test_client_capabilities_with_tasks(): - """Test that tasks capability is properly set when handlers are provided.""" - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - - received_capabilities: ClientCapabilities | None = None - - # Define custom handlers to trigger capability building (never actually called) - async def my_list_tasks_handler( - context: RequestContext[ClientSession], - params: types.PaginatedRequestParams | None, - ) -> types.ListTasksResult | types.ErrorData: - raise NotImplementedError - - async def my_cancel_task_handler( - context: RequestContext[ClientSession], - params: types.CancelTaskRequestParams, - ) -> types.CancelTaskResult | types.ErrorData: - raise NotImplementedError - - async def mock_server(): - nonlocal received_capabilities - - session_message = await client_to_server_receive.receive() - jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request, JSONRPCRequest) - request = client_request_adapter.validate_python( - jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) - ) - assert isinstance(request, InitializeRequest) - received_capabilities = request.params.capabilities - - result = InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) - - async with server_to_client_send: - await server_to_client_send.send( - SessionMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) - ) - ) - await client_to_server_receive.receive() - - # Create handlers container - task_handlers = ExperimentalTaskHandlers( - list_tasks=my_list_tasks_handler, - cancel_task=my_cancel_task_handler, - ) - - async with ( - ClientSession( - server_to_client_receive, - client_to_server_send, - experimental_task_handlers=task_handlers, - ) as session, - anyio.create_task_group() as tg, - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - ): - tg.start_soon(mock_server) - await session.initialize() - - # Assert that tasks capability is properly set from handlers - assert received_capabilities is not None - assert received_capabilities.tasks is not None - assert isinstance(received_capabilities.tasks, types.ClientTasksCapability) - assert received_capabilities.tasks.list is not None - assert received_capabilities.tasks.cancel is not None - - -@pytest.mark.anyio -async def test_client_capabilities_auto_built_from_handlers(): - """Test that tasks capability is automatically built from provided handlers.""" - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - - received_capabilities: ClientCapabilities | None = None - - # Define custom handlers (not defaults) - async def my_list_tasks_handler( - context: RequestContext[ClientSession], - params: types.PaginatedRequestParams | None, - ) -> types.ListTasksResult | types.ErrorData: - raise NotImplementedError - - async def my_cancel_task_handler( - context: RequestContext[ClientSession], - params: types.CancelTaskRequestParams, - ) -> types.CancelTaskResult | types.ErrorData: - raise NotImplementedError - - async def mock_server(): - nonlocal received_capabilities - - session_message = await client_to_server_receive.receive() - jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request, JSONRPCRequest) - request = client_request_adapter.validate_python( - jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) - ) - assert isinstance(request, InitializeRequest) - received_capabilities = request.params.capabilities - - result = InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) - - async with server_to_client_send: - await server_to_client_send.send( - SessionMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) - ) - ) - await client_to_server_receive.receive() - - # Provide handlers via ExperimentalTaskHandlers - task_handlers = ExperimentalTaskHandlers( - list_tasks=my_list_tasks_handler, - cancel_task=my_cancel_task_handler, - ) - - async with ( - ClientSession( - server_to_client_receive, - client_to_server_send, - experimental_task_handlers=task_handlers, - ) as session, - anyio.create_task_group() as tg, - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - ): - tg.start_soon(mock_server) - await session.initialize() - - # Assert that tasks capability was auto-built from handlers - assert received_capabilities is not None - assert received_capabilities.tasks is not None - assert received_capabilities.tasks.list is not None - assert received_capabilities.tasks.cancel is not None - # requests should be None since we didn't provide task-augmented handlers - assert received_capabilities.tasks.requests is None - - -@pytest.mark.anyio -async def test_client_capabilities_with_task_augmented_handlers(): - """Test that requests capability is built when augmented handlers are provided.""" - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - - received_capabilities: ClientCapabilities | None = None - - # Define task-augmented handler - async def my_augmented_sampling_handler( - context: RequestContext[ClientSession], - params: types.CreateMessageRequestParams, - task_metadata: types.TaskMetadata, - ) -> types.CreateTaskResult | types.ErrorData: - raise NotImplementedError - - async def mock_server(): - nonlocal received_capabilities - - session_message = await client_to_server_receive.receive() - jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request, JSONRPCRequest) - request = client_request_adapter.validate_python( - jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) - ) - assert isinstance(request, InitializeRequest) - received_capabilities = request.params.capabilities - - result = InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) - - async with server_to_client_send: - await server_to_client_send.send( - SessionMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) - ) - ) - await client_to_server_receive.receive() - - # Provide task-augmented sampling handler - task_handlers = ExperimentalTaskHandlers( - augmented_sampling=my_augmented_sampling_handler, - ) - - async with ( - ClientSession( - server_to_client_receive, - client_to_server_send, - experimental_task_handlers=task_handlers, - ) as session, - anyio.create_task_group() as tg, - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - ): - tg.start_soon(mock_server) - await session.initialize() - - # Assert that tasks capability includes requests.sampling - assert received_capabilities is not None - assert received_capabilities.tasks is not None - assert received_capabilities.tasks.requests is not None - assert received_capabilities.tasks.requests.sampling is not None - assert received_capabilities.tasks.requests.elicitation is None # Not provided diff --git a/tests/experimental/tasks/client/test_handlers.py b/tests/experimental/tasks/client/test_handlers.py deleted file mode 100644 index 137ff80106..0000000000 --- a/tests/experimental/tasks/client/test_handlers.py +++ /dev/null @@ -1,874 +0,0 @@ -"""Tests for client-side task management handlers (server -> client requests). - -These tests verify that clients can handle task-related requests from servers: -- GetTaskRequest - server polling client's task status -- GetTaskPayloadRequest - server getting result from client's task -- ListTasksRequest - server listing client's tasks -- CancelTaskRequest - server cancelling client's task - -This is the inverse of the existing tests in test_tasks.py, which test -client -> server task requests. -""" - -from collections.abc import AsyncIterator -from dataclasses import dataclass - -import anyio -import pytest -from anyio import Event -from anyio.abc import TaskGroup -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream - -from mcp import types -from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers -from mcp.client.session import ClientSession -from mcp.shared._context import RequestContext -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder -from mcp.types import ( - CancelTaskRequest, - CancelTaskRequestParams, - CancelTaskResult, - ClientResult, - CreateMessageRequest, - CreateMessageRequestParams, - CreateMessageResult, - CreateTaskResult, - ElicitRequest, - ElicitRequestFormParams, - ElicitRequestParams, - ElicitResult, - ErrorData, - GetTaskPayloadRequest, - GetTaskPayloadRequestParams, - GetTaskPayloadResult, - GetTaskRequest, - GetTaskRequestParams, - GetTaskResult, - ListTasksRequest, - ListTasksResult, - SamplingMessage, - ServerNotification, - ServerRequest, - TaskMetadata, - TextContent, -) - -# Buffer size for test streams -STREAM_BUFFER_SIZE = 10 - - -@dataclass -class ClientTestStreams: - """Bidirectional message streams for client/server communication in tests.""" - - server_send: MemoryObjectSendStream[SessionMessage] - server_receive: MemoryObjectReceiveStream[SessionMessage] - client_send: MemoryObjectSendStream[SessionMessage] - client_receive: MemoryObjectReceiveStream[SessionMessage] - - -@pytest.fixture -async def client_streams() -> AsyncIterator[ClientTestStreams]: - """Create bidirectional message streams for client tests. - - Automatically closes all streams after the test completes. - """ - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage]( - STREAM_BUFFER_SIZE - ) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage]( - STREAM_BUFFER_SIZE - ) - - streams = ClientTestStreams( - server_send=server_to_client_send, - server_receive=client_to_server_receive, - client_send=client_to_server_send, - client_receive=server_to_client_receive, - ) - - yield streams - - # Cleanup - await server_to_client_send.aclose() - await server_to_client_receive.aclose() - await client_to_server_send.aclose() - await client_to_server_receive.aclose() - - -async def _default_message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, -) -> None: - """Default message handler that ignores messages (tests handle them explicitly).""" - ... - - -@pytest.mark.anyio -async def test_client_handles_get_task_request(client_streams: ClientTestStreams) -> None: - """Test that client can respond to GetTaskRequest from server.""" - with anyio.fail_after(10): - store = InMemoryTaskStore() - received_task_id: str | None = None - - async def get_task_handler( - context: RequestContext[ClientSession], - params: GetTaskRequestParams, - ) -> GetTaskResult | ErrorData: - nonlocal received_task_id - received_task_id = params.task_id - task = await store.get_task(params.task_id) - assert task is not None, f"Test setup error: task {params.task_id} should exist" - return GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=task.poll_interval, - ) - - await store.create_task(TaskMetadata(ttl=60000), task_id="test-task-123") - - task_handlers = ExperimentalTaskHandlers(get_task=get_task_handler) - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - - async def run_client() -> None: - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - experimental_task_handlers=task_handlers, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - typed_request = GetTaskRequest(params=GetTaskRequestParams(task_id="test-task-123")) - request = types.JSONRPCRequest(jsonrpc="2.0", id="req-1", **typed_request.model_dump(by_alias=True)) - await client_streams.server_send.send(SessionMessage(request)) - - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCResponse) - assert response.id == "req-1" - - result = GetTaskResult.model_validate(response.result) - assert result.task_id == "test-task-123" - assert result.status == "working" - assert received_task_id == "test-task-123" - - tg.cancel_scope.cancel() - - store.cleanup() - - -@pytest.mark.anyio -async def test_client_handles_get_task_result_request(client_streams: ClientTestStreams) -> None: - """Test that client can respond to GetTaskPayloadRequest from server.""" - with anyio.fail_after(10): - store = InMemoryTaskStore() - - async def get_task_result_handler( - context: RequestContext[ClientSession], - params: GetTaskPayloadRequestParams, - ) -> GetTaskPayloadResult | ErrorData: - result = await store.get_result(params.task_id) - assert result is not None, f"Test setup error: result for {params.task_id} should exist" - assert isinstance(result, types.CallToolResult) - return GetTaskPayloadResult(**result.model_dump()) - - await store.create_task(TaskMetadata(ttl=60000), task_id="test-task-456") - await store.store_result( - "test-task-456", - types.CallToolResult(content=[TextContent(type="text", text="Task completed successfully!")]), - ) - await store.update_task("test-task-456", status="completed") - - task_handlers = ExperimentalTaskHandlers(get_task_result=get_task_result_handler) - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - - async def run_client() -> None: - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - experimental_task_handlers=task_handlers, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - typed_request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id="test-task-456")) - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-2", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCResponse) - - assert isinstance(response.result, dict) - result_dict = response.result - assert "content" in result_dict - assert len(result_dict["content"]) == 1 - assert result_dict["content"][0]["text"] == "Task completed successfully!" - - tg.cancel_scope.cancel() - - store.cleanup() - - -@pytest.mark.anyio -async def test_client_handles_list_tasks_request(client_streams: ClientTestStreams) -> None: - """Test that client can respond to ListTasksRequest from server.""" - with anyio.fail_after(10): - store = InMemoryTaskStore() - - async def list_tasks_handler( - context: RequestContext[ClientSession], - params: types.PaginatedRequestParams | None, - ) -> ListTasksResult | ErrorData: - cursor = params.cursor if params else None - tasks_list, next_cursor = await store.list_tasks(cursor=cursor) - return ListTasksResult(tasks=tasks_list, next_cursor=next_cursor) - - await store.create_task(TaskMetadata(ttl=60000), task_id="task-1") - await store.create_task(TaskMetadata(ttl=60000), task_id="task-2") - - task_handlers = ExperimentalTaskHandlers(list_tasks=list_tasks_handler) - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - - async def run_client() -> None: - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - experimental_task_handlers=task_handlers, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - typed_request = ListTasksRequest() - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-3", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCResponse) - - result = ListTasksResult.model_validate(response.result) - assert len(result.tasks) == 2 - - tg.cancel_scope.cancel() - - store.cleanup() - - -@pytest.mark.anyio -async def test_client_handles_cancel_task_request(client_streams: ClientTestStreams) -> None: - """Test that client can respond to CancelTaskRequest from server.""" - with anyio.fail_after(10): - store = InMemoryTaskStore() - - async def cancel_task_handler( - context: RequestContext[ClientSession], - params: CancelTaskRequestParams, - ) -> CancelTaskResult | ErrorData: - task = await store.get_task(params.task_id) - assert task is not None, f"Test setup error: task {params.task_id} should exist" - await store.update_task(params.task_id, status="cancelled") - updated = await store.get_task(params.task_id) - assert updated is not None - return CancelTaskResult( - task_id=updated.task_id, - status=updated.status, - created_at=updated.created_at, - last_updated_at=updated.last_updated_at, - ttl=updated.ttl, - ) - - await store.create_task(TaskMetadata(ttl=60000), task_id="task-to-cancel") - - task_handlers = ExperimentalTaskHandlers(cancel_task=cancel_task_handler) - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - - async def run_client() -> None: - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - experimental_task_handlers=task_handlers, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - typed_request = CancelTaskRequest(params=CancelTaskRequestParams(task_id="task-to-cancel")) - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-4", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCResponse) - - result = CancelTaskResult.model_validate(response.result) - assert result.task_id == "task-to-cancel" - assert result.status == "cancelled" - - tg.cancel_scope.cancel() - - store.cleanup() - - -@pytest.mark.anyio -async def test_client_task_augmented_sampling(client_streams: ClientTestStreams) -> None: - """Test that client can handle task-augmented sampling request from server.""" - with anyio.fail_after(10): - store = InMemoryTaskStore() - sampling_completed = Event() - created_task_id: list[str | None] = [None] - background_tg: list[TaskGroup | None] = [None] - - async def task_augmented_sampling_callback( - context: RequestContext[ClientSession], - params: CreateMessageRequestParams, - task_metadata: TaskMetadata, - ) -> CreateTaskResult: - task = await store.create_task(task_metadata) - created_task_id[0] = task.task_id - - async def do_sampling() -> None: - result = CreateMessageResult( - role="assistant", - content=TextContent(type="text", text="Sampled response"), - model="test-model", - stop_reason="endTurn", - ) - await store.store_result(task.task_id, result) - await store.update_task(task.task_id, status="completed") - sampling_completed.set() - - assert background_tg[0] is not None - background_tg[0].start_soon(do_sampling) - return CreateTaskResult(task=task) - - async def get_task_handler( - context: RequestContext[ClientSession], - params: GetTaskRequestParams, - ) -> GetTaskResult | ErrorData: - task = await store.get_task(params.task_id) - assert task is not None, f"Test setup error: task {params.task_id} should exist" - return GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=task.poll_interval, - ) - - async def get_task_result_handler( - context: RequestContext[ClientSession], - params: GetTaskPayloadRequestParams, - ) -> GetTaskPayloadResult | ErrorData: - result = await store.get_result(params.task_id) - assert result is not None, f"Test setup error: result for {params.task_id} should exist" - assert isinstance(result, CreateMessageResult) - return GetTaskPayloadResult(**result.model_dump()) - - task_handlers = ExperimentalTaskHandlers( - augmented_sampling=task_augmented_sampling_callback, - get_task=get_task_handler, - get_task_result=get_task_result_handler, - ) - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - background_tg[0] = tg - - async def run_client() -> None: - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - experimental_task_handlers=task_handlers, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - # Step 1: Server sends task-augmented CreateMessageRequest - typed_request = CreateMessageRequest( - params=CreateMessageRequestParams( - messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], - max_tokens=100, - task=TaskMetadata(ttl=60000), - ) - ) - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-sampling", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - # Step 2: Client responds with CreateTaskResult - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCResponse) - - task_result = CreateTaskResult.model_validate(response.result) - task_id = task_result.task.task_id - assert task_id == created_task_id[0] - - # Step 3: Wait for background sampling - await sampling_completed.wait() - - # Step 4: Server polls task status - typed_poll = GetTaskRequest(params=GetTaskRequestParams(task_id=task_id)) - poll_request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-poll", - **typed_poll.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(poll_request)) - - poll_response_msg = await client_streams.server_receive.receive() - poll_response = poll_response_msg.message - assert isinstance(poll_response, types.JSONRPCResponse) - - status = GetTaskResult.model_validate(poll_response.result) - assert status.status == "completed" - - # Step 5: Server gets result - typed_result_req = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=task_id)) - result_request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-result", - **typed_result_req.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(result_request)) - - result_response_msg = await client_streams.server_receive.receive() - result_response = result_response_msg.message - assert isinstance(result_response, types.JSONRPCResponse) - - assert isinstance(result_response.result, dict) - assert result_response.result["role"] == "assistant" - - tg.cancel_scope.cancel() - - store.cleanup() - - -@pytest.mark.anyio -async def test_client_task_augmented_elicitation(client_streams: ClientTestStreams) -> None: - """Test that client can handle task-augmented elicitation request from server.""" - with anyio.fail_after(10): - store = InMemoryTaskStore() - elicitation_completed = Event() - created_task_id: list[str | None] = [None] - background_tg: list[TaskGroup | None] = [None] - - async def task_augmented_elicitation_callback( - context: RequestContext[ClientSession], - params: ElicitRequestParams, - task_metadata: TaskMetadata, - ) -> CreateTaskResult | ErrorData: - task = await store.create_task(task_metadata) - created_task_id[0] = task.task_id - - async def do_elicitation() -> None: - # Simulate user providing elicitation response - result = ElicitResult(action="accept", content={"name": "Test User"}) - await store.store_result(task.task_id, result) - await store.update_task(task.task_id, status="completed") - elicitation_completed.set() - - assert background_tg[0] is not None - background_tg[0].start_soon(do_elicitation) - return CreateTaskResult(task=task) - - async def get_task_handler( - context: RequestContext[ClientSession], - params: GetTaskRequestParams, - ) -> GetTaskResult | ErrorData: - task = await store.get_task(params.task_id) - assert task is not None, f"Test setup error: task {params.task_id} should exist" - return GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=task.poll_interval, - ) - - async def get_task_result_handler( - context: RequestContext[ClientSession], - params: GetTaskPayloadRequestParams, - ) -> GetTaskPayloadResult | ErrorData: - result = await store.get_result(params.task_id) - assert result is not None, f"Test setup error: result for {params.task_id} should exist" - assert isinstance(result, ElicitResult) - return GetTaskPayloadResult(**result.model_dump()) - - task_handlers = ExperimentalTaskHandlers( - augmented_elicitation=task_augmented_elicitation_callback, - get_task=get_task_handler, - get_task_result=get_task_result_handler, - ) - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - background_tg[0] = tg - - async def run_client() -> None: - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - experimental_task_handlers=task_handlers, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - # Step 1: Server sends task-augmented ElicitRequest - typed_request = ElicitRequest( - params=ElicitRequestFormParams( - message="What is your name?", - requested_schema={"type": "object", "properties": {"name": {"type": "string"}}}, - task=TaskMetadata(ttl=60000), - ) - ) - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-elicit", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - # Step 2: Client responds with CreateTaskResult - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCResponse) - - task_result = CreateTaskResult.model_validate(response.result) - task_id = task_result.task.task_id - assert task_id == created_task_id[0] - - # Step 3: Wait for background elicitation - await elicitation_completed.wait() - - # Step 4: Server polls task status - typed_poll = GetTaskRequest(params=GetTaskRequestParams(task_id=task_id)) - poll_request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-poll", - **typed_poll.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(poll_request)) - - poll_response_msg = await client_streams.server_receive.receive() - poll_response = poll_response_msg.message - assert isinstance(poll_response, types.JSONRPCResponse) - - status = GetTaskResult.model_validate(poll_response.result) - assert status.status == "completed" - - # Step 5: Server gets result - typed_result_req = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=task_id)) - result_request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-result", - **typed_result_req.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(result_request)) - - result_response_msg = await client_streams.server_receive.receive() - result_response = result_response_msg.message - assert isinstance(result_response, types.JSONRPCResponse) - - # Verify the elicitation result - assert isinstance(result_response.result, dict) - assert result_response.result["action"] == "accept" - assert result_response.result["content"] == {"name": "Test User"} - - tg.cancel_scope.cancel() - - store.cleanup() - - -@pytest.mark.anyio -async def test_client_returns_error_for_unhandled_task_request(client_streams: ClientTestStreams) -> None: - """Test that client returns error when no handler is registered for task request.""" - with anyio.fail_after(10): - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - - async def run_client() -> None: - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - typed_request = GetTaskRequest(params=GetTaskRequestParams(task_id="nonexistent")) - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-unhandled", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCError) - assert ( - "not supported" in response.error.message.lower() - or "method not found" in response.error.message.lower() - ) - - tg.cancel_scope.cancel() - - -@pytest.mark.anyio -async def test_client_returns_error_for_unhandled_task_result_request(client_streams: ClientTestStreams) -> None: - """Test that client returns error for unhandled tasks/result request.""" - with anyio.fail_after(10): - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - - async def run_client() -> None: - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - typed_request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id="nonexistent")) - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-result", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCError) - assert "not supported" in response.error.message.lower() - - tg.cancel_scope.cancel() - - -@pytest.mark.anyio -async def test_client_returns_error_for_unhandled_list_tasks_request(client_streams: ClientTestStreams) -> None: - """Test that client returns error for unhandled tasks/list request.""" - with anyio.fail_after(10): - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - - async def run_client() -> None: - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - typed_request = ListTasksRequest() - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-list", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCError) - assert "not supported" in response.error.message.lower() - - tg.cancel_scope.cancel() - - -@pytest.mark.anyio -async def test_client_returns_error_for_unhandled_cancel_task_request(client_streams: ClientTestStreams) -> None: - """Test that client returns error for unhandled tasks/cancel request.""" - with anyio.fail_after(10): - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - - async def run_client() -> None: - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - typed_request = CancelTaskRequest(params=CancelTaskRequestParams(task_id="nonexistent")) - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-cancel", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCError) - assert "not supported" in response.error.message.lower() - - tg.cancel_scope.cancel() - - -@pytest.mark.anyio -async def test_client_returns_error_for_unhandled_task_augmented_sampling(client_streams: ClientTestStreams) -> None: - """Test that client returns error for task-augmented sampling without handler.""" - with anyio.fail_after(10): - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - - async def run_client() -> None: - # No task handlers provided - uses defaults - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - # Send task-augmented sampling request - typed_request = CreateMessageRequest( - params=CreateMessageRequestParams( - messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], - max_tokens=100, - task=TaskMetadata(ttl=60000), - ) - ) - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-sampling", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCError) - assert "not supported" in response.error.message.lower() - - tg.cancel_scope.cancel() - - -@pytest.mark.anyio -async def test_client_returns_error_for_unhandled_task_augmented_elicitation( - client_streams: ClientTestStreams, -) -> None: - """Test that client returns error for task-augmented elicitation without handler.""" - with anyio.fail_after(10): - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - - async def run_client() -> None: - # No task handlers provided - uses defaults - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - # Send task-augmented elicitation request - typed_request = ElicitRequest( - params=ElicitRequestFormParams( - message="What is your name?", - requested_schema={"type": "object", "properties": {"name": {"type": "string"}}}, - task=TaskMetadata(ttl=60000), - ) - ) - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-elicit", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCError) - assert "not supported" in response.error.message.lower() - - tg.cancel_scope.cancel() diff --git a/tests/experimental/tasks/client/test_poll_task.py b/tests/experimental/tasks/client/test_poll_task.py deleted file mode 100644 index 5e3158d955..0000000000 --- a/tests/experimental/tasks/client/test_poll_task.py +++ /dev/null @@ -1,121 +0,0 @@ -"""Tests for poll_task async iterator.""" - -from collections.abc import Callable, Coroutine -from datetime import datetime, timezone -from typing import Any -from unittest.mock import AsyncMock - -import pytest - -from mcp.client.experimental.tasks import ExperimentalClientFeatures -from mcp.types import GetTaskResult, TaskStatus - - -def make_task_result( - status: TaskStatus = "working", - poll_interval: int = 0, - task_id: str = "test-task", - status_message: str | None = None, -) -> GetTaskResult: - """Create GetTaskResult with sensible defaults.""" - now = datetime.now(timezone.utc) - return GetTaskResult( - task_id=task_id, - status=status, - status_message=status_message, - created_at=now, - last_updated_at=now, - ttl=60000, - poll_interval=poll_interval, - ) - - -def make_status_sequence( - *statuses: TaskStatus, - task_id: str = "test-task", -) -> Callable[[str], Coroutine[Any, Any, GetTaskResult]]: - """Create mock get_task that returns statuses in sequence.""" - status_iter = iter(statuses) - - async def mock_get_task(tid: str) -> GetTaskResult: - return make_task_result(status=next(status_iter), task_id=tid) - - return mock_get_task - - -@pytest.fixture -def mock_session() -> AsyncMock: - return AsyncMock() - - -@pytest.fixture -def features(mock_session: AsyncMock) -> ExperimentalClientFeatures: - return ExperimentalClientFeatures(mock_session) - - -@pytest.mark.anyio -async def test_poll_task_yields_until_completed(features: ExperimentalClientFeatures) -> None: - """poll_task yields each status until terminal.""" - features.get_task = make_status_sequence("working", "working", "completed") # type: ignore[method-assign] - - statuses = [s.status async for s in features.poll_task("test-task")] - - assert statuses == ["working", "working", "completed"] - - -@pytest.mark.anyio -@pytest.mark.parametrize("terminal_status", ["completed", "failed", "cancelled"]) -async def test_poll_task_exits_on_terminal(features: ExperimentalClientFeatures, terminal_status: TaskStatus) -> None: - """poll_task exits immediately when task is already terminal.""" - features.get_task = make_status_sequence(terminal_status) # type: ignore[method-assign] - - statuses = [s.status async for s in features.poll_task("test-task")] - - assert statuses == [terminal_status] - - -@pytest.mark.anyio -async def test_poll_task_continues_through_input_required(features: ExperimentalClientFeatures) -> None: - """poll_task yields input_required and continues (non-terminal).""" - features.get_task = make_status_sequence("working", "input_required", "working", "completed") # type: ignore[method-assign] - - statuses = [s.status async for s in features.poll_task("test-task")] - - assert statuses == ["working", "input_required", "working", "completed"] - - -@pytest.mark.anyio -async def test_poll_task_passes_task_id(features: ExperimentalClientFeatures) -> None: - """poll_task passes correct task_id to get_task.""" - received_ids: list[str] = [] - - async def mock_get_task(task_id: str) -> GetTaskResult: - received_ids.append(task_id) - return make_task_result(status="completed", task_id=task_id) - - features.get_task = mock_get_task # type: ignore[method-assign] - - _ = [s async for s in features.poll_task("my-task-123")] - - assert received_ids == ["my-task-123"] - - -@pytest.mark.anyio -async def test_poll_task_yields_full_result(features: ExperimentalClientFeatures) -> None: - """poll_task yields complete GetTaskResult objects.""" - - async def mock_get_task(task_id: str) -> GetTaskResult: - return make_task_result( - status="completed", - task_id=task_id, - status_message="All done!", - ) - - features.get_task = mock_get_task # type: ignore[method-assign] - - results = [r async for r in features.poll_task("test-task")] - - assert len(results) == 1 - assert results[0].status == "completed" - assert results[0].status_message == "All done!" - assert results[0].task_id == "test-task" diff --git a/tests/experimental/tasks/client/test_tasks.py b/tests/experimental/tasks/client/test_tasks.py deleted file mode 100644 index 613c794ebf..0000000000 --- a/tests/experimental/tasks/client/test_tasks.py +++ /dev/null @@ -1,309 +0,0 @@ -"""Tests for the experimental client task methods (session.experimental).""" - -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager -from dataclasses import dataclass, field - -import anyio -import pytest -from anyio import Event -from anyio.abc import TaskGroup - -from mcp import Client -from mcp.server import Server, ServerRequestContext -from mcp.shared.experimental.tasks.helpers import task_execution -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.types import ( - CallToolRequest, - CallToolRequestParams, - CallToolResult, - CancelTaskRequestParams, - CancelTaskResult, - CreateTaskResult, - GetTaskPayloadRequestParams, - GetTaskPayloadResult, - GetTaskRequestParams, - GetTaskResult, - ListTasksResult, - ListToolsResult, - PaginatedRequestParams, - TaskMetadata, - TextContent, -) - -pytestmark = pytest.mark.anyio - - -@dataclass -class AppContext: - """Application context passed via lifespan_context.""" - - task_group: TaskGroup - store: InMemoryTaskStore - task_done_events: dict[str, Event] = field(default_factory=lambda: {}) - - -async def _handle_list_tools( - ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None -) -> ListToolsResult: - raise NotImplementedError - - -async def _handle_call_tool_with_done_event( - ctx: ServerRequestContext[AppContext], params: CallToolRequestParams, *, result_text: str = "Done" -) -> CallToolResult | CreateTaskResult: - app = ctx.lifespan_context - if ctx.experimental.is_task: - task_metadata = ctx.experimental.task_metadata - assert task_metadata is not None - task = await app.store.create_task(task_metadata) - - done_event = Event() - app.task_done_events[task.task_id] = done_event - - async def do_work() -> None: - async with task_execution(task.task_id, app.store) as task_ctx: - await task_ctx.complete(CallToolResult(content=[TextContent(type="text", text=result_text)])) - done_event.set() - - app.task_group.start_soon(do_work) - return CreateTaskResult(task=task) - - raise NotImplementedError - - -def _make_lifespan(store: InMemoryTaskStore, task_done_events: dict[str, Event]): - @asynccontextmanager - async def app_lifespan(server: Server[AppContext]) -> AsyncIterator[AppContext]: - async with anyio.create_task_group() as tg: - yield AppContext(task_group=tg, store=store, task_done_events=task_done_events) - - return app_lifespan - - -async def test_session_experimental_get_task() -> None: - """Test session.experimental.get_task() method.""" - store = InMemoryTaskStore() - task_done_events: dict[str, Event] = {} - - async def handle_get_task(ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams) -> GetTaskResult: - app = ctx.lifespan_context - task = await app.store.get_task(params.task_id) - assert task is not None, f"Test setup error: task {params.task_id} should exist" - return GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=task.poll_interval, - ) - - server: Server[AppContext] = Server( - "test-server", - lifespan=_make_lifespan(store, task_done_events), - on_list_tools=_handle_list_tools, - on_call_tool=_handle_call_tool_with_done_event, - ) - server.experimental.enable_tasks(on_get_task=handle_get_task) - - async with Client(server) as client: - # Create a task - create_result = await client.session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) - ), - CreateTaskResult, - ) - task_id = create_result.task.task_id - - # Wait for task to complete - await task_done_events[task_id].wait() - - # Use session.experimental to get task status - task_status = await client.session.experimental.get_task(task_id) - - assert task_status.task_id == task_id - assert task_status.status == "completed" - - -async def test_session_experimental_get_task_result() -> None: - """Test session.experimental.get_task_result() method.""" - store = InMemoryTaskStore() - task_done_events: dict[str, Event] = {} - - async def handle_call_tool( - ctx: ServerRequestContext[AppContext], params: CallToolRequestParams - ) -> CallToolResult | CreateTaskResult: - return await _handle_call_tool_with_done_event(ctx, params, result_text="Task result content") - - async def handle_get_task_result( - ctx: ServerRequestContext[AppContext], params: GetTaskPayloadRequestParams - ) -> GetTaskPayloadResult: - app = ctx.lifespan_context - result = await app.store.get_result(params.task_id) - assert result is not None, f"Test setup error: result for {params.task_id} should exist" - assert isinstance(result, CallToolResult) - return GetTaskPayloadResult(**result.model_dump()) - - server: Server[AppContext] = Server( - "test-server", - lifespan=_make_lifespan(store, task_done_events), - on_list_tools=_handle_list_tools, - on_call_tool=handle_call_tool, - ) - server.experimental.enable_tasks(on_task_result=handle_get_task_result) - - async with Client(server) as client: - # Create a task - create_result = await client.session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) - ), - CreateTaskResult, - ) - task_id = create_result.task.task_id - - # Wait for task to complete - await task_done_events[task_id].wait() - - # Use TaskClient to get task result - task_result = await client.session.experimental.get_task_result(task_id, CallToolResult) - - assert len(task_result.content) == 1 - content = task_result.content[0] - assert isinstance(content, TextContent) - assert content.text == "Task result content" - - -async def test_session_experimental_list_tasks() -> None: - """Test TaskClient.list_tasks() method.""" - store = InMemoryTaskStore() - task_done_events: dict[str, Event] = {} - - async def handle_list_tasks( - ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None - ) -> ListTasksResult: - app = ctx.lifespan_context - cursor = params.cursor if params else None - tasks_list, next_cursor = await app.store.list_tasks(cursor=cursor) - return ListTasksResult(tasks=tasks_list, next_cursor=next_cursor) - - server: Server[AppContext] = Server( - "test-server", - lifespan=_make_lifespan(store, task_done_events), - on_list_tools=_handle_list_tools, - on_call_tool=_handle_call_tool_with_done_event, - ) - server.experimental.enable_tasks(on_list_tasks=handle_list_tasks) - - async with Client(server) as client: - # Create two tasks - for _ in range(2): - create_result = await client.session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) - ), - CreateTaskResult, - ) - await task_done_events[create_result.task.task_id].wait() - - # Use TaskClient to list tasks - list_result = await client.session.experimental.list_tasks() - - assert len(list_result.tasks) == 2 - - -async def test_session_experimental_cancel_task() -> None: - """Test TaskClient.cancel_task() method.""" - store = InMemoryTaskStore() - task_done_events: dict[str, Event] = {} - - async def handle_call_tool_no_work( - ctx: ServerRequestContext[AppContext], params: CallToolRequestParams - ) -> CallToolResult | CreateTaskResult: - app = ctx.lifespan_context - if ctx.experimental.is_task: - task_metadata = ctx.experimental.task_metadata - assert task_metadata is not None - task = await app.store.create_task(task_metadata) - # Don't start any work - task stays in "working" status - return CreateTaskResult(task=task) - raise NotImplementedError - - async def handle_get_task(ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams) -> GetTaskResult: - app = ctx.lifespan_context - task = await app.store.get_task(params.task_id) - assert task is not None, f"Test setup error: task {params.task_id} should exist" - return GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=task.poll_interval, - ) - - async def handle_cancel_task( - ctx: ServerRequestContext[AppContext], params: CancelTaskRequestParams - ) -> CancelTaskResult: - app = ctx.lifespan_context - task = await app.store.get_task(params.task_id) - assert task is not None, f"Test setup error: task {params.task_id} should exist" - await app.store.update_task(params.task_id, status="cancelled") - updated_task = await app.store.get_task(params.task_id) - assert updated_task is not None - return CancelTaskResult( - task_id=updated_task.task_id, - status=updated_task.status, - created_at=updated_task.created_at, - last_updated_at=updated_task.last_updated_at, - ttl=updated_task.ttl, - ) - - server: Server[AppContext] = Server( - "test-server", - lifespan=_make_lifespan(store, task_done_events), - on_list_tools=_handle_list_tools, - on_call_tool=handle_call_tool_no_work, - ) - server.experimental.enable_tasks(on_get_task=handle_get_task, on_cancel_task=handle_cancel_task) - - async with Client(server) as client: - # Create a task (but don't complete it) - create_result = await client.session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) - ), - CreateTaskResult, - ) - task_id = create_result.task.task_id - - # Verify task is working - status_before = await client.session.experimental.get_task(task_id) - assert status_before.status == "working" - - # Cancel the task - await client.session.experimental.cancel_task(task_id) - - # Verify task is cancelled - status_after = await client.session.experimental.get_task(task_id) - assert status_after.status == "cancelled" diff --git a/tests/experimental/tasks/server/__init__.py b/tests/experimental/tasks/server/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/experimental/tasks/server/test_context.py b/tests/experimental/tasks/server/test_context.py deleted file mode 100644 index a0f1a190d2..0000000000 --- a/tests/experimental/tasks/server/test_context.py +++ /dev/null @@ -1,183 +0,0 @@ -"""Tests for TaskContext and helper functions.""" - -import pytest - -from mcp.shared.experimental.tasks.context import TaskContext -from mcp.shared.experimental.tasks.helpers import create_task_state, task_execution -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.types import CallToolResult, TaskMetadata, TextContent - - -@pytest.mark.anyio -async def test_task_context_properties() -> None: - """Test TaskContext basic properties.""" - store = InMemoryTaskStore() - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - ctx = TaskContext(task, store) - - assert ctx.task_id == task.task_id - assert ctx.task.task_id == task.task_id - assert ctx.task.status == "working" - assert ctx.is_cancelled is False - - store.cleanup() - - -@pytest.mark.anyio -async def test_task_context_update_status() -> None: - """Test TaskContext.update_status.""" - store = InMemoryTaskStore() - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - ctx = TaskContext(task, store) - - await ctx.update_status("Processing step 1...") - - # Check status message was updated - updated = await store.get_task(task.task_id) - assert updated is not None - assert updated.status_message == "Processing step 1..." - - store.cleanup() - - -@pytest.mark.anyio -async def test_task_context_complete() -> None: - """Test TaskContext.complete.""" - store = InMemoryTaskStore() - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - ctx = TaskContext(task, store) - - result = CallToolResult(content=[TextContent(type="text", text="Done!")]) - await ctx.complete(result) - - # Check task status - updated = await store.get_task(task.task_id) - assert updated is not None - assert updated.status == "completed" - - # Check result is stored - stored_result = await store.get_result(task.task_id) - assert stored_result is not None - - store.cleanup() - - -@pytest.mark.anyio -async def test_task_context_fail() -> None: - """Test TaskContext.fail.""" - store = InMemoryTaskStore() - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - ctx = TaskContext(task, store) - - await ctx.fail("Something went wrong!") - - # Check task status - updated = await store.get_task(task.task_id) - assert updated is not None - assert updated.status == "failed" - assert updated.status_message == "Something went wrong!" - - store.cleanup() - - -@pytest.mark.anyio -async def test_task_context_cancellation() -> None: - """Test TaskContext cancellation request.""" - store = InMemoryTaskStore() - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - ctx = TaskContext(task, store) - - assert ctx.is_cancelled is False - - ctx.request_cancellation() - - assert ctx.is_cancelled is True - - store.cleanup() - - -def test_create_task_state_generates_id() -> None: - """create_task_state generates a unique task ID when none provided.""" - task1 = create_task_state(TaskMetadata(ttl=60000)) - task2 = create_task_state(TaskMetadata(ttl=60000)) - - assert task1.task_id != task2.task_id - - -def test_create_task_state_uses_provided_id() -> None: - """create_task_state uses the provided task ID.""" - task = create_task_state(TaskMetadata(ttl=60000), task_id="my-task-123") - assert task.task_id == "my-task-123" - - -def test_create_task_state_null_ttl() -> None: - """create_task_state handles null TTL.""" - task = create_task_state(TaskMetadata(ttl=None)) - assert task.ttl is None - - -def test_create_task_state_has_created_at() -> None: - """create_task_state sets createdAt timestamp.""" - task = create_task_state(TaskMetadata(ttl=60000)) - assert task.created_at is not None - - -@pytest.mark.anyio -async def test_task_execution_provides_context() -> None: - """task_execution provides a TaskContext for the task.""" - store = InMemoryTaskStore() - await store.create_task(TaskMetadata(ttl=60000), task_id="exec-test-1") - - async with task_execution("exec-test-1", store) as ctx: - assert ctx.task_id == "exec-test-1" - assert ctx.task.status == "working" - - store.cleanup() - - -@pytest.mark.anyio -async def test_task_execution_auto_fails_on_exception() -> None: - """task_execution automatically fails task on unhandled exception.""" - store = InMemoryTaskStore() - await store.create_task(TaskMetadata(ttl=60000), task_id="exec-fail-1") - - async with task_execution("exec-fail-1", store): - raise RuntimeError("Oops!") - - # Task should be failed - failed_task = await store.get_task("exec-fail-1") - assert failed_task is not None - assert failed_task.status == "failed" - assert "Oops!" in (failed_task.status_message or "") - - store.cleanup() - - -@pytest.mark.anyio -async def test_task_execution_doesnt_fail_if_already_terminal() -> None: - """task_execution doesn't re-fail if task already terminal.""" - store = InMemoryTaskStore() - await store.create_task(TaskMetadata(ttl=60000), task_id="exec-term-1") - - async with task_execution("exec-term-1", store) as ctx: - # Complete the task first - await ctx.complete(CallToolResult(content=[TextContent(type="text", text="Done")])) - # Then raise - shouldn't change status - raise RuntimeError("This shouldn't matter") - - # Task should remain completed - final_task = await store.get_task("exec-term-1") - assert final_task is not None - assert final_task.status == "completed" - - store.cleanup() - - -@pytest.mark.anyio -async def test_task_execution_not_found() -> None: - """task_execution raises ValueError for non-existent task.""" - store = InMemoryTaskStore() - - with pytest.raises(ValueError, match="not found"): - async with task_execution("nonexistent", store): - ... diff --git a/tests/experimental/tasks/server/test_integration.py b/tests/experimental/tasks/server/test_integration.py deleted file mode 100644 index b5b79033d0..0000000000 --- a/tests/experimental/tasks/server/test_integration.py +++ /dev/null @@ -1,247 +0,0 @@ -"""End-to-end integration tests for tasks functionality. - -These tests demonstrate the full task lifecycle: -1. Client sends task-augmented request (tools/call with task metadata) -2. Server creates task and returns CreateTaskResult immediately -3. Background work executes (using task_execution context manager) -4. Client polls with tasks/get -5. Client retrieves result with tasks/result -""" - -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager -from dataclasses import dataclass, field - -import anyio -import pytest -from anyio import Event -from anyio.abc import TaskGroup - -from mcp import Client -from mcp.server import Server, ServerRequestContext -from mcp.shared.experimental.tasks.helpers import task_execution -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.types import ( - CallToolRequest, - CallToolRequestParams, - CallToolResult, - CreateTaskResult, - GetTaskPayloadRequestParams, - GetTaskPayloadResult, - GetTaskRequestParams, - GetTaskResult, - ListTasksResult, - ListToolsResult, - PaginatedRequestParams, - TaskMetadata, - TextContent, -) - -pytestmark = pytest.mark.anyio - - -@dataclass -class AppContext: - """Application context passed via lifespan_context.""" - - task_group: TaskGroup - store: InMemoryTaskStore - task_done_events: dict[str, Event] = field(default_factory=lambda: {}) - - -def _make_lifespan(store: InMemoryTaskStore, task_done_events: dict[str, Event]): - @asynccontextmanager - async def app_lifespan(server: Server[AppContext]) -> AsyncIterator[AppContext]: - async with anyio.create_task_group() as tg: - yield AppContext(task_group=tg, store=store, task_done_events=task_done_events) - - return app_lifespan - - -async def test_task_lifecycle_with_task_execution() -> None: - """Test the complete task lifecycle using the task_execution pattern.""" - store = InMemoryTaskStore() - task_done_events: dict[str, Event] = {} - - async def handle_list_tools( - ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None - ) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool( - ctx: ServerRequestContext[AppContext], params: CallToolRequestParams - ) -> CallToolResult | CreateTaskResult: - app = ctx.lifespan_context - if params.name == "process_data" and ctx.experimental.is_task: - task_metadata = ctx.experimental.task_metadata - assert task_metadata is not None - task = await app.store.create_task(task_metadata) - - done_event = Event() - app.task_done_events[task.task_id] = done_event - - async def do_work() -> None: - async with task_execution(task.task_id, app.store) as task_ctx: - await task_ctx.update_status("Processing input...") - input_value = (params.arguments or {}).get("input", "") - result_text = f"Processed: {input_value.upper()}" - await task_ctx.complete(CallToolResult(content=[TextContent(type="text", text=result_text)])) - done_event.set() - - app.task_group.start_soon(do_work) - return CreateTaskResult(task=task) - - raise NotImplementedError - - async def handle_get_task(ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams) -> GetTaskResult: - app = ctx.lifespan_context - task = await app.store.get_task(params.task_id) - assert task is not None, f"Test setup error: task {params.task_id} should exist" - return GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=task.poll_interval, - ) - - async def handle_get_task_result( - ctx: ServerRequestContext[AppContext], params: GetTaskPayloadRequestParams - ) -> GetTaskPayloadResult: - app = ctx.lifespan_context - result = await app.store.get_result(params.task_id) - assert result is not None, f"Test setup error: result for {params.task_id} should exist" - assert isinstance(result, CallToolResult) - return GetTaskPayloadResult(**result.model_dump()) - - async def handle_list_tasks( - ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None - ) -> ListTasksResult: - raise NotImplementedError - - server: Server[AppContext] = Server( - "test-tasks", - lifespan=_make_lifespan(store, task_done_events), - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, - ) - server.experimental.enable_tasks( - on_get_task=handle_get_task, - on_task_result=handle_get_task_result, - on_list_tasks=handle_list_tasks, - ) - - async with Client(server) as client: - # Step 1: Send task-augmented tool call - create_result = await client.session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="process_data", - arguments={"input": "hello world"}, - task=TaskMetadata(ttl=60000), - ), - ), - CreateTaskResult, - ) - - assert isinstance(create_result, CreateTaskResult) - assert create_result.task.status == "working" - task_id = create_result.task.task_id - - # Step 2: Wait for task to complete - await task_done_events[task_id].wait() - - task_status = await client.session.experimental.get_task(task_id) - assert task_status.task_id == task_id - assert task_status.status == "completed" - - # Step 3: Retrieve the actual result - task_result = await client.session.experimental.get_task_result(task_id, CallToolResult) - - assert len(task_result.content) == 1 - content = task_result.content[0] - assert isinstance(content, TextContent) - assert content.text == "Processed: HELLO WORLD" - - -async def test_task_auto_fails_on_exception() -> None: - """Test that task_execution automatically fails the task on unhandled exception.""" - store = InMemoryTaskStore() - task_done_events: dict[str, Event] = {} - - async def handle_list_tools( - ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None - ) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool( - ctx: ServerRequestContext[AppContext], params: CallToolRequestParams - ) -> CallToolResult | CreateTaskResult: - app = ctx.lifespan_context - if params.name == "failing_task" and ctx.experimental.is_task: - task_metadata = ctx.experimental.task_metadata - assert task_metadata is not None - task = await app.store.create_task(task_metadata) - - done_event = Event() - app.task_done_events[task.task_id] = done_event - - async def do_failing_work() -> None: - async with task_execution(task.task_id, app.store) as task_ctx: - await task_ctx.update_status("About to fail...") - raise RuntimeError("Something went wrong!") - # This line is reached because task_execution suppresses the exception - done_event.set() - - app.task_group.start_soon(do_failing_work) - return CreateTaskResult(task=task) - - raise NotImplementedError - - async def handle_get_task(ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams) -> GetTaskResult: - app = ctx.lifespan_context - task = await app.store.get_task(params.task_id) - assert task is not None, f"Test setup error: task {params.task_id} should exist" - return GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=task.poll_interval, - ) - - server: Server[AppContext] = Server( - "test-tasks-failure", - lifespan=_make_lifespan(store, task_done_events), - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, - ) - server.experimental.enable_tasks(on_get_task=handle_get_task) - - async with Client(server) as client: - # Send task request - create_result = await client.session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="failing_task", - arguments={}, - task=TaskMetadata(ttl=60000), - ), - ), - CreateTaskResult, - ) - - task_id = create_result.task.task_id - - # Wait for task to complete (even though it fails) - await task_done_events[task_id].wait() - - # Check that task was auto-failed - task_status = await client.session.experimental.get_task(task_id) - - assert task_status.status == "failed" - assert task_status.status_message == "Something went wrong!" diff --git a/tests/experimental/tasks/server/test_run_task_flow.py b/tests/experimental/tasks/server/test_run_task_flow.py deleted file mode 100644 index 027382e69e..0000000000 --- a/tests/experimental/tasks/server/test_run_task_flow.py +++ /dev/null @@ -1,367 +0,0 @@ -"""Tests for the simplified task API: enable_tasks() + run_task() - -This tests the recommended user flow: -1. server.experimental.enable_tasks() - one-line setup -2. ctx.experimental.run_task(work) - spawns work, returns CreateTaskResult -3. work function uses ServerTaskContext for elicit/create_message - -These are integration tests that verify the complete flow works end-to-end. -""" - -from unittest.mock import Mock - -import anyio -import pytest -from anyio import Event - -from mcp import Client -from mcp.server import Server, ServerRequestContext -from mcp.server.experimental.request_context import Experimental -from mcp.server.experimental.task_context import ServerTaskContext -from mcp.server.experimental.task_support import TaskSupport -from mcp.server.lowlevel import NotificationOptions -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue -from mcp.types import ( - TASK_REQUIRED, - CallToolRequestParams, - CallToolResult, - CreateTaskResult, - GetTaskRequestParams, - GetTaskResult, - ListToolsResult, - PaginatedRequestParams, - TextContent, -) - -pytestmark = pytest.mark.anyio - - -async def _handle_list_tools_simple_task( - ctx: ServerRequestContext, params: PaginatedRequestParams | None -) -> ListToolsResult: - raise NotImplementedError - - -async def test_run_task_basic_flow() -> None: - """Test the basic run_task flow without elicitation.""" - work_completed = Event() - received_meta: list[str | None] = [None] - - async def handle_call_tool( - ctx: ServerRequestContext, params: CallToolRequestParams - ) -> CallToolResult | CreateTaskResult: - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - if ctx.meta is not None: # pragma: no branch - received_meta[0] = ctx.meta.get("custom_field") - - async def work(task: ServerTaskContext) -> CallToolResult: - await task.update_status("Working...") - input_val = (params.arguments or {}).get("input", "default") - result = CallToolResult(content=[TextContent(type="text", text=f"Processed: {input_val}")]) - work_completed.set() - return result - - return await ctx.experimental.run_task(work) - - server = Server( - "test-run-task", - on_list_tools=_handle_list_tools_simple_task, - on_call_tool=handle_call_tool, - ) - server.experimental.enable_tasks() - - async with Client(server) as client: - result = await client.session.experimental.call_tool_as_task( - "simple_task", - {"input": "hello"}, - meta={"custom_field": "test_value"}, - ) - - task_id = result.task.task_id - assert result.task.status == "working" - - with anyio.fail_after(5): - await work_completed.wait() - - with anyio.fail_after(5): - while True: - task_status = await client.session.experimental.get_task(task_id) - if task_status.status == "completed": # pragma: no branch - break - - assert received_meta[0] == "test_value" - - -async def test_run_task_auto_fails_on_exception() -> None: - """Test that run_task automatically fails the task when work raises.""" - work_failed = Event() - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool( - ctx: ServerRequestContext, params: CallToolRequestParams - ) -> CallToolResult | CreateTaskResult: - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - async def work(task: ServerTaskContext) -> CallToolResult: - work_failed.set() - raise RuntimeError("Something went wrong!") - - return await ctx.experimental.run_task(work) - - server = Server( - "test-run-task-fail", - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, - ) - server.experimental.enable_tasks() - - async with Client(server) as client: - result = await client.session.experimental.call_tool_as_task("failing_task", {}) - task_id = result.task.task_id - - with anyio.fail_after(5): - await work_failed.wait() - - with anyio.fail_after(5): - while True: - task_status = await client.session.experimental.get_task(task_id) - if task_status.status == "failed": # pragma: no branch - break - - assert "Something went wrong" in (task_status.status_message or "") - - -async def test_enable_tasks_auto_registers_handlers() -> None: - """Test that enable_tasks() auto-registers get_task, list_tasks, cancel_task handlers.""" - server = Server("test-enable-tasks") - - # Before enable_tasks, no task capabilities - caps_before = server.get_capabilities(NotificationOptions(), {}) - assert caps_before.tasks is None - - # Enable tasks - server.experimental.enable_tasks() - - # After enable_tasks, should have task capabilities - caps_after = server.get_capabilities(NotificationOptions(), {}) - assert caps_after.tasks is not None - assert caps_after.tasks.list is not None - assert caps_after.tasks.cancel is not None - assert caps_after.tasks.requests is not None - assert caps_after.tasks.requests.tools is not None - assert caps_after.tasks.requests.tools.call is not None - - -async def test_enable_tasks_with_custom_store_and_queue() -> None: - """Test that enable_tasks() uses provided store and queue instead of defaults.""" - server = Server("test-custom-store-queue") - - custom_store = InMemoryTaskStore() - custom_queue = InMemoryTaskMessageQueue() - - task_support = server.experimental.enable_tasks(store=custom_store, queue=custom_queue) - - assert task_support.store is custom_store - assert task_support.queue is custom_queue - - -async def test_enable_tasks_skips_default_handlers_when_custom_registered() -> None: - """Test that enable_tasks() doesn't override already-registered handlers.""" - server = Server("test-custom-handlers") - - # Register custom handlers via enable_tasks kwargs - async def custom_get_task(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: - raise NotImplementedError - - server.experimental.enable_tasks(on_get_task=custom_get_task) - - # Verify handler is registered - assert server._has_handler("tasks/get") - assert server._has_handler("tasks/list") - assert server._has_handler("tasks/cancel") - assert server._has_handler("tasks/result") - - -async def test_run_task_without_enable_tasks_raises() -> None: - """Test that run_task raises when enable_tasks() wasn't called.""" - experimental = Experimental( - task_metadata=None, - _client_capabilities=None, - _session=None, - _task_support=None, # Not enabled - ) - - async def work(task: ServerTaskContext) -> CallToolResult: - raise NotImplementedError - - with pytest.raises(RuntimeError, match="Task support not enabled"): - await experimental.run_task(work) - - -async def test_task_support_task_group_before_run_raises() -> None: - """Test that accessing task_group before run() raises RuntimeError.""" - task_support = TaskSupport.in_memory() - - with pytest.raises(RuntimeError, match="TaskSupport not running"): - _ = task_support.task_group - - -async def test_run_task_without_session_raises() -> None: - """Test that run_task raises when session is not available.""" - task_support = TaskSupport.in_memory() - - experimental = Experimental( - task_metadata=None, - _client_capabilities=None, - _session=None, # No session - _task_support=task_support, - ) - - async def work(task: ServerTaskContext) -> CallToolResult: - raise NotImplementedError - - with pytest.raises(RuntimeError, match="Session not available"): - await experimental.run_task(work) - - -async def test_run_task_without_task_metadata_raises() -> None: - """Test that run_task raises when request is not task-augmented.""" - task_support = TaskSupport.in_memory() - mock_session = Mock() - - experimental = Experimental( - task_metadata=None, # Not a task-augmented request - _client_capabilities=None, - _session=mock_session, - _task_support=task_support, - ) - - async def work(task: ServerTaskContext) -> CallToolResult: - raise NotImplementedError - - with pytest.raises(RuntimeError, match="Request is not task-augmented"): - await experimental.run_task(work) - - -async def test_run_task_with_model_immediate_response() -> None: - """Test that run_task includes model_immediate_response in CreateTaskResult._meta.""" - work_completed = Event() - immediate_response_text = "Processing your request..." - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool( - ctx: ServerRequestContext, params: CallToolRequestParams - ) -> CallToolResult | CreateTaskResult: - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - async def work(task: ServerTaskContext) -> CallToolResult: - work_completed.set() - return CallToolResult(content=[TextContent(type="text", text="Done")]) - - return await ctx.experimental.run_task(work, model_immediate_response=immediate_response_text) - - server = Server( - "test-run-task-immediate", - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, - ) - server.experimental.enable_tasks() - - async with Client(server) as client: - result = await client.session.experimental.call_tool_as_task("task_with_immediate", {}) - - assert result.meta is not None - assert "io.modelcontextprotocol/model-immediate-response" in result.meta - assert result.meta["io.modelcontextprotocol/model-immediate-response"] == immediate_response_text - - with anyio.fail_after(5): - await work_completed.wait() - - -async def test_run_task_doesnt_complete_if_already_terminal() -> None: - """Test that run_task doesn't auto-complete if work manually completed the task.""" - work_completed = Event() - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool( - ctx: ServerRequestContext, params: CallToolRequestParams - ) -> CallToolResult | CreateTaskResult: - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - async def work(task: ServerTaskContext) -> CallToolResult: - manual_result = CallToolResult(content=[TextContent(type="text", text="Manually completed")]) - await task.complete(manual_result, notify=False) - work_completed.set() - return CallToolResult(content=[TextContent(type="text", text="This should be ignored")]) - - return await ctx.experimental.run_task(work) - - server = Server( - "test-already-complete", - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, - ) - server.experimental.enable_tasks() - - async with Client(server) as client: - result = await client.session.experimental.call_tool_as_task("manual_complete_task", {}) - task_id = result.task.task_id - - with anyio.fail_after(5): - await work_completed.wait() - - with anyio.fail_after(5): - while True: - status = await client.session.experimental.get_task(task_id) - if status.status == "completed": # pragma: no branch - break - - -async def test_run_task_doesnt_fail_if_already_terminal() -> None: - """Test that run_task doesn't auto-fail if work manually failed/cancelled the task.""" - work_completed = Event() - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool( - ctx: ServerRequestContext, params: CallToolRequestParams - ) -> CallToolResult | CreateTaskResult: - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - async def work(task: ServerTaskContext) -> CallToolResult: - await task.fail("Manually failed", notify=False) - work_completed.set() - raise RuntimeError("This error should not change status") - - return await ctx.experimental.run_task(work) - - server = Server( - "test-already-failed", - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, - ) - server.experimental.enable_tasks() - - async with Client(server) as client: - result = await client.session.experimental.call_tool_as_task("manual_cancel_task", {}) - task_id = result.task.task_id - - with anyio.fail_after(5): - await work_completed.wait() - - with anyio.fail_after(5): - while True: - status = await client.session.experimental.get_task(task_id) - if status.status == "failed": # pragma: no branch - break - - assert status.status_message == "Manually failed" diff --git a/tests/experimental/tasks/server/test_server.py b/tests/experimental/tasks/server/test_server.py deleted file mode 100644 index 6a28b274ea..0000000000 --- a/tests/experimental/tasks/server/test_server.py +++ /dev/null @@ -1,797 +0,0 @@ -"""Tests for server-side task support (handlers, capabilities, integration).""" - -from datetime import datetime, timezone -from typing import Any - -import anyio -import pytest - -from mcp import Client -from mcp.client.session import ClientSession -from mcp.server import Server, ServerRequestContext -from mcp.server.lowlevel import NotificationOptions -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession -from mcp.shared.exceptions import MCPError -from mcp.shared.message import ServerMessageMetadata, SessionMessage -from mcp.shared.response_router import ResponseRouter -from mcp.shared.session import RequestResponder -from mcp.types import ( - INVALID_REQUEST, - TASK_FORBIDDEN, - TASK_OPTIONAL, - TASK_REQUIRED, - CallToolRequest, - CallToolRequestParams, - CallToolResult, - CancelTaskRequestParams, - CancelTaskResult, - ClientResult, - ErrorData, - GetTaskPayloadRequest, - GetTaskPayloadRequestParams, - GetTaskPayloadResult, - GetTaskRequestParams, - GetTaskResult, - JSONRPCError, - JSONRPCNotification, - JSONRPCResponse, - ListTasksResult, - ListToolsResult, - PaginatedRequestParams, - SamplingMessage, - ServerCapabilities, - ServerNotification, - ServerRequest, - Task, - TaskMetadata, - TextContent, - Tool, - ToolExecution, -) - -pytestmark = pytest.mark.anyio - - -async def test_list_tasks_handler() -> None: - """Test that experimental list_tasks handler works via Client.""" - now = datetime.now(timezone.utc) - test_tasks = [ - Task(task_id="task-1", status="working", created_at=now, last_updated_at=now, ttl=60000, poll_interval=1000), - Task(task_id="task-2", status="completed", created_at=now, last_updated_at=now, ttl=60000, poll_interval=1000), - ] - - async def handle_list_tasks(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListTasksResult: - return ListTasksResult(tasks=test_tasks) - - server = Server("test") - server.experimental.enable_tasks(on_list_tasks=handle_list_tasks) - - async with Client(server) as client: - result = await client.session.experimental.list_tasks() - assert len(result.tasks) == 2 - assert result.tasks[0].task_id == "task-1" - assert result.tasks[1].task_id == "task-2" - - -async def test_get_task_handler() -> None: - """Test that experimental get_task handler works via Client.""" - - async def handle_get_task(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: - now = datetime.now(timezone.utc) - return GetTaskResult( - task_id=params.task_id, - status="working", - created_at=now, - last_updated_at=now, - ttl=60000, - poll_interval=1000, - ) - - server = Server("test") - server.experimental.enable_tasks(on_get_task=handle_get_task) - - async with Client(server) as client: - result = await client.session.experimental.get_task("test-task-123") - assert result.task_id == "test-task-123" - assert result.status == "working" - - -async def test_get_task_result_handler() -> None: - """Test that experimental get_task_result handler works via Client.""" - - async def handle_get_task_result( - ctx: ServerRequestContext, params: GetTaskPayloadRequestParams - ) -> GetTaskPayloadResult: - return GetTaskPayloadResult() - - server = Server("test") - server.experimental.enable_tasks(on_task_result=handle_get_task_result) - - async with Client(server) as client: - result = await client.session.send_request( - GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id="test-task-123")), - GetTaskPayloadResult, - ) - assert isinstance(result, GetTaskPayloadResult) - - -async def test_cancel_task_handler() -> None: - """Test that experimental cancel_task handler works via Client.""" - - async def handle_cancel_task(ctx: ServerRequestContext, params: CancelTaskRequestParams) -> CancelTaskResult: - now = datetime.now(timezone.utc) - return CancelTaskResult( - task_id=params.task_id, - status="cancelled", - created_at=now, - last_updated_at=now, - ttl=60000, - ) - - server = Server("test") - server.experimental.enable_tasks(on_cancel_task=handle_cancel_task) - - async with Client(server) as client: - result = await client.session.experimental.cancel_task("test-task-123") - assert result.task_id == "test-task-123" - assert result.status == "cancelled" - - -async def test_server_capabilities_include_tasks() -> None: - """Test that server capabilities include tasks when handlers are registered.""" - server = Server("test") - - async def noop_list_tasks(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListTasksResult: - raise NotImplementedError - - async def noop_cancel_task(ctx: ServerRequestContext, params: CancelTaskRequestParams) -> CancelTaskResult: - raise NotImplementedError - - server.experimental.enable_tasks(on_list_tasks=noop_list_tasks, on_cancel_task=noop_cancel_task) - - capabilities = server.get_capabilities(notification_options=NotificationOptions(), experimental_capabilities={}) - - assert capabilities.tasks is not None - assert capabilities.tasks.list is not None - assert capabilities.tasks.cancel is not None - assert capabilities.tasks.requests is not None - assert capabilities.tasks.requests.tools is not None - - -@pytest.mark.skip( - reason="TODO(maxisbey): enable_tasks registers default handlers for all task methods, " - "so partial capabilities aren't possible yet. Low-level API should support " - "selectively enabling/disabling task capabilities." -) -async def test_server_capabilities_partial_tasks() -> None: # pragma: no cover - """Test capabilities with only some task handlers registered.""" - server = Server("test") - - async def noop_list_tasks(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListTasksResult: - raise NotImplementedError - - # Only list_tasks registered, not cancel_task - server.experimental.enable_tasks(on_list_tasks=noop_list_tasks) - - capabilities = server.get_capabilities(notification_options=NotificationOptions(), experimental_capabilities={}) - - assert capabilities.tasks is not None - assert capabilities.tasks.list is not None - assert capabilities.tasks.cancel is None # Not registered - - -async def test_tool_with_task_execution_metadata() -> None: - """Test that tools can declare task execution mode.""" - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult( - tools=[ - Tool( - name="quick_tool", - description="Fast tool", - input_schema={"type": "object", "properties": {}}, - execution=ToolExecution(task_support=TASK_FORBIDDEN), - ), - Tool( - name="long_tool", - description="Long running tool", - input_schema={"type": "object", "properties": {}}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ), - Tool( - name="flexible_tool", - description="Can be either", - input_schema={"type": "object", "properties": {}}, - execution=ToolExecution(task_support=TASK_OPTIONAL), - ), - ] - ) - - server = Server("test", on_list_tools=handle_list_tools) - - async with Client(server) as client: - result = await client.list_tools() - tools = result.tools - - assert tools[0].execution is not None - assert tools[0].execution.task_support == TASK_FORBIDDEN - assert tools[1].execution is not None - assert tools[1].execution.task_support == TASK_REQUIRED - assert tools[2].execution is not None - assert tools[2].execution.task_support == TASK_OPTIONAL - - -async def test_task_metadata_in_call_tool_request() -> None: - """Test that task metadata is accessible via ctx when calling a tool.""" - captured_task_metadata: TaskMetadata | None = None - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: - nonlocal captured_task_metadata - captured_task_metadata = ctx.experimental.task_metadata - return CallToolResult(content=[TextContent(type="text", text="done")]) - - server = Server("test", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) - - async with Client(server) as client: - # Call tool with task metadata - await client.session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="long_task", - arguments={}, - task=TaskMetadata(ttl=60000), - ), - ), - CallToolResult, - ) - - assert captured_task_metadata is not None - assert captured_task_metadata.ttl == 60000 - - -async def test_task_metadata_is_task_property() -> None: - """Test that ctx.experimental.is_task works correctly.""" - is_task_values: list[bool] = [] - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: - is_task_values.append(ctx.experimental.is_task) - return CallToolResult(content=[TextContent(type="text", text="done")]) - - server = Server("test", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) - - async with Client(server) as client: - # Call without task metadata - await client.session.send_request( - CallToolRequest(params=CallToolRequestParams(name="test_tool", arguments={})), - CallToolResult, - ) - - # Call with task metadata - await client.session.send_request( - CallToolRequest( - params=CallToolRequestParams(name="test_tool", arguments={}, task=TaskMetadata(ttl=60000)), - ), - CallToolResult, - ) - - assert len(is_task_values) == 2 - assert is_task_values[0] is False # First call without task - assert is_task_values[1] is True # Second call with task - - -async def test_update_capabilities_no_handlers() -> None: - """Test that update_capabilities returns early when no task handlers are registered.""" - server = Server("test-no-handlers") - _ = server.experimental - - caps = server.get_capabilities(NotificationOptions(), {}) - assert caps.tasks is None - - -async def test_update_capabilities_partial_handlers() -> None: - """Test that update_capabilities skips list/cancel when only tasks/get is registered.""" - server = Server("test-partial") - # Access .experimental to create the ExperimentalHandlers instance - exp = server.experimental - # Second access returns the same cached instance - assert server.experimental is exp - - async def noop_get(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: - raise NotImplementedError - - server._add_request_handler("tasks/get", noop_get) - - caps = server.get_capabilities(NotificationOptions(), {}) - assert caps.tasks is not None - assert caps.tasks.list is None - assert caps.tasks.cancel is None - - -async def test_default_task_handlers_via_enable_tasks() -> None: - """Test that enable_tasks() auto-registers working default handlers.""" - server = Server("test-default-handlers") - task_support = server.experimental.enable_tasks() - store = task_support.store - - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: ... # pragma: no branch - - async def run_server() -> None: - async with task_support.run(): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), - ) as server_session: - task_support.configure_session(server_session) - async for message in server_session.incoming_messages: - await server._handle_message(message, server_session, {}, False) - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - await client_session.initialize() - - # Create a task directly in the store for testing - task = await store.create_task(TaskMetadata(ttl=60000)) - - # Test list_tasks (default handler) - list_result = await client_session.experimental.list_tasks() - assert len(list_result.tasks) == 1 - assert list_result.tasks[0].task_id == task.task_id - - # Test get_task (default handler - found) - get_result = await client_session.experimental.get_task(task.task_id) - assert get_result.task_id == task.task_id - assert get_result.status == "working" - - # Test get_task (default handler - not found path) - with pytest.raises(MCPError, match="not found"): - await client_session.experimental.get_task("nonexistent-task") - - # Create a completed task to test get_task_result - completed_task = await store.create_task(TaskMetadata(ttl=60000)) - await store.store_result( - completed_task.task_id, CallToolResult(content=[TextContent(type="text", text="Test result")]) - ) - await store.update_task(completed_task.task_id, status="completed") - - # Test get_task_result (default handler) - payload_result = await client_session.send_request( - GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=completed_task.task_id)), - GetTaskPayloadResult, - ) - # The result should have the related-task metadata - assert payload_result.meta is not None - assert "io.modelcontextprotocol/related-task" in payload_result.meta - - # Test cancel_task (default handler) - cancel_result = await client_session.experimental.cancel_task(task.task_id) - assert cancel_result.task_id == task.task_id - assert cancel_result.status == "cancelled" - - tg.cancel_scope.cancel() - - -@pytest.mark.anyio -async def test_build_elicit_form_request() -> None: - """Test that _build_elicit_form_request builds a proper elicitation request.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - try: - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions(server_name="test-server", server_version="1.0.0", capabilities=ServerCapabilities()), - ) as server_session: - # Test without task_id - request = server_session._build_elicit_form_request( - message="Test message", - requested_schema={"type": "object", "properties": {"answer": {"type": "string"}}}, - ) - assert request.method == "elicitation/create" - assert request.params is not None - assert request.params["message"] == "Test message" - - # Test with related_task_id (adds related-task metadata) - request_with_task = server_session._build_elicit_form_request( - message="Task message", - requested_schema={"type": "object"}, - related_task_id="test-task-123", - ) - assert request_with_task.method == "elicitation/create" - assert request_with_task.params is not None - assert "_meta" in request_with_task.params - assert "io.modelcontextprotocol/related-task" in request_with_task.params["_meta"] - assert ( - request_with_task.params["_meta"]["io.modelcontextprotocol/related-task"]["taskId"] == "test-task-123" - ) - finally: - await server_to_client_send.aclose() - await server_to_client_receive.aclose() - await client_to_server_send.aclose() - await client_to_server_receive.aclose() - - -@pytest.mark.anyio -async def test_build_elicit_url_request() -> None: - """Test that _build_elicit_url_request builds a proper URL mode elicitation request.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - try: - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions(server_name="test-server", server_version="1.0.0", capabilities=ServerCapabilities()), - ) as server_session: - # Test without related_task_id - request = server_session._build_elicit_url_request( - message="Please authorize with GitHub", - url="https://github.com/login/oauth/authorize", - elicitation_id="oauth-123", - ) - assert request.method == "elicitation/create" - assert request.params is not None - assert request.params["message"] == "Please authorize with GitHub" - assert request.params["url"] == "https://github.com/login/oauth/authorize" - assert request.params["elicitationId"] == "oauth-123" - assert request.params["mode"] == "url" - - # Test with related_task_id (adds related-task metadata) - request_with_task = server_session._build_elicit_url_request( - message="OAuth required", - url="https://example.com/oauth", - elicitation_id="oauth-456", - related_task_id="test-task-789", - ) - assert request_with_task.method == "elicitation/create" - assert request_with_task.params is not None - assert "_meta" in request_with_task.params - assert "io.modelcontextprotocol/related-task" in request_with_task.params["_meta"] - assert ( - request_with_task.params["_meta"]["io.modelcontextprotocol/related-task"]["taskId"] == "test-task-789" - ) - finally: - await server_to_client_send.aclose() - await server_to_client_receive.aclose() - await client_to_server_send.aclose() - await client_to_server_receive.aclose() - - -@pytest.mark.anyio -async def test_build_create_message_request() -> None: - """Test that _build_create_message_request builds a proper sampling request.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - try: - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=ServerCapabilities(), - ), - ) as server_session: - messages = [ - SamplingMessage(role="user", content=TextContent(type="text", text="Hello")), - ] - - # Test without task_id - request = server_session._build_create_message_request( - messages=messages, - max_tokens=100, - system_prompt="You are helpful", - ) - assert request.method == "sampling/createMessage" - assert request.params is not None - assert request.params["maxTokens"] == 100 - - # Test with related_task_id (adds related-task metadata) - request_with_task = server_session._build_create_message_request( - messages=messages, - max_tokens=50, - related_task_id="sampling-task-456", - ) - assert request_with_task.method == "sampling/createMessage" - assert request_with_task.params is not None - assert "_meta" in request_with_task.params - assert "io.modelcontextprotocol/related-task" in request_with_task.params["_meta"] - assert ( - request_with_task.params["_meta"]["io.modelcontextprotocol/related-task"]["taskId"] - == "sampling-task-456" - ) - finally: - await server_to_client_send.aclose() - await server_to_client_receive.aclose() - await client_to_server_send.aclose() - await client_to_server_receive.aclose() - - -@pytest.mark.anyio -async def test_send_message() -> None: - """Test that send_message sends a raw session message.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - try: - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=ServerCapabilities(), - ), - ) as server_session: - # Create a test message - notification = JSONRPCNotification(jsonrpc="2.0", method="test/notification") - message = SessionMessage( - message=notification, - metadata=ServerMessageMetadata(related_request_id="test-req-1"), - ) - - # Send the message - await server_session.send_message(message) - - # Verify it was sent to the stream - received = await server_to_client_receive.receive() - assert isinstance(received.message, JSONRPCNotification) - assert received.message.method == "test/notification" - finally: # pragma: lax no cover - await server_to_client_send.aclose() - await server_to_client_receive.aclose() - await client_to_server_send.aclose() - await client_to_server_receive.aclose() - - -@pytest.mark.anyio -async def test_response_routing_success() -> None: - """Test that response routing works for success responses.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - # Track routed responses with event for synchronization - routed_responses: list[dict[str, Any]] = [] - response_received = anyio.Event() - - class TestRouter(ResponseRouter): - def route_response(self, request_id: str | int, response: dict[str, Any]) -> bool: - routed_responses.append({"id": request_id, "response": response}) - response_received.set() - return True # Handled - - def route_error(self, request_id: str | int, error: ErrorData) -> bool: - raise NotImplementedError - - try: - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=ServerCapabilities(), - ), - ) as server_session: - router = TestRouter() - server_session.add_response_router(router) - - # Simulate receiving a response from client - response = JSONRPCResponse(jsonrpc="2.0", id="test-req-1", result={"status": "ok"}) - message = SessionMessage(message=response) - - # Send from "client" side - await client_to_server_send.send(message) - - # Wait for response to be routed - with anyio.fail_after(5): - await response_received.wait() - - # Verify response was routed - assert len(routed_responses) == 1 - assert routed_responses[0]["id"] == "test-req-1" - assert routed_responses[0]["response"]["status"] == "ok" - finally: # pragma: lax no cover - await server_to_client_send.aclose() - await server_to_client_receive.aclose() - await client_to_server_send.aclose() - await client_to_server_receive.aclose() - - -@pytest.mark.anyio -async def test_response_routing_error() -> None: - """Test that error routing works for error responses.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - # Track routed errors with event for synchronization - routed_errors: list[dict[str, Any]] = [] - error_received = anyio.Event() - - class TestRouter(ResponseRouter): - def route_response(self, request_id: str | int, response: dict[str, Any]) -> bool: - raise NotImplementedError - - def route_error(self, request_id: str | int, error: ErrorData) -> bool: - routed_errors.append({"id": request_id, "error": error}) - error_received.set() - return True # Handled - - try: - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=ServerCapabilities(), - ), - ) as server_session: - router = TestRouter() - server_session.add_response_router(router) - - # Simulate receiving an error response from client - error_data = ErrorData(code=INVALID_REQUEST, message="Test error") - error_response = JSONRPCError(jsonrpc="2.0", id="test-req-2", error=error_data) - message = SessionMessage(message=error_response) - - # Send from "client" side - await client_to_server_send.send(message) - - # Wait for error to be routed - with anyio.fail_after(5): - await error_received.wait() - - # Verify error was routed - assert len(routed_errors) == 1 - assert routed_errors[0]["id"] == "test-req-2" - assert routed_errors[0]["error"].message == "Test error" - finally: # pragma: lax no cover - await server_to_client_send.aclose() - await server_to_client_receive.aclose() - await client_to_server_send.aclose() - await client_to_server_receive.aclose() - - -@pytest.mark.anyio -async def test_response_routing_skips_non_matching_routers() -> None: - """Test that routing continues to next router when first doesn't match.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - # Track which routers were called - router_calls: list[str] = [] - response_received = anyio.Event() - - class NonMatchingRouter(ResponseRouter): - def route_response(self, request_id: str | int, response: dict[str, Any]) -> bool: - router_calls.append("non_matching_response") - return False # Doesn't handle it - - def route_error(self, request_id: str | int, error: ErrorData) -> bool: - raise NotImplementedError - - class MatchingRouter(ResponseRouter): - def route_response(self, request_id: str | int, response: dict[str, Any]) -> bool: - router_calls.append("matching_response") - response_received.set() - return True # Handles it - - def route_error(self, request_id: str | int, error: ErrorData) -> bool: - raise NotImplementedError - - try: - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=ServerCapabilities(), - ), - ) as server_session: - # Add non-matching router first, then matching router - server_session.add_response_router(NonMatchingRouter()) - server_session.add_response_router(MatchingRouter()) - - # Send a response - should skip first router and be handled by second - response = JSONRPCResponse(jsonrpc="2.0", id="test-req-1", result={"status": "ok"}) - message = SessionMessage(message=response) - await client_to_server_send.send(message) - - with anyio.fail_after(5): - await response_received.wait() - - # Verify both routers were called (first returned False, second returned True) - assert router_calls == ["non_matching_response", "matching_response"] - finally: # pragma: lax no cover - await server_to_client_send.aclose() - await server_to_client_receive.aclose() - await client_to_server_send.aclose() - await client_to_server_receive.aclose() - - -@pytest.mark.anyio -async def test_error_routing_skips_non_matching_routers() -> None: - """Test that error routing continues to next router when first doesn't match.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - # Track which routers were called - router_calls: list[str] = [] - error_received = anyio.Event() - - class NonMatchingRouter(ResponseRouter): - def route_response(self, request_id: str | int, response: dict[str, Any]) -> bool: - raise NotImplementedError - - def route_error(self, request_id: str | int, error: ErrorData) -> bool: - router_calls.append("non_matching_error") - return False # Doesn't handle it - - class MatchingRouter(ResponseRouter): - def route_response(self, request_id: str | int, response: dict[str, Any]) -> bool: - raise NotImplementedError - - def route_error(self, request_id: str | int, error: ErrorData) -> bool: - router_calls.append("matching_error") - error_received.set() - return True # Handles it - - try: - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=ServerCapabilities(), - ), - ) as server_session: - # Add non-matching router first, then matching router - server_session.add_response_router(NonMatchingRouter()) - server_session.add_response_router(MatchingRouter()) - - # Send an error - should skip first router and be handled by second - error_data = ErrorData(code=INVALID_REQUEST, message="Test error") - error_response = JSONRPCError(jsonrpc="2.0", id="test-req-2", error=error_data) - message = SessionMessage(message=error_response) - await client_to_server_send.send(message) - - with anyio.fail_after(5): - await error_received.wait() - - # Verify both routers were called (first returned False, second returned True) - assert router_calls == ["non_matching_error", "matching_error"] - finally: # pragma: lax no cover - await server_to_client_send.aclose() - await server_to_client_receive.aclose() - await client_to_server_send.aclose() - await client_to_server_receive.aclose() diff --git a/tests/experimental/tasks/server/test_server_task_context.py b/tests/experimental/tasks/server/test_server_task_context.py deleted file mode 100644 index e23299698c..0000000000 --- a/tests/experimental/tasks/server/test_server_task_context.py +++ /dev/null @@ -1,709 +0,0 @@ -"""Tests for ServerTaskContext.""" - -import asyncio -from unittest.mock import AsyncMock, Mock - -import anyio -import pytest - -from mcp.server.experimental.task_context import ServerTaskContext -from mcp.server.experimental.task_result_handler import TaskResultHandler -from mcp.shared.exceptions import MCPError -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue -from mcp.types import ( - CallToolResult, - ClientCapabilities, - ClientTasksCapability, - ClientTasksRequestsCapability, - Implementation, - InitializeRequestParams, - JSONRPCRequest, - SamplingMessage, - TaskMetadata, - TasksCreateElicitationCapability, - TasksCreateMessageCapability, - TasksElicitationCapability, - TasksSamplingCapability, - TextContent, -) - - -@pytest.mark.anyio -async def test_server_task_context_properties() -> None: - """Test ServerTaskContext property accessors.""" - store = InMemoryTaskStore() - mock_session = Mock() - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-123") - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - ) - - assert ctx.task_id == "test-123" - assert ctx.task.task_id == "test-123" - assert ctx.is_cancelled is False - - store.cleanup() - - -@pytest.mark.anyio -async def test_server_task_context_request_cancellation() -> None: - """Test ServerTaskContext.request_cancellation().""" - store = InMemoryTaskStore() - mock_session = Mock() - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - ) - - assert ctx.is_cancelled is False - ctx.request_cancellation() - assert ctx.is_cancelled is True - - store.cleanup() - - -@pytest.mark.anyio -async def test_server_task_context_update_status_with_notify() -> None: - """Test update_status sends notification when notify=True.""" - store = InMemoryTaskStore() - mock_session = Mock() - mock_session.send_notification = AsyncMock() - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - ) - - await ctx.update_status("Working...", notify=True) - - mock_session.send_notification.assert_called_once() - store.cleanup() - - -@pytest.mark.anyio -async def test_server_task_context_update_status_without_notify() -> None: - """Test update_status skips notification when notify=False.""" - store = InMemoryTaskStore() - mock_session = Mock() - mock_session.send_notification = AsyncMock() - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - ) - - await ctx.update_status("Working...", notify=False) - - mock_session.send_notification.assert_not_called() - store.cleanup() - - -@pytest.mark.anyio -async def test_server_task_context_complete_with_notify() -> None: - """Test complete sends notification when notify=True.""" - store = InMemoryTaskStore() - mock_session = Mock() - mock_session.send_notification = AsyncMock() - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - ) - - result = CallToolResult(content=[TextContent(type="text", text="Done")]) - await ctx.complete(result, notify=True) - - mock_session.send_notification.assert_called_once() - store.cleanup() - - -@pytest.mark.anyio -async def test_server_task_context_fail_with_notify() -> None: - """Test fail sends notification when notify=True.""" - store = InMemoryTaskStore() - mock_session = Mock() - mock_session.send_notification = AsyncMock() - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - ) - - await ctx.fail("Something went wrong", notify=True) - - mock_session.send_notification.assert_called_once() - store.cleanup() - - -@pytest.mark.anyio -async def test_elicit_raises_when_client_lacks_capability() -> None: - """Test that elicit() raises MCPError when client doesn't support elicitation.""" - store = InMemoryTaskStore() - mock_session = Mock() - mock_session.check_client_capability = Mock(return_value=False) - queue = InMemoryTaskMessageQueue() - handler = TaskResultHandler(store, queue) - task = await store.create_task(TaskMetadata(ttl=60000)) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=handler, - ) - - with pytest.raises(MCPError) as exc_info: - await ctx.elicit(message="Test?", requested_schema={"type": "object"}) - - assert "elicitation capability" in exc_info.value.error.message - mock_session.check_client_capability.assert_called_once() - store.cleanup() - - -@pytest.mark.anyio -async def test_create_message_raises_when_client_lacks_capability() -> None: - """Test that create_message() raises MCPError when client doesn't support sampling.""" - store = InMemoryTaskStore() - mock_session = Mock() - mock_session.check_client_capability = Mock(return_value=False) - queue = InMemoryTaskMessageQueue() - handler = TaskResultHandler(store, queue) - task = await store.create_task(TaskMetadata(ttl=60000)) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=handler, - ) - - with pytest.raises(MCPError) as exc_info: - await ctx.create_message(messages=[], max_tokens=100) - - assert "sampling capability" in exc_info.value.error.message - mock_session.check_client_capability.assert_called_once() - store.cleanup() - - -@pytest.mark.anyio -async def test_elicit_raises_without_handler() -> None: - """Test that elicit() raises when handler is not provided.""" - store = InMemoryTaskStore() - mock_session = Mock() - mock_session.check_client_capability = Mock(return_value=True) - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=None, - ) - - with pytest.raises(RuntimeError, match="handler is required"): - await ctx.elicit(message="Test?", requested_schema={"type": "object"}) - - store.cleanup() - - -@pytest.mark.anyio -async def test_elicit_url_raises_without_handler() -> None: - """Test that elicit_url() raises when handler is not provided.""" - store = InMemoryTaskStore() - mock_session = Mock() - mock_session.check_client_capability = Mock(return_value=True) - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=None, - ) - - with pytest.raises(RuntimeError, match="handler is required for elicit_url"): - await ctx.elicit_url( - message="Please authorize", - url="https://example.com/oauth", - elicitation_id="oauth-123", - ) - - store.cleanup() - - -@pytest.mark.anyio -async def test_create_message_raises_without_handler() -> None: - """Test that create_message() raises when handler is not provided.""" - store = InMemoryTaskStore() - mock_session = Mock() - mock_session.check_client_capability = Mock(return_value=True) - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=None, - ) - - with pytest.raises(RuntimeError, match="handler is required"): - await ctx.create_message(messages=[], max_tokens=100) - - store.cleanup() - - -@pytest.mark.anyio -async def test_elicit_queues_request_and_waits_for_response() -> None: - """Test that elicit() queues request and waits for response.""" - store = InMemoryTaskStore() - queue = InMemoryTaskMessageQueue() - handler = TaskResultHandler(store, queue) - task = await store.create_task(TaskMetadata(ttl=60000)) - - mock_session = Mock() - mock_session.check_client_capability = Mock(return_value=True) - mock_session._build_elicit_form_request = Mock( - return_value=JSONRPCRequest( - jsonrpc="2.0", - id="test-req-1", - method="elicitation/create", - params={"message": "Test?", "_meta": {}}, - ) - ) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=handler, - ) - - elicit_result = None - - async def run_elicit() -> None: - nonlocal elicit_result - elicit_result = await ctx.elicit( - message="Test?", - requested_schema={"type": "object"}, - ) - - async with anyio.create_task_group() as tg: - tg.start_soon(run_elicit) - - # Wait for request to be queued - await queue.wait_for_message(task.task_id) - - # Verify task is in input_required status - updated_task = await store.get_task(task.task_id) - assert updated_task is not None - assert updated_task.status == "input_required" - - # Dequeue and simulate response - msg = await queue.dequeue(task.task_id) - assert msg is not None - assert msg.resolver is not None - - # Resolve with mock elicitation response - msg.resolver.set_result({"action": "accept", "content": {"name": "Alice"}}) - - # Verify result - assert elicit_result is not None - assert elicit_result.action == "accept" - assert elicit_result.content == {"name": "Alice"} - - # Verify task is back to working - final_task = await store.get_task(task.task_id) - assert final_task is not None - assert final_task.status == "working" - - store.cleanup() - - -@pytest.mark.anyio -async def test_elicit_url_queues_request_and_waits_for_response() -> None: - """Test that elicit_url() queues request and waits for response.""" - store = InMemoryTaskStore() - queue = InMemoryTaskMessageQueue() - handler = TaskResultHandler(store, queue) - task = await store.create_task(TaskMetadata(ttl=60000)) - - mock_session = Mock() - mock_session.check_client_capability = Mock(return_value=True) - mock_session._build_elicit_url_request = Mock( - return_value=JSONRPCRequest( - jsonrpc="2.0", - id="test-url-req-1", - method="elicitation/create", - params={"message": "Authorize", "url": "https://example.com", "elicitationId": "123", "mode": "url"}, - ) - ) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=handler, - ) - - elicit_result = None - - async def run_elicit_url() -> None: - nonlocal elicit_result - elicit_result = await ctx.elicit_url( - message="Authorize", - url="https://example.com/oauth", - elicitation_id="oauth-123", - ) - - async with anyio.create_task_group() as tg: - tg.start_soon(run_elicit_url) - - # Wait for request to be queued - await queue.wait_for_message(task.task_id) - - # Verify task is in input_required status - updated_task = await store.get_task(task.task_id) - assert updated_task is not None - assert updated_task.status == "input_required" - - # Dequeue and simulate response - msg = await queue.dequeue(task.task_id) - assert msg is not None - assert msg.resolver is not None - - # Resolve with mock elicitation response (URL mode just returns action) - msg.resolver.set_result({"action": "accept"}) - - # Verify result - assert elicit_result is not None - assert elicit_result.action == "accept" - - # Verify task is back to working - final_task = await store.get_task(task.task_id) - assert final_task is not None - assert final_task.status == "working" - - store.cleanup() - - -@pytest.mark.anyio -async def test_create_message_queues_request_and_waits_for_response() -> None: - """Test that create_message() queues request and waits for response.""" - store = InMemoryTaskStore() - queue = InMemoryTaskMessageQueue() - handler = TaskResultHandler(store, queue) - task = await store.create_task(TaskMetadata(ttl=60000)) - - mock_session = Mock() - mock_session.check_client_capability = Mock(return_value=True) - mock_session._build_create_message_request = Mock( - return_value=JSONRPCRequest( - jsonrpc="2.0", - id="test-req-2", - method="sampling/createMessage", - params={"messages": [], "maxTokens": 100, "_meta": {}}, - ) - ) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=handler, - ) - - sampling_result = None - - async def run_sampling() -> None: - nonlocal sampling_result - sampling_result = await ctx.create_message( - messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], - max_tokens=100, - ) - - async with anyio.create_task_group() as tg: - tg.start_soon(run_sampling) - - # Wait for request to be queued - await queue.wait_for_message(task.task_id) - - # Verify task is in input_required status - updated_task = await store.get_task(task.task_id) - assert updated_task is not None - assert updated_task.status == "input_required" - - # Dequeue and simulate response - msg = await queue.dequeue(task.task_id) - assert msg is not None - assert msg.resolver is not None - - # Resolve with mock sampling response - msg.resolver.set_result( - { - "role": "assistant", - "content": {"type": "text", "text": "Hello back!"}, - "model": "test-model", - "stopReason": "endTurn", - } - ) - - # Verify result - assert sampling_result is not None - assert sampling_result.role == "assistant" - assert sampling_result.model == "test-model" - - # Verify task is back to working - final_task = await store.get_task(task.task_id) - assert final_task is not None - assert final_task.status == "working" - - store.cleanup() - - -@pytest.mark.anyio -async def test_elicit_restores_status_on_cancellation() -> None: - """Test that elicit() restores task status to working when cancelled.""" - store = InMemoryTaskStore() - queue = InMemoryTaskMessageQueue() - handler = TaskResultHandler(store, queue) - task = await store.create_task(TaskMetadata(ttl=60000)) - - mock_session = Mock() - mock_session.check_client_capability = Mock(return_value=True) - mock_session._build_elicit_form_request = Mock( - return_value=JSONRPCRequest( - jsonrpc="2.0", - id="test-req-cancel", - method="elicitation/create", - params={"message": "Test?", "_meta": {}}, - ) - ) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=handler, - ) - - cancelled_error_raised = False - - async with anyio.create_task_group() as tg: - - async def do_elicit() -> None: - nonlocal cancelled_error_raised - try: - await ctx.elicit( - message="Test?", - requested_schema={"type": "object"}, - ) - except anyio.get_cancelled_exc_class(): - cancelled_error_raised = True - # Don't re-raise - let the test continue - - tg.start_soon(do_elicit) - - # Wait for request to be queued - await queue.wait_for_message(task.task_id) - - # Verify task is in input_required status - updated_task = await store.get_task(task.task_id) - assert updated_task is not None - assert updated_task.status == "input_required" - - # Get the queued message and set cancellation exception on its resolver - msg = await queue.dequeue(task.task_id) - assert msg is not None - assert msg.resolver is not None - - # Trigger cancellation by setting exception (use asyncio.CancelledError directly) - msg.resolver.set_exception(asyncio.CancelledError()) - - # Verify task is back to working after cancellation - final_task = await store.get_task(task.task_id) - assert final_task is not None - assert final_task.status == "working" - assert cancelled_error_raised - - store.cleanup() - - -@pytest.mark.anyio -async def test_create_message_restores_status_on_cancellation() -> None: - """Test that create_message() restores task status to working when cancelled.""" - store = InMemoryTaskStore() - queue = InMemoryTaskMessageQueue() - handler = TaskResultHandler(store, queue) - task = await store.create_task(TaskMetadata(ttl=60000)) - - mock_session = Mock() - mock_session.check_client_capability = Mock(return_value=True) - mock_session._build_create_message_request = Mock( - return_value=JSONRPCRequest( - jsonrpc="2.0", - id="test-req-cancel-2", - method="sampling/createMessage", - params={"messages": [], "maxTokens": 100, "_meta": {}}, - ) - ) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=handler, - ) - - cancelled_error_raised = False - - async with anyio.create_task_group() as tg: - - async def do_sampling() -> None: - nonlocal cancelled_error_raised - try: - await ctx.create_message( - messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], - max_tokens=100, - ) - except anyio.get_cancelled_exc_class(): - cancelled_error_raised = True - # Don't re-raise - - tg.start_soon(do_sampling) - - # Wait for request to be queued - await queue.wait_for_message(task.task_id) - - # Verify task is in input_required status - updated_task = await store.get_task(task.task_id) - assert updated_task is not None - assert updated_task.status == "input_required" - - # Get the queued message and set cancellation exception on its resolver - msg = await queue.dequeue(task.task_id) - assert msg is not None - assert msg.resolver is not None - - # Trigger cancellation by setting exception (use asyncio.CancelledError directly) - msg.resolver.set_exception(asyncio.CancelledError()) - - # Verify task is back to working after cancellation - final_task = await store.get_task(task.task_id) - assert final_task is not None - assert final_task.status == "working" - assert cancelled_error_raised - - store.cleanup() - - -@pytest.mark.anyio -async def test_elicit_as_task_raises_without_handler() -> None: - """Test that elicit_as_task() raises when handler is not provided.""" - store = InMemoryTaskStore() - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) - - # Create mock session with proper client capabilities - mock_session = Mock() - mock_session.client_params = InitializeRequestParams( - protocol_version="2025-01-01", - capabilities=ClientCapabilities( - tasks=ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()) - ) - ) - ), - client_info=Implementation(name="test", version="1.0"), - ) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=None, - ) - - with pytest.raises(RuntimeError, match="handler is required for elicit_as_task"): - await ctx.elicit_as_task(message="Test?", requested_schema={"type": "object"}) - - store.cleanup() - - -@pytest.mark.anyio -async def test_create_message_as_task_raises_without_handler() -> None: - """Test that create_message_as_task() raises when handler is not provided.""" - store = InMemoryTaskStore() - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) - - # Create mock session with proper client capabilities - mock_session = Mock() - mock_session.client_params = InitializeRequestParams( - protocol_version="2025-01-01", - capabilities=ClientCapabilities( - tasks=ClientTasksCapability( - requests=ClientTasksRequestsCapability( - sampling=TasksSamplingCapability(create_message=TasksCreateMessageCapability()) - ) - ) - ), - client_info=Implementation(name="test", version="1.0"), - ) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=None, - ) - - with pytest.raises(RuntimeError, match="handler is required for create_message_as_task"): - await ctx.create_message_as_task( - messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], - max_tokens=100, - ) - - store.cleanup() diff --git a/tests/experimental/tasks/server/test_store.py b/tests/experimental/tasks/server/test_store.py deleted file mode 100644 index 0d431899c8..0000000000 --- a/tests/experimental/tasks/server/test_store.py +++ /dev/null @@ -1,406 +0,0 @@ -"""Tests for InMemoryTaskStore.""" - -from collections.abc import AsyncIterator -from datetime import datetime, timedelta, timezone - -import pytest - -from mcp.shared.exceptions import MCPError -from mcp.shared.experimental.tasks.helpers import cancel_task -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.types import INVALID_PARAMS, CallToolResult, TaskMetadata, TextContent - - -@pytest.fixture -async def store() -> AsyncIterator[InMemoryTaskStore]: - """Provide a clean InMemoryTaskStore for each test with automatic cleanup.""" - store = InMemoryTaskStore() - yield store - store.cleanup() - - -@pytest.mark.anyio -async def test_create_and_get(store: InMemoryTaskStore) -> None: - """Test InMemoryTaskStore create and get operations.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - - assert task.task_id is not None - assert task.status == "working" - assert task.ttl == 60000 - - retrieved = await store.get_task(task.task_id) - assert retrieved is not None - assert retrieved.task_id == task.task_id - assert retrieved.status == "working" - - -@pytest.mark.anyio -async def test_create_with_custom_id(store: InMemoryTaskStore) -> None: - """Test InMemoryTaskStore create with custom task ID.""" - task = await store.create_task( - metadata=TaskMetadata(ttl=60000), - task_id="my-custom-id", - ) - - assert task.task_id == "my-custom-id" - assert task.status == "working" - - retrieved = await store.get_task("my-custom-id") - assert retrieved is not None - assert retrieved.task_id == "my-custom-id" - - -@pytest.mark.anyio -async def test_create_duplicate_id_raises(store: InMemoryTaskStore) -> None: - """Test that creating a task with duplicate ID raises.""" - await store.create_task(metadata=TaskMetadata(ttl=60000), task_id="duplicate") - - with pytest.raises(ValueError, match="already exists"): - await store.create_task(metadata=TaskMetadata(ttl=60000), task_id="duplicate") - - -@pytest.mark.anyio -async def test_get_nonexistent_returns_none(store: InMemoryTaskStore) -> None: - """Test that getting a nonexistent task returns None.""" - retrieved = await store.get_task("nonexistent") - assert retrieved is None - - -@pytest.mark.anyio -async def test_update_status(store: InMemoryTaskStore) -> None: - """Test InMemoryTaskStore status updates.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - - updated = await store.update_task(task.task_id, status="completed", status_message="All done!") - - assert updated.status == "completed" - assert updated.status_message == "All done!" - - retrieved = await store.get_task(task.task_id) - assert retrieved is not None - assert retrieved.status == "completed" - assert retrieved.status_message == "All done!" - - -@pytest.mark.anyio -async def test_update_nonexistent_raises(store: InMemoryTaskStore) -> None: - """Test that updating a nonexistent task raises.""" - with pytest.raises(ValueError, match="not found"): - await store.update_task("nonexistent", status="completed") - - -@pytest.mark.anyio -async def test_store_and_get_result(store: InMemoryTaskStore) -> None: - """Test InMemoryTaskStore result storage and retrieval.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - - # Store result - result = CallToolResult(content=[TextContent(type="text", text="Result data")]) - await store.store_result(task.task_id, result) - - # Retrieve result - retrieved_result = await store.get_result(task.task_id) - assert retrieved_result == result - - -@pytest.mark.anyio -async def test_get_result_nonexistent_returns_none(store: InMemoryTaskStore) -> None: - """Test that getting result for nonexistent task returns None.""" - result = await store.get_result("nonexistent") - assert result is None - - -@pytest.mark.anyio -async def test_get_result_no_result_returns_none(store: InMemoryTaskStore) -> None: - """Test that getting result when none stored returns None.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - result = await store.get_result(task.task_id) - assert result is None - - -@pytest.mark.anyio -async def test_list_tasks(store: InMemoryTaskStore) -> None: - """Test InMemoryTaskStore list operation.""" - # Create multiple tasks - for _ in range(3): - await store.create_task(metadata=TaskMetadata(ttl=60000)) - - tasks, next_cursor = await store.list_tasks() - assert len(tasks) == 3 - assert next_cursor is None # Less than page size - - -@pytest.mark.anyio -async def test_list_tasks_pagination() -> None: - """Test InMemoryTaskStore pagination.""" - # Needs custom page_size, can't use fixture - store = InMemoryTaskStore(page_size=2) - - # Create 5 tasks - for _ in range(5): - await store.create_task(metadata=TaskMetadata(ttl=60000)) - - # First page - tasks, next_cursor = await store.list_tasks() - assert len(tasks) == 2 - assert next_cursor is not None - - # Second page - tasks, next_cursor = await store.list_tasks(cursor=next_cursor) - assert len(tasks) == 2 - assert next_cursor is not None - - # Third page (last) - tasks, next_cursor = await store.list_tasks(cursor=next_cursor) - assert len(tasks) == 1 - assert next_cursor is None - - store.cleanup() - - -@pytest.mark.anyio -async def test_list_tasks_invalid_cursor(store: InMemoryTaskStore) -> None: - """Test that invalid cursor raises.""" - await store.create_task(metadata=TaskMetadata(ttl=60000)) - - with pytest.raises(ValueError, match="Invalid cursor"): - await store.list_tasks(cursor="invalid-cursor") - - -@pytest.mark.anyio -async def test_delete_task(store: InMemoryTaskStore) -> None: - """Test InMemoryTaskStore delete operation.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - - deleted = await store.delete_task(task.task_id) - assert deleted is True - - retrieved = await store.get_task(task.task_id) - assert retrieved is None - - # Delete non-existent - deleted = await store.delete_task(task.task_id) - assert deleted is False - - -@pytest.mark.anyio -async def test_get_all_tasks_helper(store: InMemoryTaskStore) -> None: - """Test the get_all_tasks debugging helper.""" - await store.create_task(metadata=TaskMetadata(ttl=60000)) - await store.create_task(metadata=TaskMetadata(ttl=60000)) - - all_tasks = store.get_all_tasks() - assert len(all_tasks) == 2 - - -@pytest.mark.anyio -async def test_store_result_nonexistent_raises(store: InMemoryTaskStore) -> None: - """Test that storing result for nonexistent task raises ValueError.""" - result = CallToolResult(content=[TextContent(type="text", text="Result")]) - - with pytest.raises(ValueError, match="not found"): - await store.store_result("nonexistent-id", result) - - -@pytest.mark.anyio -async def test_create_task_with_null_ttl(store: InMemoryTaskStore) -> None: - """Test creating task with null TTL (never expires).""" - task = await store.create_task(metadata=TaskMetadata(ttl=None)) - - assert task.ttl is None - - # Task should persist (not expire) - retrieved = await store.get_task(task.task_id) - assert retrieved is not None - - -@pytest.mark.anyio -async def test_task_expiration_cleanup(store: InMemoryTaskStore) -> None: - """Test that expired tasks are cleaned up lazily.""" - # Create a task with very short TTL - task = await store.create_task(metadata=TaskMetadata(ttl=1)) # 1ms TTL - - # Manually force the expiry to be in the past - stored = store._tasks.get(task.task_id) - assert stored is not None - stored.expires_at = datetime.now(timezone.utc) - timedelta(seconds=10) - - # Task should still exist in internal dict but be expired - assert task.task_id in store._tasks - - # Any access operation should clean up expired tasks - # list_tasks triggers cleanup - tasks, _ = await store.list_tasks() - - # Expired task should be cleaned up - assert task.task_id not in store._tasks - assert len(tasks) == 0 - - -@pytest.mark.anyio -async def test_task_with_null_ttl_never_expires(store: InMemoryTaskStore) -> None: - """Test that tasks with null TTL never expire during cleanup.""" - # Create task with null TTL - task = await store.create_task(metadata=TaskMetadata(ttl=None)) - - # Verify internal storage has no expiry - stored = store._tasks.get(task.task_id) - assert stored is not None - assert stored.expires_at is None - - # Access operations should NOT remove this task - await store.list_tasks() - await store.get_task(task.task_id) - - # Task should still exist - assert task.task_id in store._tasks - retrieved = await store.get_task(task.task_id) - assert retrieved is not None - - -@pytest.mark.anyio -async def test_terminal_task_ttl_reset(store: InMemoryTaskStore) -> None: - """Test that TTL is reset when task enters terminal state.""" - # Create task with short TTL - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) # 60s - - # Get the initial expiry - stored = store._tasks.get(task.task_id) - assert stored is not None - initial_expiry = stored.expires_at - assert initial_expiry is not None - - # Update to terminal state (completed) - await store.update_task(task.task_id, status="completed") - - # Expiry should be reset to a new time (from now + TTL) - new_expiry = stored.expires_at - assert new_expiry is not None - assert new_expiry >= initial_expiry - - -@pytest.mark.anyio -async def test_terminal_status_transition_rejected(store: InMemoryTaskStore) -> None: - """Test that transitions from terminal states are rejected. - - Per spec: Terminal states (completed, failed, cancelled) MUST NOT - transition to any other status. - """ - # Test each terminal status - for terminal_status in ("completed", "failed", "cancelled"): - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - - # Move to terminal state - await store.update_task(task.task_id, status=terminal_status) - - # Attempting to transition to any other status should raise - with pytest.raises(ValueError, match="Cannot transition from terminal status"): - await store.update_task(task.task_id, status="working") - - # Also test transitioning to another terminal state - other_terminal = "failed" if terminal_status != "failed" else "completed" - with pytest.raises(ValueError, match="Cannot transition from terminal status"): - await store.update_task(task.task_id, status=other_terminal) - - -@pytest.mark.anyio -async def test_terminal_status_allows_same_status(store: InMemoryTaskStore) -> None: - """Test that setting the same terminal status doesn't raise. - - This is not a transition, so it should be allowed (no-op). - """ - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - await store.update_task(task.task_id, status="completed") - - # Setting the same status should not raise - updated = await store.update_task(task.task_id, status="completed") - assert updated.status == "completed" - - # Updating just the message should also work - updated = await store.update_task(task.task_id, status_message="Updated message") - assert updated.status_message == "Updated message" - - -@pytest.mark.anyio -async def test_wait_for_update_nonexistent_raises(store: InMemoryTaskStore) -> None: - """Test that wait_for_update raises for nonexistent task.""" - with pytest.raises(ValueError, match="not found"): - await store.wait_for_update("nonexistent-task-id") - - -@pytest.mark.anyio -async def test_cancel_task_succeeds_for_working_task(store: InMemoryTaskStore) -> None: - """Test cancel_task helper succeeds for a working task.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - assert task.status == "working" - - result = await cancel_task(store, task.task_id) - - assert result.task_id == task.task_id - assert result.status == "cancelled" - - # Verify store is updated - retrieved = await store.get_task(task.task_id) - assert retrieved is not None - assert retrieved.status == "cancelled" - - -@pytest.mark.anyio -async def test_cancel_task_rejects_nonexistent_task(store: InMemoryTaskStore) -> None: - """Test cancel_task raises MCPError with INVALID_PARAMS for nonexistent task.""" - with pytest.raises(MCPError) as exc_info: - await cancel_task(store, "nonexistent-task-id") - - assert exc_info.value.error.code == INVALID_PARAMS - assert "not found" in exc_info.value.error.message - - -@pytest.mark.anyio -async def test_cancel_task_rejects_completed_task(store: InMemoryTaskStore) -> None: - """Test cancel_task raises MCPError with INVALID_PARAMS for completed task.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - await store.update_task(task.task_id, status="completed") - - with pytest.raises(MCPError) as exc_info: - await cancel_task(store, task.task_id) - - assert exc_info.value.error.code == INVALID_PARAMS - assert "terminal state 'completed'" in exc_info.value.error.message - - -@pytest.mark.anyio -async def test_cancel_task_rejects_failed_task(store: InMemoryTaskStore) -> None: - """Test cancel_task raises MCPError with INVALID_PARAMS for failed task.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - await store.update_task(task.task_id, status="failed") - - with pytest.raises(MCPError) as exc_info: - await cancel_task(store, task.task_id) - - assert exc_info.value.error.code == INVALID_PARAMS - assert "terminal state 'failed'" in exc_info.value.error.message - - -@pytest.mark.anyio -async def test_cancel_task_rejects_already_cancelled_task(store: InMemoryTaskStore) -> None: - """Test cancel_task raises MCPError with INVALID_PARAMS for already cancelled task.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - await store.update_task(task.task_id, status="cancelled") - - with pytest.raises(MCPError) as exc_info: - await cancel_task(store, task.task_id) - - assert exc_info.value.error.code == INVALID_PARAMS - assert "terminal state 'cancelled'" in exc_info.value.error.message - - -@pytest.mark.anyio -async def test_cancel_task_succeeds_for_input_required_task(store: InMemoryTaskStore) -> None: - """Test cancel_task helper succeeds for a task in input_required status.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - await store.update_task(task.task_id, status="input_required") - - result = await cancel_task(store, task.task_id) - - assert result.task_id == task.task_id - assert result.status == "cancelled" diff --git a/tests/experimental/tasks/server/test_task_result_handler.py b/tests/experimental/tasks/server/test_task_result_handler.py deleted file mode 100644 index 8b5a03ce2b..0000000000 --- a/tests/experimental/tasks/server/test_task_result_handler.py +++ /dev/null @@ -1,354 +0,0 @@ -"""Tests for TaskResultHandler.""" - -from collections.abc import AsyncIterator -from typing import Any -from unittest.mock import AsyncMock, Mock - -import anyio -import pytest - -from mcp.server.experimental.task_result_handler import TaskResultHandler -from mcp.shared.exceptions import MCPError -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue, QueuedMessage -from mcp.shared.experimental.tasks.resolver import Resolver -from mcp.shared.message import SessionMessage -from mcp.types import ( - INVALID_REQUEST, - CallToolResult, - ErrorData, - GetTaskPayloadRequest, - GetTaskPayloadRequestParams, - GetTaskPayloadResult, - JSONRPCRequest, - TaskMetadata, - TextContent, -) - - -@pytest.fixture -async def store() -> AsyncIterator[InMemoryTaskStore]: - """Provide a clean store for each test.""" - s = InMemoryTaskStore() - yield s - s.cleanup() - - -@pytest.fixture -def queue() -> InMemoryTaskMessageQueue: - """Provide a clean queue for each test.""" - return InMemoryTaskMessageQueue() - - -@pytest.fixture -def handler(store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue) -> TaskResultHandler: - """Provide a handler for each test.""" - return TaskResultHandler(store, queue) - - -@pytest.mark.anyio -async def test_handle_returns_result_for_completed_task( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that handle() returns the stored result for a completed task.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") - result = CallToolResult(content=[TextContent(type="text", text="Done!")]) - await store.store_result(task.task_id, result) - await store.update_task(task.task_id, status="completed") - - mock_session = Mock() - mock_session.send_message = AsyncMock() - - request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=task.task_id)) - response = await handler.handle(request, mock_session, "req-1") - - assert response is not None - assert response.meta is not None - assert "io.modelcontextprotocol/related-task" in response.meta - - -@pytest.mark.anyio -async def test_handle_raises_for_nonexistent_task( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that handle() raises MCPError for nonexistent task.""" - mock_session = Mock() - request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id="nonexistent")) - - with pytest.raises(MCPError) as exc_info: - await handler.handle(request, mock_session, "req-1") - - assert "not found" in exc_info.value.error.message - - -@pytest.mark.anyio -async def test_handle_returns_empty_result_when_no_result_stored( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that handle() returns minimal result when task completed without stored result.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") - await store.update_task(task.task_id, status="completed") - - mock_session = Mock() - mock_session.send_message = AsyncMock() - - request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=task.task_id)) - response = await handler.handle(request, mock_session, "req-1") - - assert response is not None - assert response.meta is not None - assert "io.modelcontextprotocol/related-task" in response.meta - - -@pytest.mark.anyio -async def test_handle_delivers_queued_messages( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that handle() delivers queued messages before returning.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") - - queued_msg = QueuedMessage( - type="notification", - message=JSONRPCRequest( - jsonrpc="2.0", - id="notif-1", - method="test/notification", - params={}, - ), - ) - await queue.enqueue(task.task_id, queued_msg) - await store.update_task(task.task_id, status="completed") - - sent_messages: list[SessionMessage] = [] - - async def track_send(msg: SessionMessage) -> None: - sent_messages.append(msg) - - mock_session = Mock() - mock_session.send_message = track_send - - request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=task.task_id)) - await handler.handle(request, mock_session, "req-1") - - assert len(sent_messages) == 1 - - -@pytest.mark.anyio -async def test_handle_waits_for_task_completion( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that handle() waits for task to complete before returning.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") - - mock_session = Mock() - mock_session.send_message = AsyncMock() - - request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=task.task_id)) - result_holder: list[GetTaskPayloadResult | None] = [None] - - async def run_handle() -> None: - result_holder[0] = await handler.handle(request, mock_session, "req-1") - - async with anyio.create_task_group() as tg: - tg.start_soon(run_handle) - - # Wait for handler to start waiting (event gets created when wait starts) - while task.task_id not in store._update_events: - await anyio.sleep(0) - - await store.store_result(task.task_id, CallToolResult(content=[TextContent(type="text", text="Done")])) - await store.update_task(task.task_id, status="completed") - - assert result_holder[0] is not None - - -@pytest.mark.anyio -async def test_route_response_resolves_pending_request( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that route_response() resolves a pending request.""" - resolver: Resolver[dict[str, Any]] = Resolver() - handler._pending_requests["req-123"] = resolver - - result = handler.route_response("req-123", {"status": "ok"}) - - assert result is True - assert resolver.done() - assert await resolver.wait() == {"status": "ok"} - - -@pytest.mark.anyio -async def test_route_response_returns_false_for_unknown_request( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that route_response() returns False for unknown request ID.""" - result = handler.route_response("unknown-req", {"status": "ok"}) - assert result is False - - -@pytest.mark.anyio -async def test_route_response_returns_false_for_already_done_resolver( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that route_response() returns False if resolver already completed.""" - resolver: Resolver[dict[str, Any]] = Resolver() - resolver.set_result({"already": "done"}) - handler._pending_requests["req-123"] = resolver - - result = handler.route_response("req-123", {"new": "data"}) - - assert result is False - - -@pytest.mark.anyio -async def test_route_error_resolves_pending_request_with_exception( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that route_error() sets exception on pending request.""" - resolver: Resolver[dict[str, Any]] = Resolver() - handler._pending_requests["req-123"] = resolver - - error = ErrorData(code=INVALID_REQUEST, message="Something went wrong") - result = handler.route_error("req-123", error) - - assert result is True - assert resolver.done() - - with pytest.raises(MCPError) as exc_info: - await resolver.wait() - assert exc_info.value.error.message == "Something went wrong" - - -@pytest.mark.anyio -async def test_route_error_returns_false_for_unknown_request( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that route_error() returns False for unknown request ID.""" - error = ErrorData(code=INVALID_REQUEST, message="Error") - result = handler.route_error("unknown-req", error) - assert result is False - - -@pytest.mark.anyio -async def test_deliver_registers_resolver_for_request_messages( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that _deliver_queued_messages registers resolvers for request messages.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") - - resolver: Resolver[dict[str, Any]] = Resolver() - queued_msg = QueuedMessage( - type="request", - message=JSONRPCRequest( - jsonrpc="2.0", - id="inner-req-1", - method="elicitation/create", - params={}, - ), - resolver=resolver, - original_request_id="inner-req-1", - ) - await queue.enqueue(task.task_id, queued_msg) - - mock_session = Mock() - mock_session.send_message = AsyncMock() - - await handler._deliver_queued_messages(task.task_id, mock_session, "outer-req-1") - - assert "inner-req-1" in handler._pending_requests - assert handler._pending_requests["inner-req-1"] is resolver - - -@pytest.mark.anyio -async def test_deliver_skips_resolver_registration_when_no_original_id( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that _deliver_queued_messages skips resolver registration when original_request_id is None.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") - - resolver: Resolver[dict[str, Any]] = Resolver() - queued_msg = QueuedMessage( - type="request", - message=JSONRPCRequest( - jsonrpc="2.0", - id="inner-req-1", - method="elicitation/create", - params={}, - ), - resolver=resolver, - original_request_id=None, # No original request ID - ) - await queue.enqueue(task.task_id, queued_msg) - - mock_session = Mock() - mock_session.send_message = AsyncMock() - - await handler._deliver_queued_messages(task.task_id, mock_session, "outer-req-1") - - # Resolver should NOT be registered since original_request_id is None - assert len(handler._pending_requests) == 0 - # But the message should still be sent - mock_session.send_message.assert_called_once() - - -@pytest.mark.anyio -async def test_wait_for_task_update_handles_store_exception( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that _wait_for_task_update handles store exception gracefully.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") - - # Make wait_for_update raise an exception - async def failing_wait(task_id: str) -> None: - raise RuntimeError("Store error") - - store.wait_for_update = failing_wait # type: ignore[method-assign] - - # Queue a message to unblock the race via the queue path - async def enqueue_later() -> None: - # Wait for queue to start waiting (event gets created when wait starts) - while task.task_id not in queue._events: - await anyio.sleep(0) - await queue.enqueue( - task.task_id, - QueuedMessage( - type="notification", - message=JSONRPCRequest( - jsonrpc="2.0", - id="notif-1", - method="test/notification", - params={}, - ), - ), - ) - - async with anyio.create_task_group() as tg: - tg.start_soon(enqueue_later) - # This should complete via the queue path even though store raises - await handler._wait_for_task_update(task.task_id) - - -@pytest.mark.anyio -async def test_wait_for_task_update_handles_queue_exception( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that _wait_for_task_update handles queue exception gracefully.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") - - # Make wait_for_message raise an exception - async def failing_wait(task_id: str) -> None: - raise RuntimeError("Queue error") - - queue.wait_for_message = failing_wait # type: ignore[method-assign] - - # Update the store to unblock the race via the store path - async def update_later() -> None: - # Wait for store to start waiting (event gets created when wait starts) - while task.task_id not in store._update_events: - await anyio.sleep(0) - await store.update_task(task.task_id, status="completed") - - async with anyio.create_task_group() as tg: - tg.start_soon(update_later) - # This should complete via the store path even though queue raises - await handler._wait_for_task_update(task.task_id) diff --git a/tests/experimental/tasks/test_capabilities.py b/tests/experimental/tasks/test_capabilities.py deleted file mode 100644 index 90a8656ba0..0000000000 --- a/tests/experimental/tasks/test_capabilities.py +++ /dev/null @@ -1,283 +0,0 @@ -"""Tests for tasks capability checking utilities.""" - -import pytest - -from mcp import MCPError -from mcp.shared.experimental.tasks.capabilities import ( - check_tasks_capability, - has_task_augmented_elicitation, - has_task_augmented_sampling, - require_task_augmented_elicitation, - require_task_augmented_sampling, -) -from mcp.types import ( - ClientCapabilities, - ClientTasksCapability, - ClientTasksRequestsCapability, - TasksCreateElicitationCapability, - TasksCreateMessageCapability, - TasksElicitationCapability, - TasksSamplingCapability, -) - - -class TestCheckTasksCapability: - """Tests for check_tasks_capability function.""" - - def test_required_requests_none_returns_true(self) -> None: - """When required.requests is None, should return True.""" - required = ClientTasksCapability() - client = ClientTasksCapability() - assert check_tasks_capability(required, client) is True - - def test_client_requests_none_returns_false(self) -> None: - """When client.requests is None but required.requests is set, should return False.""" - required = ClientTasksCapability(requests=ClientTasksRequestsCapability()) - client = ClientTasksCapability() - assert check_tasks_capability(required, client) is False - - def test_elicitation_required_but_client_missing(self) -> None: - """When elicitation is required but client doesn't have it.""" - required = ClientTasksCapability( - requests=ClientTasksRequestsCapability(elicitation=TasksElicitationCapability()) - ) - client = ClientTasksCapability(requests=ClientTasksRequestsCapability()) - assert check_tasks_capability(required, client) is False - - def test_elicitation_create_required_but_client_missing(self) -> None: - """When elicitation.create is required but client doesn't have it.""" - required = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()) - ) - ) - client = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability() # No create - ) - ) - assert check_tasks_capability(required, client) is False - - def test_elicitation_create_present(self) -> None: - """When elicitation.create is required and client has it.""" - required = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()) - ) - ) - client = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()) - ) - ) - assert check_tasks_capability(required, client) is True - - def test_sampling_required_but_client_missing(self) -> None: - """When sampling is required but client doesn't have it.""" - required = ClientTasksCapability(requests=ClientTasksRequestsCapability(sampling=TasksSamplingCapability())) - client = ClientTasksCapability(requests=ClientTasksRequestsCapability()) - assert check_tasks_capability(required, client) is False - - def test_sampling_create_message_required_but_client_missing(self) -> None: - """When sampling.createMessage is required but client doesn't have it.""" - required = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - sampling=TasksSamplingCapability(create_message=TasksCreateMessageCapability()) - ) - ) - client = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - sampling=TasksSamplingCapability() # No createMessage - ) - ) - assert check_tasks_capability(required, client) is False - - def test_sampling_create_message_present(self) -> None: - """When sampling.createMessage is required and client has it.""" - required = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - sampling=TasksSamplingCapability(create_message=TasksCreateMessageCapability()) - ) - ) - client = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - sampling=TasksSamplingCapability(create_message=TasksCreateMessageCapability()) - ) - ) - assert check_tasks_capability(required, client) is True - - def test_both_elicitation_and_sampling_present(self) -> None: - """When both elicitation.create and sampling.createMessage are required and client has both.""" - required = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()), - sampling=TasksSamplingCapability(create_message=TasksCreateMessageCapability()), - ) - ) - client = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()), - sampling=TasksSamplingCapability(create_message=TasksCreateMessageCapability()), - ) - ) - assert check_tasks_capability(required, client) is True - - def test_elicitation_without_create_required(self) -> None: - """When elicitation is required but not create specifically.""" - required = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability() # No create - ) - ) - client = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()) - ) - ) - assert check_tasks_capability(required, client) is True - - def test_sampling_without_create_message_required(self) -> None: - """When sampling is required but not createMessage specifically.""" - required = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - sampling=TasksSamplingCapability() # No createMessage - ) - ) - client = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - sampling=TasksSamplingCapability(create_message=TasksCreateMessageCapability()) - ) - ) - assert check_tasks_capability(required, client) is True - - -class TestHasTaskAugmentedElicitation: - """Tests for has_task_augmented_elicitation function.""" - - def test_tasks_none(self) -> None: - """Returns False when caps.tasks is None.""" - caps = ClientCapabilities() - assert has_task_augmented_elicitation(caps) is False - - def test_requests_none(self) -> None: - """Returns False when caps.tasks.requests is None.""" - caps = ClientCapabilities(tasks=ClientTasksCapability()) - assert has_task_augmented_elicitation(caps) is False - - def test_elicitation_none(self) -> None: - """Returns False when caps.tasks.requests.elicitation is None.""" - caps = ClientCapabilities(tasks=ClientTasksCapability(requests=ClientTasksRequestsCapability())) - assert has_task_augmented_elicitation(caps) is False - - def test_create_none(self) -> None: - """Returns False when caps.tasks.requests.elicitation.create is None.""" - caps = ClientCapabilities( - tasks=ClientTasksCapability( - requests=ClientTasksRequestsCapability(elicitation=TasksElicitationCapability()) - ) - ) - assert has_task_augmented_elicitation(caps) is False - - def test_create_present(self) -> None: - """Returns True when full capability path is present.""" - caps = ClientCapabilities( - tasks=ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()) - ) - ) - ) - assert has_task_augmented_elicitation(caps) is True - - -class TestHasTaskAugmentedSampling: - """Tests for has_task_augmented_sampling function.""" - - def test_tasks_none(self) -> None: - """Returns False when caps.tasks is None.""" - caps = ClientCapabilities() - assert has_task_augmented_sampling(caps) is False - - def test_requests_none(self) -> None: - """Returns False when caps.tasks.requests is None.""" - caps = ClientCapabilities(tasks=ClientTasksCapability()) - assert has_task_augmented_sampling(caps) is False - - def test_sampling_none(self) -> None: - """Returns False when caps.tasks.requests.sampling is None.""" - caps = ClientCapabilities(tasks=ClientTasksCapability(requests=ClientTasksRequestsCapability())) - assert has_task_augmented_sampling(caps) is False - - def test_create_message_none(self) -> None: - """Returns False when caps.tasks.requests.sampling.createMessage is None.""" - caps = ClientCapabilities( - tasks=ClientTasksCapability(requests=ClientTasksRequestsCapability(sampling=TasksSamplingCapability())) - ) - assert has_task_augmented_sampling(caps) is False - - def test_create_message_present(self) -> None: - """Returns True when full capability path is present.""" - caps = ClientCapabilities( - tasks=ClientTasksCapability( - requests=ClientTasksRequestsCapability( - sampling=TasksSamplingCapability(create_message=TasksCreateMessageCapability()) - ) - ) - ) - assert has_task_augmented_sampling(caps) is True - - -class TestRequireTaskAugmentedElicitation: - """Tests for require_task_augmented_elicitation function.""" - - def test_raises_when_none(self) -> None: - """Raises MCPError when client_caps is None.""" - with pytest.raises(MCPError) as exc_info: - require_task_augmented_elicitation(None) - assert "task-augmented elicitation" in str(exc_info.value) - - def test_raises_when_missing(self) -> None: - """Raises MCPError when capability is missing.""" - caps = ClientCapabilities() - with pytest.raises(MCPError) as exc_info: - require_task_augmented_elicitation(caps) - assert "task-augmented elicitation" in str(exc_info.value) - - def test_passes_when_present(self) -> None: - """Does not raise when capability is present.""" - caps = ClientCapabilities( - tasks=ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()) - ) - ) - ) - require_task_augmented_elicitation(caps) - - -class TestRequireTaskAugmentedSampling: - """Tests for require_task_augmented_sampling function.""" - - def test_raises_when_none(self) -> None: - """Raises MCPError when client_caps is None.""" - with pytest.raises(MCPError) as exc_info: - require_task_augmented_sampling(None) - assert "task-augmented sampling" in str(exc_info.value) - - def test_raises_when_missing(self) -> None: - """Raises MCPError when capability is missing.""" - caps = ClientCapabilities() - with pytest.raises(MCPError) as exc_info: - require_task_augmented_sampling(caps) - assert "task-augmented sampling" in str(exc_info.value) - - def test_passes_when_present(self) -> None: - """Does not raise when capability is present.""" - caps = ClientCapabilities( - tasks=ClientTasksCapability( - requests=ClientTasksRequestsCapability( - sampling=TasksSamplingCapability(create_message=TasksCreateMessageCapability()) - ) - ) - ) - require_task_augmented_sampling(caps) diff --git a/tests/experimental/tasks/test_elicitation_scenarios.py b/tests/experimental/tasks/test_elicitation_scenarios.py deleted file mode 100644 index 2d0378a9ce..0000000000 --- a/tests/experimental/tasks/test_elicitation_scenarios.py +++ /dev/null @@ -1,695 +0,0 @@ -"""Tests for the four elicitation scenarios with tasks. - -This tests all combinations of tool call types and elicitation types: -1. Normal tool call + Normal elicitation (session.elicit) -2. Normal tool call + Task-augmented elicitation (session.experimental.elicit_as_task) -3. Task-augmented tool call + Normal elicitation (task.elicit) -4. Task-augmented tool call + Task-augmented elicitation (task.elicit_as_task) - -And the same for sampling (create_message). -""" - -from typing import Any - -import anyio -import pytest -from anyio import Event - -from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers -from mcp.client.session import ClientSession -from mcp.server import Server, ServerRequestContext -from mcp.server.experimental.task_context import ServerTaskContext -from mcp.server.lowlevel import NotificationOptions -from mcp.shared._context import RequestContext -from mcp.shared.experimental.tasks.helpers import is_terminal -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.shared.message import SessionMessage -from mcp.types import ( - TASK_REQUIRED, - CallToolRequestParams, - CallToolResult, - CreateMessageRequestParams, - CreateMessageResult, - CreateTaskResult, - ElicitRequestParams, - ElicitResult, - ErrorData, - GetTaskPayloadResult, - GetTaskResult, - ListToolsResult, - PaginatedRequestParams, - SamplingMessage, - TaskMetadata, - TextContent, - Tool, -) - - -def create_client_task_handlers( - client_task_store: InMemoryTaskStore, - elicit_received: Event, -) -> ExperimentalTaskHandlers: - """Create task handlers for client to handle task-augmented elicitation from server.""" - - elicit_response = ElicitResult(action="accept", content={"confirm": True}) - task_complete_events: dict[str, Event] = {} - - async def handle_augmented_elicitation( - context: RequestContext[ClientSession], - params: ElicitRequestParams, - task_metadata: TaskMetadata, - ) -> CreateTaskResult: - """Handle task-augmented elicitation by creating a client-side task.""" - elicit_received.set() - task = await client_task_store.create_task(task_metadata) - task_complete_events[task.task_id] = Event() - - async def complete_task() -> None: - # Store result before updating status to avoid race condition - await client_task_store.store_result(task.task_id, elicit_response) - await client_task_store.update_task(task.task_id, status="completed") - task_complete_events[task.task_id].set() - - context.session._task_group.start_soon(complete_task) # pyright: ignore[reportPrivateUsage] - return CreateTaskResult(task=task) - - async def handle_get_task( - context: RequestContext[ClientSession], - params: Any, - ) -> GetTaskResult: - """Handle tasks/get from server.""" - task = await client_task_store.get_task(params.task_id) - assert task is not None, f"Task not found: {params.task_id}" - return GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=100, - ) - - async def handle_get_task_result( - context: RequestContext[ClientSession], - params: Any, - ) -> GetTaskPayloadResult | ErrorData: - """Handle tasks/result from server.""" - event = task_complete_events.get(params.task_id) - assert event is not None, f"No completion event for task: {params.task_id}" - await event.wait() - result = await client_task_store.get_result(params.task_id) - assert result is not None, f"Result not found for task: {params.task_id}" - return GetTaskPayloadResult.model_validate(result.model_dump(by_alias=True)) - - return ExperimentalTaskHandlers( - augmented_elicitation=handle_augmented_elicitation, - get_task=handle_get_task, - get_task_result=handle_get_task_result, - ) - - -def create_sampling_task_handlers( - client_task_store: InMemoryTaskStore, - sampling_received: Event, -) -> ExperimentalTaskHandlers: - """Create task handlers for client to handle task-augmented sampling from server.""" - - sampling_response = CreateMessageResult( - role="assistant", - content=TextContent(type="text", text="Hello from the model!"), - model="test-model", - ) - task_complete_events: dict[str, Event] = {} - - async def handle_augmented_sampling( - context: RequestContext[ClientSession], - params: CreateMessageRequestParams, - task_metadata: TaskMetadata, - ) -> CreateTaskResult: - """Handle task-augmented sampling by creating a client-side task.""" - sampling_received.set() - task = await client_task_store.create_task(task_metadata) - task_complete_events[task.task_id] = Event() - - async def complete_task() -> None: - # Store result before updating status to avoid race condition - await client_task_store.store_result(task.task_id, sampling_response) - await client_task_store.update_task(task.task_id, status="completed") - task_complete_events[task.task_id].set() - - context.session._task_group.start_soon(complete_task) # pyright: ignore[reportPrivateUsage] - return CreateTaskResult(task=task) - - async def handle_get_task( - context: RequestContext[ClientSession], - params: Any, - ) -> GetTaskResult: - """Handle tasks/get from server.""" - task = await client_task_store.get_task(params.task_id) - assert task is not None, f"Task not found: {params.task_id}" - return GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=100, - ) - - async def handle_get_task_result( - context: RequestContext[ClientSession], - params: Any, - ) -> GetTaskPayloadResult | ErrorData: - """Handle tasks/result from server.""" - event = task_complete_events.get(params.task_id) - assert event is not None, f"No completion event for task: {params.task_id}" - await event.wait() - result = await client_task_store.get_result(params.task_id) - assert result is not None, f"Result not found for task: {params.task_id}" - return GetTaskPayloadResult.model_validate(result.model_dump(by_alias=True)) - - return ExperimentalTaskHandlers( - augmented_sampling=handle_augmented_sampling, - get_task=handle_get_task, - get_task_result=handle_get_task_result, - ) - - -@pytest.mark.anyio -async def test_scenario1_normal_tool_normal_elicitation() -> None: - """Scenario 1: Normal tool call with normal elicitation. - - Server calls session.elicit() directly, client responds immediately. - """ - elicit_received = Event() - tool_result: list[str] = [] - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult( - tools=[ - Tool( - name="confirm_action", - description="Confirm an action", - input_schema={"type": "object"}, - ) - ] - ) - - async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: - # Normal elicitation - expects immediate response - result = await ctx.session.elicit( - message="Please confirm the action", - requested_schema={"type": "object", "properties": {"confirm": {"type": "boolean"}}}, - ) - - confirmed = result.content.get("confirm", False) if result.content else False - tool_result.append("confirmed" if confirmed else "cancelled") - return CallToolResult(content=[TextContent(type="text", text="confirmed" if confirmed else "cancelled")]) - - server = Server("test-scenario1", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) - - # Elicitation callback for client - async def elicitation_callback( - context: RequestContext[ClientSession], - params: ElicitRequestParams, - ) -> ElicitResult: - elicit_received.set() - return ElicitResult(action="accept", content={"confirm": True}) - - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ) - - async def run_client() -> None: - async with ClientSession( - server_to_client_receive, - client_to_server_send, - elicitation_callback=elicitation_callback, - ) as client_session: - await client_session.initialize() - - # Call tool normally (not as task) - result = await client_session.call_tool("confirm_action", {}) - - # Verify elicitation was received and tool completed - assert elicit_received.is_set() - assert len(result.content) > 0 - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "confirmed" - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) - - assert tool_result[0] == "confirmed" - - -@pytest.mark.anyio -async def test_scenario2_normal_tool_task_augmented_elicitation() -> None: - """Scenario 2: Normal tool call with task-augmented elicitation. - - Server calls session.experimental.elicit_as_task(), client creates a task - for the elicitation and returns CreateTaskResult. Server polls client. - """ - elicit_received = Event() - tool_result: list[str] = [] - - # Client-side task store for handling task-augmented elicitation - client_task_store = InMemoryTaskStore() - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult( - tools=[ - Tool( - name="confirm_action", - description="Confirm an action", - input_schema={"type": "object"}, - ) - ] - ) - - async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: - # Task-augmented elicitation - server polls client - result = await ctx.session.experimental.elicit_as_task( - message="Please confirm the action", - requested_schema={"type": "object", "properties": {"confirm": {"type": "boolean"}}}, - ttl=60000, - ) - - confirmed = result.content.get("confirm", False) if result.content else False - tool_result.append("confirmed" if confirmed else "cancelled") - return CallToolResult(content=[TextContent(type="text", text="confirmed" if confirmed else "cancelled")]) - - server = Server("test-scenario2", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) - task_handlers = create_client_task_handlers(client_task_store, elicit_received) - - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ) - - async def run_client() -> None: - async with ClientSession( - server_to_client_receive, - client_to_server_send, - experimental_task_handlers=task_handlers, - ) as client_session: - await client_session.initialize() - - # Call tool normally (not as task) - result = await client_session.call_tool("confirm_action", {}) - - # Verify elicitation was received and tool completed - assert elicit_received.is_set() - assert len(result.content) > 0 - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "confirmed" - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) - - assert tool_result[0] == "confirmed" - client_task_store.cleanup() - - -@pytest.mark.anyio -async def test_scenario3_task_augmented_tool_normal_elicitation() -> None: - """Scenario 3: Task-augmented tool call with normal elicitation. - - Client calls tool as task. Inside the task, server uses task.elicit() - which queues the request and delivers via tasks/result. - """ - elicit_received = Event() - work_completed = Event() - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CreateTaskResult: - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - async def work(task: ServerTaskContext) -> CallToolResult: - # Normal elicitation within task - queued and delivered via tasks/result - result = await task.elicit( - message="Please confirm the action", - requested_schema={"type": "object", "properties": {"confirm": {"type": "boolean"}}}, - ) - - confirmed = result.content.get("confirm", False) if result.content else False - work_completed.set() - return CallToolResult(content=[TextContent(type="text", text="confirmed" if confirmed else "cancelled")]) - - return await ctx.experimental.run_task(work) - - server = Server("test-scenario3", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) - server.experimental.enable_tasks() - - # Elicitation callback for client - async def elicitation_callback( - context: RequestContext[ClientSession], - params: ElicitRequestParams, - ) -> ElicitResult: - elicit_received.set() - return ElicitResult(action="accept", content={"confirm": True}) - - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ) - - async def run_client() -> None: - async with ClientSession( - server_to_client_receive, - client_to_server_send, - elicitation_callback=elicitation_callback, - ) as client_session: - await client_session.initialize() - - # Call tool as task - create_result = await client_session.experimental.call_tool_as_task("confirm_action", {}) - task_id = create_result.task.task_id - assert create_result.task.status == "working" - - # Poll until input_required, then call tasks/result - found_input_required = False - async for status in client_session.experimental.poll_task(task_id): # pragma: no branch - if status.status == "input_required": # pragma: no branch - found_input_required = True - break - assert found_input_required, "Expected to see input_required status" - - # This will deliver the elicitation and get the response - final_result = await client_session.experimental.get_task_result(task_id, CallToolResult) - - # Verify - assert elicit_received.is_set() - assert len(final_result.content) > 0 - assert isinstance(final_result.content[0], TextContent) - assert final_result.content[0].text == "confirmed" - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) - - assert work_completed.is_set() - - -@pytest.mark.anyio -async def test_scenario4_task_augmented_tool_task_augmented_elicitation() -> None: - """Scenario 4: Task-augmented tool call with task-augmented elicitation. - - Client calls tool as task. Inside the task, server uses task.elicit_as_task() - which sends task-augmented elicitation. Client creates its own task for the - elicitation, and server polls the client. - - This tests the full bidirectional flow where: - 1. Client calls tasks/result on server (for tool task) - 2. Server delivers task-augmented elicitation through that stream - 3. Client creates its own task and returns CreateTaskResult - 4. Server polls the client's task while the client's tasks/result is still open - 5. Server gets the ElicitResult and completes the tool task - 6. Client's tasks/result returns with the CallToolResult - """ - elicit_received = Event() - work_completed = Event() - - # Client-side task store for handling task-augmented elicitation - client_task_store = InMemoryTaskStore() - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CreateTaskResult: - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - async def work(task: ServerTaskContext) -> CallToolResult: - # Task-augmented elicitation within task - server polls client - result = await task.elicit_as_task( - message="Please confirm the action", - requested_schema={"type": "object", "properties": {"confirm": {"type": "boolean"}}}, - ttl=60000, - ) - - confirmed = result.content.get("confirm", False) if result.content else False - work_completed.set() - return CallToolResult(content=[TextContent(type="text", text="confirmed" if confirmed else "cancelled")]) - - return await ctx.experimental.run_task(work) - - server = Server("test-scenario4", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) - server.experimental.enable_tasks() - task_handlers = create_client_task_handlers(client_task_store, elicit_received) - - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ) - - async def run_client() -> None: - async with ClientSession( - server_to_client_receive, - client_to_server_send, - experimental_task_handlers=task_handlers, - ) as client_session: - await client_session.initialize() - - # Call tool as task - create_result = await client_session.experimental.call_tool_as_task("confirm_action", {}) - task_id = create_result.task.task_id - assert create_result.task.status == "working" - - # Poll until input_required or terminal, then call tasks/result - found_expected_status = False - async for status in client_session.experimental.poll_task(task_id): # pragma: no branch - if status.status == "input_required" or is_terminal(status.status): # pragma: no branch - found_expected_status = True - break - assert found_expected_status, "Expected to see input_required or terminal status" - - # This will deliver the task-augmented elicitation, - # server will poll client, and eventually return the tool result - final_result = await client_session.experimental.get_task_result(task_id, CallToolResult) - - # Verify - assert elicit_received.is_set() - assert len(final_result.content) > 0 - assert isinstance(final_result.content[0], TextContent) - assert final_result.content[0].text == "confirmed" - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) - - assert work_completed.is_set() - client_task_store.cleanup() - - -@pytest.mark.anyio -async def test_scenario2_sampling_normal_tool_task_augmented_sampling() -> None: - """Scenario 2 for sampling: Normal tool call with task-augmented sampling. - - Server calls session.experimental.create_message_as_task(), client creates - a task for the sampling and returns CreateTaskResult. Server polls client. - """ - sampling_received = Event() - tool_result: list[str] = [] - - # Client-side task store for handling task-augmented sampling - client_task_store = InMemoryTaskStore() - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult( - tools=[ - Tool( - name="generate_text", - description="Generate text using sampling", - input_schema={"type": "object"}, - ) - ] - ) - - async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: - # Task-augmented sampling - server polls client - result = await ctx.session.experimental.create_message_as_task( - messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], - max_tokens=100, - ttl=60000, - ) - - assert isinstance(result.content, TextContent), "Expected TextContent response" - response_text = result.content.text - - tool_result.append(response_text) - return CallToolResult(content=[TextContent(type="text", text=response_text)]) - - server = Server("test-scenario2-sampling", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) - task_handlers = create_sampling_task_handlers(client_task_store, sampling_received) - - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ) - - async def run_client() -> None: - async with ClientSession( - server_to_client_receive, - client_to_server_send, - experimental_task_handlers=task_handlers, - ) as client_session: - await client_session.initialize() - - # Call tool normally (not as task) - result = await client_session.call_tool("generate_text", {}) - - # Verify sampling was received and tool completed - assert sampling_received.is_set() - assert len(result.content) > 0 - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Hello from the model!" - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) - - assert tool_result[0] == "Hello from the model!" - client_task_store.cleanup() - - -@pytest.mark.anyio -async def test_scenario4_sampling_task_augmented_tool_task_augmented_sampling() -> None: - """Scenario 4 for sampling: Task-augmented tool call with task-augmented sampling. - - Client calls tool as task. Inside the task, server uses task.create_message_as_task() - which sends task-augmented sampling. Client creates its own task for the sampling, - and server polls the client. - """ - sampling_received = Event() - work_completed = Event() - - # Client-side task store for handling task-augmented sampling - client_task_store = InMemoryTaskStore() - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CreateTaskResult: - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - async def work(task: ServerTaskContext) -> CallToolResult: - # Task-augmented sampling within task - server polls client - result = await task.create_message_as_task( - messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], - max_tokens=100, - ttl=60000, - ) - - assert isinstance(result.content, TextContent), "Expected TextContent response" - response_text = result.content.text - - work_completed.set() - return CallToolResult(content=[TextContent(type="text", text=response_text)]) - - return await ctx.experimental.run_task(work) - - server = Server("test-scenario4-sampling", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) - server.experimental.enable_tasks() - task_handlers = create_sampling_task_handlers(client_task_store, sampling_received) - - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ) - - async def run_client() -> None: - async with ClientSession( - server_to_client_receive, - client_to_server_send, - experimental_task_handlers=task_handlers, - ) as client_session: - await client_session.initialize() - - # Call tool as task - create_result = await client_session.experimental.call_tool_as_task("generate_text", {}) - task_id = create_result.task.task_id - assert create_result.task.status == "working" - - # Poll until input_required or terminal - found_expected_status = False - async for status in client_session.experimental.poll_task(task_id): # pragma: no branch - if status.status == "input_required" or is_terminal(status.status): # pragma: no branch - found_expected_status = True - break - assert found_expected_status, "Expected to see input_required or terminal status" - - final_result = await client_session.experimental.get_task_result(task_id, CallToolResult) - - # Verify - assert sampling_received.is_set() - assert len(final_result.content) > 0 - assert isinstance(final_result.content[0], TextContent) - assert final_result.content[0].text == "Hello from the model!" - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) - - assert work_completed.is_set() - client_task_store.cleanup() diff --git a/tests/experimental/tasks/test_message_queue.py b/tests/experimental/tasks/test_message_queue.py deleted file mode 100644 index eca113d5b4..0000000000 --- a/tests/experimental/tasks/test_message_queue.py +++ /dev/null @@ -1,330 +0,0 @@ -"""Tests for TaskMessageQueue and InMemoryTaskMessageQueue.""" - -from collections import deque -from datetime import datetime, timezone - -import anyio -import pytest - -from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue, QueuedMessage -from mcp.shared.experimental.tasks.resolver import Resolver -from mcp.types import JSONRPCNotification, JSONRPCRequest - - -@pytest.fixture -def queue() -> InMemoryTaskMessageQueue: - return InMemoryTaskMessageQueue() - - -def make_request(id: int = 1, method: str = "test/method") -> JSONRPCRequest: - return JSONRPCRequest(jsonrpc="2.0", id=id, method=method) - - -def make_notification(method: str = "test/notify") -> JSONRPCNotification: - return JSONRPCNotification(jsonrpc="2.0", method=method) - - -class TestInMemoryTaskMessageQueue: - @pytest.mark.anyio - async def test_enqueue_and_dequeue(self, queue: InMemoryTaskMessageQueue) -> None: - """Test basic enqueue and dequeue operations.""" - task_id = "task-1" - msg = QueuedMessage(type="request", message=make_request()) - - await queue.enqueue(task_id, msg) - result = await queue.dequeue(task_id) - - assert result is not None - assert result.type == "request" - assert result.message.method == "test/method" - - @pytest.mark.anyio - async def test_dequeue_empty_returns_none(self, queue: InMemoryTaskMessageQueue) -> None: - """Dequeue from empty queue returns None.""" - result = await queue.dequeue("nonexistent-task") - assert result is None - - @pytest.mark.anyio - async def test_fifo_ordering(self, queue: InMemoryTaskMessageQueue) -> None: - """Messages are dequeued in FIFO order.""" - task_id = "task-1" - - await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request(1, "first"))) - await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request(2, "second"))) - await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request(3, "third"))) - - msg1 = await queue.dequeue(task_id) - msg2 = await queue.dequeue(task_id) - msg3 = await queue.dequeue(task_id) - - assert msg1 is not None and msg1.message.method == "first" - assert msg2 is not None and msg2.message.method == "second" - assert msg3 is not None and msg3.message.method == "third" - - @pytest.mark.anyio - async def test_separate_queues_per_task(self, queue: InMemoryTaskMessageQueue) -> None: - """Each task has its own queue.""" - await queue.enqueue("task-1", QueuedMessage(type="request", message=make_request(1, "task1-msg"))) - await queue.enqueue("task-2", QueuedMessage(type="request", message=make_request(2, "task2-msg"))) - - msg1 = await queue.dequeue("task-1") - msg2 = await queue.dequeue("task-2") - - assert msg1 is not None and msg1.message.method == "task1-msg" - assert msg2 is not None and msg2.message.method == "task2-msg" - - @pytest.mark.anyio - async def test_peek_does_not_remove(self, queue: InMemoryTaskMessageQueue) -> None: - """Peek returns message without removing it.""" - task_id = "task-1" - await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request())) - - peeked = await queue.peek(task_id) - dequeued = await queue.dequeue(task_id) - - assert peeked is not None - assert dequeued is not None - assert isinstance(peeked.message, JSONRPCRequest) - assert isinstance(dequeued.message, JSONRPCRequest) - assert peeked.message.id == dequeued.message.id - - @pytest.mark.anyio - async def test_is_empty(self, queue: InMemoryTaskMessageQueue) -> None: - """Test is_empty method.""" - task_id = "task-1" - - assert await queue.is_empty(task_id) is True - - await queue.enqueue(task_id, QueuedMessage(type="notification", message=make_notification())) - assert await queue.is_empty(task_id) is False - - await queue.dequeue(task_id) - assert await queue.is_empty(task_id) is True - - @pytest.mark.anyio - async def test_clear_returns_all_messages(self, queue: InMemoryTaskMessageQueue) -> None: - """Clear removes and returns all messages.""" - task_id = "task-1" - - await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request(1))) - await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request(2))) - await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request(3))) - - messages = await queue.clear(task_id) - - assert len(messages) == 3 - assert await queue.is_empty(task_id) is True - - @pytest.mark.anyio - async def test_clear_empty_queue(self, queue: InMemoryTaskMessageQueue) -> None: - """Clear on empty queue returns empty list.""" - messages = await queue.clear("nonexistent") - assert messages == [] - - @pytest.mark.anyio - async def test_notification_messages(self, queue: InMemoryTaskMessageQueue) -> None: - """Test queuing notification messages.""" - task_id = "task-1" - msg = QueuedMessage(type="notification", message=make_notification("log/message")) - - await queue.enqueue(task_id, msg) - result = await queue.dequeue(task_id) - - assert result is not None - assert result.type == "notification" - assert result.message.method == "log/message" - - @pytest.mark.anyio - async def test_message_timestamp(self, queue: InMemoryTaskMessageQueue) -> None: - """Messages have timestamps.""" - before = datetime.now(timezone.utc) - msg = QueuedMessage(type="request", message=make_request()) - after = datetime.now(timezone.utc) - - assert before <= msg.timestamp <= after - - @pytest.mark.anyio - async def test_message_with_resolver(self, queue: InMemoryTaskMessageQueue) -> None: - """Messages can have resolvers.""" - task_id = "task-1" - resolver: Resolver[dict[str, str]] = Resolver() - - msg = QueuedMessage( - type="request", - message=make_request(), - resolver=resolver, - original_request_id=42, - ) - - await queue.enqueue(task_id, msg) - result = await queue.dequeue(task_id) - - assert result is not None - assert result.resolver is resolver - assert result.original_request_id == 42 - - @pytest.mark.anyio - async def test_cleanup_specific_task(self, queue: InMemoryTaskMessageQueue) -> None: - """Cleanup removes specific task's data.""" - await queue.enqueue("task-1", QueuedMessage(type="request", message=make_request(1))) - await queue.enqueue("task-2", QueuedMessage(type="request", message=make_request(2))) - - queue.cleanup("task-1") - - assert await queue.is_empty("task-1") is True - assert await queue.is_empty("task-2") is False - - @pytest.mark.anyio - async def test_cleanup_all(self, queue: InMemoryTaskMessageQueue) -> None: - """Cleanup without task_id removes all data.""" - await queue.enqueue("task-1", QueuedMessage(type="request", message=make_request(1))) - await queue.enqueue("task-2", QueuedMessage(type="request", message=make_request(2))) - - queue.cleanup() - - assert await queue.is_empty("task-1") is True - assert await queue.is_empty("task-2") is True - - @pytest.mark.anyio - async def test_wait_for_message_returns_immediately_if_message_exists( - self, queue: InMemoryTaskMessageQueue - ) -> None: - """wait_for_message returns immediately if queue not empty.""" - task_id = "task-1" - await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request())) - - # Should return immediately, not block - with anyio.fail_after(1): - await queue.wait_for_message(task_id) - - @pytest.mark.anyio - async def test_wait_for_message_blocks_until_message(self, queue: InMemoryTaskMessageQueue) -> None: - """wait_for_message blocks until a message is enqueued.""" - task_id = "task-1" - received = False - waiter_started = anyio.Event() - - async def enqueue_when_ready() -> None: - # Wait until the waiter has started before enqueueing - await waiter_started.wait() - await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request())) - - async def wait_for_msg() -> None: - nonlocal received - # Signal that we're about to start waiting - waiter_started.set() - await queue.wait_for_message(task_id) - received = True - - async with anyio.create_task_group() as tg: - tg.start_soon(wait_for_msg) - tg.start_soon(enqueue_when_ready) - - assert received is True - - @pytest.mark.anyio - async def test_notify_message_available_wakes_waiter(self, queue: InMemoryTaskMessageQueue) -> None: - """notify_message_available wakes up waiting coroutines.""" - task_id = "task-1" - notified = False - waiter_started = anyio.Event() - - async def notify_when_ready() -> None: - # Wait until the waiter has started before notifying - await waiter_started.wait() - await queue.notify_message_available(task_id) - - async def wait_for_notification() -> None: - nonlocal notified - # Signal that we're about to start waiting - waiter_started.set() - await queue.wait_for_message(task_id) - notified = True - - async with anyio.create_task_group() as tg: - tg.start_soon(wait_for_notification) - tg.start_soon(notify_when_ready) - - assert notified is True - - @pytest.mark.anyio - async def test_peek_empty_queue_returns_none(self, queue: InMemoryTaskMessageQueue) -> None: - """Peek on empty queue returns None.""" - result = await queue.peek("nonexistent-task") - assert result is None - - @pytest.mark.anyio - async def test_wait_for_message_double_check_race_condition(self, queue: InMemoryTaskMessageQueue) -> None: - """wait_for_message returns early if message arrives after event creation but before wait.""" - task_id = "task-1" - - # To test the double-check path (lines 223-225), we need a message to arrive - # after the event is created (line 220) but before event.wait() (line 228). - # We simulate this by injecting a message before is_empty is called the second time. - - original_is_empty = queue.is_empty - call_count = 0 - - async def is_empty_with_injection(tid: str) -> bool: - nonlocal call_count - call_count += 1 - if call_count == 2 and tid == task_id: - # Before second check, inject a message - this simulates a message - # arriving between event creation and the double-check - queue._queues[task_id] = deque([QueuedMessage(type="request", message=make_request())]) - return await original_is_empty(tid) - - queue.is_empty = is_empty_with_injection # type: ignore[method-assign] - - # Should return immediately due to double-check finding the message - with anyio.fail_after(1): - await queue.wait_for_message(task_id) - - -class TestResolver: - @pytest.mark.anyio - async def test_set_result_and_wait(self) -> None: - """Test basic set_result and wait flow.""" - resolver: Resolver[str] = Resolver() - - resolver.set_result("hello") - result = await resolver.wait() - - assert result == "hello" - assert resolver.done() - - @pytest.mark.anyio - async def test_set_exception_and_wait(self) -> None: - """Test set_exception raises on wait.""" - resolver: Resolver[str] = Resolver() - - resolver.set_exception(ValueError("test error")) - - with pytest.raises(ValueError, match="test error"): - await resolver.wait() - - assert resolver.done() - - @pytest.mark.anyio - async def test_set_result_when_already_completed_raises(self) -> None: - """Test that set_result raises if resolver already completed.""" - resolver: Resolver[str] = Resolver() - resolver.set_result("first") - - with pytest.raises(RuntimeError, match="already completed"): - resolver.set_result("second") - - @pytest.mark.anyio - async def test_set_exception_when_already_completed_raises(self) -> None: - """Test that set_exception raises if resolver already completed.""" - resolver: Resolver[str] = Resolver() - resolver.set_result("done") - - with pytest.raises(RuntimeError, match="already completed"): - resolver.set_exception(ValueError("too late")) - - @pytest.mark.anyio - async def test_done_returns_false_before_completion(self) -> None: - """Test done() returns False before any result is set.""" - resolver: Resolver[str] = Resolver() - assert resolver.done() is False diff --git a/tests/experimental/tasks/test_request_context.py b/tests/experimental/tasks/test_request_context.py deleted file mode 100644 index ad4023389e..0000000000 --- a/tests/experimental/tasks/test_request_context.py +++ /dev/null @@ -1,166 +0,0 @@ -"""Tests for the RequestContext.experimental (Experimental class) task validation helpers.""" - -import pytest - -from mcp.server.experimental.request_context import Experimental -from mcp.shared.exceptions import MCPError -from mcp.types import ( - METHOD_NOT_FOUND, - TASK_FORBIDDEN, - TASK_OPTIONAL, - TASK_REQUIRED, - ClientCapabilities, - ClientTasksCapability, - TaskMetadata, - Tool, - ToolExecution, -) - - -def test_is_task_true_when_metadata_present() -> None: - exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) - assert exp.is_task is True - - -def test_is_task_false_when_no_metadata() -> None: - exp = Experimental(task_metadata=None) - assert exp.is_task is False - - -def test_client_supports_tasks_true() -> None: - exp = Experimental(_client_capabilities=ClientCapabilities(tasks=ClientTasksCapability())) - assert exp.client_supports_tasks is True - - -def test_client_supports_tasks_false_no_tasks() -> None: - exp = Experimental(_client_capabilities=ClientCapabilities()) - assert exp.client_supports_tasks is False - - -def test_client_supports_tasks_false_no_capabilities() -> None: - exp = Experimental(_client_capabilities=None) - assert exp.client_supports_tasks is False - - -def test_validate_task_mode_required_with_task_is_valid() -> None: - exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) - error = exp.validate_task_mode(TASK_REQUIRED, raise_error=False) - assert error is None - - -def test_validate_task_mode_required_without_task_returns_error() -> None: - exp = Experimental(task_metadata=None) - error = exp.validate_task_mode(TASK_REQUIRED, raise_error=False) - assert error is not None - assert error.code == METHOD_NOT_FOUND - assert "requires task-augmented" in error.message - - -def test_validate_task_mode_required_without_task_raises_by_default() -> None: - exp = Experimental(task_metadata=None) - with pytest.raises(MCPError) as exc_info: - exp.validate_task_mode(TASK_REQUIRED) - assert exc_info.value.error.code == METHOD_NOT_FOUND - - -def test_validate_task_mode_forbidden_without_task_is_valid() -> None: - exp = Experimental(task_metadata=None) - error = exp.validate_task_mode(TASK_FORBIDDEN, raise_error=False) - assert error is None - - -def test_validate_task_mode_forbidden_with_task_returns_error() -> None: - exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) - error = exp.validate_task_mode(TASK_FORBIDDEN, raise_error=False) - assert error is not None - assert error.code == METHOD_NOT_FOUND - assert "does not support task-augmented" in error.message - - -def test_validate_task_mode_forbidden_with_task_raises_by_default() -> None: - exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) - with pytest.raises(MCPError) as exc_info: - exp.validate_task_mode(TASK_FORBIDDEN) - assert exc_info.value.error.code == METHOD_NOT_FOUND - - -def test_validate_task_mode_none_treated_as_forbidden() -> None: - exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) - error = exp.validate_task_mode(None, raise_error=False) - assert error is not None - assert "does not support task-augmented" in error.message - - -def test_validate_task_mode_optional_with_task_is_valid() -> None: - exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) - error = exp.validate_task_mode(TASK_OPTIONAL, raise_error=False) - assert error is None - - -def test_validate_task_mode_optional_without_task_is_valid() -> None: - exp = Experimental(task_metadata=None) - error = exp.validate_task_mode(TASK_OPTIONAL, raise_error=False) - assert error is None - - -def test_validate_for_tool_with_execution_required() -> None: - exp = Experimental(task_metadata=None) - tool = Tool( - name="test", - description="test", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - error = exp.validate_for_tool(tool, raise_error=False) - assert error is not None - assert "requires task-augmented" in error.message - - -def test_validate_for_tool_without_execution() -> None: - exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) - tool = Tool( - name="test", - description="test", - input_schema={"type": "object"}, - execution=None, - ) - error = exp.validate_for_tool(tool, raise_error=False) - assert error is not None - assert "does not support task-augmented" in error.message - - -def test_validate_for_tool_optional_with_task() -> None: - exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) - tool = Tool( - name="test", - description="test", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_OPTIONAL), - ) - error = exp.validate_for_tool(tool, raise_error=False) - assert error is None - - -def test_can_use_tool_required_with_task_support() -> None: - exp = Experimental(_client_capabilities=ClientCapabilities(tasks=ClientTasksCapability())) - assert exp.can_use_tool(TASK_REQUIRED) is True - - -def test_can_use_tool_required_without_task_support() -> None: - exp = Experimental(_client_capabilities=ClientCapabilities()) - assert exp.can_use_tool(TASK_REQUIRED) is False - - -def test_can_use_tool_optional_without_task_support() -> None: - exp = Experimental(_client_capabilities=ClientCapabilities()) - assert exp.can_use_tool(TASK_OPTIONAL) is True - - -def test_can_use_tool_forbidden_without_task_support() -> None: - exp = Experimental(_client_capabilities=ClientCapabilities()) - assert exp.can_use_tool(TASK_FORBIDDEN) is True - - -def test_can_use_tool_none_without_task_support() -> None: - exp = Experimental(_client_capabilities=ClientCapabilities()) - assert exp.can_use_tool(None) is True diff --git a/tests/experimental/tasks/test_spec_compliance.py b/tests/experimental/tasks/test_spec_compliance.py deleted file mode 100644 index 38d7d0a664..0000000000 --- a/tests/experimental/tasks/test_spec_compliance.py +++ /dev/null @@ -1,717 +0,0 @@ -"""Tasks Spec Compliance Tests -=========================== - -Test structure mirrors: https://modelcontextprotocol.io/specification/draft/basic/utilities/tasks.md - -Each section contains tests for normative requirements (MUST/SHOULD/MAY). -""" - -from datetime import datetime, timezone - -import pytest - -from mcp.server import Server, ServerRequestContext -from mcp.server.lowlevel import NotificationOptions -from mcp.shared.experimental.tasks.helpers import MODEL_IMMEDIATE_RESPONSE_KEY -from mcp.types import ( - CancelTaskRequestParams, - CancelTaskResult, - CreateTaskResult, - GetTaskRequestParams, - GetTaskResult, - ListTasksResult, - PaginatedRequestParams, - ServerCapabilities, - Task, -) - -# Shared test datetime -TEST_DATETIME = datetime(2025, 1, 1, tzinfo=timezone.utc) - - -def _get_capabilities(server: Server) -> ServerCapabilities: - """Helper to get capabilities from a server.""" - return server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ) - - -def test_server_without_task_handlers_has_no_tasks_capability() -> None: - """Server without any task handlers has no tasks capability.""" - server: Server = Server("test") - caps = _get_capabilities(server) - assert caps.tasks is None - - -async def _noop_get_task(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: - raise NotImplementedError - - -async def _noop_list_tasks(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListTasksResult: - raise NotImplementedError - - -async def _noop_cancel_task(ctx: ServerRequestContext, params: CancelTaskRequestParams) -> CancelTaskResult: - raise NotImplementedError - - -def test_server_with_list_tasks_handler_declares_list_capability() -> None: - """Server with list_tasks handler declares tasks.list capability.""" - server: Server = Server("test") - server.experimental.enable_tasks(on_list_tasks=_noop_list_tasks) - - caps = _get_capabilities(server) - assert caps.tasks is not None - assert caps.tasks.list is not None - - -def test_server_with_cancel_task_handler_declares_cancel_capability() -> None: - """Server with cancel_task handler declares tasks.cancel capability.""" - server: Server = Server("test") - server.experimental.enable_tasks(on_cancel_task=_noop_cancel_task) - - caps = _get_capabilities(server) - assert caps.tasks is not None - assert caps.tasks.cancel is not None - - -def test_server_with_get_task_handler_declares_requests_tools_call_capability() -> None: - """Server with get_task handler declares tasks.requests.tools.call capability. - (get_task is required for task-augmented tools/call support) - """ - server: Server = Server("test") - server.experimental.enable_tasks(on_get_task=_noop_get_task) - - caps = _get_capabilities(server) - assert caps.tasks is not None - assert caps.tasks.requests is not None - assert caps.tasks.requests.tools is not None - - -@pytest.mark.skip( - reason="TODO(maxisbey): enable_tasks registers default handlers for all task methods, " - "so partial capabilities aren't possible yet. Low-level API should support " - "selectively enabling/disabling task capabilities." -) -def test_server_without_list_handler_has_no_list_capability() -> None: # pragma: no cover - """Server without list_tasks handler has no tasks.list capability.""" - server: Server = Server("test") - server.experimental.enable_tasks(on_get_task=_noop_get_task) - - caps = _get_capabilities(server) - assert caps.tasks is not None - assert caps.tasks.list is None - - -@pytest.mark.skip( - reason="TODO(maxisbey): enable_tasks registers default handlers for all task methods, " - "so partial capabilities aren't possible yet. Low-level API should support " - "selectively enabling/disabling task capabilities." -) -def test_server_without_cancel_handler_has_no_cancel_capability() -> None: # pragma: no cover - """Server without cancel_task handler has no tasks.cancel capability.""" - server: Server = Server("test") - server.experimental.enable_tasks(on_get_task=_noop_get_task) - - caps = _get_capabilities(server) - assert caps.tasks is not None - assert caps.tasks.cancel is None - - -def test_server_with_all_task_handlers_has_full_capability() -> None: - """Server with all task handlers declares complete tasks capability.""" - server: Server = Server("test") - server.experimental.enable_tasks( - on_list_tasks=_noop_list_tasks, - on_cancel_task=_noop_cancel_task, - on_get_task=_noop_get_task, - ) - - caps = _get_capabilities(server) - assert caps.tasks is not None - assert caps.tasks.list is not None - assert caps.tasks.cancel is not None - assert caps.tasks.requests is not None - assert caps.tasks.requests.tools is not None - - -class TestClientCapabilities: - """Clients declare: - - tasks.list — supports listing operations - - tasks.cancel — supports cancellation - - tasks.requests.sampling.createMessage — task-augmented sampling - - tasks.requests.elicitation.create — task-augmented elicitation - """ - - def test_client_declares_tasks_capability(self) -> None: - """Client can declare tasks capability.""" - pytest.skip("TODO") - - -class TestToolLevelNegotiation: - """Tools in tools/list responses include execution.taskSupport with values: - - Not present or "forbidden": No task augmentation allowed - - "optional": Task augmentation allowed at requestor discretion - - "required": Task augmentation is mandatory - """ - - def test_tool_execution_task_forbidden_rejects_task_augmented_call(self) -> None: - """Tool with execution.taskSupport="forbidden" MUST reject task-augmented calls (-32601).""" - pytest.skip("TODO") - - def test_tool_execution_task_absent_rejects_task_augmented_call(self) -> None: - """Tool without execution.taskSupport MUST reject task-augmented calls (-32601).""" - pytest.skip("TODO") - - def test_tool_execution_task_optional_accepts_normal_call(self) -> None: - """Tool with execution.taskSupport="optional" accepts normal calls.""" - pytest.skip("TODO") - - def test_tool_execution_task_optional_accepts_task_augmented_call(self) -> None: - """Tool with execution.taskSupport="optional" accepts task-augmented calls.""" - pytest.skip("TODO") - - def test_tool_execution_task_required_rejects_normal_call(self) -> None: - """Tool with execution.taskSupport="required" MUST reject non-task calls (-32601).""" - pytest.skip("TODO") - - def test_tool_execution_task_required_accepts_task_augmented_call(self) -> None: - """Tool with execution.taskSupport="required" accepts task-augmented calls.""" - pytest.skip("TODO") - - -class TestCapabilityNegotiation: - """Requestors SHOULD only augment requests with a task if the corresponding - capability has been declared by the receiver. - - Receivers that do not declare the task capability for a request type - MUST process requests of that type normally, ignoring any task-augmentation - metadata if present. - """ - - def test_receiver_without_capability_ignores_task_metadata(self) -> None: - """Receiver without task capability MUST process request normally, - ignoring task-augmentation metadata. - """ - pytest.skip("TODO") - - def test_receiver_with_capability_may_require_task_augmentation(self) -> None: - """Receivers that declare task capability MAY return error (-32600) - for non-task-augmented requests, requiring task augmentation. - """ - pytest.skip("TODO") - - -class TestTaskStatusLifecycle: - """Tasks begin in working status and follow valid transitions: - working → input_required → working → terminal - working → terminal (directly) - input_required → terminal (directly) - - Terminal states (no further transitions allowed): - - completed - - failed - - cancelled - """ - - def test_task_begins_in_working_status(self) -> None: - """Tasks MUST begin in working status.""" - pytest.skip("TODO") - - def test_working_to_completed_transition(self) -> None: - """working → completed is valid.""" - pytest.skip("TODO") - - def test_working_to_failed_transition(self) -> None: - """working → failed is valid.""" - pytest.skip("TODO") - - def test_working_to_cancelled_transition(self) -> None: - """working → cancelled is valid.""" - pytest.skip("TODO") - - def test_working_to_input_required_transition(self) -> None: - """working → input_required is valid.""" - pytest.skip("TODO") - - def test_input_required_to_working_transition(self) -> None: - """input_required → working is valid.""" - pytest.skip("TODO") - - def test_input_required_to_terminal_transition(self) -> None: - """input_required → terminal is valid.""" - pytest.skip("TODO") - - def test_terminal_state_no_further_transitions(self) -> None: - """Terminal states allow no further transitions.""" - pytest.skip("TODO") - - def test_completed_is_terminal(self) -> None: - """completed is a terminal state.""" - pytest.skip("TODO") - - def test_failed_is_terminal(self) -> None: - """failed is a terminal state.""" - pytest.skip("TODO") - - def test_cancelled_is_terminal(self) -> None: - """cancelled is a terminal state.""" - pytest.skip("TODO") - - -class TestInputRequiredStatus: - """When a receiver needs information to proceed, it moves the task to input_required. - The requestor should call tasks/result to retrieve input requests. - The task must include io.modelcontextprotocol/related-task metadata in associated requests. - """ - - def test_input_required_status_retrievable_via_tasks_get(self) -> None: - """Task in input_required status is retrievable via tasks/get.""" - pytest.skip("TODO") - - def test_input_required_related_task_metadata_in_requests(self) -> None: - """Task MUST include io.modelcontextprotocol/related-task metadata - in associated requests. - """ - pytest.skip("TODO") - - -class TestCreatingTask: - """Request structure: - {"method": "tools/call", "params": {"name": "...", "arguments": {...}, "task": {"ttl": 60000}}} - - Response (CreateTaskResult): - {"result": {"task": {"taskId": "...", "status": "working", ...}}} - - Receivers may include io.modelcontextprotocol/model-immediate-response in _meta. - """ - - def test_task_augmented_request_returns_create_task_result(self) -> None: - """Task-augmented request MUST return CreateTaskResult immediately.""" - pytest.skip("TODO") - - def test_create_task_result_contains_task_id(self) -> None: - """CreateTaskResult MUST contain taskId.""" - pytest.skip("TODO") - - def test_create_task_result_contains_status_working(self) -> None: - """CreateTaskResult MUST have status=working initially.""" - pytest.skip("TODO") - - def test_create_task_result_contains_created_at(self) -> None: - """CreateTaskResult MUST contain createdAt timestamp.""" - pytest.skip("TODO") - - def test_create_task_result_created_at_is_iso8601(self) -> None: - """createdAt MUST be ISO 8601 formatted.""" - pytest.skip("TODO") - - def test_create_task_result_may_contain_ttl(self) -> None: - """CreateTaskResult MAY contain ttl.""" - pytest.skip("TODO") - - def test_create_task_result_may_contain_poll_interval(self) -> None: - """CreateTaskResult MAY contain pollInterval.""" - pytest.skip("TODO") - - def test_create_task_result_may_contain_status_message(self) -> None: - """CreateTaskResult MAY contain statusMessage.""" - pytest.skip("TODO") - - def test_receiver_may_override_requested_ttl(self) -> None: - """Receiver MAY override requested ttl but MUST return actual value.""" - pytest.skip("TODO") - - def test_model_immediate_response_in_meta(self) -> None: - """Receiver MAY include io.modelcontextprotocol/model-immediate-response - in _meta to provide immediate response while task executes. - """ - # Verify the constant has the correct value per spec - assert MODEL_IMMEDIATE_RESPONSE_KEY == "io.modelcontextprotocol/model-immediate-response" - - # CreateTaskResult can include model-immediate-response in _meta - task = Task( - task_id="test-123", - status="working", - created_at=TEST_DATETIME, - last_updated_at=TEST_DATETIME, - ttl=60000, - ) - immediate_msg = "Task started, processing your request..." - # Note: Must use _meta= (alias) not meta= due to Pydantic alias handling - result = CreateTaskResult( - task=task, - **{"_meta": {MODEL_IMMEDIATE_RESPONSE_KEY: immediate_msg}}, - ) - - # Verify the metadata is present and correct - assert result.meta is not None - assert MODEL_IMMEDIATE_RESPONSE_KEY in result.meta - assert result.meta[MODEL_IMMEDIATE_RESPONSE_KEY] == immediate_msg - - # Verify it serializes correctly with _meta alias - serialized = result.model_dump(by_alias=True) - assert "_meta" in serialized - assert MODEL_IMMEDIATE_RESPONSE_KEY in serialized["_meta"] - assert serialized["_meta"][MODEL_IMMEDIATE_RESPONSE_KEY] == immediate_msg - - -class TestGettingTaskStatus: - """Request: {"method": "tasks/get", "params": {"taskId": "..."}} - Response: Returns full Task object with current status and pollInterval. - """ - - def test_tasks_get_returns_task_object(self) -> None: - """tasks/get MUST return full Task object.""" - pytest.skip("TODO") - - def test_tasks_get_returns_current_status(self) -> None: - """tasks/get MUST return current status.""" - pytest.skip("TODO") - - def test_tasks_get_may_return_poll_interval(self) -> None: - """tasks/get MAY return pollInterval.""" - pytest.skip("TODO") - - def test_tasks_get_invalid_task_id_returns_error(self) -> None: - """tasks/get with invalid taskId MUST return -32602.""" - pytest.skip("TODO") - - def test_tasks_get_nonexistent_task_id_returns_error(self) -> None: - """tasks/get with nonexistent taskId MUST return -32602.""" - pytest.skip("TODO") - - -class TestRetrievingResults: - """Request: {"method": "tasks/result", "params": {"taskId": "..."}} - Response: The actual operation result structure (e.g., CallToolResult). - - This call blocks until terminal status. - """ - - def test_tasks_result_returns_underlying_result(self) -> None: - """tasks/result MUST return exactly what underlying request would return.""" - pytest.skip("TODO") - - def test_tasks_result_blocks_until_terminal(self) -> None: - """tasks/result MUST block for non-terminal tasks.""" - pytest.skip("TODO") - - def test_tasks_result_unblocks_on_terminal(self) -> None: - """tasks/result MUST unblock upon reaching terminal status.""" - pytest.skip("TODO") - - def test_tasks_result_includes_related_task_metadata(self) -> None: - """tasks/result MUST include io.modelcontextprotocol/related-task in _meta.""" - pytest.skip("TODO") - - def test_tasks_result_returns_error_for_failed_task(self) -> None: - """tasks/result returns the same error the underlying request - would have produced for failed tasks. - """ - pytest.skip("TODO") - - def test_tasks_result_invalid_task_id_returns_error(self) -> None: - """tasks/result with invalid taskId MUST return -32602.""" - pytest.skip("TODO") - - -class TestListingTasks: - """Request: {"method": "tasks/list", "params": {"cursor": "optional"}} - Response: Array of tasks with pagination support via nextCursor. - """ - - def test_tasks_list_returns_array_of_tasks(self) -> None: - """tasks/list MUST return array of tasks.""" - pytest.skip("TODO") - - def test_tasks_list_pagination_with_cursor(self) -> None: - """tasks/list supports pagination via cursor.""" - pytest.skip("TODO") - - def test_tasks_list_returns_next_cursor_when_more_results(self) -> None: - """tasks/list MUST return nextCursor when more results available.""" - pytest.skip("TODO") - - def test_tasks_list_cursors_are_opaque(self) -> None: - """Implementers MUST treat cursors as opaque tokens.""" - pytest.skip("TODO") - - def test_tasks_list_invalid_cursor_returns_error(self) -> None: - """tasks/list with invalid cursor MUST return -32602.""" - pytest.skip("TODO") - - -class TestCancellingTasks: - """Request: {"method": "tasks/cancel", "params": {"taskId": "..."}} - Response: Returns the task object with status: "cancelled". - """ - - def test_tasks_cancel_returns_cancelled_task(self) -> None: - """tasks/cancel MUST return task with status=cancelled.""" - pytest.skip("TODO") - - def test_tasks_cancel_terminal_task_returns_error(self) -> None: - """Cancelling already-terminal task MUST return -32602.""" - pytest.skip("TODO") - - def test_tasks_cancel_completed_task_returns_error(self) -> None: - """Cancelling completed task MUST return -32602.""" - pytest.skip("TODO") - - def test_tasks_cancel_failed_task_returns_error(self) -> None: - """Cancelling failed task MUST return -32602.""" - pytest.skip("TODO") - - def test_tasks_cancel_already_cancelled_task_returns_error(self) -> None: - """Cancelling already-cancelled task MUST return -32602.""" - pytest.skip("TODO") - - def test_tasks_cancel_invalid_task_id_returns_error(self) -> None: - """tasks/cancel with invalid taskId MUST return -32602.""" - pytest.skip("TODO") - - -class TestStatusNotifications: - """Receivers MAY send: {"method": "notifications/tasks/status", "params": {...}} - These are optional; requestors MUST NOT rely on them and SHOULD continue polling. - """ - - def test_receiver_may_send_status_notification(self) -> None: - """Receiver MAY send notifications/tasks/status.""" - pytest.skip("TODO") - - def test_status_notification_contains_task_id(self) -> None: - """Status notification MUST contain taskId.""" - pytest.skip("TODO") - - def test_status_notification_contains_status(self) -> None: - """Status notification MUST contain status.""" - pytest.skip("TODO") - - -class TestTaskManagement: - """- Receivers generate unique task IDs as strings - - Tasks must begin in working status - - createdAt timestamps must be ISO 8601 formatted - - Receivers may override requested ttl but must return actual value - - Receivers may delete tasks after TTL expires - - All task-related messages must include io.modelcontextprotocol/related-task - in _meta except for tasks/get, tasks/list, tasks/cancel operations - """ - - def test_task_ids_are_unique_strings(self) -> None: - """Receivers MUST generate unique task IDs as strings.""" - pytest.skip("TODO") - - def test_multiple_tasks_have_unique_ids(self) -> None: - """Multiple tasks MUST have unique IDs.""" - pytest.skip("TODO") - - def test_receiver_may_delete_tasks_after_ttl(self) -> None: - """Receivers MAY delete tasks after TTL expires.""" - pytest.skip("TODO") - - def test_related_task_metadata_in_task_messages(self) -> None: - """All task-related messages MUST include io.modelcontextprotocol/related-task - in _meta. - """ - pytest.skip("TODO") - - def test_tasks_get_does_not_require_related_task_metadata(self) -> None: - """tasks/get does not require related-task metadata.""" - pytest.skip("TODO") - - def test_tasks_list_does_not_require_related_task_metadata(self) -> None: - """tasks/list does not require related-task metadata.""" - pytest.skip("TODO") - - def test_tasks_cancel_does_not_require_related_task_metadata(self) -> None: - """tasks/cancel does not require related-task metadata.""" - pytest.skip("TODO") - - -class TestResultHandling: - """- Receivers must return CreateTaskResult immediately upon accepting task-augmented requests - - tasks/result must return exactly what the underlying request would return - - tasks/result blocks for non-terminal tasks; must unblock upon reaching terminal status - """ - - def test_create_task_result_returned_immediately(self) -> None: - """Receiver MUST return CreateTaskResult immediately (not after work completes).""" - pytest.skip("TODO") - - def test_tasks_result_matches_underlying_result_structure(self) -> None: - """tasks/result MUST return same structure as underlying request.""" - pytest.skip("TODO") - - def test_tasks_result_for_tool_call_returns_call_tool_result(self) -> None: - """tasks/result for tools/call returns CallToolResult.""" - pytest.skip("TODO") - - -class TestProgressTracking: - """Task-augmented requests support progress notifications using the progressToken - mechanism, which remains valid throughout the task lifetime. - """ - - def test_progress_token_valid_throughout_task_lifetime(self) -> None: - """progressToken remains valid throughout task lifetime.""" - pytest.skip("TODO") - - def test_progress_notifications_sent_during_task_execution(self) -> None: - """Progress notifications can be sent during task execution.""" - pytest.skip("TODO") - - -class TestProtocolErrors: - """Protocol Errors (JSON-RPC standard codes): - - -32600 (Invalid request): Non-task requests to endpoint requiring task augmentation - - -32602 (Invalid params): Invalid/nonexistent taskId, invalid cursor, cancel terminal task - - -32603 (Internal error): Server-side execution failures - """ - - def test_invalid_request_for_required_task_augmentation(self) -> None: - """Non-task request to task-required endpoint returns -32600.""" - pytest.skip("TODO") - - def test_invalid_params_for_invalid_task_id(self) -> None: - """Invalid taskId returns -32602.""" - pytest.skip("TODO") - - def test_invalid_params_for_nonexistent_task_id(self) -> None: - """Nonexistent taskId returns -32602.""" - pytest.skip("TODO") - - def test_invalid_params_for_invalid_cursor(self) -> None: - """Invalid cursor in tasks/list returns -32602.""" - pytest.skip("TODO") - - def test_invalid_params_for_cancel_terminal_task(self) -> None: - """Attempt to cancel terminal task returns -32602.""" - pytest.skip("TODO") - - def test_internal_error_for_server_failure(self) -> None: - """Server-side execution failure returns -32603.""" - pytest.skip("TODO") - - -class TestTaskExecutionErrors: - """When underlying requests fail, the task moves to failed status. - - tasks/get response should include statusMessage explaining failure - - tasks/result returns same error the underlying request would have produced - - For tool calls, isError: true moves task to failed status - """ - - def test_underlying_failure_moves_task_to_failed(self) -> None: - """Underlying request failure moves task to failed status.""" - pytest.skip("TODO") - - def test_failed_task_has_status_message(self) -> None: - """Failed task SHOULD include statusMessage explaining failure.""" - pytest.skip("TODO") - - def test_tasks_result_returns_underlying_error(self) -> None: - """tasks/result returns same error underlying request would produce.""" - pytest.skip("TODO") - - def test_tool_call_is_error_true_moves_to_failed(self) -> None: - """Tool call with isError: true moves task to failed status.""" - pytest.skip("TODO") - - -class TestTaskObject: - """Task Object fields: - - taskId: String identifier - - status: Current execution state - - statusMessage: Optional human-readable description - - createdAt: ISO 8601 timestamp of creation - - ttl: Milliseconds before potential deletion - - pollInterval: Suggested milliseconds between polls - """ - - def test_task_has_task_id_string(self) -> None: - """Task MUST have taskId as string.""" - pytest.skip("TODO") - - def test_task_has_status(self) -> None: - """Task MUST have status.""" - pytest.skip("TODO") - - def test_task_status_message_is_optional(self) -> None: - """Task statusMessage is optional.""" - pytest.skip("TODO") - - def test_task_has_created_at(self) -> None: - """Task MUST have createdAt.""" - pytest.skip("TODO") - - def test_task_ttl_is_optional(self) -> None: - """Task ttl is optional.""" - pytest.skip("TODO") - - def test_task_poll_interval_is_optional(self) -> None: - """Task pollInterval is optional.""" - pytest.skip("TODO") - - -class TestRelatedTaskMetadata: - """Related Task Metadata structure: - {"_meta": {"io.modelcontextprotocol/related-task": {"taskId": "..."}}} - """ - - def test_related_task_metadata_structure(self) -> None: - """Related task metadata has correct structure.""" - pytest.skip("TODO") - - def test_related_task_metadata_contains_task_id(self) -> None: - """Related task metadata contains taskId.""" - pytest.skip("TODO") - - -class TestAccessAndIsolation: - """- Task IDs enable access to sensitive results - - Authorization context binding is essential where available - - For non-authorized environments: strong entropy IDs, strict TTL limits - """ - - def test_task_bound_to_authorization_context(self) -> None: - """Receivers receiving authorization context MUST bind tasks to that context.""" - pytest.skip("TODO") - - def test_reject_task_operations_outside_authorization_context(self) -> None: - """Receivers MUST reject task operations for tasks outside - requestor's authorization context. - """ - pytest.skip("TODO") - - def test_non_authorized_environments_use_secure_ids(self) -> None: - """For non-authorized environments, receivers SHOULD use - cryptographically secure IDs. - """ - pytest.skip("TODO") - - def test_non_authorized_environments_use_shorter_ttls(self) -> None: - """For non-authorized environments, receivers SHOULD use shorter TTLs.""" - pytest.skip("TODO") - - -class TestResourceLimits: - """Receivers should: - - Enforce concurrent task limits per requestor - - Implement maximum TTL constraints - - Clean up expired tasks promptly - """ - - def test_concurrent_task_limit_enforced(self) -> None: - """Receiver SHOULD enforce concurrent task limits per requestor.""" - pytest.skip("TODO") - - def test_maximum_ttl_constraint_enforced(self) -> None: - """Receiver SHOULD implement maximum TTL constraints.""" - pytest.skip("TODO") - - def test_expired_tasks_cleaned_up(self) -> None: - """Receiver SHOULD clean up expired tasks promptly.""" - pytest.skip("TODO") diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 109b30fc77..caed8905d0 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -49,8 +49,9 @@ _SOURCE_PATTERN = re.compile(r"https://modelcontextprotocol\.io/specification/.+|sdk|issue:#\d+") _TASKS_DEFERRAL = ( - "Tasks are experimental and the spec is being substantially revised; python task behaviour is " - "covered by tests/experimental/tasks/ until the next spec revision settles." + "Tasks have been removed from the draft spec and from this SDK; they are expected to return " + "as a separate MCP extension. These 2025-11-25 requirements are tracked but intentionally " + "unimplemented." ) diff --git a/tests/issues/test_176_progress_token.py b/tests/issues/test_176_progress_token.py index 5d5f8b8fc9..bef44928ac 100644 --- a/tests/issues/test_176_progress_token.py +++ b/tests/issues/test_176_progress_token.py @@ -3,7 +3,6 @@ import pytest from mcp.server.context import ServerRequestContext -from mcp.server.experimental.request_context import Experimental from mcp.server.mcpserver import Context pytestmark = pytest.mark.anyio @@ -22,7 +21,6 @@ async def test_progress_token_zero_first_call(): session=mock_session, meta={"progress_token": 0}, lifespan_context=None, - experimental=Experimental(), ) # Create context with our mocks diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index 3457ec944a..21352b5f2f 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -11,7 +11,6 @@ from mcp.client import Client from mcp.server.context import ServerRequestContext -from mcp.server.experimental.request_context import Experimental from mcp.server.mcpserver import Context, MCPServer from mcp.server.mcpserver.exceptions import ToolError from mcp.server.mcpserver.prompts.base import Message, UserMessage @@ -1502,7 +1501,6 @@ async def test_report_progress_passes_related_request_id(): session=mock_session, meta={"progress_token": "tok-1"}, lifespan_context=None, - experimental=Experimental(), ) ctx = Context(request_context=request_context, mcp_server=MagicMock()) diff --git a/tests/server/test_session.py b/tests/server/test_session.py index a2786d865d..6116a7c7f5 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -76,6 +76,49 @@ async def run_server(): assert received_initialized +@pytest.mark.anyio +async def test_check_client_capability(): + """check_client_capability reflects the capabilities sent by the client at initialize.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + + initialized = anyio.Event() + + async def list_roots_callback(context: Any) -> types.ListRootsResult: # pragma: no cover + return types.ListRootsResult(roots=[]) + + async def run_server(server_session: ServerSession): + async for message in server_session.incoming_messages: # pragma: no branch + if isinstance(message, ClientNotification) and isinstance( + message, InitializedNotification + ): # pragma: no branch + initialized.set() + return + + async with ( + ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions(server_name="mcp", server_version="0.1.0", capabilities=ServerCapabilities()), + ) as server_session, + ClientSession( + server_to_client_receive, + client_to_server_send, + list_roots_callback=list_roots_callback, + ) as client_session, + anyio.create_task_group() as tg, + ): + tg.start_soon(run_server, server_session) + await client_session.initialize() + with anyio.fail_after(5): + await initialized.wait() + + # ClientSession advertises roots when a list_roots_callback is provided. + assert server_session.check_client_capability(types.ClientCapabilities(roots=types.RootsCapability())) + # ClientSession does not advertise sampling without a sampling_callback. + assert not server_session.check_client_capability(types.ClientCapabilities(sampling=types.SamplingCapability())) + + @pytest.mark.anyio async def test_server_capabilities(): notification_options = NotificationOptions() diff --git a/uv.lock b/uv.lock index 5b72e97fce..df63607f40 100644 --- a/uv.lock +++ b/uv.lock @@ -18,10 +18,6 @@ members = [ "mcp-simple-resource", "mcp-simple-streamablehttp", "mcp-simple-streamablehttp-stateless", - "mcp-simple-task", - "mcp-simple-task-client", - "mcp-simple-task-interactive", - "mcp-simple-task-interactive-client", "mcp-simple-tool", "mcp-snippets", "mcp-sse-polling-client", @@ -1268,126 +1264,6 @@ dev = [ { name = "ruff", specifier = ">=0.6.9" }, ] -[[package]] -name = "mcp-simple-task" -version = "0.1.0" -source = { editable = "examples/servers/simple-task" } -dependencies = [ - { name = "anyio" }, - { name = "click" }, - { name = "mcp" }, - { name = "starlette" }, - { name = "uvicorn" }, -] - -[package.dev-dependencies] -dev = [ - { name = "pyright" }, - { name = "ruff" }, -] - -[package.metadata] -requires-dist = [ - { name = "anyio", specifier = ">=4.5" }, - { name = "click", specifier = ">=8.0" }, - { name = "mcp", editable = "." }, - { name = "starlette" }, - { name = "uvicorn" }, -] - -[package.metadata.requires-dev] -dev = [ - { name = "pyright", specifier = ">=1.1.378" }, - { name = "ruff", specifier = ">=0.6.9" }, -] - -[[package]] -name = "mcp-simple-task-client" -version = "0.1.0" -source = { editable = "examples/clients/simple-task-client" } -dependencies = [ - { name = "click" }, - { name = "mcp" }, -] - -[package.dev-dependencies] -dev = [ - { name = "pyright" }, - { name = "ruff" }, -] - -[package.metadata] -requires-dist = [ - { name = "click", specifier = ">=8.0" }, - { name = "mcp", editable = "." }, -] - -[package.metadata.requires-dev] -dev = [ - { name = "pyright", specifier = ">=1.1.378" }, - { name = "ruff", specifier = ">=0.6.9" }, -] - -[[package]] -name = "mcp-simple-task-interactive" -version = "0.1.0" -source = { editable = "examples/servers/simple-task-interactive" } -dependencies = [ - { name = "anyio" }, - { name = "click" }, - { name = "mcp" }, - { name = "starlette" }, - { name = "uvicorn" }, -] - -[package.dev-dependencies] -dev = [ - { name = "pyright" }, - { name = "ruff" }, -] - -[package.metadata] -requires-dist = [ - { name = "anyio", specifier = ">=4.5" }, - { name = "click", specifier = ">=8.0" }, - { name = "mcp", editable = "." }, - { name = "starlette" }, - { name = "uvicorn" }, -] - -[package.metadata.requires-dev] -dev = [ - { name = "pyright", specifier = ">=1.1.378" }, - { name = "ruff", specifier = ">=0.6.9" }, -] - -[[package]] -name = "mcp-simple-task-interactive-client" -version = "0.1.0" -source = { editable = "examples/clients/simple-task-interactive-client" } -dependencies = [ - { name = "click" }, - { name = "mcp" }, -] - -[package.dev-dependencies] -dev = [ - { name = "pyright" }, - { name = "ruff" }, -] - -[package.metadata] -requires-dist = [ - { name = "click", specifier = ">=8.0" }, - { name = "mcp", editable = "." }, -] - -[package.metadata.requires-dev] -dev = [ - { name = "pyright", specifier = ">=1.1.378" }, - { name = "ruff", specifier = ">=0.6.9" }, -] - [[package]] name = "mcp-simple-tool" version = "0.1.0" From c91f4069386ad724b275f537ee0f0af2af5a6a1c Mon Sep 17 00:00:00 2001 From: Max <224885523+maxisbey@users.noreply.github.com> Date: Tue, 2 Jun 2026 17:27:41 +0100 Subject: [PATCH 74/84] Require protocol_version to be a string (#2763) --- src/mcp/client/streamable_http.py | 2 +- src/mcp/types/_types.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index aa3e50e07e..9cdf717c73 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -120,7 +120,7 @@ def _maybe_extract_protocol_version_from_message(self, message: JSONRPCMessage) try: # Parse the result as InitializeResult for type safety init_result = InitializeResult.model_validate(message.result, by_name=False) - self.protocol_version = str(init_result.protocol_version) + self.protocol_version = init_result.protocol_version logger.info(f"Negotiated protocol version: {self.protocol_version}") except Exception: # pragma: no cover logger.warning("Failed to parse initialization response as InitializeResult", exc_info=True) diff --git a/src/mcp/types/_types.py b/src/mcp/types/_types.py index 34800ba12e..e9d39ef6f3 100644 --- a/src/mcp/types/_types.py +++ b/src/mcp/types/_types.py @@ -304,7 +304,7 @@ class ServerCapabilities(MCPModel): class InitializeRequestParams(RequestParams): """Parameters for the initialize request.""" - protocol_version: str | int + protocol_version: str """The latest version of the Model Context Protocol that the client supports.""" capabilities: ClientCapabilities client_info: Implementation @@ -322,7 +322,7 @@ class InitializeRequest(Request[InitializeRequestParams, Literal["initialize"]]) class InitializeResult(Result): """After receiving an initialize request from the client, the server sends this.""" - protocol_version: str | int + protocol_version: str """The version of the Model Context Protocol that the server wants to use.""" capabilities: ServerCapabilities server_info: Implementation From 60c04207750f15cb75a0bd6f9eed78cfe06bc492 Mon Sep 17 00:00:00 2001 From: Aryan Motgi <85900811+aryanmotgi@users.noreply.github.com> Date: Tue, 2 Jun 2026 09:28:19 -0700 Subject: [PATCH 75/84] docs: correct create_mcp_http_client default timeout docstring (#2683) Co-authored-by: aryanmotgi Co-authored-by: Marcelo Trylesinski --- src/mcp/shared/_httpx_utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/mcp/shared/_httpx_utils.py b/src/mcp/shared/_httpx_utils.py index 251469eaa1..6a121aff6d 100644 --- a/src/mcp/shared/_httpx_utils.py +++ b/src/mcp/shared/_httpx_utils.py @@ -27,14 +27,12 @@ def create_mcp_http_client( ) -> httpx.AsyncClient: """Create a standardized httpx AsyncClient with MCP defaults. - This function provides common defaults used throughout the MCP codebase: - - follow_redirects=True (always enabled) - - Default timeout of 30 seconds if not specified + Always enables follow_redirects and applies an SSE-friendly default timeout. Args: headers: Optional headers to include with all requests. - timeout: Request timeout as httpx.Timeout object. - Defaults to 30 seconds if not specified. + timeout: Request timeout as httpx.Timeout object. Defaults to 30s for + connect/write/pool and 300s for read (for long-lived SSE streams). auth: Optional authentication handler. Returns: From 4f6f0e8bf5544253e77b4f8d4cad804db5e2e5ef Mon Sep 17 00:00:00 2001 From: Zach Leventer Date: Tue, 2 Jun 2026 12:30:57 -0400 Subject: [PATCH 76/84] Remove dead commented-out code in register_client (#2500) --- src/mcp/client/auth/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/mcp/client/auth/utils.py b/src/mcp/client/auth/utils.py index d75324f2f0..780a24e859 100644 --- a/src/mcp/client/auth/utils.py +++ b/src/mcp/client/auth/utils.py @@ -240,8 +240,6 @@ async def handle_registration_response(response: Response) -> OAuthClientInforma content = await response.aread() client_info = OAuthClientInformationFull.model_validate_json(content) return client_info - # self.context.client_info = client_info - # await self.context.storage.set_client_info(client_info) except ValidationError as e: # pragma: no cover raise OAuthRegistrationError(f"Invalid registration response: {e}") From a9381263275a86257002c1e34101c1dce70cbc05 Mon Sep 17 00:00:00 2001 From: yukawithdata <90426808+yuka-with-data@users.noreply.github.com> Date: Tue, 2 Jun 2026 09:32:53 -0700 Subject: [PATCH 77/84] Doc: Clarify MCP Client-Server model in What is MCP section (#2459) --- README.v2.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/README.v2.md b/README.v2.md index a888f21bda..bae230c3f9 100644 --- a/README.v2.md +++ b/README.v2.md @@ -207,7 +207,11 @@ In the inspector UI, connect to `http://localhost:8000/mcp`. ## What is MCP? -The [Model Context Protocol (MCP)](https://modelcontextprotocol.io) lets you build servers that expose data and functionality to LLM applications in a secure, standardized way. Think of it like a web API, but specifically designed for LLM interactions. MCP servers can: +The [Model Context Protocol (MCP)](https://modelcontextprotocol.io) lets you build servers that expose data and functionality to LLM applications in a secure, standardized way. Think of it like a web API, but specifically designed for LLM interactions. + +MCP follows a **client-server model**, where LLM applications act as clients and connect to MCP servers to access capabilities such as data retrieval and tool execution in a consistent format. + +MCP servers can: - Expose data through **Resources** (think of these sort of like GET endpoints; they are used to load information into the LLM's context) - Provide functionality through **Tools** (sort of like POST endpoints; they are used to execute code or otherwise produce a side effect) From ed6adeee7e698ba66ee6434ee2bd69d0ec8729ea Mon Sep 17 00:00:00 2001 From: Max <224885523+maxisbey@users.noreply.github.com> Date: Tue, 2 Jun 2026 18:34:39 +0100 Subject: [PATCH 78/84] docs: require a passing conformance test for new 2026-07-28 spec features (#2761) --- AGENTS.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/AGENTS.md b/AGENTS.md index 307bd81b3e..3d4c43dd72 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -65,6 +65,11 @@ tests. Don't silence warnings from your own code; fix the underlying cause. Scoped `ignore::` entries for upstream libraries are acceptable in `pyproject.toml` with a comment explaining why. +- New features from the 2026-07-28 spec must have a matching test in the + [conformance suite](https://github.com/modelcontextprotocol/conformance) + that passes against this SDK (CI runs it via + `.github/workflows/conformance.yml`). If no matching test exists, stop and + tell the user so they can raise an issue on the conformance repo. ### Coverage From b3025f93b4ba63c6d0d04b1357947ecdb78eb0bd Mon Sep 17 00:00:00 2001 From: Max <224885523+maxisbey@users.noreply.github.com> Date: Tue, 2 Jun 2026 20:30:17 +0100 Subject: [PATCH 79/84] Run transport security tests in process instead of over sockets (#2764) --- src/mcp/server/streamable_http.py | 2 +- tests/interaction/transports/__init__.py | 9 + tests/server/test_sse_security.py | 309 ++++++----------- tests/server/test_streamable_http_manager.py | 29 +- tests/server/test_streamable_http_security.py | 317 +++++------------- 5 files changed, 212 insertions(+), 454 deletions(-) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index f2f4407cea..98948ff999 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -669,7 +669,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: await response(request.scope, request.receive, send) return - if not await self._validate_request_headers(request, send): # pragma: no cover + if not await self._validate_request_headers(request, send): return # Handle resumability: check for Last-Event-ID header diff --git a/tests/interaction/transports/__init__.py b/tests/interaction/transports/__init__.py index e69de29bb2..b5bbb633c2 100644 --- a/tests/interaction/transports/__init__.py +++ b/tests/interaction/transports/__init__.py @@ -0,0 +1,9 @@ +"""Transport-specific interaction tests, and the in-process streaming bridge they are built on. + +`StreamingASGITransport` is re-exported here as the sanctioned import point for test code +outside this suite (the bridge module itself is suite-private). +""" + +from tests.interaction.transports._bridge import StreamingASGITransport + +__all__ = ["StreamingASGITransport"] diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index e95dc51b31..e77bd5e2c2 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -1,15 +1,12 @@ """Tests for SSE server request validation.""" import logging -import multiprocessing import re -import socket import anyio import httpx import pytest import sse_starlette.sse -import uvicorn from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response @@ -23,12 +20,16 @@ from mcp.server.transport_security import TransportSecuritySettings from mcp.shared._stream_protocols import WriteStream from mcp.shared.message import SessionMessage -from mcp.types import JSONRPCRequest, JSONRPCResponse, Tool -from tests.test_helpers import wait_for_server +from mcp.types import JSONRPCRequest, JSONRPCResponse +from tests.interaction.transports import StreamingASGITransport logger = logging.getLogger(__name__) SERVER_NAME = "test_sse_security_server" +# The in-process app is mounted at this origin purely so URLs are well-formed and the default +# Host header is a localhost form; nothing listens here. +BASE_URL = "http://127.0.0.1:8000" + @pytest.fixture(autouse=True) def reset_sse_starlette_exit_event() -> None: @@ -39,275 +40,161 @@ def reset_sse_starlette_exit_event() -> None: app_status.should_exit_event = None -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: # pragma: no cover - return f"http://127.0.0.1:{server_port}" - - -class SecurityTestServer(Server): # pragma: no cover - def __init__(self): - super().__init__(SERVER_NAME) - - async def on_list_tools(self) -> list[Tool]: - return [] - - -def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): # pragma: no cover - """Run the SSE server with specified security settings.""" - app = SecurityTestServer() +def sse_security_client(security_settings: TransportSecuritySettings | None = None) -> httpx.AsyncClient: + """An httpx client whose requests are served in process by an SSE app with the given settings.""" + server = Server(SERVER_NAME) sse_transport = SseServerTransport("/messages/", security_settings) - async def handle_sse(request: Request): + async def handle_sse(request: Request) -> Response: try: - async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams: - if streams: - await app.run(streams[0], streams[1], app.create_initialization_options()) + async with sse_transport.connect_sse(request.scope, request.receive, request._send) as (read, write): + await server.run(read, write, server.create_initialization_options()) except ValueError as e: - # Validation error was already handled inside connect_sse + # Validation error was already handled inside connect_sse, which sent the rejection + # response itself; its non-empty body checkpoints, so the test reads the rejection + # status before the trailing Response() below sends a second response start. logger.debug(f"SSE connection failed validation: {e}") return Response() - routes = [ - Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse_transport.handle_post_message), - ] - - starlette_app = Starlette(routes=routes) - uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") - - -def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): - """Start server in a separate process.""" - process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) - process.start() - # Wait for server to be ready to accept connections - wait_for_server(port) - return process + app = Starlette( + routes=[ + Route("/sse", endpoint=handle_sse), + Mount("/messages/", app=sse_transport.handle_post_message), + ] + ) + # The SSE GET runs until it observes a disconnect, so the bridge must let the application + # drain on close rather than cancelling it. + transport = StreamingASGITransport(app, cancel_on_close=False) + return httpx.AsyncClient(transport=transport, base_url=BASE_URL) @pytest.mark.anyio -async def test_sse_security_default_settings(server_port: int): - """Test SSE with default security settings (protection disabled).""" - process = start_server_process(server_port) +async def test_sse_security_default_settings() -> None: + """With default security settings (protection disabled), any Host and Origin connect.""" + headers = {"Host": "evil.com", "Origin": "http://evil.com"} - try: - headers = {"Host": "evil.com", "Origin": "http://evil.com"} - - async with httpx.AsyncClient(timeout=5.0) as client: - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - assert response.status_code == 200 - finally: - process.terminate() - process.join() + async with sse_security_client() as client: + async with client.stream("GET", "/sse", headers=headers) as response: + assert response.status_code == 200 @pytest.mark.anyio -async def test_sse_security_invalid_host_header(server_port: int): - """Test SSE with invalid Host header.""" - # Enable security by providing settings with an empty allowed_hosts list +async def test_sse_security_invalid_host_header() -> None: + """A Host header outside allowed_hosts is rejected with 421.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["example.com"]) - process = start_server_process(server_port, security_settings) - try: - # Test with invalid host header - headers = {"Host": "evil.com"} - - async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) - assert response.status_code == 421 - assert response.text == "Invalid Host header" - - finally: - process.terminate() - process.join() + async with sse_security_client(security_settings) as client: + response = await client.get("/sse", headers={"Host": "evil.com"}) + assert response.status_code == 421 + assert response.text == "Invalid Host header" @pytest.mark.anyio -async def test_sse_security_invalid_origin_header(server_port: int): - """Test SSE with invalid Origin header.""" - # Configure security to allow the host but restrict origins +async def test_sse_security_invalid_origin_header() -> None: + """An Origin header outside allowed_origins is rejected with 403.""" security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://localhost:*"] ) - process = start_server_process(server_port, security_settings) - - try: - # Test with invalid origin header - headers = {"Origin": "http://evil.com"} - async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) - assert response.status_code == 403 - assert response.text == "Invalid Origin header" - - finally: - process.terminate() - process.join() + async with sse_security_client(security_settings) as client: + response = await client.get("/sse", headers={"Origin": "http://evil.com"}) + assert response.status_code == 403 + assert response.text == "Invalid Origin header" @pytest.mark.anyio -async def test_sse_security_post_invalid_content_type(server_port: int): - """Test POST endpoint with invalid Content-Type header.""" - # Configure security to allow the host +async def test_sse_security_post_invalid_content_type() -> None: + """A POST whose Content-Type is not application/json (or is missing) is rejected with 400.""" security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"] ) - process = start_server_process(server_port, security_settings) - - try: - async with httpx.AsyncClient(timeout=5.0) as client: - # Test POST with invalid content type - fake_session_id = "12345678123456781234567812345678" - response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", - headers={"Content-Type": "text/plain"}, - content="test", - ) - assert response.status_code == 400 - assert response.text == "Invalid Content-Type header" + fake_session_id = "12345678123456781234567812345678" - # Test POST with missing content type - response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", content="test" - ) - assert response.status_code == 400 - assert response.text == "Invalid Content-Type header" + async with sse_security_client(security_settings) as client: + response = await client.post( + f"/messages/?session_id={fake_session_id}", + headers={"Content-Type": "text/plain"}, + content="test", + ) + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" - finally: - process.terminate() - process.join() + response = await client.post(f"/messages/?session_id={fake_session_id}", content="test") + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" @pytest.mark.anyio -async def test_sse_security_disabled(server_port: int): - """Test SSE with security disabled.""" +async def test_sse_security_disabled() -> None: + """With protection explicitly disabled, a disallowed Host still connects.""" settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) - process = start_server_process(server_port, settings) - - try: - # Test with invalid host header - should still work - headers = {"Host": "evil.com"} - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully even with invalid host - assert response.status_code == 200 - - finally: - process.terminate() - process.join() + async with sse_security_client(settings) as client: + async with client.stream("GET", "/sse", headers={"Host": "evil.com"}) as response: + assert response.status_code == 200 @pytest.mark.anyio -async def test_sse_security_custom_allowed_hosts(server_port: int): - """Test SSE with custom allowed hosts.""" +async def test_sse_security_custom_allowed_hosts() -> None: + """A custom entry in allowed_hosts connects; hosts outside the list are still rejected.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["localhost", "127.0.0.1", "custom.host"], allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"], ) - process = start_server_process(server_port, settings) - - try: - # Test with custom allowed host - headers = {"Host": "custom.host"} - - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully with custom host - assert response.status_code == 200 - # Test with non-allowed host - headers = {"Host": "evil.com"} + async with sse_security_client(settings) as client: + async with client.stream("GET", "/sse", headers={"Host": "custom.host"}) as response: + assert response.status_code == 200 - async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) - assert response.status_code == 421 - assert response.text == "Invalid Host header" - - finally: - process.terminate() - process.join() + response = await client.get("/sse", headers={"Host": "evil.com"}) + assert response.status_code == 421 + assert response.text == "Invalid Host header" @pytest.mark.anyio -async def test_sse_security_wildcard_ports(server_port: int): - """Test SSE with wildcard port patterns.""" +async def test_sse_security_wildcard_ports() -> None: + """A `host:*` pattern accepts that host with any port, for Host and Origin alike.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["localhost:*", "127.0.0.1:*"], allowed_origins=["http://localhost:*", "http://127.0.0.1:*"], ) - process = start_server_process(server_port, settings) - try: - # Test with various port numbers + async with sse_security_client(settings) as client: for test_port in [8080, 3000, 9999]: - headers = {"Host": f"localhost:{test_port}"} - - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully with any port - assert response.status_code == 200 - - headers = {"Origin": f"http://localhost:{test_port}"} - - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully with any port - assert response.status_code == 200 + async with client.stream("GET", "/sse", headers={"Host": f"localhost:{test_port}"}) as response: + assert response.status_code == 200 - finally: - process.terminate() - process.join() + async with client.stream("GET", "/sse", headers={"Origin": f"http://localhost:{test_port}"}) as response: + assert response.status_code == 200 @pytest.mark.anyio -async def test_sse_security_post_valid_content_type(server_port: int): - """Test POST endpoint with valid Content-Type headers.""" - # Configure security to allow the host +async def test_sse_security_post_valid_content_type() -> None: + """Every application/json Content-Type variant passes validation (reaching the session lookup).""" security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"] ) - process = start_server_process(server_port, security_settings) - - try: - async with httpx.AsyncClient() as client: - # Test with various valid content types - valid_content_types = [ - "application/json", - "application/json; charset=utf-8", - "application/json;charset=utf-8", - "APPLICATION/JSON", # Case insensitive - ] - - for content_type in valid_content_types: - # Use a valid UUID format (even though session won't exist) - fake_session_id = "12345678123456781234567812345678" - response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", - headers={"Content-Type": content_type}, - json={"test": "data"}, - ) - # Will get 404 because session doesn't exist, but that's OK - # We're testing that it passes the content-type check - assert response.status_code == 404 - assert response.text == "Could not find session" - - finally: - process.terminate() - process.join() + valid_content_types = [ + "application/json", + "application/json; charset=utf-8", + "application/json;charset=utf-8", + "APPLICATION/JSON", # Case insensitive + ] + # A well-formed session ID that no live session owns. + fake_session_id = "12345678123456781234567812345678" + + async with sse_security_client(security_settings) as client: + for content_type in valid_content_types: + response = await client.post( + f"/messages/?session_id={fake_session_id}", + headers={"Content-Type": content_type}, + json={"test": "data"}, + ) + # 404 proves the request passed the content-type check and reached the session lookup. + assert response.status_code == 404 + assert response.text == "Could not find session" def _authenticated_user(client_id: str, subject: str | None = None, issuer: str | None = None) -> AuthenticatedUser: diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index ba75547964..f02e520eea 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -340,12 +340,33 @@ async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestP await client.list_tools() +class _IdleTimeoutObserver(logging.Handler): + """Resolves `reaped` when the manager logs that a session's idle timeout fired.""" + + def __init__(self) -> None: + super().__init__() + self.reaped = anyio.Event() + + def emit(self, record: logging.LogRecord) -> None: + if "idle timeout" in record.getMessage(): + self.reaped.set() + + @pytest.mark.anyio -async def test_idle_session_is_reaped(): +async def test_idle_session_is_reaped(caplog: pytest.LogCaptureFixture, request: pytest.FixtureRequest): """After idle timeout fires, the session returns 404.""" app = Server("test-idle-reap") manager = StreamableHTTPSessionManager(app=app, session_idle_timeout=0.05) + # The reap is observed through the manager's own "idle timeout" log record: the manager pops + # the session synchronously after emitting it, before its next await, so a waiter woken by + # the record always finds the session gone. caplog.set_level enables INFO so it is created. + observer = _IdleTimeoutObserver() + manager_logger = logging.getLogger(streamable_http_manager.__name__) + manager_logger.addHandler(observer) + request.addfinalizer(lambda: manager_logger.removeHandler(observer)) + caplog.set_level(logging.INFO, logger=streamable_http_manager.__name__) + async with manager.run(): sent_messages: list[Message] = [] @@ -376,8 +397,10 @@ async def mock_receive(): # pragma: no cover assert session_id is not None, "Session ID not found in response headers" - # Wait for the 50ms idle timeout to fire and cleanup to complete - await anyio.sleep(0.1) + # Wait for the 50ms idle timeout to fire and the session to be unregistered. Re-requesting + # the session to poll for the 404 would push its idle deadline forward and keep it alive. + with anyio.fail_after(5): + await observer.reaped.wait() # Verify via public API: old session ID now returns 404 response_messages: list[Message] = [] diff --git a/tests/server/test_streamable_http_security.py b/tests/server/test_streamable_http_security.py index 897555353e..f13bb4a9bb 100644 --- a/tests/server/test_streamable_http_security.py +++ b/tests/server/test_streamable_http_security.py @@ -1,291 +1,130 @@ """Tests for StreamableHTTP server DNS rebinding protection.""" -import multiprocessing -import socket -from collections.abc import AsyncGenerator +from collections.abc import AsyncIterator from contextlib import asynccontextmanager import httpx import pytest -import uvicorn from starlette.applications import Starlette from starlette.routing import Mount -from starlette.types import Receive, Scope, Send from mcp.server import Server from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings -from mcp.types import Tool -from tests.test_helpers import wait_for_server +from tests.interaction.transports import StreamingASGITransport SERVER_NAME = "test_streamable_http_security_server" +# The in-process app is mounted at this origin purely so URLs are well-formed and the default +# Host header is a localhost form; nothing listens here. +BASE_URL = "http://127.0.0.1:8000" -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] +@asynccontextmanager +async def streamable_http_security_client( + security_settings: TransportSecuritySettings | None = None, +) -> AsyncIterator[httpx.AsyncClient]: + """Yield an httpx client served in process by a StreamableHTTP app with the given settings.""" + session_manager = StreamableHTTPSessionManager(app=Server(SERVER_NAME), security_settings=security_settings) + app = Starlette(routes=[Mount("/", app=session_manager.handle_request)]) -@pytest.fixture -def server_url(server_port: int) -> str: # pragma: no cover - return f"http://127.0.0.1:{server_port}" + async with session_manager.run(): + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as client: + yield client -class SecurityTestServer(Server): # pragma: no cover - def __init__(self): - super().__init__(SERVER_NAME) +def _base_headers() -> dict[str, str]: + """Headers every well-formed request carries, so each test varies only the header under test.""" + return {"Accept": "application/json, text/event-stream", "Content-Type": "application/json"} - async def on_list_tools(self) -> list[Tool]: - return [] - -def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): # pragma: no cover - """Run the StreamableHTTP server with specified security settings.""" - app = SecurityTestServer() - - # Create session manager with security settings - session_manager = StreamableHTTPSessionManager( - app=app, - json_response=False, - stateless=False, - security_settings=security_settings, - ) - - # Create the ASGI handler - async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None: - await session_manager.handle_request(scope, receive, send) - - # Create Starlette app with lifespan - @asynccontextmanager - async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: - async with session_manager.run(): - yield - - routes = [ - Mount("/", app=handle_streamable_http), - ] - - starlette_app = Starlette(routes=routes, lifespan=lifespan) - uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") - - -def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): - """Start server in a separate process.""" - process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) - process.start() - # Wait for server to be ready to accept connections - wait_for_server(port) - return process +def _initialize_body() -> dict[str, object]: + """A minimal initialize POST body; these tests assert header validation, not the handshake.""" + return {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} @pytest.mark.anyio -async def test_streamable_http_security_default_settings(server_port: int): - """Test StreamableHTTP with default security settings (protection enabled).""" - process = start_server_process(server_port) - - try: - # Test with valid localhost headers - async with httpx.AsyncClient(timeout=5.0) as client: - # POST request to initialize session - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - ) - assert response.status_code == 200 - assert "mcp-session-id" in response.headers - - finally: - process.terminate() - process.join() +async def test_streamable_http_security_default_settings() -> None: + """With default security settings, a request with localhost headers is served.""" + async with streamable_http_security_client() as client: + response = await client.post("/", json=_initialize_body(), headers=_base_headers()) + assert response.status_code == 200 + assert "mcp-session-id" in response.headers @pytest.mark.anyio -async def test_streamable_http_security_invalid_host_header(server_port: int): - """Test StreamableHTTP with invalid Host header.""" +async def test_streamable_http_security_invalid_host_header() -> None: + """A Host header outside allowed_hosts is rejected with 421.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True) - process = start_server_process(server_port, security_settings) - - try: - # Test with invalid host header - headers = { - "Host": "evil.com", - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers, - ) - assert response.status_code == 421 - assert response.text == "Invalid Host header" - - finally: - process.terminate() - process.join() + + async with streamable_http_security_client(security_settings) as client: + response = await client.post("/", json=_initialize_body(), headers=_base_headers() | {"Host": "evil.com"}) + assert response.status_code == 421 + assert response.text == "Invalid Host header" @pytest.mark.anyio -async def test_streamable_http_security_invalid_origin_header(server_port: int): - """Test StreamableHTTP with invalid Origin header.""" +async def test_streamable_http_security_invalid_origin_header() -> None: + """An Origin header outside allowed_origins is rejected with 403.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"]) - process = start_server_process(server_port, security_settings) - - try: - # Test with invalid origin header - headers = { - "Origin": "http://evil.com", - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers, - ) - assert response.status_code == 403 - assert response.text == "Invalid Origin header" - - finally: - process.terminate() - process.join() + + async with streamable_http_security_client(security_settings) as client: + response = await client.post( + "/", json=_initialize_body(), headers=_base_headers() | {"Origin": "http://evil.com"} + ) + assert response.status_code == 403 + assert response.text == "Invalid Origin header" @pytest.mark.anyio -async def test_streamable_http_security_invalid_content_type(server_port: int): - """Test StreamableHTTP POST with invalid Content-Type header.""" - process = start_server_process(server_port) - - try: - async with httpx.AsyncClient(timeout=5.0) as client: - # Test POST with invalid content type - response = await client.post( - f"http://127.0.0.1:{server_port}/", - headers={ - "Content-Type": "text/plain", - "Accept": "application/json, text/event-stream", - }, - content="test", - ) - assert response.status_code == 400 - assert response.text == "Invalid Content-Type header" - - # Test POST with missing content type - response = await client.post( - f"http://127.0.0.1:{server_port}/", - headers={"Accept": "application/json, text/event-stream"}, - content="test", - ) - assert response.status_code == 400 - assert response.text == "Invalid Content-Type header" - - finally: - process.terminate() - process.join() +async def test_streamable_http_security_invalid_content_type() -> None: + """A POST whose Content-Type is not application/json (or is missing) is rejected with 400.""" + async with streamable_http_security_client() as client: + response = await client.post("/", headers=_base_headers() | {"Content-Type": "text/plain"}, content="test") + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" + + response = await client.post("/", headers={"Accept": "application/json, text/event-stream"}, content="test") + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" @pytest.mark.anyio -async def test_streamable_http_security_disabled(server_port: int): - """Test StreamableHTTP with security disabled.""" +async def test_streamable_http_security_disabled() -> None: + """With protection explicitly disabled, a disallowed Host is still served.""" settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) - process = start_server_process(server_port, settings) - - try: - # Test with invalid host header - should still work - headers = { - "Host": "evil.com", - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers, - ) - # Should connect successfully even with invalid host - assert response.status_code == 200 - - finally: - process.terminate() - process.join() + + async with streamable_http_security_client(settings) as client: + response = await client.post("/", json=_initialize_body(), headers=_base_headers() | {"Host": "evil.com"}) + assert response.status_code == 200 @pytest.mark.anyio -async def test_streamable_http_security_custom_allowed_hosts(server_port: int): - """Test StreamableHTTP with custom allowed hosts.""" +async def test_streamable_http_security_custom_allowed_hosts() -> None: + """A custom entry in allowed_hosts is served.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["localhost", "127.0.0.1", "custom.host"], allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"], ) - process = start_server_process(server_port, settings) - - try: - # Test with custom allowed host - headers = { - "Host": "custom.host", - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers, - ) - # Should connect successfully with custom host - assert response.status_code == 200 - finally: - process.terminate() - process.join() + + async with streamable_http_security_client(settings) as client: + response = await client.post("/", json=_initialize_body(), headers=_base_headers() | {"Host": "custom.host"}) + assert response.status_code == 200 @pytest.mark.anyio -async def test_streamable_http_security_get_request(server_port: int): - """Test StreamableHTTP GET request with security.""" +async def test_streamable_http_security_get_request() -> None: + """GET requests pass the same Host validation before any session handling.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1"]) - process = start_server_process(server_port, security_settings) - - try: - # Test GET request with invalid host header - headers = { - "Host": "evil.com", - "Accept": "text/event-stream", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.get(f"http://127.0.0.1:{server_port}/", headers=headers) - assert response.status_code == 421 - assert response.text == "Invalid Host header" - - # Test GET request with valid host header - headers = { - "Host": "127.0.0.1", - "Accept": "text/event-stream", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - # GET requests need a session ID in StreamableHTTP - # So it will fail with "Missing session ID" not security error - response = await client.get(f"http://127.0.0.1:{server_port}/", headers=headers) - # This should pass security but fail on session validation - assert response.status_code == 400 - body = response.json() - assert "Missing session ID" in body["error"]["message"] - - finally: - process.terminate() - process.join() + + async with streamable_http_security_client(security_settings) as client: + response = await client.get("/", headers={"Accept": "text/event-stream", "Host": "evil.com"}) + assert response.status_code == 421 + assert response.text == "Invalid Host header" + + response = await client.get("/", headers={"Accept": "text/event-stream", "Host": "127.0.0.1"}) + # An allowed host passes security and fails on session validation instead. + assert response.status_code == 400 + body = response.json() + assert "Missing session ID" in body["error"]["message"] From ed39e73c0ba16cd3ff9c3996421a17c21b4ed05b Mon Sep 17 00:00:00 2001 From: Max <224885523+maxisbey@users.noreply.github.com> Date: Tue, 2 Jun 2026 21:46:53 +0100 Subject: [PATCH 80/84] Run SSE and Unicode transport tests in process instead of over sockets (#2765) --- tests/client/test_http_unicode.py | 288 ++++++++------------- tests/shared/test_sse.py | 409 ++++++++++-------------------- 2 files changed, 242 insertions(+), 455 deletions(-) diff --git a/tests/client/test_http_unicode.py b/tests/client/test_http_unicode.py index cc2e14e469..585a142617 100644 --- a/tests/client/test_http_unicode.py +++ b/tests/client/test_http_unicode.py @@ -4,11 +4,10 @@ (server→client and client→server) using the streamable HTTP transport. """ -import multiprocessing -import socket -from collections.abc import AsyncGenerator, Generator +from collections.abc import AsyncIterator from contextlib import asynccontextmanager +import httpx import pytest from starlette.applications import Starlette from starlette.routing import Mount @@ -19,7 +18,10 @@ from mcp.server import Server, ServerRequestContext from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.types import TextContent, Tool -from tests.test_helpers import wait_for_server +from tests.interaction.transports import StreamingASGITransport + +# The in-process app is mounted at this origin purely so URLs are well-formed; nothing listens here. +BASE_URL = "http://127.0.0.1:8000" # Test constants with various Unicode characters UNICODE_TEST_STRINGS = { @@ -41,74 +43,62 @@ } -def run_unicode_server(port: int) -> None: # pragma: no cover - """Run the Unicode test server in a separate process.""" - import uvicorn - - # Need to recreate the server setup in this process - async def handle_list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[ - Tool( - name="echo_unicode", - description="🔤 Echo Unicode text - Hello 👋 World 🌍 - Testing 🧪 Unicode ✨", - input_schema={ - "type": "object", - "properties": { - "text": {"type": "string", "description": "Text to echo back"}, - }, - "required": ["text"], +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + Tool( + name="echo_unicode", + description="🔤 Echo Unicode text - Hello 👋 World 🌍 - Testing 🧪 Unicode ✨", + input_schema={ + "type": "object", + "properties": { + "text": {"type": "string", "description": "Text to echo back"}, }, - ), - ] - ) - - async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: - if params.name == "echo_unicode": - text = params.arguments.get("text", "") if params.arguments else "" - return types.CallToolResult( - content=[ - TextContent( - type="text", - text=f"Echo: {text}", - ) - ] + "required": ["text"], + }, + ), + ] + ) + + +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + assert params.name == "echo_unicode" + assert params.arguments is not None + return types.CallToolResult(content=[TextContent(type="text", text=f"Echo: {params.arguments['text']}")]) + + +async def handle_list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListPromptsResult: + return types.ListPromptsResult( + prompts=[ + types.Prompt( + name="unicode_prompt", + description="Unicode prompt - Слой хранилища, где располагаются", + arguments=[], ) - else: - raise ValueError(f"Unknown tool: {params.name}") - - async def handle_list_prompts( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListPromptsResult: - return types.ListPromptsResult( - prompts=[ - types.Prompt( - name="unicode_prompt", - description="Unicode prompt - Слой хранилища, где располагаются", - arguments=[], - ) - ] - ) - - async def handle_get_prompt( - ctx: ServerRequestContext, params: types.GetPromptRequestParams - ) -> types.GetPromptResult: - if params.name == "unicode_prompt": - return types.GetPromptResult( - messages=[ - types.PromptMessage( - role="user", - content=types.TextContent( - type="text", - text="Hello世界🌍Привет안녕مرحباשלום", - ), - ) - ] + ] + ) + + +async def handle_get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> types.GetPromptResult: + assert params.name == "unicode_prompt" + return types.GetPromptResult( + messages=[ + types.PromptMessage( + role="user", + content=types.TextContent(type="text", text="Hello世界🌍Привет안녕مرحباשלום"), ) - raise ValueError(f"Unknown prompt: {params.name}") + ] + ) + +@asynccontextmanager +async def unicode_session() -> AsyncIterator[ClientSession]: + """Yield an initialized ClientSession speaking streamable HTTP (SSE responses) to the + Unicode test server, entirely in process.""" server = Server( name="unicode_test_server", on_list_tools=handle_list_tools, @@ -116,122 +106,68 @@ async def handle_get_prompt( on_list_prompts=handle_list_prompts, on_get_prompt=handle_get_prompt, ) - - # Create the session manager - session_manager = StreamableHTTPSessionManager( - app=server, - json_response=False, # Use SSE for testing - ) - - @asynccontextmanager - async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: - async with session_manager.run(): - yield - - # Create an ASGI application - app = Starlette( - debug=True, - routes=[ - Mount("/mcp", app=session_manager.handle_request), - ], - lifespan=lifespan, - ) - - # Run the server - config = uvicorn.Config( - app=app, - host="127.0.0.1", - port=port, - log_level="error", - ) - uvicorn_server = uvicorn.Server(config) - uvicorn_server.run() - - -@pytest.fixture -def unicode_server_port() -> int: - """Find an available port for the Unicode test server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def running_unicode_server(unicode_server_port: int) -> Generator[str, None, None]: - """Start a Unicode test server in a separate process.""" - proc = multiprocessing.Process(target=run_unicode_server, kwargs={"port": unicode_server_port}, daemon=True) - proc.start() - - # Wait for server to be ready - wait_for_server(unicode_server_port) - - try: - yield f"http://127.0.0.1:{unicode_server_port}" - finally: - # Clean up - try graceful termination first - proc.terminate() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - proc.kill() - proc.join(timeout=1) + # SSE response mode, so Unicode rides the SSE event encoding rather than a plain JSON body. + session_manager = StreamableHTTPSessionManager(app=server, json_response=False) + app = Starlette(routes=[Mount("/mcp", app=session_manager.handle_request)]) + + async with ( + session_manager.run(), + # follow_redirects matches the SDK's own client factory; Starlette's Mount 307-redirects + # the bare /mcp path to /mcp/. + httpx.AsyncClient( + transport=StreamingASGITransport(app), base_url=BASE_URL, follow_redirects=True + ) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + await session.initialize() + yield session @pytest.mark.anyio -async def test_streamable_http_client_unicode_tool_call(running_unicode_server: str) -> None: +async def test_streamable_http_client_unicode_tool_call() -> None: """Test that Unicode text is correctly handled in tool calls via streamable HTTP.""" - base_url = running_unicode_server - endpoint_url = f"{base_url}/mcp" - - async with streamable_http_client(endpoint_url) as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - - # Test 1: List tools (server→client Unicode in descriptions) - tools = await session.list_tools() - assert len(tools.tools) == 1 + async with unicode_session() as session: + # Test 1: List tools (server→client Unicode in descriptions) + tools = await session.list_tools() + assert len(tools.tools) == 1 - # Check Unicode in tool descriptions - echo_tool = tools.tools[0] - assert echo_tool.name == "echo_unicode" - assert echo_tool.description is not None - assert "🔤" in echo_tool.description - assert "👋" in echo_tool.description + # Check Unicode in tool descriptions + echo_tool = tools.tools[0] + assert echo_tool.name == "echo_unicode" + assert echo_tool.description is not None + assert "🔤" in echo_tool.description + assert "👋" in echo_tool.description - # Test 2: Send Unicode text in tool call (client→server→client) - for test_name, test_string in UNICODE_TEST_STRINGS.items(): - result = await session.call_tool("echo_unicode", arguments={"text": test_string}) + # Test 2: Send Unicode text in tool call (client→server→client) + for test_name, test_string in UNICODE_TEST_STRINGS.items(): + result = await session.call_tool("echo_unicode", arguments={"text": test_string}) - # Verify server correctly received and echoed back Unicode - assert len(result.content) == 1 - content = result.content[0] - assert content.type == "text" - assert f"Echo: {test_string}" == content.text, f"Failed for {test_name}" + # Verify server correctly received and echoed back Unicode + assert len(result.content) == 1 + content = result.content[0] + assert content.type == "text" + assert f"Echo: {test_string}" == content.text, f"Failed for {test_name}" @pytest.mark.anyio -async def test_streamable_http_client_unicode_prompts(running_unicode_server: str) -> None: +async def test_streamable_http_client_unicode_prompts() -> None: """Test that Unicode text is correctly handled in prompts via streamable HTTP.""" - base_url = running_unicode_server - endpoint_url = f"{base_url}/mcp" - - async with streamable_http_client(endpoint_url) as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - - # Test 1: List prompts (server→client Unicode in descriptions) - prompts = await session.list_prompts() - assert len(prompts.prompts) == 1 - - prompt = prompts.prompts[0] - assert prompt.name == "unicode_prompt" - assert prompt.description is not None - assert "Слой хранилища, где располагаются" in prompt.description - - # Test 2: Get prompt with Unicode content (server→client) - result = await session.get_prompt("unicode_prompt", arguments={}) - assert len(result.messages) == 1 - - message = result.messages[0] - assert message.role == "user" - assert message.content.type == "text" - assert message.content.text == "Hello世界🌍Привет안녕مرحباשלום" + async with unicode_session() as session: + # Test 1: List prompts (server→client Unicode in descriptions) + prompts = await session.list_prompts() + assert len(prompts.prompts) == 1 + + prompt = prompts.prompts[0] + assert prompt.name == "unicode_prompt" + assert prompt.description is not None + assert "Слой хранилища, где располагаются" in prompt.description + + # Test 2: Get prompt with Unicode content (server→client) + result = await session.get_prompt("unicode_prompt", arguments={}) + assert len(result.messages) == 1 + + message = result.messages[0] + assert message.role == "user" + assert message.content.type == "text" + assert message.content.text == "Hello世界🌍Привет안녕مرحباשלום" diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 5629a5707b..675a4acb16 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -1,7 +1,7 @@ +"""Tests for the SSE client and server transports, driven entirely in process.""" + import json -import multiprocessing -import socket -from collections.abc import AsyncGenerator, Generator +from collections.abc import AsyncGenerator from typing import Any from unittest.mock import AsyncMock, MagicMock, Mock, patch from urllib.parse import urlparse @@ -9,7 +9,6 @@ import anyio import httpx import pytest -import uvicorn from httpx_sse import ServerSentEvent from inline_snapshot import snapshot from starlette.applications import Starlette @@ -24,6 +23,7 @@ from mcp.server import Server, ServerRequestContext from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings +from mcp.shared._httpx_utils import McpHttpClientFactory from mcp.shared.exceptions import MCPError from mcp.types import ( CallToolRequestParams, @@ -41,171 +41,114 @@ TextResourceContents, Tool, ) -from tests.test_helpers import wait_for_server +from tests.interaction.transports import StreamingASGITransport SERVER_NAME = "test_server_for_SSE" +# The in-process app is mounted at this origin purely so URLs are well-formed; nothing listens here. +BASE_URL = "http://127.0.0.1:8000" + + +def in_process_client_factory(app: Starlette) -> McpHttpClientFactory: + """An httpx_client_factory for sse_client whose clients are served in process by `app`.""" + + def factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + # The SSE GET runs until it observes a disconnect, so the bridge must let the + # application drain on close rather than cancelling it. follow_redirects matches + # create_mcp_http_client, the factory this one stands in for. + return httpx.AsyncClient( + transport=StreamingASGITransport(app, cancel_on_close=False), + base_url=BASE_URL, + headers=headers, + timeout=timeout, + auth=auth, + follow_redirects=True, + ) -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - + return factory -@pytest.fixture -def server_url(server_port: int) -> str: - return f"http://127.0.0.1:{server_port}" - -async def _handle_read_resource( # pragma: no cover - ctx: ServerRequestContext, params: ReadResourceRequestParams -) -> ReadResourceResult: +async def _handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: uri = str(params.uri) parsed = urlparse(uri) if parsed.scheme == "foobar": - text = f"Read {parsed.netloc}" - elif parsed.scheme == "slow": - await anyio.sleep(2.0) - text = f"Slow response from {parsed.netloc}" - else: - raise MCPError(code=404, message="OOPS! no resource with that URI was found") - return ReadResourceResult(contents=[TextResourceContents(uri=uri, text=text, mime_type="text/plain")]) - - -async def _handle_list_tools( # pragma: no cover - ctx: ServerRequestContext, params: PaginatedRequestParams | None -) -> ListToolsResult: - return ListToolsResult( - tools=[ - Tool( - name="test_tool", - description="A test tool", - input_schema={"type": "object", "properties": {}}, - ) - ] - ) - - -async def _handle_call_tool( # pragma: no cover - ctx: ServerRequestContext, params: CallToolRequestParams -) -> CallToolResult: - return CallToolResult(content=[TextContent(type="text", text=f"Called {params.name}")]) - - -def _create_server() -> Server: # pragma: no cover - return Server( - SERVER_NAME, - on_read_resource=_handle_read_resource, - on_list_tools=_handle_list_tools, - on_call_tool=_handle_call_tool, - ) + return ReadResourceResult( + contents=[TextResourceContents(uri=uri, text=f"Read {parsed.netloc}", mime_type="text/plain")] + ) + raise MCPError(code=404, message="OOPS! no resource with that URI was found") -# Test fixtures -def make_server_app() -> Starlette: # pragma: no cover - """Create test Starlette app with SSE transport""" - # Configure security with allowed hosts/origins for testing - security_settings = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] +def make_app(server: Server) -> Starlette: + """Mount `server` on a Starlette app exposing the SSE transport at /sse and /messages/.""" + # DNS-rebinding protection validates Host/Origin headers against a network attack that cannot + # exist for an in-process app; the transport security behaviour itself is pinned by + # tests/server/test_sse_security.py. + sse = SseServerTransport( + "/messages/", security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False) ) - sse = SseServerTransport("/messages/", security_settings=security_settings) - server = _create_server() async def handle_sse(request: Request) -> Response: - async with sse.connect_sse(request.scope, request.receive, request._send) as streams: - await server.run(streams[0], streams[1], server.create_initialization_options()) + async with sse.connect_sse(request.scope, request.receive, request._send) as (read_stream, write_stream): + await server.run(read_stream, write_stream, server.create_initialization_options()) return Response() - app = Starlette( + return Starlette( routes=[ Route("/sse", endpoint=handle_sse), Mount("/messages/", app=sse.handle_post_message), ] ) - return app - - -def run_server(server_port: int) -> None: # pragma: no cover - app = make_server_app() - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting server on {server_port}") - server.run() - - -@pytest.fixture() -def server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True) - print("starting process") - proc.start() - # Wait for server to be running - print("waiting for server to start") - wait_for_server(server_port) +def make_server_app() -> Starlette: + return make_app(Server(SERVER_NAME, on_read_resource=_handle_read_resource)) - yield - print("killing server") - # Signal the server to stop - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("server process failed to terminate") - - -@pytest.fixture() -async def http_client(server: None, server_url: str) -> AsyncGenerator[httpx.AsyncClient, None]: - """Create test client""" - async with httpx.AsyncClient(base_url=server_url) as client: - yield client +@pytest.mark.anyio +async def test_raw_sse_connection() -> None: + """The SSE GET responds 200 with an event-stream content type, announcing the session + endpoint as its first event.""" + http_client = httpx.AsyncClient( + transport=StreamingASGITransport(make_server_app(), cancel_on_close=False), base_url=BASE_URL + ) + with anyio.fail_after(5): + async with http_client, http_client.stream("GET", "/sse") as response: + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" -# Tests -@pytest.mark.anyio -async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None: - """Test the SSE connection establishment simply with an HTTP client.""" - async with anyio.create_task_group(): - - async def connection_test() -> None: - async with http_client.stream("GET", "/sse") as response: - assert response.status_code == 200 - assert response.headers["content-type"] == "text/event-stream; charset=utf-8" - - line_number = 0 - async for line in response.aiter_lines(): # pragma: no branch - if line_number == 0: - assert line == "event: endpoint" - elif line_number == 1: - assert line.startswith("data: /messages/?session_id=") - else: - return - line_number += 1 - - # Add timeout to prevent test from hanging if it fails - with anyio.fail_after(3): - await connection_test() + lines = response.aiter_lines() + assert await anext(lines) == "event: endpoint" + assert (await anext(lines)).startswith("data: /messages/?session_id=") @pytest.mark.anyio -async def test_sse_client_basic_connection(server: None, server_url: str) -> None: - async with sse_client(server_url + "/sse") as streams: +async def test_sse_client_basic_connection() -> None: + """A client initializes against, and pings, a server over the SSE transport.""" + factory = in_process_client_factory(make_server_app()) + async with sse_client(f"{BASE_URL}/sse", httpx_client_factory=factory) as streams: async with ClientSession(*streams) as session: - # Test initialization result = await session.initialize() assert isinstance(result, InitializeResult) assert result.server_info.name == SERVER_NAME - # Test ping ping_result = await session.send_ping() assert isinstance(ping_result, EmptyResult) @pytest.mark.anyio -async def test_sse_client_on_session_created(server: None, server_url: str) -> None: +async def test_sse_client_on_session_created() -> None: + """The session-created callback receives the new session ID before sse_client yields.""" + factory = in_process_client_factory(make_server_app()) captured: list[str] = [] - async with sse_client(server_url + "/sse", on_session_created=captured.append) as streams: + async with sse_client( + f"{BASE_URL}/sse", httpx_client_factory=factory, on_session_created=captured.append + ) as streams: async with ClientSession(*streams) as session: result = await session.initialize() assert isinstance(result, InitializeResult) @@ -226,13 +169,14 @@ async def test_sse_client_on_session_created(server: None, server_url: str) -> N ], ) def test_extract_session_id_from_endpoint(endpoint_url: str, expected: str | None) -> None: + """The session ID is read from the endpoint URL's sessionId/session_id query parameters.""" assert _extract_session_id_from_endpoint(endpoint_url) == expected @pytest.mark.anyio -async def test_sse_client_on_session_created_not_called_when_no_session_id( - server: None, server_url: str, monkeypatch: pytest.MonkeyPatch -) -> None: +async def test_sse_client_on_session_created_not_called_when_no_session_id(monkeypatch: pytest.MonkeyPatch) -> None: + """No session-created callback fires when the endpoint URL carries no session ID.""" + factory = in_process_client_factory(make_server_app()) callback_mock = Mock() def mock_extract(url: str) -> None: @@ -240,7 +184,7 @@ def mock_extract(url: str) -> None: monkeypatch.setattr(mcp.client.sse, "_extract_session_id_from_endpoint", mock_extract) - async with sse_client(server_url + "/sse", on_session_created=callback_mock) as streams: + async with sse_client(f"{BASE_URL}/sse", httpx_client_factory=factory, on_session_created=callback_mock) as streams: async with ClientSession(*streams) as session: result = await session.initialize() assert isinstance(result, InitializeResult) @@ -250,8 +194,9 @@ def mock_extract(url: str) -> None: @pytest.fixture -async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: - async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: +async def initialized_sse_client_session() -> AsyncGenerator[ClientSession, None]: + factory = in_process_client_factory(make_server_app()) + async with sse_client(f"{BASE_URL}/sse", httpx_client_factory=factory) as streams: async with ClientSession(*streams) as session: await session.initialize() yield session @@ -261,6 +206,7 @@ async def initialized_sse_client_session(server: None, server_url: str) -> Async async def test_sse_client_happy_request_and_response( initialized_sse_client_session: ClientSession, ) -> None: + """A resource read round-trips its arguments and the handler's content over SSE.""" session = initialized_sse_client_session response = await session.read_resource(uri="foobar://should-work") assert len(response.contents) == 1 @@ -272,93 +218,45 @@ async def test_sse_client_happy_request_and_response( async def test_sse_client_exception_handling( initialized_sse_client_session: ClientSession, ) -> None: + """A server-side MCPError reaches the client with its message intact.""" session = initialized_sse_client_session with pytest.raises(MCPError, match="OOPS! no resource with that URI was found"): await session.read_resource(uri="xxx://will-not-work") @pytest.mark.anyio -@pytest.mark.skip("this test highlights a possible bug in SSE read timeout exception handling") -async def test_sse_client_timeout( # pragma: no cover - initialized_sse_client_session: ClientSession, -) -> None: - session = initialized_sse_client_session - - # sanity check that normal, fast responses are working - response = await session.read_resource(uri="foobar://1") - assert isinstance(response, ReadResourceResult) - - with anyio.move_on_after(3): - with pytest.raises(MCPError, match="Read timed out"): - response = await session.read_resource(uri="slow://2") - # we should receive an error here - return - - pytest.fail("the client should have timed out and returned an error already") - - -def run_mounted_server(server_port: int) -> None: # pragma: no cover - app = make_server_app() - main_app = Starlette(routes=[Mount("/mounted_app", app=app)]) - server = uvicorn.Server(config=uvicorn.Config(app=main_app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting server on {server_port}") - server.run() - - -@pytest.fixture() -def mounted_server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process(target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True) - print("starting process") - proc.start() - - # Wait for server to be running - print("waiting for server to start") - wait_for_server(server_port) +async def test_sse_client_basic_connection_mounted_app() -> None: + """The SSE transport works unchanged when its app is mounted under a sub-path.""" + main_app = Starlette(routes=[Mount("/mounted_app", app=make_server_app())]) + factory = in_process_client_factory(main_app) - yield - - print("killing server") - # Signal the server to stop - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("server process failed to terminate") - - -@pytest.mark.anyio -async def test_sse_client_basic_connection_mounted_app(mounted_server: None, server_url: str) -> None: - async with sse_client(server_url + "/mounted_app/sse") as streams: + async with sse_client(f"{BASE_URL}/mounted_app/sse", httpx_client_factory=factory) as streams: async with ClientSession(*streams) as session: - # Test initialization result = await session.initialize() assert isinstance(result, InitializeResult) assert result.server_info.name == SERVER_NAME - # Test ping ping_result = await session.send_ping() assert isinstance(ping_result, EmptyResult) -async def _handle_context_call_tool( # pragma: no cover - ctx: ServerRequestContext, params: CallToolRequestParams -) -> CallToolResult: - headers_info: dict[str, Any] = {} - if ctx.request: - headers_info = dict(ctx.request.headers) +async def _handle_context_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + assert params.name in ("echo_headers", "echo_context") + assert ctx.request is not None + headers_info = dict(ctx.request.headers) if params.name == "echo_headers": return CallToolResult(content=[TextContent(type="text", text=json.dumps(headers_info))]) - elif params.name == "echo_context": - context_data = { - "request_id": (params.arguments or {}).get("request_id"), - "headers": headers_info, - } - return CallToolResult(content=[TextContent(type="text", text=json.dumps(context_data))]) - return CallToolResult(content=[TextContent(type="text", text=f"Called {params.name}")]) + assert params.arguments is not None + context_data = { + "request_id": params.arguments.get("request_id"), + "headers": headers_info, + } + return CallToolResult(content=[TextContent(type="text", text=json.dumps(context_data))]) -async def _handle_context_list_tools( # pragma: no cover +async def _handle_context_list_tools( ctx: ServerRequestContext, params: PaginatedRequestParams | None ) -> ListToolsResult: return ListToolsResult( @@ -381,115 +279,65 @@ async def _handle_context_list_tools( # pragma: no cover ) -def run_context_server(server_port: int) -> None: # pragma: no cover - """Run a server that captures request context""" - # Configure security with allowed hosts/origins for testing - security_settings = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] - ) - sse = SseServerTransport("/messages/", security_settings=security_settings) - context_server = Server( - "request_context_server", - on_call_tool=_handle_context_call_tool, - on_list_tools=_handle_context_list_tools, - ) - - async def handle_sse(request: Request) -> Response: - async with sse.connect_sse(request.scope, request.receive, request._send) as streams: - await context_server.run(streams[0], streams[1], context_server.create_initialization_options()) - return Response() - - app = Starlette( - routes=[ - Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse.handle_post_message), - ] +def make_context_server_app() -> Starlette: + return make_app( + Server( + "request_context_server", + on_call_tool=_handle_context_call_tool, + on_list_tools=_handle_context_list_tools, + ) ) - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting context server on {server_port}") - server.run() - - -@pytest.fixture() -def context_server(server_port: int) -> Generator[None, None, None]: - """Fixture that provides a server with request context capture""" - proc = multiprocessing.Process(target=run_context_server, kwargs={"server_port": server_port}, daemon=True) - print("starting context server process") - proc.start() - - # Wait for server to be running - print("waiting for context server to start") - wait_for_server(server_port) - - yield - - print("killing context server") - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("context server process failed to terminate") - @pytest.mark.anyio -async def test_request_context_propagation(context_server: None, server_url: str) -> None: - """Test that request context is properly propagated through SSE transport.""" - # Test with custom headers +async def test_request_context_propagation() -> None: + """Custom HTTP headers on the SSE connection are visible to server handlers via ctx.request.""" + factory = in_process_client_factory(make_context_server_app()) + custom_headers = { "Authorization": "Bearer test-token", "X-Custom-Header": "test-value", "X-Trace-Id": "trace-123", } - async with sse_client(server_url + "/sse", headers=custom_headers) as ( - read_stream, - write_stream, - ): - async with ClientSession(read_stream, write_stream) as session: - # Initialize the session + async with sse_client(f"{BASE_URL}/sse", httpx_client_factory=factory, headers=custom_headers) as streams: + async with ClientSession(*streams) as session: result = await session.initialize() assert isinstance(result, InitializeResult) - # Call the tool that echoes headers back tool_result = await session.call_tool("echo_headers", {}) - # Parse the JSON response - assert len(tool_result.content) == 1 - headers_data = json.loads(tool_result.content[0].text if tool_result.content[0].type == "text" else "{}") + content = tool_result.content[0] + assert isinstance(content, TextContent) + headers_data = json.loads(content.text) - # Verify headers were propagated assert headers_data.get("authorization") == "Bearer test-token" assert headers_data.get("x-custom-header") == "test-value" assert headers_data.get("x-trace-id") == "trace-123" @pytest.mark.anyio -async def test_request_context_isolation(context_server: None, server_url: str) -> None: - """Test that request contexts are isolated between different SSE clients.""" +async def test_request_context_isolation() -> None: + """Each SSE connection's handlers see only that connection's request headers.""" + factory = in_process_client_factory(make_context_server_app()) contexts: list[dict[str, Any]] = [] - # Create multiple clients with different headers + # Connect three clients in turn, each with its own headers. for i in range(3): headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"} - async with sse_client(server_url + "/sse", headers=headers) as ( - read_stream, - write_stream, - ): - async with ClientSession(read_stream, write_stream) as session: + async with sse_client(f"{BASE_URL}/sse", httpx_client_factory=factory, headers=headers) as streams: + async with ClientSession(*streams) as session: await session.initialize() - # Call the tool that echoes context tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"}) assert len(tool_result.content) == 1 - context_data = json.loads( - tool_result.content[0].text if tool_result.content[0].type == "text" else "{}" - ) - contexts.append(context_data) + content = tool_result.content[0] + assert isinstance(content, TextContent) + contexts.append(json.loads(content.text)) - # Verify each request had its own context assert len(contexts) == 3 for i, ctx in enumerate(contexts): assert ctx["request_id"] == f"request-{i}" @@ -497,7 +345,7 @@ async def test_request_context_isolation(context_server: None, server_url: str) assert ctx["headers"].get("x-custom-value") == f"value-{i}" -def test_sse_message_id_coercion(): +def test_sse_message_id_coercion() -> None: """Previously, the `RequestId` would coerce a string that looked like an integer into an integer. See for more details. @@ -531,7 +379,7 @@ def test_sse_message_id_coercion(): ("/messages/#fragment", ValueError), ], ) -def test_sse_server_transport_endpoint_validation(endpoint: str, expected_result: str | type[Exception]): +def test_sse_server_transport_endpoint_validation(endpoint: str, expected_result: str | type[Exception]) -> None: """Test that SseServerTransport properly validates and normalizes endpoints.""" if isinstance(expected_result, type): # Test invalid endpoints that should raise an exception @@ -605,7 +453,7 @@ async def mock_aiter_sse() -> AsyncGenerator[ServerSentEvent, None]: @pytest.mark.anyio -async def test_sse_session_cleanup_on_disconnect(server: None, server_url: str) -> None: +async def test_sse_session_cleanup_on_disconnect() -> None: """Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/1227 When a client disconnects, the server should remove the session from @@ -613,18 +461,21 @@ async def test_sse_session_cleanup_on_disconnect(server: None, server_url: str) POST requests to disconnected sessions return 202 Accepted followed by a ClosedResourceError when the server tries to write to the dead stream. """ + factory = in_process_client_factory(make_server_app()) captured: list[str] = [] # Connect a client session, then disconnect - async with sse_client(server_url + "/sse", on_session_created=captured.append) as streams: + async with sse_client( + f"{BASE_URL}/sse", httpx_client_factory=factory, on_session_created=captured.append + ) as streams: async with ClientSession(*streams) as session: await session.initialize() # After disconnect, POST to the stale session should return 404 # (not 202 as it did before the fix) - async with httpx.AsyncClient() as client: + async with factory() as client: response = await client.post( - f"{server_url}/messages/?session_id={captured[0]}", + f"/messages/?session_id={captured[0]}", json={"jsonrpc": "2.0", "method": "ping", "id": 99}, headers={"Content-Type": "application/json"}, ) From 19fe9faec82cd2b901ecdc5d45300f2ac9189ec6 Mon Sep 17 00:00:00 2001 From: Max <224885523+maxisbey@users.noreply.github.com> Date: Wed, 3 Jun 2026 12:45:00 +0100 Subject: [PATCH 81/84] Run StreamableHTTP transport tests in process instead of over sockets (#2767) --- src/mcp/server/streamable_http.py | 8 +- tests/shared/test_streamable_http.py | 2110 ++++++++++++-------------- tests/test_helpers.py | 28 - 3 files changed, 969 insertions(+), 1177 deletions(-) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 98948ff999..2cb4c0748e 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -207,7 +207,7 @@ def close_sse_stream(self, request_id: RequestId) -> None: send_stream.close() receive_stream.close() - def close_standalone_sse_stream(self) -> None: # pragma: no cover + def close_standalone_sse_stream(self) -> None: """Close the standalone GET SSE stream, triggering client reconnection. This method closes the HTTP connection for the standalone GET stream used @@ -221,8 +221,6 @@ def close_standalone_sse_stream(self) -> None: # pragma: no cover This is a no-op if there is no active standalone SSE stream. Requires event_store to be configured for events to be stored during the disconnect. - Currently, client reconnection for standalone GET streams is NOT - implemented - this is a known gap (see test_standalone_get_stream_reconnection). """ self.close_sse_stream(GET_STREAM_KEY) @@ -245,7 +243,7 @@ def _create_session_message( async def close_stream_callback() -> None: self.close_sse_stream(request_id) - async def close_standalone_stream_callback() -> None: # pragma: no cover + async def close_standalone_stream_callback() -> None: self.close_standalone_sse_stream() metadata = ServerMessageMetadata( @@ -421,7 +419,7 @@ async def _validate_accept_header(self, request: Request, scope: Scope, send: Se has_json, has_sse = self._check_accept_headers(request) if self.is_json_response_enabled: # For JSON-only responses, only require application/json - if not has_json: # pragma: no cover + if not has_json: response = self._create_error_response( "Not Acceptable: Client must accept application/json", HTTPStatus.NOT_ACCEPTABLE, diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 3d5770fb61..b43a3361c9 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -1,16 +1,14 @@ """Tests for the StreamableHTTP server and client transport. -Contains tests for both server and client sides of the StreamableHTTP transport. +Contains tests for both server and client sides of the StreamableHTTP transport, driven +entirely in process. """ from __future__ import annotations as _annotations import json -import multiprocessing -import socket import time -import traceback -from collections.abc import AsyncIterator, Generator +from collections.abc import AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass, field from typing import Any @@ -20,8 +18,6 @@ import anyio import httpx import pytest -import requests -import uvicorn from httpx_sse import ServerSentEvent from starlette.applications import Starlette from starlette.requests import Request @@ -46,11 +42,6 @@ from mcp.server.transport_security import TransportSecuritySettings from mcp.shared._context import RequestContext from mcp.shared._context_streams import create_context_streams -from mcp.shared._httpx_utils import ( - MCP_DEFAULT_SSE_READ_TIMEOUT, - MCP_DEFAULT_TIMEOUT, - create_mcp_http_client, -) from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder from mcp.types import ( @@ -66,11 +57,10 @@ TextResourceContents, Tool, ) -from tests.test_helpers import wait_for_server +from tests.interaction.transports import StreamingASGITransport # Test constants SERVER_NAME = "test_streamable_http_server" -TEST_SESSION_ID = "test-session-id-12345" INIT_REQUEST = { "jsonrpc": "2.0", "method": "initialize", @@ -82,9 +72,12 @@ "id": "init-1", } +# The in-process app is mounted at this origin purely so URLs are well-formed; nothing listens here. +BASE_URL = "http://127.0.0.1:8000" + # Helper functions -def extract_protocol_version_from_sse(response: requests.Response) -> str: +def extract_protocol_version_from_sse(response: httpx.Response) -> str: """Extract the negotiated protocol version from an SSE initialization response.""" assert response.headers.get("Content-Type") == "text/event-stream" for line in response.text.splitlines(): @@ -109,32 +102,23 @@ async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage | self._events.append((stream_id, event_id, message)) return event_id - async def replay_events_after( # pragma: no cover + async def replay_events_after( self, last_event_id: EventId, send_callback: EventCallback, ) -> StreamId | None: """Replay events after the specified ID.""" - # Find the stream ID of the last event - target_stream_id = None - for stream_id, event_id, _ in self._events: - if event_id == last_event_id: - target_stream_id = stream_id - break - - if target_stream_id is None: - # If event ID not found, return None - return None + # Find the stream ID of the last event; clients always resume from a stored event. + target_stream_id = next(stream_id for stream_id, event_id, _ in self._events if event_id == last_event_id) # Convert last_event_id to int for comparison last_event_id_int = int(last_event_id) - # Replay only events from the same stream with ID > last_event_id + # Replay only events from the same stream with ID > last_event_id, skipping priming + # events (None message). for stream_id, event_id, message in self._events: - if stream_id == target_stream_id and int(event_id) > last_event_id_int: - # Skip priming events (None message) - if message is not None: - await send_callback(EventMessage(message, event_id)) + if stream_id == target_stream_id and message is not None and int(event_id) > last_event_id_int: + await send_callback(EventMessage(message, event_id)) return target_stream_id @@ -145,26 +129,23 @@ class ServerState: @asynccontextmanager -async def _server_lifespan(_server: Server[ServerState]) -> AsyncIterator[ServerState]: # pragma: no cover +async def _server_lifespan(_server: Server[ServerState]) -> AsyncIterator[ServerState]: yield ServerState() -async def _handle_read_resource( # pragma: no cover +async def _handle_read_resource( ctx: ServerRequestContext[ServerState], params: ReadResourceRequestParams ) -> ReadResourceResult: uri = str(params.uri) parsed = urlparse(uri) if parsed.scheme == "foobar": - text = f"Read {parsed.netloc}" - elif parsed.scheme == "slow": - await anyio.sleep(2.0) - text = f"Slow response from {parsed.netloc}" - else: - raise ValueError(f"Unknown resource: {uri}") - return ReadResourceResult(contents=[TextResourceContents(uri=uri, text=text, mime_type="text/plain")]) + return ReadResourceResult( + contents=[TextResourceContents(uri=uri, text=f"Read {parsed.netloc}", mime_type="text/plain")] + ) + raise ValueError(f"Unknown resource: {uri}") -async def _handle_list_tools( # pragma: no cover +async def _handle_list_tools( ctx: ServerRequestContext[ServerState], params: PaginatedRequestParams | None ) -> ListToolsResult: return ListToolsResult( @@ -179,11 +160,6 @@ async def _handle_list_tools( # pragma: no cover description="A test tool that sends a notification", input_schema={"type": "object", "properties": {}}, ), - Tool( - name="long_running_with_checkpoints", - description="A long-running tool that sends periodic notifications", - input_schema={"type": "object", "properties": {}}, - ), Tool( name="test_sampling_tool", description="A tool that triggers server-side sampling", @@ -209,17 +185,6 @@ async def _handle_list_tools( # pragma: no cover description="Tool that sends notification1, closes stream, sends notification2, notification3", input_schema={"type": "object", "properties": {}}, ), - Tool( - name="tool_with_multiple_stream_closes", - description="Tool that closes SSE stream multiple times during execution", - input_schema={ - "type": "object", - "properties": { - "checkpoints": {"type": "integer", "default": 3}, - "sleep_time": {"type": "number", "default": 0.2}, - }, - }, - ), Tool( name="tool_with_standalone_stream_close", description="Tool that closes standalone GET stream mid-operation", @@ -229,36 +194,14 @@ async def _handle_list_tools( # pragma: no cover ) -async def _handle_call_tool( # pragma: no cover - ctx: ServerRequestContext[ServerState], params: CallToolRequestParams -) -> CallToolResult: +async def _handle_call_tool(ctx: ServerRequestContext[ServerState], params: CallToolRequestParams) -> CallToolResult: name = params.name - args = params.arguments or {} # When the tool is called, send a notification to test GET stream if name == "test_tool_with_standalone_notification": await ctx.session.send_resource_updated(uri="http://test_resource") return CallToolResult(content=[TextContent(type="text", text=f"Called {name}")]) - elif name == "long_running_with_checkpoints": - await ctx.session.send_log_message( - level="info", - data="Tool started", - logger="tool", - related_request_id=ctx.request_id, - ) - - await anyio.sleep(0.1) - - await ctx.session.send_log_message( - level="info", - data="Tool is almost done", - logger="tool", - related_request_id=ctx.request_id, - ) - - return CallToolResult(content=[TextContent(type="text", text="Completed!")]) - elif name == "test_sampling_tool": sampling_result = await ctx.session.create_message( messages=[ @@ -271,15 +214,12 @@ async def _handle_call_tool( # pragma: no cover related_request_id=ctx.request_id, ) - if sampling_result.content.type == "text": - response = sampling_result.content.text - else: - response = str(sampling_result.content) + assert sampling_result.content.type == "text" return CallToolResult( content=[ TextContent( type="text", - text=f"Response from sampling: {response}", + text=f"Response from sampling: {sampling_result.content.text}", ) ] ) @@ -349,31 +289,12 @@ async def _handle_call_tool( # pragma: no cover ) return CallToolResult(content=[TextContent(type="text", text="All notifications sent")]) - elif name == "tool_with_multiple_stream_closes": - num_checkpoints = args.get("checkpoints", 3) - sleep_time = args.get("sleep_time", 0.2) - - for i in range(num_checkpoints): - await ctx.session.send_log_message( - level="info", - data=f"checkpoint_{i}", - logger="multi_close_tool", - related_request_id=ctx.request_id, - ) - - if ctx.close_sse_stream: - await ctx.close_sse_stream() - - await anyio.sleep(sleep_time) - - return CallToolResult(content=[TextContent(type="text", text=f"Completed {num_checkpoints} checkpoints")]) - elif name == "tool_with_standalone_stream_close": await ctx.session.send_resource_updated(uri="http://notification_1") await anyio.sleep(0.1) - if ctx.close_standalone_sse_stream: - await ctx.close_standalone_sse_stream() + assert ctx.close_standalone_sse_stream is not None + await ctx.close_standalone_sse_stream() await anyio.sleep(1.5) await ctx.session.send_resource_updated(uri="http://notification_2") @@ -383,7 +304,7 @@ async def _handle_call_tool( # pragma: no cover return CallToolResult(content=[TextContent(type="text", text=f"Called {name}")]) -def _create_server() -> Server[ServerState]: # pragma: no cover +def _create_server() -> Server[ServerState]: return Server( SERVER_NAME, lifespan=_server_lifespan, @@ -393,113 +314,60 @@ def _create_server() -> Server[ServerState]: # pragma: no cover ) -def create_app( +@asynccontextmanager +async def running_app( is_json_response_enabled: bool = False, event_store: EventStore | None = None, retry_interval: int | None = None, -) -> Starlette: # pragma: no cover - """Create a Starlette application for testing using the session manager. + server: Server[Any] | None = None, +) -> AsyncIterator[Starlette]: + """Serve the test server's streamable HTTP app in process for the duration. Args: is_json_response_enabled: If True, use JSON responses instead of SSE streams. event_store: Optional event store for testing resumability. retry_interval: Retry interval in milliseconds for SSE polling. + server: Server to mount; defaults to the file's shared test server. """ - # Create server instance - server = _create_server() - - # Create the session manager - security_settings = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] - ) + # DNS-rebinding protection validates Host/Origin headers against a network attack that cannot + # exist for an in-process app; the protection itself is pinned by + # tests/server/test_streamable_http_security.py. session_manager = StreamableHTTPSessionManager( - app=server, + app=server if server is not None else _create_server(), event_store=event_store, json_response=is_json_response_enabled, - security_settings=security_settings, + security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False), retry_interval=retry_interval, ) + app = Starlette(routes=[Mount("/mcp", app=session_manager.handle_request)]) + async with session_manager.run(): + yield app - # Create an ASGI application that uses the session manager - app = Starlette( - debug=True, - routes=[ - Mount("/mcp", app=session_manager.handle_request), - ], - lifespan=lambda app: session_manager.run(), - ) - - return app +def make_client(app: Starlette, headers: dict[str, str] | None = None) -> httpx.AsyncClient: + """An httpx client served in process by `app`, with create_mcp_http_client's redirect default. -def run_server( - port: int, - is_json_response_enabled: bool = False, - event_store: EventStore | None = None, - retry_interval: int | None = None, -) -> None: # pragma: no cover - """Run the test server. - - Args: - port: Port to listen on. - is_json_response_enabled: If True, use JSON responses instead of SSE streams. - event_store: Optional event store for testing resumability. - retry_interval: Retry interval in milliseconds for SSE polling. + (Starlette's Mount 307-redirects the bare /mcp path to /mcp/, which the SDK's own client + factory follows.) """ - - app = create_app(is_json_response_enabled, event_store, retry_interval) - # Configure server - config = uvicorn.Config( - app=app, - host="127.0.0.1", - port=port, - log_level="info", - limit_concurrency=10, - timeout_keep_alive=5, - access_log=False, + return httpx.AsyncClient( + transport=StreamingASGITransport(app), base_url=BASE_URL, headers=headers, follow_redirects=True ) - # Start the server - server = uvicorn.Server(config=config) - - # This is important to catch exceptions and prevent test hangs - try: - server.run() - except Exception: - traceback.print_exc() - -# Test fixtures - using same approach as SSE tests +# Test fixtures @pytest.fixture -def basic_server_port() -> int: - """Find an available port for the basic server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] +async def basic_app() -> AsyncIterator[Starlette]: + """The test server's app with SSE response mode.""" + async with running_app() as app: + yield app @pytest.fixture -def json_server_port() -> int: - """Find an available port for the JSON response server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def basic_server(basic_server_port: int) -> Generator[None, None, None]: - """Start a basic server.""" - proc = multiprocessing.Process(target=run_server, kwargs={"port": basic_server_port}, daemon=True) - proc.start() - - # Wait for server to be running - wait_for_server(basic_server_port) - - yield - - # Clean up - proc.kill() - proc.join(timeout=2) +async def json_app() -> AsyncIterator[Starlette]: + """The test server's app with JSON response mode.""" + async with running_app(is_json_response_enabled=True) as app: + yield app @pytest.fixture @@ -509,82 +377,29 @@ def event_store() -> SimpleEventStore: @pytest.fixture -def event_server_port() -> int: - """Find an available port for the event store server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def event_server( - event_server_port: int, event_store: SimpleEventStore -) -> Generator[tuple[SimpleEventStore, str], None, None]: - """Start a server with event store and retry_interval enabled.""" - proc = multiprocessing.Process( - target=run_server, - kwargs={"port": event_server_port, "event_store": event_store, "retry_interval": 500}, - daemon=True, - ) - proc.start() - - # Wait for server to be running - wait_for_server(event_server_port) - - yield event_store, f"http://127.0.0.1:{event_server_port}" - - # Clean up - proc.kill() - proc.join(timeout=2) - - -@pytest.fixture -def json_response_server(json_server_port: int) -> Generator[None, None, None]: - """Start a server with JSON response enabled.""" - proc = multiprocessing.Process( - target=run_server, - kwargs={"port": json_server_port, "is_json_response_enabled": True}, - daemon=True, - ) - proc.start() - - # Wait for server to be running - wait_for_server(json_server_port) - - yield - - # Clean up - proc.kill() - proc.join(timeout=2) - - -@pytest.fixture -def basic_server_url(basic_server_port: int) -> str: - """Get the URL for the basic test server.""" - return f"http://127.0.0.1:{basic_server_port}" - - -@pytest.fixture -def json_server_url(json_server_port: int) -> str: - """Get the URL for the JSON response test server.""" - return f"http://127.0.0.1:{json_server_port}" +async def event_app(event_store: SimpleEventStore) -> AsyncIterator[tuple[SimpleEventStore, Starlette]]: + """The test server's app with an event store and retry_interval enabled.""" + async with running_app(event_store=event_store, retry_interval=500) as app: + yield event_store, app # Basic request validation tests -def test_accept_header_validation(basic_server: None, basic_server_url: str): - """Test that Accept header is properly validated.""" - # Test without Accept header (suppress requests library default Accept: */*) - session = requests.Session() - session.headers.pop("Accept") - response = session.post( - f"{basic_server_url}/mcp", - headers={"Content-Type": "application/json"}, - json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, - ) - assert response.status_code == 406 - assert "Not Acceptable" in response.text +@pytest.mark.anyio +async def test_accept_header_validation(basic_app: Starlette) -> None: + """A POST without an Accept header is rejected with 406.""" + async with make_client(basic_app) as client: + # Suppress the httpx client default Accept: */* header + del client.headers["accept"] + response = await client.post( + "/mcp", + headers={"Content-Type": "application/json"}, + json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text +@pytest.mark.anyio @pytest.mark.parametrize( "accept_header", [ @@ -596,19 +411,21 @@ def test_accept_header_validation(basic_server: None, basic_server_url: str): "application/*;q=0.9, text/*;q=0.8", ], ) -def test_accept_header_wildcard(basic_server: None, basic_server_url: str, accept_header: str): - """Test that wildcard Accept headers are accepted per RFC 7231.""" - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": accept_header, - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 200 +async def test_accept_header_wildcard(basic_app: Starlette, accept_header: str) -> None: + """Wildcard Accept headers are accepted per RFC 7231.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": accept_header, + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 +@pytest.mark.anyio @pytest.mark.parametrize( "accept_header", [ @@ -617,100 +434,104 @@ def test_accept_header_wildcard(basic_server: None, basic_server_url: str, accep "text/*", ], ) -def test_accept_header_incompatible(basic_server: None, basic_server_url: str, accept_header: str): - """Test that incompatible Accept headers are rejected for SSE mode.""" - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": accept_header, - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 406 - assert "Not Acceptable" in response.text - - -def test_content_type_validation(basic_server: None, basic_server_url: str): - """Test that Content-Type header is properly validated.""" - # Test with incorrect Content-Type - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "text/plain", - }, - data="This is not JSON", - ) +async def test_accept_header_incompatible(basic_app: Starlette, accept_header: str) -> None: + """Accept headers that cannot cover both response representations are rejected for SSE mode.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": accept_header, + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text - assert response.status_code == 400 - assert "Invalid Content-Type" in response.text +@pytest.mark.anyio +async def test_content_type_validation(basic_app: Starlette) -> None: + """A POST whose Content-Type is not application/json is rejected with 400.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "text/plain", + }, + content="This is not JSON", + ) + + assert response.status_code == 400 + assert "Invalid Content-Type" in response.text + + +@pytest.mark.anyio +async def test_json_validation(basic_app: Starlette) -> None: + """A POST body that is not valid JSON is rejected with a parse error.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + content="this is not valid json", + ) + assert response.status_code == 400 + assert "Parse error" in response.text -def test_json_validation(basic_server: None, basic_server_url: str): - """Test that JSON content is properly validated.""" - # Test with invalid JSON - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - data="this is not valid json", - ) - assert response.status_code == 400 - assert "Parse error" in response.text - - -def test_json_parsing(basic_server: None, basic_server_url: str): - """Test that JSON content is properly parse.""" - # Test with valid JSON but invalid JSON-RPC - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json={"foo": "bar"}, - ) - assert response.status_code == 400 - assert "Validation error" in response.text - - -def test_method_not_allowed(basic_server: None, basic_server_url: str): - """Test that unsupported HTTP methods are rejected.""" - # Test with unsupported method (PUT) - response = requests.put( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, - ) - assert response.status_code == 405 - assert "Method Not Allowed" in response.text - - -def test_session_validation(basic_server: None, basic_server_url: str): - """Test session ID validation.""" - # session_id not used directly in this test - - # Test without session ID - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json={"jsonrpc": "2.0", "method": "list_tools", "id": 1}, - ) - assert response.status_code == 400 - assert "Missing session ID" in response.text + +@pytest.mark.anyio +async def test_json_parsing(basic_app: Starlette) -> None: + """Valid JSON that is not a JSON-RPC message is rejected with a validation error.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={"foo": "bar"}, + ) + assert response.status_code == 400 + assert "Validation error" in response.text -def test_session_id_pattern(): - """Test that SESSION_ID_PATTERN correctly validates session IDs.""" +@pytest.mark.anyio +async def test_method_not_allowed(basic_app: Starlette) -> None: + """Unsupported HTTP methods are rejected with 405.""" + async with make_client(basic_app) as client: + response = await client.put( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, + ) + assert response.status_code == 405 + assert "Method Not Allowed" in response.text + + +@pytest.mark.anyio +async def test_session_validation(basic_app: Starlette) -> None: + """A non-initialize request without a session ID is rejected with 400.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={"jsonrpc": "2.0", "method": "list_tools", "id": 1}, + ) + assert response.status_code == 400 + assert "Missing session ID" in response.text + + +def test_session_id_pattern() -> None: + """SESSION_ID_PATTERN accepts visible ASCII (0x21-0x7E) and rejects everything else.""" # Valid session IDs (visible ASCII characters from 0x21 to 0x7E) valid_session_ids = [ "test-session-id", @@ -744,8 +565,8 @@ def test_session_id_pattern(): assert SESSION_ID_PATTERN.fullmatch(session_id) is None -def test_streamable_http_transport_init_validation(): - """Test that StreamableHTTPServerTransport validates session ID on init.""" +def test_streamable_http_transport_init_validation() -> None: + """StreamableHTTPServerTransport accepts valid or absent session IDs and rejects invalid ones.""" # Valid session ID should initialize without errors valid_transport = StreamableHTTPServerTransport(mcp_session_id="valid-id") assert valid_transport.mcp_session_id == "valid-id" @@ -767,144 +588,153 @@ def test_streamable_http_transport_init_validation(): StreamableHTTPServerTransport(mcp_session_id="test\n") -def test_session_termination(basic_server: None, basic_server_url: str): - """Test session termination via DELETE and subsequent request handling.""" - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 200 +@pytest.mark.anyio +async def test_session_termination(basic_app: Starlette) -> None: + """DELETE terminates the session, after which requests for it return 404.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 - # Extract negotiated protocol version from SSE response - negotiated_version = extract_protocol_version_from_sse(response) + # Extract negotiated protocol version from SSE response + negotiated_version = extract_protocol_version_from_sse(response) - # Now terminate the session - session_id = response.headers.get(MCP_SESSION_ID_HEADER) - response = requests.delete( - f"{basic_server_url}/mcp", - headers={ - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - ) - assert response.status_code == 200 - - # Try to use the terminated session - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: session_id, - }, - json={"jsonrpc": "2.0", "method": "ping", "id": 2}, - ) - assert response.status_code == 404 - assert "Session has been terminated" in response.text - - -def test_response(basic_server: None, basic_server_url: str): - """Test response handling for a valid request.""" - mcp_url = f"{basic_server_url}/mcp" - response = requests.post( - mcp_url, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 200 + # Now terminate the session + session_id = response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + response = await client.delete( + "/mcp", + headers={ + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, + }, + ) + assert response.status_code == 200 + + # Try to use the terminated session + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + }, + json={"jsonrpc": "2.0", "method": "ping", "id": 2}, + ) + assert response.status_code == 404 + assert "Session has been terminated" in response.text - # Extract negotiated protocol version from SSE response - negotiated_version = extract_protocol_version_from_sse(response) - # Now get the session ID - session_id = response.headers.get(MCP_SESSION_ID_HEADER) +@pytest.mark.anyio +async def test_response(basic_app: Starlette) -> None: + """A request on an initialized session is answered on a text/event-stream response.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 - # Try to use the session with proper headers - tools_response = requests.post( - mcp_url, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: session_id, # Use the session ID we got earlier - MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - json={"jsonrpc": "2.0", "method": "tools/list", "id": "tools-1"}, - stream=True, - ) - assert tools_response.status_code == 200 - assert tools_response.headers.get("Content-Type") == "text/event-stream" - - -def test_json_response(json_response_server: None, json_server_url: str): - """Test response handling when is_json_response_enabled is True.""" - mcp_url = f"{json_server_url}/mcp" - response = requests.post( - mcp_url, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 200 - assert response.headers.get("Content-Type") == "application/json" - - -def test_json_response_accept_json_only(json_response_server: None, json_server_url: str): - """Test that json_response servers only require application/json in Accept header.""" - mcp_url = f"{json_server_url}/mcp" - response = requests.post( - mcp_url, - headers={ - "Accept": "application/json", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 200 - assert response.headers.get("Content-Type") == "application/json" - - -def test_json_response_missing_accept_header(json_response_server: None, json_server_url: str): - """Test that json_response servers reject requests without Accept header.""" - mcp_url = f"{json_server_url}/mcp" - # Suppress requests library default Accept: */* header - session = requests.Session() - session.headers.pop("Accept") - response = session.post( - mcp_url, - headers={ - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 406 - assert "Not Acceptable" in response.text + # Extract negotiated protocol version from SSE response + negotiated_version = extract_protocol_version_from_sse(response) + # Now get the session ID + session_id = response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + + # Try to use the session with proper headers + async with client.stream( + "POST", + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, # Use the session ID we got earlier + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "tools-1"}, + ) as tools_response: + assert tools_response.status_code == 200 + assert tools_response.headers.get("Content-Type") == "text/event-stream" -def test_json_response_incorrect_accept_header(json_response_server: None, json_server_url: str): - """Test that json_response servers reject requests with incorrect Accept header.""" - mcp_url = f"{json_server_url}/mcp" - # Test with only text/event-stream (wrong for JSON server) - response = requests.post( - mcp_url, - headers={ - "Accept": "text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 406 - assert "Not Acceptable" in response.text +@pytest.mark.anyio +async def test_json_response(json_app: Starlette) -> None: + """With JSON response mode enabled, requests are answered with application/json bodies.""" + async with make_client(json_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + assert response.headers.get("Content-Type") == "application/json" + + +@pytest.mark.anyio +async def test_json_response_accept_json_only(json_app: Starlette) -> None: + """JSON response mode only requires application/json in the Accept header.""" + async with make_client(json_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + assert response.headers.get("Content-Type") == "application/json" + + +@pytest.mark.anyio +async def test_json_response_missing_accept_header(json_app: Starlette) -> None: + """JSON response mode still rejects requests without an Accept header.""" + async with make_client(json_app) as client: + # Suppress the httpx client default Accept: */* header + del client.headers["accept"] + response = await client.post( + "/mcp", + headers={ + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text + +@pytest.mark.anyio +async def test_json_response_incorrect_accept_header(json_app: Starlette) -> None: + """JSON response mode rejects an Accept header that does not cover application/json.""" + async with make_client(json_app) as client: + # Test with only text/event-stream (wrong for JSON server) + response = await client.post( + "/mcp", + headers={ + "Accept": "text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text + + +@pytest.mark.anyio @pytest.mark.parametrize( "accept_header", [ @@ -913,167 +743,134 @@ def test_json_response_incorrect_accept_header(json_response_server: None, json_ "application/*;q=0.9", ], ) -def test_json_response_wildcard_accept_header(json_response_server: None, json_server_url: str, accept_header: str): - """Test that json_response servers accept wildcard Accept headers per RFC 7231.""" - mcp_url = f"{json_server_url}/mcp" - response = requests.post( - mcp_url, - headers={ - "Accept": accept_header, - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 200 - assert response.headers.get("Content-Type") == "application/json" - - -def test_get_sse_stream(basic_server: None, basic_server_url: str): - """Test establishing an SSE stream via GET request.""" - # First, we need to initialize a session - mcp_url = f"{basic_server_url}/mcp" - init_response = requests.post( - mcp_url, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert init_response.status_code == 200 +async def test_json_response_wildcard_accept_header(json_app: Starlette, accept_header: str) -> None: + """JSON response mode accepts wildcard Accept headers per RFC 7231.""" + async with make_client(json_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": accept_header, + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + assert response.headers.get("Content-Type") == "application/json" - # Get the session ID - session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) - assert session_id is not None - # Extract negotiated protocol version from SSE response - init_data = None - assert init_response.headers.get("Content-Type") == "text/event-stream" - for line in init_response.text.splitlines(): # pragma: no branch - if line.startswith("data: "): - init_data = json.loads(line[6:]) - break - assert init_data is not None - negotiated_version = init_data["result"]["protocolVersion"] - - # Now attempt to establish an SSE stream via GET - get_response = requests.get( - mcp_url, - headers={ - "Accept": "text/event-stream", - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - stream=True, - ) +@pytest.mark.anyio +async def test_get_sse_stream(basic_app: Starlette) -> None: + """GET establishes the standalone SSE stream, and a second GET is rejected with 409.""" + async with make_client(basic_app) as client: + # First, we need to initialize a session + init_response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 - # Verify we got a successful response with the right content type - assert get_response.status_code == 200 - assert get_response.headers.get("Content-Type") == "text/event-stream" + # Get the session ID + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + negotiated_version = extract_protocol_version_from_sse(init_response) - # Test that a second GET request gets rejected (only one stream allowed) - second_get = requests.get( - mcp_url, - headers={ + # Now attempt to establish an SSE stream via GET + get_headers = { "Accept": "text/event-stream", MCP_SESSION_ID_HEADER: session_id, MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - stream=True, - ) + } + # The streams enter in order, so the second GET arrives while the first is held open. + async with ( + client.stream("GET", "/mcp", headers=get_headers) as get_response, + client.stream("GET", "/mcp", headers=get_headers) as second_get, + ): + # Verify we got a successful response with the right content type + assert get_response.status_code == 200 + assert get_response.headers.get("Content-Type") == "text/event-stream" - # Should get CONFLICT (409) since there's already a stream - # Note: This might fail if the first stream fully closed before this runs, - # but generally it should work in the test environment where it runs quickly - assert second_get.status_code == 409 - - -def test_get_validation(basic_server: None, basic_server_url: str): - """Test validation for GET requests.""" - # First, we need to initialize a session - mcp_url = f"{basic_server_url}/mcp" - init_response = requests.post( - mcp_url, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert init_response.status_code == 200 + # The second GET gets CONFLICT (409): only one standalone stream is allowed per session. + assert second_get.status_code == 409 - # Get the session ID - session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) - assert session_id is not None - # Extract negotiated protocol version from SSE response - init_data = None - assert init_response.headers.get("Content-Type") == "text/event-stream" - for line in init_response.text.splitlines(): # pragma: no branch - if line.startswith("data: "): - init_data = json.loads(line[6:]) - break - assert init_data is not None - negotiated_version = init_data["result"]["protocolVersion"] - - # Test without Accept header (suppress requests library default Accept: */*) - session = requests.Session() - session.headers.pop("Accept") - response = session.get( - mcp_url, - headers={ - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - stream=True, - ) - assert response.status_code == 406 - assert "Not Acceptable" in response.text - - # Test with wrong Accept header - response = requests.get( - mcp_url, - headers={ - "Accept": "application/json", - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - ) - assert response.status_code == 406 - assert "Not Acceptable" in response.text +@pytest.mark.anyio +async def test_get_validation(basic_app: Starlette) -> None: + """A GET without an Accept header covering text/event-stream is rejected with 406.""" + async with make_client(basic_app) as client: + # First, we need to initialize a session + init_response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 + + # Get the session ID + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + negotiated_version = extract_protocol_version_from_sse(init_response) + + # Test without Accept header (suppress the httpx client default Accept: */*) + del client.headers["accept"] + response = await client.get( + "/mcp", + headers={ + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, + }, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text + + # Test with wrong Accept header + response = await client.get( + "/mcp", + headers={ + "Accept": "application/json", + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, + }, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text # Client-specific fixtures @pytest.fixture -async def http_client(basic_server: None, basic_server_url: str): # pragma: no cover - """Create test client matching the SSE test pattern.""" - async with httpx.AsyncClient(base_url=basic_server_url) as client: - yield client - - -@pytest.fixture -async def initialized_client_session(basic_server: None, basic_server_url: str): +async def initialized_client_session(basic_app: Starlette) -> AsyncIterator[ClientSession]: """Create initialized StreamableHTTP client session.""" - async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - yield session + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + await session.initialize() + yield session @pytest.mark.anyio -async def test_streamable_http_client_basic_connection(basic_server: None, basic_server_url: str): - """Test basic client connection with initialization.""" - async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == SERVER_NAME +async def test_streamable_http_client_basic_connection(basic_app: Starlette) -> None: + """A client initializes against a server over the StreamableHTTP transport.""" + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.server_info.name == SERVER_NAME @pytest.mark.anyio -async def test_streamable_http_client_resource_read(initialized_client_session: ClientSession): - """Test client resource read functionality.""" +async def test_streamable_http_client_resource_read(initialized_client_session: ClientSession) -> None: + """A resource read round-trips its arguments and the handler's content.""" response = await initialized_client_session.read_resource(uri="foobar://test-resource") assert len(response.contents) == 1 assert response.contents[0].uri == "foobar://test-resource" @@ -1082,11 +879,11 @@ async def test_streamable_http_client_resource_read(initialized_client_session: @pytest.mark.anyio -async def test_streamable_http_client_tool_invocation(initialized_client_session: ClientSession): - """Test client tool invocation.""" +async def test_streamable_http_client_tool_invocation(initialized_client_session: ClientSession) -> None: + """A tool call reaches the handler and returns its content.""" # First list tools tools = await initialized_client_session.list_tools() - assert len(tools.tools) == 10 + assert len(tools.tools) == 8 assert tools.tools[0].name == "test_tool" # Call the tool @@ -1097,8 +894,8 @@ async def test_streamable_http_client_tool_invocation(initialized_client_session @pytest.mark.anyio -async def test_streamable_http_client_error_handling(initialized_client_session: ClientSession): - """Test error handling in client.""" +async def test_streamable_http_client_error_handling(initialized_client_session: ClientSession) -> None: + """A server-side error reaches the client as an MCPError with the handler's message.""" with pytest.raises(MCPError) as exc_info: await initialized_client_session.read_resource(uri="unknown://test-error") assert exc_info.value.error.code == 0 @@ -1106,50 +903,56 @@ async def test_streamable_http_client_error_handling(initialized_client_session: @pytest.mark.anyio -async def test_streamable_http_client_session_persistence(basic_server: None, basic_server_url: str): - """Test that session ID persists across requests.""" - async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Initialize the session - result = await session.initialize() - assert isinstance(result, InitializeResult) - - # Make multiple requests to verify session persistence - tools = await session.list_tools() - assert len(tools.tools) == 10 - - # Read a resource - resource = await session.read_resource(uri="foobar://test-persist") - assert isinstance(resource.contents[0], TextResourceContents) is True - content = resource.contents[0] - assert isinstance(content, TextResourceContents) - assert content.text == "Read test-persist" +async def test_streamable_http_client_session_persistence(basic_app: Starlette) -> None: + """The session persists across multiple requests on one connection.""" + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Make multiple requests to verify session persistence + tools = await session.list_tools() + assert len(tools.tools) == 8 + + # Read a resource + resource = await session.read_resource(uri="foobar://test-persist") + assert isinstance(resource.contents[0], TextResourceContents) is True + content = resource.contents[0] + assert isinstance(content, TextResourceContents) + assert content.text == "Read test-persist" @pytest.mark.anyio -async def test_streamable_http_client_json_response(json_response_server: None, json_server_url: str): - """Test client with JSON response mode.""" - async with streamable_http_client(f"{json_server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Initialize the session - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == SERVER_NAME - - # Check tool listing - tools = await session.list_tools() - assert len(tools.tools) == 10 - - # Call a tool and verify JSON response handling - result = await session.call_tool("test_tool", {}) - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert result.content[0].text == "Called test_tool" +async def test_streamable_http_client_json_response(json_app: Starlette) -> None: + """The client works identically against a server in JSON response mode.""" + async with ( + make_client(json_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.server_info.name == SERVER_NAME + + # Check tool listing + tools = await session.list_tools() + assert len(tools.tools) == 8 + + # Call a tool and verify JSON response handling + result = await session.call_tool("test_tool", {}) + assert len(result.content) == 1 + assert result.content[0].type == "text" + assert result.content[0].text == "Called test_tool" @pytest.mark.anyio -async def test_streamable_http_client_get_stream(basic_server: None, basic_server_url: str): - """Test GET stream functionality for server-initiated messages.""" +async def test_streamable_http_client_get_stream(basic_app: Starlette) -> None: + """A server-initiated notification reaches the client on the standalone GET stream.""" notifications_received: list[types.ServerNotification] = [] # Define message handler to capture notifications @@ -1159,30 +962,33 @@ async def message_handler( # pragma: no branch if isinstance(message, types.ServerNotification): # pragma: no branch notifications_received.append(message) - async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: - # Initialize the session - this triggers the GET stream setup - result = await session.initialize() - assert isinstance(result, InitializeResult) + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream, message_handler=message_handler) as session, + ): + # Initialize the session - this triggers the GET stream setup + result = await session.initialize() + assert isinstance(result, InitializeResult) - # Call the special tool that sends a notification - await session.call_tool("test_tool_with_standalone_notification", {}) + # Call the special tool that sends a notification + await session.call_tool("test_tool_with_standalone_notification", {}) - # Verify we received the notification - assert len(notifications_received) > 0 + # Verify we received the notification + assert len(notifications_received) > 0 - # Verify the notification is a ResourceUpdatedNotification - resource_update_found = False - for notif in notifications_received: - if isinstance(notif, types.ResourceUpdatedNotification): # pragma: no branch - assert str(notif.params.uri) == "http://test_resource" - resource_update_found = True + # Verify the notification is a ResourceUpdatedNotification + resource_update_found = False + for notif in notifications_received: + if isinstance(notif, types.ResourceUpdatedNotification): # pragma: no branch + assert str(notif.params.uri) == "http://test_resource" + resource_update_found = True - assert resource_update_found, "ResourceUpdatedNotification not received via GET stream" + assert resource_update_found, "ResourceUpdatedNotification not received via GET stream" -def create_session_id_capturing_client() -> tuple[httpx.AsyncClient, list[str]]: - """Create an httpx client that captures the session ID from responses.""" +def create_session_id_capturing_client(app: Starlette) -> tuple[httpx.AsyncClient, list[str]]: + """Create an in-process httpx client that captures the session ID from responses.""" captured_ids: list[str] = [] async def capture_session_id(response: httpx.Response) -> None: @@ -1191,21 +997,22 @@ async def capture_session_id(response: httpx.Response) -> None: captured_ids.append(session_id) client = httpx.AsyncClient( + transport=StreamingASGITransport(app), + base_url=BASE_URL, follow_redirects=True, - timeout=httpx.Timeout(MCP_DEFAULT_TIMEOUT, read=MCP_DEFAULT_SSE_READ_TIMEOUT), event_hooks={"response": [capture_session_id]}, ) return client, captured_ids @pytest.mark.anyio -async def test_streamable_http_client_session_termination(basic_server: None, basic_server_url: str): - """Test client session termination functionality.""" +async def test_streamable_http_client_session_termination(basic_app: Starlette) -> None: + """After the client terminates its session on close, a new connection with that session ID fails.""" # Use httpx client with event hooks to capture session ID - httpx_client, captured_ids = create_session_id_capturing_client() + httpx_client, captured_ids = create_session_id_capturing_client(basic_app) async with httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client) as ( read_stream, write_stream, ): @@ -1220,10 +1027,10 @@ async def test_streamable_http_client_session_termination(basic_server: None, ba # Make a request to confirm session is working tools = await session.list_tools() - assert len(tools.tools) == 10 + assert len(tools.tools) == 8 - async with create_mcp_http_client(headers=headers) as httpx_client2: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client2) as ( + async with make_client(basic_app, headers=headers) as httpx_client2: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client2) as ( read_stream, write_stream, ): @@ -1235,9 +1042,9 @@ async def test_streamable_http_client_session_termination(basic_server: None, ba @pytest.mark.anyio async def test_streamable_http_client_session_termination_204( - basic_server: None, basic_server_url: str, monkeypatch: pytest.MonkeyPatch -): - """Test client session termination functionality with a 204 response. + basic_app: Starlette, monkeypatch: pytest.MonkeyPatch +) -> None: + """Session termination also succeeds when the server answers the DELETE with 204. This test patches the httpx client to return a 204 response for DELETEs. """ @@ -1263,10 +1070,10 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt monkeypatch.setattr(httpx.AsyncClient, "delete", mock_delete) # Use httpx client with event hooks to capture session ID - httpx_client, captured_ids = create_session_id_capturing_client() + httpx_client, captured_ids = create_session_id_capturing_client(basic_app) async with httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client) as ( read_stream, write_stream, ): @@ -1281,10 +1088,10 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt # Make a request to confirm session is working tools = await session.list_tools() - assert len(tools.tools) == 10 + assert len(tools.tools) == 8 - async with create_mcp_http_client(headers=headers) as httpx_client2: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client2) as ( + async with make_client(basic_app, headers=headers) as httpx_client2: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client2) as ( read_stream, write_stream, ): @@ -1295,14 +1102,15 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt @pytest.mark.anyio -async def test_streamable_http_client_resumption(event_server: tuple[SimpleEventStore, str]): - """Test client session resumption using sync primitives for reliable coordination.""" - _, server_url = event_server +async def test_streamable_http_client_resumption(event_app: tuple[SimpleEventStore, Starlette]) -> None: + """A second client resumes an interrupted request with a resumption token and receives the rest.""" + _, app = event_app # Variables to track the state captured_resumption_token: str | None = None captured_notifications: list[types.ServerNotification] = [] - first_notification_received = False + first_notification_received = anyio.Event() + resumption_token_received = anyio.Event() async def message_handler( # pragma: no branch message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, @@ -1312,19 +1120,19 @@ async def message_handler( # pragma: no branch # Look for our first notification if isinstance(message, types.LoggingMessageNotification): # pragma: no branch if message.params.data == "First notification before lock": - nonlocal first_notification_received - first_notification_received = True + first_notification_received.set() async def on_resumption_token_update(token: str) -> None: nonlocal captured_resumption_token captured_resumption_token = token + resumption_token_received.set() # Use httpx client with event hooks to capture session ID - httpx_client, captured_ids = create_session_id_capturing_client() + httpx_client, captured_ids = create_session_id_capturing_client(app) # First, start the client session and begin the tool that waits on lock async with httpx_client: - async with streamable_http_client(f"{server_url}/mcp", terminate_on_close=False, http_client=httpx_client) as ( + async with streamable_http_client(f"{BASE_URL}/mcp", terminate_on_close=False, http_client=httpx_client) as ( read_stream, write_stream, ): @@ -1363,13 +1171,13 @@ async def run_tool(): tg.start_soon(run_tool) # Wait for the first notification and resumption token - while not first_notification_received or not captured_resumption_token: - await anyio.sleep(0.1) + with anyio.fail_after(5): + await first_notification_received.wait() + await resumption_token_received.wait() - # The while loop only exits after first_notification_received=True, - # which is set by message_handler immediately after appending to - # captured_notifications. The server tool is blocked on its lock, - # so nothing else can arrive before we cancel. + # first_notification_received is set by message_handler immediately + # after appending to captured_notifications. The server tool is + # blocked on its lock, so nothing else can arrive before we cancel. assert len(captured_notifications) == 1 assert isinstance(captured_notifications[0], types.LoggingMessageNotification) assert captured_notifications[0].params.data == "First notification before lock" @@ -1379,8 +1187,8 @@ async def run_tool(): # Kill the client session while tool is waiting on lock tg.cancel_scope.cancel() - async with create_mcp_http_client(headers=headers) as httpx_client2: - async with streamable_http_client(f"{server_url}/mcp", http_client=httpx_client2) as ( + async with make_client(app, headers=headers) as httpx_client2: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client2) as ( read_stream, write_stream, ): @@ -1413,8 +1221,8 @@ async def run_tool(): @pytest.mark.anyio -async def test_streamablehttp_server_sampling(basic_server: None, basic_server_url: str): - """Test server-initiated sampling request through streamable HTTP transport.""" +async def test_streamablehttp_server_sampling(basic_app: Starlette) -> None: + """A server-initiated sampling request reaches the client callback and its result the tool.""" # Variable to track if sampling callback was invoked sampling_callback_invoked = False captured_message_params = None @@ -1441,29 +1249,32 @@ async def sampling_callback( ) # Create client with sampling callback - async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, sampling_callback=sampling_callback) as session: - # Initialize the session - result = await session.initialize() - assert isinstance(result, InitializeResult) - - # Call the tool that triggers server-side sampling - tool_result = await session.call_tool("test_sampling_tool", {}) - - # Verify the tool result contains the expected content - assert len(tool_result.content) == 1 - assert tool_result.content[0].type == "text" - assert "Response from sampling: Received message from server" in tool_result.content[0].text - - # Verify sampling callback was invoked - assert sampling_callback_invoked - assert captured_message_params is not None - assert len(captured_message_params.messages) == 1 - assert captured_message_params.messages[0].content.text == "Server needs client sampling" + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream, sampling_callback=sampling_callback) as session, + ): + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Call the tool that triggers server-side sampling + tool_result = await session.call_tool("test_sampling_tool", {}) + + # Verify the tool result contains the expected content + assert len(tool_result.content) == 1 + assert tool_result.content[0].type == "text" + assert "Response from sampling: Received message from server" in tool_result.content[0].text + + # Verify sampling callback was invoked + assert sampling_callback_invoked + assert captured_message_params is not None + assert len(captured_message_params.messages) == 1 + assert captured_message_params.messages[0].content.text == "Server needs client sampling" # Context-aware server implementation for testing request context propagation -async def _handle_context_list_tools( # pragma: no cover +async def _handle_context_list_tools( ctx: ServerRequestContext, params: PaginatedRequestParams | None ) -> ListToolsResult: return ListToolsResult( @@ -1488,97 +1299,51 @@ async def _handle_context_list_tools( # pragma: no cover ) -async def _handle_context_call_tool( # pragma: no cover - ctx: ServerRequestContext, params: CallToolRequestParams -) -> CallToolResult: - name = params.name - args = params.arguments or {} - - if name == "echo_headers": - headers_info: dict[str, Any] = {} - if ctx.request and isinstance(ctx.request, Request): - headers_info = dict(ctx.request.headers) - return CallToolResult(content=[TextContent(type="text", text=json.dumps(headers_info))]) - - elif name == "echo_context": - context_data: dict[str, Any] = { - "request_id": args.get("request_id"), - "headers": {}, - "method": None, - "path": None, - } - if ctx.request and isinstance(ctx.request, Request): - request = ctx.request - context_data["headers"] = dict(request.headers) - context_data["method"] = request.method - context_data["path"] = request.url.path - return CallToolResult(content=[TextContent(type="text", text=json.dumps(context_data))]) +async def _handle_context_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + assert params.name in ("echo_headers", "echo_context") + assert isinstance(ctx.request, Request) + + if params.name == "echo_headers": + return CallToolResult(content=[TextContent(type="text", text=json.dumps(dict(ctx.request.headers)))]) - return CallToolResult(content=[TextContent(type="text", text=f"Unknown tool: {name}")]) + assert params.arguments is not None + context_data: dict[str, Any] = { + "request_id": params.arguments.get("request_id"), + "headers": dict(ctx.request.headers), + "method": ctx.request.method, + "path": ctx.request.url.path, + } + return CallToolResult(content=[TextContent(type="text", text=json.dumps(context_data))]) -# Server runner for context-aware testing -def run_context_aware_server(port: int): # pragma: no cover - """Run the context-aware test server.""" +@pytest.fixture +async def context_app() -> AsyncIterator[Starlette]: + """An app whose server echoes request context, served in process.""" server = Server( "ContextAwareServer", on_list_tools=_handle_context_list_tools, on_call_tool=_handle_context_call_tool, ) - session_manager = StreamableHTTPSessionManager( app=server, - event_store=None, - json_response=False, + security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False), ) - - app = Starlette( - debug=True, - routes=[ - Mount("/mcp", app=session_manager.handle_request), - ], - lifespan=lambda app: session_manager.run(), - ) - - server_instance = uvicorn.Server( - config=uvicorn.Config( - app=app, - host="127.0.0.1", - port=port, - log_level="error", - ) - ) - server_instance.run() - - -@pytest.fixture -def context_aware_server(basic_server_port: int) -> Generator[None, None, None]: - """Start the context-aware server in a separate process.""" - proc = multiprocessing.Process(target=run_context_aware_server, args=(basic_server_port,), daemon=True) - proc.start() - - # Wait for server to be running - wait_for_server(basic_server_port) - - yield - - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("Context-aware server process failed to terminate") + app = Starlette(routes=[Mount("/mcp", app=session_manager.handle_request)]) + async with session_manager.run(): + yield app @pytest.mark.anyio -async def test_streamablehttp_request_context_propagation(context_aware_server: None, basic_server_url: str) -> None: - """Test that request context is properly propagated through StreamableHTTP.""" +async def test_streamablehttp_request_context_propagation(context_app: Starlette) -> None: + """Custom HTTP headers on the connection are visible to server handlers via ctx.request.""" custom_headers = { "Authorization": "Bearer test-token", "X-Custom-Header": "test-value", "X-Trace-Id": "trace-123", } - async with create_mcp_http_client(headers=custom_headers) as httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + async with make_client(context_app, headers=custom_headers) as httpx_client: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client) as ( read_stream, write_stream, ): @@ -1602,11 +1367,11 @@ async def test_streamablehttp_request_context_propagation(context_aware_server: @pytest.mark.anyio -async def test_streamablehttp_request_context_isolation(context_aware_server: None, basic_server_url: str) -> None: - """Test that request contexts are isolated between StreamableHTTP clients.""" +async def test_streamablehttp_request_context_isolation(context_app: Starlette) -> None: + """Each connection's handlers see only that connection's request headers.""" contexts: list[dict[str, Any]] = [] - # Create multiple clients with different headers + # Connect three clients in turn, each with its own headers. for i in range(3): headers = { "X-Request-Id": f"request-{i}", @@ -1614,8 +1379,8 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No "Authorization": f"Bearer token-{i}", } - async with create_mcp_http_client(headers=headers) as httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + async with make_client(context_app, headers=headers) as httpx_client: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client) as ( read_stream, write_stream, ): @@ -1640,145 +1405,160 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No @pytest.mark.anyio -async def test_client_includes_protocol_version_header_after_init(context_aware_server: None, basic_server_url: str): - """Test that client includes mcp-protocol-version header after initialization.""" - async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Initialize and get the negotiated version - init_result = await session.initialize() - negotiated_version = init_result.protocol_version - - # Call a tool that echoes headers to verify the header is present - tool_result = await session.call_tool("echo_headers", {}) - - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - headers_data = json.loads(tool_result.content[0].text) - - # Verify protocol version header is present - assert "mcp-protocol-version" in headers_data - assert headers_data[MCP_PROTOCOL_VERSION_HEADER] == negotiated_version - - -def test_server_validates_protocol_version_header(basic_server: None, basic_server_url: str): - """Test that server returns 400 Bad Request version if header unsupported or invalid.""" - # First initialize a session to get a valid session ID - init_response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert init_response.status_code == 200 - session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) - - # Test request with invalid protocol version (should fail) - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: "invalid-version", - }, - json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-2"}, - ) - assert response.status_code == 400 - assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower() - - # Test request with unsupported protocol version (should fail) - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: "1999-01-01", # Very old unsupported version - }, - json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-3"}, - ) - assert response.status_code == 400 - assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower() +async def test_client_includes_protocol_version_header_after_init(context_app: Starlette) -> None: + """After initialization, every client request carries the negotiated protocol version header.""" + async with ( + make_client(context_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + # Initialize and get the negotiated version + init_result = await session.initialize() + negotiated_version = init_result.protocol_version - # Test request with valid protocol version (should succeed) - negotiated_version = extract_protocol_version_from_sse(init_response) + # Call a tool that echoes headers to verify the header is present + tool_result = await session.call_tool("echo_headers", {}) - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-4"}, - ) - assert response.status_code == 200 - - -def test_server_backwards_compatibility_no_protocol_version(basic_server: None, basic_server_url: str): - """Test server accepts requests without protocol version header.""" - # First initialize a session to get a valid session ID - init_response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert init_response.status_code == 200 - session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) - - # Test request without mcp-protocol-version header (backwards compatibility) - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: session_id, - }, - json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-backwards-compat"}, - stream=True, - ) - assert response.status_code == 200 # Should succeed for backwards compatibility - assert response.headers.get("Content-Type") == "text/event-stream" + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + headers_data = json.loads(tool_result.content[0].text) + + # Verify protocol version header is present + assert "mcp-protocol-version" in headers_data + assert headers_data[MCP_PROTOCOL_VERSION_HEADER] == negotiated_version + + +@pytest.mark.anyio +async def test_server_validates_protocol_version_header(basic_app: Starlette) -> None: + """An invalid or unsupported protocol version header is rejected with 400; the negotiated one passes.""" + async with make_client(basic_app) as client: + # First initialize a session to get a valid session ID + init_response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + + # Test request with invalid protocol version (should fail) + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: "invalid-version", + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-2"}, + ) + assert response.status_code == 400 + assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower() + + # Test request with unsupported protocol version (should fail) + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: "1999-01-01", # Very old unsupported version + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-3"}, + ) + assert response.status_code == 400 + assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower() + + # Test request with valid protocol version (should succeed) + negotiated_version = extract_protocol_version_from_sse(init_response) + + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-4"}, + ) + assert response.status_code == 200 @pytest.mark.anyio -async def test_client_crash_handled(basic_server: None, basic_server_url: str): - """Test that cases where the client crashes are handled gracefully.""" +async def test_server_backwards_compatibility_no_protocol_version(basic_app: Starlette) -> None: + """A request without a protocol version header is accepted for backwards compatibility.""" + async with make_client(basic_app) as client: + # First initialize a session to get a valid session ID + init_response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + + # Test request without mcp-protocol-version header (backwards compatibility) + async with client.stream( + "POST", + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-backwards-compat"}, + ) as response: + assert response.status_code == 200 # Should succeed for backwards compatibility + assert response.headers.get("Content-Type") == "text/event-stream" + + +@pytest.mark.anyio +async def test_client_crash_handled(basic_app: Starlette) -> None: + """A client crashing mid-session does not prevent later clients from connecting.""" # Simulate bad client that crashes after init async def bad_client(): """Client that triggers ClosedResourceError""" - async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - raise Exception("client crash") + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + await session.initialize() + raise Exception("client crash") - # Run bad client a few times to trigger the crash + # Run bad client a few times to trigger the crash. The crash surfaces wrapped in exception + # groups whose exact shape is not the subject here — what matters is that the server survives. for _ in range(3): try: await bad_client() except Exception: pass - await anyio.sleep(0.1) # Try a good client, it should still be able to connect and list tools - async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - result = await session.initialize() - assert isinstance(result, InitializeResult) - tools = await session.list_tools() - assert tools.tools + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + result = await session.initialize() + assert isinstance(result, InitializeResult) + tools = await session.list_tools() + assert tools.tools @pytest.mark.anyio -async def test_handle_sse_event_skips_empty_data(): - """Test that _handle_sse_event skips empty SSE data (keep-alive pings).""" +async def test_handle_sse_event_skips_empty_data() -> None: + """_handle_sse_event skips empty SSE data (keep-alive pings) without writing to the stream.""" transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") # Create a mock SSE event with empty data (keep-alive ping) @@ -1804,8 +1584,8 @@ async def test_handle_sse_event_skips_empty_data(): @pytest.mark.anyio -async def test_priming_event_not_sent_for_old_protocol_version(): - """Test that _maybe_send_priming_event skips for old protocol versions (backwards compat).""" +async def test_priming_event_not_sent_for_old_protocol_version() -> None: + """_maybe_send_priming_event skips for old protocol versions (backwards compat).""" # Create a transport with an event store transport = StreamableHTTPServerTransport( "/mcp", @@ -1833,8 +1613,8 @@ async def test_priming_event_not_sent_for_old_protocol_version(): @pytest.mark.anyio -async def test_priming_event_not_sent_without_event_store(): - """Test that _maybe_send_priming_event returns early when no event_store is configured.""" +async def test_priming_event_not_sent_without_event_store() -> None: + """_maybe_send_priming_event returns early when no event_store is configured.""" # Create a transport WITHOUT an event store transport = StreamableHTTPServerTransport("/mcp") @@ -1853,8 +1633,8 @@ async def test_priming_event_not_sent_without_event_store(): @pytest.mark.anyio -async def test_priming_event_includes_retry_interval(): - """Test that _maybe_send_priming_event includes retry field when retry_interval is set.""" +async def test_priming_event_includes_retry_interval() -> None: + """_maybe_send_priming_event includes the retry field when retry_interval is set.""" # Create a transport with an event store AND retry_interval transport = StreamableHTTPServerTransport( "/mcp", @@ -1882,8 +1662,8 @@ async def test_priming_event_includes_retry_interval(): @pytest.mark.anyio -async def test_close_sse_stream_callback_not_provided_for_old_protocol_version(): - """Test that close_sse_stream callbacks are NOT provided for old protocol versions.""" +async def test_close_sse_stream_callback_not_provided_for_old_protocol_version() -> None: + """close_sse_stream callbacks are only provided for protocol versions that support polling.""" # Create a transport with an event store transport = StreamableHTTPServerTransport( "/mcp", @@ -1915,71 +1695,76 @@ async def test_close_sse_stream_callback_not_provided_for_old_protocol_version() @pytest.mark.anyio async def test_streamable_http_client_receives_priming_event( - event_server: tuple[SimpleEventStore, str], + event_app: tuple[SimpleEventStore, Starlette], ) -> None: """Client should receive priming event (resumption token update) on POST SSE stream.""" - _, server_url = event_server + _, app = event_app captured_resumption_tokens: list[str] = [] async def on_resumption_token_update(token: str) -> None: captured_resumption_tokens.append(token) - async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + await session.initialize() - # Call tool with resumption token callback via send_request - metadata = ClientMessageMetadata( - on_resumption_token_update=on_resumption_token_update, - ) - result = await session.send_request( - types.CallToolRequest(params=types.CallToolRequestParams(name="test_tool", arguments={})), - types.CallToolResult, - metadata=metadata, - ) - assert result is not None - - # Should have received priming event token BEFORE response data - # Priming event = 1 token (empty data, id only) - # Response = 1 token (actual JSON-RPC response) - # Total = 2 tokens minimum - assert len(captured_resumption_tokens) >= 2, ( - f"Server must send priming event before response. " - f"Expected >= 2 tokens (priming + response), got {len(captured_resumption_tokens)}" - ) - assert captured_resumption_tokens[0] is not None + # Call tool with resumption token callback via send_request + metadata = ClientMessageMetadata( + on_resumption_token_update=on_resumption_token_update, + ) + result = await session.send_request( + types.CallToolRequest(params=types.CallToolRequestParams(name="test_tool", arguments={})), + types.CallToolResult, + metadata=metadata, + ) + assert result is not None + + # Should have received priming event token BEFORE response data + # Priming event = 1 token (empty data, id only) + # Response = 1 token (actual JSON-RPC response) + # Total = 2 tokens minimum + assert len(captured_resumption_tokens) >= 2, ( + f"Server must send priming event before response. " + f"Expected >= 2 tokens (priming + response), got {len(captured_resumption_tokens)}" + ) + assert captured_resumption_tokens[0] is not None @pytest.mark.anyio async def test_server_close_sse_stream_via_context( - event_server: tuple[SimpleEventStore, str], + event_app: tuple[SimpleEventStore, Starlette], ) -> None: """Server tool can call ctx.close_sse_stream() to close connection.""" - _, server_url = event_server + _, app = event_app - async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + await session.initialize() - # Call tool that closes stream mid-operation - # This should NOT raise NotImplementedError when fully implemented - result = await session.call_tool("tool_with_stream_close", {}) + # Call tool that closes stream mid-operation + result = await session.call_tool("tool_with_stream_close", {}) - # Client should still receive complete response (via auto-reconnect) - assert result is not None - assert len(result.content) > 0 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Done" + # Client should still receive complete response (via auto-reconnect) + assert result is not None + assert len(result.content) > 0 + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Done" @pytest.mark.anyio async def test_streamable_http_client_auto_reconnects( - event_server: tuple[SimpleEventStore, str], + event_app: tuple[SimpleEventStore, Starlette], ) -> None: """Client should auto-reconnect with Last-Event-ID when server closes after priming event.""" - _, server_url = event_server + _, app = event_app captured_notifications: list[str] = [] async def message_handler( @@ -1991,59 +1776,63 @@ async def message_handler( if isinstance(message, types.LoggingMessageNotification): # pragma: no branch captured_notifications.append(str(message.params.data)) - async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: - await session.initialize() - - # Call tool that: - # 1. Sends notification - # 2. Closes SSE stream - # 3. Sends more notifications (stored in event_store) - # 4. Returns response - result = await session.call_tool("tool_with_stream_close", {}) - - # Client should have auto-reconnected and received ALL notifications - assert len(captured_notifications) >= 2, ( - "Client should auto-reconnect and receive notifications sent both before and after stream close" - ) - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Done" + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream, message_handler=message_handler) as session, + ): + await session.initialize() + + # Call tool that: + # 1. Sends notification + # 2. Closes SSE stream + # 3. Sends more notifications (stored in event_store) + # 4. Returns response + result = await session.call_tool("tool_with_stream_close", {}) + + # Client should have auto-reconnected and received ALL notifications + assert len(captured_notifications) >= 2, ( + "Client should auto-reconnect and receive notifications sent both before and after stream close" + ) + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Done" @pytest.mark.anyio async def test_streamable_http_client_respects_retry_interval( - event_server: tuple[SimpleEventStore, str], + event_app: tuple[SimpleEventStore, Starlette], ) -> None: """Client MUST respect retry field, waiting specified ms before reconnecting.""" - _, server_url = event_server + _, app = event_app - async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + await session.initialize() - start_time = time.monotonic() - result = await session.call_tool("tool_with_stream_close", {}) - elapsed = time.monotonic() - start_time + start_time = time.monotonic() + result = await session.call_tool("tool_with_stream_close", {}) + elapsed = time.monotonic() - start_time - # Verify result was received - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Done" + # Verify result was received + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Done" - # The elapsed time should include at least the retry interval - # if reconnection occurred. This test may be flaky depending on - # implementation details, but demonstrates the expected behavior. - # Note: This assertion may need adjustment based on actual implementation - assert elapsed >= 0.4, f"Client should wait ~500ms before reconnecting, but elapsed time was {elapsed:.3f}s" + # The elapsed time should include at least the retry interval (500ms) before + # the client reconnected; the tool's own work only accounts for ~100ms. + assert elapsed >= 0.4, f"Client should wait ~500ms before reconnecting, but elapsed time was {elapsed:.3f}s" @pytest.mark.anyio async def test_streamable_http_sse_polling_full_cycle( - event_server: tuple[SimpleEventStore, str], + event_app: tuple[SimpleEventStore, Starlette], ) -> None: """End-to-end test: server closes stream, client reconnects, receives all events.""" - _, server_url = event_server + _, app = event_app all_notifications: list[str] = [] async def message_handler( @@ -2055,35 +1844,38 @@ async def message_handler( if isinstance(message, types.LoggingMessageNotification): # pragma: no branch all_notifications.append(str(message.params.data)) - async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: - await session.initialize() - - # Call tool that simulates polling pattern: - # 1. Server sends priming event - # 2. Server sends "Before close" notification - # 3. Server closes stream (calls close_sse_stream) - # 4. (client reconnects automatically) - # 5. Server sends "After close" notification - # 6. Server sends final response - result = await session.call_tool("tool_with_stream_close", {}) - - # Verify all notifications received in order - assert "Before close" in all_notifications, "Should receive notification sent before stream close" - assert "After close" in all_notifications, ( - "Should receive notification sent after stream close (via auto-reconnect)" - ) - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Done" + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream, message_handler=message_handler) as session, + ): + await session.initialize() + + # Call tool that simulates polling pattern: + # 1. Server sends priming event + # 2. Server sends "Before close" notification + # 3. Server closes stream (calls close_sse_stream) + # 4. (client reconnects automatically) + # 5. Server sends "After close" notification + # 6. Server sends final response + result = await session.call_tool("tool_with_stream_close", {}) + + # Verify all notifications received in order + assert "Before close" in all_notifications, "Should receive notification sent before stream close" + assert "After close" in all_notifications, ( + "Should receive notification sent after stream close (via auto-reconnect)" + ) + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Done" @pytest.mark.anyio async def test_streamable_http_events_replayed_after_disconnect( - event_server: tuple[SimpleEventStore, str], + event_app: tuple[SimpleEventStore, Starlette], ) -> None: """Events sent while client is disconnected should be replayed on reconnect.""" - _, server_url = event_server + _, app = event_app notification_data: list[str] = [] async def message_handler( @@ -2095,37 +1887,43 @@ async def message_handler( if isinstance(message, types.LoggingMessageNotification): # pragma: no branch notification_data.append(str(message.params.data)) - async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: - await session.initialize() + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream, message_handler=message_handler) as session, + ): + await session.initialize() - # Tool sends: notification1, close_stream, notification2, notification3, response - # Client should receive all notifications even though 2&3 were sent during disconnect - result = await session.call_tool("tool_with_multiple_notifications_and_close", {}) + # Tool sends: notification1, close_stream, notification2, notification3, response + # Client should receive all notifications even though 2&3 were sent during disconnect + result = await session.call_tool("tool_with_multiple_notifications_and_close", {}) - assert "notification1" in notification_data, "Should receive notification1 (sent before close)" - assert "notification2" in notification_data, "Should receive notification2 (sent after close, replayed)" - assert "notification3" in notification_data, "Should receive notification3 (sent after close, replayed)" + assert "notification1" in notification_data, "Should receive notification1 (sent before close)" + assert "notification2" in notification_data, "Should receive notification2 (sent after close, replayed)" + assert "notification3" in notification_data, "Should receive notification3 (sent after close, replayed)" - # Verify order: notification1 should come before notification2 and notification3 - idx1 = notification_data.index("notification1") - idx2 = notification_data.index("notification2") - idx3 = notification_data.index("notification3") - assert idx1 < idx2 < idx3, "Notifications should be received in order" + # Verify order: notification1 should come before notification2 and notification3 + idx1 = notification_data.index("notification1") + idx2 = notification_data.index("notification2") + idx3 = notification_data.index("notification3") + assert idx1 < idx2 < idx3, "Notifications should be received in order" - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "All notifications sent" + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "All notifications sent" @pytest.mark.anyio -async def test_streamable_http_multiple_reconnections( - event_server: tuple[SimpleEventStore, str], -): - """Verify multiple close_sse_stream() calls each trigger a client reconnect. +async def test_streamable_http_multiple_reconnections() -> None: + """Every close_sse_stream() severs a live connection and triggers its own client reconnect. - Server uses retry_interval=500ms, tool sleeps 600ms after each close to ensure - client has time to reconnect before the next checkpoint. + The tool closes its SSE stream three times; before each next cycle it waits until the + client has observed the previous cycle's two new resumption tokens (the checkpoint and the + new connection's priming event). The priming event is sent only after the server has + re-registered the resumed stream, so once the client holds its token the next close is + guaranteed to sever a live connection rather than silently no-op — making the exact token + count below a consequence of causality, not timing margins. This pins reconnect-per-close + accounting; reconnect *latency* is pinned by test_streamable_http_client_respects_retry_interval. With 3 checkpoints, we expect 8 resumption tokens: - 1 priming (initial POST connection) @@ -2133,45 +1931,72 @@ async def test_streamable_http_multiple_reconnections( - 3 priming (one per reconnect after each close) - 1 response """ - _, server_url = event_server resumption_tokens: list[str] = [] + # milestones[n] fires when the client has observed n tokens. After the initial priming + # (token 1), each completed cycle i contributes exactly two tokens — checkpoint_i and the + # reconnect's priming, in either order — so cycle i is complete at 3 + 2i tokens. + milestones = {3: anyio.Event(), 5: anyio.Event(), 7: anyio.Event()} async def on_resumption_token(token: str) -> None: resumption_tokens.append(token) + milestone = milestones.get(len(resumption_tokens)) + if milestone is not None: + milestone.set() - async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - - # Use send_request with metadata to track resumption tokens - metadata = ClientMessageMetadata(on_resumption_token_update=on_resumption_token) - result = await session.send_request( - types.CallToolRequest( - method="tools/call", - params=types.CallToolRequestParams( - name="tool_with_multiple_stream_closes", - # retry_interval=500ms, so sleep 600ms to ensure reconnect completes - arguments={"checkpoints": 3, "sleep_time": 0.6}, - ), - ), - types.CallToolResult, - metadata=metadata, + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + assert params.name == "multi_close_tool" + for i, milestone in enumerate(milestones.values()): + await ctx.session.send_log_message( + level="info", + data=f"checkpoint_{i}", + logger="multi_close_tool", + related_request_id=ctx.request_id, ) + assert ctx.close_sse_stream is not None + await ctx.close_sse_stream() + # Client and server share one event loop, so the tool can wait directly on the + # client-side callback observing the reconnect. + with anyio.fail_after(5): + await milestone.wait() + return CallToolResult(content=[TextContent(type="text", text="Completed 3 checkpoints")]) + + server = Server("multi_reconnect_server", on_call_tool=handle_call_tool) + + async with ( + # retry_interval is small to keep the test fast, but nonzero so each dying connection + # finishes unwinding before its replacement registers. + running_app(event_store=SimpleEventStore(), retry_interval=50, server=server) as app, + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + await session.initialize() + + # Use send_request with metadata to track resumption tokens + metadata = ClientMessageMetadata(on_resumption_token_update=on_resumption_token) + result = await session.send_request( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams(name="multi_close_tool", arguments={}), + ), + types.CallToolResult, + metadata=metadata, + ) - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert "Completed 3 checkpoints" in result.content[0].text + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert "Completed 3 checkpoints" in result.content[0].text - # 4 priming + 3 notifications + 1 response = 8 tokens. All tokens are - # captured before send_request returns, so this is safe to check here. - assert len(resumption_tokens) == 8, ( - f"Expected 8 resumption tokens (4 priming + 3 notifs + 1 response), " - f"got {len(resumption_tokens)}: {resumption_tokens}" - ) + # 4 priming + 3 notifications + 1 response = 8 tokens. All tokens are + # captured before send_request returns, so this is safe to check here. + assert len(resumption_tokens) == 8, ( + f"Expected 8 resumption tokens (4 priming + 3 notifs + 1 response), " + f"got {len(resumption_tokens)}: {resumption_tokens}" + ) @pytest.mark.anyio -async def test_standalone_get_stream_reconnection(event_server: tuple[SimpleEventStore, str]) -> None: +async def test_standalone_get_stream_reconnection(event_app: tuple[SimpleEventStore, Starlette]) -> None: """Test that standalone GET stream automatically reconnects after server closes it. Verifies: @@ -2180,10 +2005,10 @@ async def test_standalone_get_stream_reconnection(event_server: tuple[SimpleEven 3. Client reconnects with Last-Event-ID 4. Client receives notification 2 on new connection - Note: Requires event_server fixture (with event store) because close_standalone_sse_stream + Note: Requires the event store app because close_standalone_sse_stream callback is only provided when event_store is configured and protocol version >= 2025-11-25. """ - _, server_url = event_server + _, app = event_app received_notifications: list[str] = [] async def message_handler( @@ -2195,45 +2020,46 @@ async def message_handler( if isinstance(message, types.ResourceUpdatedNotification): # pragma: no branch received_notifications.append(str(message.params.uri)) - async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: - await session.initialize() - - # Call tool that: - # 1. Sends notification_1 via GET stream - # 2. Closes standalone GET stream - # 3. Sends notification_2 (stored in event_store) - # 4. Returns response - result = await session.call_tool("tool_with_standalone_stream_close", {}) - - # Verify the tool completed - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Standalone stream close test done" - - # Verify both notifications were received - assert "http://notification_1" in received_notifications, ( - f"Should receive notification 1 (sent before GET stream close), got: {received_notifications}" - ) - assert "http://notification_2" in received_notifications, ( - f"Should receive notification 2 after reconnect, got: {received_notifications}" - ) + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream, message_handler=message_handler) as session, + ): + await session.initialize() + + # Call tool that: + # 1. Sends notification_1 via GET stream + # 2. Closes standalone GET stream + # 3. Sends notification_2 (stored in event_store) + # 4. Returns response + result = await session.call_tool("tool_with_standalone_stream_close", {}) + + # Verify the tool completed + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Standalone stream close test done" + + # Verify both notifications were received + assert "http://notification_1" in received_notifications, ( + f"Should receive notification 1 (sent before GET stream close), got: {received_notifications}" + ) + assert "http://notification_2" in received_notifications, ( + f"Should receive notification 2 after reconnect, got: {received_notifications}" + ) @pytest.mark.anyio -async def test_streamable_http_client_does_not_mutate_provided_client( - basic_server: None, basic_server_url: str -) -> None: - """Test that streamable_http_client does not mutate the provided httpx client's headers.""" +async def test_streamable_http_client_does_not_mutate_provided_client(basic_app: Starlette) -> None: + """streamable_http_client does not mutate the provided httpx client's headers.""" # Create a client with custom headers original_headers = { "X-Custom-Header": "custom-value", "Authorization": "Bearer test-token", } - async with httpx.AsyncClient(headers=original_headers, follow_redirects=True) as custom_client: + async with make_client(basic_app, headers=original_headers) as custom_client: # Use the client with streamable_http_client - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=custom_client) as ( + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=custom_client) as ( read_stream, write_stream, ): @@ -2254,18 +2080,16 @@ async def test_streamable_http_client_does_not_mutate_provided_client( @pytest.mark.anyio -async def test_streamable_http_client_mcp_headers_override_defaults( - context_aware_server: None, basic_server_url: str -) -> None: - """Test that MCP protocol headers override httpx.AsyncClient default headers.""" +async def test_streamable_http_client_mcp_headers_override_defaults(context_app: Starlette) -> None: + """MCP protocol headers override the httpx client's default headers in actual requests.""" # httpx.AsyncClient has default "accept: */*" header # We need to verify that our MCP accept header overrides it in actual requests - async with httpx.AsyncClient(follow_redirects=True) as client: + async with make_client(context_app) as client: # Verify client has default accept header assert client.headers.get("accept") == "*/*" - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=client) as (read_stream, write_stream): + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=client) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch await session.initialize() @@ -2285,18 +2109,16 @@ async def test_streamable_http_client_mcp_headers_override_defaults( @pytest.mark.anyio -async def test_streamable_http_client_preserves_custom_with_mcp_headers( - context_aware_server: None, basic_server_url: str -) -> None: - """Test that both custom headers and MCP protocol headers are sent in requests.""" +async def test_streamable_http_client_preserves_custom_with_mcp_headers(context_app: Starlette) -> None: + """Custom client headers and MCP protocol headers are both sent in requests.""" custom_headers = { "X-Custom-Header": "custom-value", "X-Request-Id": "req-123", "Authorization": "Bearer test-token", } - async with httpx.AsyncClient(headers=custom_headers, follow_redirects=True) as client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=client) as (read_stream, write_stream): + async with make_client(context_app, headers=custom_headers) as client: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=client) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch await session.initialize() diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 810c72820b..0038b18905 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -2,7 +2,6 @@ import socket import threading -import time from collections.abc import Generator from contextlib import contextmanager from typing import Any @@ -56,30 +55,3 @@ def run_uvicorn_in_thread(app: Any, **config_kwargs: Any) -> Generator[str, None finally: server.should_exit = True thread.join(timeout=_SERVER_SHUTDOWN_TIMEOUT_S) - - -def wait_for_server(port: int, timeout: float = 20.0) -> None: - """Wait for server to be ready to accept connections. - - Polls the server port until it accepts connections or timeout is reached. - This eliminates race conditions without arbitrary sleeps. - - Args: - port: The port number to check - timeout: Maximum time to wait in seconds (default 5.0) - - Raises: - TimeoutError: If server doesn't start within the timeout period - """ - start_time = time.time() - while time.time() - start_time < timeout: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.settimeout(0.1) - s.connect(("127.0.0.1", port)) - # Server is ready - return - except (ConnectionRefusedError, OSError): - # Server not ready yet, retry quickly - time.sleep(0.01) - raise TimeoutError(f"Server on port {port} did not start within {timeout} seconds") # pragma: no cover From bdc48e98b1984bd3eb49ee1cf29b6ea4b0db1b86 Mon Sep 17 00:00:00 2001 From: Max <224885523+maxisbey@users.noreply.github.com> Date: Fri, 5 Jun 2026 16:15:43 +0100 Subject: [PATCH 82/84] Fix stdio client shutdown bugs and rebuild the stdio test suite (#2773) --- docs/migration.md | 40 + src/mcp/client/stdio.py | 368 ++-- src/mcp/os/posix/utilities.py | 84 +- src/mcp/os/win32/utilities.py | 314 ++-- tests/client/test_stdio.py | 1650 +++++++++++++---- tests/interaction/transports/test_stdio.py | 71 +- .../test_1027_win_unreachable_cleanup.py | 240 --- tests/issues/test_552_windows_hang.py | 55 +- tests/server/mcpserver/test_elicitation.py | 14 +- tests/server/test_stdio.py | 145 +- tests/shared/test_win32_utils.py | 10 - tests/transports/__init__.py | 0 tests/transports/stdio/__init__.py | 0 tests/transports/stdio/_liveness.py | 80 + tests/transports/stdio/conftest.py | 77 + tests/transports/stdio/test_lifecycle.py | 276 +++ tests/transports/stdio/test_posix.py | 116 ++ tests/transports/stdio/test_windows.py | 235 +++ 18 files changed, 2645 insertions(+), 1130 deletions(-) delete mode 100644 tests/issues/test_1027_win_unreachable_cleanup.py delete mode 100644 tests/shared/test_win32_utils.py create mode 100644 tests/transports/__init__.py create mode 100644 tests/transports/stdio/__init__.py create mode 100644 tests/transports/stdio/_liveness.py create mode 100644 tests/transports/stdio/conftest.py create mode 100644 tests/transports/stdio/test_lifecycle.py create mode 100644 tests/transports/stdio/test_posix.py create mode 100644 tests/transports/stdio/test_windows.py diff --git a/docs/migration.md b/docs/migration.md index 9850f74cd4..0f5fc91c3d 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -105,6 +105,46 @@ The `headers`, `timeout`, `sse_read_timeout`, and `auth` parameters have been re Note: `sse_client` retains its `headers`, `timeout`, `sse_read_timeout`, and `auth` parameters — only the streamable HTTP transport changed. +### `terminate_windows_process` removed + +The deprecated `mcp.os.win32.utilities.terminate_windows_process` function has been +removed. Process termination is handled internally by the `stdio_client` context +manager; there is no replacement API. The Windows tree-termination helper +`terminate_windows_process_tree` no longer accepts a `timeout_seconds` argument — +the value was never used (Job Object termination is immediate). + +### `stdio_client` no longer kills children of a gracefully-exited server on POSIX + +When a server exits on its own after `stdio_client` closes its stdin, background +child processes the server leaves behind are no longer killed on POSIX — their +lifetime is the server's business. The old behavior was a side effect of a shutdown +wait gated on the stdio pipes closing rather than on process exit: a child holding +an inherited pipe made a well-behaved server look hung, so its whole process tree +was killed. (That gating is an asyncio behavior specific to Python 3.11+ — on +Python 3.10 and the trio backend the old wait already resolved on process exit, so +the spurious kill never fired there.) A server that does not exit within the grace +period is still terminated +along with its entire process group. On Windows, children stay in the server's Job +Object and are still killed at shutdown — now deterministically when the job handle +is closed, rather than whenever the handle happened to be garbage-collected. + +If you relied on `stdio_client` killing everything the server spawned, make the +server terminate its own children on shutdown (its stdin reaching EOF is the +shutdown signal), or clean up the process tree from the host application after +`stdio_client` exits. + +Two related shutdown refinements: `stdio_client` now closes its end of the pipes +deterministically at shutdown, so a surviving child that keeps writing to an +inherited stdout receives `EPIPE`/`SIGPIPE` once the client is gone (previously the +pipe lingered until garbage collection); and a failed write to a server that is +still running now surfaces as a closed connection (`CONNECTION_CLOSED`) on the read +side instead of leaving requests waiting indefinitely. + +`terminate_posix_process_tree` now requires the process to lead its own process +group (spawned with `start_new_session=True`); the `getpgid()` lookup and the +per-process terminate/kill fallback are gone. The win32 utilities logger is now +named `mcp.os.win32.utilities` (was `client.stdio.win32`). + ### Removed type aliases and classes The following deprecated type aliases and classes have been removed from `mcp.types`: diff --git a/src/mcp/client/stdio.py b/src/mcp/client/stdio.py index 902dc8576c..baf7ad1ca1 100644 --- a/src/mcp/client/stdio.py +++ b/src/mcp/client/stdio.py @@ -1,21 +1,33 @@ +"""stdio client transport. + +Runs an MCP server as a subprocess and exchanges newline-delimited JSON-RPC +messages with it over stdin/stdout. Two pipe tasks bridge the server's pipes +to the session's in-memory streams; shutdown follows the MCP spec sequence +(close stdin, wait, then kill the process tree) inside a cancellation shield +with every wait bounded, so a cancelled caller can neither leak a live server +process nor hang on one. +""" + import logging import os import sys -from contextlib import asynccontextmanager +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager, suppress from pathlib import Path from typing import Literal, TextIO import anyio import anyio.lowlevel -from anyio.abc import Process -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from anyio.abc import AsyncResource, Process from anyio.streams.text import TextReceiveStream from pydantic import BaseModel, Field from mcp import types +from mcp.client._transport import TransportStreams from mcp.os.posix.utilities import terminate_posix_process_tree from mcp.os.win32.utilities import ( - FallbackProcess, + ServerProcess, + close_process_job, create_windows_process, get_windows_executable_command, terminate_windows_process_tree, @@ -44,14 +56,24 @@ else ["HOME", "LOGNAME", "PATH", "SHELL", "TERM", "USER"] ) -# Timeout for process termination before falling back to force kill +# Grace period for the server to exit on its own after its stdin closes. PROCESS_TERMINATION_TIMEOUT = 2.0 +# Extra time after SIGTERM before SIGKILL; POSIX only (Windows kills hard). +FORCE_KILL_TIMEOUT = 2.0 + +# Time for the event loop to observe a kill; only an unkillable process runs this out. +_KILL_REAP_TIMEOUT = 2.0 + +# Time for the writer to flush accepted messages before stdin closes. +_WRITER_FLUSH_TIMEOUT = 0.5 + +# How often to poll returncode while waiting for the process to die. +_EXIT_POLL_INTERVAL = 0.01 + def get_default_environment() -> dict[str, str]: - """Returns a default environment object including only environment variables deemed - safe to inherit. - """ + """Returns only the environment variables that are safe to inherit.""" env: dict[str, str] = {} for key in DEFAULT_INHERITED_ENV_VARS: @@ -76,150 +98,227 @@ class StdioServerParameters(BaseModel): """Command line arguments to pass to the executable.""" env: dict[str, str] | None = None - """ - The environment to use when spawning the process. - - If not specified, the result of get_default_environment() will be used. - """ + """Extra environment variables, merged over get_default_environment().""" cwd: str | Path | None = None """The working directory to use when spawning the process.""" encoding: str = "utf-8" - """ - The text encoding used when sending/receiving messages to the server. - - Defaults to utf-8. - """ + """Text encoding for messages to and from the server.""" encoding_error_handler: Literal["strict", "ignore", "replace"] = "strict" - """ - The text encoding error handler. - - See https://docs.python.org/3/library/codecs.html#codec-base-classes for - explanations of possible values. - """ + """Encoding error handler; see https://docs.python.org/3/library/codecs.html#error-handlers.""" @asynccontextmanager -async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stderr): - """Client transport for stdio: this will connect to a server by spawning a - process and communicating with it over stdin/stdout. +async def stdio_client( + server: StdioServerParameters, errlog: TextIO = sys.stderr +) -> AsyncGenerator[TransportStreams, None]: + """Spawns an MCP server subprocess and connects to it over stdin/stdout. + + Raises: + OSError: If the server process cannot be spawned. + ValueError: If the spawn parameters are invalid (embedded NUL bytes). """ - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] - read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] + command = _get_executable_command(server.command) - write_stream: MemoryObjectSendStream[SessionMessage] - write_stream_reader: MemoryObjectReceiveStream[SessionMessage] + process = await _create_platform_compatible_process( + command=command, + args=server.args, + env=get_default_environment() | (server.env or {}), + errlog=errlog, + cwd=server.cwd, + ) - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + # The spawn succeeded; no awaits until the task group is entered, or a + # cancellation delivered in the gap would leak the live process. + read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) + write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) - try: - command = _get_executable_command(server.command) - - # Open process with stderr piped for capture - process = await _create_platform_compatible_process( - command=command, - args=server.args, - env=({**get_default_environment(), **server.env} if server.env is not None else get_default_environment()), - errlog=errlog, - cwd=server.cwd, - ) - except OSError: - # Clean up streams if process creation fails - await read_stream.aclose() - await write_stream.aclose() - await read_stream_writer.aclose() - await write_stream_reader.aclose() - raise - - async def stdout_reader(): + shutting_down = False + writer_done = anyio.Event() + + async def stdout_reader() -> None: assert process.stdout, "Opened process is missing stdout" + stdout = TextReceiveStream(process.stdout, encoding=server.encoding, errors=server.encoding_error_handler) try: async with read_stream_writer: - buffer = "" - async for chunk in TextReceiveStream( - process.stdout, - encoding=server.encoding, - errors=server.encoding_error_handler, - ): - lines = (buffer + chunk).split("\n") - buffer = lines.pop() - - for line in lines: - try: - message = types.jsonrpc_message_adapter.validate_json(line, by_name=False) - except Exception as exc: # pragma: no cover - logger.exception("Failed to parse JSONRPC message from server") - await read_stream_writer.send(exc) - continue - - session_message = SessionMessage(message) - await read_stream_writer.send(session_message) - except anyio.ClosedResourceError: # pragma: lax no cover - await anyio.lowlevel.checkpoint() - - async def stdin_writer(): + try: + # One line at a time; no read-ahead while a delivery is blocked. + buffer = "" + async for chunk in stdout: + lines = (buffer + chunk).split("\n") + buffer = lines.pop() + for line in lines: + try: + await read_stream_writer.send(_parse_line(line)) + except (anyio.ClosedResourceError, anyio.BrokenResourceError): + return # the session is gone; only the drain below remains + finally: + await _drain_stdout(process) + except anyio.ClosedResourceError: + pass # our own shutdown closed the stdout stream under the read + except (anyio.BrokenResourceError, ConnectionError): + # Teardown noise during shutdown, a real failure otherwise; either way + # the session sees clean closure when the read stream closes. + if not shutting_down: + logger.exception("Reading from the MCP server's stdout failed mid-session") + + async def stdin_writer() -> None: assert process.stdin, "Opened process is missing stdin" try: async with write_stream_reader: async for session_message in write_stream_reader: json = session_message.message.model_dump_json(by_alias=True, exclude_unset=True) - await process.stdin.send( - (json + "\n").encode( - encoding=server.encoding, - errors=server.encoding_error_handler, - ) - ) - except anyio.ClosedResourceError: # pragma: no cover - await anyio.lowlevel.checkpoint() - - async with anyio.create_task_group() as tg, process: + data = (json + "\n").encode(encoding=server.encoding, errors=server.encoding_error_handler) + await process.stdin.send(data) + except (anyio.ClosedResourceError, anyio.BrokenResourceError, OSError): + # The server may still be alive: close the read stream so the session + # sees the connection end instead of a request hanging forever. + await read_stream_writer.aclose() + finally: + writer_done.set() + + async def shutdown() -> None: + """Winds the transport down: stop traffic, flush, stop the server, release the streams.""" + # Unblock the reader into its drain: a server stuck writing stdout cannot + # read its stdin, so draining is what lets the flush below complete. + read_stream.close() + # Bounded window for the writer to flush already-accepted messages. + write_stream.close() + with anyio.move_on_after(_WRITER_FLUSH_TIMEOUT) as flush_scope: + await writer_done.wait() + if flush_scope.cancelled_caught: + await anyio.lowlevel.cancel_shielded_checkpoint() # resync coverage on 3.11 (gh-106749) + await _stop_server_process(process) + await _aclose_all(read_stream, write_stream, read_stream_writer, write_stream_reader) + # One pass so unblocked tasks exit via their except paths before the cancel. + await anyio.lowlevel.checkpoint() + + async with anyio.create_task_group() as tg: tg.start_soon(stdout_reader) tg.start_soon(stdin_writer) try: yield read_stream, write_stream finally: - # MCP spec: stdio shutdown sequence - # 1. Close input stream to server - # 2. Wait for server to exit, or send SIGTERM if it doesn't exit in time - # 3. Send SIGKILL if still not exited - if process.stdin: # pragma: no branch - try: - await process.stdin.aclose() - except Exception: # pragma: no cover - # stdin might already be closed, which is fine - pass - - try: - # Give the process time to exit gracefully after stdin closes - with anyio.fail_after(PROCESS_TERMINATION_TIMEOUT): - await process.wait() - except TimeoutError: - # Process didn't exit from stdin closure, use platform-specific termination - # which handles SIGTERM -> SIGKILL escalation - await _terminate_process_tree(process) - except ProcessLookupError: # pragma: no cover - # Process already exited, which is fine - pass - await read_stream.aclose() - await write_stream.aclose() - await read_stream_writer.aclose() - await write_stream_reader.aclose() + shutting_down = True + # Shutdown must finish even under caller cancellation, or the server + # process would leak; every wait inside is bounded. (Native + # task.cancel() and the fallback's worker threads can still defeat it.) + with anyio.CancelScope(shield=True): + await shutdown() + # Unstick pipe tasks a kill survivor's open pipe end could still block. + tg.cancel_scope.cancel() + # The cancel lands via throw(); one yield resyncs 3.11 coverage (gh-106749). + await anyio.lowlevel.cancel_shielded_checkpoint() + + +def _parse_line(line: str) -> SessionMessage | Exception: + """Parses one stdout line, returning parse errors as values for the session to surface.""" + try: + message = types.jsonrpc_message_adapter.validate_json(line, by_name=False) + except ValueError as exc: + logger.exception("Failed to parse JSONRPC message from server") + return exc + return SessionMessage(message) -def _get_executable_command(command: str) -> str: - """Get the correct executable command normalized for the current platform. +async def _drain_stdout(process: ServerProcess) -> None: + """Consumes and discards the server's remaining stdout. - Args: - command: Base command (e.g., 'uvx', 'npx') + Keeps a server flushing buffered output from blocking on a full pipe and + missing its chance to exit; shielded, raw bytes, ends when shutdown closes + the pipe. + """ + assert process.stdout + with anyio.CancelScope(shield=True): + with suppress( + anyio.EndOfStream, + anyio.ClosedResourceError, + anyio.BrokenResourceError, + ConnectionError, + OSError, + ): + while True: + await process.stdout.receive() + + +async def _stop_server_process(process: ServerProcess) -> None: + """Closes stdin, waits out the grace period, then kills the whole tree. + + The escalation order is spec text; timeouts and tree-wide scope are SDK policy: + https://modelcontextprotocol.io/specification/2025-11-25/basic/lifecycle#shutdown + """ + assert process.stdin and process.stdout, "server process is spawned with pipes" - Returns: - str: Platform-appropriate command + await _close_pipe(process.stdin) + if not await _wait_for_process_exit(process, PROCESS_TERMINATION_TIMEOUT): + await _terminate_process_tree(process) + # Until the event loop observes the death, the transport cannot close. + if not await _wait_for_process_exit(process, _KILL_REAP_TIMEOUT): + logger.warning("MCP server process %d is still alive after the kill escalation; abandoning it", process.pid) + + # Reaps surviving Windows job members now, not at GC; no-op on POSIX. + close_process_job(process) + # A kill survivor can hold the stdout pipe open; poison the reader anyway. + await _close_pipe(process.stdout) + _close_subprocess_transport(process) + + +async def _close_pipe(stream: AsyncResource) -> None: + """Closes a pipe stream, tolerating one already closed, broken, or contended.""" + with suppress(OSError, anyio.BrokenResourceError, anyio.ClosedResourceError): + await stream.aclose() + + +async def _wait_for_process_exit(process: ServerProcess, timeout: float) -> bool: + """Returns whether the process died within the timeout, by polling returncode. + + Not process.wait(): on asyncio 3.11+ it also waits for pipe EOF, and a + child that inherited the pipes makes an exited server look hung. + """ + deadline = anyio.current_time() + timeout + while process.returncode is None: + if anyio.current_time() >= deadline: + return False + await anyio.sleep(_EXIT_POLL_INTERVAL) + return True + + +async def _terminate_process_tree(process: ServerProcess) -> None: + """Kills the process and all its descendants. + + POSIX: SIGTERM to the process group, SIGKILL after FORCE_KILL_TIMEOUT. + Windows: immediate Job Object termination (already a hard kill). """ + if sys.platform == "win32": # pragma: no cover + await terminate_windows_process_tree(process) + else: # pragma: lax no cover + # The Windows-only FallbackProcess never reaches the POSIX path. + assert isinstance(process, Process) + await terminate_posix_process_tree(process, FORCE_KILL_TIMEOUT) + + +def _close_subprocess_transport(process: ServerProcess) -> None: + """Closes the asyncio subprocess transport, if there is one. + + The transport otherwise stays open (and warns at GC) while a surviving + descendant holds a pipe end; nothing public exposes it, hence the attribute + walk. No-op on trio and the Windows fallback. + """ + transport = getattr(getattr(process, "_process", None), "_transport", None) + # Duck-typed: uvloop's UVProcessTransport is not an asyncio.SubprocessTransport. + close = getattr(transport, "close", None) + if callable(close): + # close() on <=3.12 can raise PermissionError re-killing a setuid child. + with suppress(PermissionError): + close() + + +def _get_executable_command(command: str) -> str: + """Normalizes the command for the current platform.""" if sys.platform == "win32": # pragma: no cover return get_windows_executable_command(command) else: # pragma: lax no cover @@ -232,16 +331,15 @@ async def _create_platform_compatible_process( env: dict[str, str] | None = None, errlog: TextIO = sys.stderr, cwd: Path | str | None = None, -): - """Creates a subprocess in a platform-compatible way. +) -> ServerProcess: + """Spawns the server in its own kill scope. - Unix: Creates process in a new session/process group for killpg support - Windows: Creates process in a Job Object for reliable child termination + A new session/process group on POSIX, a Job Object on Windows. """ if sys.platform == "win32": # pragma: no cover - process = await create_windows_process(command, args, env, errlog, cwd) + return await create_windows_process(command, args, env, errlog, cwd) else: # pragma: lax no cover - process = await anyio.open_process( + return await anyio.open_process( [command, *args], env=env, stderr=errlog, @@ -249,22 +347,8 @@ async def _create_platform_compatible_process( start_new_session=True, ) - return process - -async def _terminate_process_tree(process: Process | FallbackProcess, timeout_seconds: float = 2.0) -> None: - """Terminate a process and all its children using platform-specific methods. - - Unix: Uses os.killpg() for atomic process group termination - Windows: Uses Job Objects via pywin32 for reliable child process cleanup - - Args: - process: The process to terminate - timeout_seconds: Timeout in seconds before force killing (default: 2.0) - """ - if sys.platform == "win32": # pragma: no cover - await terminate_windows_process_tree(process, timeout_seconds) - else: # pragma: lax no cover - # FallbackProcess should only be used for Windows compatibility - assert isinstance(process, Process) - await terminate_posix_process_tree(process, timeout_seconds) +async def _aclose_all(*streams: AsyncResource) -> None: + """Closes every given stream.""" + for stream in streams: + await stream.aclose() diff --git a/src/mcp/os/posix/utilities.py b/src/mcp/os/posix/utilities.py index 0e9d74cf3c..d15be17194 100644 --- a/src/mcp/os/posix/utilities.py +++ b/src/mcp/os/posix/utilities.py @@ -3,55 +3,61 @@ import logging import os import signal +from contextlib import suppress import anyio from anyio.abc import Process logger = logging.getLogger(__name__) +# How often to probe for surviving group members between SIGTERM and SIGKILL. +_GROUP_POLL_INTERVAL = 0.01 -async def terminate_posix_process_tree(process: Process, timeout_seconds: float = 2.0) -> None: - """Terminate a process and all its children on POSIX systems. - - Uses os.killpg() for atomic process group termination. - Args: - process: The process to terminate - timeout_seconds: Timeout in seconds before force killing (default: 2.0) +async def terminate_posix_process_tree(process: Process, timeout_seconds: float = 2.0) -> None: + """Terminates a process and all its descendants on POSIX. + + SIGTERMs the process group, waits up to timeout_seconds for it to + disappear, then SIGKILLs whatever remains. killpg reaches every descendant + atomically, even ones whose parent already exited; daemonizers that left + the group escape by design. A group only disappears once every member is + dead and reaped, so a client running as PID 1 should reap orphans (e.g. + docker run --init) or the wait below runs its full timeout. """ - pid = getattr(process, "pid", None) or getattr(getattr(process, "popen", None), "pid", None) - if not pid: - # No PID means there's no process to terminate - it either never started, - # already exited, or we have an invalid process object - return + # The leader's pid is the pgid (start_new_session). Never use getpgid(): + # it fails once the leader is reaped, even with live members left. + pgid = process.pid try: - pgid = os.getpgid(pid) os.killpg(pgid, signal.SIGTERM) + except ProcessLookupError: + return # the whole group is already gone + except PermissionError: + # EPERM never proves the group is gone (macOS raises it for zombie or + # foreign-euid members), so keep waiting and escalating. + logger.warning( + "No permission to signal some of process group %d; waiting for it to exit anyway", pgid, exc_info=True + ) + + with anyio.move_on_after(timeout_seconds): + while _group_alive(pgid): + # Reading returncode reaps the leader on trio; a zombie leader would + # otherwise keep the group alive for the full timeout. + _ = process.returncode + await anyio.sleep(_GROUP_POLL_INTERVAL) + return + + # ESRCH: died since the last probe. EPERM: we killed what we were allowed to. + with suppress(ProcessLookupError, PermissionError): + os.killpg(pgid, signal.SIGKILL) - with anyio.move_on_after(timeout_seconds): - while True: - try: - # Check if process group still exists (signal 0 = check only) - os.killpg(pgid, 0) - await anyio.sleep(0.1) - except ProcessLookupError: - return - - try: - os.killpg(pgid, signal.SIGKILL) - except ProcessLookupError: - pass - - except (ProcessLookupError, PermissionError, OSError) as e: - logger.warning(f"Process group termination failed for PID {pid}: {e}, falling back to simple terminate") - try: - process.terminate() - with anyio.fail_after(timeout_seconds): - await process.wait() - except Exception: - logger.warning(f"Process termination failed for PID {pid}, attempting force kill") - try: - process.kill() - except Exception: - logger.exception(f"Failed to kill process {pid}") + +def _group_alive(pgid: int) -> bool: + """Probes the group with signal 0; only ESRCH proves it is gone.""" + try: + os.killpg(pgid, 0) + except ProcessLookupError: + return False + except PermissionError: + pass # unsignalable survivors or unreaped zombies; EPERM is ambiguous + return True diff --git a/src/mcp/os/win32/utilities.py b/src/mcp/os/win32/utilities.py index 6f68405f78..1cc867d4fa 100644 --- a/src/mcp/os/win32/utilities.py +++ b/src/mcp/os/win32/utilities.py @@ -4,16 +4,16 @@ import shutil import subprocess import sys +import weakref +from contextlib import suppress from pathlib import Path -from typing import BinaryIO, TextIO, cast +from typing import BinaryIO, TextIO, TypeAlias, cast import anyio -from anyio import to_thread from anyio.abc import Process from anyio.streams.file import FileReadStream, FileWriteStream -from typing_extensions import deprecated -logger = logging.getLogger("client.stdio.win32") +logger = logging.getLogger(__name__) # Windows-specific imports for Job Objects if sys.platform == "win32": @@ -28,110 +28,86 @@ win32job = None pywintypes = None -JobHandle = int +# How often FallbackProcess polls the underlying Popen for exit. +_EXIT_POLL_INTERVAL = 0.01 +# Job Object handle per spawned process, for tree termination at shutdown. +# Values stay pywin32 PyHANDLEs: if no pop site ever runs, the dying weak entry +# drops the last reference and the PyHANDLE destructor closes the handle, which +# is what makes KILL_ON_JOB_CLOSE reap an abandoned tree. +_process_jobs: "weakref.WeakKeyDictionary[Process | FallbackProcess, object]" = weakref.WeakKeyDictionary() -def get_windows_executable_command(command: str) -> str: - """Get the correct executable command normalized for Windows. - - On Windows, commands might exist with specific extensions (.exe, .cmd, etc.) - that need to be located for proper execution. - Args: - command: Base command (e.g., 'uvx', 'npx') +def get_windows_executable_command(command: str) -> str: + """Resolves the command to a Windows executable path. - Returns: - str: Windows-appropriate command path + Tries the bare name first, then the common script extensions (.cmd, .bat, + .exe, .ps1). """ try: - # First check if command exists in PATH as-is if command_path := shutil.which(command): return command_path - # Check for Windows-specific extensions for ext in [".cmd", ".bat", ".exe", ".ps1"]: ext_version = f"{command}{ext}" if ext_path := shutil.which(ext_version): return ext_path - # For regular commands or if we couldn't find special versions return command except OSError: - # Handle file system errors during path resolution - # (permissions, broken symlinks, etc.) - return command + return command # path probing failed (permissions, broken symlinks) class FallbackProcess: - """A fallback process wrapper for Windows to handle async I/O - when using subprocess.Popen, which provides sync-only FileIO objects. + """Async wrapper around subprocess.Popen for SelectorEventLoop. - This wraps stdin and stdout into async-compatible - streams (FileReadStream, FileWriteStream), - so that MCP clients expecting async streams can work properly. + Windows event loops without async subprocess support get this Popen-backed + fallback, with anyio file streams wrapping the pipes. """ - def __init__(self, popen_obj: subprocess.Popen[bytes]): + def __init__(self, popen_obj: subprocess.Popen[bytes]) -> None: self.popen: subprocess.Popen[bytes] = popen_obj - self.stdin_raw = popen_obj.stdin # type: ignore[assignment] - self.stdout_raw = popen_obj.stdout # type: ignore[assignment] - self.stderr = popen_obj.stderr # type: ignore[assignment] - - self.stdin = FileWriteStream(cast(BinaryIO, self.stdin_raw)) if self.stdin_raw else None - self.stdout = FileReadStream(cast(BinaryIO, self.stdout_raw)) if self.stdout_raw else None - - async def __aenter__(self): - """Support async context manager entry.""" - return self - - async def __aexit__( - self, - exc_type: BaseException | None, - exc_val: BaseException | None, - exc_tb: object | None, - ) -> None: - """Terminate and wait on process exit inside a thread.""" + stdin = popen_obj.stdin + stdout = popen_obj.stdout + + self.stdin = FileWriteStream(cast(BinaryIO, stdin)) if stdin else None + self.stdout = FileReadStream(cast(BinaryIO, stdout)) if stdout else None + + async def wait(self) -> int: + """Waits for exit by polling the Popen. + + A thread blocked in Popen.wait() cannot be cancelled by anyio, which + would defeat every timeout placed around this call. + """ + while (returncode := self.popen.poll()) is None: + await anyio.sleep(_EXIT_POLL_INTERVAL) + return returncode + + def terminate(self) -> None: + """Terminates the subprocess.""" self.popen.terminate() - await to_thread.run_sync(self.popen.wait) - - # Close the file handles to prevent ResourceWarning - if self.stdin: - await self.stdin.aclose() - if self.stdout: - await self.stdout.aclose() - if self.stdin_raw: - self.stdin_raw.close() - if self.stdout_raw: - self.stdout_raw.close() - if self.stderr: - self.stderr.close() - - async def wait(self): - """Async wait for process completion.""" - return await to_thread.run_sync(self.popen.wait) - - def terminate(self): - """Terminate the subprocess immediately.""" - return self.popen.terminate() def kill(self) -> None: - """Kill the subprocess immediately (alias for terminate).""" - self.terminate() + """Kills the subprocess (on Windows the same hard kill as terminate).""" + self.popen.kill() @property def pid(self) -> int: - """Return the process ID.""" + """Returns the process ID.""" return self.popen.pid @property def returncode(self) -> int | None: - """Return the exit code, or ``None`` if the process has not yet terminated.""" - return self.popen.returncode + """The exit code, or None while the process is still running. + + Polls the Popen so death is observable without anyone calling wait(). + """ + return self.popen.poll() -# ------------------------ -# Updated function -# ------------------------ +# The process handle stdio_client drives: anyio's Process, or the Popen-backed +# fallback used on Windows event loops without async subprocess support. +ServerProcess: TypeAlias = Process | FallbackProcess async def create_windows_process( @@ -141,53 +117,35 @@ async def create_windows_process( errlog: TextIO | None = sys.stderr, cwd: Path | str | None = None, ) -> Process | FallbackProcess: - """Creates a subprocess in a Windows-compatible way with Job Object support. + """Creates a subprocess with Job Object support for tree termination. - Attempts to use anyio's open_process for async subprocess creation. - In some cases this will throw NotImplementedError on Windows, e.g., - when using the SelectorEventLoop, which does not support async subprocesses. - In that case, we fall back to using subprocess.Popen. - - The process is automatically added to a Job Object to ensure all child - processes are terminated when the parent is terminated. - - Args: - command (str): The executable to run - args (list[str]): List of command line arguments - env (dict[str, str] | None): Environment variables - errlog (TextIO | None): Where to send stderr output (defaults to sys.stderr) - cwd (Path | str | None): Working directory for the subprocess + Spawns via anyio's open_process; event loops without async subprocess + support (notably the SelectorEventLoop) raise NotImplementedError, in which + case the spawn falls back to a Popen-backed FallbackProcess. Either way the + process is then assigned to a Job Object so its children can be terminated + with it; children spawned before the assignment completes are not captured + (see the inline note below). Returns: - Process | FallbackProcess: Async-compatible subprocess with stdin and stdout streams + Process | FallbackProcess: The spawned process with async stdin/stdout streams. """ - job = _create_job_object() - process = None - try: - # First try using anyio with Windows-specific flags to hide console window process = await anyio.open_process( [command, *args], env=env, # Ensure we don't create console windows for each process - creationflags=subprocess.CREATE_NO_WINDOW # type: ignore - if hasattr(subprocess, "CREATE_NO_WINDOW") - else 0, + creationflags=getattr(subprocess, "CREATE_NO_WINDOW", 0), stderr=errlog, cwd=cwd, ) except NotImplementedError: - # If Windows doesn't support async subprocess creation, use fallback + # Windows event loops without async subprocess support (SelectorEventLoop) process = await _create_windows_fallback_process(command, args, env, errlog, cwd) - except Exception: - # Try again without creation flags - process = await anyio.open_process( - [command, *args], - env=env, - stderr=errlog, - cwd=cwd, - ) + # Children spawned before the assignment completes land outside the job + # (membership is inherited at CreateProcess, never acquired retroactively); + # if that ever bites, the fix is a CREATE_SUSPENDED spawn -> assign -> resume. + job = _create_job_object() _maybe_assign_process_to_job(process, job) return process @@ -199,41 +157,26 @@ async def _create_windows_fallback_process( errlog: TextIO | None = sys.stderr, cwd: Path | str | None = None, ) -> FallbackProcess: - """Create a subprocess using subprocess.Popen as a fallback when anyio fails. - - This function wraps the sync subprocess.Popen in an async-compatible interface. - """ - try: - # Try launching with creationflags to avoid opening a new console window - popen_obj = subprocess.Popen( - [command, *args], - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=errlog, - env=env, - cwd=cwd, - bufsize=0, # Unbuffered output - creationflags=getattr(subprocess, "CREATE_NO_WINDOW", 0), - ) - except Exception: - # If creationflags failed, fallback without them - popen_obj = subprocess.Popen( - [command, *args], - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=errlog, - env=env, - cwd=cwd, - bufsize=0, - ) + """Spawns via subprocess.Popen and wraps it in FallbackProcess.""" + popen_obj = subprocess.Popen( + [command, *args], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=errlog, + env=env, + cwd=cwd, + bufsize=0, # Unbuffered output + creationflags=getattr(subprocess, "CREATE_NO_WINDOW", 0), + ) return FallbackProcess(popen_obj) -def _create_job_object() -> int | None: - """Create a Windows Job Object configured to terminate all processes when closed.""" - if sys.platform != "win32" or not win32job: +def _create_job_object() -> object | None: + """Creates a Windows Job Object configured to terminate all its processes when closed.""" + if sys.platform != "win32" or not win32api or not win32job: return None + job = None try: job = win32job.CreateJobObject(None, "") extended_info = win32job.QueryInformationJobObject(job, win32job.JobObjectExtendedLimitInformation) @@ -241,17 +184,20 @@ def _create_job_object() -> int | None: extended_info["BasicLimitInformation"]["LimitFlags"] |= win32job.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE win32job.SetInformationJobObject(job, win32job.JobObjectExtendedLimitInformation, extended_info) return job - except Exception as e: - logger.warning(f"Failed to create Job Object for process tree management: {e}") + except pywintypes.error: + logger.warning("Failed to create Job Object for process tree management", exc_info=True) + # If creation succeeded but configuration failed, close the handle now. + if job is not None: + _close_job_handle(job) return None -def _maybe_assign_process_to_job(process: Process | FallbackProcess, job: JobHandle | None) -> None: - """Try to assign a process to a job object. +def _maybe_assign_process_to_job(process: Process | FallbackProcess, job: object | None) -> None: + """Assigns the process to the job and records it for tree termination. - If assignment fails for any reason, the job handle is closed. + On any failure the job handle is closed instead. """ - if not job: + if job is None: return if sys.platform != "win32" or not win32api or not win32con or not win32job: @@ -262,72 +208,62 @@ def _maybe_assign_process_to_job(process: Process | FallbackProcess, job: JobHan win32con.PROCESS_SET_QUOTA | win32con.PROCESS_TERMINATE, False, process.pid ) if not process_handle: - raise Exception("Failed to open process handle") + raise pywintypes.error(0, "OpenProcess", "Failed to open process handle") try: win32job.AssignProcessToJobObject(job, process_handle) - process._job_object = job finally: win32api.CloseHandle(process_handle) - except Exception as e: - logger.warning(f"Failed to assign process {process.pid} to Job Object: {e}") - if win32api: - win32api.CloseHandle(job) + # Record only after the CloseHandle above succeeded: had it failed, the + # except below would close the job and KILL_ON_JOB_CLOSE takes the server. + _process_jobs[process] = job + except pywintypes.error: + logger.warning("Failed to assign process %d to Job Object", process.pid, exc_info=True) + _close_job_handle(job) -async def terminate_windows_process_tree(process: Process | FallbackProcess, timeout_seconds: float = 2.0) -> None: - """Terminate a process and all its children on Windows. +def close_process_job(process: Process | FallbackProcess) -> None: + """Closes the process's Job Object handle, if it still has one. - If the process has an associated job object, it will be terminated. - Otherwise, falls back to basic process termination. + KILL_ON_JOB_CLOSE makes the close also kill any members still alive, + deterministically rather than at GC time; a deliberate divergence from + POSIX, where a graceful server's children are left alive. + """ + if sys.platform != "win32": + return + + job = _process_jobs.pop(process, None) + if job is not None: + _close_job_handle(job) - Args: - process: The process to terminate - timeout_seconds: Timeout in seconds before force killing (default: 2.0) + +async def terminate_windows_process_tree(process: Process | FallbackProcess) -> None: + """Terminates the process's job, or just the process if it has no job. + + Job termination is an immediate hard kill of every member. Windows has no + tree-wide SIGTERM; the stdin-close grace period is the server's chance to + exit cleanly. """ if sys.platform != "win32": return - job = getattr(process, "_job_object", None) - if job and win32job: + job = _process_jobs.pop(process, None) + if job is not None and win32job: try: - win32job.TerminateJobObject(job, 1) - except Exception: - # Job might already be terminated - pass + with suppress(pywintypes.error): # the job might already be terminated + win32job.TerminateJobObject(job, 1) finally: - if win32api: - try: - win32api.CloseHandle(job) - except Exception: - pass + _close_job_handle(job) - # Always try to terminate the process itself as well + # The process may have no job (creation or assignment failed); kill it directly too. try: process.terminate() - except Exception: + except OSError: pass -@deprecated( - "terminate_windows_process is deprecated and will be removed in a future version. " - "Process termination is now handled internally by the stdio_client context manager." -) -async def terminate_windows_process(process: Process | FallbackProcess): - """Terminate a Windows process. - - Note: On Windows, terminating a process with process.terminate() doesn't - always guarantee immediate process termination. - If the process does not exit within 2 seconds, process.kill() is called - to send a SIGKILL-equivalent signal. - - Args: - process: The process to terminate - """ - try: - process.terminate() - with anyio.fail_after(2.0): - await process.wait() - except TimeoutError: - # Force kill if it doesn't terminate - process.kill() +def _close_job_handle(job: object) -> None: + """Closes a Job Object handle, tolerating one that is already closed.""" + if win32api and pywintypes: + with suppress(pywintypes.error): + win32api.CloseHandle(job) diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 06e2cba4b1..f3cb88dc9c 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -1,253 +1,1120 @@ +"""Tests for the stdio client transport. + +Transport logic (framing, parse errors, shutdown escalation decisions) is tested in +process against a fake process injected through the spawn seam; only real OS behaviour +(process-group kill semantics, SIGKILL after an ignored SIGTERM, exec failure) uses +real subprocesses, synchronized only by kernel-level liveness sockets. The full +client<->server round trip is pinned by tests/interaction/transports/test_stdio.py. +""" + import errno -import shutil +import gc +import logging +import math +import os +import signal import sys -import textwrap -import time +from collections.abc import Callable from contextlib import AsyncExitStack, suppress +from pathlib import Path +from typing import TextIO, cast import anyio import anyio.abc +import anyio.lowlevel import pytest +import trio +import trio.testing +from anyio.streams.memory import MemoryObjectReceiveStream +from mcp.client import stdio +from mcp.client._transport import ReadStream from mcp.client.session import ClientSession from mcp.client.stdio import ( + _EXIT_POLL_INTERVAL, StdioServerParameters, _create_platform_compatible_process, _terminate_process_tree, stdio_client, ) +from mcp.os.posix import utilities as posix_utilities +from mcp.os.posix.utilities import terminate_posix_process_tree from mcp.os.win32.utilities import FallbackProcess from mcp.shared.exceptions import MCPError from mcp.shared.message import SessionMessage from mcp.types import CONNECTION_CLOSED, JSONRPCMessage, JSONRPCRequest, JSONRPCResponse -# Timeout for cleanup of processes that ignore SIGTERM -# This timeout ensures the test fails quickly if the cleanup logic doesn't have -# proper fallback mechanisms (SIGINT/SIGKILL) for processes that ignore SIGTERM -SIGTERM_IGNORING_PROCESS_TIMEOUT = 5.0 +# --------------------------------------------------------------------------- +# In-process fake of the spawned server process +# --------------------------------------------------------------------------- +# +# Everything between the spawn and the OS kill is pure SDK logic, so it is tested +# against this fake by monkeypatching the spawn and terminate seams. The OS half +# is tested separately below with real processes. + + +class _FakeStdin: + """The fake process's stdin: records what the client writes, signals closure.""" + + def __init__(self, process: "FakeProcess") -> None: + self._process = process + + async def send(self, data: bytes) -> None: + if self._process.stdin_send_gate is not None: + # A full pipe whose reader is busy elsewhere: the write completes + # only once the test's gate opens. + await self._process.stdin_send_gate.wait() + if self._process.stdin_send_blocks: + # A pipe whose reader stopped reading: the write never completes. + await anyio.sleep_forever() + if self._process.stdin_send_error is not None: + raise self._process.stdin_send_error + if self._process.returncode is not None: + # What the asyncio backend surfaces when writing to a dead child's pipe. + raise ConnectionResetError("Connection lost") + self._process.written.append(data) + + async def aclose(self) -> None: + self._process.stdin_closed.set() + if self._process.on_stdin_close is not None: + self._process.on_stdin_close() + if self._process.stdin_aclose_error is not None: + raise self._process.stdin_aclose_error + + +class _FakeStdout: + """The fake process's stdout: delegates to the in-memory stream. + + Optionally surfaces the abrupt-death or close-time errors a real pipe can. + """ + + def __init__( + self, + inner: MemoryObjectReceiveStream[bytes], + *, + eof_error: Exception | None = None, + aclose_error: Exception | None = None, + on_receive: Callable[[], None], + ) -> None: + self._inner = inner + self._eof_error = eof_error + self._aclose_error = aclose_error + self._on_receive = on_receive + + async def receive(self) -> bytes: + try: + chunk = await self._inner.receive() + except anyio.EndOfStream: + if self._eof_error is not None: + # A hard-killed pipe surfaces a reset, not EOF, on the proactor loop. + raise self._eof_error from None + raise + self._on_receive() + return chunk + + async def aclose(self) -> None: + await self._inner.aclose() + if self._aclose_error is not None: + raise self._aclose_error + # Real async closes yield; keeps the fake honest and shutdown scheduling realistic. + await anyio.lowlevel.checkpoint() + + +class FakeProcess: + """In-memory stand-in for the spawned server process. + + `feed`/`close_stdout` drive its stdout, `written` records client writes, `exit` + and the error knobs replay death and pipe failure modes. + """ + + def __init__( + self, + on_stdin_close: Callable[[], None] | None = None, + stdin_aclose_error: Exception | None = None, + stdin_send_error: Exception | None = None, + stdin_send_blocks: bool = False, + stdin_send_gate: anyio.Event | None = None, + stdout_eof_error: Exception | None = None, + stdout_aclose_error: Exception | None = None, + on_stdout_receive: Callable[[], None] | None = None, + ) -> None: + self._stdout_send, stdout_receive = anyio.create_memory_object_stream[bytes](math.inf) + self.stdout = _FakeStdout( + stdout_receive, + eof_error=stdout_eof_error, + aclose_error=stdout_aclose_error, + on_receive=self._dispatch_stdout_receive, + ) + self.pid = 424242 + self.written: list[bytes] = [] + self.stdin_closed = anyio.Event() + self.returncode: int | None = None + self.on_stdin_close = on_stdin_close + self.stdin_aclose_error = stdin_aclose_error + self.stdin_send_error = stdin_send_error + self.stdin_send_blocks = stdin_send_blocks + self.stdin_send_gate = stdin_send_gate + self.on_stdout_receive = on_stdout_receive + self.stdin = _FakeStdin(self) + + def _dispatch_stdout_receive(self) -> None: + # Late-bound so a test can assign `on_stdout_receive` after construction. + if self.on_stdout_receive is not None: + self.on_stdout_receive() + + async def feed(self, data: bytes) -> None: + """Make `data` readable on the fake process's stdout.""" + await self._stdout_send.send(data) + + def close_stdout(self) -> None: + """End the fake process's stdout, as the kernel does when it dies.""" + self._stdout_send.close() + + def exit(self, code: int = 0) -> None: + """Die: set the exit code and EOF stdout, as the kernel does.""" + self.returncode = code + self.close_stdout() + + def pending_stdout_chunks(self) -> int: + """How many fed chunks the client has not yet pulled off the fake stdout.""" + return self._stdout_send.statistics().current_buffer_used + + +def install_fake_process( + monkeypatch: pytest.MonkeyPatch, process: FakeProcess, *, grace_period: float | None = 0.2 +) -> list[FakeProcess]: + """Route stdio_client's spawn and terminate seams to `process`. + + Returns the list of processes the (fake) tree termination was invoked on. + `grace_period=None` keeps the production stdin-close grace (affordable only on a + virtual clock). + """ + terminated: list[FakeProcess] = [] + + async def fake_spawn( + command: str, + args: list[str], + env: dict[str, str] | None = None, + errlog: TextIO = sys.stderr, + cwd: Path | str | None = None, + ) -> FakeProcess: + return process + + async def fake_terminate_tree(proc: FakeProcess) -> None: + terminated.append(proc) + proc.exit(-15) -tee = shutil.which("tee") + monkeypatch.setattr(stdio, "_create_platform_compatible_process", fake_spawn) + monkeypatch.setattr(stdio, "_terminate_process_tree", fake_terminate_tree) + if grace_period is not None: + monkeypatch.setattr(stdio, "PROCESS_TERMINATION_TIMEOUT", grace_period) + return terminated + + +FAKE_PARAMS = StdioServerParameters(command="fake-server") + + +def _line(message: JSONRPCMessage) -> bytes: + """The wire form of `message`: one JSON document on its own line.""" + return (message.model_dump_json(by_alias=True, exclude_unset=True) + "\n").encode() + + +async def _next_message(read_stream: ReadStream[SessionMessage | Exception]) -> JSONRPCMessage: + received = await read_stream.receive() + assert isinstance(received, SessionMessage) + return received.message + + +@pytest.mark.anyio +async def test_messages_split_and_packed_across_chunks_are_reframed(monkeypatch: pytest.MonkeyPatch) -> None: + """Framing survives arbitrary chunk boundaries. + + Split, packed, and CRLF-terminated messages are each delivered exactly once, and a + trailing line without a newline is not delivered. + """ + ping = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + pong = JSONRPCResponse(jsonrpc="2.0", id=1, result={}) + ping2 = JSONRPCRequest(jsonrpc="2.0", id=2, method="ping") + process = FakeProcess(on_stdin_close=lambda: process.exit(0)) + + install_fake_process(monkeypatch, process) + + with anyio.fail_after(5): + async with stdio_client(FAKE_PARAMS) as (read_stream, _): + # First message split mid-bytes; its tail packed with the second, a + # CRLF-framed third (the SDK's own server emits \r\n on Windows; jiter + # treats the \r as JSON whitespace), and a partial fourth. + wire = _line(ping) + crlf_wire = ping2.model_dump_json(by_alias=True, exclude_unset=True).encode() + b"\r\n" + await process.feed(wire[:7]) + await process.feed(wire[7:] + _line(pong) + crlf_wire + b'{"jsonrpc": "2.0", "id": 99') + + assert await _next_message(read_stream) == ping + assert await _next_message(read_stream) == pong + assert await _next_message(read_stream) == ping2 + + # The partial trailing message is dropped at EOF, not delivered broken. + # (no branch: coverage mis-traces the exit arc of a `with` whose body + # raises inside a nested async context.) + with pytest.raises(anyio.EndOfStream): # pragma: no branch + process.close_stdout() + await read_stream.receive() @pytest.mark.anyio -@pytest.mark.skipif(tee is None, reason="could not find tee command") -async def test_stdio_context_manager_exiting(): - assert tee is not None - async with stdio_client(StdioServerParameters(command=tee)) as (_, _): - pass +async def test_each_outgoing_message_is_written_as_exactly_one_line(monkeypatch: pytest.MonkeyPatch) -> None: + """Client -> server framing writes one line per message. + + Every sent message reaches the server's stdin as exactly one newline-terminated + JSON document. + """ + ping = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + pong = JSONRPCResponse(jsonrpc="2.0", id=1, result={}) + process = FakeProcess(on_stdin_close=lambda: process.exit(0)) + + install_fake_process(monkeypatch, process) + + with anyio.fail_after(5): + async with stdio_client(FAKE_PARAMS) as (_, write_stream): + await write_stream.send(SessionMessage(ping)) + await write_stream.send(SessionMessage(pong)) + # The zero-buffer handoff resumes this task before the writer has + # necessarily written; once all tasks block again, both writes have landed. + await anyio.wait_all_tasks_blocked() + assert process.written == [_line(ping), _line(pong)] @pytest.mark.anyio -@pytest.mark.skipif(tee is None, reason="could not find tee command") -async def test_stdio_client(): - assert tee is not None - server_parameters = StdioServerParameters(command=tee) - - async with stdio_client(server_parameters) as (read_stream, write_stream): - # Test sending and receiving messages - messages = [ - JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"), - JSONRPCResponse(jsonrpc="2.0", id=2, result={}), - ] - - async with write_stream: - for message in messages: - session_message = SessionMessage(message) - await write_stream.send(session_message) - - read_messages: list[JSONRPCMessage] = [] - async with read_stream: - async for message in read_stream: - if isinstance(message, Exception): # pragma: no cover - raise message - - read_messages.append(message.message) - if len(read_messages) == 2: - break - - assert len(read_messages) == 2 - assert read_messages[0] == JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") - assert read_messages[1] == JSONRPCResponse(jsonrpc="2.0", id=2, result={}) +async def test_invalid_json_from_the_server_surfaces_as_an_in_stream_exception( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A line failing JSON-RPC validation is delivered as an Exception on the read stream. + + The messages after it still come through. + """ + ping = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + process = FakeProcess(on_stdin_close=lambda: process.exit(0)) + + install_fake_process(monkeypatch, process) + + with anyio.fail_after(5): + async with stdio_client(FAKE_PARAMS) as (read_stream, _): + await process.feed(b"not json\n" + _line(ping)) + + error = await read_stream.receive() + # The transport surfaces parse failures as the underlying validation error. + assert isinstance(error, ValueError) + assert await _next_message(read_stream) == ping @pytest.mark.anyio -async def test_stdio_client_bad_path(): - """Check that the connection doesn't hang if process errors.""" - server_params = StdioServerParameters(command=sys.executable, args=["-c", "non-existent-file.py"]) - async with stdio_client(server_params) as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # The session should raise an error when the connection closes +async def test_a_server_that_dies_before_responding_fails_initialize_with_connection_closed( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Server death (stdout EOF) is reported to the session as a closed connection. + + The in-flight initialize fails instead of hanging. + """ + process = FakeProcess(on_stdin_close=lambda: process.exit(0)) + process.exit(1) + + install_fake_process(monkeypatch, process) + + with anyio.fail_after(5): + async with ( + stdio_client(FAKE_PARAMS) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): with pytest.raises(MCPError) as exc_info: await session.initialize() - # Check that we got a connection closed error assert exc_info.value.error.code == CONNECTION_CLOSED - assert "Connection closed" in exc_info.value.error.message + assert exc_info.value.error.message == "Connection closed" + + +@pytest.mark.anyio +async def test_a_server_that_exits_on_stdin_close_is_never_terminated(monkeypatch: pytest.MonkeyPatch) -> None: + """Closing stdin (shutdown's first step) suffices for a well-behaved server. + + The escalation is never invoked. The fake's stdin also raises on close, which the + shutdown must tolerate. + """ + + process = FakeProcess( + on_stdin_close=lambda: process.exit(0), + stdin_aclose_error=anyio.ClosedResourceError(), + ) + terminated = install_fake_process(monkeypatch, process) + + with anyio.fail_after(5): + async with stdio_client(FAKE_PARAMS): + pass + + assert terminated == [] + assert process.stdin_closed.is_set() + + +def test_escalation_fires_once_and_only_after_the_grace_period(monkeypatch: pytest.MonkeyPatch) -> None: + """A server that ignores stdin closure is terminated at the grace deadline exactly. + + The kill lands no earlier than the production `PROCESS_TERMINATION_TIMEOUT` on the + runtime clock, and by the first `returncode` poll after it. + + The suite's only direct trio use: anyio's pytest plugin cannot hand the backend a + clock, so the test calls `trio.run` itself with an autojumping `MockClock`. Every + time primitive rides that one virtual clock, so the production grace elapses + instantly and the bound can be two-sided (a wall-clock upper bound flakes under + load). That virtual seconds match wall seconds is the runtime clock's contract, + deliberately not re-tested here. + """ + + class ClockedFakeProcess(FakeProcess): + """Records the virtual time of each death. + + Only the (fake) tree termination calls `exit` here, so these are the + escalation timestamps. + """ + + def __init__(self) -> None: + super().__init__() + self.exit_times: list[float] = [] + + def exit(self, code: int = 0) -> None: + self.exit_times.append(trio.current_time()) + super().exit(code) + + process = ClockedFakeProcess() + terminated = install_fake_process(monkeypatch, process, grace_period=None) + + async def run_client() -> float: + with anyio.fail_after(stdio.PROCESS_TERMINATION_TIMEOUT + 5): # virtual seconds + async with stdio_client(FAKE_PARAMS): + # Evaluated just before the context exits: the moment cleanup begins. + return trio.current_time() + + cleanup_started = trio.run(run_client, clock=trio.testing.MockClock(autojump_threshold=0)) + + assert terminated == [process] + virtual_elapsed = process.exit_times[0] - cleanup_started + # Two-sided: never before the grace deadline, and within one poll interval past it + # (shutdown's writer-flush poll); the epsilon absorbs virtual-sleep float drift. + assert ( + stdio.PROCESS_TERMINATION_TIMEOUT + <= virtual_elapsed + <= stdio.PROCESS_TERMINATION_TIMEOUT + _EXIT_POLL_INTERVAL + 1e-9 + ), virtual_elapsed + + +def test_a_server_dying_in_the_final_poll_interval_is_not_escalated(monkeypatch: pytest.MonkeyPatch) -> None: + """A server exiting in the poll interval the grace deadline cuts short is not escalated. + + Such a server is dead, not hung: the timed-out grace wait must re-check `returncode` + before deciding to escalate, so this server is never terminated. + + Runs on trio's MockClock (see the escalation-bound test above). The grace is + set to end mid-interval (0.105 with 0.01 polls) and the fake dies at 0.102 + after its stdin closes, strictly between the last in-window poll (0.10) and + the deadline (0.105), so no two timers collide. + """ + process = FakeProcess() + terminated = install_fake_process(monkeypatch, process, grace_period=0.105) + + async def run_client() -> None: + with anyio.fail_after(5): # virtual seconds + async with anyio.create_task_group() as tg: + + async def die_late() -> None: + await anyio.sleep(0.102) + process.exit(0) + + # The grace wait starts when stdin closes; anchor the death there. + process.on_stdin_close = lambda: tg.start_soon(die_late) + # no branch: the tracer drops this nested async-with's arcs under + # trio's MockClock even though the body runs. + async with stdio_client(FAKE_PARAMS): # pragma: no branch + pass + + trio.run(run_client, clock=trio.testing.MockClock(autojump_threshold=0)) + + assert terminated == [] + assert process.returncode == 0 + + +@pytest.mark.anyio +async def test_cancelling_the_client_still_runs_the_full_shutdown(monkeypatch: pytest.MonkeyPatch) -> None: + """Cancellation (a client timeout, app shutdown) must not skip the shutdown sequence. + + Stdin is still closed and a server ignoring it is still terminated. Without the + shielded shutdown this leaks the process and can deadlock. + """ + process = FakeProcess() + terminated = install_fake_process(monkeypatch, process, grace_period=0.05) + entered = anyio.Event() + # Cancel a scope owned by the client's task, not the test's task group: a host + # self-cancel is delivered by throwing through this test function's suspended + # frames, and Python 3.11's tracer loses coverage events after such a throw() + # traversal (python/cpython#106749). + cancel_scope = anyio.CancelScope() + + async def run_client_until_cancelled() -> None: + with cancel_scope: + async with stdio_client(FAKE_PARAMS): + entered.set() + await anyio.sleep_forever() + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + tg.start_soon(run_client_until_cancelled) + await entered.wait() + cancel_scope.cancel() + + assert process.stdin_closed.is_set() + assert terminated == [process] @pytest.mark.anyio -async def test_stdio_client_nonexistent_command(): - """Test that stdio_client raises an error for non-existent commands.""" - # Create a server with a non-existent command +async def test_writing_after_the_server_dies_reports_clean_closure(monkeypatch: pytest.MonkeyPatch) -> None: + """A send racing the server's death must not surface a raw backend exception. + + The exception (ConnectionResetError in an exception group) must not escape the + context manager; the transport still shuts down cleanly. + """ + ping = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + process = FakeProcess(on_stdin_close=lambda: process.exit(0)) + + install_fake_process(monkeypatch, process) + + with anyio.fail_after(5): + async with stdio_client(FAKE_PARAMS) as (_, write_stream): + process.exit(1) + # The fake's stdin now raises ConnectionResetError, as a dead child's pipe does. + await write_stream.send(SessionMessage(ping)) + + assert process.written == [] + + +@pytest.mark.anyio +async def test_exiting_with_an_unconsumed_server_message_does_not_raise(monkeypatch: pytest.MonkeyPatch) -> None: + """Exiting while a server message is still undelivered must be a clean exit. + + Shutdown closes the read stream under the blocked reader task, and that closure + must not escape the caller as a BrokenResourceError in an exception group. + """ + ping = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + process = FakeProcess(on_stdin_close=lambda: process.exit(0)) + + install_fake_process(monkeypatch, process) + + with anyio.fail_after(5): + async with stdio_client(FAKE_PARAMS): + # Feed a message and never receive it: the reader parses it and blocks + # delivering into the zero-buffer read stream until shutdown breaks the send. + await process.feed(_line(ping)) + # Wait until the reader task is genuinely parked on its blocked send + # before shutdown closes the stream out from under it. + await anyio.wait_all_tasks_blocked() + + +@pytest.mark.anyio +async def test_spawn_failure_propagates_the_error_and_leaks_no_streams(monkeypatch: pytest.MonkeyPatch) -> None: + """When the spawn itself fails, the OSError reaches the caller and no streams leak. + + The transport's internal streams are all closed; an unclosed stream would fail the + test through its GC-time ResourceWarning under filterwarnings=error. + """ + + async def failing_spawn( + command: str, + args: list[str], + env: dict[str, str] | None = None, + errlog: TextIO = sys.stderr, + cwd: Path | str | None = None, + ) -> FakeProcess: + raise OSError(errno.EACCES, "Permission denied") + + monkeypatch.setattr(stdio, "_create_platform_compatible_process", failing_spawn) + + with pytest.raises(OSError) as exc_info: + async with stdio_client(FAKE_PARAMS): + pass # pragma: no cover + + assert exc_info.value.errno == errno.EACCES + # Drop the ExceptionInfo before collecting: its traceback references the suspended + # stdio_client frame, which would keep leaked streams alive across the collect. + del exc_info + gc.collect() + + +@pytest.mark.anyio +async def test_a_command_that_cannot_be_execed_raises_enoent() -> None: + """A command that cannot be exec'd raises OSError(ENOENT) out of stdio_client.""" server_params = StdioServerParameters( command="/path/to/nonexistent/command", args=["--help"], ) - # Should raise an error when trying to start the process with pytest.raises(OSError) as exc_info: - async with stdio_client(server_params) as (_, _): + async with stdio_client(server_params): pass # pragma: no cover - # The error should indicate the command was not found (ENOENT: No such file or directory) assert exc_info.value.errno == errno.ENOENT @pytest.mark.anyio -async def test_stdio_client_universal_cleanup(): - """Test that stdio_client completes cleanup within reasonable time - even when connected to processes that exit slowly. +async def test_cancellation_during_spawn_leaks_no_streams(monkeypatch: pytest.MonkeyPatch) -> None: + """Cancellation while the spawn is still in flight must not leak the internal streams. + + A caller timeout can fire mid-spawn (interpreter cold start); an unclosed stream + would fail the test through its GC-time ResourceWarning under filterwarnings=error. """ + spawn_started = anyio.Event() - # Use a Python script that simulates a long-running process - # This ensures consistent behavior across platforms - long_running_script = textwrap.dedent( - """ - import time - import sys - - # Simulate a long-running process - for i in range(100): - time.sleep(0.1) - # Flush to ensure output is visible - sys.stdout.flush() - sys.stderr.flush() - """ + async def hanging_spawn( + command: str, + args: list[str], + env: dict[str, str] | None = None, + errlog: TextIO = sys.stderr, + cwd: Path | str | None = None, + ) -> FakeProcess: + spawn_started.set() + await anyio.sleep_forever() + raise NotImplementedError("unreachable: the spawn is cancelled while parked") + + monkeypatch.setattr(stdio, "_create_platform_compatible_process", hanging_spawn) + + # Cancel a scope owned by the client's task, not the test's task group: a host + # self-cancel is delivered by throwing through this test function's suspended + # frames, and Python 3.11's tracer loses coverage events after such a throw() + # traversal (python/cpython#106749). + cancel_scope = anyio.CancelScope() + + async def run_client() -> None: + with cancel_scope: + async with stdio_client(FAKE_PARAMS): + pass # pragma: no cover + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + tg.start_soon(run_client) + await spawn_started.wait() + cancel_scope.cancel() + + gc.collect() + + +@pytest.mark.anyio +async def test_a_non_oserror_spawn_failure_propagates_and_leaks_no_streams(monkeypatch: pytest.MonkeyPatch) -> None: + """A non-OSError spawn failure also propagates and leaks no streams. + + Spawning can fail with more than OSError (e.g. ValueError for a NUL byte in the + command); the error reaches the caller and the transport's internal streams are + still all closed (checked through GC-time ResourceWarnings, as above). + """ + + async def failing_spawn( + command: str, + args: list[str], + env: dict[str, str] | None = None, + errlog: TextIO = sys.stderr, + cwd: Path | str | None = None, + ) -> FakeProcess: + raise ValueError("embedded null byte") + + monkeypatch.setattr(stdio, "_create_platform_compatible_process", failing_spawn) + + with pytest.raises(ValueError, match="embedded null byte"): + async with stdio_client(FAKE_PARAMS): + pass # pragma: no cover + + gc.collect() + + +@pytest.mark.anyio +async def test_a_message_sent_just_before_exit_is_flushed_to_the_server(monkeypatch: pytest.MonkeyPatch) -> None: + """A message the transport accepted must reach the server even on immediate exit. + + The caller exits right after sending. Once the writer is parked waiting, a send is + a pure handoff that returns before the write lands, so the second message here is + the one shutdown must let the writer flush before closing the server's stdin. + """ + ping = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + pong = JSONRPCResponse(jsonrpc="2.0", id=1, result={}) + process = FakeProcess(on_stdin_close=lambda: process.exit(0)) + + install_fake_process(monkeypatch, process) + + with anyio.fail_after(5): + async with stdio_client(FAKE_PARAMS) as (_, write_stream): + await write_stream.send(SessionMessage(ping)) + await write_stream.send(SessionMessage(pong)) + + assert process.written == [_line(ping), _line(pong)] + + +@pytest.mark.anyio +async def test_a_failed_write_to_a_live_server_closes_the_read_stream_instead_of_hanging( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A failed write to a live server ends the read stream instead of hanging the session. + + When a write fails but the server is still alive (stdout never EOFs), the transport + must end the read stream so a session maps the loss to CONNECTION_CLOSED instead of + waiting forever. EIO pins that plain OSError, not just ConnectionError, is handled. + + Steps: + 1. A send fails with EIO while the server is alive; the read stream ends. + 2. Output the server produces afterwards is still drained, so it cannot wedge + on a full pipe. + """ + ping = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + pong = JSONRPCResponse(jsonrpc="2.0", id=1, result={}) + process = FakeProcess( + on_stdin_close=lambda: process.exit(0), + stdin_send_error=OSError(errno.EIO, "I/O error"), ) + terminated = install_fake_process(monkeypatch, process) - server_params = StdioServerParameters( - command=sys.executable, - args=["-c", long_running_script], + with anyio.fail_after(5): + async with stdio_client(FAKE_PARAMS) as (read_stream, write_stream): + await write_stream.send(SessionMessage(ping)) + + with pytest.raises(anyio.EndOfStream): + await read_stream.receive() + + await process.feed(_line(pong)) + await anyio.wait_all_tasks_blocked() + assert process.pending_stdout_chunks() == 0 + + assert process.written == [] + assert terminated == [] + + +@pytest.mark.anyio +async def test_exit_completes_when_a_write_is_wedged_in_a_pipe_no_one_reads( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Exiting stays bounded even when the writer is parked in a write that cannot complete. + + A kill-surviving descendant can hold the read end without reading; the flush window + expires and the post-shutdown cancellation unparks the writer. + """ + ping = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + process = FakeProcess(on_stdin_close=lambda: process.exit(0), stdin_send_blocks=True) + terminated = install_fake_process(monkeypatch, process) + monkeypatch.setattr(stdio, "_WRITER_FLUSH_TIMEOUT", 0.05) + + with anyio.fail_after(5): + async with stdio_client(FAKE_PARAMS) as (_, write_stream): + await write_stream.send(SessionMessage(ping)) + # Wait until the writer task is genuinely parked inside the wedged send. + await anyio.wait_all_tasks_blocked() + + assert process.written == [] + assert terminated == [] + assert process.stdin_closed.is_set() + + +@pytest.mark.anyio +async def test_undelivered_server_output_is_drained_at_shutdown_so_the_server_can_exit( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Output the caller never received is consumed during the stdin-close grace period. + + A real server flushing its remaining output on the way out would otherwise block on + a full pipe, never reach its stdin read, and be killed despite being well-behaved. + The fake ignores stdin closure (so it is ultimately terminated); the pin is that its + backlog was drained during the grace window. + """ + ping = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + pong = JSONRPCResponse(jsonrpc="2.0", id=1, result={}) + process = FakeProcess() + terminated = install_fake_process(monkeypatch, process) + + with anyio.fail_after(5): + async with stdio_client(FAKE_PARAMS): + # Three separate chunks: the reader parks delivering the first; the other + # two sit unconsumed in the pipe when shutdown begins. + await process.feed(_line(ping)) + await process.feed(_line(pong)) + await process.feed(_line(ping)) + await anyio.wait_all_tasks_blocked() + assert process.pending_stdout_chunks() == 2 + + assert terminated == [process] + assert process.pending_stdout_chunks() == 0 + + +@pytest.mark.anyio +async def test_shutdown_drains_stdout_first_so_a_wedged_writers_flush_can_complete( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Shutdown unblocks the reader's drain before waiting out the writer flush. + + A server wedged writing its stdout cannot get to reading its stdin, so a client + write can sit in a full pipe; the drain is what unwedges the server and lets the + flush complete. + """ + ping = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + pong = JSONRPCResponse(jsonrpc="2.0", id=1, result={}) + + received = 0 + stdin_gate = anyio.Event() + + def unwedge_once_drained() -> None: + # Accept the client's write only once all three output chunks are consumed, + # like a real server whose blocked stdout write gates its stdin read. + nonlocal received + received += 1 + if received == 3: + stdin_gate.set() + + process = FakeProcess( + on_stdin_close=lambda: process.exit(0), + stdin_send_gate=stdin_gate, + on_stdout_receive=unwedge_once_drained, ) + terminated = install_fake_process(monkeypatch, process) + # A flush wait that never gets unwedged would outlast the whole test budget. + monkeypatch.setattr(stdio, "_WRITER_FLUSH_TIMEOUT", 30.0) - start_time = time.time() + with anyio.fail_after(5): + async with stdio_client(FAKE_PARAMS) as (_read_stream, write_stream): + # The reader parks delivering a message nobody receives, with more + # chunks backed up behind it; the writer parks in the gated send. + await process.feed(_line(ping)) + await process.feed(_line(pong)) + await process.feed(_line(ping)) + await write_stream.send(SessionMessage(ping)) + await anyio.wait_all_tasks_blocked() - with anyio.move_on_after(8.0) as cancel_scope: - async with stdio_client(server_params) as (_, _): - # Immediately exit - this triggers cleanup while process is still running - pass + assert terminated == [] + assert len(process.written) == 1 + assert process.pending_stdout_chunks() == 0 - end_time = time.time() - elapsed = end_time - start_time - # On Windows: 2s (stdin wait) + 2s (terminate wait) + overhead = ~5s expected - assert elapsed < 6.0, ( - f"stdio_client cleanup took {elapsed:.1f} seconds, expected < 6.0 seconds. " - f"This suggests the timeout mechanism may not be working properly." - ) +@pytest.mark.anyio +async def test_cancellation_with_undelivered_backlog_still_drains_and_spares_the_server( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Cancellation must not skip the shutdown drain. + + A well-behaved server that can only exit once its remaining output is consumed (a + real one blocks on a full stdout pipe) still exits within the grace period and is + never terminated. + """ + ping = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + pong = JSONRPCResponse(jsonrpc="2.0", id=1, result={}) + process = FakeProcess() + terminated = install_fake_process(monkeypatch, process) + + def exit_when_flushed() -> None: + # The fake exits only once its stdin has closed AND its output backlog + # has been consumed, like a real server wedged writing its stdout. + if process.stdin_closed.is_set() and process.pending_stdout_chunks() == 0: + process.exit(0) + + process.on_stdin_close = exit_when_flushed + process.on_stdout_receive = exit_when_flushed + + entered = anyio.Event() + # Cancel a scope owned by the client's task, not the test's task group (see + # test_cancelling_the_client_still_runs_the_full_shutdown). + cancel_scope = anyio.CancelScope() + + async def run_client_until_cancelled() -> None: + with cancel_scope: + async with stdio_client(FAKE_PARAMS): + await process.feed(_line(ping)) + await process.feed(_line(pong)) + await process.feed(_line(ping)) + entered.set() + await anyio.sleep_forever() + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + tg.start_soon(run_client_until_cancelled) + await entered.wait() + cancel_scope.cancel() + + assert process.pending_stdout_chunks() == 0 + assert terminated == [] - # Check if we timed out - if cancel_scope.cancelled_caught: # pragma: no cover - pytest.fail( - "stdio_client cleanup timed out after 8.0 seconds. " - "This indicates the cleanup mechanism is hanging and needs fixing." - ) + +@pytest.mark.anyio +async def test_invalid_utf8_flushed_by_a_dying_server_does_not_break_shutdown( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """The shutdown drain consumes raw bytes. + + A server flushing non-UTF-8 output (a crash dump, say) on its way out must not + abort the drain or surface a UnicodeDecodeError out of the context manager. + """ + ping = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + process = FakeProcess(on_stdin_close=lambda: process.exit(0)) + terminated = install_fake_process(monkeypatch, process) + + with anyio.fail_after(5): + async with stdio_client(FAKE_PARAMS): + # Park the reader delivering a message nobody receives, then queue + # bytes that are not valid UTF-8 behind it. + await process.feed(_line(ping)) + await anyio.wait_all_tasks_blocked() + await process.feed(b"\xff\xfe not utf-8\n") + + assert terminated == [] + assert process.pending_stdout_chunks() == 0 @pytest.mark.anyio -@pytest.mark.skipif(sys.platform == "win32", reason="Windows signal handling is different") -async def test_stdio_client_sigint_only_process(): # pragma: lax no cover - """Test cleanup with a process that ignores SIGTERM but responds to SIGINT.""" - # Create a Python script that ignores SIGTERM but handles SIGINT - script_content = textwrap.dedent( - """ - import signal - import sys - import time +async def test_a_kill_racing_a_pending_stdout_read_is_swallowed_during_shutdown( + monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, +) -> None: + """A hard kill during a pending stdout read must not escape the context manager. + + The read surfaces ConnectionResetError on the proactor backend; being expected + teardown noise, it is not logged as an error either. + """ + process = FakeProcess(stdout_eof_error=ConnectionResetError("read torn down by kill")) + terminated = install_fake_process(monkeypatch, process) - # Ignore SIGTERM (what process.terminate() sends) - signal.signal(signal.SIGTERM, signal.SIG_IGN) + with anyio.fail_after(5): + async with stdio_client(FAKE_PARAMS): + pass # the fake ignores stdin closure, so shutdown must escalate - # Handle SIGINT (Ctrl+C signal) by exiting cleanly - def sigint_handler(signum, frame): - sys.exit(0) + assert terminated == [process] + assert not [record for record in caplog.records if record.levelno >= logging.ERROR] - signal.signal(signal.SIGINT, sigint_handler) - # Keep running until SIGINT received - while True: - time.sleep(0.1) - """ +@pytest.mark.anyio +async def test_a_mid_session_stdout_failure_is_logged_and_surfaces_as_clean_closure( + monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, +) -> None: + """A mid-session stdout read failure ends the read stream cleanly and is logged. + + A failure outside shutdown surfaces no raw exception out of the context manager and + leaves an error log identifying the failure, unlike the silent shutdown case. + """ + process = FakeProcess( + on_stdin_close=lambda: process.exit(0), + stdout_eof_error=ConnectionResetError("pipe failed mid-session"), ) + install_fake_process(monkeypatch, process) - server_params = StdioServerParameters( - command=sys.executable, - args=["-c", script_content], + with anyio.fail_after(5): + async with stdio_client(FAKE_PARAMS) as (read_stream, _): + process.exit(1) + # (no branch: coverage mis-traces the exit arc of a `with` whose body + # raises inside a nested async context.) + with pytest.raises(anyio.EndOfStream): # pragma: no branch + await read_stream.receive() + + assert "stdout failed mid-session" in caplog.text + + +@pytest.mark.anyio +async def test_a_failing_stdout_close_still_closes_the_transport_streams(monkeypatch: pytest.MonkeyPatch) -> None: + """A close-time error on the process's stdout must not abort the rest of the shutdown. + + Such an error (a contended pipe handle on the Windows fallback) still leaves the + context exiting cleanly and the internal streams all closed (checked via GC-time + ResourceWarnings). + """ + process = FakeProcess( + on_stdin_close=lambda: process.exit(0), + stdout_aclose_error=OSError(errno.EBADF, "Bad file descriptor"), ) + terminated = install_fake_process(monkeypatch, process) - start_time = time.time() - - try: - # Use anyio timeout to prevent test from hanging forever - with anyio.move_on_after(5.0) as cancel_scope: - async with stdio_client(server_params) as (_, _): - # Let the process start and begin ignoring SIGTERM - await anyio.sleep(0.5) - # Exit context triggers cleanup - this should not hang - pass - - if cancel_scope.cancelled_caught: # pragma: no cover - raise TimeoutError("Test timed out") - - end_time = time.time() - elapsed = end_time - start_time - - # Should complete quickly even with SIGTERM-ignoring process - # This will fail if cleanup only uses process.terminate() without fallback - assert elapsed < SIGTERM_IGNORING_PROCESS_TIMEOUT, ( - f"stdio_client cleanup took {elapsed:.1f} seconds with SIGTERM-ignoring process. " - f"Expected < {SIGTERM_IGNORING_PROCESS_TIMEOUT} seconds. " - "This suggests the cleanup needs SIGINT/SIGKILL fallback." - ) - except (TimeoutError, Exception) as e: # pragma: no cover - if isinstance(e, TimeoutError) or "timed out" in str(e): - pytest.fail( - f"stdio_client cleanup timed out after {SIGTERM_IGNORING_PROCESS_TIMEOUT} seconds " - "with SIGTERM-ignoring process. " - "This confirms the cleanup needs SIGINT/SIGKILL fallback for processes that ignore SIGTERM." - ) - else: - raise + with anyio.fail_after(5): + async with stdio_client(FAKE_PARAMS): + pass + + assert terminated == [] + gc.collect() + + +@pytest.mark.anyio +async def test_a_process_surviving_the_kill_escalation_is_logged_and_abandoned( + monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, +) -> None: + """A process surviving the whole kill escalation is logged and abandoned. + + If the process is still alive after the escalation (D-state, an unsignalable + survivor), shutdown still completes, bounded, and leaves a warning instead of + silently leaking a live process. + """ + process = FakeProcess() # ignores stdin closure and survives "termination" + install_fake_process(monkeypatch, process, grace_period=0.05) + + stubborn: list[FakeProcess] = [] + + async def stubborn_terminate(proc: FakeProcess) -> None: + stubborn.append(proc) # the kill has no effect + + monkeypatch.setattr(stdio, "_terminate_process_tree", stubborn_terminate) + monkeypatch.setattr(stdio, "_KILL_REAP_TIMEOUT", 0.05) + + with anyio.fail_after(5): + async with stdio_client(FAKE_PARAMS): + pass + + assert stubborn == [process] + assert process.returncode is None + assert "still alive after the kill escalation" in caplog.text + # The fake "survived", so nothing ever EOF'd its stdout pipe; release it here + # or its GC-time ResourceWarning would fail a later test. + process.close_stdout() # --------------------------------------------------------------------------- -# TestChildProcessCleanup — socket-based deterministic child liveness probe +# POSIX tree-termination policy, tested through the sanctioned killpg seam # --------------------------------------------------------------------------- # -# These tests verify that `_terminate_process_tree()` kills the *entire* process -# tree (not just the immediate child), which is critical for cleaning up tools -# like `npx` that spawn their own subprocesses. -# -# Mechanism: each subprocess in the tree connects a TCP socket back to a -# listener owned by the test. We then use two kernel-guaranteed blocking-I/O -# signals — neither requires any `sleep()` or polling loop: -# -# 1. `await listener.accept()` blocks until the subprocess connects, -# proving it is running. -# 2. After `_terminate_process_tree()`, `await stream.receive(1)` raises -# `EndOfStream` (clean close / FIN) or `BrokenResourceError` (abrupt -# close / RST — typical on Windows after TerminateJobObject) because the -# kernel closes all file descriptors when a process terminates. Either -# is the direct, OS-level proof that the child is dead. +# `mcp.os.posix.utilities` is coverage-omitted and the sanctioned place to monkeypatch +# OS calls. These pin the EPERM policy without a foreign-euid process: macOS killpg +# raises EPERM when *any* group member cannot be signalled, even if others were. + + +class _StubPosixProcess: + """The two attributes `terminate_posix_process_tree` touches. + + They are the pgid source and the reap-progress probe. + """ + + pid = 54321 + returncode: int | None = None + + +@pytest.mark.anyio +@pytest.mark.skipif(sys.platform == "win32", reason="POSIX killpg semantics") +# lax no cover: Windows CI jobs enforce 100% coverage per job and skip this test. +async def test_an_eperm_group_that_dies_during_the_grace_period_is_not_sigkilled( # pragma: lax no cover + monkeypatch: pytest.MonkeyPatch, +) -> None: + """EPERM from the SIGTERM killpg no longer short-circuits termination. + + The grace wait still runs, and a group observed to be gone during it is never + SIGKILLed. + """ + calls: list[tuple[int, int]] = [] + probes = 0 + + def fake_killpg(pgid: int, sig: int) -> None: + nonlocal probes + calls.append((pgid, sig)) + if sig == signal.SIGTERM: + raise PermissionError("one group member has a foreign euid") + if sig == 0: + probes += 1 + if probes == 1: + raise PermissionError("survivors we may not signal") + raise ProcessLookupError("group is gone") + raise NotImplementedError("no other signal should be sent") + + monkeypatch.setattr(posix_utilities.os, "killpg", fake_killpg) + stub = _StubPosixProcess() + + with anyio.fail_after(5): + await terminate_posix_process_tree(cast(anyio.abc.Process, stub)) + + assert calls == [(stub.pid, signal.SIGTERM), (stub.pid, 0), (stub.pid, 0)] + + +@pytest.mark.anyio +@pytest.mark.skipif(sys.platform == "win32", reason="POSIX killpg semantics") +# lax no cover: same Windows-runner coverage reason as above. +async def test_an_eperm_group_that_outlives_the_grace_period_is_still_sigkilled( # pragma: lax no cover + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Even when every probe reports EPERM, the SIGKILL escalation still fires. + + It fires after the grace period, and its own EPERM is tolerated. Pre-fix, EPERM at + SIGTERM abandoned the group escalation for a leader-only kill, leaking every other + group member. The tiny timeout is the time-based grace period under test. + """ + calls: list[tuple[int, int]] = [] + + def fake_killpg(pgid: int, sig: int) -> None: + calls.append((pgid, sig)) + if sig in (signal.SIGTERM, 0, signal.SIGKILL): + raise PermissionError("a foreign-euid member never goes away") + raise NotImplementedError("no other signal should be sent") + + monkeypatch.setattr(posix_utilities.os, "killpg", fake_killpg) + stub = _StubPosixProcess() + + with anyio.fail_after(5): + await terminate_posix_process_tree(cast(anyio.abc.Process, stub), timeout_seconds=0.05) + + assert calls[0] == (stub.pid, signal.SIGTERM) + assert calls[-1] == (stub.pid, signal.SIGKILL) + assert set(calls[1:-1]) == {(stub.pid, 0)} + + +@pytest.mark.anyio +@pytest.mark.parametrize("anyio_backend", ["asyncio", "trio"]) +@pytest.mark.skipif(sys.platform == "win32", reason="POSIX killpg semantics") +# lax no cover: same Windows-runner coverage reason as above. +async def test_the_grace_wait_reads_returncode_so_trio_can_reap_the_leaders_zombie( # pragma: lax no cover + monkeypatch: pytest.MonkeyPatch, +) -> None: + """The wait between SIGTERM and SIGKILL reads `process.returncode` while it polls. + + On trio that property calls `Popen.poll()`, whose reap stops the leader's zombie + from keeping the group alive for the full timeout (see terminate_posix_process_tree). + Regression pin for the read itself, on both backends; the reaping side effect is + trio's documented behaviour, deliberately not re-tested here. + """ + + calls: list[tuple[int, int]] = [] + + def fake_killpg(pgid: int, sig: int) -> None: + # SIGTERM is accepted and every liveness probe reports survivors, so the + # grace wait runs to its (tiny) timeout and the SIGKILL escalation fires. + calls.append((pgid, sig)) + + class _ReadCountingProcess: + """A live-forever leader whose `returncode` property counts its reads.""" + + pid = 54321 + + def __init__(self) -> None: + self.returncode_reads = 0 + + @property + def returncode(self) -> int | None: + self.returncode_reads += 1 + return None + + monkeypatch.setattr(posix_utilities.os, "killpg", fake_killpg) + stub = _ReadCountingProcess() + + with anyio.fail_after(5): + await terminate_posix_process_tree(cast(anyio.abc.Process, stub), timeout_seconds=0.05) + + # The wait ran to its deadline (the escalation fired)... + assert calls[0] == (stub.pid, signal.SIGTERM) + assert calls[-1] == (stub.pid, signal.SIGKILL) + # ...and `returncode` was read while it polled, the read that reaps on trio. + assert stub.returncode_reads >= 1 + + +# --------------------------------------------------------------------------- +# Real-process tests: the OS facts no fake can certify +# --------------------------------------------------------------------------- # -# This replaces an older file-growth-watching approach whose fixed `sleep()` -# durations raced against slow Python interpreter startup on loaded CI runners. +# These pin kernel behaviour (process-group kill semantics, SIGKILL delivery) via a +# socket liveness probe, no sleeps or polls: `accept()` blocks until the subprocess +# connects, proving it runs (and its pre-connect setup ran); after cleanup, `receive(1)` +# raises EndOfStream (FIN) or BrokenResourceError (RST, typical of SIGKILL and Windows +# job termination) because the kernel closes a dead process's file descriptors. def _connect_back_script(port: int) -> str: - """Return a ``python -c`` script body that connects to the given port, - sends ``b'alive'``, then blocks forever. Used by TestChildProcessCleanup - subprocesses as a liveness probe.""" + """Return a ``python -c`` liveness-probe body: connect to `port`, send `b'alive'`, block forever.""" return ( f"import socket, time\n" f"s = socket.create_connection(('127.0.0.1', {port}))\n" @@ -256,15 +1123,6 @@ def _connect_back_script(port: int) -> str: ) -def _spawn_then_block(child_script: str) -> str: - """Return a ``python -c`` script body that spawns ``child_script`` as a - subprocess, then blocks forever. The ``!r`` injection avoids nested-quote - escaping for arbitrary child script content.""" - return ( - f"import subprocess, sys, time\nsubprocess.Popen([sys.executable, '-c', {child_script!r}])\ntime.sleep(3600)\n" - ) - - async def _open_liveness_listener() -> tuple[anyio.abc.SocketListener, int]: """Open a TCP listener on localhost and return it along with its port.""" multi = await anyio.create_tcp_listener(local_host="127.0.0.1") @@ -279,9 +1137,8 @@ async def _open_liveness_listener() -> tuple[anyio.abc.SocketListener, int]: async def _accept_alive(sock: anyio.abc.SocketListener) -> anyio.abc.SocketStream: """Accept one connection and assert the peer sent ``b'alive'``. - Blocks deterministically until a subprocess connects (no polling). The - outer test bounds this with ``anyio.fail_after`` to catch the case where - the subprocess chain failed to start. + Blocks until a subprocess connects (the outer test bounds this with + ``anyio.fail_after``). """ stream = await sock.accept() msg = await stream.receive(5) @@ -290,53 +1147,31 @@ async def _accept_alive(sock: anyio.abc.SocketListener) -> anyio.abc.SocketStrea async def _assert_stream_closed(stream: anyio.abc.SocketStream) -> None: - """Assert the peer holding the other end of ``stream`` has terminated. - - When a process dies, the kernel closes its file descriptors including - sockets. The next ``receive()`` on the peer socket unblocks with one of: - - - ``anyio.EndOfStream`` — clean close (FIN), typical after graceful exit - or POSIX ``SIGTERM``. - - ``anyio.BrokenResourceError`` — abrupt close (RST), typical after - Windows ``TerminateJobObject`` or POSIX ``SIGKILL``. - - Either is a deterministic, kernel-level signal that the process is dead — - no sleeps or polling required. - """ + """Assert the peer holding the other end of `stream` has terminated.""" with anyio.fail_after(5.0), pytest.raises((anyio.EndOfStream, anyio.BrokenResourceError)): await stream.receive(1) -async def _terminate_and_reap(proc: anyio.abc.Process | FallbackProcess) -> None: - """Terminate the process tree, reap, and tear down pipe transports. +# lax no cover: only called by win32-skipped tests; Windows CI jobs enforce 100% +# coverage per job, where these helpers never execute. +async def _wait_until_exited(proc: anyio.abc.Process) -> None: # pragma: lax no cover + """Poll `returncode` until the process itself dies. - ``_terminate_process_tree`` kills the OS process group / Job Object but does - not call ``process.wait()`` or clean up the asyncio pipe transports. On - Windows those transports leak and emit ``ResourceWarning`` when GC'd in a - later test, causing ``PytestUnraisableExceptionWarning`` knock-on failures. + Not `proc.wait()`: on asyncio that also waits for the pipes to close, conflating + process death with pipe state. + """ + while proc.returncode is None: + await anyio.sleep(0.01) - Production ``stdio.py`` avoids this via its ``stdout_reader`` task which - reads stdout to EOF (triggering ``_ProactorReadPipeTransport._eof_received`` - → ``close()``) plus ``async with process:`` which waits and closes stdin. - These tests call ``_terminate_process_tree`` directly, so they replicate - both parts here: ``wait()`` + close stdin + drain stdout to EOF. - The stdout drain is the non-obvious part: anyio's ``StreamReaderWrapper.aclose()`` - only marks the Python-level reader closed — it never touches the underlying - ``_ProactorReadPipeTransport``. That transport starts paused and only detects - pipe EOF when someone reads, so without a drain it lives until ``__del__``. +async def _reap(proc: anyio.abc.Process) -> None: # pragma: lax no cover + """Reap an already-killed process and release its pipe transports. - Idempotent: the ``returncode`` guard skips termination if already reaped - (avoids spurious WARNING/ERROR logs from ``terminate_posix_process_tree``'s - fallback path, visible because ``log_cli = true``); ``wait()`` and stream - ``aclose()`` no-op on subsequent calls; the drain raises ``ClosedResourceError`` - on the second call, caught by the suppress. The tests call this explicitly - as the action under test and ``AsyncExitStack`` calls it again on exit as a - safety net. Bounded by ``move_on_after`` to prevent hangs. + Draining stdout to EOF lets the asyncio pipe transport observe the closure instead + of warning at GC. The bound swallows a hung cleanup on purpose; reaping is just a + safety net. """ with anyio.move_on_after(5.0): - if proc.returncode is None: - await _terminate_process_tree(proc) await proc.wait() assert proc.stdin is not None assert proc.stdout is not None @@ -346,215 +1181,220 @@ async def _terminate_and_reap(proc: anyio.abc.Process | FallbackProcess) -> None await proc.stdout.aclose() -class TestChildProcessCleanup: - """Integration tests for ``_terminate_process_tree`` covering basic, - nested, and early-parent-exit process tree scenarios. See module-level - comment above for the socket-based liveness probe mechanism. - """ - - @pytest.mark.anyio - async def test_basic_child_process_cleanup(self): - """Parent spawns one child; terminating the tree kills both.""" - async with AsyncExitStack() as stack: - sock, port = await _open_liveness_listener() - stack.push_async_callback(sock.aclose) - - # Parent spawns a child; the child connects back to us. - parent_script = _spawn_then_block(_connect_back_script(port)) - proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script]) - stack.push_async_callback(_terminate_and_reap, proc) - - # Deterministic: accept() blocks until the child connects. No sleep. - with anyio.fail_after(10.0): - stream = await _accept_alive(sock) - stack.push_async_callback(stream.aclose) +def _record_spawned_processes(monkeypatch: pytest.MonkeyPatch) -> list[anyio.abc.Process | FallbackProcess]: + """Record every process `stdio_client` spawns (the real spawn still runs). - # Terminate, reap and close transports (wraps _terminate_process_tree, - # the behavior under test). - await _terminate_and_reap(proc) + A test can inspect each process afterwards and tear its process group down on + failure. + """ + spawned: list[anyio.abc.Process | FallbackProcess] = [] + + async def recording_spawn( + command: str, + args: list[str], + env: dict[str, str] | None = None, + errlog: TextIO = sys.stderr, + cwd: Path | str | None = None, + ) -> anyio.abc.Process | FallbackProcess: + process = await _create_platform_compatible_process(command, args, env, errlog, cwd) + spawned.append(process) + return process + + monkeypatch.setattr(stdio, "_create_platform_compatible_process", recording_spawn) + return spawned + + +# lax no cover: registered on every platform but a no-op on Windows, whose runners +# enforce 100% coverage per job. +def _kill_spawn_groups(spawned: list[anyio.abc.Process | FallbackProcess]) -> None: # pragma: lax no cover + """Failure-path safety net: SIGKILL each spawn-time process group. + + This stops a test failing mid-body from orphaning its sleep-forever descendants. + A no-op when the test passed, and on Windows (no process group to signal; the Job + Object covers strays). + """ + if sys.platform == "win32": + return + for process in spawned: + # macOS killpg raises EPERM for a group holding only unreaped zombies. + with suppress(ProcessLookupError, PermissionError): + os.killpg(process.pid, signal.SIGKILL) - # Deterministic: kernel closed child's socket when it died. - await _assert_stream_closed(stream) - @pytest.mark.anyio - async def test_nested_process_tree(self): - """Parent → child → grandchild; terminating the tree kills all three.""" - async with AsyncExitStack() as stack: - sock, port = await _open_liveness_listener() - stack.push_async_callback(sock.aclose) - - # Build a three-level chain: parent spawns child, child spawns - # grandchild. Every level connects back to our socket. - grandchild = _connect_back_script(port) - child = ( - f"import subprocess, sys\n" - f"subprocess.Popen([sys.executable, '-c', {grandchild!r}])\n" + _connect_back_script(port) - ) - parent_script = ( - f"import subprocess, sys\n" - f"subprocess.Popen([sys.executable, '-c', {child!r}])\n" + _connect_back_script(port) - ) - proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script]) - stack.push_async_callback(_terminate_and_reap, proc) - - # Deterministic: three blocking accepts, one per tree level. - streams: list[anyio.abc.SocketStream] = [] - with anyio.fail_after(10.0): - for _ in range(3): - stream = await _accept_alive(sock) +@pytest.mark.anyio +async def test_exiting_the_context_terminates_the_entire_process_tree(monkeypatch: pytest.MonkeyPatch) -> None: + """Exiting `stdio_client` kills the server's whole process tree. + + The tree is a parent that exits instantly on SIGTERM (so the group must outlive its + leader), a child, and a grandchild, each death observed through its liveness socket + closing. The escalation timing is pinned in process by + test_escalation_fires_once_and_only_after_the_grace_period; the production grace + constant's value is deliberately unpinned. + """ + monkeypatch.setattr(stdio, "PROCESS_TERMINATION_TIMEOUT", 0.2) + spawned = _record_spawned_processes(monkeypatch) + + async with AsyncExitStack() as stack: + stack.callback(_kill_spawn_groups, spawned) + sock, port = await _open_liveness_listener() + stack.push_async_callback(sock.aclose) + + grandchild = _connect_back_script(port) + child = ( + f"import subprocess, sys\nsubprocess.Popen([sys.executable, '-c', {grandchild!r}])\n" + + _connect_back_script(port) + ) + # The parent exits immediately on SIGTERM and never reads stdin, so cleanup + # must escalate, and the group kill must work even as its leader dies first. + parent = ( + f"import signal, subprocess, sys, time\n" + f"signal.signal(signal.SIGTERM, lambda *_: sys.exit(0))\n" + f"subprocess.Popen([sys.executable, '-c', {child!r}])\n" + _connect_back_script(port) + ) + server_params = StdioServerParameters(command=sys.executable, args=["-c", parent]) + + # The bound covers three Python interpreter cold starts on a loaded runner; + # a healthy run takes well under a second. + with anyio.fail_after(15.0): + async with stdio_client(server_params): + streams = [await _accept_alive(sock) for _ in range(3)] + for stream in streams: stack.push_async_callback(stream.aclose) - streams.append(stream) - # Terminate the entire tree (wraps _terminate_process_tree). - await _terminate_and_reap(proc) - - # Every level of the tree must be dead: three kernel-level EOFs. - for stream in streams: - await _assert_stream_closed(stream) - - @pytest.mark.anyio - async def test_early_parent_exit(self): - """Parent exits immediately on SIGTERM; process-group termination still - catches the child (exercises the race where the parent dies mid-cleanup). - """ - async with AsyncExitStack() as stack: - sock, port = await _open_liveness_listener() - stack.push_async_callback(sock.aclose) - - # Parent installs a SIGTERM handler that exits immediately, spawns a - # child that connects back to us, then blocks. - child = _connect_back_script(port) - parent_script = ( - f"import signal, subprocess, sys, time\n" - f"signal.signal(signal.SIGTERM, lambda *_: sys.exit(0))\n" - f"subprocess.Popen([sys.executable, '-c', {child!r}])\n" - f"time.sleep(3600)\n" - ) - proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script]) - stack.push_async_callback(_terminate_and_reap, proc) - - # Deterministic: child connected means both parent and child are up. - with anyio.fail_after(10.0): - stream = await _accept_alive(sock) - stack.push_async_callback(stream.aclose) - - # Parent will sys.exit(0) on SIGTERM, but the process-group kill - # (POSIX killpg / Windows Job Object) must still terminate the child. - await _terminate_and_reap(proc) - - # Child must be dead despite parent's early exit. + for stream in streams: await _assert_stream_closed(stream) @pytest.mark.anyio -async def test_stdio_client_graceful_stdin_exit(): - """Test that a process exits gracefully when stdin is closed, - without needing SIGTERM or SIGKILL. +@pytest.mark.skipif(sys.platform == "win32", reason="POSIX process-group semantics") +# lax no cover: Windows CI jobs enforce 100% coverage per job and skip this test. +async def test_tree_kill_reaches_children_after_the_leader_has_already_exited() -> None: # pragma: lax no cover + """Killing the tree of an already-exited process still reaches its surviving children. + + The process group outlives its leader, and the group ID is the leader's pid by + construction (start_new_session), not something to look up from the (reaped) + leader. """ - # Create a Python script that exits when stdin is closed - script_content = textwrap.dedent( - """ - import sys - - # Read from stdin until it's closed - try: - while True: - line = sys.stdin.readline() - if not line: # EOF/stdin closed - break - except: - pass - - # Exit gracefully - sys.exit(0) - """ - ) + async with AsyncExitStack() as stack: + sock, port = await _open_liveness_listener() + stack.push_async_callback(sock.aclose) + + child = _connect_back_script(port) + # The parent spawns the child and exits immediately: the group leader is dead + # (and reaped) by the time the tree is terminated. + parent = f"import subprocess, sys\nsubprocess.Popen([sys.executable, '-c', {child!r}])\n" + proc = await _create_platform_compatible_process(sys.executable, ["-c", parent]) + assert isinstance(proc, anyio.abc.Process) + stack.callback(_kill_spawn_groups, [proc]) + stack.push_async_callback(_reap, proc) + + # Two interpreter cold starts on a loaded runner; healthy runs take ~0.2s. + with anyio.fail_after(10.0): + stream = await _accept_alive(sock) + stack.push_async_callback(stream.aclose) + # The child connecting proves the parent ran; wait for the leader itself + # to be gone so the kill exercises the dead-leader path. + await _wait_until_exited(proc) - server_params = StdioServerParameters( - command=sys.executable, - args=["-c", script_content], - ) + await _terminate_process_tree(proc) - start_time = time.time() + await _assert_stream_closed(stream) - # Use anyio timeout to prevent test from hanging forever - with anyio.move_on_after(5.0) as cancel_scope: - async with stdio_client(server_params) as (_, _): - # Let the process start and begin reading stdin - await anyio.sleep(0.2) - # Exit context triggers cleanup - process should exit from stdin closure - pass - if cancel_scope.cancelled_caught: - pytest.fail( - "stdio_client cleanup timed out after 5.0 seconds. " - "Process should have exited gracefully when stdin was closed." - ) # pragma: no cover +@pytest.mark.anyio +@pytest.mark.skipif(sys.platform == "win32", reason="POSIX process-group semantics") +# lax no cover: same Windows-runner coverage reason as above. +async def test_terminating_an_already_exited_process_is_a_no_op() -> None: # pragma: lax no cover + """Once the whole group is gone, tree termination returns without error. - end_time = time.time() - elapsed = end_time - start_time + It does not fall back to signalling a reaped pid. + """ + proc = await _create_platform_compatible_process(sys.executable, ["-c", "pass"]) + assert isinstance(proc, anyio.abc.Process) - # Should complete quickly with just stdin closure (no signals needed) - assert elapsed < 3.0, ( - f"stdio_client cleanup took {elapsed:.1f} seconds for stdin-aware process. " - f"Expected < 3.0 seconds since process should exit on stdin closure." - ) + # The bound covers one interpreter cold start on a loaded runner; a healthy run + # takes well under a second. + with anyio.fail_after(10.0): + await _wait_until_exited(proc) + await _terminate_process_tree(proc) + await _reap(proc) @pytest.mark.anyio -async def test_stdio_client_stdin_close_ignored(): - """Test that when a process ignores stdin closure, the shutdown sequence - properly escalates to SIGTERM. +@pytest.mark.skipif(sys.platform == "win32", reason="Windows signal handling is different") +# lax no cover: Windows CI jobs enforce 100% coverage per job and skip this test. +async def test_escalation_kills_a_process_that_ignores_sigterm( # pragma: lax no cover + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Cleanup escalates past SIGTERM and kills a process that ignores it. + + The child installs SIG_IGN *before* connecting to the liveness socket, so the + ignore is guaranteed in place; SIGKILL delivery is proven by the kernel closing + the socket. The only test of the SIGTERM-then-SIGKILL escalation itself; the + production constants' values are deliberately unpinned. """ - # Create a Python script that ignores stdin closure but responds to SIGTERM - script_content = textwrap.dedent( - """ - import signal - import sys - import time - - # Set up SIGTERM handler to exit cleanly - def sigterm_handler(signum, frame): - sys.exit(0) - - signal.signal(signal.SIGTERM, sigterm_handler) - - # Close stdin immediately to simulate ignoring it - sys.stdin.close() + monkeypatch.setattr(stdio, "PROCESS_TERMINATION_TIMEOUT", 0.2) + monkeypatch.setattr(stdio, "FORCE_KILL_TIMEOUT", 0.2) + spawned = _record_spawned_processes(monkeypatch) + + async with AsyncExitStack() as stack: + stack.callback(_kill_spawn_groups, spawned) + sock, port = await _open_liveness_listener() + stack.push_async_callback(sock.aclose) + + script = "import signal\nsignal.signal(signal.SIGTERM, signal.SIG_IGN)\n" + _connect_back_script(port) + server_params = StdioServerParameters(command=sys.executable, args=["-c", script]) + + # The bound covers an interpreter cold start on a loaded runner plus the two + # shortened escalation waits; a healthy run takes well under a second. + with anyio.fail_after(15.0): + async with stdio_client(server_params): + stream = await _accept_alive(sock) + stack.push_async_callback(stream.aclose) - # Keep running until SIGTERM - while True: - time.sleep(0.1) - """ - ) + await _assert_stream_closed(stream) - server_params = StdioServerParameters( - command=sys.executable, - args=["-c", script_content], - ) - start_time = time.time() +@pytest.mark.anyio +@pytest.mark.skipif(not Path("/proc/self/fd").is_dir(), reason="needs procfs to enumerate open file descriptors") +# lax no cover: Windows CI jobs enforce 100% coverage per job, have no procfs, and skip this. +async def test_a_graceful_exit_with_a_surviving_child_leaks_no_pipe_fds( # pragma: lax no cover + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A graceful exit with a surviving child must not leak the client's pipe fds. + + A server may exit cleanly on stdin closure while leaving a child holding the + inherited pipe ends (the POSIX policy: survivors are the server's business). The + client must still release its own pipe fds and subprocess transport at shutdown + (on asyncio nothing else ever closes them while the orphan holds the pipe) instead + of leaking them for the orphan's lifetime. + """ + spawned = _record_spawned_processes(monkeypatch) - # Use anyio timeout to prevent test from hanging forever - with anyio.move_on_after(7.0) as cancel_scope: - async with stdio_client(server_params) as (_, _): - # Let the process start - await anyio.sleep(0.2) - # Exit context triggers cleanup - pass + async with AsyncExitStack() as stack: + stack.callback(_kill_spawn_groups, spawned) + sock, port = await _open_liveness_listener() + stack.push_async_callback(sock.aclose) - if cancel_scope.cancelled_caught: - pytest.fail( - "stdio_client cleanup timed out after 7.0 seconds. " - "Process should have been terminated via SIGTERM escalation." - ) # pragma: no cover + child = _connect_back_script(port) + # The server hands its inherited pipes to a child, then exits as soon as its + # stdin closes: the well-behaved graceful path, so no kill ever happens. + server = f"import subprocess, sys\nsubprocess.Popen([sys.executable, '-c', {child!r}])\nsys.stdin.read()\n" + server_params = StdioServerParameters(command=sys.executable, args=["-c", server]) - end_time = time.time() - elapsed = end_time - start_time + gc.collect() # settle earlier garbage so its collection cannot close fds mid-test + baseline = set(os.listdir("/proc/self/fd")) - # Should take ~2 seconds (stdin close timeout) before SIGTERM is sent - # Total time should be between 2-4 seconds - assert 1.5 < elapsed < 4.5, ( - f"stdio_client cleanup took {elapsed:.1f} seconds for stdin-ignoring process. " - f"Expected between 2-4 seconds (2s stdin timeout + termination time)." - ) + # Two interpreter cold starts on a loaded runner; healthy runs take ~0.3s. + with anyio.fail_after(15.0): + async with stdio_client(server_params): + stream = await _accept_alive(sock) + await stream.aclose() + + leader = spawned[0] + assert isinstance(leader, anyio.abc.Process) + # The graceful path: exited on stdin closure, no termination involved. + assert leader.returncode == 0 + # Subset, not equality: other machinery may close fds, but never open new + # ones; a leaked pipe fd would show up as an extra entry. + assert set(os.listdir("/proc/self/fd")) <= baseline diff --git a/tests/interaction/transports/test_stdio.py b/tests/interaction/transports/test_stdio.py index 27cc65de42..1f65996aa3 100644 --- a/tests/interaction/transports/test_stdio.py +++ b/tests/interaction/transports/test_stdio.py @@ -1,18 +1,13 @@ """The stdio transport: one subprocess end-to-end test and one in-process framing test. -Everything else in the suite runs in a single process; the subprocess test exists to prove the same -client↔server round trip works over the stdio transport's real boundary (a child process whose -stdin/stdout carry one newline-delimited JSON-RPC message per line). The server lives in -`_stdio_server.py` and is launched via `python -m` so subprocess coverage measurement applies. - -The framing test drives `stdio_server` in-process by passing it injected text streams instead of the -real stdin/stdout, so the raw lines the transport writes can be asserted directly without a process -boundary. - -stdio is deliberately not a leg of the `connect`-fixture matrix: spawning a subprocess per test -would be slow, and the matrix already proves transport-agnosticism over three in-process -transports. Process-lifecycle edge cases (escalation to terminate/kill, parse errors) are covered by -`tests/client/test_stdio.py` and stay deferred here. +The subprocess test proves the client-server round trip over the transport's real process +boundary; its server lives in `_stdio_server.py` and is launched via `python -m` so subprocess +coverage measurement applies. The framing test drives `stdio_server` over injected in-process +streams instead. + +stdio is deliberately not a leg of the `connect`-fixture matrix: a subprocess per test would be +slow, and the matrix already proves transport-agnosticism in-process. Process-lifecycle edge +cases (terminate/kill escalation, parse errors) stay in `tests/client/test_stdio.py`. """ import io @@ -26,6 +21,7 @@ import pytest from inline_snapshot import snapshot +from mcp.client import stdio from mcp.client.client import Client from mcp.client.stdio import StdioServerParameters, stdio_client from mcp.server.stdio import stdio_server @@ -51,10 +47,21 @@ @requirement("transport:stdio") @requirement("transport:stdio:clean-shutdown") @requirement("transport:stdio:stderr-passthrough") -async def test_tool_call_and_notification_round_trip_over_a_stdio_subprocess() -> None: - """A Client connected over stdio initializes, calls a tool with arguments, receives the - server's log notification before the call returns, and the server exits when the transport - closes its stdin.""" +async def test_tool_call_and_notification_round_trip_over_a_stdio_subprocess( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A stdio-subprocess Client round-trips a tool call, a notification, and a clean exit. + + The Client initializes, calls a tool with arguments, and receives the server's log + notification before the call returns; the server exits when the transport closes its + stdin. + """ + # After stdin closes, the child must unwind, write the clean-exit line, and let coverage's + # atexit hook persist its subprocess data file before escalation. The production 2s default + # was too tight on slow Windows runners: the child was killed mid-atexit (test stayed green) + # and the silently missing data file tripped the 100% coverage gate. Not under test. + monkeypatch.setattr(stdio, "PROCESS_TERMINATION_TIMEOUT", 10.0) + received: list[LoggingMessageNotificationParams] = [] async def collect(params: LoggingMessageNotificationParams) -> None: @@ -66,15 +73,18 @@ async def collect(params: LoggingMessageNotificationParams) -> None: command=sys.executable, args=["-m", _stdio_server.__name__], cwd=str(_REPO_ROOT), - # stdio_client deliberately filters the inherited environment to a safe minimum, - # which drops the variables coverage.py's subprocess support uses; pass them through - # so the server module is measured. Empty when not running under coverage. - env={key: value for key, value in os.environ.items() if key.startswith("COVERAGE_")}, + # stdio_client filters the inherited environment, dropping the variables + # coverage.py's subprocess support uses; pass them through so the server module is + # measured. PYTHONWARNINGS: the child recompiles anyio (pytest's pyc tag differs), + # and on 3.14 anyio's return-in-finally SyntaxWarning would land on the snapshot stderr. + env={key: value for key, value in os.environ.items() if key.startswith("COVERAGE_")} + | {"PYTHONWARNINGS": "ignore::SyntaxWarning"}, ), errlog=errlog, ) - with anyio.fail_after(10): + # Must exceed session time plus the patched PROCESS_TERMINATION_TIMEOUT (10s). + with anyio.fail_after(20): async with Client(transport, logging_callback=collect) as client: assert client.initialize_result.server_info.name == "stdio-echo" result = await client.call_tool("echo", {"text": "across\nprocesses"}) @@ -83,28 +93,23 @@ async def collect(params: LoggingMessageNotificationParams) -> None: captured_stderr = errlog.read() assert result == snapshot(CallToolResult(content=[TextContent(text="across\nprocesses")])) - # stdio carries one ordered server→client stream, so the same notification-before-response + # stdio carries one ordered server-to-client stream, so the same notification-before-response # guarantee holds here as for the in-memory transport. assert received == snapshot( [LoggingMessageNotificationParams(level="info", logger="echo", data="echoing across\nprocesses")] ) - # The server writes this line only after its run loop returns, which happens when stdin closes: - # seeing it proves the process exited on its own rather than via the transport's terminate - # escalation, without a timing-based assertion. The capture itself proves stderr passthrough: - # the transport routes the child's stderr to the caller's `errlog` without consuming it. + # The server writes this line only after its run loop returns on stdin close: seeing it proves + # a self-exit, not the terminate escalation. The capture itself proves stderr passthrough. assert captured_stderr == snapshot("stdio-echo: clean exit\n") @requirement("transport:stdio:stream-purity") @requirement("transport:stdio:no-embedded-newlines") async def test_stdio_server_writes_one_jsonrpc_message_per_line() -> None: - """Everything `stdio_server` writes is a valid JSON-RPC message on its own line, and nothing else. + """Every `stdio_server` write is one valid JSON-RPC message on its own line. - The transport's stdin/stdout parameters are public, so the test injects in-process text streams - instead of the real process handles and drives the read/write streams directly: a JSON-RPC line on - stdin is parsed and delivered, and every message sent on the write stream appears as exactly one - newline-terminated line whose payload newlines are JSON-escaped. This proves the transport's own - framing; it does not guard `sys.stdout` against handler code that prints to it directly (see the + Each line is newline-terminated with payload newlines JSON-escaped. This proves the + transport's own framing; it does not guard `sys.stdout` against handler code (see the divergence on `transport:stdio:stream-purity`). """ captured = io.StringIO() diff --git a/tests/issues/test_1027_win_unreachable_cleanup.py b/tests/issues/test_1027_win_unreachable_cleanup.py deleted file mode 100644 index c59c5aecae..0000000000 --- a/tests/issues/test_1027_win_unreachable_cleanup.py +++ /dev/null @@ -1,240 +0,0 @@ -"""Regression test for issue #1027: Ensure cleanup procedures run properly during shutdown - -Issue #1027 reported that cleanup code after "yield" in lifespan was unreachable when -processes were terminated. This has been fixed by implementing the MCP spec-compliant -stdio shutdown sequence that closes stdin first, allowing graceful exit. - -These tests verify the fix continues to work correctly across all platforms. -""" - -import sys -import tempfile -import textwrap -from pathlib import Path - -import anyio -import pytest - -from mcp import ClientSession, StdioServerParameters -from mcp.client.stdio import _create_platform_compatible_process, stdio_client -from tests.shared.test_win32_utils import escape_path_for_python - - -@pytest.mark.anyio -async def test_lifespan_cleanup_executed(): - """Regression test ensuring MCP server cleanup code runs during shutdown. - - This test verifies that the fix for issue #1027 works correctly by: - 1. Starting an MCP server that writes a marker file on startup - 2. Shutting down the server normally via stdio_client - 3. Verifying the cleanup code (after yield) executed and wrote its marker file - - The fix implements proper stdin closure before termination, giving servers - time to run their cleanup handlers. - """ - - # Create marker files to track server lifecycle - with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as f: - startup_marker = f.name - with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as f: - cleanup_marker = f.name - - # Remove the files so we can detect when they're created - Path(startup_marker).unlink() - Path(cleanup_marker).unlink() - - # Create a minimal MCP server using MCPServer that tracks lifecycle - server_code = textwrap.dedent(f""" - import asyncio - import sys - from pathlib import Path - from contextlib import asynccontextmanager - from mcp.server.mcpserver import MCPServer - - STARTUP_MARKER = {escape_path_for_python(startup_marker)} - CLEANUP_MARKER = {escape_path_for_python(cleanup_marker)} - - @asynccontextmanager - async def lifespan(server): - # Write startup marker - Path(STARTUP_MARKER).write_text("started") - try: - yield {{"started": True}} - finally: - # This cleanup code now runs properly during shutdown - Path(CLEANUP_MARKER).write_text("cleaned up") - - mcp = MCPServer("test-server", lifespan=lifespan) - - @mcp.tool() - def echo(text: str) -> str: - return text - - if __name__ == "__main__": - mcp.run() - """) - - # Write the server script to a temporary file - with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".py") as f: - server_script = f.name - f.write(server_code) - - try: - # Launch the MCP server - params = StdioServerParameters(command=sys.executable, args=[server_script]) - - async with stdio_client(params) as (read, write): - async with ClientSession(read, write) as session: - # Initialize the session - result = await session.initialize() - assert result.protocol_version in ["2024-11-05", "2025-06-18", "2025-11-25"] - - # Verify startup marker was created - assert Path(startup_marker).exists(), "Server startup marker not created" - assert Path(startup_marker).read_text() == "started" - - # Make a test request to ensure server is working - response = await session.call_tool("echo", {"text": "hello"}) - assert response.content[0].type == "text" - assert getattr(response.content[0], "text") == "hello" - - # Session will be closed when exiting the context manager - - # Give server a moment to complete cleanup - with anyio.move_on_after(5.0): - while not Path(cleanup_marker).exists(): # pragma: lax no cover - await anyio.sleep(0.1) - - # Verify cleanup marker was created - this works now that stdio_client - # properly closes stdin before termination, allowing graceful shutdown - assert Path(cleanup_marker).exists(), "Server cleanup marker not created - regression in issue #1027 fix" - assert Path(cleanup_marker).read_text() == "cleaned up" - - finally: - # Clean up files - for path in [server_script, startup_marker, cleanup_marker]: - try: # pragma: lax no cover - Path(path).unlink() - except FileNotFoundError: # pragma: lax no cover - pass - - -@pytest.mark.anyio -@pytest.mark.filterwarnings("ignore::ResourceWarning" if sys.platform == "win32" else "default") -async def test_stdin_close_triggers_cleanup(): - """Regression test verifying the stdin-based graceful shutdown mechanism. - - This test ensures the core fix for issue #1027 continues to work by: - 1. Manually managing a server process - 2. Closing stdin to trigger graceful shutdown - 3. Verifying cleanup handlers run before the process exits - - This mimics the behavior now implemented in stdio_client's shutdown sequence. - - Note on Windows ResourceWarning: - On Windows, we may see ResourceWarning about unclosed file descriptors. - This is expected behavior because: - - We're manually managing the process lifecycle - - Windows file handle cleanup works differently than Unix - - The warning doesn't indicate a real issue - cleanup still works - We filter this warning on Windows only to avoid test noise. - """ - - # Create marker files to track server lifecycle - with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as f: - startup_marker = f.name - with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as f: - cleanup_marker = f.name - - # Remove the files so we can detect when they're created - Path(startup_marker).unlink() - Path(cleanup_marker).unlink() - - # Create an MCP server that handles stdin closure gracefully - server_code = textwrap.dedent(f""" - import asyncio - import sys - from pathlib import Path - from contextlib import asynccontextmanager - from mcp.server.mcpserver import MCPServer - - STARTUP_MARKER = {escape_path_for_python(startup_marker)} - CLEANUP_MARKER = {escape_path_for_python(cleanup_marker)} - - @asynccontextmanager - async def lifespan(server): - # Write startup marker - Path(STARTUP_MARKER).write_text("started") - try: - yield {{"started": True}} - finally: - # This cleanup code runs when stdin closes, enabling graceful shutdown - Path(CLEANUP_MARKER).write_text("cleaned up") - - mcp = MCPServer("test-server", lifespan=lifespan) - - @mcp.tool() - def echo(text: str) -> str: - return text - - if __name__ == "__main__": - # The server should exit gracefully when stdin closes - try: - mcp.run() - except Exception: - # Server might get EOF or other errors when stdin closes - pass - """) - - # Write the server script to a temporary file - with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".py") as f: - server_script = f.name - f.write(server_code) - - try: - # This test manually manages the process to verify stdin-based shutdown - # Start the server process - process = await _create_platform_compatible_process( - command=sys.executable, args=[server_script], env=None, errlog=sys.stderr, cwd=None - ) - - # Wait for server to start - with anyio.move_on_after(10.0): - while not Path(startup_marker).exists(): - await anyio.sleep(0.1) - - # Check if process is still running - if hasattr(process, "returncode") and process.returncode is not None: # pragma: lax no cover - pytest.fail(f"Server process exited with code {process.returncode}") - - assert Path(startup_marker).exists(), "Server startup marker not created" - - # Close stdin to signal shutdown - if process.stdin: # pragma: no branch - await process.stdin.aclose() - - # Wait for process to exit gracefully - try: - with anyio.fail_after(5.0): # Increased from 2.0 to 5.0 - await process.wait() - except TimeoutError: # pragma: lax no cover - # If it doesn't exit after stdin close, terminate it - process.terminate() - await process.wait() - - # Check if cleanup ran - with anyio.move_on_after(5.0): - while not Path(cleanup_marker).exists(): # pragma: lax no cover - await anyio.sleep(0.1) - - # Verify the cleanup ran - stdin closure enables graceful shutdown - assert Path(cleanup_marker).exists(), "Server cleanup marker not created - stdin-based shutdown failed" - assert Path(cleanup_marker).read_text() == "cleaned up" - - finally: - # Clean up files - for path in [server_script, startup_marker, cleanup_marker]: - try: # pragma: lax no cover - Path(path).unlink() - except FileNotFoundError: # pragma: lax no cover - pass diff --git a/tests/issues/test_552_windows_hang.py b/tests/issues/test_552_windows_hang.py index 1adb5d80cb..371d033c2b 100644 --- a/tests/issues/test_552_windows_hang.py +++ b/tests/issues/test_552_windows_hang.py @@ -1,5 +1,6 @@ """Test for issue #552: stdio_client hangs on Windows.""" +import json import sys from textwrap import dedent @@ -8,41 +9,36 @@ from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client +from mcp.types import LATEST_PROTOCOL_VERSION, InitializeResult @pytest.mark.skipif(sys.platform != "win32", reason="Windows-specific test") # pragma: no cover @pytest.mark.anyio -async def test_windows_stdio_client_with_session(): - """Test the exact scenario from issue #552: Using ClientSession with stdio_client. +async def test_initialize_succeeds_and_shutdown_returns_after_the_server_exits_mid_session(): + """Initialize completes and shutdown returns when the server exits mid-session. - This reproduces the original bug report where stdio_client hangs on Windows 11 - when used with ClientSession. + This is the proactor pipe scenario that hung on Windows 11 (issue #552). The positive + assertion matters: a session that errors quickly would also "not hang". """ - # Create a minimal MCP server that responds to initialization - server_script = dedent(""" + # A minimal server: answer initialize correctly, then exit. + server_script = dedent(f""" import json import sys - # Read initialization request line = sys.stdin.readline() + request = json.loads(line) - # Send initialization response - response = { + response = {{ "jsonrpc": "2.0", - "id": 1, - "result": { - "protocolVersion": "1.0", - "capabilities": {}, - "serverInfo": {"name": "test-server", "version": "1.0"} - } - } + "id": request["id"], + "result": {{ + "protocolVersion": {json.dumps(LATEST_PROTOCOL_VERSION)}, + "capabilities": {{}}, + "serverInfo": {{"name": "test-server", "version": "1.0"}} + }} + }} print(json.dumps(response)) sys.stdout.flush() - - # Exit after a short delay - import time - time.sleep(0.1) - sys.exit(0) """).strip() params = StdioServerParameters( @@ -50,14 +46,11 @@ async def test_windows_stdio_client_with_session(): args=["-c", server_script], ) - # This is the exact pattern from the bug report with anyio.fail_after(10): - try: - async with stdio_client(params) as (read, write): - async with ClientSession(read, write) as session: - await session.initialize() - # Should exit ClientSession without hanging - # Should exit stdio_client without hanging - except Exception: - # Connection errors are expected when process exits - pass + async with stdio_client(params) as (read, write): + async with ClientSession(read, write) as session: + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.server_info.name == "test-server" + # Exiting ClientSession and stdio_client must not hang even though the + # server process is already gone. diff --git a/tests/server/mcpserver/test_elicitation.py b/tests/server/mcpserver/test_elicitation.py index 679fb848f5..9292586b32 100644 --- a/tests/server/mcpserver/test_elicitation.py +++ b/tests/server/mcpserver/test_elicitation.py @@ -1,4 +1,4 @@ -"""Test the elicitation feature using stdio transport.""" +"""Test the elicitation feature over the in-memory client transport.""" from typing import Any @@ -58,9 +58,9 @@ async def call_tool_and_assert( @pytest.mark.anyio -async def test_stdio_elicitation(): - """Test the elicitation feature using stdio transport.""" - mcp = MCPServer(name="StdioElicitationServer") +async def test_elicitation_accept_returns_the_users_answer_to_the_tool(): + """An accepted elicitation delivers the user's content back to the requesting tool.""" + mcp = MCPServer(name="ElicitationServer") create_ask_user_tool(mcp) # Create a custom handler for elicitation requests @@ -76,9 +76,9 @@ async def elicitation_callback(context: RequestContext[ClientSession], params: E @pytest.mark.anyio -async def test_stdio_elicitation_decline(): - """Test elicitation with user declining.""" - mcp = MCPServer(name="StdioElicitationDeclineServer") +async def test_elicitation_decline_reaches_the_tool_without_content(): + """A declined elicitation reports the decline to the tool, with no content attached.""" + mcp = MCPServer(name="ElicitationDeclineServer") create_ask_user_tool(mcp) async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index 677a993567..054a157b3b 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -1,17 +1,26 @@ import io import sys +import threading +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager from io import TextIOWrapper import anyio import pytest +from mcp.server.mcpserver import MCPServer from mcp.server.stdio import stdio_server from mcp.shared.message import SessionMessage from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter @pytest.mark.anyio -async def test_stdio_server(): +async def test_stdio_server_round_trips_messages_over_injected_streams() -> None: + """stdio_server frames JSON-RPC messages as one line each in both directions. + + Parses one message per stdin line and writes each outgoing message as exactly one + line, driven over injected in-process streams. + """ stdin = io.StringIO() stdout = io.StringIO() @@ -24,52 +33,45 @@ async def test_stdio_server(): stdin.write(message.model_dump_json(by_alias=True, exclude_none=True) + "\n") stdin.seek(0) - async with stdio_server(stdin=anyio.AsyncFile(stdin), stdout=anyio.AsyncFile(stdout)) as ( - read_stream, - write_stream, - ): - received_messages: list[JSONRPCMessage] = [] - async with read_stream: - async for message in read_stream: - if isinstance(message, Exception): # pragma: no cover - raise message - received_messages.append(message.message) - if len(received_messages) == 2: - break - - # Verify received messages - assert len(received_messages) == 2 - assert received_messages[0] == JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") - assert received_messages[1] == JSONRPCResponse(jsonrpc="2.0", id=2, result={}) - - # Test sending responses from the server - responses = [ - JSONRPCRequest(jsonrpc="2.0", id=3, method="ping"), - JSONRPCResponse(jsonrpc="2.0", id=4, result={}), - ] - - async with write_stream: + with anyio.fail_after(5): + async with stdio_server(stdin=anyio.AsyncFile(stdin), stdout=anyio.AsyncFile(stdout)) as ( + read_stream, + write_stream, + ): + async with read_stream: + received_messages: list[JSONRPCMessage] = [] + for _ in range(2): + received = await read_stream.receive() + assert not isinstance(received, Exception) + received_messages.append(received.message) + + assert received_messages[0] == JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + assert received_messages[1] == JSONRPCResponse(jsonrpc="2.0", id=2, result={}) + + responses = [ + JSONRPCRequest(jsonrpc="2.0", id=3, method="ping"), + JSONRPCResponse(jsonrpc="2.0", id=4, result={}), + ] + for response in responses: - session_message = SessionMessage(response) - await write_stream.send(session_message) + await write_stream.send(SessionMessage(response)) + await write_stream.aclose() stdout.seek(0) output_lines = stdout.readlines() assert len(output_lines) == 2 received_responses = [jsonrpc_message_adapter.validate_json(line.strip()) for line in output_lines] - assert len(received_responses) == 2 assert received_responses[0] == JSONRPCRequest(jsonrpc="2.0", id=3, method="ping") assert received_responses[1] == JSONRPCResponse(jsonrpc="2.0", id=4, result={}) @pytest.mark.anyio -async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch): - """Non-UTF-8 bytes on stdin must not crash the server. +async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch) -> None: + """Non-UTF-8 stdin bytes surface as an in-stream exception without killing the stream. - Invalid bytes are replaced with U+FFFD, which then fails JSON parsing and - is delivered as an in-stream exception. Subsequent valid messages must - still be processed. + Invalid bytes are replaced with U+FFFD, fail JSON parsing, and arrive as an in-stream + exception; subsequent valid messages are still processed. """ # \xff\xfe are invalid UTF-8 start bytes. valid = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") @@ -92,3 +94,78 @@ async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch): second = await read_stream.receive() assert isinstance(second, SessionMessage) assert second.message == valid + + +class _KeepOpenBytesIO(io.BytesIO): + """A BytesIO that survives its TextIOWrapper being closed. + + Lets the test read what was written after `run()` has torn the wrapper down. + """ + + def close(self) -> None: + pass + + +def _run_stdio_bounded(server: MCPServer) -> None: + """Run the blocking `server.run("stdio")` in a daemon thread joined with a 5s bound. + + `run()` creates its own event loop, so a sync test cannot arm `anyio.fail_after`; + the join timeout turns a run loop that never returns on stdin EOF into a red test + instead of a silent CI hang. An exception escaping `run()` still fails the test: + pytest's unhandled-thread warning is escalated by `filterwarnings = ["error"]`. + """ + + def target() -> None: + server.run("stdio") + + thread = threading.Thread(target=target, daemon=True) + thread.start() + thread.join(5) + assert not thread.is_alive(), 'run("stdio") did not return after stdin EOF' + + +def test_mcpserver_run_stdio_serves_until_stdin_closes(monkeypatch: pytest.MonkeyPatch) -> None: + """`MCPServer.run("stdio")` serves over process stdio and returns at stdin EOF. + + Answers a request over the process's stdio and returns when stdin reaches EOF, + rather than serving forever. + """ + ping = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + stdin_bytes = io.BytesIO(ping.model_dump_json(by_alias=True, exclude_none=True).encode() + b"\n") + captured = _KeepOpenBytesIO() + monkeypatch.setattr(sys, "stdin", TextIOWrapper(stdin_bytes, encoding="utf-8")) + monkeypatch.setattr(sys, "stdout", TextIOWrapper(captured, encoding="utf-8")) + + _run_stdio_bounded(MCPServer(name="RunStdioServer")) + + response = jsonrpc_message_adapter.validate_json(captured.getvalue().decode().strip()) + assert response == JSONRPCResponse(jsonrpc="2.0", id=1, result={}) + + +def test_mcpserver_run_stdio_runs_lifespan_cleanup_after_stdin_closes(monkeypatch: pytest.MonkeyPatch) -> None: + """Code after `yield` in a lifespan runs when stdin EOF ends `run("stdio")`. + + Regression lock for the issue #1027 shutdown chain: the run loop must end on + stdin EOF and unwind the lifespan rather than be killed before returning. + """ + events: list[str] = [] + + @asynccontextmanager + async def lifespan(server: MCPServer) -> AsyncIterator[None]: + events.append("setup") + try: + yield + finally: + events.append("cleanup") + + ping = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + stdin_bytes = io.BytesIO(ping.model_dump_json(by_alias=True, exclude_none=True).encode() + b"\n") + captured = _KeepOpenBytesIO() + monkeypatch.setattr(sys, "stdin", TextIOWrapper(stdin_bytes, encoding="utf-8")) + monkeypatch.setattr(sys, "stdout", TextIOWrapper(captured, encoding="utf-8")) + + _run_stdio_bounded(MCPServer(name="LifespanStdioServer", lifespan=lifespan)) + + assert events == ["setup", "cleanup"] + response = jsonrpc_message_adapter.validate_json(captured.getvalue().decode().strip()) + assert response == JSONRPCResponse(jsonrpc="2.0", id=1, result={}) diff --git a/tests/shared/test_win32_utils.py b/tests/shared/test_win32_utils.py deleted file mode 100644 index e0f9cb4995..0000000000 --- a/tests/shared/test_win32_utils.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Windows-specific test utilities.""" - - -def escape_path_for_python(path: str) -> str: - """Escape a file path for use in Python code strings. - - Converts backslashes to forward slashes which work on all platforms - and don't need escaping in Python strings. - """ - return repr(path.replace("\\", "/")) diff --git a/tests/transports/__init__.py b/tests/transports/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/transports/stdio/__init__.py b/tests/transports/stdio/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/transports/stdio/_liveness.py b/tests/transports/stdio/_liveness.py new file mode 100644 index 0000000000..5e4b679fe0 --- /dev/null +++ b/tests/transports/stdio/_liveness.py @@ -0,0 +1,80 @@ +"""Kernel-synchronized liveness probes for the real-subprocess stdio lifecycle suite. + +A spawned (grand)child connects back to a test-owned TCP listener and sends +`b'alive'`; the kernel then provides every signal a test needs, with no sleeps or +polling. The kernel closes all of a process's file descriptors on exit, so EOF +(clean close / FIN) or `BrokenResourceError` (abrupt close / RST, typical of +SIGKILL and Windows job termination) proves death; only a running process can +answer an echo, so a reply proves liveness without racing a kill. + +Extracted from the real-process section of tests/client/test_stdio.py; the two +copies on this branch are deliberate -- consolidating them is follow-up work. +""" + +import anyio +import anyio.abc +import pytest + + +def connect_back_script(port: int, *, echo: bool = False) -> str: + """Return a `python -c` script body that connects to 127.0.0.1:`port` and sends `b'alive'`. + + After the banner the script blocks forever -- or, with `echo=True`, echoes every + received chunk back so `assert_peer_echoes` can prove the process still runs. + """ + # lax no cover: echo mode is used only by POSIX-gated tests; Windows runners enforce 100% per job. + if echo: # pragma: lax no cover + tail = "while True:\n data = s.recv(65536)\n if not data:\n break\n s.sendall(data)\n" + else: + tail = "time.sleep(3600)\n" + return f"import socket, time\ns = socket.create_connection(('127.0.0.1', {port}))\ns.sendall(b'alive')\n" + tail + + +async def open_liveness_listener() -> tuple[anyio.abc.SocketListener, int]: + """Open a TCP listener on localhost and return it along with its port.""" + multi = await anyio.create_tcp_listener(local_host="127.0.0.1") + sock = multi.listeners[0] + assert isinstance(sock, anyio.abc.SocketListener) + addr = sock.extra(anyio.abc.SocketAttribute.local_address) + # IPv4 local_address is (host: str, port: int) + assert isinstance(addr, tuple) and len(addr) >= 2 and isinstance(addr[1], int) + return sock, addr[1] + + +async def accept_alive(sock: anyio.abc.SocketListener) -> anyio.abc.SocketStream: + """Accept one connection and assert the peer sent `b'alive'`. + + Reads until the full 5-byte banner arrives (TCP may legally split even a tiny + send). Callers bound this with `anyio.fail_after` to catch a subprocess that + never started. + """ + stream = await sock.accept() + msg = b"" + while len(msg) < 5: + msg += await stream.receive(5 - len(msg)) + assert msg == b"alive", f"expected b'alive', got {msg!r}" + return stream + + +async def assert_stream_closed(stream: anyio.abc.SocketStream) -> None: + """Assert the peer holding the other end of `stream` has terminated.""" + with anyio.fail_after(5.0), pytest.raises((anyio.EndOfStream, anyio.BrokenResourceError)): + await stream.receive(1) + + +async def assert_peer_echoes(stream: anyio.abc.SocketStream) -> None: # pragma: lax no cover + """Assert the peer holding the other end of `stream` is still running. + + Round-trips one echo through the stream (the peer must use `echo=True`); a dead + process can never answer, so this cannot pass spuriously. + + lax no cover: only POSIX-gated survival tests call this; Windows runners + enforce 100% coverage per job. + """ + with anyio.fail_after(5.0): + await stream.send(b"ping") + # Read until the full echo has arrived: TCP may legally split even a tiny send. + echoed = b"" + while len(echoed) < 4: + echoed += await stream.receive(4 - len(echoed)) + assert echoed == b"ping", f"expected b'ping', got {echoed!r}" diff --git a/tests/transports/stdio/conftest.py b/tests/transports/stdio/conftest.py new file mode 100644 index 0000000000..e9601ac6d5 --- /dev/null +++ b/tests/transports/stdio/conftest.py @@ -0,0 +1,77 @@ +"""Fixtures for the stdio lifecycle suite. + +Provides recording seams around `stdio_client`'s spawn and tree-termination +internals (the real implementations still run), plus a teardown that keeps a +crashed test from orphaning its sleep-forever subprocesses. +""" + +import os +import signal +import sys +from collections.abc import Generator +from contextlib import suppress +from pathlib import Path +from typing import TextIO + +import anyio.abc +import pytest + +from mcp.client import stdio +from mcp.client.stdio import _create_platform_compatible_process, _terminate_process_tree +from mcp.os.win32.utilities import FallbackProcess + + +@pytest.fixture +def spawned_processes( + monkeypatch: pytest.MonkeyPatch, +) -> Generator[list[anyio.abc.Process | FallbackProcess]]: + """Record every process `stdio_client` spawns; the real spawn still runs. + + Teardown SIGKILLs each spawn-time process group on POSIX: the safety net for a + test that dies mid-body and the reaper for deliberate survivors. On Windows + there is no group to signal (the Job Object covers strays). + """ + spawned: list[anyio.abc.Process | FallbackProcess] = [] + + async def recording_spawn( + command: str, + args: list[str], + env: dict[str, str] | None = None, + errlog: TextIO = sys.stderr, + cwd: Path | str | None = None, + ) -> anyio.abc.Process | FallbackProcess: + process = await _create_platform_compatible_process(command, args, env, errlog, cwd) + spawned.append(process) + return process + + monkeypatch.setattr(stdio, "_create_platform_compatible_process", recording_spawn) + yield spawned + _kill_spawn_groups(spawned) + + +@pytest.fixture +def terminate_calls(monkeypatch: pytest.MonkeyPatch) -> list[anyio.abc.Process | FallbackProcess]: + """Record every invocation of `stdio_client`'s tree-termination seam; the real termination still runs. + + An empty list after the context exits proves the graceful path: a FIN looks the + same whether the peer exited on stdin closure or was killed. + """ + terminated: list[anyio.abc.Process | FallbackProcess] = [] + + async def recording_terminate(process: anyio.abc.Process | FallbackProcess) -> None: + terminated.append(process) + await _terminate_process_tree(process) + + monkeypatch.setattr(stdio, "_terminate_process_tree", recording_terminate) + return terminated + + +# lax no cover: registered on every platform but a no-op on Windows, whose runners enforce 100% per job. +def _kill_spawn_groups(spawned: list[anyio.abc.Process | FallbackProcess]) -> None: # pragma: lax no cover + """SIGKILL each spawn-time process group; see `spawned_processes`.""" + if sys.platform == "win32": + return + for process in spawned: + # macOS killpg raises EPERM for a group holding only unreaped zombies. + with suppress(ProcessLookupError, PermissionError): + os.killpg(process.pid, signal.SIGKILL) diff --git a/tests/transports/stdio/test_lifecycle.py b/tests/transports/stdio/test_lifecycle.py new file mode 100644 index 0000000000..8a370c10f6 --- /dev/null +++ b/tests/transports/stdio/test_lifecycle.py @@ -0,0 +1,276 @@ +"""Real-subprocess stdio lifecycle tests that hold on both POSIX and Windows. + +The `stdio_client` tests each launch a real server through the public API and pin +one lifecycle behaviour, with kernel-level liveness sockets as the only +synchronization; the `FallbackProcess` tests wrap a raw `subprocess.Popen` +directly. Platform-divergent shutdown policy lives in test_posix.py / +test_windows.py; the full protocol round trip is pinned by +tests/interaction/transports/test_stdio.py and in-process shutdown logic by +tests/client/test_stdio.py. +""" + +import os +import subprocess +import sys +import threading +from contextlib import AsyncExitStack +from pathlib import Path + +import anyio +import anyio.abc +import pytest + +from mcp.client import stdio +from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.os.win32.utilities import FallbackProcess +from tests.transports.stdio._liveness import ( + accept_alive, + assert_stream_closed, + connect_back_script, + open_liveness_listener, +) + + +@pytest.mark.anyio +async def test_a_server_that_exits_on_stdin_close_is_reaped_and_never_terminated( + spawned_processes: list[anyio.abc.Process | FallbackProcess], + terminate_calls: list[anyio.abc.Process | FallbackProcess], +) -> None: + """The happy path: closing stdin alone shuts a well-behaved server down. + + The server exits with code 0 and the escalation seam is never invoked. + """ + async with AsyncExitStack() as stack: + sock, port = await open_liveness_listener() + stack.push_async_callback(sock.aclose) + + # The server exits on its own at stdin EOF -- the well-behaved response + # to shutdown's first step. + server = ( + f"import socket, sys\n" + f"s = socket.create_connection(('127.0.0.1', {port}))\n" + f"s.sendall(b'alive')\n" + f"sys.stdin.read()\n" + ) + params = StdioServerParameters(command=sys.executable, args=["-c", server]) + + # The bound covers one interpreter cold start on a loaded runner; a healthy + # run takes well under a second. + with anyio.fail_after(10.0): + async with stdio_client(params): + stream = await accept_alive(sock) + stack.push_async_callback(stream.aclose) + + await assert_stream_closed(stream) + + assert spawned_processes[0].returncode == 0 + assert terminate_calls == [] + + +@pytest.mark.anyio +async def test_cancelling_the_client_mid_session_terminates_the_whole_server_tree( + monkeypatch: pytest.MonkeyPatch, + spawned_processes: list[anyio.abc.Process | FallbackProcess], + terminate_calls: list[anyio.abc.Process | FallbackProcess], +) -> None: + """Cancellation still runs the full shutdown against a real process tree. + + Cancellation here stands in for a client timeout or app shutdown: a server that + ignores stdin closure is escalated against, and its child dies with it. + """ + monkeypatch.setattr(stdio, "PROCESS_TERMINATION_TIMEOUT", 0.2) + + async with AsyncExitStack() as stack: + sock, port = await open_liveness_listener() + stack.push_async_callback(sock.aclose) + + child = connect_back_script(port) + # The parent never reads stdin and blocks forever, so only the escalation + # can end it -- which cancellation must not skip. + parent = f"import subprocess, sys\nsubprocess.Popen([sys.executable, '-c', {child!r}])\n" + connect_back_script( + port + ) + params = StdioServerParameters(command=sys.executable, args=["-c", parent]) + + entered = anyio.Event() + # Cancel a scope owned by the client's task, not the test's task group: a + # host self-cancel is delivered by throwing through this test function's + # suspended frames, and Python 3.11's tracer loses coverage events after + # such a throw() traversal (python/cpython#106749). + cancel_scope = anyio.CancelScope() + + async def run_client_until_cancelled() -> None: + with cancel_scope: + async with stdio_client(params): + entered.set() + await anyio.sleep_forever() + + streams: list[anyio.abc.SocketStream] = [] + # The bound covers two interpreter cold starts on a loaded runner plus the + # shortened escalation wait; a healthy run takes around a second. + with anyio.fail_after(10.0): + async with anyio.create_task_group() as tg: + tg.start_soon(run_client_until_cancelled) + await entered.wait() + for _ in range(2): + stream = await accept_alive(sock) + stack.push_async_callback(stream.aclose) + streams.append(stream) + cancel_scope.cancel() + + for stream in streams: + await assert_stream_closed(stream) + + assert terminate_calls == spawned_processes + + +@pytest.mark.anyio +async def test_a_server_that_exits_mid_session_keeps_its_own_exit_code( + spawned_processes: list[anyio.abc.Process | FallbackProcess], + terminate_calls: list[anyio.abc.Process | FallbackProcess], +) -> None: + """A server that dies on its own mid-session is reaped with the exit code it chose. + + The client surfaces the child's true status rather than synthesizing one, and + the escalation seam confirms nothing was terminated along the way. + """ + async with AsyncExitStack() as stack: + sock, port = await open_liveness_listener() + stack.push_async_callback(sock.aclose) + + server = ( + f"import socket, sys\n" + f"s = socket.create_connection(('127.0.0.1', {port}))\n" + f"s.sendall(b'alive')\n" + f"sys.exit(7)\n" + ) + params = StdioServerParameters(command=sys.executable, args=["-c", server]) + + # The bound covers one interpreter cold start on a loaded runner; a healthy + # run takes well under a second. + with anyio.fail_after(10.0): + # no branch: coverage mis-traces the exit arcs of a nested `async with` on 3.11+. + async with stdio_client(params): # pragma: no branch + stream = await accept_alive(sock) + stack.push_async_callback(stream.aclose) + # The server is already gone before shutdown begins. + await assert_stream_closed(stream) + + assert spawned_processes[0].returncode == 7 + assert terminate_calls == [] + + +@pytest.mark.anyio +async def test_server_stderr_output_reaches_the_errlog_file( + tmp_path: Path, + spawned_processes: list[anyio.abc.Process | FallbackProcess], +) -> None: + """What the server writes to stderr lands in the file passed as `errlog`. + + The spawn hands over errlog's file descriptor as the child's stderr, so it must + be a real file -- an in-memory StringIO has no fileno. + """ + marker = "stdio-lifecycle stderr marker 4242" + + async with AsyncExitStack() as stack: + sock, port = await open_liveness_listener() + stack.push_async_callback(sock.aclose) + + server = ( + f"import socket, sys\n" + f"s = socket.create_connection(('127.0.0.1', {port}))\n" + f"s.sendall(b'alive')\n" + f"sys.stderr.write({marker!r} + '\\n')\n" + f"sys.stderr.flush()\n" + f"sys.stdin.read()\n" + ) + params = StdioServerParameters(command=sys.executable, args=["-c", server]) + + with (tmp_path / "errlog.txt").open("w+", encoding="utf-8") as errlog: + # The bound covers one interpreter cold start on a loaded runner; a + # healthy run takes well under a second. + with anyio.fail_after(10.0): + async with stdio_client(params, errlog=errlog): + stream = await accept_alive(sock) + stack.push_async_callback(stream.aclose) + + # The server exited on stdin EOF, so every stderr write it made has + # reached the file descriptor. + errlog.seek(0) + content = errlog.read() + + assert marker in content + assert spawned_processes[0].returncode == 0 + + +@pytest.mark.skipif( + not hasattr(os, "waitid"), reason="needs os.waitid(WNOWAIT); absent on Windows and macOS before 3.13" +) +# lax no cover: Windows runners enforce 100% per job but lack os.waitid and skip this +# test; test_windows.py's SelectorEventLoop lifecycle test exercises the property there. +def test_fallback_process_reports_death_through_returncode_without_a_wait_call() -> None: # pragma: lax no cover + """`FallbackProcess.returncode` observes process death on its own. + + Pre-fix it returned Popen's cached value, which stays None until someone calls wait()/poll(). + + `os.waitid(WEXITED | WNOWAIT)` waits for the child to become reapable without + reaping it or priming Popen's cache (which would mask the regression); the + pre-fix cached read would still see None here. stdout EOF is NOT such a signal: + the kernel closes the pipes before the exit status is published, so an + EOF-then-assert version flakes. + """ + popen = subprocess.Popen( + [sys.executable, "-c", "pass"], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + ) + assert popen.stdin is not None and popen.stdout is not None + try: + process = FallbackProcess(popen) + + os.waitid(os.P_PID, popen.pid, os.WEXITED | os.WNOWAIT) + assert process.returncode == 0 + finally: + popen.stdin.close() + popen.stdout.close() + # The WNOWAIT above left the child unreaped; reap it so no zombie (and no + # Popen ResourceWarning) outlives the test. + popen.wait() + + +@pytest.mark.anyio +async def test_fallback_process_wait_is_cancellable_while_the_child_lives() -> None: + """`FallbackProcess.wait()` honours cancellation while the child is still running. + + Pre-fix it parked `Popen.wait()` in a worker thread anyio will not abandon, + which blocks every cancellation aimed at it. Runs everywhere: the wrapper holds + a plain Popen. + """ + popen = subprocess.Popen( + [sys.executable, "-c", "import sys; sys.stdin.read()"], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + ) + assert popen.stdin is not None and popen.stdout is not None + # Pre-fix, no timeout below can fire while the worker thread is parked in + # Popen.wait(); killing the child turns that regression's hang into a clean failure. + watchdog = threading.Timer(8.0, popen.kill) + watchdog.start() + try: + process = FallbackProcess(popen) + + # move_on_after's short deadline is the time-based feature under test -- + # cancellability -- not a wait for an async condition. + with anyio.fail_after(5): + with anyio.move_on_after(0.1) as scope: + await process.wait() + + assert scope.cancelled_caught + # Only the wait was cancelled; the child itself is untouched. + assert popen.poll() is None + finally: + watchdog.cancel() + popen.kill() + popen.wait() + popen.stdin.close() + popen.stdout.close() diff --git a/tests/transports/stdio/test_posix.py b/tests/transports/stdio/test_posix.py new file mode 100644 index 0000000000..521b8bd772 --- /dev/null +++ b/tests/transports/stdio/test_posix.py @@ -0,0 +1,116 @@ +"""POSIX-only stdio lifecycle tests: a gracefully-exited server's children survive the client shutdown. + +SDK-defined policy, not spec-mandated (docs/migration.md, "`stdio_client` no +longer kills children of a gracefully-exited server on POSIX"). Windows has the +opposite documented outcome; see tests/transports/stdio/test_windows.py. +""" + +import errno +import sys +from contextlib import suppress + +import anyio +import anyio.abc +import pytest + +from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.os.win32.utilities import FallbackProcess +from tests.transports.stdio._liveness import ( + accept_alive, + assert_peer_echoes, + connect_back_script, + open_liveness_listener, +) + +pytestmark = pytest.mark.skipif(sys.platform == "win32", reason="POSIX process-group semantics") + + +@pytest.mark.anyio +# lax no cover: the per-job 100% coverage gate also runs on Windows, where this file is skipped. +async def test_a_gracefully_exiting_servers_child_survives_the_client_shutdown( # pragma: lax no cover + spawned_processes: list[anyio.abc.Process | FallbackProcess], + terminate_calls: list[anyio.abc.Process | FallbackProcess], +) -> None: + """A server that exits on stdin closure keeps its background child running after `stdio_client` returns. + + The client never escalates against the gracefully-exited server. SDK-defined + policy per docs/migration.md; regression for the pre-fix client that + tree-killed the child. The Windows twin in test_windows.py pins the opposite outcome. + """ + sock, port = await open_liveness_listener() + async with sock: + child = connect_back_script(port, echo=True) + # The server hands its inherited pipes to a child, then exits as soon as + # its stdin closes: the well-behaved graceful path. + server = f"import subprocess, sys\nsubprocess.Popen([sys.executable, '-c', {child!r}])\nsys.stdin.read()\n" + params = StdioServerParameters(command=sys.executable, args=["-c", server]) + + # Two interpreter cold starts on a loaded runner; healthy runs take ~0.3s. + with anyio.fail_after(10.0): + async with stdio_client(params): + child_stream = await accept_alive(sock) + async with child_stream: + # Only a live process answers an echo: the child survived shutdown. + await assert_peer_echoes(child_stream) + + # A FIN-shaped probe cannot tell graceful exit from a kill; the seam can: + # no escalation was invoked, and the leader exited 0 on stdin closure. + assert terminate_calls == [] + leader = spawned_processes[0] + assert leader.returncode == 0 + # The child is deliberately left running; the spawned_processes teardown + # SIGKILLs the spawn-time process group to reap it. + + +@pytest.mark.anyio +@pytest.mark.usefixtures("spawned_processes") # failure-path safety net for the parked child +# lax no cover: same Windows-runner coverage-gate reason as above. +async def test_a_surviving_childs_write_to_the_inherited_stdout_fails_with_epipe() -> None: # pragma: lax no cover + """A surviving child writing to the stdout pipe it inherited from the server gets EPIPE once the client is gone. + + The pipe's only read end was the client's, and shutdown closed it + deterministically rather than at GC time. Pins the docs/migration.md claim + "a surviving child that keeps writing to an inherited stdout receives + EPIPE/SIGPIPE once the client is gone" (SDK-defined). + + Steps: the server hands its stdio pipes to a child and exits on stdin closure; + the child parks on its socket until `stdio_client` has fully exited (so the + write cannot race transport teardown), then writes one byte to its inherited + fd 1 and reports the errno (0 on success) back over the socket. + """ + sock, port = await open_liveness_listener() + async with sock: + # Pin SIGPIPE to SIG_IGN explicitly (CPython already starts that way) so + # the write fails with EPIPE instead of relying on interpreter startup details. + child = ( + f"import os, signal, socket\n" + f"signal.signal(signal.SIGPIPE, signal.SIG_IGN)\n" + f"s = socket.create_connection(('127.0.0.1', {port}))\n" + f"s.sendall(b'alive')\n" + f"s.recv(4)\n" + f"try:\n" + f" os.write(1, b'x')\n" + f" result = b'0'\n" + f"except OSError as e:\n" + f" result = str(e.errno).encode()\n" + f"s.sendall(result)\n" + ) + server = f"import subprocess, sys\nsubprocess.Popen([sys.executable, '-c', {child!r}])\nsys.stdin.read()\n" + params = StdioServerParameters(command=sys.executable, args=["-c", server]) + + # Two interpreter cold starts on a loaded runner; healthy runs take ~0.3s. + with anyio.fail_after(10.0): + async with stdio_client(params): + child_stream = await accept_alive(sock) + async with child_stream: + # The context has fully exited: the transport, and with it the + # pipe's only read end, is closed. Release the child's write. + await child_stream.send(b"go") + # The child sends its errno report and exits, so read to EOF: the + # complete reply is everything before the kernel's FIN. + reply = b"" + with suppress(anyio.EndOfStream): + while True: + reply += await child_stream.receive(16) + + assert int(reply) == errno.EPIPE, f"child reported errno {reply!r}, expected EPIPE" diff --git a/tests/transports/stdio/test_windows.py b/tests/transports/stdio/test_windows.py new file mode 100644 index 0000000000..0e7ad092c7 --- /dev/null +++ b/tests/transports/stdio/test_windows.py @@ -0,0 +1,235 @@ +"""Windows-only stdio lifecycle behaviors, against real subprocesses. + +Each test pins a contract that exists only on Windows: Job-Object reaping of a +gracefully-exited server's children (the deliberate divergence from the POSIX +policy in test_posix.py), the SelectorEventLoop fallback wrapper, and the CRLF +line endings a native text-mode server emits. Synchronization is kernel-level +only (liveness sockets); see `_liveness`. + +Per-test no-cover pragmas (as in tests/issues/test_552_windows_hang.py): bodies run +only on windows-latest CI legs, the per-job 100% gate would count them uncovered on +non-Windows runners, and strict-no-cover is skipped on Windows where they execute. +""" + +import asyncio +import sys +from contextlib import AsyncExitStack +from pathlib import Path + +import anyio +import anyio.abc +import pytest + +from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.os.win32.utilities import FallbackProcess +from mcp.shared.message import SessionMessage +from mcp.types import JSONRPCRequest, JSONRPCResponse +from tests.transports.stdio._liveness import ( + accept_alive, + assert_stream_closed, + connect_back_script, + open_liveness_listener, +) + +pytestmark = [ + pytest.mark.anyio, + pytest.mark.skipif(sys.platform != "win32", reason="Windows Job Object / event-loop semantics"), +] + + +async def test_a_gracefully_exited_servers_child_is_reaped_when_the_job_handle_closes( # pragma: no cover + tmp_path: Path, + spawned_processes: list[anyio.abc.Process | FallbackProcess], + terminate_calls: list[anyio.abc.Process | FallbackProcess], +) -> None: + """A gracefully-exited server's child is killed deterministically when shutdown closes the job handle. + + The server exits cleanly on stdin closure, leaving a child behind; shutdown's + close of the server's Job Object handle (`close_process_job` + + `JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE`) kills that child deterministically, not at + GC time. Documented divergence from POSIX (docs/migration.md; the POSIX twin is + test_posix.py::test_a_gracefully_exiting_servers_child_survives_the_client_shutdown). + + `terminate_calls == []` is the load-bearing distinction: the child died through + the graceful path's job-handle close, not the escalation's `TerminateJobObject`; + the two kills are indistinguishable on the socket. + + Both processes connect back and their stderr is captured via `errlog`, so a + timeout failure can report which process never showed and the child's fate + (xdist swallows subprocess stderr on CI). + """ + async with AsyncExitStack() as stack: + sock, port = await open_liveness_listener() + stack.push_async_callback(sock.aclose) + + # The startup marker (and any child traceback, via stderr=sys.stderr below) + # lands in errlog, splitting "never started" from "started but never connected". + child = "import sys\nprint('child-started', file=sys.stderr, flush=True)\n" + connect_back_script(port) + # The server spawns a child, connects back itself, then exits as soon as + # its stdin closes: the graceful path, so the escalation never runs. + # The child inherits Job membership: the SDK assigns the server to the Job + # synchronously after spawn, long before the cold-starting interpreter can + # Popen the child (membership is inherited at CreateProcess, never + # acquired retroactively). + # + # The child's stdin must be DEVNULL: CPython startup queries fd 0, and + # Windows serializes that query behind the server's pending blocking + # `sys.stdin.read()` on the inherited pipe, so the child would freeze at + # interpreter startup until the next inbound byte or EOF. + # + # After stdin EOF ends the server, it reports the child's `poll()` status: + # `None` means alive at server exit; an exit/NTSTATUS code names the killer. + server = ( + f"import socket, subprocess, sys\n" + f"try:\n" + f" p = subprocess.Popen([sys.executable, '-c', {child!r}], " + f"stdin=subprocess.DEVNULL, stderr=sys.stderr)\n" + f"except BaseException as exc:\n" + f" print(exc, file=sys.stderr, flush=True)\n" + f" raise\n" + f"s = socket.create_connection(('127.0.0.1', {port}))\n" + f"s.sendall(b'alive')\n" + f"sys.stdin.read()\n" + f"print('child-rc:%s' % p.poll(), file=sys.stderr, flush=True)\n" + ) + server_params = StdioServerParameters(command=sys.executable, args=["-c", server]) + + with (tmp_path / "errlog.txt").open("w+", encoding="utf-8") as errlog: + + def server_stderr() -> str: + errlog.seek(0) + return errlog.read() + + streams: list[anyio.abc.SocketStream] = [] + spawn_started = anyio.current_time() + entered_at: float | None = None + try: + # Two interpreter cold starts on a loaded runner; healthy runs + # take well under a second. + with anyio.fail_after(15.0): + async with stdio_client(server_params, errlog=errlog): + entered_at = anyio.current_time() + # The server and child race to connect; accept both, + # order-agnostic (accept_alive verifies each banner). + for _ in range(2): + stream = await accept_alive(sock) + stack.push_async_callback(stream.aclose) + streams.append(stream) + except TimeoutError: + # `stdio_client.__aexit__` has already completed its shielded shutdown, + # so the stderr read carries the server's final `child-rc` line, not a + # mid-flight snapshot. + missing_leg = "the server never ran its connect line" if not streams else "the child never connected" + spawn_split = ( + "the context never entered" + if entered_at is None + else f"the context entered {entered_at - spawn_started:.1f}s after spawn began" + ) + pytest.fail( + f"{len(streams)}/2 liveness connections arrived ({missing_leg}); " + f"{spawn_split}; server stderr: {server_stderr()!r}" + ) + + # Context exit closed the job handle: KILL_ON_JOB_CLOSE killed the + # child and the server exited gracefully, so both sockets close. + # The `spawned_processes` strong reference is load-bearing: `_process_jobs` + # is weak-keyed, so without it a GC between context exit and this assert + # could close the job handle itself and mask a regression in the + # deterministic close. + try: + for stream in streams: + await assert_stream_closed(stream) + except TimeoutError: + pytest.fail(f"a socket stayed open after shutdown; server stderr: {server_stderr()!r}") + + leader = spawned_processes[0] + # The graceful path: the server exited on stdin closure with code 0, + # and the tree-termination escalation was never invoked. + assert leader.returncode == 0, server_stderr() + assert terminate_calls == [], server_stderr() + + +# Overrides the suite-wide anyio_backend fixture for this test only: a selector +# event loop cannot run asyncio subprocesses, forcing stdio_client onto FallbackProcess. +@pytest.mark.parametrize("anyio_backend", [("asyncio", {"loop_factory": asyncio.SelectorEventLoop})]) +async def test_a_selector_event_loop_session_uses_the_fallback_process_and_exits_cleanly( # pragma: no cover + spawned_processes: list[anyio.abc.Process | FallbackProcess], + terminate_calls: list[anyio.abc.Process | FallbackProcess], +) -> None: + """Under a `SelectorEventLoop`, `stdio_client` falls back to `FallbackProcess` and still exits cleanly. + + A selector event loop has no asyncio subprocess support, so `stdio_client` + falls back to the Popen-based `FallbackProcess` wrapper; a well-behaved server + still completes the full clean lifecycle: spawn, liveness, exit on stdin + closure, reaped, never escalated against. + + The `isinstance` check is the engagement proof: if a future anyio gains selector + subprocess support, the spawn would silently return a normal Process. A hang here + most likely means the known fallback hazard documented in `stdio_client`'s + shutdown comment (reader thread parked in a synchronous `ReadFile`), which is + why this test pins only the clean-exit path, never a kill path. + """ + async with AsyncExitStack() as stack: + sock, port = await open_liveness_listener() + stack.push_async_callback(sock.aclose) + + # Connect back for liveness, then exit as soon as stdin closes: the + # well-behaved server, so shutdown's first step suffices. + server = ( + f"import socket, sys\n" + f"s = socket.create_connection(('127.0.0.1', {port}))\n" + f"s.sendall(b'alive')\n" + f"sys.stdin.read()\n" + ) + server_params = StdioServerParameters(command=sys.executable, args=["-c", server]) + + # One interpreter cold start on a loaded runner; healthy runs take ~0.3s. + with anyio.fail_after(10.0): + async with stdio_client(server_params): + stream = await accept_alive(sock) + stack.push_async_callback(stream.aclose) + # The engagement proof, asserted while the session is live. + assert isinstance(spawned_processes[0], FallbackProcess) + + # The server exited on stdin closure: socket closed, exit code 0, and the + # escalation never fired. + await assert_stream_closed(stream) + assert spawned_processes[0].returncode == 0 + assert terminate_calls == [] + + +async def test_a_native_server_emitting_crlf_line_endings_round_trips_messages() -> None: # pragma: no cover + """The client round-trips messages from a text-mode Windows server that frames its output with \\r\\n. + + `TextIOWrapper`'s `newline=None` translates "\\n" to `os.linesep`, so such a + server emits \\r\\n; the client still parses each line because the reader + splits on "\\n" only and the JSON parser tolerates the trailing "\\r" as + whitespace. The SDK's own server writes through such a wrapper, so this + tolerance is load-bearing for Windows interop. + + tests/issues/test_552_windows_hang.py exercises the same wire form implicitly + through `initialize()`; this test is the explicit owner of the framing claim. + """ + # Read one request, answer it via print() (which emits \r\n on Windows), then + # exit when stdin closes. json.loads/dumps keep the script free of SDK imports. + server = ( + "import json, sys\n" + "line = sys.stdin.readline()\n" + "request = json.loads(line)\n" + "print(json.dumps({'jsonrpc': '2.0', 'id': request['id'], 'result': {}}))\n" + "sys.stdout.flush()\n" + "sys.stdin.read()\n" + ) + server_params = StdioServerParameters(command=sys.executable, args=["-c", server]) + + ping = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + + # One interpreter cold start on a loaded runner; healthy runs take ~0.3s. + with anyio.fail_after(10.0): + async with stdio_client(server_params) as (read_stream, write_stream): + await write_stream.send(SessionMessage(ping)) + received = await read_stream.receive() + # A reader that choked on the trailing \r would deliver a ValueError + # here instead of a parsed message. + assert isinstance(received, SessionMessage) + assert received.message == JSONRPCResponse(jsonrpc="2.0", id=1, result={}) From ac96f88abde567370afc44bdf1501e077738afe8 Mon Sep 17 00:00:00 2001 From: Max <224885523+maxisbey@users.noreply.github.com> Date: Fri, 5 Jun 2026 21:33:48 +0100 Subject: [PATCH 83/84] Deflake the session-level timeout test with trio's virtual clock (#2788) --- tests/interaction/lowlevel/test_timeouts.py | 24 ++++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/tests/interaction/lowlevel/test_timeouts.py b/tests/interaction/lowlevel/test_timeouts.py index a9c83d641d..b440f32106 100644 --- a/tests/interaction/lowlevel/test_timeouts.py +++ b/tests/interaction/lowlevel/test_timeouts.py @@ -1,14 +1,16 @@ """Request timeouts against the low-level Server, driven through the public Client API. The handler blocks on an event that is never set, so the awaited response can never arrive and -any positive timeout fires deterministically on the next event-loop pass. The timeout is therefore -set to an effectively-zero duration: the tests add no wall-clock time to the suite. (Zero itself +any positive timeout fires deterministically on the next event-loop pass. Per-request timeouts are +set to an effectively-zero duration; the session-level test runs on trio's virtual clock instead +(see the comment there). Either way the tests add no wall-clock time to the suite. (Zero itself cannot be used: a falsy read_timeout_seconds is silently treated as "no timeout".) """ import anyio import pytest from inline_snapshot import snapshot +from trio.testing import MockClock from mcp import MCPError, types from mcp.client.client import Client @@ -85,7 +87,19 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara assert result == snapshot(CallToolResult(content=[TextContent(text="still alive")])) +# A session-level timeout cannot use the effectively-zero pattern above: it also governs the +# initialize handshake, which must complete before the blocked tool call can wait the timeout +# out in full. Any real-clock margin is a bet against CI scheduler stalls (a 50ms value lost +# that bet in CI; the in-process handshake tail reaches ~190ms on a loaded windows runner), so +# this test runs on trio's virtual clock instead. With autojump, time advances only when every +# task is blocked: the handshake always has a runnable task and therefore cannot time out no +# matter how slow the runner, and once the tool call blocks on the never-answered request the +# run goes idle and the clock jumps straight to the deadline — deterministic, with no real wait. @requirement("protocol:timeout:session-default") +@pytest.mark.parametrize( + "anyio_backend", + [pytest.param(("trio", {"clock": MockClock(autojump_threshold=0)}), id="trio-mockclock")], +) async def test_session_level_timeout_applies_to_every_request() -> None: """A read timeout configured on the client applies to requests that do not set their own.""" @@ -96,12 +110,6 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("blocker", on_call_tool=call_tool) - # The one real wall-clock wait in the suite, and it cannot be made effectively zero like the - # per-request timeouts: a session-level timeout also governs the initialize handshake, so the - # value must be long enough for the in-process handshake to complete before the blocked tool - # call waits it out in full. 50ms buys a ~50x safety margin over the handshake's actual - # latency; lowering it only erodes the margin against CI scheduler jitter without saving - # anything perceptible. async with Client(server, read_timeout_seconds=0.05) as client: with pytest.raises(MCPError) as exc_info: await client.call_tool("block", {}) From b478bff56d56aea238ebdec26c1ee818e59ee258 Mon Sep 17 00:00:00 2001 From: Max <224885523+maxisbey@users.noreply.github.com> Date: Mon, 8 Jun 2026 12:05:27 +0100 Subject: [PATCH 84/84] Remove the unsupported WebSocket transport (#2785) --- docs/migration.md | 4 + pyproject.toml | 8 +- src/mcp/client/websocket.py | 85 ------------------- src/mcp/server/websocket.py | 52 ------------ tests/client/test_transport_stream_cleanup.py | 14 --- tests/shared/test_ws.py | 51 ----------- tests/test_helpers.py | 57 ------------- uv.lock | 69 +-------------- 8 files changed, 9 insertions(+), 331 deletions(-) delete mode 100644 src/mcp/client/websocket.py delete mode 100644 src/mcp/server/websocket.py delete mode 100644 tests/shared/test_ws.py delete mode 100644 tests/test_helpers.py diff --git a/docs/migration.md b/docs/migration.md index 0f5fc91c3d..da67e034d5 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -145,6 +145,10 @@ group (spawned with `start_new_session=True`); the `getpgid()` lookup and the per-process terminate/kill fallback are gone. The win32 utilities logger is now named `mcp.os.win32.utilities` (was `client.stdio.win32`). +### WebSocket transport removed + +The WebSocket transport has been removed: `mcp.client.websocket.websocket_client`, `mcp.server.websocket.websocket_server`, and the `ws` optional dependency extra (`mcp[ws]`) no longer exist. WebSocket was never part of the MCP specification. Use the streamable HTTP transport instead (`mcp.client.streamable_http.streamable_http_client` on the client, `streamable_http_app()` on the server), which supports bidirectional communication with server-to-client streaming over standard HTTP. + ### Removed type aliases and classes The following deprecated type aliases and classes have been removed from `mcp.types`: diff --git a/pyproject.toml b/pyproject.toml index 6d2319621a..94710c9cef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,6 @@ dependencies = [ [project.optional-dependencies] rich = ["rich>=13.9.4"] cli = ["typer>=0.16.0", "python-dotenv>=1.0.0"] -ws = ["websockets>=15.0.1"] [project.scripts] mcp = "mcp.cli:app [cli]" @@ -75,8 +74,8 @@ build-constraint-dependencies = [ [dependency-groups] dev = [ - # We add mcp[cli,ws] so `uv sync` considers the extras. - "mcp[cli,ws]", + # We add mcp[cli] so `uv sync` considers the extras. + "mcp[cli]", "pyright>=1.1.400", "pytest>=8.4.0", "ruff>=0.8.5", @@ -204,9 +203,6 @@ addopts = """ """ filterwarnings = [ "error", - # This should be fixed on Uvicorn's side. - "ignore::DeprecationWarning:websockets", - "ignore:websockets.server.WebSocketServerProtocol is deprecated:DeprecationWarning", # pywin32 internal deprecation warning "ignore:getargs.*The 'u' format is deprecated:DeprecationWarning", ] diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py deleted file mode 100644 index de473f36d3..0000000000 --- a/src/mcp/client/websocket.py +++ /dev/null @@ -1,85 +0,0 @@ -import json -from collections.abc import AsyncGenerator -from contextlib import asynccontextmanager - -import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic import ValidationError -from websockets.asyncio.client import connect as ws_connect -from websockets.typing import Subprotocol - -from mcp import types -from mcp.shared.message import SessionMessage - - -@asynccontextmanager -async def websocket_client( - url: str, -) -> AsyncGenerator[ - tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]], - None, -]: - """WebSocket client transport for MCP, symmetrical to the server version. - - Connects to 'url' using the 'mcp' subprotocol, then yields: - (read_stream, write_stream) - - - read_stream: As you read from this stream, you'll receive either valid - SessionMessage objects or Exception objects (when validation fails). - - write_stream: Write SessionMessage objects to this stream to send them - over the WebSocket to the server. - """ - - # Create two in-memory streams: - # - One for incoming messages (read_stream, written by ws_reader) - # - One for outgoing messages (write_stream, read by ws_writer) - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] - read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - write_stream: MemoryObjectSendStream[SessionMessage] - write_stream_reader: MemoryObjectReceiveStream[SessionMessage] - - # Connect using websockets, requesting the "mcp" subprotocol - async with ws_connect(url, subprotocols=[Subprotocol("mcp")]) as ws: - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - - async def ws_reader(): - """Reads text messages from the WebSocket, parses them as JSON-RPC messages, - and sends them into read_stream_writer. - """ - async with read_stream_writer: - async for raw_text in ws: - try: - message = types.jsonrpc_message_adapter.validate_json(raw_text, by_name=False) - session_message = SessionMessage(message) - await read_stream_writer.send(session_message) - except ValidationError as exc: # pragma: no cover - # If JSON parse or model validation fails, send the exception - await read_stream_writer.send(exc) - - async def ws_writer(): - """Reads JSON-RPC messages from write_stream_reader and - sends them to the server. - """ - async with write_stream_reader: - async for session_message in write_stream_reader: - # Convert to a dict, then to JSON - msg_dict = session_message.message.model_dump(by_alias=True, mode="json", exclude_unset=True) - await ws.send(json.dumps(msg_dict)) - - async with ( - read_stream_writer, - read_stream, - write_stream, - write_stream_reader, - anyio.create_task_group() as tg, - ): - # Start reader and writer tasks - tg.start_soon(ws_reader) - tg.start_soon(ws_writer) - - # Yield the receive/send streams - yield (read_stream, write_stream) - - # Once the caller's 'async with' block exits, we shut down - tg.cancel_scope.cancel() diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py deleted file mode 100644 index 277f9b5af2..0000000000 --- a/src/mcp/server/websocket.py +++ /dev/null @@ -1,52 +0,0 @@ -from contextlib import asynccontextmanager - -import anyio -from pydantic_core import ValidationError -from starlette.types import Receive, Scope, Send -from starlette.websockets import WebSocket - -from mcp import types -from mcp.shared._context_streams import create_context_streams -from mcp.shared.message import SessionMessage - - -@asynccontextmanager -async def websocket_server(scope: Scope, receive: Receive, send: Send): - """WebSocket server transport for MCP. This is an ASGI application, suitable for use - with a framework like Starlette and a server like Hypercorn. - """ - - websocket = WebSocket(scope, receive, send) - await websocket.accept(subprotocol="mcp") - - read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) - write_stream, write_stream_reader = create_context_streams[SessionMessage](0) - - async def ws_reader(): - try: - async with read_stream_writer: - async for msg in websocket.iter_text(): - try: - client_message = types.jsonrpc_message_adapter.validate_json(msg, by_name=False) - except ValidationError as exc: # pragma: no cover - await read_stream_writer.send(exc) - continue - - session_message = SessionMessage(client_message) - await read_stream_writer.send(session_message) - except anyio.ClosedResourceError: # pragma: no cover - await websocket.close() - - async def ws_writer(): - try: - async with write_stream_reader: - async for session_message in write_stream_reader: - obj = session_message.message.model_dump_json(by_alias=True, exclude_unset=True) - await websocket.send_text(obj) - except anyio.ClosedResourceError: # pragma: no cover - await websocket.close() - - async with anyio.create_task_group() as tg: - tg.start_soon(ws_reader) - tg.start_soon(ws_writer) - yield (read_stream, write_stream) diff --git a/tests/client/test_transport_stream_cleanup.py b/tests/client/test_transport_stream_cleanup.py index 1e6be3c725..40d3b2439d 100644 --- a/tests/client/test_transport_stream_cleanup.py +++ b/tests/client/test_transport_stream_cleanup.py @@ -21,7 +21,6 @@ from mcp.client.sse import sse_client from mcp.client.streamable_http import streamable_http_client -from mcp.client.websocket import websocket_client @contextmanager @@ -104,16 +103,3 @@ async def test_streamable_http_client_closes_all_streams_on_exit() -> None: with _assert_no_memory_stream_leak(): async with streamable_http_client("http://127.0.0.1:1/mcp"): pass - - -@pytest.mark.anyio -async def test_websocket_client_closes_all_streams_on_connection_error(free_tcp_port: int) -> None: - """websocket_client must close all 4 stream ends when ws_connect fails. - - Before the fix, there was no try/finally at all — if ws_connect raised, - all 4 streams were leaked. - """ - with _assert_no_memory_stream_leak(): - with pytest.raises(OSError): - async with websocket_client(f"ws://127.0.0.1:{free_tcp_port}/ws"): - pytest.fail("should not reach here") # pragma: no cover diff --git a/tests/shared/test_ws.py b/tests/shared/test_ws.py deleted file mode 100644 index 482dcdcf32..0000000000 --- a/tests/shared/test_ws.py +++ /dev/null @@ -1,51 +0,0 @@ -"""Smoke test for the WebSocket transport. - -Runs the full WS stack end-to-end over a real TCP connection, covering both -``src/mcp/client/websocket.py`` and ``src/mcp/server/websocket.py``. MCP -semantics (error propagation, timeouts, etc.) are transport-agnostic and are -covered in ``tests/client/test_client.py`` and ``tests/issues/test_88_random_error.py``. -""" - -from collections.abc import Generator - -import pytest -from starlette.applications import Starlette -from starlette.routing import WebSocketRoute -from starlette.websockets import WebSocket - -from mcp.client.session import ClientSession -from mcp.client.websocket import websocket_client -from mcp.server import Server -from mcp.server.websocket import websocket_server -from mcp.types import EmptyResult, InitializeResult -from tests.test_helpers import run_uvicorn_in_thread - -SERVER_NAME = "test_server_for_WS" - - -def make_server_app() -> Starlette: - srv = Server(SERVER_NAME) - - async def handle_ws(websocket: WebSocket) -> None: - async with websocket_server(websocket.scope, websocket.receive, websocket.send) as streams: - await srv.run(streams[0], streams[1], srv.create_initialization_options()) - - return Starlette(routes=[WebSocketRoute("/ws", endpoint=handle_ws)]) - - -@pytest.fixture -def ws_server_url() -> Generator[str, None, None]: - with run_uvicorn_in_thread(make_server_app()) as base_url: - yield base_url.replace("http://", "ws://") + "/ws" - - -@pytest.mark.anyio -async def test_ws_client_basic_connection(ws_server_url: str) -> None: - async with websocket_client(ws_server_url) as streams: - async with ClientSession(*streams) as session: - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == SERVER_NAME - - ping_result = await session.send_ping() - assert isinstance(ping_result, EmptyResult) diff --git a/tests/test_helpers.py b/tests/test_helpers.py deleted file mode 100644 index 0038b18905..0000000000 --- a/tests/test_helpers.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Common test utilities for MCP server tests.""" - -import socket -import threading -from collections.abc import Generator -from contextlib import contextmanager -from typing import Any - -import uvicorn - -_SERVER_SHUTDOWN_TIMEOUT_S = 5.0 - - -@contextmanager -def run_uvicorn_in_thread(app: Any, **config_kwargs: Any) -> Generator[str, None, None]: - """Run a uvicorn server in a background thread on an ephemeral port. - - The socket is bound and put into listening state *before* the thread - starts, so the port is known immediately with no wait. The kernel's - listen queue buffers any connections that arrive before uvicorn's event - loop reaches ``accept()``, so callers can connect as soon as this - function yields — no polling, no sleeps, no startup race. - - This also avoids the TOCTOU race of the old pick-a-port-then-rebind - pattern: the socket passed here is the one uvicorn serves on, with no - gap where another pytest-xdist worker could claim it. - - Args: - app: ASGI application to serve. - **config_kwargs: Additional keyword arguments for :class:`uvicorn.Config` - (e.g. ``log_level``). ``host``/``port`` are ignored since the - socket is pre-bound. - - Yields: - The base URL of the running server, e.g. ``http://127.0.0.1:54321``. - """ - host = "127.0.0.1" - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind((host, 0)) - sock.listen() - port = sock.getsockname()[1] - - config_kwargs.setdefault("log_level", "error") - # Uvicorn's interface autodetection calls asyncio.iscoroutinefunction, - # which Python 3.14 deprecates. Under filterwarnings=error this crashes - # the server thread silently. Starlette is asgi3; skip the autodetect. - config_kwargs.setdefault("interface", "asgi3") - server = uvicorn.Server(config=uvicorn.Config(app=app, **config_kwargs)) - - thread = threading.Thread(target=server.run, kwargs={"sockets": [sock]}, daemon=True) - thread.start() - try: - yield f"http://{host}:{port}" - finally: - server.should_exit = True - thread.join(timeout=_SERVER_SHUTDOWN_TIMEOUT_S) diff --git a/uv.lock b/uv.lock index df63607f40..b4cfafa09f 100644 --- a/uv.lock +++ b/uv.lock @@ -870,9 +870,6 @@ cli = [ rich = [ { name = "rich" }, ] -ws = [ - { name = "websockets" }, -] [package.dev-dependencies] dev = [ @@ -880,7 +877,7 @@ dev = [ { name = "dirty-equals" }, { name = "inline-snapshot" }, { name = "logfire" }, - { name = "mcp", extra = ["cli", "ws"] }, + { name = "mcp", extra = ["cli"] }, { name = "pillow" }, { name = "pyright" }, { name = "pytest" }, @@ -922,9 +919,8 @@ requires-dist = [ { name = "typing-extensions", specifier = ">=4.13.0" }, { name = "typing-inspection", specifier = ">=0.4.1" }, { name = "uvicorn", marker = "sys_platform != 'emscripten'", specifier = ">=0.31.1" }, - { name = "websockets", marker = "extra == 'ws'", specifier = ">=15.0.1" }, ] -provides-extras = ["cli", "rich", "ws"] +provides-extras = ["cli", "rich"] [package.metadata.requires-dev] dev = [ @@ -932,7 +928,7 @@ dev = [ { name = "dirty-equals", specifier = ">=0.9.0" }, { name = "inline-snapshot", specifier = ">=0.23.0" }, { name = "logfire", specifier = ">=3.0.0" }, - { name = "mcp", extras = ["cli", "ws"], editable = "." }, + { name = "mcp", extras = ["cli"], editable = "." }, { name = "pillow", specifier = ">=12.0" }, { name = "pyright", specifier = ">=1.1.400" }, { name = "pytest", specifier = ">=8.4.0" }, @@ -2756,65 +2752,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f4/24/2a3e3df732393fed8b3ebf2ec078f05546de641fe1b667ee316ec1dcf3b7/webencodings-0.5.1-py2.py3-none-any.whl", hash = "sha256:a0af1213f3c2226497a97e2b3aa01a7e4bee4f403f95be16fc9acd2947514a78", size = 11774, upload-time = "2017-04-05T20:21:32.581Z" }, ] -[[package]] -name = "websockets" -version = "15.0.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/21/e6/26d09fab466b7ca9c7737474c52be4f76a40301b08362eb2dbc19dcc16c1/websockets-15.0.1.tar.gz", hash = "sha256:82544de02076bafba038ce055ee6412d68da13ab47f0c60cab827346de828dee", size = 177016, upload-time = "2025-03-05T20:03:41.606Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/1e/da/6462a9f510c0c49837bbc9345aca92d767a56c1fb2939e1579df1e1cdcf7/websockets-15.0.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d63efaa0cd96cf0c5fe4d581521d9fa87744540d4bc999ae6e08595a1014b45b", size = 175423, upload-time = "2025-03-05T20:01:35.363Z" }, - { url = "https://files.pythonhosted.org/packages/1c/9f/9d11c1a4eb046a9e106483b9ff69bce7ac880443f00e5ce64261b47b07e7/websockets-15.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ac60e3b188ec7574cb761b08d50fcedf9d77f1530352db4eef1707fe9dee7205", size = 173080, upload-time = "2025-03-05T20:01:37.304Z" }, - { url = "https://files.pythonhosted.org/packages/d5/4f/b462242432d93ea45f297b6179c7333dd0402b855a912a04e7fc61c0d71f/websockets-15.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5756779642579d902eed757b21b0164cd6fe338506a8083eb58af5c372e39d9a", size = 173329, upload-time = "2025-03-05T20:01:39.668Z" }, - { url = "https://files.pythonhosted.org/packages/6e/0c/6afa1f4644d7ed50284ac59cc70ef8abd44ccf7d45850d989ea7310538d0/websockets-15.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0fdfe3e2a29e4db3659dbd5bbf04560cea53dd9610273917799f1cde46aa725e", size = 182312, upload-time = "2025-03-05T20:01:41.815Z" }, - { url = "https://files.pythonhosted.org/packages/dd/d4/ffc8bd1350b229ca7a4db2a3e1c482cf87cea1baccd0ef3e72bc720caeec/websockets-15.0.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4c2529b320eb9e35af0fa3016c187dffb84a3ecc572bcee7c3ce302bfeba52bf", size = 181319, upload-time = "2025-03-05T20:01:43.967Z" }, - { url = "https://files.pythonhosted.org/packages/97/3a/5323a6bb94917af13bbb34009fac01e55c51dfde354f63692bf2533ffbc2/websockets-15.0.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac1e5c9054fe23226fb11e05a6e630837f074174c4c2f0fe442996112a6de4fb", size = 181631, upload-time = "2025-03-05T20:01:46.104Z" }, - { url = "https://files.pythonhosted.org/packages/a6/cc/1aeb0f7cee59ef065724041bb7ed667b6ab1eeffe5141696cccec2687b66/websockets-15.0.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:5df592cd503496351d6dc14f7cdad49f268d8e618f80dce0cd5a36b93c3fc08d", size = 182016, upload-time = "2025-03-05T20:01:47.603Z" }, - { url = "https://files.pythonhosted.org/packages/79/f9/c86f8f7af208e4161a7f7e02774e9d0a81c632ae76db2ff22549e1718a51/websockets-15.0.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:0a34631031a8f05657e8e90903e656959234f3a04552259458aac0b0f9ae6fd9", size = 181426, upload-time = "2025-03-05T20:01:48.949Z" }, - { url = "https://files.pythonhosted.org/packages/c7/b9/828b0bc6753db905b91df6ae477c0b14a141090df64fb17f8a9d7e3516cf/websockets-15.0.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3d00075aa65772e7ce9e990cab3ff1de702aa09be3940d1dc88d5abf1ab8a09c", size = 181360, upload-time = "2025-03-05T20:01:50.938Z" }, - { url = "https://files.pythonhosted.org/packages/89/fb/250f5533ec468ba6327055b7d98b9df056fb1ce623b8b6aaafb30b55d02e/websockets-15.0.1-cp310-cp310-win32.whl", hash = "sha256:1234d4ef35db82f5446dca8e35a7da7964d02c127b095e172e54397fb6a6c256", size = 176388, upload-time = "2025-03-05T20:01:52.213Z" }, - { url = "https://files.pythonhosted.org/packages/1c/46/aca7082012768bb98e5608f01658ff3ac8437e563eca41cf068bd5849a5e/websockets-15.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:39c1fec2c11dc8d89bba6b2bf1556af381611a173ac2b511cf7231622058af41", size = 176830, upload-time = "2025-03-05T20:01:53.922Z" }, - { url = "https://files.pythonhosted.org/packages/9f/32/18fcd5919c293a398db67443acd33fde142f283853076049824fc58e6f75/websockets-15.0.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:823c248b690b2fd9303ba00c4f66cd5e2d8c3ba4aa968b2779be9532a4dad431", size = 175423, upload-time = "2025-03-05T20:01:56.276Z" }, - { url = "https://files.pythonhosted.org/packages/76/70/ba1ad96b07869275ef42e2ce21f07a5b0148936688c2baf7e4a1f60d5058/websockets-15.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678999709e68425ae2593acf2e3ebcbcf2e69885a5ee78f9eb80e6e371f1bf57", size = 173082, upload-time = "2025-03-05T20:01:57.563Z" }, - { url = "https://files.pythonhosted.org/packages/86/f2/10b55821dd40eb696ce4704a87d57774696f9451108cff0d2824c97e0f97/websockets-15.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d50fd1ee42388dcfb2b3676132c78116490976f1300da28eb629272d5d93e905", size = 173330, upload-time = "2025-03-05T20:01:59.063Z" }, - { url = "https://files.pythonhosted.org/packages/a5/90/1c37ae8b8a113d3daf1065222b6af61cc44102da95388ac0018fcb7d93d9/websockets-15.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d99e5546bf73dbad5bf3547174cd6cb8ba7273062a23808ffea025ecb1cf8562", size = 182878, upload-time = "2025-03-05T20:02:00.305Z" }, - { url = "https://files.pythonhosted.org/packages/8e/8d/96e8e288b2a41dffafb78e8904ea7367ee4f891dafc2ab8d87e2124cb3d3/websockets-15.0.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:66dd88c918e3287efc22409d426c8f729688d89a0c587c88971a0faa2c2f3792", size = 181883, upload-time = "2025-03-05T20:02:03.148Z" }, - { url = "https://files.pythonhosted.org/packages/93/1f/5d6dbf551766308f6f50f8baf8e9860be6182911e8106da7a7f73785f4c4/websockets-15.0.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8dd8327c795b3e3f219760fa603dcae1dcc148172290a8ab15158cf85a953413", size = 182252, upload-time = "2025-03-05T20:02:05.29Z" }, - { url = "https://files.pythonhosted.org/packages/d4/78/2d4fed9123e6620cbf1706c0de8a1632e1a28e7774d94346d7de1bba2ca3/websockets-15.0.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8fdc51055e6ff4adeb88d58a11042ec9a5eae317a0a53d12c062c8a8865909e8", size = 182521, upload-time = "2025-03-05T20:02:07.458Z" }, - { url = "https://files.pythonhosted.org/packages/e7/3b/66d4c1b444dd1a9823c4a81f50231b921bab54eee2f69e70319b4e21f1ca/websockets-15.0.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:693f0192126df6c2327cce3baa7c06f2a117575e32ab2308f7f8216c29d9e2e3", size = 181958, upload-time = "2025-03-05T20:02:09.842Z" }, - { url = "https://files.pythonhosted.org/packages/08/ff/e9eed2ee5fed6f76fdd6032ca5cd38c57ca9661430bb3d5fb2872dc8703c/websockets-15.0.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:54479983bd5fb469c38f2f5c7e3a24f9a4e70594cd68cd1fa6b9340dadaff7cf", size = 181918, upload-time = "2025-03-05T20:02:11.968Z" }, - { url = "https://files.pythonhosted.org/packages/d8/75/994634a49b7e12532be6a42103597b71098fd25900f7437d6055ed39930a/websockets-15.0.1-cp311-cp311-win32.whl", hash = "sha256:16b6c1b3e57799b9d38427dda63edcbe4926352c47cf88588c0be4ace18dac85", size = 176388, upload-time = "2025-03-05T20:02:13.32Z" }, - { url = "https://files.pythonhosted.org/packages/98/93/e36c73f78400a65f5e236cd376713c34182e6663f6889cd45a4a04d8f203/websockets-15.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:27ccee0071a0e75d22cb35849b1db43f2ecd3e161041ac1ee9d2352ddf72f065", size = 176828, upload-time = "2025-03-05T20:02:14.585Z" }, - { url = "https://files.pythonhosted.org/packages/51/6b/4545a0d843594f5d0771e86463606a3988b5a09ca5123136f8a76580dd63/websockets-15.0.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:3e90baa811a5d73f3ca0bcbf32064d663ed81318ab225ee4f427ad4e26e5aff3", size = 175437, upload-time = "2025-03-05T20:02:16.706Z" }, - { url = "https://files.pythonhosted.org/packages/f4/71/809a0f5f6a06522af902e0f2ea2757f71ead94610010cf570ab5c98e99ed/websockets-15.0.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:592f1a9fe869c778694f0aa806ba0374e97648ab57936f092fd9d87f8bc03665", size = 173096, upload-time = "2025-03-05T20:02:18.832Z" }, - { url = "https://files.pythonhosted.org/packages/3d/69/1a681dd6f02180916f116894181eab8b2e25b31e484c5d0eae637ec01f7c/websockets-15.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0701bc3cfcb9164d04a14b149fd74be7347a530ad3bbf15ab2c678a2cd3dd9a2", size = 173332, upload-time = "2025-03-05T20:02:20.187Z" }, - { url = "https://files.pythonhosted.org/packages/a6/02/0073b3952f5bce97eafbb35757f8d0d54812b6174ed8dd952aa08429bcc3/websockets-15.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8b56bdcdb4505c8078cb6c7157d9811a85790f2f2b3632c7d1462ab5783d215", size = 183152, upload-time = "2025-03-05T20:02:22.286Z" }, - { url = "https://files.pythonhosted.org/packages/74/45/c205c8480eafd114b428284840da0b1be9ffd0e4f87338dc95dc6ff961a1/websockets-15.0.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0af68c55afbd5f07986df82831c7bff04846928ea8d1fd7f30052638788bc9b5", size = 182096, upload-time = "2025-03-05T20:02:24.368Z" }, - { url = "https://files.pythonhosted.org/packages/14/8f/aa61f528fba38578ec553c145857a181384c72b98156f858ca5c8e82d9d3/websockets-15.0.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64dee438fed052b52e4f98f76c5790513235efaa1ef7f3f2192c392cd7c91b65", size = 182523, upload-time = "2025-03-05T20:02:25.669Z" }, - { url = "https://files.pythonhosted.org/packages/ec/6d/0267396610add5bc0d0d3e77f546d4cd287200804fe02323797de77dbce9/websockets-15.0.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d5f6b181bb38171a8ad1d6aa58a67a6aa9d4b38d0f8c5f496b9e42561dfc62fe", size = 182790, upload-time = "2025-03-05T20:02:26.99Z" }, - { url = "https://files.pythonhosted.org/packages/02/05/c68c5adbf679cf610ae2f74a9b871ae84564462955d991178f95a1ddb7dd/websockets-15.0.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:5d54b09eba2bada6011aea5375542a157637b91029687eb4fdb2dab11059c1b4", size = 182165, upload-time = "2025-03-05T20:02:30.291Z" }, - { url = "https://files.pythonhosted.org/packages/29/93/bb672df7b2f5faac89761cb5fa34f5cec45a4026c383a4b5761c6cea5c16/websockets-15.0.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3be571a8b5afed347da347bfcf27ba12b069d9d7f42cb8c7028b5e98bbb12597", size = 182160, upload-time = "2025-03-05T20:02:31.634Z" }, - { url = "https://files.pythonhosted.org/packages/ff/83/de1f7709376dc3ca9b7eeb4b9a07b4526b14876b6d372a4dc62312bebee0/websockets-15.0.1-cp312-cp312-win32.whl", hash = "sha256:c338ffa0520bdb12fbc527265235639fb76e7bc7faafbb93f6ba80d9c06578a9", size = 176395, upload-time = "2025-03-05T20:02:33.017Z" }, - { url = "https://files.pythonhosted.org/packages/7d/71/abf2ebc3bbfa40f391ce1428c7168fb20582d0ff57019b69ea20fa698043/websockets-15.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:fcd5cf9e305d7b8338754470cf69cf81f420459dbae8a3b40cee57417f4614a7", size = 176841, upload-time = "2025-03-05T20:02:34.498Z" }, - { url = "https://files.pythonhosted.org/packages/cb/9f/51f0cf64471a9d2b4d0fc6c534f323b664e7095640c34562f5182e5a7195/websockets-15.0.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ee443ef070bb3b6ed74514f5efaa37a252af57c90eb33b956d35c8e9c10a1931", size = 175440, upload-time = "2025-03-05T20:02:36.695Z" }, - { url = "https://files.pythonhosted.org/packages/8a/05/aa116ec9943c718905997412c5989f7ed671bc0188ee2ba89520e8765d7b/websockets-15.0.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5a939de6b7b4e18ca683218320fc67ea886038265fd1ed30173f5ce3f8e85675", size = 173098, upload-time = "2025-03-05T20:02:37.985Z" }, - { url = "https://files.pythonhosted.org/packages/ff/0b/33cef55ff24f2d92924923c99926dcce78e7bd922d649467f0eda8368923/websockets-15.0.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:746ee8dba912cd6fc889a8147168991d50ed70447bf18bcda7039f7d2e3d9151", size = 173329, upload-time = "2025-03-05T20:02:39.298Z" }, - { url = "https://files.pythonhosted.org/packages/31/1d/063b25dcc01faa8fada1469bdf769de3768b7044eac9d41f734fd7b6ad6d/websockets-15.0.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:595b6c3969023ecf9041b2936ac3827e4623bfa3ccf007575f04c5a6aa318c22", size = 183111, upload-time = "2025-03-05T20:02:40.595Z" }, - { url = "https://files.pythonhosted.org/packages/93/53/9a87ee494a51bf63e4ec9241c1ccc4f7c2f45fff85d5bde2ff74fcb68b9e/websockets-15.0.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c714d2fc58b5ca3e285461a4cc0c9a66bd0e24c5da9911e30158286c9b5be7f", size = 182054, upload-time = "2025-03-05T20:02:41.926Z" }, - { url = "https://files.pythonhosted.org/packages/ff/b2/83a6ddf56cdcbad4e3d841fcc55d6ba7d19aeb89c50f24dd7e859ec0805f/websockets-15.0.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f3c1e2ab208db911594ae5b4f79addeb3501604a165019dd221c0bdcabe4db8", size = 182496, upload-time = "2025-03-05T20:02:43.304Z" }, - { url = "https://files.pythonhosted.org/packages/98/41/e7038944ed0abf34c45aa4635ba28136f06052e08fc2168520bb8b25149f/websockets-15.0.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:229cf1d3ca6c1804400b0a9790dc66528e08a6a1feec0d5040e8b9eb14422375", size = 182829, upload-time = "2025-03-05T20:02:48.812Z" }, - { url = "https://files.pythonhosted.org/packages/e0/17/de15b6158680c7623c6ef0db361da965ab25d813ae54fcfeae2e5b9ef910/websockets-15.0.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:756c56e867a90fb00177d530dca4b097dd753cde348448a1012ed6c5131f8b7d", size = 182217, upload-time = "2025-03-05T20:02:50.14Z" }, - { url = "https://files.pythonhosted.org/packages/33/2b/1f168cb6041853eef0362fb9554c3824367c5560cbdaad89ac40f8c2edfc/websockets-15.0.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:558d023b3df0bffe50a04e710bc87742de35060580a293c2a984299ed83bc4e4", size = 182195, upload-time = "2025-03-05T20:02:51.561Z" }, - { url = "https://files.pythonhosted.org/packages/86/eb/20b6cdf273913d0ad05a6a14aed4b9a85591c18a987a3d47f20fa13dcc47/websockets-15.0.1-cp313-cp313-win32.whl", hash = "sha256:ba9e56e8ceeeedb2e080147ba85ffcd5cd0711b89576b83784d8605a7df455fa", size = 176393, upload-time = "2025-03-05T20:02:53.814Z" }, - { url = "https://files.pythonhosted.org/packages/1b/6c/c65773d6cab416a64d191d6ee8a8b1c68a09970ea6909d16965d26bfed1e/websockets-15.0.1-cp313-cp313-win_amd64.whl", hash = "sha256:e09473f095a819042ecb2ab9465aee615bd9c2028e4ef7d933600a8401c79561", size = 176837, upload-time = "2025-03-05T20:02:55.237Z" }, - { url = "https://files.pythonhosted.org/packages/02/9e/d40f779fa16f74d3468357197af8d6ad07e7c5a27ea1ca74ceb38986f77a/websockets-15.0.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:0c9e74d766f2818bb95f84c25be4dea09841ac0f734d1966f415e4edfc4ef1c3", size = 173109, upload-time = "2025-03-05T20:03:17.769Z" }, - { url = "https://files.pythonhosted.org/packages/bc/cd/5b887b8585a593073fd92f7c23ecd3985cd2c3175025a91b0d69b0551372/websockets-15.0.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1009ee0c7739c08a0cd59de430d6de452a55e42d6b522de7aa15e6f67db0b8e1", size = 173343, upload-time = "2025-03-05T20:03:19.094Z" }, - { url = "https://files.pythonhosted.org/packages/fe/ae/d34f7556890341e900a95acf4886833646306269f899d58ad62f588bf410/websockets-15.0.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76d1f20b1c7a2fa82367e04982e708723ba0e7b8d43aa643d3dcd404d74f1475", size = 174599, upload-time = "2025-03-05T20:03:21.1Z" }, - { url = "https://files.pythonhosted.org/packages/71/e6/5fd43993a87db364ec60fc1d608273a1a465c0caba69176dd160e197ce42/websockets-15.0.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f29d80eb9a9263b8d109135351caf568cc3f80b9928bccde535c235de55c22d9", size = 174207, upload-time = "2025-03-05T20:03:23.221Z" }, - { url = "https://files.pythonhosted.org/packages/2b/fb/c492d6daa5ec067c2988ac80c61359ace5c4c674c532985ac5a123436cec/websockets-15.0.1-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b359ed09954d7c18bbc1680f380c7301f92c60bf924171629c5db97febb12f04", size = 174155, upload-time = "2025-03-05T20:03:25.321Z" }, - { url = "https://files.pythonhosted.org/packages/68/a1/dcb68430b1d00b698ae7a7e0194433bce4f07ded185f0ee5fb21e2a2e91e/websockets-15.0.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:cad21560da69f4ce7658ca2cb83138fb4cf695a2ba3e475e0559e05991aa8122", size = 176884, upload-time = "2025-03-05T20:03:27.934Z" }, - { url = "https://files.pythonhosted.org/packages/fa/a8/5b41e0da817d64113292ab1f8247140aac61cbf6cfd085d6a0fa77f4984f/websockets-15.0.1-py3-none-any.whl", hash = "sha256:f7a866fbc1e97b5c617ee4116daaa09b722101d4a3c170c787450ba409f9736f", size = 169743, upload-time = "2025-03-05T20:03:39.41Z" }, -] - [[package]] name = "wrapt" version = "1.17.3"