From cc22bf54641330046fe95666e8c4e041095bb883 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 4 Mar 2026 13:23:02 +0000 Subject: [PATCH 01/60] refactor: remove request_ctx ContextVar, thread Context explicitly (#2203) Co-authored-by: Marcelo Trylesinski --- docs/migration.md | 33 +- src/mcp/server/lowlevel/server.py | 9 +- src/mcp/server/mcpserver/__init__.py | 3 +- src/mcp/server/mcpserver/context.py | 280 ++++++++++++++++ src/mcp/server/mcpserver/prompts/base.py | 12 +- src/mcp/server/mcpserver/prompts/manager.py | 8 +- .../mcpserver/resources/resource_manager.py | 6 +- .../server/mcpserver/resources/templates.py | 10 +- src/mcp/server/mcpserver/server.py | 316 ++---------------- src/mcp/server/mcpserver/tools/base.py | 10 +- .../server/mcpserver/tools/tool_manager.py | 6 +- .../mcpserver/utilities/context_injection.py | 4 +- tests/client/test_list_roots_callback.py | 3 +- tests/client/test_logging_callback.py | 13 +- tests/client/test_sampling_callback.py | 10 +- tests/server/mcpserver/prompts/test_base.py | 31 +- .../server/mcpserver/prompts/test_manager.py | 9 +- .../resources/test_resource_manager.py | 7 +- .../resources/test_resource_template.py | 13 +- tests/server/mcpserver/test_server.py | 21 +- tests/server/mcpserver/test_tool_manager.py | 83 ++--- tests/server/mcpserver/tools/__init__.py | 0 tests/server/mcpserver/tools/test_base.py | 10 + 23 files changed, 484 insertions(+), 413 deletions(-) create mode 100644 src/mcp/server/mcpserver/context.py create mode 100644 tests/server/mcpserver/tools/__init__.py create mode 100644 tests/server/mcpserver/tools/test_base.py diff --git a/docs/migration.md b/docs/migration.md index 6316836938..7cf0325533 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -288,6 +288,37 @@ app = Starlette(routes=[Mount("/", app=mcp.streamable_http_app(json_response=Tru **Note:** DNS rebinding protection is automatically enabled when `host` is `127.0.0.1`, `localhost`, or `::1`. This now happens in `sse_app()` and `streamable_http_app()` instead of the constructor. +### `MCPServer.get_context()` removed + +`MCPServer.get_context()` has been removed. Context is now injected by the framework and passed explicitly — there is no ambient ContextVar to read from. + +**If you were calling `get_context()` from inside a tool/resource/prompt:** use the `ctx: Context` parameter injection instead. + +**Before (v1):** + +```python +@mcp.tool() +async def my_tool(x: int) -> str: + ctx = mcp.get_context() + await ctx.info("Processing...") + return str(x) +``` + +**After (v2):** + +```python +@mcp.tool() +async def my_tool(x: int, ctx: Context) -> str: + await ctx.info("Processing...") + return str(x) +``` + +### `MCPServer.call_tool()`, `read_resource()`, `get_prompt()` now accept a `context` parameter + +`MCPServer.call_tool()`, `MCPServer.read_resource()`, and `MCPServer.get_prompt()` now accept an optional `context: Context | None = None` parameter. The framework passes this automatically during normal request handling. If you call these methods directly and omit `context`, a Context with no active request is constructed for you — tools that don't use `ctx` work normally, but any attempt to use `ctx.session`, `ctx.request_id`, etc. will raise. + +The internal layers (`ToolManager.call_tool`, `Tool.run`, `Prompt.render`, `ResourceTemplate.create_resource`, etc.) now require `context` as a positional argument. + ### Replace `RootModel` by union types with `TypeAdapter` validation The following union types are no longer `RootModel` subclasses: @@ -694,7 +725,7 @@ If you prefer the convenience of automatic wrapping, use `MCPServer` which still ### Lowlevel `Server`: `request_context` property removed -The `server.request_context` property has been removed. Request context is now passed directly to handlers as the first argument (`ctx`). The `request_ctx` module-level contextvar is now an internal implementation detail and should not be relied upon. +The `server.request_context` property has been removed. Request context is now passed directly to handlers as the first argument (`ctx`). The `request_ctx` module-level contextvar has been removed entirely. **Before (v1):** diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index aee6440402..1c84c86107 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -36,7 +36,6 @@ async def main(): from __future__ import annotations -import contextvars import logging import warnings from collections.abc import AsyncIterator, Awaitable, Callable @@ -74,8 +73,6 @@ async def main(): LifespanResultT = TypeVar("LifespanResultT", default=Any) -request_ctx: contextvars.ContextVar[ServerRequestContext[Any]] = contextvars.ContextVar("request_ctx") - class NotificationOptions: def __init__(self, prompts_changed: bool = False, resources_changed: bool = False, tools_changed: bool = False): @@ -474,11 +471,7 @@ async def _handle_request( close_sse_stream=close_sse_stream_cb, close_standalone_sse_stream=close_standalone_sse_stream_cb, ) - token = request_ctx.set(ctx) - try: - response = await handler(ctx, req.params) - finally: - request_ctx.reset(token) + response = await handler(ctx, req.params) except MCPError as err: response = err.error except anyio.get_cancelled_exc_class(): diff --git a/src/mcp/server/mcpserver/__init__.py b/src/mcp/server/mcpserver/__init__.py index f51c0b0ed8..0857e38bd4 100644 --- a/src/mcp/server/mcpserver/__init__.py +++ b/src/mcp/server/mcpserver/__init__.py @@ -2,7 +2,8 @@ from mcp.types import Icon -from .server import Context, MCPServer +from .context import Context +from .server import MCPServer from .utilities.types import Audio, Image __all__ = ["MCPServer", "Context", "Image", "Audio", "Icon"] diff --git a/src/mcp/server/mcpserver/context.py b/src/mcp/server/mcpserver/context.py new file mode 100644 index 0000000000..1538adc7c7 --- /dev/null +++ b/src/mcp/server/mcpserver/context.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, Generic, Literal + +from pydantic import AnyUrl, BaseModel + +from mcp.server.context import LifespanContextT, RequestT, ServerRequestContext +from mcp.server.elicitation import ( + ElicitationResult, + ElicitSchemaModelT, + UrlElicitationResult, + elicit_url, + elicit_with_validation, +) +from mcp.server.lowlevel.helper_types import ReadResourceContents + +if TYPE_CHECKING: + from mcp.server.mcpserver.server import MCPServer + + +class Context(BaseModel, Generic[LifespanContextT, RequestT]): + """Context object providing access to MCP capabilities. + + This provides a cleaner interface to MCP's RequestContext functionality. + It gets injected into tool and resource functions that request it via type hints. + + To use context in a tool function, add a parameter with the Context type annotation: + + ```python + @server.tool() + async def my_tool(x: int, ctx: Context) -> str: + # Log messages to the client + await ctx.info(f"Processing {x}") + await ctx.debug("Debug info") + await ctx.warning("Warning message") + await ctx.error("Error message") + + # Report progress + await ctx.report_progress(50, 100) + + # Access resources + data = await ctx.read_resource("resource://data") + + # Get request info + request_id = ctx.request_id + client_id = ctx.client_id + + return str(x) + ``` + + The context parameter name can be anything as long as it's annotated with Context. + The context is optional - tools that don't need it can omit the parameter. + """ + + _request_context: ServerRequestContext[LifespanContextT, RequestT] | None + _mcp_server: MCPServer | None + + # TODO(maxisbey): Consider making request_context/mcp_server required, or refactor Context entirely. + def __init__( + self, + *, + request_context: ServerRequestContext[LifespanContextT, RequestT] | None = None, + mcp_server: MCPServer | None = None, + # TODO(Marcelo): We should drop this kwargs parameter. + **kwargs: Any, + ): + super().__init__(**kwargs) + self._request_context = request_context + self._mcp_server = mcp_server + + @property + def mcp_server(self) -> MCPServer: + """Access to the MCPServer instance.""" + if self._mcp_server is None: # pragma: no cover + raise ValueError("Context is not available outside of a request") + return self._mcp_server # pragma: no cover + + @property + def request_context(self) -> ServerRequestContext[LifespanContextT, RequestT]: + """Access to the underlying request context.""" + if self._request_context is None: # pragma: no cover + raise ValueError("Context is not available outside of a request") + return self._request_context + + async def report_progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: + """Report progress for the current operation. + + Args: + progress: Current progress value (e.g., 24) + total: Optional total value (e.g., 100) + message: Optional message (e.g., "Starting render...") + """ + progress_token = self.request_context.meta.get("progress_token") if self.request_context.meta else None + + if progress_token is None: # pragma: no cover + return + + await self.request_context.session.send_progress_notification( + progress_token=progress_token, + progress=progress, + total=total, + message=message, + related_request_id=self.request_id, + ) + + async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]: + """Read a resource by URI. + + Args: + uri: Resource URI to read + + Returns: + The resource content as either text or bytes + """ + assert self._mcp_server is not None, "Context is not available outside of a request" + return await self._mcp_server.read_resource(uri, self) + + async def elicit( + self, + message: str, + schema: type[ElicitSchemaModelT], + ) -> ElicitationResult[ElicitSchemaModelT]: + """Elicit information from the client/user. + + This method can be used to interactively ask for additional information from the + client within a tool's execution. The client might display the message to the + user and collect a response according to the provided schema. If the client + is an agent, it might decide how to handle the elicitation -- either by asking + the user or automatically generating a response. + + Args: + message: Message to present to the user + schema: A Pydantic model class defining the expected response structure. + According to the specification, only primitive types are allowed. + + Returns: + An ElicitationResult containing the action taken and the data if accepted + + Note: + Check the result.action to determine if the user accepted, declined, or cancelled. + The result.data will only be populated if action is "accept" and validation succeeded. + """ + + return await elicit_with_validation( + session=self.request_context.session, + message=message, + schema=schema, + related_request_id=self.request_id, + ) + + async def elicit_url( + self, + message: str, + url: str, + elicitation_id: str, + ) -> UrlElicitationResult: + """Request URL mode elicitation from the client. + + This directs the user to an external URL for out-of-band interactions + that must not pass through the MCP client. Use this for: + - Collecting sensitive credentials (API keys, passwords) + - OAuth authorization flows with third-party services + - Payment and subscription flows + - Any interaction where data should not pass through the LLM context + + The response indicates whether the user consented to navigate to the URL. + The actual interaction happens out-of-band. When the elicitation completes, + call `ctx.session.send_elicit_complete(elicitation_id)` to notify the client. + + Args: + message: Human-readable explanation of why the interaction is needed + url: The URL the user should navigate to + elicitation_id: Unique identifier for tracking this elicitation + + Returns: + UrlElicitationResult indicating accept, decline, or cancel + """ + return await elicit_url( + session=self.request_context.session, + message=message, + url=url, + elicitation_id=elicitation_id, + related_request_id=self.request_id, + ) + + async def log( + self, + level: Literal["debug", "info", "warning", "error"], + message: str, + *, + logger_name: str | None = None, + extra: dict[str, Any] | None = None, + ) -> None: + """Send a log message to the client. + + Args: + level: Log level (debug, info, warning, error) + message: Log message + logger_name: Optional logger name + extra: Optional dictionary with additional structured data to include + """ + + if extra: + log_data = {"message": message, **extra} + else: + log_data = message + + await self.request_context.session.send_log_message( + level=level, + data=log_data, + logger=logger_name, + related_request_id=self.request_id, + ) + + @property + def client_id(self) -> str | None: + """Get the client ID if available.""" + return self.request_context.meta.get("client_id") if self.request_context.meta else None # pragma: no cover + + @property + def request_id(self) -> str: + """Get the unique ID for this request.""" + return str(self.request_context.request_id) + + @property + def session(self): + """Access to the underlying session for advanced usage.""" + return self.request_context.session + + async def close_sse_stream(self) -> None: + """Close the SSE stream to trigger client reconnection. + + This method closes the HTTP connection for the current request, triggering + client reconnection. Events continue to be stored in the event store and will + be replayed when the client reconnects with Last-Event-ID. + + Use this to implement polling behavior during long-running operations - + the client will reconnect after the retry interval specified in the priming event. + + Note: + This is a no-op if not using StreamableHTTP transport with event_store. + The callback is only available when event_store is configured. + """ + if self._request_context and self._request_context.close_sse_stream: # pragma: no cover + await self._request_context.close_sse_stream() + + async def close_standalone_sse_stream(self) -> None: + """Close the standalone GET SSE stream to trigger client reconnection. + + This method closes the HTTP connection for the standalone GET stream used + for unsolicited server-to-client notifications. The client SHOULD reconnect + with Last-Event-ID to resume receiving notifications. + + Note: + This is a no-op if not using StreamableHTTP transport with event_store. + Currently, client reconnection for standalone GET streams is NOT + implemented - this is a known gap. + """ + if self._request_context and self._request_context.close_standalone_sse_stream: # pragma: no cover + await self._request_context.close_standalone_sse_stream() + + # Convenience methods for common log levels + async def debug(self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None) -> None: + """Send a debug log message.""" + await self.log("debug", message, logger_name=logger_name, extra=extra) + + async def info(self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None) -> None: + """Send an info log message.""" + await self.log("info", message, logger_name=logger_name, extra=extra) + + async def warning( + self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None + ) -> None: + """Send a warning log message.""" + await self.log("warning", message, logger_name=logger_name, extra=extra) + + async def error(self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None) -> None: + """Send an error log message.""" + await self.log("error", message, logger_name=logger_name, extra=extra) diff --git a/src/mcp/server/mcpserver/prompts/base.py b/src/mcp/server/mcpserver/prompts/base.py index 17744a6707..0c319d53cc 100644 --- a/src/mcp/server/mcpserver/prompts/base.py +++ b/src/mcp/server/mcpserver/prompts/base.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: from mcp.server.context import LifespanContextT, RequestT - from mcp.server.mcpserver.server import Context + from mcp.server.mcpserver.context import Context class Message(BaseModel): @@ -135,10 +135,14 @@ def from_function( async def render( self, - arguments: dict[str, Any] | None = None, - context: Context[LifespanContextT, RequestT] | None = None, + arguments: dict[str, Any] | None, + context: Context[LifespanContextT, RequestT], ) -> list[Message]: - """Render the prompt with arguments.""" + """Render the prompt with arguments. + + Raises: + ValueError: If required arguments are missing, or if rendering fails. + """ # Validate required arguments if self.arguments: required = {arg.name for arg in self.arguments if arg.required} diff --git a/src/mcp/server/mcpserver/prompts/manager.py b/src/mcp/server/mcpserver/prompts/manager.py index 21b9741318..28a7a6e98c 100644 --- a/src/mcp/server/mcpserver/prompts/manager.py +++ b/src/mcp/server/mcpserver/prompts/manager.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: from mcp.server.context import LifespanContextT, RequestT - from mcp.server.mcpserver.server import Context + from mcp.server.mcpserver.context import Context logger = get_logger(__name__) @@ -48,12 +48,12 @@ def add_prompt( async def render_prompt( self, name: str, - arguments: dict[str, Any] | None = None, - context: Context[LifespanContextT, RequestT] | None = None, + arguments: dict[str, Any] | None, + context: Context[LifespanContextT, RequestT], ) -> list[Message]: """Render a prompt by name with arguments.""" prompt = self.get_prompt(name) if not prompt: raise ValueError(f"Unknown prompt: {name}") - return await prompt.render(arguments, context=context) + return await prompt.render(arguments, context) diff --git a/src/mcp/server/mcpserver/resources/resource_manager.py b/src/mcp/server/mcpserver/resources/resource_manager.py index ed5b741239..6bf17376d1 100644 --- a/src/mcp/server/mcpserver/resources/resource_manager.py +++ b/src/mcp/server/mcpserver/resources/resource_manager.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: from mcp.server.context import LifespanContextT, RequestT - from mcp.server.mcpserver.server import Context + from mcp.server.mcpserver.context import Context logger = get_logger(__name__) @@ -80,9 +80,7 @@ def add_template( self._templates[template.uri_template] = template return template - async def get_resource( - self, uri: AnyUrl | str, context: Context[LifespanContextT, RequestT] | None = None - ) -> Resource: + async def get_resource(self, uri: AnyUrl | str, context: Context[LifespanContextT, RequestT]) -> Resource: """Get resource by URI, checking concrete resources first, then templates.""" uri_str = str(uri) logger.debug("Getting resource", extra={"uri": uri_str}) diff --git a/src/mcp/server/mcpserver/resources/templates.py b/src/mcp/server/mcpserver/resources/templates.py index e796823d9d..2d612657c4 100644 --- a/src/mcp/server/mcpserver/resources/templates.py +++ b/src/mcp/server/mcpserver/resources/templates.py @@ -17,7 +17,7 @@ if TYPE_CHECKING: from mcp.server.context import LifespanContextT, RequestT - from mcp.server.mcpserver.server import Context + from mcp.server.mcpserver.context import Context class ResourceTemplate(BaseModel): @@ -99,9 +99,13 @@ async def create_resource( self, uri: str, params: dict[str, Any], - context: Context[LifespanContextT, RequestT] | None = None, + context: Context[LifespanContextT, RequestT], ) -> Resource: - """Create a resource from the template with the given parameters.""" + """Create a resource from the template with the given parameters. + + Raises: + ValueError: If creating the resource fails. + """ try: # Add context to params if needed params = inject_context(self.fn, params, context, self.context_kwarg) diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index 9c7105a7b4..2a7a58117a 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -12,7 +12,6 @@ import anyio import pydantic_core -from pydantic import BaseModel from pydantic.networks import AnyUrl from pydantic_settings import BaseSettings, SettingsConfigDict from starlette.applications import Starlette @@ -27,12 +26,11 @@ from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware from mcp.server.auth.provider import OAuthAuthorizationServerProvider, ProviderTokenVerifier, TokenVerifier from mcp.server.auth.settings import AuthSettings -from mcp.server.context import LifespanContextT, RequestT, ServerRequestContext -from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, UrlElicitationResult, elicit_with_validation -from mcp.server.elicitation import elicit_url as _elicit_url +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel.helper_types import ReadResourceContents -from mcp.server.lowlevel.server import LifespanResultT, Server, request_ctx +from mcp.server.lowlevel.server import LifespanResultT, Server from mcp.server.lowlevel.server import lifespan as default_lifespan +from mcp.server.mcpserver.context import Context from mcp.server.mcpserver.exceptions import ResourceError from mcp.server.mcpserver.prompts import Prompt, PromptManager from mcp.server.mcpserver.resources import FunctionResource, Resource, ResourceManager @@ -300,8 +298,9 @@ async def _handle_list_tools( async def _handle_call_tool( self, ctx: ServerRequestContext[LifespanResultT], params: CallToolRequestParams ) -> CallToolResult: + context = Context(request_context=ctx, mcp_server=self) try: - result = await self.call_tool(params.name, params.arguments or {}) + result = await self.call_tool(params.name, params.arguments or {}, context) except MCPError: raise except Exception as e: @@ -332,7 +331,8 @@ async def _handle_list_resources( async def _handle_read_resource( self, ctx: ServerRequestContext[LifespanResultT], params: ReadResourceRequestParams ) -> ReadResourceResult: - results = await self.read_resource(params.uri) + context = Context(request_context=ctx, mcp_server=self) + results = await self.read_resource(params.uri, context) contents: list[TextResourceContents | BlobResourceContents] = [] for item in results: if isinstance(item.content, bytes): @@ -368,7 +368,8 @@ async def _handle_list_prompts( async def _handle_get_prompt( self, ctx: ServerRequestContext[LifespanResultT], params: GetPromptRequestParams ) -> GetPromptResult: - return await self.get_prompt(params.name, params.arguments) + context = Context(request_context=ctx, mcp_server=self) + return await self.get_prompt(params.name, params.arguments, context) async def list_tools(self) -> list[MCPTool]: """List all available tools.""" @@ -387,22 +388,13 @@ async def list_tools(self) -> list[MCPTool]: for info in tools ] - def get_context(self) -> Context[LifespanResultT, Request]: - """Return a Context object. - - Note that the context will only be valid during a request; outside a - request, most methods will error. - """ - try: - request_context = request_ctx.get() - except LookupError: - request_context = None - return Context(request_context=request_context, mcp_server=self) - - async def call_tool(self, name: str, arguments: dict[str, Any]) -> Sequence[ContentBlock] | dict[str, Any]: + async def call_tool( + self, name: str, arguments: dict[str, Any], context: Context[LifespanResultT, Any] | None = None + ) -> Sequence[ContentBlock] | dict[str, Any]: """Call a tool by name with arguments.""" - context = self.get_context() - return await self._tool_manager.call_tool(name, arguments, context=context, convert_result=True) + if context is None: + context = Context(mcp_server=self) + return await self._tool_manager.call_tool(name, arguments, context, convert_result=True) async def list_resources(self) -> list[MCPResource]: """List all available resources.""" @@ -438,12 +430,14 @@ async def list_resource_templates(self) -> list[MCPResourceTemplate]: for template in templates ] - async def read_resource(self, uri: AnyUrl | str) -> Iterable[ReadResourceContents]: + async def read_resource( + self, uri: AnyUrl | str, context: Context[LifespanResultT, Any] | None = None + ) -> Iterable[ReadResourceContents]: """Read a resource by URI.""" - - context = self.get_context() + if context is None: + context = Context(mcp_server=self) try: - resource = await self._resource_manager.get_resource(uri, context=context) + resource = await self._resource_manager.get_resource(uri, context) except ValueError: raise ResourceError(f"Unknown resource: {uri}") @@ -1087,14 +1081,18 @@ async def list_prompts(self) -> list[MCPPrompt]: for prompt in prompts ] - async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> GetPromptResult: + async def get_prompt( + self, name: str, arguments: dict[str, Any] | None = None, context: Context[LifespanResultT, Any] | None = None + ) -> GetPromptResult: """Get a prompt by name with arguments.""" + if context is None: + context = Context(mcp_server=self) try: prompt = self._prompt_manager.get_prompt(name) if not prompt: raise ValueError(f"Unknown prompt: {name}") - messages = await prompt.render(arguments, context=self.get_context()) + messages = await prompt.render(arguments, context) return GetPromptResult( description=prompt.description, @@ -1103,263 +1101,3 @@ async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None) - except Exception as e: logger.exception(f"Error getting prompt {name}") raise ValueError(str(e)) - - -class Context(BaseModel, Generic[LifespanContextT, RequestT]): - """Context object providing access to MCP capabilities. - - This provides a cleaner interface to MCP's RequestContext functionality. - It gets injected into tool and resource functions that request it via type hints. - - To use context in a tool function, add a parameter with the Context type annotation: - - ```python - @server.tool() - async def my_tool(x: int, ctx: Context) -> str: - # Log messages to the client - await ctx.info(f"Processing {x}") - await ctx.debug("Debug info") - await ctx.warning("Warning message") - await ctx.error("Error message") - - # Report progress - await ctx.report_progress(50, 100) - - # Access resources - data = await ctx.read_resource("resource://data") - - # Get request info - request_id = ctx.request_id - client_id = ctx.client_id - - return str(x) - ``` - - The context parameter name can be anything as long as it's annotated with Context. - The context is optional - tools that don't need it can omit the parameter. - """ - - _request_context: ServerRequestContext[LifespanContextT, RequestT] | None - _mcp_server: MCPServer | None - - def __init__( - self, - *, - request_context: ServerRequestContext[LifespanContextT, RequestT] | None = None, - mcp_server: MCPServer | None = None, - # TODO(Marcelo): We should drop this kwargs parameter. - **kwargs: Any, - ): - super().__init__(**kwargs) - self._request_context = request_context - self._mcp_server = mcp_server - - @property - def mcp_server(self) -> MCPServer: - """Access to the MCPServer instance.""" - if self._mcp_server is None: # pragma: no cover - raise ValueError("Context is not available outside of a request") - return self._mcp_server # pragma: no cover - - @property - def request_context(self) -> ServerRequestContext[LifespanContextT, RequestT]: - """Access to the underlying request context.""" - if self._request_context is None: # pragma: no cover - raise ValueError("Context is not available outside of a request") - return self._request_context - - async def report_progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: - """Report progress for the current operation. - - Args: - progress: Current progress value (e.g., 24) - total: Optional total value (e.g., 100) - message: Optional message (e.g., "Starting render...") - """ - progress_token = self.request_context.meta.get("progress_token") if self.request_context.meta else None - - if progress_token is None: # pragma: no cover - return - - await self.request_context.session.send_progress_notification( - progress_token=progress_token, - progress=progress, - total=total, - message=message, - related_request_id=self.request_id, - ) - - async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]: - """Read a resource by URI. - - Args: - uri: Resource URI to read - - Returns: - The resource content as either text or bytes - """ - assert self._mcp_server is not None, "Context is not available outside of a request" - return await self._mcp_server.read_resource(uri) - - async def elicit( - self, - message: str, - schema: type[ElicitSchemaModelT], - ) -> ElicitationResult[ElicitSchemaModelT]: - """Elicit information from the client/user. - - This method can be used to interactively ask for additional information from the - client within a tool's execution. The client might display the message to the - user and collect a response according to the provided schema. If the client - is an agent, it might decide how to handle the elicitation -- either by asking - the user or automatically generating a response. - - Args: - message: Message to present to the user - schema: A Pydantic model class defining the expected response structure. - According to the specification, only primitive types are allowed. - - Returns: - An ElicitationResult containing the action taken and the data if accepted - - Note: - Check the result.action to determine if the user accepted, declined, or cancelled. - The result.data will only be populated if action is "accept" and validation succeeded. - """ - - return await elicit_with_validation( - session=self.request_context.session, - message=message, - schema=schema, - related_request_id=self.request_id, - ) - - async def elicit_url( - self, - message: str, - url: str, - elicitation_id: str, - ) -> UrlElicitationResult: - """Request URL mode elicitation from the client. - - This directs the user to an external URL for out-of-band interactions - that must not pass through the MCP client. Use this for: - - Collecting sensitive credentials (API keys, passwords) - - OAuth authorization flows with third-party services - - Payment and subscription flows - - Any interaction where data should not pass through the LLM context - - The response indicates whether the user consented to navigate to the URL. - The actual interaction happens out-of-band. When the elicitation completes, - call `ctx.session.send_elicit_complete(elicitation_id)` to notify the client. - - Args: - message: Human-readable explanation of why the interaction is needed - url: The URL the user should navigate to - elicitation_id: Unique identifier for tracking this elicitation - - Returns: - UrlElicitationResult indicating accept, decline, or cancel - """ - return await _elicit_url( - session=self.request_context.session, - message=message, - url=url, - elicitation_id=elicitation_id, - related_request_id=self.request_id, - ) - - async def log( - self, - level: Literal["debug", "info", "warning", "error"], - message: str, - *, - logger_name: str | None = None, - extra: dict[str, Any] | None = None, - ) -> None: - """Send a log message to the client. - - Args: - level: Log level (debug, info, warning, error) - message: Log message - logger_name: Optional logger name - extra: Optional dictionary with additional structured data to include - """ - - if extra: - log_data = {"message": message, **extra} - else: - log_data = message - - await self.request_context.session.send_log_message( - level=level, - data=log_data, - logger=logger_name, - related_request_id=self.request_id, - ) - - @property - def client_id(self) -> str | None: - """Get the client ID if available.""" - return self.request_context.meta.get("client_id") if self.request_context.meta else None # pragma: no cover - - @property - def request_id(self) -> str: - """Get the unique ID for this request.""" - return str(self.request_context.request_id) - - @property - def session(self): - """Access to the underlying session for advanced usage.""" - return self.request_context.session - - async def close_sse_stream(self) -> None: - """Close the SSE stream to trigger client reconnection. - - This method closes the HTTP connection for the current request, triggering - client reconnection. Events continue to be stored in the event store and will - be replayed when the client reconnects with Last-Event-ID. - - Use this to implement polling behavior during long-running operations - - the client will reconnect after the retry interval specified in the priming event. - - Note: - This is a no-op if not using StreamableHTTP transport with event_store. - The callback is only available when event_store is configured. - """ - if self._request_context and self._request_context.close_sse_stream: # pragma: no cover - await self._request_context.close_sse_stream() - - async def close_standalone_sse_stream(self) -> None: - """Close the standalone GET SSE stream to trigger client reconnection. - - This method closes the HTTP connection for the standalone GET stream used - for unsolicited server-to-client notifications. The client SHOULD reconnect - with Last-Event-ID to resume receiving notifications. - - Note: - This is a no-op if not using StreamableHTTP transport with event_store. - Currently, client reconnection for standalone GET streams is NOT - implemented - this is a known gap. - """ - if self._request_context and self._request_context.close_standalone_sse_stream: # pragma: no cover - await self._request_context.close_standalone_sse_stream() - - # Convenience methods for common log levels - async def debug(self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None) -> None: - """Send a debug log message.""" - await self.log("debug", message, logger_name=logger_name, extra=extra) - - async def info(self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None) -> None: - """Send an info log message.""" - await self.log("info", message, logger_name=logger_name, extra=extra) - - async def warning( - self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None - ) -> None: - """Send a warning log message.""" - await self.log("warning", message, logger_name=logger_name, extra=extra) - - async def error(self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None) -> None: - """Send an error log message.""" - await self.log("error", message, logger_name=logger_name, extra=extra) diff --git a/src/mcp/server/mcpserver/tools/base.py b/src/mcp/server/mcpserver/tools/base.py index f6bfadbc4d..dc65be9885 100644 --- a/src/mcp/server/mcpserver/tools/base.py +++ b/src/mcp/server/mcpserver/tools/base.py @@ -17,7 +17,7 @@ if TYPE_CHECKING: from mcp.server.context import LifespanContextT, RequestT - from mcp.server.mcpserver.server import Context + from mcp.server.mcpserver.context import Context class Tool(BaseModel): @@ -92,10 +92,14 @@ def from_function( async def run( self, arguments: dict[str, Any], - context: Context[LifespanContextT, RequestT] | None = None, + context: Context[LifespanContextT, RequestT], convert_result: bool = False, ) -> Any: - """Run the tool with arguments.""" + """Run the tool with arguments. + + Raises: + ToolError: If the tool function raises during execution. + """ try: result = await self.fn_metadata.call_fn_with_arg_validation( self.fn, diff --git a/src/mcp/server/mcpserver/tools/tool_manager.py b/src/mcp/server/mcpserver/tools/tool_manager.py index c6f8384bdb..32ed547973 100644 --- a/src/mcp/server/mcpserver/tools/tool_manager.py +++ b/src/mcp/server/mcpserver/tools/tool_manager.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: from mcp.server.context import LifespanContextT, RequestT - from mcp.server.mcpserver.server import Context + from mcp.server.mcpserver.context import Context logger = get_logger(__name__) @@ -81,7 +81,7 @@ async def call_tool( self, name: str, arguments: dict[str, Any], - context: Context[LifespanContextT, RequestT] | None = None, + context: Context[LifespanContextT, RequestT], convert_result: bool = False, ) -> Any: """Call a tool by name with arguments.""" @@ -89,4 +89,4 @@ async def call_tool( if not tool: raise ToolError(f"Unknown tool: {name}") - return await tool.run(arguments, context=context, convert_result=convert_result) + return await tool.run(arguments, context, convert_result=convert_result) diff --git a/src/mcp/server/mcpserver/utilities/context_injection.py b/src/mcp/server/mcpserver/utilities/context_injection.py index 9cba83e860..ac7ab82d05 100644 --- a/src/mcp/server/mcpserver/utilities/context_injection.py +++ b/src/mcp/server/mcpserver/utilities/context_injection.py @@ -7,6 +7,8 @@ from collections.abc import Callable from typing import Any +from mcp.server.mcpserver.context import Context + def find_context_parameter(fn: Callable[..., Any]) -> str | None: """Find the parameter that should receive the Context object. @@ -20,8 +22,6 @@ def find_context_parameter(fn: Callable[..., Any]) -> str | None: Returns: The name of the context parameter, or None if not found """ - from mcp.server.mcpserver.server import Context - # Get type hints to properly resolve string annotations try: hints = typing.get_type_hints(fn) diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index 6a2f49f390..1ab90be772 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -3,8 +3,7 @@ from mcp import Client from mcp.client.session import ClientSession -from mcp.server.mcpserver import MCPServer -from mcp.server.mcpserver.server import Context +from mcp.server.mcpserver import Context, MCPServer from mcp.shared._context import RequestContext from mcp.types import ListRootsResult, Root, TextContent diff --git a/tests/client/test_logging_callback.py b/tests/client/test_logging_callback.py index 31cdeece73..1598fd55f6 100644 --- a/tests/client/test_logging_callback.py +++ b/tests/client/test_logging_callback.py @@ -3,7 +3,7 @@ import pytest from mcp import Client, types -from mcp.server.mcpserver import MCPServer +from mcp.server.mcpserver import Context, MCPServer from mcp.shared.session import RequestResponder from mcp.types import ( LoggingMessageNotificationParams, @@ -33,14 +33,10 @@ async def test_tool() -> bool: # Create a function that can send a log notification @server.tool("test_tool_with_log") async def test_tool_with_log( - message: str, level: Literal["debug", "info", "warning", "error"], logger: str + message: str, level: Literal["debug", "info", "warning", "error"], logger: str, ctx: Context ) -> bool: """Send a log notification to the client.""" - await server.get_context().log( - level=level, - message=message, - logger_name=logger, - ) + await ctx.log(level=level, message=message, logger_name=logger) return True @server.tool("test_tool_with_log_extra") @@ -50,9 +46,10 @@ async def test_tool_with_log_extra( logger: str, extra_string: str, extra_dict: dict[str, Any], + ctx: Context, ) -> bool: """Send a log notification to the client with extra fields.""" - await server.get_context().log( + await ctx.log( level=level, message=message, logger_name=logger, diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index 3357bc921d..6efcac0a52 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -2,7 +2,7 @@ from mcp import Client from mcp.client.session import ClientSession -from mcp.server.mcpserver import MCPServer +from mcp.server.mcpserver import Context, MCPServer from mcp.shared._context import RequestContext from mcp.types import ( CreateMessageRequestParams, @@ -32,8 +32,8 @@ async def sampling_callback( return callback_return @server.tool("test_sampling") - async def test_sampling_tool(message: str): - value = await server.get_context().session.create_message( + async def test_sampling_tool(message: str, ctx: Context) -> bool: + value = await ctx.session.create_message( messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))], max_tokens=100, ) @@ -77,9 +77,9 @@ async def sampling_callback( return callback_return @server.tool("test_backwards_compat") - async def test_tool(message: str): + async def test_tool(message: str, ctx: Context) -> bool: # Call create_message WITHOUT tools - result = await server.get_context().session.create_message( + result = await ctx.session.create_message( messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))], max_tokens=100, ) diff --git a/tests/server/mcpserver/prompts/test_base.py b/tests/server/mcpserver/prompts/test_base.py index 553e47363f..fe18e91bd7 100644 --- a/tests/server/mcpserver/prompts/test_base.py +++ b/tests/server/mcpserver/prompts/test_base.py @@ -2,6 +2,7 @@ import pytest +from mcp.server.mcpserver import Context from mcp.server.mcpserver.prompts.base import AssistantMessage, Message, Prompt, UserMessage from mcp.types import EmbeddedResource, TextContent, TextResourceContents @@ -13,7 +14,9 @@ def fn() -> str: return "Hello, world!" prompt = Prompt.from_function(fn) - assert await prompt.render() == [UserMessage(content=TextContent(type="text", text="Hello, world!"))] + assert await prompt.render(None, Context()) == [ + UserMessage(content=TextContent(type="text", text="Hello, world!")) + ] @pytest.mark.anyio async def test_async_fn(self): @@ -21,7 +24,9 @@ async def fn() -> str: return "Hello, world!" prompt = Prompt.from_function(fn) - assert await prompt.render() == [UserMessage(content=TextContent(type="text", text="Hello, world!"))] + assert await prompt.render(None, Context()) == [ + UserMessage(content=TextContent(type="text", text="Hello, world!")) + ] @pytest.mark.anyio async def test_fn_with_args(self): @@ -29,7 +34,7 @@ async def fn(name: str, age: int = 30) -> str: return f"Hello, {name}! You're {age} years old." prompt = Prompt.from_function(fn) - assert await prompt.render(arguments={"name": "World"}) == [ + assert await prompt.render({"name": "World"}, Context()) == [ UserMessage(content=TextContent(type="text", text="Hello, World! You're 30 years old.")) ] @@ -40,7 +45,7 @@ async def fn(name: str, age: int = 30) -> str: # pragma: no cover prompt = Prompt.from_function(fn) with pytest.raises(ValueError): - await prompt.render(arguments={"age": 40}) + await prompt.render({"age": 40}, Context()) @pytest.mark.anyio async def test_fn_returns_message(self): @@ -48,7 +53,9 @@ async def fn() -> UserMessage: return UserMessage(content="Hello, world!") prompt = Prompt.from_function(fn) - assert await prompt.render() == [UserMessage(content=TextContent(type="text", text="Hello, world!"))] + assert await prompt.render(None, Context()) == [ + UserMessage(content=TextContent(type="text", text="Hello, world!")) + ] @pytest.mark.anyio async def test_fn_returns_assistant_message(self): @@ -56,7 +63,9 @@ async def fn() -> AssistantMessage: return AssistantMessage(content=TextContent(type="text", text="Hello, world!")) prompt = Prompt.from_function(fn) - assert await prompt.render() == [AssistantMessage(content=TextContent(type="text", text="Hello, world!"))] + assert await prompt.render(None, Context()) == [ + AssistantMessage(content=TextContent(type="text", text="Hello, world!")) + ] @pytest.mark.anyio async def test_fn_returns_multiple_messages(self): @@ -70,7 +79,7 @@ async def fn() -> list[Message]: return expected prompt = Prompt.from_function(fn) - assert await prompt.render() == expected + assert await prompt.render(None, Context()) == expected @pytest.mark.anyio async def test_fn_returns_list_of_strings(self): @@ -83,7 +92,7 @@ async def fn() -> list[str]: return expected prompt = Prompt.from_function(fn) - assert await prompt.render() == [UserMessage(t) for t in expected] + assert await prompt.render(None, Context()) == [UserMessage(t) for t in expected] @pytest.mark.anyio async def test_fn_returns_resource_content(self): @@ -102,7 +111,7 @@ async def fn() -> UserMessage: ) prompt = Prompt.from_function(fn) - assert await prompt.render() == [ + assert await prompt.render(None, Context()) == [ UserMessage( content=EmbeddedResource( type="resource", @@ -136,7 +145,7 @@ async def fn() -> list[Message]: ] prompt = Prompt.from_function(fn) - assert await prompt.render() == [ + assert await prompt.render(None, Context()) == [ UserMessage(content=TextContent(type="text", text="Please analyze this file:")), UserMessage( content=EmbeddedResource( @@ -169,7 +178,7 @@ async def fn() -> dict[str, Any]: } prompt = Prompt.from_function(fn) - assert await prompt.render() == [ + assert await prompt.render(None, Context()) == [ UserMessage( content=EmbeddedResource( type="resource", diff --git a/tests/server/mcpserver/prompts/test_manager.py b/tests/server/mcpserver/prompts/test_manager.py index 02f91c6802..99a03db565 100644 --- a/tests/server/mcpserver/prompts/test_manager.py +++ b/tests/server/mcpserver/prompts/test_manager.py @@ -1,5 +1,6 @@ import pytest +from mcp.server.mcpserver import Context from mcp.server.mcpserver.prompts.base import Prompt, UserMessage from mcp.server.mcpserver.prompts.manager import PromptManager from mcp.types import TextContent @@ -72,7 +73,7 @@ def fn() -> str: manager = PromptManager() prompt = Prompt.from_function(fn) manager.add_prompt(prompt) - messages = await manager.render_prompt("fn") + messages = await manager.render_prompt("fn", None, Context()) assert messages == [UserMessage(content=TextContent(type="text", text="Hello, world!"))] @pytest.mark.anyio @@ -85,7 +86,7 @@ def fn(name: str) -> str: manager = PromptManager() prompt = Prompt.from_function(fn) manager.add_prompt(prompt) - messages = await manager.render_prompt("fn", arguments={"name": "World"}) + messages = await manager.render_prompt("fn", {"name": "World"}, Context()) assert messages == [UserMessage(content=TextContent(type="text", text="Hello, World!"))] @pytest.mark.anyio @@ -93,7 +94,7 @@ async def test_render_unknown_prompt(self): """Test rendering a non-existent prompt.""" manager = PromptManager() with pytest.raises(ValueError, match="Unknown prompt: unknown"): - await manager.render_prompt("unknown") + await manager.render_prompt("unknown", None, Context()) @pytest.mark.anyio async def test_render_prompt_with_missing_args(self): @@ -106,4 +107,4 @@ def fn(name: str) -> str: # pragma: no cover prompt = Prompt.from_function(fn) manager.add_prompt(prompt) with pytest.raises(ValueError, match="Missing required arguments"): - await manager.render_prompt("fn") + await manager.render_prompt("fn", None, Context()) diff --git a/tests/server/mcpserver/resources/test_resource_manager.py b/tests/server/mcpserver/resources/test_resource_manager.py index eb9b355aaf..724b579974 100644 --- a/tests/server/mcpserver/resources/test_resource_manager.py +++ b/tests/server/mcpserver/resources/test_resource_manager.py @@ -4,6 +4,7 @@ import pytest from pydantic import AnyUrl +from mcp.server.mcpserver import Context from mcp.server.mcpserver.resources import FileResource, FunctionResource, ResourceManager, ResourceTemplate @@ -86,7 +87,7 @@ async def test_get_resource(self, temp_file: Path): path=temp_file, ) manager.add_resource(resource) - retrieved = await manager.get_resource(resource.uri) + retrieved = await manager.get_resource(resource.uri, Context()) assert retrieved == resource @pytest.mark.anyio @@ -104,7 +105,7 @@ def greet(name: str) -> str: ) manager._templates[template.uri_template] = template - resource = await manager.get_resource(AnyUrl("greet://world")) + resource = await manager.get_resource(AnyUrl("greet://world"), Context()) assert isinstance(resource, FunctionResource) content = await resource.read() assert content == "Hello, world!" @@ -114,7 +115,7 @@ async def test_get_unknown_resource(self): """Test getting a non-existent resource.""" manager = ResourceManager() with pytest.raises(ValueError, match="Unknown resource"): - await manager.get_resource(AnyUrl("unknown://test")) + await manager.get_resource(AnyUrl("unknown://test"), Context()) def test_list_resources(self, temp_file: Path): """Test listing all resources.""" diff --git a/tests/server/mcpserver/resources/test_resource_template.py b/tests/server/mcpserver/resources/test_resource_template.py index 08f2033bff..640cfe8031 100644 --- a/tests/server/mcpserver/resources/test_resource_template.py +++ b/tests/server/mcpserver/resources/test_resource_template.py @@ -4,7 +4,7 @@ import pytest from pydantic import BaseModel -from mcp.server.mcpserver import MCPServer +from mcp.server.mcpserver import Context, MCPServer from mcp.server.mcpserver.resources import FunctionResource, ResourceTemplate from mcp.types import Annotations @@ -64,6 +64,7 @@ def my_func(key: str, value: int) -> dict[str, Any]: resource = await template.create_resource( "test://foo/123", {"key": "foo", "value": 123}, + Context(), ) assert isinstance(resource, FunctionResource) @@ -86,7 +87,7 @@ def failing_func(x: str) -> str: ) with pytest.raises(ValueError, match="Error creating resource from template"): - await template.create_resource("fail://test", {"x": "test"}) + await template.create_resource("fail://test", {"x": "test"}, Context()) @pytest.mark.anyio async def test_async_text_resource(self): @@ -104,6 +105,7 @@ async def greet(name: str) -> str: resource = await template.create_resource( "greet://world", {"name": "world"}, + Context(), ) assert isinstance(resource, FunctionResource) @@ -126,6 +128,7 @@ async def get_bytes(value: str) -> bytes: resource = await template.create_resource( "bytes://test", {"value": "test"}, + Context(), ) assert isinstance(resource, FunctionResource) @@ -152,6 +155,7 @@ def get_data(key: str, value: int) -> MyModel: resource = await template.create_resource( "test://foo/123", {"key": "foo", "value": 123}, + Context(), ) assert isinstance(resource, FunctionResource) @@ -183,6 +187,7 @@ def get_data(value: str) -> CustomData: resource = await template.create_resource( "test://hello", {"value": "hello"}, + Context(), ) assert isinstance(resource, FunctionResource) @@ -249,7 +254,7 @@ def get_item(item_id: str) -> str: ) # Create a resource from the template - resource = await template.create_resource("resource://items/123", {"item_id": "123"}) + resource = await template.create_resource("resource://items/123", {"item_id": "123"}, Context()) # The resource should inherit the template's annotations assert resource.annotations is not None @@ -298,7 +303,7 @@ def get_item(item_id: str) -> str: ) # Create a resource from the template - resource = await template.create_resource("resource://items/123", {"item_id": "123"}) + resource = await template.create_resource("resource://items/123", {"item_id": "123"}, Context()) # The resource should inherit the template's metadata assert resource.meta is not None diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index cfbe6587bb..3d130bfc33 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -891,7 +891,7 @@ def get_data(name: str) -> str: assert len(await mcp.list_resources()) == 0 # When accessed, should create a concrete resource - resource = await mcp._resource_manager.get_resource("resource://test/data") + resource = await mcp._resource_manager.get_resource("resource://test/data", Context()) assert isinstance(resource, FunctionResource) result = await resource.read() assert result == "Data for test" @@ -1231,6 +1231,19 @@ def prompt_no_context(text: str) -> str: class TestServerPrompts: """Test prompt functionality in MCPServer server.""" + async def test_get_prompt_direct_call_without_context(self): + """Test calling mcp.get_prompt() directly without passing context.""" + mcp = MCPServer() + + @mcp.prompt() + def fn() -> str: + return "Hello, world!" + + result = await mcp.get_prompt("fn") + content = result.messages[0].content + assert isinstance(content, TextContent) + assert content.text == "Hello, world!" + async def test_prompt_decorator(self): """Test that the prompt decorator registers prompts correctly.""" mcp = MCPServer() @@ -1243,7 +1256,7 @@ def fn() -> str: assert len(prompts) == 1 assert prompts[0].name == "fn" # Don't compare functions directly since validate_call wraps them - content = await prompts[0].render() + content = await prompts[0].render(None, Context()) assert isinstance(content[0].content, TextContent) assert content[0].content.text == "Hello, world!" @@ -1258,7 +1271,7 @@ def fn() -> str: prompts = mcp._prompt_manager.list_prompts() assert len(prompts) == 1 assert prompts[0].name == "custom_name" - content = await prompts[0].render() + content = await prompts[0].render(None, Context()) assert isinstance(content[0].content, TextContent) assert content[0].content.text == "Hello, world!" @@ -1273,7 +1286,7 @@ def fn() -> str: prompts = mcp._prompt_manager.list_prompts() assert len(prompts) == 1 assert prompts[0].description == "A custom description" - content = await prompts[0].render() + content = await prompts[0].render(None, Context()) assert isinstance(content[0].content, TextContent) assert content[0].content.text == "Hello, world!" diff --git a/tests/server/mcpserver/test_tool_manager.py b/tests/server/mcpserver/test_tool_manager.py index 550bba50a3..f990ec47b7 100644 --- a/tests/server/mcpserver/test_tool_manager.py +++ b/tests/server/mcpserver/test_tool_manager.py @@ -188,7 +188,7 @@ def sum(a: int, b: int) -> int: manager = ToolManager() manager.add_tool(sum) - result = await manager.call_tool("sum", {"a": 1, "b": 2}) + result = await manager.call_tool("sum", {"a": 1, "b": 2}, Context()) assert result == 3 @pytest.mark.anyio @@ -199,7 +199,7 @@ async def double(n: int) -> int: manager = ToolManager() manager.add_tool(double) - result = await manager.call_tool("double", {"n": 5}) + result = await manager.call_tool("double", {"n": 5}, Context()) assert result == 10 @pytest.mark.anyio @@ -213,7 +213,7 @@ def __call__(self, x: int) -> int: manager = ToolManager() tool = manager.add_tool(MyTool()) - result = await tool.run({"x": 5}) + result = await tool.run({"x": 5}, Context()) assert result == 10 @pytest.mark.anyio @@ -227,7 +227,7 @@ async def __call__(self, x: int) -> int: manager = ToolManager() tool = manager.add_tool(MyAsyncTool()) - result = await tool.run({"x": 5}) + result = await tool.run({"x": 5}, Context()) assert result == 10 @pytest.mark.anyio @@ -238,7 +238,7 @@ def sum(a: int, b: int = 1) -> int: manager = ToolManager() manager.add_tool(sum) - result = await manager.call_tool("sum", {"a": 1}) + result = await manager.call_tool("sum", {"a": 1}, Context()) assert result == 2 @pytest.mark.anyio @@ -250,13 +250,13 @@ def sum(a: int, b: int) -> int: # pragma: no cover manager = ToolManager() manager.add_tool(sum) with pytest.raises(ToolError): - await manager.call_tool("sum", {"a": 1}) + await manager.call_tool("sum", {"a": 1}, Context()) @pytest.mark.anyio async def test_call_unknown_tool(self): manager = ToolManager() with pytest.raises(ToolError): - await manager.call_tool("unknown", {"a": 1}) + await manager.call_tool("unknown", {"a": 1}, Context()) @pytest.mark.anyio async def test_call_tool_with_list_int_input(self): @@ -266,9 +266,9 @@ def sum_vals(vals: list[int]) -> int: manager = ToolManager() manager.add_tool(sum_vals) # Try both with plain list and with JSON list - result = await manager.call_tool("sum_vals", {"vals": "[1, 2, 3]"}) + result = await manager.call_tool("sum_vals", {"vals": "[1, 2, 3]"}, Context()) assert result == 6 - result = await manager.call_tool("sum_vals", {"vals": [1, 2, 3]}) + result = await manager.call_tool("sum_vals", {"vals": [1, 2, 3]}, Context()) assert result == 6 @pytest.mark.anyio @@ -279,13 +279,13 @@ def concat_strs(vals: list[str] | str) -> str: manager = ToolManager() manager.add_tool(concat_strs) # Try both with plain python object and with JSON list - result = await manager.call_tool("concat_strs", {"vals": ["a", "b", "c"]}) + result = await manager.call_tool("concat_strs", {"vals": ["a", "b", "c"]}, Context()) assert result == "abc" - result = await manager.call_tool("concat_strs", {"vals": '["a", "b", "c"]'}) + result = await manager.call_tool("concat_strs", {"vals": '["a", "b", "c"]'}, Context()) assert result == "abc" - result = await manager.call_tool("concat_strs", {"vals": "a"}) + result = await manager.call_tool("concat_strs", {"vals": "a"}, Context()) assert result == "a" - result = await manager.call_tool("concat_strs", {"vals": '"a"'}) + result = await manager.call_tool("concat_strs", {"vals": '"a"'}, Context()) assert result == '"a"' @pytest.mark.anyio @@ -297,7 +297,7 @@ class Shrimp(BaseModel): shrimp: list[Shrimp] x: None - def name_shrimp(tank: MyShrimpTank, ctx: Context[ServerSessionT, None]) -> list[str]: + def name_shrimp(tank: MyShrimpTank) -> list[str]: return [x.name for x in tank.shrimp] manager = ToolManager() @@ -305,11 +305,13 @@ def name_shrimp(tank: MyShrimpTank, ctx: Context[ServerSessionT, None]) -> list[ result = await manager.call_tool( "name_shrimp", {"tank": {"x": None, "shrimp": [{"name": "rex"}, {"name": "gertrude"}]}}, + Context(), ) assert result == ["rex", "gertrude"] result = await manager.call_tool( "name_shrimp", {"tank": '{"x": null, "shrimp": [{"name": "rex"}, {"name": "gertrude"}]}'}, + Context(), ) assert result == ["rex", "gertrude"] @@ -364,9 +366,7 @@ def tool_with_context(x: int, ctx: Context[ServerSessionT, None]) -> str: manager = ToolManager() manager.add_tool(tool_with_context) - mcp = MCPServer() - ctx = mcp.get_context() - result = await manager.call_tool("tool_with_context", {"x": 42}, context=ctx) + result = await manager.call_tool("tool_with_context", {"x": 42}, context=Context()) assert result == "42" @pytest.mark.anyio @@ -380,22 +380,7 @@ async def async_tool(x: int, ctx: Context[ServerSessionT, None]) -> str: manager = ToolManager() manager.add_tool(async_tool) - mcp = MCPServer() - ctx = mcp.get_context() - result = await manager.call_tool("async_tool", {"x": 42}, context=ctx) - assert result == "42" - - @pytest.mark.anyio - async def test_context_optional(self): - """Test that context is optional when calling tools.""" - - def tool_with_context(x: int, ctx: Context[ServerSessionT, None] | None = None) -> str: - return str(x) - - manager = ToolManager() - manager.add_tool(tool_with_context) - # Should not raise an error when context is not provided - result = await manager.call_tool("tool_with_context", {"x": 42}) + result = await manager.call_tool("async_tool", {"x": 42}, context=Context()) assert result == "42" @pytest.mark.anyio @@ -408,10 +393,8 @@ def tool_with_context(x: int, ctx: Context[ServerSessionT, None]) -> str: manager = ToolManager() manager.add_tool(tool_with_context) - mcp = MCPServer() - ctx = mcp.get_context() with pytest.raises(ToolError, match="Error executing tool tool_with_context"): - await manager.call_tool("tool_with_context", {"x": 42}, context=ctx) + await manager.call_tool("tool_with_context", {"x": 42}, context=Context()) class TestToolAnnotations: @@ -471,7 +454,7 @@ def get_user(user_id: int) -> UserOutput: manager = ToolManager() manager.add_tool(get_user) - result = await manager.call_tool("get_user", {"user_id": 1}, convert_result=True) + result = await manager.call_tool("get_user", {"user_id": 1}, Context(), convert_result=True) # don't test unstructured output here, just the structured conversion assert len(result) == 2 and result[1] == {"name": "John", "age": 30} @@ -485,9 +468,9 @@ def double_number(n: int) -> int: manager = ToolManager() manager.add_tool(double_number) - result = await manager.call_tool("double_number", {"n": 5}) + result = await manager.call_tool("double_number", {"n": 5}, Context()) assert result == 10 - result = await manager.call_tool("double_number", {"n": 5}, convert_result=True) + result = await manager.call_tool("double_number", {"n": 5}, Context(), convert_result=True) assert isinstance(result[0][0], TextContent) and result[1] == {"result": 10} @pytest.mark.anyio @@ -506,7 +489,7 @@ def get_user_dict(user_id: int) -> UserDict: manager = ToolManager() manager.add_tool(get_user_dict) - result = await manager.call_tool("get_user_dict", {"user_id": 1}) + result = await manager.call_tool("get_user_dict", {"user_id": 1}, Context()) assert result == expected_output @pytest.mark.anyio @@ -526,7 +509,7 @@ def get_person() -> Person: manager = ToolManager() manager.add_tool(get_person) - result = await manager.call_tool("get_person", {}, convert_result=True) + result = await manager.call_tool("get_person", {}, Context(), convert_result=True) # don't test unstructured output here, just the structured conversion assert len(result) == 2 and result[1] == expected_output @@ -543,9 +526,9 @@ def get_numbers() -> list[int]: manager = ToolManager() manager.add_tool(get_numbers) - result = await manager.call_tool("get_numbers", {}) + result = await manager.call_tool("get_numbers", {}, Context()) assert result == expected_list - result = await manager.call_tool("get_numbers", {}, convert_result=True) + result = await manager.call_tool("get_numbers", {}, Context(), convert_result=True) assert isinstance(result[0][0], TextContent) and result[1] == expected_output @pytest.mark.anyio @@ -558,7 +541,7 @@ def get_dict() -> dict[str, Any]: manager = ToolManager() manager.add_tool(get_dict, structured_output=False) - result = await manager.call_tool("get_dict", {}) + result = await manager.call_tool("get_dict", {}, Context()) assert isinstance(result, dict) assert result == {"key": "value"} @@ -601,12 +584,12 @@ def get_config() -> dict[str, Any]: assert "properties" not in tool.output_schema # dict[str, Any] has no constraints # Test raw result - result = await manager.call_tool("get_config", {}) + result = await manager.call_tool("get_config", {}, Context()) expected = {"debug": True, "port": 8080, "features": ["auth", "logging"]} assert result == expected # Test converted result - result = await manager.call_tool("get_config", {}) + result = await manager.call_tool("get_config", {}, Context()) assert result == expected @pytest.mark.anyio @@ -626,12 +609,12 @@ def get_scores() -> dict[str, int]: assert tool.output_schema["additionalProperties"]["type"] == "integer" # Test raw result - result = await manager.call_tool("get_scores", {}) + result = await manager.call_tool("get_scores", {}, Context()) expected = {"alice": 100, "bob": 85, "charlie": 92} assert result == expected # Test converted result - result = await manager.call_tool("get_scores", {}) + result = await manager.call_tool("get_scores", {}, Context()) assert result == expected @@ -885,7 +868,7 @@ def greet(name: str) -> str: manager.add_tool(greet) # Verify tool works before removal - result = await manager.call_tool("greet", {"name": "World"}) + result = await manager.call_tool("greet", {"name": "World"}, Context()) assert result == "Hello, World!" # Remove the tool @@ -893,7 +876,7 @@ def greet(name: str) -> str: # Verify calling removed tool raises error with pytest.raises(ToolError, match="Unknown tool: greet"): - await manager.call_tool("greet", {"name": "World"}) + await manager.call_tool("greet", {"name": "World"}, Context()) def test_remove_tool_case_sensitive(self): """Test that tool removal is case-sensitive.""" diff --git a/tests/server/mcpserver/tools/__init__.py b/tests/server/mcpserver/tools/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/server/mcpserver/tools/test_base.py b/tests/server/mcpserver/tools/test_base.py new file mode 100644 index 0000000000..22d5f973e9 --- /dev/null +++ b/tests/server/mcpserver/tools/test_base.py @@ -0,0 +1,10 @@ +from mcp.server.mcpserver import Context +from mcp.server.mcpserver.tools.base import Tool + + +def test_context_detected_in_union_annotation(): + def my_tool(x: int, ctx: Context | None) -> str: + raise NotImplementedError + + tool = Tool.from_function(my_tool) + assert tool.context_kwarg == "ctx" From b3149d2f33dd929143123195755bf4b9b57efddb Mon Sep 17 00:00:00 2001 From: Varun6578 <34965159+Varun6578@users.noreply.github.com> Date: Wed, 4 Mar 2026 20:15:11 +0530 Subject: [PATCH 02/60] fix: clean up SSE session on client disconnect (#2200) Co-authored-by: Varun Sharma Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Max Isbey <224885523+maxisbey@users.noreply.github.com> --- src/mcp/server/sse.py | 1 + tests/shared/test_sse.py | 27 +++++++++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 9007230cea..9dcee67f78 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -186,6 +186,7 @@ async def response_wrapper(scope: Scope, receive: Receive, send: Send): ) await read_stream_writer.aclose() await write_stream_reader.aclose() + self._read_stream_writers.pop(session_id, None) logging.debug(f"Client session disconnected {session_id}") logger.debug("Starting SSE response task") diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 7b2bc0a139..bfbecc0c8a 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -611,3 +611,30 @@ async def mock_aiter_sse() -> AsyncGenerator[ServerSentEvent, None]: assert not isinstance(msg, Exception) assert isinstance(msg.message, types.JSONRPCResponse) assert msg.message.id == 1 + + +@pytest.mark.anyio +async def test_sse_session_cleanup_on_disconnect(server: None, server_url: str) -> None: + """Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/1227 + + When a client disconnects, the server should remove the session from + _read_stream_writers. Without this cleanup, stale sessions accumulate and + POST requests to disconnected sessions return 202 Accepted followed by a + ClosedResourceError when the server tries to write to the dead stream. + """ + captured: list[str] = [] + + # Connect a client session, then disconnect + async with sse_client(server_url + "/sse", on_session_created=captured.append) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + + # After disconnect, POST to the stale session should return 404 + # (not 202 as it did before the fix) + async with httpx.AsyncClient() as client: + response = await client.post( + f"{server_url}/messages/?session_id={captured[0]}", + json={"jsonrpc": "2.0", "method": "ping", "id": 99}, + headers={"Content-Type": "application/json"}, + ) + assert response.status_code == 404 From 528abfab86f1f3c003bd7d54a1f0bbd65d81c59c Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 4 Mar 2026 16:11:34 +0000 Subject: [PATCH 03/60] tests: remove lax-no-cover pragmas by moving assertions before cancellation (#2206) --- tests/shared/test_sse.py | 19 ++++----- tests/shared/test_streamable_http.py | 63 +++++++++++++--------------- 2 files changed, 36 insertions(+), 46 deletions(-) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index bfbecc0c8a..890e997332 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -203,19 +203,15 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non @pytest.mark.anyio async def test_sse_client_on_session_created(server: None, server_url: str) -> None: - captured_session_id: str | None = None - - def on_session_created(session_id: str) -> None: - nonlocal captured_session_id - captured_session_id = session_id + captured: list[str] = [] - async with sse_client(server_url + "/sse", on_session_created=on_session_created) as streams: + async with sse_client(server_url + "/sse", on_session_created=captured.append) as streams: async with ClientSession(*streams) as session: result = await session.initialize() assert isinstance(result, InitializeResult) - - assert captured_session_id is not None # pragma: lax no cover - assert len(captured_session_id) > 0 # pragma: lax no cover + # Callback fires when the endpoint event arrives, before sse_client yields. + assert len(captured) == 1 + assert len(captured[0]) > 0 @pytest.mark.parametrize( @@ -248,8 +244,9 @@ def mock_extract(url: str) -> None: async with ClientSession(*streams) as session: result = await session.initialize() assert isinstance(result, InitializeResult) - - callback_mock.assert_not_called() # pragma: lax no cover + # Callback would have fired by now (endpoint event arrives before + # sse_client yields); if it hasn't, it won't. + callback_mock.assert_not_called() @pytest.fixture diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 42b1a3698a..61ba4a2e54 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -1132,22 +1132,19 @@ async def test_streamable_http_client_session_termination(basic_server: None, ba read_stream, write_stream, ): - async with ClientSession(read_stream, write_stream) as session: + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch # Initialize the session result = await session.initialize() assert isinstance(result, InitializeResult) assert len(captured_ids) > 0 captured_session_id = captured_ids[0] assert captured_session_id is not None + headers = {MCP_SESSION_ID_HEADER: captured_session_id} # Make a request to confirm session is working tools = await session.list_tools() assert len(tools.tools) == 10 - headers: dict[str, str] = {} # pragma: lax no cover - if captured_session_id: # pragma: lax no cover - headers[MCP_SESSION_ID_HEADER] = captured_session_id - async with create_mcp_http_client(headers=headers) as httpx_client2: async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client2) as ( read_stream, @@ -1196,22 +1193,19 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt read_stream, write_stream, ): - async with ClientSession(read_stream, write_stream) as session: + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch # Initialize the session result = await session.initialize() assert isinstance(result, InitializeResult) assert len(captured_ids) > 0 captured_session_id = captured_ids[0] assert captured_session_id is not None + headers = {MCP_SESSION_ID_HEADER: captured_session_id} # Make a request to confirm session is working tools = await session.list_tools() assert len(tools.tools) == 10 - headers: dict[str, str] = {} # pragma: lax no cover - if captured_session_id: # pragma: lax no cover - headers[MCP_SESSION_ID_HEADER] = captured_session_id - async with create_mcp_http_client(headers=headers) as httpx_client2: async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client2) as ( read_stream, @@ -1231,7 +1225,6 @@ async def test_streamable_http_client_resumption(event_server: tuple[SimpleEvent # Variables to track the state captured_resumption_token: str | None = None captured_notifications: list[types.ServerNotification] = [] - captured_protocol_version: str | int | None = None first_notification_received = False async def message_handler( # pragma: no branch @@ -1258,15 +1251,20 @@ async def on_resumption_token_update(token: str) -> None: read_stream, write_stream, ): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: + async with ClientSession( # pragma: no branch + read_stream, write_stream, message_handler=message_handler + ) as session: # Initialize the session result = await session.initialize() assert isinstance(result, InitializeResult) assert len(captured_ids) > 0 captured_session_id = captured_ids[0] assert captured_session_id is not None - # Capture the negotiated protocol version - captured_protocol_version = result.protocol_version + # Build phase-2 headers now while both values are in scope + headers: dict[str, Any] = { + MCP_SESSION_ID_HEADER: captured_session_id, + MCP_PROTOCOL_VERSION_HEADER: result.protocol_version, + } # Start the tool that will wait on lock in a task async with anyio.create_task_group() as tg: # pragma: no branch @@ -1291,25 +1289,19 @@ async def run_tool(): while not first_notification_received or not captured_resumption_token: await anyio.sleep(0.1) + # The while loop only exits after first_notification_received=True, + # which is set by message_handler immediately after appending to + # captured_notifications. The server tool is blocked on its lock, + # so nothing else can arrive before we cancel. + assert len(captured_notifications) == 1 + assert isinstance(captured_notifications[0], types.LoggingMessageNotification) + assert captured_notifications[0].params.data == "First notification before lock" + # Reset for phase 2 before cancelling + captured_notifications.clear() + # Kill the client session while tool is waiting on lock tg.cancel_scope.cancel() - # Verify we received exactly one notification (inside ClientSession - # so coverage tracks these on Python 3.11, see PR #1897 for details) - assert len(captured_notifications) == 1 # pragma: lax no cover - assert isinstance(captured_notifications[0], types.LoggingMessageNotification) # pragma: lax no cover - assert captured_notifications[0].params.data == "First notification before lock" # pragma: lax no cover - - # Clear notifications and set up headers for phase 2 (between connections, - # not tracked by coverage on Python 3.11 due to cancel scope + sys.settrace bug) - captured_notifications = [] # pragma: lax no cover - assert captured_session_id is not None # pragma: lax no cover - assert captured_protocol_version is not None # pragma: lax no cover - headers: dict[str, Any] = { # pragma: lax no cover - MCP_SESSION_ID_HEADER: captured_session_id, - MCP_PROTOCOL_VERSION_HEADER: captured_protocol_version, - } - async with create_mcp_http_client(headers=headers) as httpx_client2: async with streamable_http_client(f"{server_url}/mcp", http_client=httpx_client2) as ( read_stream, @@ -2092,11 +2084,12 @@ async def on_resumption_token(token: str) -> None: assert isinstance(result.content[0], TextContent) assert "Completed 3 checkpoints" in result.content[0].text - # 4 priming + 3 notifications + 1 response = 8 tokens - assert len(resumption_tokens) == 8, ( # pragma: lax no cover - f"Expected 8 resumption tokens (4 priming + 3 notifs + 1 response), " - f"got {len(resumption_tokens)}: {resumption_tokens}" - ) + # 4 priming + 3 notifications + 1 response = 8 tokens. All tokens are + # captured before send_request returns, so this is safe to check here. + assert len(resumption_tokens) == 8, ( + f"Expected 8 resumption tokens (4 priming + 3 notifs + 1 response), " + f"got {len(resumption_tokens)}: {resumption_tokens}" + ) @pytest.mark.anyio From 7c0224828bc4e62c53208a051aaa209779868f53 Mon Sep 17 00:00:00 2001 From: Giulio Leone Date: Thu, 5 Mar 2026 15:57:33 +0100 Subject: [PATCH 04/60] fix(oauth): include client_id in token request body for client_secret_post (#2185) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/mcp/client/auth/oauth2.py | 5 +- .../extensions/test_client_credentials.py | 66 +++++++++++++++++++ 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 7f5af51867..25075dec3b 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -205,8 +205,9 @@ def prepare_token_auth( headers["Authorization"] = f"Basic {encoded_credentials}" # Don't include client_secret in body for basic auth data = {k: v for k, v in data.items() if k != "client_secret"} - elif auth_method == "client_secret_post" and self.client_info.client_secret: - # Include client_secret in request body + elif auth_method == "client_secret_post" and self.client_info.client_id and self.client_info.client_secret: + # Include client_id and client_secret in request body (RFC 6749 §2.3.1) + data["client_id"] = self.client_info.client_id data["client_secret"] = self.client_info.client_secret # For auth_method == "none", don't add any client_secret diff --git a/tests/client/auth/extensions/test_client_credentials.py b/tests/client/auth/extensions/test_client_credentials.py index 0003b16797..09760f4530 100644 --- a/tests/client/auth/extensions/test_client_credentials.py +++ b/tests/client/auth/extensions/test_client_credentials.py @@ -252,6 +252,72 @@ async def test_exchange_token_client_credentials(self, mock_storage: MockTokenSt assert "scope=read write" in content assert "resource=https://api.example.com/v1/mcp" in content + @pytest.mark.anyio + async def test_exchange_token_client_secret_post_includes_client_id(self, mock_storage: MockTokenStorage): + """Test that client_secret_post includes both client_id and client_secret in body (RFC 6749 §2.3.1).""" + provider = ClientCredentialsOAuthProvider( + server_url="https://api.example.com/v1/mcp", + storage=mock_storage, + client_id="test-client-id", + client_secret="test-client-secret", + token_endpoint_auth_method="client_secret_post", + scopes="read write", + ) + await provider._initialize() + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://api.example.com"), + authorization_endpoint=AnyHttpUrl("https://api.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://api.example.com/token"), + ) + provider.context.protocol_version = "2025-06-18" + + request = await provider._perform_authorization() + + content = urllib.parse.unquote_plus(request.content.decode()) + assert "grant_type=client_credentials" in content + assert "client_id=test-client-id" in content + assert "client_secret=test-client-secret" in content + # Should NOT have Basic auth header + assert "Authorization" not in request.headers + + @pytest.mark.anyio + async def test_exchange_token_client_secret_post_without_client_id(self, mock_storage: MockTokenStorage): + """Test client_secret_post skips body credentials when client_id is None.""" + provider = ClientCredentialsOAuthProvider( + server_url="https://api.example.com/v1/mcp", + storage=mock_storage, + client_id="placeholder", + client_secret="test-client-secret", + token_endpoint_auth_method="client_secret_post", + scopes="read write", + ) + await provider._initialize() + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://api.example.com"), + authorization_endpoint=AnyHttpUrl("https://api.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://api.example.com/token"), + ) + provider.context.protocol_version = "2025-06-18" + # Override client_info to have client_id=None (edge case) + provider.context.client_info = OAuthClientInformationFull( + redirect_uris=None, + client_id=None, + client_secret="test-client-secret", + grant_types=["client_credentials"], + token_endpoint_auth_method="client_secret_post", + scope="read write", + ) + + request = await provider._perform_authorization() + + content = urllib.parse.unquote_plus(request.content.decode()) + assert "grant_type=client_credentials" in content + # Neither client_id nor client_secret should be in body since client_id is None + # (RFC 6749 §2.3.1 requires both for client_secret_post) + assert "client_id=" not in content + assert "client_secret=" not in content + assert "Authorization" not in request.headers + @pytest.mark.anyio async def test_exchange_token_without_scopes(self, mock_storage: MockTokenStorage): """Test token exchange without scopes.""" From b33c81167572096baeb7f7cff35987fc1168b28d Mon Sep 17 00:00:00 2001 From: Giulio Leone Date: Thu, 5 Mar 2026 16:44:33 +0100 Subject: [PATCH 05/60] perf: use deque for InMemoryTaskMessageQueue FIFO operations (#2165) --- src/mcp/shared/experimental/tasks/message_queue.py | 9 +++++---- tests/experimental/tasks/test_message_queue.py | 3 ++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/mcp/shared/experimental/tasks/message_queue.py b/src/mcp/shared/experimental/tasks/message_queue.py index 018c2b7b26..e17c4a8650 100644 --- a/src/mcp/shared/experimental/tasks/message_queue.py +++ b/src/mcp/shared/experimental/tasks/message_queue.py @@ -12,6 +12,7 @@ """ from abc import ABC, abstractmethod +from collections import deque from dataclasses import dataclass, field from datetime import datetime, timezone from typing import Any, Literal @@ -151,13 +152,13 @@ class InMemoryTaskMessageQueue(TaskMessageQueue): """ def __init__(self) -> None: - self._queues: dict[str, list[QueuedMessage]] = {} + self._queues: dict[str, deque[QueuedMessage]] = {} self._events: dict[str, anyio.Event] = {} - def _get_queue(self, task_id: str) -> list[QueuedMessage]: + def _get_queue(self, task_id: str) -> deque[QueuedMessage]: """Get or create the queue for a task.""" if task_id not in self._queues: - self._queues[task_id] = [] + self._queues[task_id] = deque() return self._queues[task_id] async def enqueue(self, task_id: str, message: QueuedMessage) -> None: @@ -172,7 +173,7 @@ async def dequeue(self, task_id: str) -> QueuedMessage | None: queue = self._get_queue(task_id) if not queue: return None - return queue.pop(0) + return queue.popleft() async def peek(self, task_id: str) -> QueuedMessage | None: """Return the next message without removing it.""" diff --git a/tests/experimental/tasks/test_message_queue.py b/tests/experimental/tasks/test_message_queue.py index a8517e535c..eca113d5b4 100644 --- a/tests/experimental/tasks/test_message_queue.py +++ b/tests/experimental/tasks/test_message_queue.py @@ -1,5 +1,6 @@ """Tests for TaskMessageQueue and InMemoryTaskMessageQueue.""" +from collections import deque from datetime import datetime, timezone import anyio @@ -270,7 +271,7 @@ async def is_empty_with_injection(tid: str) -> bool: if call_count == 2 and tid == task_id: # Before second check, inject a message - this simulates a message # arriving between event creation and the double-check - queue._queues[task_id] = [QueuedMessage(type="request", message=make_request())] + queue._queues[task_id] = deque([QueuedMessage(type="request", message=make_request())]) return await original_is_empty(tid) queue.is_empty = is_empty_with_injection # type: ignore[method-assign] From 92f1b1500d808a32d261d9e101d6f3bae3ad9e25 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 6 Mar 2026 14:50:58 +0000 Subject: [PATCH 06/60] fix: remove MIME type validation from MCPServer Resource (#2235) --- src/mcp/server/mcpserver/resources/base.py | 6 +----- tests/server/mcpserver/resources/test_resources.py | 8 ++++++++ 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/mcp/server/mcpserver/resources/base.py b/src/mcp/server/mcpserver/resources/base.py index d3ccc425ef..d48e0695cb 100644 --- a/src/mcp/server/mcpserver/resources/base.py +++ b/src/mcp/server/mcpserver/resources/base.py @@ -23,11 +23,7 @@ class Resource(BaseModel, abc.ABC): name: str | None = Field(description="Name of the resource", default=None) title: str | None = Field(description="Human-readable title of the resource", default=None) description: str | None = Field(description="Description of the resource", default=None) - mime_type: str = Field( - default="text/plain", - description="MIME type of the resource content", - pattern=r"^[a-zA-Z0-9]+/[a-zA-Z0-9\-+.]+(;\s*[a-zA-Z0-9\-_.]+=[a-zA-Z0-9\-_.]+)*$", - ) + mime_type: str = Field(default="text/plain", description="MIME type of the resource content") icons: list[Icon] | None = Field(default=None, description="Optional list of icons for this resource") annotations: Annotations | None = Field(default=None, description="Optional annotations for the resource") meta: dict[str, Any] | None = Field(default=None, description="Optional metadata for this resource") diff --git a/tests/server/mcpserver/resources/test_resources.py b/tests/server/mcpserver/resources/test_resources.py index 93dc438d5d..5d36beda85 100644 --- a/tests/server/mcpserver/resources/test_resources.py +++ b/tests/server/mcpserver/resources/test_resources.py @@ -91,6 +91,14 @@ def dummy_func() -> str: # pragma: no cover ) assert resource.mime_type == "application/json" + # RFC 2045 quoted parameter value (gh-1756) + resource = FunctionResource( + uri="resource://test", + fn=dummy_func, + mime_type='text/plain; charset="utf-8"', + ) + assert resource.mime_type == 'text/plain; charset="utf-8"' + @pytest.mark.anyio async def test_resource_read_abstract(self): """Test that Resource.read() is abstract.""" From eaf971cf252692c94d3f76d3b9063eaddc7f5eb4 Mon Sep 17 00:00:00 2001 From: Ramesh Reddy Adutla <134313151+rameshreddy-adutla@users.noreply.github.com> Date: Fri, 6 Mar 2026 16:55:37 +0000 Subject: [PATCH 07/60] Add warning log when rejecting request with unknown/expired session ID (#2212) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Max Isbey <224885523+maxisbey@users.noreply.github.com> --- src/mcp/server/streamable_http_manager.py | 1 + tests/server/test_streamable_http_manager.py | 7 +++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 50bcd5e791..c25314eab6 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -272,6 +272,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE # Unknown or expired session ID - return 404 per MCP spec # TODO: Align error code once spec clarifies # See: https://github.com/modelcontextprotocol/python-sdk/issues/1821 + logger.info(f"Rejected request with unknown or expired session ID: {request_mcp_session_id[:64]}") error_response = JSONRPCError( jsonrpc="2.0", id=None, diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 54a898cc5c..47cfbf14a4 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -1,6 +1,7 @@ """Tests for StreamableHTTPSessionManager.""" import json +import logging from typing import Any from unittest.mock import AsyncMock, patch @@ -269,7 +270,7 @@ async def mock_receive(): @pytest.mark.anyio -async def test_unknown_session_id_returns_404(): +async def test_unknown_session_id_returns_404(caplog: pytest.LogCaptureFixture): """Test that requests with unknown session IDs return HTTP 404 per MCP spec.""" app = Server("test-unknown-session") manager = StreamableHTTPSessionManager(app=app) @@ -299,7 +300,8 @@ async def mock_send(message: Message): async def mock_receive(): return {"type": "http.request", "body": b"{}", "more_body": False} # pragma: no cover - await manager.handle_request(scope, mock_receive, mock_send) + with caplog.at_level(logging.INFO): + await manager.handle_request(scope, mock_receive, mock_send) # Find the response start message response_start = next( @@ -315,6 +317,7 @@ async def mock_receive(): assert error_data["id"] is None assert error_data["error"]["code"] == INVALID_REQUEST assert error_data["error"]["message"] == "Session not found" + assert "Rejected request with unknown or expired session ID: non-existent-session-id" in caplog.text @pytest.mark.anyio From 7ba41dcfae2044987fbcca0e10744e992489091f Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 6 Mar 2026 17:24:18 +0000 Subject: [PATCH 08/60] fix: make local coverage runs reliable (#2236) --- .github/workflows/shared.yml | 1 + CLAUDE.md | 16 +++++++++++++--- scripts/test | 1 + tests/test_examples.py | 27 ++++++++++++--------------- 4 files changed, 27 insertions(+), 18 deletions(-) diff --git a/.github/workflows/shared.yml b/.github/workflows/shared.yml index 72e328b541..efb45c8898 100644 --- a/.github/workflows/shared.yml +++ b/.github/workflows/shared.yml @@ -70,6 +70,7 @@ jobs: - name: Run pytest with coverage shell: bash run: | + uv run --frozen --no-sync coverage erase uv run --frozen --no-sync coverage run -m pytest -n auto uv run --frozen --no-sync coverage combine uv run --frozen --no-sync coverage report diff --git a/CLAUDE.md b/CLAUDE.md index e48ce6e70c..98bd451152 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -28,9 +28,19 @@ This document contains critical information about working with this codebase. Fo - Bug fixes require regression tests - IMPORTANT: The `tests/client/test_client.py` is the most well designed test file. Follow its patterns. - IMPORTANT: Be minimal, and focus on E2E tests: Use the `mcp.client.Client` whenever possible. - - IMPORTANT: Before pushing, verify 100% branch coverage on changed files by running - `uv run --frozen pytest -x` (coverage is configured in `pyproject.toml` with `fail_under = 100` - and `branch = true`). If any branch is uncovered, add a test for it before pushing. + - Coverage: CI requires 100% (`fail_under = 100`, `branch = true`). + - Full check: `./scripts/test` (~20s, matches CI exactly) + - Targeted check while iterating: + + ```bash + uv run --frozen coverage erase + uv run --frozen coverage run -m pytest tests/path/test_foo.py + uv run --frozen coverage combine + uv run --frozen coverage report --include='src/mcp/path/foo.py' --fail-under=0 + ``` + + Partial runs can't hit 100% (coverage tracks `tests/` too), so `--fail-under=0` + and `--include` scope the report to what you actually changed. - Avoid `anyio.sleep()` with a fixed duration to wait for async operations. Instead: - Use `anyio.Event` — set it in the callback/handler, `await event.wait()` in the test - For stream messages, use `await stream.receive()` instead of `sleep()` + `receive_nowait()` diff --git a/scripts/test b/scripts/test index 0d08e47b1b..ee1259b597 100755 --- a/scripts/test +++ b/scripts/test @@ -2,6 +2,7 @@ set -ex +uv run --frozen coverage erase uv run --frozen coverage run -m pytest -n auto $@ uv run --frozen coverage combine uv run --frozen coverage report diff --git a/tests/test_examples.py b/tests/test_examples.py index aa9de09579..3af82f04c5 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -5,7 +5,6 @@ # pyright: reportUnknownArgumentType=false # pyright: reportUnknownMemberType=false -import sys from pathlib import Path import pytest @@ -65,12 +64,17 @@ async def test_direct_call_tool_result_return(): @pytest.mark.anyio -async def test_desktop(monkeypatch: pytest.MonkeyPatch): +async def test_desktop(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): """Test the desktop server""" - # Mock desktop directory listing - mock_files = [Path("/fake/path/file1.txt"), Path("/fake/path/file2.txt")] - monkeypatch.setattr(Path, "iterdir", lambda self: mock_files) # type: ignore[reportUnknownArgumentType] - monkeypatch.setattr(Path, "home", lambda: Path("/fake/home")) + # Build a real Desktop directory under tmp_path rather than patching + # Path.iterdir — a class-level patch breaks jsonschema_specifications' + # import-time schema discovery when this test happens to be the first + # tool call in an xdist worker. + desktop = tmp_path / "Desktop" + desktop.mkdir() + (desktop / "file1.txt").touch() + (desktop / "file2.txt").touch() + monkeypatch.setattr(Path, "home", lambda: tmp_path) from examples.mcpserver.desktop import mcp @@ -85,15 +89,8 @@ async def test_desktop(monkeypatch: pytest.MonkeyPatch): content = result.contents[0] assert isinstance(content, TextResourceContents) assert isinstance(content.text, str) - if sys.platform == "win32": # pragma: no cover - file_1 = "/fake/path/file1.txt".replace("/", "\\\\") # might be a bug - file_2 = "/fake/path/file2.txt".replace("/", "\\\\") # might be a bug - assert file_1 in content.text - assert file_2 in content.text - # might be a bug, but the test is passing - else: # pragma: lax no cover - assert "/fake/path/file1.txt" in content.text - assert "/fake/path/file2.txt" in content.text + assert "file1.txt" in content.text + assert "file2.txt" in content.text # TODO(v2): Change back to README.md when v2 is released From 51c53f2c189ddc8f5ed0925e582565a0f91b1d9b Mon Sep 17 00:00:00 2001 From: Shivam Aggarwal Date: Mon, 9 Mar 2026 22:00:02 +0530 Subject: [PATCH 09/60] fix: accept wildcard media types in Accept header per RFC 7231 (#2152) Co-authored-by: Shivam --- src/mcp/server/streamable_http.py | 15 +++-- tests/shared/test_streamable_http.py | 86 ++++++++++++++++++++++++++-- 2 files changed, 92 insertions(+), 9 deletions(-) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 04aed345e0..aa99e7c887 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -391,12 +391,19 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No await self._handle_unsupported_request(request, send) def _check_accept_headers(self, request: Request) -> tuple[bool, bool]: - """Check if the request accepts the required media types.""" + """Check if the request accepts the required media types. + + Supports wildcard media types per RFC 7231, section 5.3.2: + - */* matches any media type + - application/* matches any application/ subtype + - text/* matches any text/ subtype + """ accept_header = request.headers.get("accept", "") - accept_types = [media_type.strip() for media_type in accept_header.split(",")] + accept_types = [media_type.strip().split(";")[0].strip().lower() for media_type in accept_header.split(",")] - has_json = any(media_type.startswith(CONTENT_TYPE_JSON) for media_type in accept_types) - has_sse = any(media_type.startswith(CONTENT_TYPE_SSE) for media_type in accept_types) + has_wildcard = "*/*" in accept_types + has_json = has_wildcard or any(t in (CONTENT_TYPE_JSON, "application/*") for t in accept_types) + has_sse = has_wildcard or any(t in (CONTENT_TYPE_SSE, "text/*") for t in accept_types) return has_json, has_sse diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 61ba4a2e54..f8ca30441b 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -572,8 +572,10 @@ def json_server_url(json_server_port: int) -> str: # Basic request validation tests def test_accept_header_validation(basic_server: None, basic_server_url: str): """Test that Accept header is properly validated.""" - # Test without Accept header - response = requests.post( + # Test without Accept header (suppress requests library default Accept: */*) + session = requests.Session() + session.headers.pop("Accept") + response = session.post( f"{basic_server_url}/mcp", headers={"Content-Type": "application/json"}, json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, @@ -582,6 +584,52 @@ def test_accept_header_validation(basic_server: None, basic_server_url: str): assert "Not Acceptable" in response.text +@pytest.mark.parametrize( + "accept_header", + [ + "*/*", + "application/*, text/*", + "text/*, application/json", + "application/json, text/*", + "*/*;q=0.8", + "application/*;q=0.9, text/*;q=0.8", + ], +) +def test_accept_header_wildcard(basic_server: None, basic_server_url: str, accept_header: str): + """Test that wildcard Accept headers are accepted per RFC 7231.""" + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": accept_header, + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + + +@pytest.mark.parametrize( + "accept_header", + [ + "text/html", + "application/*", + "text/*", + ], +) +def test_accept_header_incompatible(basic_server: None, basic_server_url: str, accept_header: str): + """Test that incompatible Accept headers are rejected for SSE mode.""" + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": accept_header, + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text + + def test_content_type_validation(basic_server: None, basic_server_url: str): """Test that Content-Type header is properly validated.""" # Test with incorrect Content-Type @@ -826,7 +874,10 @@ def test_json_response_accept_json_only(json_response_server: None, json_server_ def test_json_response_missing_accept_header(json_response_server: None, json_server_url: str): """Test that json_response servers reject requests without Accept header.""" mcp_url = f"{json_server_url}/mcp" - response = requests.post( + # Suppress requests library default Accept: */* header + session = requests.Session() + session.headers.pop("Accept") + response = session.post( mcp_url, headers={ "Content-Type": "application/json", @@ -853,6 +904,29 @@ def test_json_response_incorrect_accept_header(json_response_server: None, json_ assert "Not Acceptable" in response.text +@pytest.mark.parametrize( + "accept_header", + [ + "*/*", + "application/*", + "application/*;q=0.9", + ], +) +def test_json_response_wildcard_accept_header(json_response_server: None, json_server_url: str, accept_header: str): + """Test that json_response servers accept wildcard Accept headers per RFC 7231.""" + mcp_url = f"{json_server_url}/mcp" + response = requests.post( + mcp_url, + headers={ + "Accept": accept_header, + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + assert response.headers.get("Content-Type") == "application/json" + + def test_get_sse_stream(basic_server: None, basic_server_url: str): """Test establishing an SSE stream via GET request.""" # First, we need to initialize a session @@ -941,8 +1015,10 @@ def test_get_validation(basic_server: None, basic_server_url: str): assert init_data is not None negotiated_version = init_data["result"]["protocolVersion"] - # Test without Accept header - response = requests.get( + # Test without Accept header (suppress requests library default Accept: */*) + session = requests.Session() + session.headers.pop("Accept") + response = session.get( mcp_url, headers={ MCP_SESSION_ID_HEADER: session_id, From 31a38b50786f8d22c99d781beaa89b71cad26d23 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 9 Mar 2026 16:52:56 +0000 Subject: [PATCH 10/60] fix: correct Context type parameters across examples and tests (#2256) --- README.v2.md | 19 ++++++--------- .../mcp_everything_server/server.py | 15 ++++++------ examples/snippets/servers/elicitation.py | 7 +++--- examples/snippets/servers/notifications.py | 3 +-- examples/snippets/servers/sampling.py | 3 +-- examples/snippets/servers/tool_progress.py | 3 +-- tests/client/test_list_roots_callback.py | 2 +- tests/server/mcpserver/test_elicitation.py | 21 ++++++++--------- tests/server/mcpserver/test_server.py | 17 +++++++------- tests/server/mcpserver/test_tool_manager.py | 11 ++++----- .../server/mcpserver/test_url_elicitation.py | 23 +++++++++---------- .../test_url_elicitation_error_throw.py | 7 +++--- tests/server/test_lifespan.py | 3 +-- 13 files changed, 59 insertions(+), 75 deletions(-) diff --git a/README.v2.md b/README.v2.md index bd6927bf92..55d867586d 100644 --- a/README.v2.md +++ b/README.v2.md @@ -346,13 +346,12 @@ Tools can optionally receive a Context object by including a parameter with the ```python from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession mcp = MCPServer(name="Progress Example") @mcp.tool() -async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str: +async def long_running_task(task_name: str, ctx: Context, steps: int = 5) -> str: """Execute a task with progress updates.""" await ctx.info(f"Starting: {task_name}") @@ -694,13 +693,12 @@ The Context object provides the following capabilities: ```python from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession mcp = MCPServer(name="Progress Example") @mcp.tool() -async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str: +async def long_running_task(task_name: str, ctx: Context, steps: int = 5) -> str: """Execute a task with progress updates.""" await ctx.info(f"Starting: {task_name}") @@ -826,7 +824,6 @@ import uuid from pydantic import BaseModel, Field from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession from mcp.shared.exceptions import UrlElicitationRequiredError from mcp.types import ElicitRequestURLParams @@ -844,7 +841,7 @@ class BookingPreferences(BaseModel): @mcp.tool() -async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerSession, None]) -> str: +async def book_table(date: str, time: str, party_size: int, ctx: Context) -> str: """Book a table with date availability check. This demonstrates form mode elicitation for collecting non-sensitive user input. @@ -868,7 +865,7 @@ async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerS @mcp.tool() -async def secure_payment(amount: float, ctx: Context[ServerSession, None]) -> str: +async def secure_payment(amount: float, ctx: Context) -> str: """Process a secure payment requiring URL confirmation. This demonstrates URL mode elicitation using ctx.elicit_url() for @@ -892,7 +889,7 @@ async def secure_payment(amount: float, ctx: Context[ServerSession, None]) -> st @mcp.tool() -async def connect_service(service_name: str, ctx: Context[ServerSession, None]) -> str: +async def connect_service(service_name: str, ctx: Context) -> str: """Connect to a third-party service requiring OAuth authorization. This demonstrates the "throw error" pattern using UrlElicitationRequiredError. @@ -933,14 +930,13 @@ Tools can interact with LLMs through sampling (generating text): ```python from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession from mcp.types import SamplingMessage, TextContent mcp = MCPServer(name="Sampling Example") @mcp.tool() -async def generate_poem(topic: str, ctx: Context[ServerSession, None]) -> str: +async def generate_poem(topic: str, ctx: Context) -> str: """Generate a poem using LLM sampling.""" prompt = f"Write a short poem about {topic}" @@ -970,13 +966,12 @@ Tools can send logs and notifications through the context: ```python from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession mcp = MCPServer(name="Notifications Example") @mcp.tool() -async def process_data(data: str, ctx: Context[ServerSession, None]) -> str: +async def process_data(data: str, ctx: Context) -> str: """Process data with logging.""" # Different log levels await ctx.debug(f"Debug: Processing '{data}'") diff --git a/examples/servers/everything-server/mcp_everything_server/server.py b/examples/servers/everything-server/mcp_everything_server/server.py index 2101cff28f..a0620b9c1d 100644 --- a/examples/servers/everything-server/mcp_everything_server/server.py +++ b/examples/servers/everything-server/mcp_everything_server/server.py @@ -13,7 +13,6 @@ from mcp.server import ServerRequestContext from mcp.server.mcpserver import Context, MCPServer from mcp.server.mcpserver.prompts.base import UserMessage -from mcp.server.session import ServerSession from mcp.server.streamable_http import EventCallback, EventMessage, EventStore from mcp.types import ( AudioContent, @@ -142,7 +141,7 @@ def test_multiple_content_types() -> list[TextContent | ImageContent | EmbeddedR @mcp.tool() -async def test_tool_with_logging(ctx: Context[ServerSession, None]) -> str: +async def test_tool_with_logging(ctx: Context) -> str: """Tests tool that emits log messages during execution""" await ctx.info("Tool execution started") await asyncio.sleep(0.05) @@ -155,7 +154,7 @@ async def test_tool_with_logging(ctx: Context[ServerSession, None]) -> str: @mcp.tool() -async def test_tool_with_progress(ctx: Context[ServerSession, None]) -> str: +async def test_tool_with_progress(ctx: Context) -> str: """Tests tool that reports progress notifications""" await ctx.report_progress(progress=0, total=100, message="Completed step 0 of 100") await asyncio.sleep(0.05) @@ -173,7 +172,7 @@ async def test_tool_with_progress(ctx: Context[ServerSession, None]) -> str: @mcp.tool() -async def test_sampling(prompt: str, ctx: Context[ServerSession, None]) -> str: +async def test_sampling(prompt: str, ctx: Context) -> str: """Tests server-initiated sampling (LLM completion request)""" try: # Request sampling from client @@ -198,7 +197,7 @@ class UserResponse(BaseModel): @mcp.tool() -async def test_elicitation(message: str, ctx: Context[ServerSession, None]) -> str: +async def test_elicitation(message: str, ctx: Context) -> str: """Tests server-initiated elicitation (user input request)""" try: # Request user input from client @@ -230,7 +229,7 @@ class SEP1034DefaultsSchema(BaseModel): @mcp.tool() -async def test_elicitation_sep1034_defaults(ctx: Context[ServerSession, None]) -> str: +async def test_elicitation_sep1034_defaults(ctx: Context) -> str: """Tests elicitation with default values for all primitive types (SEP-1034)""" try: # Request user input with defaults for all primitive types @@ -289,7 +288,7 @@ class EnumSchemasTestSchema(BaseModel): @mcp.tool() -async def test_elicitation_sep1330_enums(ctx: Context[ServerSession, None]) -> str: +async def test_elicitation_sep1330_enums(ctx: Context) -> str: """Tests elicitation with enum schema variations per SEP-1330""" try: result = await ctx.elicit( @@ -313,7 +312,7 @@ def test_error_handling() -> str: @mcp.tool() -async def test_reconnection(ctx: Context[ServerSession, None]) -> str: +async def test_reconnection(ctx: Context) -> str: """Tests SSE polling by closing stream mid-call (SEP-1699)""" await ctx.info("Before disconnect") diff --git a/examples/snippets/servers/elicitation.py b/examples/snippets/servers/elicitation.py index 70e515c75f..79453f543e 100644 --- a/examples/snippets/servers/elicitation.py +++ b/examples/snippets/servers/elicitation.py @@ -10,7 +10,6 @@ from pydantic import BaseModel, Field from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession from mcp.shared.exceptions import UrlElicitationRequiredError from mcp.types import ElicitRequestURLParams @@ -28,7 +27,7 @@ class BookingPreferences(BaseModel): @mcp.tool() -async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerSession, None]) -> str: +async def book_table(date: str, time: str, party_size: int, ctx: Context) -> str: """Book a table with date availability check. This demonstrates form mode elicitation for collecting non-sensitive user input. @@ -52,7 +51,7 @@ async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerS @mcp.tool() -async def secure_payment(amount: float, ctx: Context[ServerSession, None]) -> str: +async def secure_payment(amount: float, ctx: Context) -> str: """Process a secure payment requiring URL confirmation. This demonstrates URL mode elicitation using ctx.elicit_url() for @@ -76,7 +75,7 @@ async def secure_payment(amount: float, ctx: Context[ServerSession, None]) -> st @mcp.tool() -async def connect_service(service_name: str, ctx: Context[ServerSession, None]) -> str: +async def connect_service(service_name: str, ctx: Context) -> str: """Connect to a third-party service requiring OAuth authorization. This demonstrates the "throw error" pattern using UrlElicitationRequiredError. diff --git a/examples/snippets/servers/notifications.py b/examples/snippets/servers/notifications.py index 5579af510f..d6d903cc7f 100644 --- a/examples/snippets/servers/notifications.py +++ b/examples/snippets/servers/notifications.py @@ -1,11 +1,10 @@ from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession mcp = MCPServer(name="Notifications Example") @mcp.tool() -async def process_data(data: str, ctx: Context[ServerSession, None]) -> str: +async def process_data(data: str, ctx: Context) -> str: """Process data with logging.""" # Different log levels await ctx.debug(f"Debug: Processing '{data}'") diff --git a/examples/snippets/servers/sampling.py b/examples/snippets/servers/sampling.py index 4ffeeda726..43259589a4 100644 --- a/examples/snippets/servers/sampling.py +++ b/examples/snippets/servers/sampling.py @@ -1,12 +1,11 @@ from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession from mcp.types import SamplingMessage, TextContent mcp = MCPServer(name="Sampling Example") @mcp.tool() -async def generate_poem(topic: str, ctx: Context[ServerSession, None]) -> str: +async def generate_poem(topic: str, ctx: Context) -> str: """Generate a poem using LLM sampling.""" prompt = f"Write a short poem about {topic}" diff --git a/examples/snippets/servers/tool_progress.py b/examples/snippets/servers/tool_progress.py index 0b283cb1f5..376dbc5db8 100644 --- a/examples/snippets/servers/tool_progress.py +++ b/examples/snippets/servers/tool_progress.py @@ -1,11 +1,10 @@ from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession mcp = MCPServer(name="Progress Example") @mcp.tool() -async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str: +async def long_running_task(task_name: str, ctx: Context, steps: int = 5) -> str: """Execute a task with progress updates.""" await ctx.info(f"Starting: {task_name}") diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index 1ab90be772..be4b9a97b9 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -25,7 +25,7 @@ async def list_roots_callback( return callback_return @server.tool("test_list_roots") - async def test_list_roots(context: Context[None], message: str): + async def test_list_roots(context: Context, message: str): roots = await context.session.list_roots() assert roots == callback_return return True diff --git a/tests/server/mcpserver/test_elicitation.py b/tests/server/mcpserver/test_elicitation.py index 6cf49fbd7a..679fb848f5 100644 --- a/tests/server/mcpserver/test_elicitation.py +++ b/tests/server/mcpserver/test_elicitation.py @@ -8,7 +8,6 @@ from mcp import Client, types from mcp.client.session import ClientSession, ElicitationFnT from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession from mcp.shared._context import RequestContext from mcp.types import ElicitRequestParams, ElicitResult, TextContent @@ -22,7 +21,7 @@ def create_ask_user_tool(mcp: MCPServer): """Create a standard ask_user tool that handles all elicitation responses.""" @mcp.tool(description="A tool that uses elicitation") - async def ask_user(prompt: str, ctx: Context[ServerSession, None]) -> str: + async def ask_user(prompt: str, ctx: Context) -> str: result = await ctx.elicit(message=f"Tool wants to ask: {prompt}", schema=AnswerSchema) if result.action == "accept" and result.data: @@ -97,7 +96,7 @@ async def test_elicitation_schema_validation(): def create_validation_tool(name: str, schema_class: type[BaseModel]): @mcp.tool(name=name, description=f"Tool testing {name}") - async def tool(ctx: Context[ServerSession, None]) -> str: + async def tool(ctx: Context) -> str: try: await ctx.elicit(message="This should fail validation", schema=schema_class) return "Should not reach here" # pragma: no cover @@ -147,7 +146,7 @@ class OptionalSchema(BaseModel): subscribe: bool | None = Field(default=False, description="Subscribe to newsletter?") @mcp.tool(description="Tool with optional fields") - async def optional_tool(ctx: Context[ServerSession, None]) -> str: + async def optional_tool(ctx: Context) -> str: result = await ctx.elicit(message="Please provide your information", schema=OptionalSchema) if result.action == "accept" and result.data: @@ -188,7 +187,7 @@ class InvalidOptionalSchema(BaseModel): optional_list: list[int] | None = Field(default=None, description="Invalid optional list") @mcp.tool(description="Tool with invalid optional field") - async def invalid_optional_tool(ctx: Context[ServerSession, None]) -> str: + async def invalid_optional_tool(ctx: Context) -> str: try: await ctx.elicit(message="This should fail", schema=InvalidOptionalSchema) return "Should not reach here" # pragma: no cover @@ -214,7 +213,7 @@ class ValidMultiSelectSchema(BaseModel): tags: list[str] = Field(description="Tags") @mcp.tool(description="Tool with valid list[str] field") - async def valid_multiselect_tool(ctx: Context[ServerSession, None]) -> str: + async def valid_multiselect_tool(ctx: Context) -> str: result = await ctx.elicit(message="Please provide tags", schema=ValidMultiSelectSchema) if result.action == "accept" and result.data: return f"Name: {result.data.name}, Tags: {', '.join(result.data.tags)}" @@ -233,7 +232,7 @@ class OptionalMultiSelectSchema(BaseModel): tags: list[str] | None = Field(default=None, description="Optional tags") @mcp.tool(description="Tool with optional list[str] field") - async def optional_multiselect_tool(ctx: Context[ServerSession, None]) -> str: + async def optional_multiselect_tool(ctx: Context) -> str: result = await ctx.elicit(message="Please provide optional tags", schema=OptionalMultiSelectSchema) if result.action == "accept" and result.data: tags_str = ", ".join(result.data.tags) if result.data.tags else "none" @@ -262,7 +261,7 @@ class DefaultsSchema(BaseModel): email: str = Field(description="Email address (required)") @mcp.tool(description="Tool with default values") - async def defaults_tool(ctx: Context[ServerSession, None]) -> str: + async def defaults_tool(ctx: Context) -> str: result = await ctx.elicit(message="Please provide your information", schema=DefaultsSchema) if result.action == "accept" and result.data: @@ -327,7 +326,7 @@ class FavoriteColorSchema(BaseModel): ) @mcp.tool(description="Single color selection") - async def select_favorite_color(ctx: Context[ServerSession, None]) -> str: + async def select_favorite_color(ctx: Context) -> str: result = await ctx.elicit(message="Select your favorite color", schema=FavoriteColorSchema) if result.action == "accept" and result.data: return f"User: {result.data.user_name}, Favorite: {result.data.favorite_color}" @@ -351,7 +350,7 @@ class FavoriteColorsSchema(BaseModel): ) @mcp.tool(description="Multiple color selection") - async def select_favorite_colors(ctx: Context[ServerSession, None]) -> str: + async def select_favorite_colors(ctx: Context) -> str: result = await ctx.elicit(message="Select your favorite colors", schema=FavoriteColorsSchema) if result.action == "accept" and result.data: return f"User: {result.data.user_name}, Colors: {', '.join(result.data.favorite_colors)}" @@ -366,7 +365,7 @@ class LegacyColorSchema(BaseModel): ) @mcp.tool(description="Legacy enum format") - async def select_color_legacy(ctx: Context[ServerSession, None]) -> str: + async def select_color_legacy(ctx: Context) -> str: result = await ctx.elicit(message="Select a color (legacy format)", schema=LegacyColorSchema) if result.action == "accept" and result.data: return f"User: {result.data.user_name}, Color: {result.data.color}" diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index 3d130bfc33..3ef06d0381 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -17,7 +17,6 @@ from mcp.server.mcpserver.prompts.base import Message, UserMessage from mcp.server.mcpserver.resources import FileResource, FunctionResource from mcp.server.mcpserver.utilities.types import Audio, Image -from mcp.server.session import ServerSession from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.exceptions import MCPError from mcp.types import ( @@ -1003,7 +1002,7 @@ async def test_context_detection(self): """Test that context parameters are properly detected.""" mcp = MCPServer() - def tool_with_context(x: int, ctx: Context[ServerSession, None]) -> str: # pragma: no cover + def tool_with_context(x: int, ctx: Context) -> str: # pragma: no cover return f"Request {ctx.request_id}: {x}" tool = mcp._tool_manager.add_tool(tool_with_context) @@ -1013,7 +1012,7 @@ async def test_context_injection(self): """Test that context is properly injected into tool calls.""" mcp = MCPServer() - def tool_with_context(x: int, ctx: Context[ServerSession, None]) -> str: + def tool_with_context(x: int, ctx: Context) -> str: assert ctx.request_id is not None return f"Request {ctx.request_id}: {x}" @@ -1030,7 +1029,7 @@ async def test_async_context(self): """Test that context works in async functions.""" mcp = MCPServer() - async def async_tool(x: int, ctx: Context[ServerSession, None]) -> str: + async def async_tool(x: int, ctx: Context) -> str: assert ctx.request_id is not None return f"Async request {ctx.request_id}: {x}" @@ -1047,7 +1046,7 @@ async def test_context_logging(self): """Test that context logging methods work.""" mcp = MCPServer() - async def logging_tool(msg: str, ctx: Context[ServerSession, None]) -> str: + async def logging_tool(msg: str, ctx: Context) -> str: await ctx.debug("Debug message") await ctx.info("Info message") await ctx.warning("Warning message") @@ -1094,7 +1093,7 @@ def test_resource() -> str: return "resource data" @mcp.tool() - async def tool_with_resource(ctx: Context[ServerSession, None]) -> str: + async def tool_with_resource(ctx: Context) -> str: r_iter = await ctx.read_resource("test://data") r_list = list(r_iter) assert len(r_list) == 1 @@ -1113,7 +1112,7 @@ async def test_resource_with_context(self): mcp = MCPServer() @mcp.resource("resource://context/{name}") - def resource_with_context(name: str, ctx: Context[ServerSession, None]) -> str: + def resource_with_context(name: str, ctx: Context) -> str: """Resource that receives context.""" assert ctx is not None return f"Resource {name} - context injected" @@ -1166,7 +1165,7 @@ async def test_resource_context_custom_name(self): mcp = MCPServer() @mcp.resource("resource://custom/{id}") - def resource_custom_ctx(id: str, my_ctx: Context[ServerSession, None]) -> str: + def resource_custom_ctx(id: str, my_ctx: Context) -> str: """Resource with custom context parameter name.""" assert my_ctx is not None return f"Resource {id} with context" @@ -1194,7 +1193,7 @@ async def test_prompt_with_context(self): mcp = MCPServer() @mcp.prompt("prompt_with_ctx") - def prompt_with_context(text: str, ctx: Context[ServerSession, None]) -> str: + def prompt_with_context(text: str, ctx: Context) -> str: """Prompt that expects context.""" assert ctx is not None return f"Prompt '{text}' - context injected" diff --git a/tests/server/mcpserver/test_tool_manager.py b/tests/server/mcpserver/test_tool_manager.py index f990ec47b7..e4dfd4ff9b 100644 --- a/tests/server/mcpserver/test_tool_manager.py +++ b/tests/server/mcpserver/test_tool_manager.py @@ -11,7 +11,6 @@ from mcp.server.mcpserver.exceptions import ToolError from mcp.server.mcpserver.tools import Tool, ToolManager from mcp.server.mcpserver.utilities.func_metadata import ArgModelBase, FuncMetadata -from mcp.server.session import ServerSessionT from mcp.types import TextContent, ToolAnnotations @@ -319,7 +318,7 @@ def name_shrimp(tank: MyShrimpTank) -> list[str]: class TestToolSchema: @pytest.mark.anyio async def test_context_arg_excluded_from_schema(self): - def something(a: int, ctx: Context[ServerSessionT, None]) -> int: # pragma: no cover + def something(a: int, ctx: Context) -> int: # pragma: no cover return a manager = ToolManager() @@ -336,7 +335,7 @@ def test_context_parameter_detection(self): """Test that context parameters are properly detected in Tool.from_function().""" - def tool_with_context(x: int, ctx: Context[ServerSessionT, None]) -> str: # pragma: no cover + def tool_with_context(x: int, ctx: Context) -> str: # pragma: no cover return str(x) manager = ToolManager() @@ -359,7 +358,7 @@ def tool_with_parametrized_context(x: int, ctx: Context[LifespanContextT, Reques async def test_context_injection(self): """Test that context is properly injected during tool execution.""" - def tool_with_context(x: int, ctx: Context[ServerSessionT, None]) -> str: + def tool_with_context(x: int, ctx: Context) -> str: assert isinstance(ctx, Context) return str(x) @@ -373,7 +372,7 @@ def tool_with_context(x: int, ctx: Context[ServerSessionT, None]) -> str: async def test_context_injection_async(self): """Test that context is properly injected in async tools.""" - async def async_tool(x: int, ctx: Context[ServerSessionT, None]) -> str: + async def async_tool(x: int, ctx: Context) -> str: assert isinstance(ctx, Context) return str(x) @@ -387,7 +386,7 @@ async def async_tool(x: int, ctx: Context[ServerSessionT, None]) -> str: async def test_context_error_handling(self): """Test error handling when context injection fails.""" - def tool_with_context(x: int, ctx: Context[ServerSessionT, None]) -> str: + def tool_with_context(x: int, ctx: Context) -> str: raise ValueError("Test error") manager = ToolManager() diff --git a/tests/server/mcpserver/test_url_elicitation.py b/tests/server/mcpserver/test_url_elicitation.py index 1311bd6728..af90dc208b 100644 --- a/tests/server/mcpserver/test_url_elicitation.py +++ b/tests/server/mcpserver/test_url_elicitation.py @@ -8,7 +8,6 @@ from mcp.client.session import ClientSession from mcp.server.elicitation import CancelledElicitation, DeclinedElicitation, elicit_url from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession from mcp.shared._context import RequestContext from mcp.types import ElicitRequestParams, ElicitResult, TextContent @@ -19,7 +18,7 @@ async def test_url_elicitation_accept(): mcp = MCPServer(name="URLElicitationServer") @mcp.tool(description="A tool that uses URL elicitation") - async def request_api_key(ctx: Context[ServerSession, None]) -> str: + async def request_api_key(ctx: Context) -> str: result = await ctx.session.elicit_url( message="Please provide your API key to continue.", url="https://example.com/api_key_setup", @@ -49,7 +48,7 @@ async def test_url_elicitation_decline(): mcp = MCPServer(name="URLElicitationDeclineServer") @mcp.tool(description="A tool that uses URL elicitation") - async def oauth_flow(ctx: Context[ServerSession, None]) -> str: + async def oauth_flow(ctx: Context) -> str: result = await ctx.session.elicit_url( message="Authorize access to your files.", url="https://example.com/oauth/authorize", @@ -75,7 +74,7 @@ async def test_url_elicitation_cancel(): mcp = MCPServer(name="URLElicitationCancelServer") @mcp.tool(description="A tool that uses URL elicitation") - async def payment_flow(ctx: Context[ServerSession, None]) -> str: + async def payment_flow(ctx: Context) -> str: result = await ctx.session.elicit_url( message="Complete payment to proceed.", url="https://example.com/payment", @@ -101,7 +100,7 @@ async def test_url_elicitation_helper_function(): mcp = MCPServer(name="URLElicitationHelperServer") @mcp.tool(description="Tool using elicit_url helper") - async def setup_credentials(ctx: Context[ServerSession, None]) -> str: + async def setup_credentials(ctx: Context) -> str: result = await elicit_url( session=ctx.session, message="Set up your credentials", @@ -127,7 +126,7 @@ async def test_url_no_content_in_response(): mcp = MCPServer(name="URLContentCheckServer") @mcp.tool(description="Check URL response format") - async def check_url_response(ctx: Context[ServerSession, None]) -> str: + async def check_url_response(ctx: Context) -> str: result = await ctx.session.elicit_url( message="Test message", url="https://example.com/test", @@ -164,7 +163,7 @@ class NameSchema(BaseModel): name: str = Field(description="Your name") @mcp.tool(description="Test form mode") - async def ask_name(ctx: Context[ServerSession, None]) -> str: + async def ask_name(ctx: Context) -> str: result = await ctx.elicit(message="What is your name?", schema=NameSchema) # Test only checks accept path with data assert result.action == "accept" @@ -195,7 +194,7 @@ async def test_elicit_complete_notification(): notification_sent = False @mcp.tool(description="Tool that sends completion notification") - async def trigger_elicitation(ctx: Context[ServerSession, None]) -> str: + async def trigger_elicitation(ctx: Context) -> str: nonlocal notification_sent # Simulate an async operation (e.g., user completing auth in browser) @@ -238,7 +237,7 @@ async def test_elicit_url_typed_results(): mcp = MCPServer(name="TypedResultsServer") @mcp.tool(description="Test declined result") - async def test_decline(ctx: Context[ServerSession, None]) -> str: + async def test_decline(ctx: Context) -> str: result = await elicit_url( session=ctx.session, message="Test decline", @@ -251,7 +250,7 @@ async def test_decline(ctx: Context[ServerSession, None]) -> str: return "Not declined" # pragma: no cover @mcp.tool(description="Test cancelled result") - async def test_cancel(ctx: Context[ServerSession, None]) -> str: + async def test_cancel(ctx: Context) -> str: result = await elicit_url( session=ctx.session, message="Test cancel", @@ -293,7 +292,7 @@ class EmailSchema(BaseModel): email: str = Field(description="Email address") @mcp.tool(description="Test deprecated elicit method") - async def use_deprecated_elicit(ctx: Context[ServerSession, None]) -> str: + async def use_deprecated_elicit(ctx: Context) -> str: # Use the deprecated elicit() method which should call elicit_form() result = await ctx.session.elicit( message="Enter your email", @@ -323,7 +322,7 @@ async def test_ctx_elicit_url_convenience_method(): mcp = MCPServer(name="CtxElicitUrlServer") @mcp.tool(description="A tool that uses ctx.elicit_url() directly") - async def direct_elicit_url(ctx: Context[ServerSession, None]) -> str: + async def direct_elicit_url(ctx: Context) -> str: # Use ctx.elicit_url() directly instead of ctx.session.elicit_url() result = await ctx.elicit_url( message="Test the convenience method", diff --git a/tests/server/mcpserver/test_url_elicitation_error_throw.py b/tests/server/mcpserver/test_url_elicitation_error_throw.py index 2d29937995..1f45fd60f0 100644 --- a/tests/server/mcpserver/test_url_elicitation_error_throw.py +++ b/tests/server/mcpserver/test_url_elicitation_error_throw.py @@ -5,7 +5,6 @@ from mcp import Client, ErrorData, types from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession from mcp.shared.exceptions import MCPError, UrlElicitationRequiredError @@ -15,7 +14,7 @@ async def test_url_elicitation_error_thrown_from_tool(): mcp = MCPServer(name="UrlElicitationErrorServer") @mcp.tool(description="A tool that raises UrlElicitationRequiredError") - async def connect_service(service_name: str, ctx: Context[ServerSession, None]) -> str: + async def connect_service(service_name: str, ctx: Context) -> str: # This tool cannot proceed without authorization raise UrlElicitationRequiredError( [ @@ -56,7 +55,7 @@ async def test_url_elicitation_error_from_error(): mcp = MCPServer(name="UrlElicitationErrorServer") @mcp.tool(description="A tool that raises UrlElicitationRequiredError with multiple elicitations") - async def multi_auth(ctx: Context[ServerSession, None]) -> str: + async def multi_auth(ctx: Context) -> str: raise UrlElicitationRequiredError( [ types.ElicitRequestURLParams( @@ -97,7 +96,7 @@ async def test_normal_exceptions_still_return_error_result(): mcp = MCPServer(name="NormalErrorServer") @mcp.tool(description="A tool that raises a normal exception") - async def failing_tool(ctx: Context[ServerSession, None]) -> str: + async def failing_tool(ctx: Context) -> str: raise ValueError("Something went wrong") async with Client(mcp) as client: diff --git a/tests/server/test_lifespan.py b/tests/server/test_lifespan.py index 0f8840d291..0d87905042 100644 --- a/tests/server/test_lifespan.py +++ b/tests/server/test_lifespan.py @@ -11,7 +11,6 @@ from mcp.server.lowlevel.server import NotificationOptions, Server from mcp.server.mcpserver import Context, MCPServer from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession from mcp.shared.message import SessionMessage from mcp.types import ( CallToolRequestParams, @@ -143,7 +142,7 @@ async def test_lifespan(server: MCPServer) -> AsyncIterator[dict[str, bool]]: # Add a tool that checks lifespan context @server.tool() - def check_lifespan(ctx: Context[ServerSession, None]) -> bool: + def check_lifespan(ctx: Context) -> bool: """Tool that checks lifespan context.""" assert isinstance(ctx.request_context.lifespan_context, dict) assert ctx.request_context.lifespan_context["started"] From 62eb08e5b23944510b8ec500a51c8f895fb58553 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 9 Mar 2026 17:47:27 +0000 Subject: [PATCH 11/60] fix: don't send log notification on transport error (#2257) --- src/mcp/server/lowlevel/server.py | 5 -- src/mcp/shared/session.py | 2 +- .../test_lowlevel_exception_handling.py | 82 +++++++++++++------ 3 files changed, 58 insertions(+), 31 deletions(-) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 1c84c86107..167f34b8bc 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -414,11 +414,6 @@ async def _handle_message( ) case Exception(): logger.error(f"Received exception from stream: {message}") - await session.send_log_message( - level="error", - data="Internal Server Error", - logger="mcp.server.exception_handler", - ) if raise_exceptions: raise message case _: diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index b617d702fe..9364abb73b 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -334,7 +334,7 @@ async def _receive_loop(self) -> None: async with self._read_stream, self._write_stream: try: async for message in self._read_stream: - if isinstance(message, Exception): # pragma: no cover + if isinstance(message, Exception): await self._handle_incoming(message) elif isinstance(message.message, JSONRPCRequest): try: diff --git a/tests/server/test_lowlevel_exception_handling.py b/tests/server/test_lowlevel_exception_handling.py index 848b35b299..46925916d9 100644 --- a/tests/server/test_lowlevel_exception_handling.py +++ b/tests/server/test_lowlevel_exception_handling.py @@ -1,55 +1,42 @@ from unittest.mock import AsyncMock, Mock +import anyio import pytest from mcp import types from mcp.server.lowlevel.server import Server from mcp.server.session import ServerSession +from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder @pytest.mark.anyio async def test_exception_handling_with_raise_exceptions_true(): - """Test that exceptions are re-raised when raise_exceptions=True""" + """Transport exceptions are re-raised when raise_exceptions=True.""" server = Server("test-server") session = Mock(spec=ServerSession) - session.send_log_message = AsyncMock() test_exception = RuntimeError("Test error") with pytest.raises(RuntimeError, match="Test error"): await server._handle_message(test_exception, session, {}, raise_exceptions=True) - session.send_log_message.assert_called_once() - @pytest.mark.anyio -@pytest.mark.parametrize( - "exception_class,message", - [ - (ValueError, "Test validation error"), - (RuntimeError, "Test runtime error"), - (KeyError, "Test key error"), - (Exception, "Basic error"), - ], -) -async def test_exception_handling_with_raise_exceptions_false(exception_class: type[Exception], message: str): - """Test that exceptions are logged when raise_exceptions=False""" +async def test_exception_handling_with_raise_exceptions_false(): + """Transport exceptions are logged locally but not sent to the client. + + The transport that reported the error is likely broken; writing back + through it races with stream closure (#1967, #2064). The TypeScript, + Go, and C# SDKs all log locally only. + """ server = Server("test-server") session = Mock(spec=ServerSession) session.send_log_message = AsyncMock() - test_exception = exception_class(message) - - await server._handle_message(test_exception, session, {}, raise_exceptions=False) - - # Should send log message - session.send_log_message.assert_called_once() - call_args = session.send_log_message.call_args + await server._handle_message(RuntimeError("Test error"), session, {}, raise_exceptions=False) - assert call_args.kwargs["level"] == "error" - assert call_args.kwargs["data"] == "Internal Server Error" - assert call_args.kwargs["logger"] == "mcp.server.exception_handler" + session.send_log_message.assert_not_called() @pytest.mark.anyio @@ -72,3 +59,48 @@ async def test_normal_message_handling_not_affected(): # Verify _handle_request was called server._handle_request.assert_called_once() + + +@pytest.mark.anyio +async def test_server_run_exits_cleanly_when_transport_yields_exception_then_closes(): + """Regression test for #1967 / #2064. + + Exercises the real Server.run() path with real memory streams, reproducing + what happens in stateless streamable HTTP when a POST handler throws: + + 1. Transport yields an Exception into the read stream + (streamable_http.py does this in its broad POST-handler except). + 2. Transport closes the read stream (terminate() in stateless mode). + 3. _receive_loop exits its `async with read_stream, write_stream:` block, + closing the write stream. + 4. Meanwhile _handle_message(exc) was spawned via tg.start_soon and runs + after the write stream is closed. + + Before the fix, _handle_message tried to send_log_message through the + closed write stream, raising ClosedResourceError inside the TaskGroup + and crashing server.run(). After the fix, it only logs locally. + """ + server = Server("test-server") + + read_send, read_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + # Zero-buffer on the write stream forces send() to block until received. + # With no receiver, a send() sits blocked until _receive_loop exits its + # `async with self._read_stream, self._write_stream:` block and closes the + # stream, at which point the blocked send raises ClosedResourceError. + # This deterministically reproduces the race without sleeps. + write_send, write_recv = anyio.create_memory_object_stream[SessionMessage](0) + + # What the streamable HTTP transport does: push the exception, then close. + read_send.send_nowait(RuntimeError("simulated transport error")) + read_send.close() + + with anyio.fail_after(5): + # stateless=True so server.run doesn't wait for initialize handshake. + # Before this fix, this raised ExceptionGroup(ClosedResourceError). + await server.run(read_recv, write_send, server.create_initialization_options(), stateless=True) + + # write_send was closed inside _receive_loop's `async with`; receive_nowait + # raises EndOfStream iff the buffer is empty (i.e., server wrote nothing). + with pytest.raises(anyio.EndOfStream): + write_recv.receive_nowait() + write_recv.close() From dd52713517d88541f9233cb2af62012ad65bb993 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 12 Mar 2026 12:52:32 +0000 Subject: [PATCH 12/60] Rewrite TestChildProcessCleanup with socket-based deterministic liveness probe (#2265) --- src/mcp/os/win32/utilities.py | 5 + tests/client/test_stdio.py | 454 +++++++++++++++------------------- 2 files changed, 201 insertions(+), 258 deletions(-) diff --git a/src/mcp/os/win32/utilities.py b/src/mcp/os/win32/utilities.py index 0e188691f1..6f68405f78 100644 --- a/src/mcp/os/win32/utilities.py +++ b/src/mcp/os/win32/utilities.py @@ -123,6 +123,11 @@ def pid(self) -> int: """Return the process ID.""" return self.popen.pid + @property + def returncode(self) -> int | None: + """Return the exit code, or ``None`` if the process has not yet terminated.""" + return self.popen.returncode + # ------------------------ # Updated function diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index f70c24eee7..06e2cba4b1 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -1,12 +1,12 @@ import errno -import os import shutil import sys -import tempfile import textwrap import time +from contextlib import AsyncExitStack, suppress import anyio +import anyio.abc import pytest from mcp.client.session import ClientSession @@ -16,12 +16,11 @@ _terminate_process_tree, stdio_client, ) +from mcp.os.win32.utilities import FallbackProcess from mcp.shared.exceptions import MCPError from mcp.shared.message import SessionMessage from mcp.types import CONNECTION_CLOSED, JSONRPCMessage, JSONRPCRequest, JSONRPCResponse -from ..shared.test_win32_utils import escape_path_for_python - # Timeout for cleanup of processes that ignore SIGTERM # This timeout ensures the test fails quickly if the cleanup logic doesn't have # proper fallback mechanisms (SIGINT/SIGKILL) for processes that ignore SIGTERM @@ -221,291 +220,230 @@ def sigint_handler(signum, frame): raise -class TestChildProcessCleanup: - """Tests for child process cleanup functionality using _terminate_process_tree. - - These tests verify that child processes are properly terminated when the parent - is killed, addressing the issue where processes like npx spawn child processes - that need to be cleaned up. The tests cover various process tree scenarios: - - - Basic parent-child relationship (single child process) - - Multi-level process trees (parent → child → grandchild) - - Race conditions where parent exits during cleanup - - Note on Windows ResourceWarning: - On Windows, we may see ResourceWarning about subprocess still running. This is - expected behavior due to how Windows process termination works: - - anyio's process.terminate() calls Windows TerminateProcess() API - - TerminateProcess() immediately kills the process without allowing cleanup - - subprocess.Popen objects in the killed process can't run their cleanup code - - Python detects this during garbage collection and issues a ResourceWarning - - This warning does NOT indicate a process leak - the processes are properly - terminated. It only means the Popen objects couldn't clean up gracefully. - This is a fundamental difference between Windows and Unix process termination. - """ +# --------------------------------------------------------------------------- +# TestChildProcessCleanup — socket-based deterministic child liveness probe +# --------------------------------------------------------------------------- +# +# These tests verify that `_terminate_process_tree()` kills the *entire* process +# tree (not just the immediate child), which is critical for cleaning up tools +# like `npx` that spawn their own subprocesses. +# +# Mechanism: each subprocess in the tree connects a TCP socket back to a +# listener owned by the test. We then use two kernel-guaranteed blocking-I/O +# signals — neither requires any `sleep()` or polling loop: +# +# 1. `await listener.accept()` blocks until the subprocess connects, +# proving it is running. +# 2. After `_terminate_process_tree()`, `await stream.receive(1)` raises +# `EndOfStream` (clean close / FIN) or `BrokenResourceError` (abrupt +# close / RST — typical on Windows after TerminateJobObject) because the +# kernel closes all file descriptors when a process terminates. Either +# is the direct, OS-level proof that the child is dead. +# +# This replaces an older file-growth-watching approach whose fixed `sleep()` +# durations raced against slow Python interpreter startup on loaded CI runners. + + +def _connect_back_script(port: int) -> str: + """Return a ``python -c`` script body that connects to the given port, + sends ``b'alive'``, then blocks forever. Used by TestChildProcessCleanup + subprocesses as a liveness probe.""" + return ( + f"import socket, time\n" + f"s = socket.create_connection(('127.0.0.1', {port}))\n" + f"s.sendall(b'alive')\n" + f"time.sleep(3600)\n" + ) - @pytest.mark.anyio - @pytest.mark.filterwarnings("ignore::ResourceWarning" if sys.platform == "win32" else "default") - async def test_basic_child_process_cleanup(self): - """Test basic parent-child process cleanup. - Parent spawns a single child process that writes continuously to a file. - """ - # Create a marker file for the child process to write to - with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: - marker_file = f.name - # Also create a file to verify parent started - with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: - parent_marker = f.name +def _spawn_then_block(child_script: str) -> str: + """Return a ``python -c`` script body that spawns ``child_script`` as a + subprocess, then blocks forever. The ``!r`` injection avoids nested-quote + escaping for arbitrary child script content.""" + return ( + f"import subprocess, sys, time\nsubprocess.Popen([sys.executable, '-c', {child_script!r}])\ntime.sleep(3600)\n" + ) - try: - # Parent script that spawns a child process - parent_script = textwrap.dedent( - f""" - import subprocess - import sys - import time - import os - - # Mark that parent started - with open({escape_path_for_python(parent_marker)}, 'w') as f: - f.write('parent started\\n') - - # Child script that writes continuously - child_script = f''' - import time - with open({escape_path_for_python(marker_file)}, 'a') as f: - while True: - f.write(f"{time.time()}") - f.flush() - time.sleep(0.1) - ''' - - # Start the child process - child = subprocess.Popen([sys.executable, '-c', child_script]) - - # Parent just sleeps - while True: - time.sleep(0.1) - """ - ) - print("\nStarting child process termination test...") +async def _open_liveness_listener() -> tuple[anyio.abc.SocketListener, int]: + """Open a TCP listener on localhost and return it along with its port.""" + multi = await anyio.create_tcp_listener(local_host="127.0.0.1") + sock = multi.listeners[0] + assert isinstance(sock, anyio.abc.SocketListener) + addr = sock.extra(anyio.abc.SocketAttribute.local_address) + # IPv4 local_address is (host: str, port: int) + assert isinstance(addr, tuple) and len(addr) >= 2 and isinstance(addr[1], int) + return sock, addr[1] + + +async def _accept_alive(sock: anyio.abc.SocketListener) -> anyio.abc.SocketStream: + """Accept one connection and assert the peer sent ``b'alive'``. + + Blocks deterministically until a subprocess connects (no polling). The + outer test bounds this with ``anyio.fail_after`` to catch the case where + the subprocess chain failed to start. + """ + stream = await sock.accept() + msg = await stream.receive(5) + assert msg == b"alive", f"expected b'alive', got {msg!r}" + return stream - # Start the parent process - proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script]) - # Wait for processes to start - await anyio.sleep(0.5) +async def _assert_stream_closed(stream: anyio.abc.SocketStream) -> None: + """Assert the peer holding the other end of ``stream`` has terminated. - # Verify parent started - assert os.path.exists(parent_marker), "Parent process didn't start" + When a process dies, the kernel closes its file descriptors including + sockets. The next ``receive()`` on the peer socket unblocks with one of: - # Verify child is writing - if os.path.exists(marker_file): # pragma: no branch - initial_size = os.path.getsize(marker_file) - await anyio.sleep(0.3) - size_after_wait = os.path.getsize(marker_file) - assert size_after_wait > initial_size, "Child process should be writing" - print(f"Child is writing (file grew from {initial_size} to {size_after_wait} bytes)") + - ``anyio.EndOfStream`` — clean close (FIN), typical after graceful exit + or POSIX ``SIGTERM``. + - ``anyio.BrokenResourceError`` — abrupt close (RST), typical after + Windows ``TerminateJobObject`` or POSIX ``SIGKILL``. - # Terminate using our function - print("Terminating process and children...") + Either is a deterministic, kernel-level signal that the process is dead — + no sleeps or polling required. + """ + with anyio.fail_after(5.0), pytest.raises((anyio.EndOfStream, anyio.BrokenResourceError)): + await stream.receive(1) + + +async def _terminate_and_reap(proc: anyio.abc.Process | FallbackProcess) -> None: + """Terminate the process tree, reap, and tear down pipe transports. + + ``_terminate_process_tree`` kills the OS process group / Job Object but does + not call ``process.wait()`` or clean up the asyncio pipe transports. On + Windows those transports leak and emit ``ResourceWarning`` when GC'd in a + later test, causing ``PytestUnraisableExceptionWarning`` knock-on failures. + + Production ``stdio.py`` avoids this via its ``stdout_reader`` task which + reads stdout to EOF (triggering ``_ProactorReadPipeTransport._eof_received`` + → ``close()``) plus ``async with process:`` which waits and closes stdin. + These tests call ``_terminate_process_tree`` directly, so they replicate + both parts here: ``wait()`` + close stdin + drain stdout to EOF. + + The stdout drain is the non-obvious part: anyio's ``StreamReaderWrapper.aclose()`` + only marks the Python-level reader closed — it never touches the underlying + ``_ProactorReadPipeTransport``. That transport starts paused and only detects + pipe EOF when someone reads, so without a drain it lives until ``__del__``. + + Idempotent: the ``returncode`` guard skips termination if already reaped + (avoids spurious WARNING/ERROR logs from ``terminate_posix_process_tree``'s + fallback path, visible because ``log_cli = true``); ``wait()`` and stream + ``aclose()`` no-op on subsequent calls; the drain raises ``ClosedResourceError`` + on the second call, caught by the suppress. The tests call this explicitly + as the action under test and ``AsyncExitStack`` calls it again on exit as a + safety net. Bounded by ``move_on_after`` to prevent hangs. + """ + with anyio.move_on_after(5.0): + if proc.returncode is None: await _terminate_process_tree(proc) + await proc.wait() + assert proc.stdin is not None + assert proc.stdout is not None + await proc.stdin.aclose() + with suppress(anyio.EndOfStream, anyio.BrokenResourceError, anyio.ClosedResourceError): + await proc.stdout.receive(65536) + await proc.stdout.aclose() - # Verify processes stopped - await anyio.sleep(0.5) - if os.path.exists(marker_file): # pragma: no branch - size_after_cleanup = os.path.getsize(marker_file) - await anyio.sleep(0.5) - final_size = os.path.getsize(marker_file) - print(f"After cleanup: file size {size_after_cleanup} -> {final_size}") - assert final_size == size_after_cleanup, ( - f"Child process still running! File grew by {final_size - size_after_cleanup} bytes" - ) +class TestChildProcessCleanup: + """Integration tests for ``_terminate_process_tree`` covering basic, + nested, and early-parent-exit process tree scenarios. See module-level + comment above for the socket-based liveness probe mechanism. + """ - print("SUCCESS: Child process was properly terminated") + @pytest.mark.anyio + async def test_basic_child_process_cleanup(self): + """Parent spawns one child; terminating the tree kills both.""" + async with AsyncExitStack() as stack: + sock, port = await _open_liveness_listener() + stack.push_async_callback(sock.aclose) + + # Parent spawns a child; the child connects back to us. + parent_script = _spawn_then_block(_connect_back_script(port)) + proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script]) + stack.push_async_callback(_terminate_and_reap, proc) + + # Deterministic: accept() blocks until the child connects. No sleep. + with anyio.fail_after(10.0): + stream = await _accept_alive(sock) + stack.push_async_callback(stream.aclose) - finally: - # Clean up files - for f in [marker_file, parent_marker]: - try: - os.unlink(f) - except OSError: # pragma: no cover - pass + # Terminate, reap and close transports (wraps _terminate_process_tree, + # the behavior under test). + await _terminate_and_reap(proc) + + # Deterministic: kernel closed child's socket when it died. + await _assert_stream_closed(stream) @pytest.mark.anyio - @pytest.mark.filterwarnings("ignore::ResourceWarning" if sys.platform == "win32" else "default") async def test_nested_process_tree(self): - """Test nested process tree cleanup (parent → child → grandchild). - Each level writes to a different file to verify all processes are terminated. - """ - # Create temporary files for each process level - with tempfile.NamedTemporaryFile(mode="w", delete=False) as f1: - parent_file = f1.name - with tempfile.NamedTemporaryFile(mode="w", delete=False) as f2: - child_file = f2.name - with tempfile.NamedTemporaryFile(mode="w", delete=False) as f3: - grandchild_file = f3.name - - try: - # Simple nested process tree test - # We create parent -> child -> grandchild, each writing to a file - parent_script = textwrap.dedent( - f""" - import subprocess - import sys - import time - import os - - # Child will spawn grandchild and write to child file - child_script = f'''import subprocess - import sys - import time - - # Grandchild just writes to file - grandchild_script = \"\"\"import time - with open({escape_path_for_python(grandchild_file)}, 'a') as f: - while True: - f.write(f"gc {{time.time()}}") - f.flush() - time.sleep(0.1)\"\"\" - - # Spawn grandchild - subprocess.Popen([sys.executable, '-c', grandchild_script]) - - # Child writes to its file - with open({escape_path_for_python(child_file)}, 'a') as f: - while True: - f.write(f"c {time.time()}") - f.flush() - time.sleep(0.1)''' - - # Spawn child process - subprocess.Popen([sys.executable, '-c', child_script]) - - # Parent writes to its file - with open({escape_path_for_python(parent_file)}, 'a') as f: - while True: - f.write(f"p {time.time()}") - f.flush() - time.sleep(0.1) - """ + """Parent → child → grandchild; terminating the tree kills all three.""" + async with AsyncExitStack() as stack: + sock, port = await _open_liveness_listener() + stack.push_async_callback(sock.aclose) + + # Build a three-level chain: parent spawns child, child spawns + # grandchild. Every level connects back to our socket. + grandchild = _connect_back_script(port) + child = ( + f"import subprocess, sys\n" + f"subprocess.Popen([sys.executable, '-c', {grandchild!r}])\n" + _connect_back_script(port) + ) + parent_script = ( + f"import subprocess, sys\n" + f"subprocess.Popen([sys.executable, '-c', {child!r}])\n" + _connect_back_script(port) ) - - # Start the parent process proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script]) + stack.push_async_callback(_terminate_and_reap, proc) - # Let all processes start - await anyio.sleep(1.0) - - # Verify all are writing - for file_path, name in [(parent_file, "parent"), (child_file, "child"), (grandchild_file, "grandchild")]: - if os.path.exists(file_path): # pragma: no branch - initial_size = os.path.getsize(file_path) - await anyio.sleep(0.3) - new_size = os.path.getsize(file_path) - assert new_size > initial_size, f"{name} process should be writing" + # Deterministic: three blocking accepts, one per tree level. + streams: list[anyio.abc.SocketStream] = [] + with anyio.fail_after(10.0): + for _ in range(3): + stream = await _accept_alive(sock) + stack.push_async_callback(stream.aclose) + streams.append(stream) - # Terminate the whole tree - await _terminate_process_tree(proc) + # Terminate the entire tree (wraps _terminate_process_tree). + await _terminate_and_reap(proc) - # Verify all stopped - await anyio.sleep(0.5) - for file_path, name in [(parent_file, "parent"), (child_file, "child"), (grandchild_file, "grandchild")]: - if os.path.exists(file_path): # pragma: no branch - size1 = os.path.getsize(file_path) - await anyio.sleep(0.3) - size2 = os.path.getsize(file_path) - assert size1 == size2, f"{name} still writing after cleanup!" - - print("SUCCESS: All processes in tree terminated") - - finally: - # Clean up all marker files - for f in [parent_file, child_file, grandchild_file]: - try: - os.unlink(f) - except OSError: # pragma: no cover - pass + # Every level of the tree must be dead: three kernel-level EOFs. + for stream in streams: + await _assert_stream_closed(stream) @pytest.mark.anyio - @pytest.mark.filterwarnings("ignore::ResourceWarning" if sys.platform == "win32" else "default") async def test_early_parent_exit(self): - """Test cleanup when parent exits during termination sequence. - Tests the race condition where parent might die during our termination - sequence but we can still clean up the children via the process group. + """Parent exits immediately on SIGTERM; process-group termination still + catches the child (exercises the race where the parent dies mid-cleanup). """ - # Create a temporary file for the child - with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: - marker_file = f.name - - try: - # Parent that spawns child and waits briefly - parent_script = textwrap.dedent( - f""" - import subprocess - import sys - import time - import signal - - # Child that continues running - child_script = f'''import time - with open({escape_path_for_python(marker_file)}, 'a') as f: - while True: - f.write(f"child {time.time()}") - f.flush() - time.sleep(0.1)''' - - # Start child in same process group - subprocess.Popen([sys.executable, '-c', child_script]) - - # Parent waits a bit then exits on SIGTERM - def handle_term(sig, frame): - sys.exit(0) - - signal.signal(signal.SIGTERM, handle_term) - - # Wait - while True: - time.sleep(0.1) - """ + async with AsyncExitStack() as stack: + sock, port = await _open_liveness_listener() + stack.push_async_callback(sock.aclose) + + # Parent installs a SIGTERM handler that exits immediately, spawns a + # child that connects back to us, then blocks. + child = _connect_back_script(port) + parent_script = ( + f"import signal, subprocess, sys, time\n" + f"signal.signal(signal.SIGTERM, lambda *_: sys.exit(0))\n" + f"subprocess.Popen([sys.executable, '-c', {child!r}])\n" + f"time.sleep(3600)\n" ) - - # Start the parent process proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script]) + stack.push_async_callback(_terminate_and_reap, proc) - # Let child start writing - await anyio.sleep(0.5) - - # Verify child is writing - if os.path.exists(marker_file): # pragma: no branch - size1 = os.path.getsize(marker_file) - await anyio.sleep(0.3) - size2 = os.path.getsize(marker_file) - assert size2 > size1, "Child should be writing" + # Deterministic: child connected means both parent and child are up. + with anyio.fail_after(10.0): + stream = await _accept_alive(sock) + stack.push_async_callback(stream.aclose) - # Terminate - this will kill the process group even if parent exits first - await _terminate_process_tree(proc) + # Parent will sys.exit(0) on SIGTERM, but the process-group kill + # (POSIX killpg / Windows Job Object) must still terminate the child. + await _terminate_and_reap(proc) - # Verify child stopped - await anyio.sleep(0.5) - if os.path.exists(marker_file): # pragma: no branch - size3 = os.path.getsize(marker_file) - await anyio.sleep(0.3) - size4 = os.path.getsize(marker_file) - assert size3 == size4, "Child should be terminated" - - print("SUCCESS: Child terminated even with parent exit during cleanup") - - finally: - # Clean up marker file - try: - os.unlink(marker_file) - except OSError: # pragma: no cover - pass + # Child must be dead despite parent's early exit. + await _assert_stream_closed(stream) @pytest.mark.anyio From 2c73a2a8811de5d4fa34bf23a1bf390d9328bedb Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 13 Mar 2026 10:39:56 +0000 Subject: [PATCH 13/60] chore(deps): bump black from 25.1.0 to 26.3.1 in the uv group across 1 directory (#2290) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- uv.lock | 95 +++++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 72 insertions(+), 23 deletions(-) diff --git a/uv.lock b/uv.lock index d01d510f17..c01e96a4a6 100644 --- a/uv.lock +++ b/uv.lock @@ -96,7 +96,7 @@ wheels = [ [[package]] name = "black" -version = "25.1.0" +version = "26.3.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, @@ -104,28 +104,38 @@ dependencies = [ { name = "packaging" }, { name = "pathspec" }, { name = "platformdirs" }, + { name = "pytokens" }, { name = "tomli", marker = "python_full_version < '3.11'" }, { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/94/49/26a7b0f3f35da4b5a65f081943b7bcd22d7002f5f0fb8098ec1ff21cb6ef/black-25.1.0.tar.gz", hash = "sha256:33496d5cd1222ad73391352b4ae8da15253c5de89b93a80b3e2c8d9a19ec2666", size = 649449, upload-time = "2025-01-29T04:15:40.373Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/4d/3b/4ba3f93ac8d90410423fdd31d7541ada9bcee1df32fb90d26de41ed40e1d/black-25.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:759e7ec1e050a15f89b770cefbf91ebee8917aac5c20483bc2d80a6c3a04df32", size = 1629419, upload-time = "2025-01-29T05:37:06.642Z" }, - { url = "https://files.pythonhosted.org/packages/b4/02/0bde0485146a8a5e694daed47561785e8b77a0466ccc1f3e485d5ef2925e/black-25.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0e519ecf93120f34243e6b0054db49c00a35f84f195d5bce7e9f5cfc578fc2da", size = 1461080, upload-time = "2025-01-29T05:37:09.321Z" }, - { url = "https://files.pythonhosted.org/packages/52/0e/abdf75183c830eaca7589144ff96d49bce73d7ec6ad12ef62185cc0f79a2/black-25.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:055e59b198df7ac0b7efca5ad7ff2516bca343276c466be72eb04a3bcc1f82d7", size = 1766886, upload-time = "2025-01-29T04:18:24.432Z" }, - { url = "https://files.pythonhosted.org/packages/dc/a6/97d8bb65b1d8a41f8a6736222ba0a334db7b7b77b8023ab4568288f23973/black-25.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:db8ea9917d6f8fc62abd90d944920d95e73c83a5ee3383493e35d271aca872e9", size = 1419404, upload-time = "2025-01-29T04:19:04.296Z" }, - { url = "https://files.pythonhosted.org/packages/7e/4f/87f596aca05c3ce5b94b8663dbfe242a12843caaa82dd3f85f1ffdc3f177/black-25.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a39337598244de4bae26475f77dda852ea00a93bd4c728e09eacd827ec929df0", size = 1614372, upload-time = "2025-01-29T05:37:11.71Z" }, - { url = "https://files.pythonhosted.org/packages/e7/d0/2c34c36190b741c59c901e56ab7f6e54dad8df05a6272a9747ecef7c6036/black-25.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:96c1c7cd856bba8e20094e36e0f948718dc688dba4a9d78c3adde52b9e6c2299", size = 1442865, upload-time = "2025-01-29T05:37:14.309Z" }, - { url = "https://files.pythonhosted.org/packages/21/d4/7518c72262468430ead45cf22bd86c883a6448b9eb43672765d69a8f1248/black-25.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bce2e264d59c91e52d8000d507eb20a9aca4a778731a08cfff7e5ac4a4bb7096", size = 1749699, upload-time = "2025-01-29T04:18:17.688Z" }, - { url = "https://files.pythonhosted.org/packages/58/db/4f5beb989b547f79096e035c4981ceb36ac2b552d0ac5f2620e941501c99/black-25.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:172b1dbff09f86ce6f4eb8edf9dede08b1fce58ba194c87d7a4f1a5aa2f5b3c2", size = 1428028, upload-time = "2025-01-29T04:18:51.711Z" }, - { url = "https://files.pythonhosted.org/packages/83/71/3fe4741df7adf015ad8dfa082dd36c94ca86bb21f25608eb247b4afb15b2/black-25.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4b60580e829091e6f9238c848ea6750efed72140b91b048770b64e74fe04908b", size = 1650988, upload-time = "2025-01-29T05:37:16.707Z" }, - { url = "https://files.pythonhosted.org/packages/13/f3/89aac8a83d73937ccd39bbe8fc6ac8860c11cfa0af5b1c96d081facac844/black-25.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1e2978f6df243b155ef5fa7e558a43037c3079093ed5d10fd84c43900f2d8ecc", size = 1453985, upload-time = "2025-01-29T05:37:18.273Z" }, - { url = "https://files.pythonhosted.org/packages/6f/22/b99efca33f1f3a1d2552c714b1e1b5ae92efac6c43e790ad539a163d1754/black-25.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b48735872ec535027d979e8dcb20bf4f70b5ac75a8ea99f127c106a7d7aba9f", size = 1783816, upload-time = "2025-01-29T04:18:33.823Z" }, - { url = "https://files.pythonhosted.org/packages/18/7e/a27c3ad3822b6f2e0e00d63d58ff6299a99a5b3aee69fa77cd4b0076b261/black-25.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:ea0213189960bda9cf99be5b8c8ce66bb054af5e9e861249cd23471bd7b0b3ba", size = 1440860, upload-time = "2025-01-29T04:19:12.944Z" }, - { url = "https://files.pythonhosted.org/packages/98/87/0edf98916640efa5d0696e1abb0a8357b52e69e82322628f25bf14d263d1/black-25.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8f0b18a02996a836cc9c9c78e5babec10930862827b1b724ddfe98ccf2f2fe4f", size = 1650673, upload-time = "2025-01-29T05:37:20.574Z" }, - { url = "https://files.pythonhosted.org/packages/52/e5/f7bf17207cf87fa6e9b676576749c6b6ed0d70f179a3d812c997870291c3/black-25.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:afebb7098bfbc70037a053b91ae8437c3857482d3a690fefc03e9ff7aa9a5fd3", size = 1453190, upload-time = "2025-01-29T05:37:22.106Z" }, - { url = "https://files.pythonhosted.org/packages/e3/ee/adda3d46d4a9120772fae6de454c8495603c37c4c3b9c60f25b1ab6401fe/black-25.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:030b9759066a4ee5e5aca28c3c77f9c64789cdd4de8ac1df642c40b708be6171", size = 1782926, upload-time = "2025-01-29T04:18:58.564Z" }, - { url = "https://files.pythonhosted.org/packages/cc/64/94eb5f45dcb997d2082f097a3944cfc7fe87e071907f677e80788a2d7b7a/black-25.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:a22f402b410566e2d1c950708c77ebf5ebd5d0d88a6a2e87c86d9fb48afa0d18", size = 1442613, upload-time = "2025-01-29T04:19:27.63Z" }, - { url = "https://files.pythonhosted.org/packages/09/71/54e999902aed72baf26bca0d50781b01838251a462612966e9fc4891eadd/black-25.1.0-py3-none-any.whl", hash = "sha256:95e8176dae143ba9097f351d174fdaf0ccd29efb414b362ae3fd72bf0f710717", size = 207646, upload-time = "2025-01-29T04:15:38.082Z" }, +sdist = { url = "https://files.pythonhosted.org/packages/e1/c5/61175d618685d42b005847464b8fb4743a67b1b8fdb75e50e5a96c31a27a/black-26.3.1.tar.gz", hash = "sha256:2c50f5063a9641c7eed7795014ba37b0f5fa227f3d408b968936e24bc0566b07", size = 666155, upload-time = "2026-03-12T03:36:03.593Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/a8/11170031095655d36ebc6664fe0897866f6023892396900eec0e8fdc4299/black-26.3.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:86a8b5035fce64f5dcd1b794cf8ec4d31fe458cf6ce3986a30deb434df82a1d2", size = 1866562, upload-time = "2026-03-12T03:39:58.639Z" }, + { url = "https://files.pythonhosted.org/packages/69/ce/9e7548d719c3248c6c2abfd555d11169457cbd584d98d179111338423790/black-26.3.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5602bdb96d52d2d0672f24f6ffe5218795736dd34807fd0fd55ccd6bf206168b", size = 1703623, upload-time = "2026-03-12T03:40:00.347Z" }, + { url = "https://files.pythonhosted.org/packages/7f/0a/8d17d1a9c06f88d3d030d0b1d4373c1551146e252afe4547ed601c0e697f/black-26.3.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6c54a4a82e291a1fee5137371ab488866b7c86a3305af4026bdd4dc78642e1ac", size = 1768388, upload-time = "2026-03-12T03:40:01.765Z" }, + { url = "https://files.pythonhosted.org/packages/52/79/c1ee726e221c863cde5164f925bacf183dfdf0397d4e3f94889439b947b4/black-26.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:6e131579c243c98f35bce64a7e08e87fb2d610544754675d4a0e73a070a5aa3a", size = 1412969, upload-time = "2026-03-12T03:40:03.252Z" }, + { url = "https://files.pythonhosted.org/packages/73/a5/15c01d613f5756f68ed8f6d4ec0a1e24b82b18889fa71affd3d1f7fad058/black-26.3.1-cp310-cp310-win_arm64.whl", hash = "sha256:5ed0ca58586c8d9a487352a96b15272b7fa55d139fc8496b519e78023a8dab0a", size = 1220345, upload-time = "2026-03-12T03:40:04.892Z" }, + { url = "https://files.pythonhosted.org/packages/17/57/5f11c92861f9c92eb9dddf515530bc2d06db843e44bdcf1c83c1427824bc/black-26.3.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:28ef38aee69e4b12fda8dba75e21f9b4f979b490c8ac0baa7cb505369ac9e1ff", size = 1851987, upload-time = "2026-03-12T03:40:06.248Z" }, + { url = "https://files.pythonhosted.org/packages/54/aa/340a1463660bf6831f9e39646bf774086dbd8ca7fc3cded9d59bbdf4ad0a/black-26.3.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bf9bf162ed91a26f1adba8efda0b573bc6924ec1408a52cc6f82cb73ec2b142c", size = 1689499, upload-time = "2026-03-12T03:40:07.642Z" }, + { url = "https://files.pythonhosted.org/packages/f3/01/b726c93d717d72733da031d2de10b92c9fa4c8d0c67e8a8a372076579279/black-26.3.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:474c27574d6d7037c1bc875a81d9be0a9a4f9ee95e62800dab3cfaadbf75acd5", size = 1754369, upload-time = "2026-03-12T03:40:09.279Z" }, + { url = "https://files.pythonhosted.org/packages/e3/09/61e91881ca291f150cfc9eb7ba19473c2e59df28859a11a88248b5cbbc4d/black-26.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:5e9d0d86df21f2e1677cc4bd090cd0e446278bcbbe49bf3659c308c3e402843e", size = 1413613, upload-time = "2026-03-12T03:40:10.943Z" }, + { url = "https://files.pythonhosted.org/packages/16/73/544f23891b22e7efe4d8f812371ab85b57f6a01b2fc45e3ba2e52ba985b8/black-26.3.1-cp311-cp311-win_arm64.whl", hash = "sha256:9a5e9f45e5d5e1c5b5c29b3bd4265dcc90e8b92cf4534520896ed77f791f4da5", size = 1219719, upload-time = "2026-03-12T03:40:12.597Z" }, + { url = "https://files.pythonhosted.org/packages/dc/f8/da5eae4fc75e78e6dceb60624e1b9662ab00d6b452996046dfa9b8a6025b/black-26.3.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b5e6f89631eb88a7302d416594a32faeee9fb8fb848290da9d0a5f2903519fc1", size = 1895920, upload-time = "2026-03-12T03:40:13.921Z" }, + { url = "https://files.pythonhosted.org/packages/2c/9f/04e6f26534da2e1629b2b48255c264cabf5eedc5141d04516d9d68a24111/black-26.3.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:41cd2012d35b47d589cb8a16faf8a32ef7a336f56356babd9fcf70939ad1897f", size = 1718499, upload-time = "2026-03-12T03:40:15.239Z" }, + { url = "https://files.pythonhosted.org/packages/04/91/a5935b2a63e31b331060c4a9fdb5a6c725840858c599032a6f3aac94055f/black-26.3.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0f76ff19ec5297dd8e66eb64deda23631e642c9393ab592826fd4bdc97a4bce7", size = 1794994, upload-time = "2026-03-12T03:40:17.124Z" }, + { url = "https://files.pythonhosted.org/packages/e7/0a/86e462cdd311a3c2a8ece708d22aba17d0b2a0d5348ca34b40cdcbea512e/black-26.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:ddb113db38838eb9f043623ba274cfaf7d51d5b0c22ecb30afe58b1bb8322983", size = 1420867, upload-time = "2026-03-12T03:40:18.83Z" }, + { url = "https://files.pythonhosted.org/packages/5b/e5/22515a19cb7eaee3440325a6b0d95d2c0e88dd180cb011b12ae488e031d1/black-26.3.1-cp312-cp312-win_arm64.whl", hash = "sha256:dfdd51fc3e64ea4f35873d1b3fb25326773d55d2329ff8449139ebaad7357efb", size = 1230124, upload-time = "2026-03-12T03:40:20.425Z" }, + { url = "https://files.pythonhosted.org/packages/f5/77/5728052a3c0450c53d9bb3945c4c46b91baa62b2cafab6801411b6271e45/black-26.3.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:855822d90f884905362f602880ed8b5df1b7e3ee7d0db2502d4388a954cc8c54", size = 1895034, upload-time = "2026-03-12T03:40:21.813Z" }, + { url = "https://files.pythonhosted.org/packages/52/73/7cae55fdfdfbe9d19e9a8d25d145018965fe2079fa908101c3733b0c55a0/black-26.3.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:8a33d657f3276328ce00e4d37fe70361e1ec7614da5d7b6e78de5426cb56332f", size = 1718503, upload-time = "2026-03-12T03:40:23.666Z" }, + { url = "https://files.pythonhosted.org/packages/e1/87/af89ad449e8254fdbc74654e6467e3c9381b61472cc532ee350d28cfdafb/black-26.3.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f1cd08e99d2f9317292a311dfe578fd2a24b15dbce97792f9c4d752275c1fa56", size = 1793557, upload-time = "2026-03-12T03:40:25.497Z" }, + { url = "https://files.pythonhosted.org/packages/43/10/d6c06a791d8124b843bf325ab4ac7d2f5b98731dff84d6064eafd687ded1/black-26.3.1-cp313-cp313-win_amd64.whl", hash = "sha256:c7e72339f841b5a237ff14f7d3880ddd0fc7f98a1199e8c4327f9a4f478c1839", size = 1422766, upload-time = "2026-03-12T03:40:27.14Z" }, + { url = "https://files.pythonhosted.org/packages/59/4f/40a582c015f2d841ac24fed6390bd68f0fc896069ff3a886317959c9daf8/black-26.3.1-cp313-cp313-win_arm64.whl", hash = "sha256:afc622538b430aa4c8c853f7f63bc582b3b8030fd8c80b70fb5fa5b834e575c2", size = 1232140, upload-time = "2026-03-12T03:40:28.882Z" }, + { url = "https://files.pythonhosted.org/packages/d5/da/e36e27c9cebc1311b7579210df6f1c86e50f2d7143ae4fcf8a5017dc8809/black-26.3.1-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:2d6bfaf7fd0993b420bed691f20f9492d53ce9a2bcccea4b797d34e947318a78", size = 1889234, upload-time = "2026-03-12T03:40:30.964Z" }, + { url = "https://files.pythonhosted.org/packages/0e/7b/9871acf393f64a5fa33668c19350ca87177b181f44bb3d0c33b2d534f22c/black-26.3.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:f89f2ab047c76a9c03f78d0d66ca519e389519902fa27e7a91117ef7611c0568", size = 1720522, upload-time = "2026-03-12T03:40:32.346Z" }, + { url = "https://files.pythonhosted.org/packages/03/87/e766c7f2e90c07fb7586cc787c9ae6462b1eedab390191f2b7fc7f6170a9/black-26.3.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b07fc0dab849d24a80a29cfab8d8a19187d1c4685d8a5e6385a5ce323c1f015f", size = 1787824, upload-time = "2026-03-12T03:40:33.636Z" }, + { url = "https://files.pythonhosted.org/packages/ac/94/2424338fb2d1875e9e83eed4c8e9c67f6905ec25afd826a911aea2b02535/black-26.3.1-cp314-cp314-win_amd64.whl", hash = "sha256:0126ae5b7c09957da2bdbd91a9ba1207453feada9e9fe51992848658c6c8e01c", size = 1445855, upload-time = "2026-03-12T03:40:35.442Z" }, + { url = "https://files.pythonhosted.org/packages/86/43/0c3338bd928afb8ee7471f1a4eec3bdbe2245ccb4a646092a222e8669840/black-26.3.1-cp314-cp314-win_arm64.whl", hash = "sha256:92c0ec1f2cc149551a2b7b47efc32c866406b6891b0ee4625e95967c8f4acfb1", size = 1258109, upload-time = "2026-03-12T03:40:36.832Z" }, + { url = "https://files.pythonhosted.org/packages/8e/0d/52d98722666d6fc6c3dd4c76df339501d6efd40e0ff95e6186a7b7f0befd/black-26.3.1-py3-none-any.whl", hash = "sha256:2bd5aa94fc267d38bb21a70d7410a89f1a1d318841855f698746f8e7f51acd1b", size = 207542, upload-time = "2026-03-12T03:36:01.668Z" }, ] [[package]] @@ -1636,11 +1646,11 @@ wheels = [ [[package]] name = "pathspec" -version = "0.12.1" +version = "1.0.4" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ca/bc/f35b8446f4531a7cb215605d100cd88b7ac6f44ab3fc94870c120ab3adbf/pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712", size = 51043, upload-time = "2023-12-10T22:30:45Z" } +sdist = { url = "https://files.pythonhosted.org/packages/fa/36/e27608899f9b8d4dff0617b2d9ab17ca5608956ca44461ac14ac48b44015/pathspec-1.0.4.tar.gz", hash = "sha256:0210e2ae8a21a9137c0d470578cb0e595af87edaa6ebf12ff176f14a02e0e645", size = 131200, upload-time = "2026-01-27T03:59:46.938Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08", size = 31191, upload-time = "2023-12-10T22:30:43.14Z" }, + { url = "https://files.pythonhosted.org/packages/ef/3c/2c197d226f9ea224a9ab8d197933f9da0ae0aac5b6e0f884e2b8d9c8e9f7/pathspec-1.0.4-py3-none-any.whl", hash = "sha256:fb6ae2fd4e7c921a165808a552060e722767cfa526f99ca5156ed2ce45a5c723", size = 55206, upload-time = "2026-01-27T03:59:45.137Z" }, ] [[package]] @@ -2064,6 +2074,45 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1b/d0/397f9626e711ff749a95d96b7af99b9c566a9bb5129b8e4c10fc4d100304/python_multipart-0.0.22-py3-none-any.whl", hash = "sha256:2b2cd894c83d21bf49d702499531c7bafd057d730c201782048f7945d82de155", size = 24579, upload-time = "2026-01-25T10:15:54.811Z" }, ] +[[package]] +name = "pytokens" +version = "0.4.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b6/34/b4e015b99031667a7b960f888889c5bd34ef585c85e1cb56a594b92836ac/pytokens-0.4.1.tar.gz", hash = "sha256:292052fe80923aae2260c073f822ceba21f3872ced9a68bb7953b348e561179a", size = 23015, upload-time = "2026-01-30T01:03:45.924Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/42/24/f206113e05cb8ef51b3850e7ef88f20da6f4bf932190ceb48bd3da103e10/pytokens-0.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2a44ed93ea23415c54f3face3b65ef2b844d96aeb3455b8a69b3df6beab6acc5", size = 161522, upload-time = "2026-01-30T01:02:50.393Z" }, + { url = "https://files.pythonhosted.org/packages/d4/e9/06a6bf1b90c2ed81a9c7d2544232fe5d2891d1cd480e8a1809ca354a8eb2/pytokens-0.4.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:add8bf86b71a5d9fb5b89f023a80b791e04fba57960aa790cc6125f7f1d39dfe", size = 246945, upload-time = "2026-01-30T01:02:52.399Z" }, + { url = "https://files.pythonhosted.org/packages/69/66/f6fb1007a4c3d8b682d5d65b7c1fb33257587a5f782647091e3408abe0b8/pytokens-0.4.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:670d286910b531c7b7e3c0b453fd8156f250adb140146d234a82219459b9640c", size = 259525, upload-time = "2026-01-30T01:02:53.737Z" }, + { url = "https://files.pythonhosted.org/packages/04/92/086f89b4d622a18418bac74ab5db7f68cf0c21cf7cc92de6c7b919d76c88/pytokens-0.4.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:4e691d7f5186bd2842c14813f79f8884bb03f5995f0575272009982c5ac6c0f7", size = 262693, upload-time = "2026-01-30T01:02:54.871Z" }, + { url = "https://files.pythonhosted.org/packages/b4/7b/8b31c347cf94a3f900bdde750b2e9131575a61fdb620d3d3c75832262137/pytokens-0.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:27b83ad28825978742beef057bfe406ad6ed524b2d28c252c5de7b4a6dd48fa2", size = 103567, upload-time = "2026-01-30T01:02:56.414Z" }, + { url = "https://files.pythonhosted.org/packages/3d/92/790ebe03f07b57e53b10884c329b9a1a308648fc083a6d4a39a10a28c8fc/pytokens-0.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d70e77c55ae8380c91c0c18dea05951482e263982911fc7410b1ffd1dadd3440", size = 160864, upload-time = "2026-01-30T01:02:57.882Z" }, + { url = "https://files.pythonhosted.org/packages/13/25/a4f555281d975bfdd1eba731450e2fe3a95870274da73fb12c40aeae7625/pytokens-0.4.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4a58d057208cb9075c144950d789511220b07636dd2e4708d5645d24de666bdc", size = 248565, upload-time = "2026-01-30T01:02:59.912Z" }, + { url = "https://files.pythonhosted.org/packages/17/50/bc0394b4ad5b1601be22fa43652173d47e4c9efbf0044c62e9a59b747c56/pytokens-0.4.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b49750419d300e2b5a3813cf229d4e5a4c728dae470bcc89867a9ad6f25a722d", size = 260824, upload-time = "2026-01-30T01:03:01.471Z" }, + { url = "https://files.pythonhosted.org/packages/4e/54/3e04f9d92a4be4fc6c80016bc396b923d2a6933ae94b5f557c939c460ee0/pytokens-0.4.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d9907d61f15bf7261d7e775bd5d7ee4d2930e04424bab1972591918497623a16", size = 264075, upload-time = "2026-01-30T01:03:04.143Z" }, + { url = "https://files.pythonhosted.org/packages/d1/1b/44b0326cb5470a4375f37988aea5d61b5cc52407143303015ebee94abfd6/pytokens-0.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:ee44d0f85b803321710f9239f335aafe16553b39106384cef8e6de40cb4ef2f6", size = 103323, upload-time = "2026-01-30T01:03:05.412Z" }, + { url = "https://files.pythonhosted.org/packages/41/5d/e44573011401fb82e9d51e97f1290ceb377800fb4eed650b96f4753b499c/pytokens-0.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:140709331e846b728475786df8aeb27d24f48cbcf7bcd449f8de75cae7a45083", size = 160663, upload-time = "2026-01-30T01:03:06.473Z" }, + { url = "https://files.pythonhosted.org/packages/f0/e6/5bbc3019f8e6f21d09c41f8b8654536117e5e211a85d89212d59cbdab381/pytokens-0.4.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6d6c4268598f762bc8e91f5dbf2ab2f61f7b95bdc07953b602db879b3c8c18e1", size = 255626, upload-time = "2026-01-30T01:03:08.177Z" }, + { url = "https://files.pythonhosted.org/packages/bf/3c/2d5297d82286f6f3d92770289fd439956b201c0a4fc7e72efb9b2293758e/pytokens-0.4.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:24afde1f53d95348b5a0eb19488661147285ca4dd7ed752bbc3e1c6242a304d1", size = 269779, upload-time = "2026-01-30T01:03:09.756Z" }, + { url = "https://files.pythonhosted.org/packages/20/01/7436e9ad693cebda0551203e0bf28f7669976c60ad07d6402098208476de/pytokens-0.4.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:5ad948d085ed6c16413eb5fec6b3e02fa00dc29a2534f088d3302c47eb59adf9", size = 268076, upload-time = "2026-01-30T01:03:10.957Z" }, + { url = "https://files.pythonhosted.org/packages/2e/df/533c82a3c752ba13ae7ef238b7f8cdd272cf1475f03c63ac6cf3fcfb00b6/pytokens-0.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:3f901fe783e06e48e8cbdc82d631fca8f118333798193e026a50ce1b3757ea68", size = 103552, upload-time = "2026-01-30T01:03:12.066Z" }, + { url = "https://files.pythonhosted.org/packages/cb/dc/08b1a080372afda3cceb4f3c0a7ba2bde9d6a5241f1edb02a22a019ee147/pytokens-0.4.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:8bdb9d0ce90cbf99c525e75a2fa415144fd570a1ba987380190e8b786bc6ef9b", size = 160720, upload-time = "2026-01-30T01:03:13.843Z" }, + { url = "https://files.pythonhosted.org/packages/64/0c/41ea22205da480837a700e395507e6a24425151dfb7ead73343d6e2d7ffe/pytokens-0.4.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5502408cab1cb18e128570f8d598981c68a50d0cbd7c61312a90507cd3a1276f", size = 254204, upload-time = "2026-01-30T01:03:14.886Z" }, + { url = "https://files.pythonhosted.org/packages/e0/d2/afe5c7f8607018beb99971489dbb846508f1b8f351fcefc225fcf4b2adc0/pytokens-0.4.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:29d1d8fb1030af4d231789959f21821ab6325e463f0503a61d204343c9b355d1", size = 268423, upload-time = "2026-01-30T01:03:15.936Z" }, + { url = "https://files.pythonhosted.org/packages/68/d4/00ffdbd370410c04e9591da9220a68dc1693ef7499173eb3e30d06e05ed1/pytokens-0.4.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:970b08dd6b86058b6dc07efe9e98414f5102974716232d10f32ff39701e841c4", size = 266859, upload-time = "2026-01-30T01:03:17.458Z" }, + { url = "https://files.pythonhosted.org/packages/a7/c9/c3161313b4ca0c601eeefabd3d3b576edaa9afdefd32da97210700e47652/pytokens-0.4.1-cp313-cp313-win_amd64.whl", hash = "sha256:9bd7d7f544d362576be74f9d5901a22f317efc20046efe2034dced238cbbfe78", size = 103520, upload-time = "2026-01-30T01:03:18.652Z" }, + { url = "https://files.pythonhosted.org/packages/8f/a7/b470f672e6fc5fee0a01d9e75005a0e617e162381974213a945fcd274843/pytokens-0.4.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:4a14d5f5fc78ce85e426aa159489e2d5961acf0e47575e08f35584009178e321", size = 160821, upload-time = "2026-01-30T01:03:19.684Z" }, + { url = "https://files.pythonhosted.org/packages/80/98/e83a36fe8d170c911f864bfded690d2542bfcfacb9c649d11a9e6eb9dc41/pytokens-0.4.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:97f50fd18543be72da51dd505e2ed20d2228c74e0464e4262e4899797803d7fa", size = 254263, upload-time = "2026-01-30T01:03:20.834Z" }, + { url = "https://files.pythonhosted.org/packages/0f/95/70d7041273890f9f97a24234c00b746e8da86df462620194cef1d411ddeb/pytokens-0.4.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:dc74c035f9bfca0255c1af77ddd2d6ae8419012805453e4b0e7513e17904545d", size = 268071, upload-time = "2026-01-30T01:03:21.888Z" }, + { url = "https://files.pythonhosted.org/packages/da/79/76e6d09ae19c99404656d7db9c35dfd20f2086f3eb6ecb496b5b31163bad/pytokens-0.4.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:f66a6bbe741bd431f6d741e617e0f39ec7257ca1f89089593479347cc4d13324", size = 271716, upload-time = "2026-01-30T01:03:23.633Z" }, + { url = "https://files.pythonhosted.org/packages/79/37/482e55fa1602e0a7ff012661d8c946bafdc05e480ea5a32f4f7e336d4aa9/pytokens-0.4.1-cp314-cp314-win_amd64.whl", hash = "sha256:b35d7e5ad269804f6697727702da3c517bb8a5228afa450ab0fa787732055fc9", size = 104539, upload-time = "2026-01-30T01:03:24.788Z" }, + { url = "https://files.pythonhosted.org/packages/30/e8/20e7db907c23f3d63b0be3b8a4fd1927f6da2395f5bcc7f72242bb963dfe/pytokens-0.4.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:8fcb9ba3709ff77e77f1c7022ff11d13553f3c30299a9fe246a166903e9091eb", size = 168474, upload-time = "2026-01-30T01:03:26.428Z" }, + { url = "https://files.pythonhosted.org/packages/d6/81/88a95ee9fafdd8f5f3452107748fd04c24930d500b9aba9738f3ade642cc/pytokens-0.4.1-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:79fc6b8699564e1f9b521582c35435f1bd32dd06822322ec44afdeba666d8cb3", size = 290473, upload-time = "2026-01-30T01:03:27.415Z" }, + { url = "https://files.pythonhosted.org/packages/cf/35/3aa899645e29b6375b4aed9f8d21df219e7c958c4c186b465e42ee0a06bf/pytokens-0.4.1-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d31b97b3de0f61571a124a00ffe9a81fb9939146c122c11060725bd5aea79975", size = 303485, upload-time = "2026-01-30T01:03:28.558Z" }, + { url = "https://files.pythonhosted.org/packages/52/a0/07907b6ff512674d9b201859f7d212298c44933633c946703a20c25e9d81/pytokens-0.4.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:967cf6e3fd4adf7de8fc73cd3043754ae79c36475c1c11d514fc72cf5490094a", size = 306698, upload-time = "2026-01-30T01:03:29.653Z" }, + { url = "https://files.pythonhosted.org/packages/39/2a/cbbf9250020a4a8dd53ba83a46c097b69e5eb49dd14e708f496f548c6612/pytokens-0.4.1-cp314-cp314t-win_amd64.whl", hash = "sha256:584c80c24b078eec1e227079d56dc22ff755e0ba8654d8383b2c549107528918", size = 116287, upload-time = "2026-01-30T01:03:30.912Z" }, + { url = "https://files.pythonhosted.org/packages/c6/78/397db326746f0a342855b81216ae1f0a32965deccfd7c830a2dbc66d2483/pytokens-0.4.1-py3-none-any.whl", hash = "sha256:26cef14744a8385f35d0e095dc8b3a7583f6c953c2e3d269c7f82484bf5ad2de", size = 13729, upload-time = "2026-01-30T01:03:45.029Z" }, +] + [[package]] name = "pywin32" version = "311" From e1fd62e0f324bf64479d663cf02ba492326f64df Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 13 Mar 2026 14:43:54 +0000 Subject: [PATCH 14/60] fix: close all memory stream ends in client transport cleanup (#2266) --- pyproject.toml | 2 +- src/mcp/client/sse.py | 2 + src/mcp/client/streamable_http.py | 2 + src/mcp/client/websocket.py | 13 ++- tests/client/test_transport_stream_cleanup.py | 102 ++++++++++++++++++ tests/shared/test_sse.py | 6 -- uv.lock | 2 +- 7 files changed, 117 insertions(+), 12 deletions(-) create mode 100644 tests/client/test_transport_stream_cleanup.py diff --git a/pyproject.toml b/pyproject.toml index 737839a23c..f275b90cfd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ classifiers = [ "Programming Language :: Python :: 3.14", ] dependencies = [ - "anyio>=4.5", + "anyio>=4.9", "httpx>=0.27.1", "httpx-sse>=0.4", "pydantic>=2.12.0", diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 61026aa0c9..972efce588 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -160,3 +160,5 @@ async def post_writer(endpoint_url: str): finally: await read_stream_writer.aclose() await write_stream.aclose() + await read_stream.aclose() + await write_stream_reader.aclose() diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 9f3dd5e0ba..3416bbc816 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -577,3 +577,5 @@ def start_get_stream() -> None: finally: await read_stream_writer.aclose() await write_stream.aclose() + await read_stream.aclose() + await write_stream_reader.aclose() diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index 79e75fad18..de473f36d3 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -38,11 +38,10 @@ async def websocket_client( write_stream: MemoryObjectSendStream[SessionMessage] write_stream_reader: MemoryObjectReceiveStream[SessionMessage] - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - # Connect using websockets, requesting the "mcp" subprotocol async with ws_connect(url, subprotocols=[Subprotocol("mcp")]) as ws: + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) async def ws_reader(): """Reads text messages from the WebSocket, parses them as JSON-RPC messages, @@ -68,7 +67,13 @@ async def ws_writer(): msg_dict = session_message.message.model_dump(by_alias=True, mode="json", exclude_unset=True) await ws.send(json.dumps(msg_dict)) - async with anyio.create_task_group() as tg: + async with ( + read_stream_writer, + read_stream, + write_stream, + write_stream_reader, + anyio.create_task_group() as tg, + ): # Start reader and writer tasks tg.start_soon(ws_reader) tg.start_soon(ws_writer) diff --git a/tests/client/test_transport_stream_cleanup.py b/tests/client/test_transport_stream_cleanup.py new file mode 100644 index 0000000000..631b0fff22 --- /dev/null +++ b/tests/client/test_transport_stream_cleanup.py @@ -0,0 +1,102 @@ +"""Regression tests for memory stream leaks in client transports. + +When a connection error occurs (404, 403, ConnectError), transport context +managers must close ALL 4 memory stream ends they created. anyio memory streams +are paired but independent — closing the writer does NOT close the reader. +Unclosed stream ends emit ResourceWarning on GC, which pytest promotes to a +test failure in whatever test happens to be running when GC triggers. + +These tests force GC after the transport context exits, so any leaked stream +triggers a ResourceWarning immediately and deterministically here, rather than +nondeterministically in an unrelated later test. +""" + +import gc +import sys +from collections.abc import Iterator +from contextlib import contextmanager + +import httpx +import pytest + +from mcp.client.sse import sse_client +from mcp.client.streamable_http import streamable_http_client +from mcp.client.websocket import websocket_client + + +@contextmanager +def _assert_no_memory_stream_leak() -> Iterator[None]: + """Fail if any anyio MemoryObject stream emits ResourceWarning during the block. + + Uses a custom sys.unraisablehook to capture ONLY MemoryObject stream leaks, + ignoring unrelated resources (e.g. PipeHandle from flaky stdio tests on the + same xdist worker). gc.collect() is forced after the block to make leaks + deterministic. + """ + leaked: list[str] = [] + old_hook = sys.unraisablehook + + def hook(args: "sys.UnraisableHookArgs") -> None: # pragma: no cover + # Only executes if a leak occurs (i.e. the bug is present). + # args.object is the __del__ function (not the stream instance) when + # unraisablehook fires from a finalizer, so check exc_value — the + # actual ResourceWarning("Unclosed "). + # Non-MemoryObject unraisables (e.g. PipeHandle leaked by an earlier + # flaky test on the same xdist worker) are deliberately ignored — + # this test should not fail for another test's resource leak. + if "MemoryObject" in str(args.exc_value): + leaked.append(str(args.exc_value)) + + sys.unraisablehook = hook + try: + yield + gc.collect() + assert not leaked, f"Memory streams leaked: {leaked}" + finally: + sys.unraisablehook = old_hook + + +@pytest.mark.anyio +async def test_sse_client_closes_all_streams_on_connection_error(free_tcp_port: int) -> None: + """sse_client must close all 4 stream ends when the connection fails. + + Before the fix, only read_stream_writer and write_stream were closed in + the finally block. read_stream and write_stream_reader were leaked. + """ + with _assert_no_memory_stream_leak(): + # sse_client enters a task group BEFORE connecting, so anyio wraps the + # ConnectError from aconnect_sse in an ExceptionGroup. + with pytest.raises(Exception) as exc_info: # noqa: B017 + async with sse_client(f"http://127.0.0.1:{free_tcp_port}/sse"): + pytest.fail("should not reach here") # pragma: no cover + + assert exc_info.group_contains(httpx.ConnectError) + # exc_info holds the traceback → holds frame locals → keeps leaked + # streams alive. Must drop it before gc.collect() can detect a leak. + del exc_info + + +@pytest.mark.anyio +async def test_streamable_http_client_closes_all_streams_on_exit() -> None: + """streamable_http_client must close all 4 stream ends on exit. + + Before the fix, read_stream was never closed — not even on the happy path. + This test enters and exits the context without sending any messages, so no + network connection is ever attempted (streamable_http connects lazily). + """ + with _assert_no_memory_stream_leak(): + async with streamable_http_client("http://127.0.0.1:1/mcp"): + pass + + +@pytest.mark.anyio +async def test_websocket_client_closes_all_streams_on_connection_error(free_tcp_port: int) -> None: + """websocket_client must close all 4 stream ends when ws_connect fails. + + Before the fix, there was no try/finally at all — if ws_connect raised, + all 4 streams were leaked. + """ + with _assert_no_memory_stream_leak(): + with pytest.raises(OSError): + async with websocket_client(f"ws://127.0.0.1:{free_tcp_port}/ws"): + pytest.fail("should not reach here") # pragma: no cover diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 890e997332..5629a5707b 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -544,12 +544,6 @@ def test_sse_server_transport_endpoint_validation(endpoint: str, expected_result assert sse._endpoint.startswith("/") -# ResourceWarning filter: When mocking aconnect_sse, the sse_client's internal task -# group doesn't receive proper cancellation signals, so the sse_reader task's finally -# block (which closes read_stream_writer) doesn't execute. This is a test artifact - -# the actual code path (`if not sse.data: continue`) IS exercised and works correctly. -# Production code with real SSE connections cleans up properly. -@pytest.mark.filterwarnings("ignore::ResourceWarning") @pytest.mark.anyio async def test_sse_client_handles_empty_keepalive_pings() -> None: """Test that SSE client properly handles empty data lines (keep-alive pings). diff --git a/uv.lock b/uv.lock index c01e96a4a6..c25047e48d 100644 --- a/uv.lock +++ b/uv.lock @@ -847,7 +847,7 @@ docs = [ [package.metadata] requires-dist = [ - { name = "anyio", specifier = ">=4.5" }, + { name = "anyio", specifier = ">=4.9" }, { name = "httpx", specifier = ">=0.27.1" }, { name = "httpx-sse", specifier = ">=0.4" }, { name = "jsonschema", specifier = ">=4.20.0" }, From abfb482246a65f829d092b594d109cf23ac35a7e Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 16 Mar 2026 11:37:01 +0000 Subject: [PATCH 15/60] refactor(examples): migrate all HTTP examples to streamable_http_app() (#2291) --- docs/experimental/tasks-server.md | 22 +---------- examples/servers/simple-pagination/README.md | 6 +-- .../mcp_simple_pagination/server.py | 29 ++------------ examples/servers/simple-prompt/README.md | 6 +-- .../simple-prompt/mcp_simple_prompt/server.py | 29 ++------------ examples/servers/simple-resource/README.md | 6 +-- .../mcp_simple_resource/server.py | 29 ++------------ .../simple-streamablehttp-stateless/README.md | 1 - .../server.py | 34 ++-------------- .../servers/simple-streamablehttp/README.md | 2 - .../mcp_simple_streamablehttp/server.py | 39 ++----------------- .../mcp_simple_task_interactive/server.py | 20 +--------- .../simple-task/mcp_simple_task/server.py | 18 +-------- examples/servers/simple-tool/README.md | 6 +-- .../simple-tool/mcp_simple_tool/server.py | 29 ++------------ .../mcp_sse_polling_demo/server.py | 38 +++--------------- 16 files changed, 43 insertions(+), 271 deletions(-) diff --git a/docs/experimental/tasks-server.md b/docs/experimental/tasks-server.md index 761dc5de5c..b350ee3bb6 100644 --- a/docs/experimental/tasks-server.md +++ b/docs/experimental/tasks-server.md @@ -408,16 +408,10 @@ For custom error messages, call `task.fail()` before raising. For web applications, use the Streamable HTTP transport: ```python -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager - import uvicorn -from starlette.applications import Starlette -from starlette.routing import Mount from mcp.server import Server from mcp.server.experimental.task_context import ServerTaskContext -from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.types import ( CallToolResult, CreateTaskResult, TextContent, Tool, ToolExecution, TASK_REQUIRED, ) @@ -462,22 +456,8 @@ async def handle_tool(name: str, arguments: dict) -> CallToolResult | CreateTask return CallToolResult(content=[TextContent(type="text", text=f"Unknown: {name}")], isError=True) -def create_app(): - session_manager = StreamableHTTPSessionManager(app=server) - - @asynccontextmanager - async def lifespan(app: Starlette) -> AsyncIterator[None]: - async with session_manager.run(): - yield - - return Starlette( - routes=[Mount("/mcp", app=session_manager.handle_request)], - lifespan=lifespan, - ) - - if __name__ == "__main__": - uvicorn.run(create_app(), host="127.0.0.1", port=8000) + uvicorn.run(server.streamable_http_app(), host="127.0.0.1", port=8000) ``` ## Testing Task Servers diff --git a/examples/servers/simple-pagination/README.md b/examples/servers/simple-pagination/README.md index e732b8efbe..4cab40fd34 100644 --- a/examples/servers/simple-pagination/README.md +++ b/examples/servers/simple-pagination/README.md @@ -4,14 +4,14 @@ A simple MCP server demonstrating pagination for tools, resources, and prompts u ## Usage -Start the server using either stdio (default) or SSE transport: +Start the server using either stdio (default) or Streamable HTTP transport: ```bash # Using stdio transport (default) uv run mcp-simple-pagination -# Using SSE transport on custom port -uv run mcp-simple-pagination --transport sse --port 8000 +# Using Streamable HTTP transport on custom port +uv run mcp-simple-pagination --transport streamable-http --port 8000 ``` The server exposes: diff --git a/examples/servers/simple-pagination/mcp_simple_pagination/server.py b/examples/servers/simple-pagination/mcp_simple_pagination/server.py index bac27a0f1f..c94f2ac3d1 100644 --- a/examples/servers/simple-pagination/mcp_simple_pagination/server.py +++ b/examples/servers/simple-pagination/mcp_simple_pagination/server.py @@ -10,7 +10,6 @@ import click from mcp import types from mcp.server import Server, ServerRequestContext -from starlette.requests import Request T = TypeVar("T") @@ -143,10 +142,10 @@ async def handle_get_prompt(ctx: ServerRequestContext, params: types.GetPromptRe @click.command() -@click.option("--port", default=8000, help="Port to listen on for SSE") +@click.option("--port", default=8000, help="Port to listen on for HTTP") @click.option( "--transport", - type=click.Choice(["stdio", "sse"]), + type=click.Choice(["stdio", "streamable-http"]), default="stdio", help="Transport type", ) @@ -161,30 +160,10 @@ def main(port: int, transport: str) -> int: on_get_prompt=handle_get_prompt, ) - if transport == "sse": - from mcp.server.sse import SseServerTransport - from starlette.applications import Starlette - from starlette.responses import Response - from starlette.routing import Mount, Route - - sse = SseServerTransport("/messages/") - - async def handle_sse(request: Request): - async with sse.connect_sse(request.scope, request.receive, request._send) as streams: # type: ignore[reportPrivateUsage] - await app.run(streams[0], streams[1], app.create_initialization_options()) - return Response() - - starlette_app = Starlette( - debug=True, - routes=[ - Route("/sse", endpoint=handle_sse, methods=["GET"]), - Mount("/messages/", app=sse.handle_post_message), - ], - ) - + if transport == "streamable-http": import uvicorn - uvicorn.run(starlette_app, host="127.0.0.1", port=port) + uvicorn.run(app.streamable_http_app(), host="127.0.0.1", port=port) else: from mcp.server.stdio import stdio_server diff --git a/examples/servers/simple-prompt/README.md b/examples/servers/simple-prompt/README.md index 48e796e198..c837da876e 100644 --- a/examples/servers/simple-prompt/README.md +++ b/examples/servers/simple-prompt/README.md @@ -4,14 +4,14 @@ A simple MCP server that exposes a customizable prompt template with optional co ## Usage -Start the server using either stdio (default) or SSE transport: +Start the server using either stdio (default) or Streamable HTTP transport: ```bash # Using stdio transport (default) uv run mcp-simple-prompt -# Using SSE transport on custom port -uv run mcp-simple-prompt --transport sse --port 8000 +# Using Streamable HTTP transport on custom port +uv run mcp-simple-prompt --transport streamable-http --port 8000 ``` The server exposes a prompt named "simple" that accepts two optional arguments: diff --git a/examples/servers/simple-prompt/mcp_simple_prompt/server.py b/examples/servers/simple-prompt/mcp_simple_prompt/server.py index 6cf99d4b69..74b71b3f38 100644 --- a/examples/servers/simple-prompt/mcp_simple_prompt/server.py +++ b/examples/servers/simple-prompt/mcp_simple_prompt/server.py @@ -2,7 +2,6 @@ import click from mcp import types from mcp.server import Server, ServerRequestContext -from starlette.requests import Request def create_messages(context: str | None = None, topic: str | None = None) -> list[types.PromptMessage]: @@ -69,10 +68,10 @@ async def handle_get_prompt(ctx: ServerRequestContext, params: types.GetPromptRe @click.command() -@click.option("--port", default=8000, help="Port to listen on for SSE") +@click.option("--port", default=8000, help="Port to listen on for HTTP") @click.option( "--transport", - type=click.Choice(["stdio", "sse"]), + type=click.Choice(["stdio", "streamable-http"]), default="stdio", help="Transport type", ) @@ -83,30 +82,10 @@ def main(port: int, transport: str) -> int: on_get_prompt=handle_get_prompt, ) - if transport == "sse": - from mcp.server.sse import SseServerTransport - from starlette.applications import Starlette - from starlette.responses import Response - from starlette.routing import Mount, Route - - sse = SseServerTransport("/messages/") - - async def handle_sse(request: Request): - async with sse.connect_sse(request.scope, request.receive, request._send) as streams: # type: ignore[reportPrivateUsage] - await app.run(streams[0], streams[1], app.create_initialization_options()) - return Response() - - starlette_app = Starlette( - debug=True, - routes=[ - Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse.handle_post_message), - ], - ) - + if transport == "streamable-http": import uvicorn - uvicorn.run(starlette_app, host="127.0.0.1", port=port) + uvicorn.run(app.streamable_http_app(), host="127.0.0.1", port=port) else: from mcp.server.stdio import stdio_server diff --git a/examples/servers/simple-resource/README.md b/examples/servers/simple-resource/README.md index df674e91e4..7fb2ab7cdc 100644 --- a/examples/servers/simple-resource/README.md +++ b/examples/servers/simple-resource/README.md @@ -4,14 +4,14 @@ A simple MCP server that exposes sample text files as resources. ## Usage -Start the server using either stdio (default) or SSE transport: +Start the server using either stdio (default) or Streamable HTTP transport: ```bash # Using stdio transport (default) uv run mcp-simple-resource -# Using SSE transport on custom port -uv run mcp-simple-resource --transport sse --port 8000 +# Using Streamable HTTP transport on custom port +uv run mcp-simple-resource --transport streamable-http --port 8000 ``` The server exposes some basic text file resources that can be read by clients. diff --git a/examples/servers/simple-resource/mcp_simple_resource/server.py b/examples/servers/simple-resource/mcp_simple_resource/server.py index b9b6a1d960..8d11054145 100644 --- a/examples/servers/simple-resource/mcp_simple_resource/server.py +++ b/examples/servers/simple-resource/mcp_simple_resource/server.py @@ -4,7 +4,6 @@ import click from mcp import types from mcp.server import Server, ServerRequestContext -from starlette.requests import Request SAMPLE_RESOURCES = { "greeting": { @@ -62,10 +61,10 @@ async def handle_read_resource( @click.command() -@click.option("--port", default=8000, help="Port to listen on for SSE") +@click.option("--port", default=8000, help="Port to listen on for HTTP") @click.option( "--transport", - type=click.Choice(["stdio", "sse"]), + type=click.Choice(["stdio", "streamable-http"]), default="stdio", help="Transport type", ) @@ -76,30 +75,10 @@ def main(port: int, transport: str) -> int: on_read_resource=handle_read_resource, ) - if transport == "sse": - from mcp.server.sse import SseServerTransport - from starlette.applications import Starlette - from starlette.responses import Response - from starlette.routing import Mount, Route - - sse = SseServerTransport("/messages/") - - async def handle_sse(request: Request): - async with sse.connect_sse(request.scope, request.receive, request._send) as streams: # type: ignore[reportPrivateUsage] - await app.run(streams[0], streams[1], app.create_initialization_options()) - return Response() - - starlette_app = Starlette( - debug=True, - routes=[ - Route("/sse", endpoint=handle_sse, methods=["GET"]), - Mount("/messages/", app=sse.handle_post_message), - ], - ) - + if transport == "streamable-http": import uvicorn - uvicorn.run(starlette_app, host="127.0.0.1", port=port) + uvicorn.run(app.streamable_http_app(), host="127.0.0.1", port=port) else: from mcp.server.stdio import stdio_server diff --git a/examples/servers/simple-streamablehttp-stateless/README.md b/examples/servers/simple-streamablehttp-stateless/README.md index b87250b353..a254f88d14 100644 --- a/examples/servers/simple-streamablehttp-stateless/README.md +++ b/examples/servers/simple-streamablehttp-stateless/README.md @@ -7,7 +7,6 @@ A stateless MCP server example demonstrating the StreamableHttp transport withou - Uses the StreamableHTTP transport in stateless mode (mcp_session_id=None) - Each request creates a new ephemeral connection - No session state maintained between requests -- Task lifecycle scoped to individual requests - Suitable for deployment in multi-node environments ## Usage diff --git a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py index cb4a6503ce..e2b8d2ef2f 100644 --- a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py +++ b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py @@ -1,17 +1,11 @@ -import contextlib import logging -from collections.abc import AsyncIterator import anyio import click import uvicorn from mcp import types from mcp.server import Server, ServerRequestContext -from mcp.server.streamable_http_manager import StreamableHTTPSessionManager -from starlette.applications import Starlette from starlette.middleware.cors import CORSMiddleware -from starlette.routing import Mount -from starlette.types import Receive, Scope, Send logger = logging.getLogger(__name__) @@ -104,39 +98,17 @@ def main( on_call_tool=handle_call_tool, ) - # Create the session manager with true stateless mode - session_manager = StreamableHTTPSessionManager( - app=app, - event_store=None, + starlette_app = app.streamable_http_app( + stateless_http=True, json_response=json_response, - stateless=True, - ) - - async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None: - await session_manager.handle_request(scope, receive, send) - - @contextlib.asynccontextmanager - async def lifespan(app: Starlette) -> AsyncIterator[None]: - """Context manager for session manager.""" - async with session_manager.run(): - logger.info("Application started with StreamableHTTP session manager!") - try: - yield - finally: - logger.info("Application shutting down...") - - # Create an ASGI application using the transport - starlette_app = Starlette( debug=True, - routes=[Mount("/mcp", app=handle_streamable_http)], - lifespan=lifespan, ) # Wrap ASGI application with CORS middleware to expose Mcp-Session-Id header # for browser-based clients (ensures 500 errors get proper CORS headers) starlette_app = CORSMiddleware( starlette_app, - allow_origins=["*"], # Allow all origins - adjust as needed for production + allow_origins=["*"], # Note: streamable_http_app() enforces localhost-only Origin by default allow_methods=["GET", "POST", "DELETE"], # MCP streamable HTTP methods expose_headers=["Mcp-Session-Id"], ) diff --git a/examples/servers/simple-streamablehttp/README.md b/examples/servers/simple-streamablehttp/README.md index 9836367170..3eed3320e7 100644 --- a/examples/servers/simple-streamablehttp/README.md +++ b/examples/servers/simple-streamablehttp/README.md @@ -6,9 +6,7 @@ A simple MCP server example demonstrating the StreamableHttp transport, which en - Uses the StreamableHTTP transport for server-client communication - Supports REST API operations (POST, GET, DELETE) for `/mcp` endpoint -- Task management with anyio task groups - Ability to send multiple notifications over time to the client -- Proper resource cleanup and lifespan management - Resumability support via InMemoryEventStore ## Usage diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index 2f2a53b1b1..ec9761d1b0 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -1,16 +1,11 @@ -import contextlib import logging -from collections.abc import AsyncIterator import anyio import click +import uvicorn from mcp import types from mcp.server import Server, ServerRequestContext -from mcp.server.streamable_http_manager import StreamableHTTPSessionManager -from starlette.applications import Starlette from starlette.middleware.cors import CORSMiddleware -from starlette.routing import Mount -from starlette.types import Receive, Scope, Send from .event_store import InMemoryEventStore @@ -127,47 +122,21 @@ def main( # For production, use a persistent storage solution. event_store = InMemoryEventStore() - # Create the session manager with our app and event store - session_manager = StreamableHTTPSessionManager( - app=app, - event_store=event_store, # Enable resumability + starlette_app = app.streamable_http_app( + event_store=event_store, json_response=json_response, - ) - - # ASGI handler for streamable HTTP connections - async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None: - await session_manager.handle_request(scope, receive, send) - - @contextlib.asynccontextmanager - async def lifespan(app: Starlette) -> AsyncIterator[None]: - """Context manager for managing session manager lifecycle.""" - async with session_manager.run(): - logger.info("Application started with StreamableHTTP session manager!") - try: - yield - finally: - logger.info("Application shutting down...") - - # Create an ASGI application using the transport - starlette_app = Starlette( debug=True, - routes=[ - Mount("/mcp", app=handle_streamable_http), - ], - lifespan=lifespan, ) # Wrap ASGI application with CORS middleware to expose Mcp-Session-Id header # for browser-based clients (ensures 500 errors get proper CORS headers) starlette_app = CORSMiddleware( starlette_app, - allow_origins=["*"], # Allow all origins - adjust as needed for production + allow_origins=["*"], # Note: streamable_http_app() enforces localhost-only Origin by default allow_methods=["GET", "POST", "DELETE"], # MCP streamable HTTP methods expose_headers=["Mcp-Session-Id"], ) - import uvicorn - uvicorn.run(starlette_app, host="127.0.0.1", port=port) return 0 diff --git a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py b/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py index 6938b6552a..bc06e12088 100644 --- a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py +++ b/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py @@ -6,8 +6,6 @@ - ServerTaskContext.elicit() and ServerTaskContext.create_message() queue requests properly """ -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager from typing import Any import click @@ -15,9 +13,6 @@ from mcp import types from mcp.server import Server, ServerRequestContext from mcp.server.experimental.task_context import ServerTaskContext -from mcp.server.streamable_http_manager import StreamableHTTPSessionManager -from starlette.applications import Starlette -from starlette.routing import Mount async def handle_list_tools( @@ -134,23 +129,10 @@ async def handle_call_tool( server.experimental.enable_tasks() -def create_app(session_manager: StreamableHTTPSessionManager) -> Starlette: - @asynccontextmanager - async def app_lifespan(app: Starlette) -> AsyncIterator[None]: - async with session_manager.run(): - yield - - return Starlette( - routes=[Mount("/mcp", app=session_manager.handle_request)], - lifespan=app_lifespan, - ) - - @click.command() @click.option("--port", default=8000, help="Port to listen on") def main(port: int) -> int: - session_manager = StreamableHTTPSessionManager(app=server) - starlette_app = create_app(session_manager) + starlette_app = server.streamable_http_app() print(f"Starting server on http://localhost:{port}/mcp") uvicorn.run(starlette_app, host="127.0.0.1", port=port) return 0 diff --git a/examples/servers/simple-task/mcp_simple_task/server.py b/examples/servers/simple-task/mcp_simple_task/server.py index 50ae3ca9af..7583cd8f0e 100644 --- a/examples/servers/simple-task/mcp_simple_task/server.py +++ b/examples/servers/simple-task/mcp_simple_task/server.py @@ -1,17 +1,11 @@ """Simple task server demonstrating MCP tasks over streamable HTTP.""" -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager - import anyio import click import uvicorn from mcp import types from mcp.server import Server, ServerRequestContext from mcp.server.experimental.task_context import ServerTaskContext -from mcp.server.streamable_http_manager import StreamableHTTPSessionManager -from starlette.applications import Starlette -from starlette.routing import Mount async def handle_list_tools( @@ -69,17 +63,7 @@ async def work(task: ServerTaskContext) -> types.CallToolResult: @click.command() @click.option("--port", default=8000, help="Port to listen on") def main(port: int) -> int: - session_manager = StreamableHTTPSessionManager(app=server) - - @asynccontextmanager - async def app_lifespan(app: Starlette) -> AsyncIterator[None]: - async with session_manager.run(): - yield - - starlette_app = Starlette( - routes=[Mount("/mcp", app=session_manager.handle_request)], - lifespan=app_lifespan, - ) + starlette_app = server.streamable_http_app() print(f"Starting server on http://localhost:{port}/mcp") uvicorn.run(starlette_app, host="127.0.0.1", port=port) diff --git a/examples/servers/simple-tool/README.md b/examples/servers/simple-tool/README.md index 06020b4b0e..7d3759f9de 100644 --- a/examples/servers/simple-tool/README.md +++ b/examples/servers/simple-tool/README.md @@ -3,14 +3,14 @@ A simple MCP server that exposes a website fetching tool. ## Usage -Start the server using either stdio (default) or SSE transport: +Start the server using either stdio (default) or Streamable HTTP transport: ```bash # Using stdio transport (default) uv run mcp-simple-tool -# Using SSE transport on custom port -uv run mcp-simple-tool --transport sse --port 8000 +# Using Streamable HTTP transport on custom port +uv run mcp-simple-tool --transport streamable-http --port 8000 ``` The server exposes a tool named "fetch" that accepts one required argument: diff --git a/examples/servers/simple-tool/mcp_simple_tool/server.py b/examples/servers/simple-tool/mcp_simple_tool/server.py index 9fe71e5b7a..226058b955 100644 --- a/examples/servers/simple-tool/mcp_simple_tool/server.py +++ b/examples/servers/simple-tool/mcp_simple_tool/server.py @@ -3,7 +3,6 @@ from mcp import types from mcp.server import Server, ServerRequestContext from mcp.shared._httpx_utils import create_mcp_http_client -from starlette.requests import Request async def fetch_website( @@ -51,10 +50,10 @@ async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequ @click.command() -@click.option("--port", default=8000, help="Port to listen on for SSE") +@click.option("--port", default=8000, help="Port to listen on for HTTP") @click.option( "--transport", - type=click.Choice(["stdio", "sse"]), + type=click.Choice(["stdio", "streamable-http"]), default="stdio", help="Transport type", ) @@ -65,30 +64,10 @@ def main(port: int, transport: str) -> int: on_call_tool=handle_call_tool, ) - if transport == "sse": - from mcp.server.sse import SseServerTransport - from starlette.applications import Starlette - from starlette.responses import Response - from starlette.routing import Mount, Route - - sse = SseServerTransport("/messages/") - - async def handle_sse(request: Request): - async with sse.connect_sse(request.scope, request.receive, request._send) as streams: # type: ignore[reportPrivateUsage] - await app.run(streams[0], streams[1], app.create_initialization_options()) - return Response() - - starlette_app = Starlette( - debug=True, - routes=[ - Route("/sse", endpoint=handle_sse, methods=["GET"]), - Mount("/messages/", app=sse.handle_post_message), - ], - ) - + if transport == "streamable-http": import uvicorn - uvicorn.run(starlette_app, host="127.0.0.1", port=port) + uvicorn.run(app.streamable_http_app(), host="127.0.0.1", port=port) else: from mcp.server.stdio import stdio_server diff --git a/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py index c8178c35a4..14bc174c47 100644 --- a/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py +++ b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py @@ -12,18 +12,13 @@ uv run mcp-sse-polling-demo --port 3000 """ -import contextlib import logging -from collections.abc import AsyncIterator import anyio import click +import uvicorn from mcp import types from mcp.server import Server, ServerRequestContext -from mcp.server.streamable_http_manager import StreamableHTTPSessionManager -from starlette.applications import Starlette -from starlette.routing import Mount -from starlette.types import Receive, Scope, Send from .event_store import InMemoryEventStore @@ -149,37 +144,14 @@ def main(port: int, log_level: str, retry_interval: int) -> int: on_call_tool=handle_call_tool, ) - # Create event store for resumability - event_store = InMemoryEventStore() - - # Create session manager with event store and retry interval - session_manager = StreamableHTTPSessionManager( - app=app, - event_store=event_store, + starlette_app = app.streamable_http_app( + event_store=InMemoryEventStore(), retry_interval=retry_interval, - ) - - async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None: - await session_manager.handle_request(scope, receive, send) - - @contextlib.asynccontextmanager - async def lifespan(starlette_app: Starlette) -> AsyncIterator[None]: - async with session_manager.run(): - logger.info(f"SSE Polling Demo server started on port {port}") - logger.info("Try: POST /mcp with tools/call for 'process_batch'") - yield - logger.info("Server shutting down...") - - starlette_app = Starlette( debug=True, - routes=[ - Mount("/mcp", app=handle_streamable_http), - ], - lifespan=lifespan, ) - import uvicorn - + logger.info(f"SSE Polling Demo server starting on port {port}") + logger.info("Try: POST /mcp with tools/call for 'process_batch'") uvicorn.run(starlette_app, host="127.0.0.1", port=port) return 0 From 75a80b6f07f0dfbe0ceb8df50c920a4bda3266d0 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 16 Mar 2026 23:30:20 +0000 Subject: [PATCH 16/60] refactor: connect-first stream lifecycle for sse and streamable_http (#2292) Co-authored-by: Marcelo Trylesinski --- src/mcp/client/sse.py | 201 +++++++++--------- src/mcp/client/streamable_http.py | 76 ++++--- tests/client/test_transport_stream_cleanup.py | 37 +++- 3 files changed, 161 insertions(+), 153 deletions(-) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 972efce588..7b66b5c1b6 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -57,108 +57,101 @@ async def sse_client( write_stream: MemoryObjectSendStream[SessionMessage] write_stream_reader: MemoryObjectReceiveStream[SessionMessage] - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - - async with anyio.create_task_group() as tg: - try: - logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}") - async with httpx_client_factory( - headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout) - ) as client: - async with aconnect_sse( - client, - "GET", - url, - ) as event_source: - event_source.response.raise_for_status() - logger.debug("SSE connection established") - - async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED): - try: - async for sse in event_source.aiter_sse(): # pragma: no branch - logger.debug(f"Received SSE event: {sse.event}") - match sse.event: - case "endpoint": - endpoint_url = urljoin(url, sse.data) - logger.debug(f"Received endpoint URL: {endpoint_url}") - - url_parsed = urlparse(url) - endpoint_parsed = urlparse(endpoint_url) - if ( # pragma: no cover - url_parsed.netloc != endpoint_parsed.netloc - or url_parsed.scheme != endpoint_parsed.scheme - ): - error_msg = ( # pragma: no cover - f"Endpoint origin does not match connection origin: {endpoint_url}" - ) - logger.error(error_msg) # pragma: no cover - raise ValueError(error_msg) # pragma: no cover - - if on_session_created: - session_id = _extract_session_id_from_endpoint(endpoint_url) - if session_id: - on_session_created(session_id) - - task_status.started(endpoint_url) - - case "message": - # Skip empty data (keep-alive pings) - if not sse.data: - continue - try: - message = types.jsonrpc_message_adapter.validate_json( - sse.data, by_name=False - ) - logger.debug(f"Received server message: {message}") - except Exception as exc: # pragma: no cover - logger.exception("Error parsing server message") # pragma: no cover - await read_stream_writer.send(exc) # pragma: no cover - continue # pragma: no cover - - session_message = SessionMessage(message) - await read_stream_writer.send(session_message) - case _: # pragma: no cover - logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover - except SSEError as sse_exc: # pragma: lax no cover - logger.exception("Encountered SSE exception") - raise sse_exc - except Exception as exc: # pragma: lax no cover - logger.exception("Error in sse_reader") - await read_stream_writer.send(exc) - finally: - await read_stream_writer.aclose() - - async def post_writer(endpoint_url: str): - try: - async with write_stream_reader: - async for session_message in write_stream_reader: - logger.debug(f"Sending client message: {session_message}") - response = await client.post( - endpoint_url, - json=session_message.message.model_dump( - by_alias=True, - mode="json", - exclude_unset=True, - ), + logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}") + async with httpx_client_factory( + headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout) + ) as client: + async with aconnect_sse(client, "GET", url) as event_source: + event_source.response.raise_for_status() + logger.debug("SSE connection established") + + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + + async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED): + try: + async for sse in event_source.aiter_sse(): # pragma: no branch + logger.debug(f"Received SSE event: {sse.event}") + match sse.event: + case "endpoint": + endpoint_url = urljoin(url, sse.data) + logger.debug(f"Received endpoint URL: {endpoint_url}") + + url_parsed = urlparse(url) + endpoint_parsed = urlparse(endpoint_url) + if ( # pragma: no cover + url_parsed.netloc != endpoint_parsed.netloc + or url_parsed.scheme != endpoint_parsed.scheme + ): + error_msg = ( # pragma: no cover + f"Endpoint origin does not match connection origin: {endpoint_url}" ) - response.raise_for_status() - logger.debug(f"Client message sent successfully: {response.status_code}") - except Exception: # pragma: lax no cover - logger.exception("Error in post_writer") - finally: - await write_stream.aclose() - - endpoint_url = await tg.start(sse_reader) - logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}") - tg.start_soon(post_writer, endpoint_url) - - try: - yield read_stream, write_stream - finally: - tg.cancel_scope.cancel() - finally: - await read_stream_writer.aclose() - await write_stream.aclose() - await read_stream.aclose() - await write_stream_reader.aclose() + logger.error(error_msg) # pragma: no cover + raise ValueError(error_msg) # pragma: no cover + + if on_session_created: + session_id = _extract_session_id_from_endpoint(endpoint_url) + if session_id: + on_session_created(session_id) + + task_status.started(endpoint_url) + + case "message": + # Skip empty data (keep-alive pings) + if not sse.data: + continue + try: + message = types.jsonrpc_message_adapter.validate_json(sse.data, by_name=False) + logger.debug(f"Received server message: {message}") + except Exception as exc: # pragma: no cover + logger.exception("Error parsing server message") # pragma: no cover + await read_stream_writer.send(exc) # pragma: no cover + continue # pragma: no cover + + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) + case _: # pragma: no cover + logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover + except SSEError as sse_exc: # pragma: lax no cover + logger.exception("Encountered SSE exception") + raise sse_exc + except Exception as exc: # pragma: lax no cover + logger.exception("Error in sse_reader") + await read_stream_writer.send(exc) + finally: + await read_stream_writer.aclose() + + async def post_writer(endpoint_url: str): + try: + async with write_stream_reader, write_stream: + async for session_message in write_stream_reader: + logger.debug(f"Sending client message: {session_message}") + response = await client.post( + endpoint_url, + json=session_message.message.model_dump( + by_alias=True, + mode="json", + exclude_unset=True, + ), + ) + response.raise_for_status() + logger.debug(f"Client message sent successfully: {response.status_code}") + except Exception: # pragma: lax no cover + logger.exception("Error in post_writer") + + # On Python 3.14, coverage.py reports a phantom branch arc on this + # line (->yield) when nested two async-with levels deep. The branch + # is the unreachable "did __aexit__ suppress?" arm for memory streams. + async with ( # pragma: no branch + read_stream_writer, + read_stream, + write_stream, + write_stream_reader, + anyio.create_task_group() as tg, + ): + endpoint_url = await tg.start(sse_reader) + logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}") + tg.start_soon(post_writer, endpoint_url) + + yield read_stream, write_stream + tg.cancel_scope.cancel() diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 3416bbc816..3afb94b034 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -440,7 +440,7 @@ async def post_writer( ) -> None: """Handle writing requests to the server.""" try: - async with write_stream_reader: + async with write_stream_reader, read_stream_writer, write_stream: async for session_message in write_stream_reader: message = session_message.message metadata = ( @@ -480,9 +480,6 @@ async def handle_request_async(): except Exception: # pragma: lax no cover logger.exception("Error in post_writer") - finally: - await read_stream_writer.aclose() - await write_stream.aclose() async def terminate_session(self, client: httpx.AsyncClient) -> None: """Terminate the session by sending a DELETE request.""" @@ -533,9 +530,6 @@ async def streamable_http_client( Example: See examples/snippets/clients/ for usage patterns. """ - read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) - write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) - # Determine if we need to create and manage the client client_provided = http_client is not None client = http_client @@ -546,36 +540,40 @@ async def streamable_http_client( transport = StreamableHTTPTransport(url) - async with anyio.create_task_group() as tg: - try: - logger.debug(f"Connecting to StreamableHTTP endpoint: {url}") - - async with contextlib.AsyncExitStack() as stack: - # Only manage client lifecycle if we created it - if not client_provided: - await stack.enter_async_context(client) - - def start_get_stream() -> None: - tg.start_soon(transport.handle_get_stream, client, read_stream_writer) - - tg.start_soon( - transport.post_writer, - client, - write_stream_reader, - read_stream_writer, - write_stream, - start_get_stream, - tg, - ) + logger.debug(f"Connecting to StreamableHTTP endpoint: {url}") + + async with contextlib.AsyncExitStack() as stack: + # Only manage client lifecycle if we created it + if not client_provided: + await stack.enter_async_context(client) + + read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) + write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) + + async with ( + read_stream_writer, + read_stream, + write_stream, + write_stream_reader, + anyio.create_task_group() as tg, + ): + + def start_get_stream() -> None: + tg.start_soon(transport.handle_get_stream, client, read_stream_writer) + + tg.start_soon( + transport.post_writer, + client, + write_stream_reader, + read_stream_writer, + write_stream, + start_get_stream, + tg, + ) - try: - yield read_stream, write_stream - finally: - if transport.session_id and terminate_on_close: - await transport.terminate_session(client) - tg.cancel_scope.cancel() - finally: - await read_stream_writer.aclose() - await write_stream.aclose() - await read_stream.aclose() - await write_stream_reader.aclose() + try: + yield read_stream, write_stream + finally: + if transport.session_id and terminate_on_close: + await transport.terminate_session(client) + tg.cancel_scope.cancel() diff --git a/tests/client/test_transport_stream_cleanup.py b/tests/client/test_transport_stream_cleanup.py index 631b0fff22..1e6be3c725 100644 --- a/tests/client/test_transport_stream_cleanup.py +++ b/tests/client/test_transport_stream_cleanup.py @@ -58,22 +58,39 @@ def hook(args: "sys.UnraisableHookArgs") -> None: # pragma: no cover @pytest.mark.anyio async def test_sse_client_closes_all_streams_on_connection_error(free_tcp_port: int) -> None: - """sse_client must close all 4 stream ends when the connection fails. + """sse_client creates streams only after the SSE connection succeeds, so a + ConnectError propagates directly with nothing to leak. - Before the fix, only read_stream_writer and write_stream were closed in - the finally block. read_stream and write_stream_reader were leaked. + Before the fix, streams were created before connecting and only 2 of 4 were + closed in the finally block. """ with _assert_no_memory_stream_leak(): - # sse_client enters a task group BEFORE connecting, so anyio wraps the - # ConnectError from aconnect_sse in an ExceptionGroup. - with pytest.raises(Exception) as exc_info: # noqa: B017 + with pytest.raises(httpx.ConnectError): async with sse_client(f"http://127.0.0.1:{free_tcp_port}/sse"): pytest.fail("should not reach here") # pragma: no cover - assert exc_info.group_contains(httpx.ConnectError) - # exc_info holds the traceback → holds frame locals → keeps leaked - # streams alive. Must drop it before gc.collect() can detect a leak. - del exc_info + +@pytest.mark.anyio +async def test_sse_client_closes_all_streams_on_http_error() -> None: + """sse_client creates streams only after raise_for_status() passes, so an + HTTPStatusError from a 4xx/5xx response propagates bare (not wrapped in an + ExceptionGroup) with nothing to leak — the task group is never entered. + """ + + def return_403(request: httpx.Request) -> httpx.Response: + return httpx.Response(403) + + def mock_factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + return httpx.AsyncClient(transport=httpx.MockTransport(return_403)) + + with _assert_no_memory_stream_leak(): + with pytest.raises(httpx.HTTPStatusError): + async with sse_client("http://test/sse", httpx_client_factory=mock_factory): + pytest.fail("should not reach here") # pragma: no cover @pytest.mark.anyio From 1a2244f402bdb08f1611bfd4ad34ef36a91a23a7 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 17 Mar 2026 18:40:39 +0000 Subject: [PATCH 17/60] fix: handle non-UTF-8 bytes in stdio server stdin (#2302) --- src/mcp/server/stdio.py | 4 ++-- tests/server/test_stdio.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index e526bab569..5ea6c4e778 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -39,7 +39,7 @@ async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio. # python is platform-dependent (Windows is particularly problematic), so we # re-wrap the underlying binary stream to ensure UTF-8. if not stdin: - stdin = anyio.wrap_file(TextIOWrapper(sys.stdin.buffer, encoding="utf-8")) + stdin = anyio.wrap_file(TextIOWrapper(sys.stdin.buffer, encoding="utf-8", errors="replace")) if not stdout: stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8")) @@ -58,7 +58,7 @@ async def stdin_reader(): async for line in stdin: try: message = types.jsonrpc_message_adapter.validate_json(line, by_name=False) - except Exception as exc: # pragma: no cover + except Exception as exc: await read_stream_writer.send(exc) continue diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index 9a7ddaab40..677a993567 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -1,4 +1,6 @@ import io +import sys +from io import TextIOWrapper import anyio import pytest @@ -59,3 +61,34 @@ async def test_stdio_server(): assert len(received_responses) == 2 assert received_responses[0] == JSONRPCRequest(jsonrpc="2.0", id=3, method="ping") assert received_responses[1] == JSONRPCResponse(jsonrpc="2.0", id=4, result={}) + + +@pytest.mark.anyio +async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch): + """Non-UTF-8 bytes on stdin must not crash the server. + + Invalid bytes are replaced with U+FFFD, which then fails JSON parsing and + is delivered as an in-stream exception. Subsequent valid messages must + still be processed. + """ + # \xff\xfe are invalid UTF-8 start bytes. + valid = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + raw_stdin = io.BytesIO(b"\xff\xfe\n" + valid.model_dump_json(by_alias=True, exclude_none=True).encode() + b"\n") + + # Replace sys.stdin with a wrapper whose .buffer is our raw bytes, so that + # stdio_server()'s default path wraps it with errors='replace'. + monkeypatch.setattr(sys, "stdin", TextIOWrapper(raw_stdin, encoding="utf-8")) + monkeypatch.setattr(sys, "stdout", TextIOWrapper(io.BytesIO(), encoding="utf-8")) + + with anyio.fail_after(5): + async with stdio_server() as (read_stream, write_stream): + await write_stream.aclose() + async with read_stream: # pragma: no branch + # First line: \xff\xfe -> U+FFFD U+FFFD -> JSON parse fails -> exception in stream + first = await read_stream.receive() + assert isinstance(first, Exception) + + # Second line: valid message still comes through + second = await read_stream.receive() + assert isinstance(second, SessionMessage) + assert second.message == valid From ff50351f9e08b0a7dbbcade3813c48986b737b05 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 17 Mar 2026 19:53:39 +0000 Subject: [PATCH 18/60] ci: run strict-no-cover in scripts/test to catch stale pragmas locally (#2305) --- CLAUDE.md | 19 ++++++++++++++++--- scripts/test | 3 +++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 98bd451152..2eee085e13 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -29,18 +29,31 @@ This document contains critical information about working with this codebase. Fo - IMPORTANT: The `tests/client/test_client.py` is the most well designed test file. Follow its patterns. - IMPORTANT: Be minimal, and focus on E2E tests: Use the `mcp.client.Client` whenever possible. - Coverage: CI requires 100% (`fail_under = 100`, `branch = true`). - - Full check: `./scripts/test` (~20s, matches CI exactly) - - Targeted check while iterating: + - Full check: `./scripts/test` (~23s). Runs coverage + `strict-no-cover` on the + default Python. Not identical to CI: CI also runs 3.10–3.14 × {ubuntu, windows}, + and some branch-coverage quirks only surface on specific matrix entries. + - Targeted check while iterating (~4s, deterministic): ```bash uv run --frozen coverage erase uv run --frozen coverage run -m pytest tests/path/test_foo.py uv run --frozen coverage combine uv run --frozen coverage report --include='src/mcp/path/foo.py' --fail-under=0 + UV_FROZEN=1 uv run --frozen strict-no-cover ``` Partial runs can't hit 100% (coverage tracks `tests/` too), so `--fail-under=0` - and `--include` scope the report to what you actually changed. + and `--include` scope the report. `strict-no-cover` has no false positives on + partial runs — if your new test executes a line marked `# pragma: no cover`, + even a single-file run catches it. + - Coverage pragmas: + - `# pragma: no cover` — line is never executed. CI's `strict-no-cover` fails if + it IS executed. When your test starts covering such a line, remove the pragma. + - `# pragma: lax no cover` — excluded from coverage but not checked by + `strict-no-cover`. Use for lines covered on some platforms/versions but not + others. + - `# pragma: no branch` — excludes branch arcs only. coverage.py misreports the + `->exit` arc for nested `async with` on Python 3.11+ (worse on 3.14/Windows). - Avoid `anyio.sleep()` with a fixed duration to wait for async operations. Instead: - Use `anyio.Event` — set it in the callback/handler, `await event.wait()` in the test - For stream messages, use `await stream.receive()` instead of `sleep()` + `receive_nowait()` diff --git a/scripts/test b/scripts/test index ee1259b597..dc43f351dd 100755 --- a/scripts/test +++ b/scripts/test @@ -6,3 +6,6 @@ uv run --frozen coverage erase uv run --frozen coverage run -m pytest -n auto $@ uv run --frozen coverage combine uv run --frozen coverage report +# strict-no-cover spawns `uv run coverage json` internally without --frozen; +# UV_FROZEN=1 propagates to that subprocess so it doesn't touch uv.lock. +UV_FROZEN=1 uv run --frozen strict-no-cover From 7826ade12b50ce377d0856a3e3448ea36269a4c7 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 18 Mar 2026 15:25:11 +0000 Subject: [PATCH 19/60] test: convert test_integration.py to in-memory transport (fix flaky) (#2277) --- src/mcp/server/session.py | 2 +- tests/server/mcpserver/test_integration.py | 744 +++++++-------------- 2 files changed, 231 insertions(+), 515 deletions(-) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 759d2131a1..ce467e6c93 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -474,7 +474,7 @@ async def send_progress_notification( related_request_id, ) - async def send_resource_list_changed(self) -> None: # pragma: no cover + async def send_resource_list_changed(self) -> None: """Send a resource list changed notification.""" await self.send_notification(types.ResourceListChangedNotification()) diff --git a/tests/server/mcpserver/test_integration.py b/tests/server/mcpserver/test_integration.py index c4ea2dad65..90e333b775 100644 --- a/tests/server/mcpserver/test_integration.py +++ b/tests/server/mcpserver/test_integration.py @@ -1,7 +1,7 @@ """Integration tests for MCPServer server functionality. These tests validate the proper functioning of MCPServer features using focused, -single-feature servers across different transports (SSE and StreamableHTTP). +single-feature example servers over an in-memory transport. """ # TODO(Marcelo): The `examples` package is not being imported as package. We need to solve this. # pyright: reportUnknownMemberType=false @@ -10,12 +10,8 @@ # pyright: reportUnknownArgumentType=false import json -import multiprocessing -import socket -from collections.abc import Generator import pytest -import uvicorn from inline_snapshot import snapshot from examples.snippets.servers import ( @@ -30,9 +26,8 @@ structured_output, tool_progress, ) +from mcp.client import Client from mcp.client.session import ClientSession -from mcp.client.sse import sse_client -from mcp.client.streamable_http import streamable_http_client from mcp.shared._context import RequestContext from mcp.shared.session import RequestResponder from mcp.types import ( @@ -42,7 +37,6 @@ ElicitRequestParams, ElicitResult, GetPromptResult, - InitializeResult, LoggingMessageNotification, LoggingMessageNotificationParams, NotificationParams, @@ -58,7 +52,8 @@ TextResourceContents, ToolListChangedNotification, ) -from tests.test_helpers import wait_for_server + +pytestmark = pytest.mark.anyio class NotificationCollector: @@ -85,105 +80,6 @@ async def handle_generic_notification( self.tool_notifications.append(message.params) -# Common fixtures -@pytest.fixture -def server_port() -> int: - """Get a free port for testing.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: - """Get the server URL for testing.""" - return f"http://127.0.0.1:{server_port}" - - -def run_server_with_transport(module_name: str, port: int, transport: str) -> None: # pragma: no cover - """Run server with specified transport.""" - # Get the MCP instance based on module name - if module_name == "basic_tool": - mcp = basic_tool.mcp - elif module_name == "basic_resource": - mcp = basic_resource.mcp - elif module_name == "basic_prompt": - mcp = basic_prompt.mcp - elif module_name == "tool_progress": - mcp = tool_progress.mcp - elif module_name == "sampling": - mcp = sampling.mcp - elif module_name == "elicitation": - mcp = elicitation.mcp - elif module_name == "completion": - mcp = completion.mcp - elif module_name == "notifications": - mcp = notifications.mcp - elif module_name == "mcpserver_quickstart": - mcp = mcpserver_quickstart.mcp - elif module_name == "structured_output": - mcp = structured_output.mcp - else: - raise ImportError(f"Unknown module: {module_name}") - - # Create app based on transport type - if transport == "sse": - app = mcp.sse_app() - elif transport == "streamable-http": - app = mcp.streamable_http_app() - else: - raise ValueError(f"Invalid transport for test server: {transport}") - - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=port, log_level="error")) - print(f"Starting {transport} server on port {port}") - server.run() - - -@pytest.fixture -def server_transport(request: pytest.FixtureRequest, server_port: int) -> Generator[str, None, None]: - """Start server in a separate process with specified MCP instance and transport. - - Args: - request: pytest request with param tuple of (module_name, transport) - server_port: Port to run the server on - - Yields: - str: The transport type ('sse' or 'streamable_http') - """ - module_name, transport = request.param - - proc = multiprocessing.Process( - target=run_server_with_transport, - args=(module_name, server_port, transport), - daemon=True, - ) - proc.start() - - # Wait for server to be ready - wait_for_server(server_port) - - yield transport - - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("Server process failed to terminate") - - -# Helper function to create client based on transport -def create_client_for_transport(transport: str, server_url: str): - """Create the appropriate client context manager based on transport type.""" - if transport == "sse": - endpoint = f"{server_url}/sse" - return sse_client(endpoint) - elif transport == "streamable-http": - endpoint = f"{server_url}/mcp" - return streamable_http_client(endpoint) - else: # pragma: no cover - raise ValueError(f"Invalid transport: {transport}") - - -# Callback functions for testing async def sampling_callback( context: RequestContext[ClientSession], params: CreateMessageRequestParams ) -> CreateMessageResult: @@ -210,147 +106,85 @@ async def elicitation_callback(context: RequestContext[ClientSession], params: E return ElicitResult(action="decline") -# Test basic tools -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("basic_tool", "sse"), - ("basic_tool", "streamable-http"), - ], - indirect=True, -) -async def test_basic_tools(server_transport: str, server_url: str) -> None: +async def test_basic_tools() -> None: """Test basic tool functionality.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Tool Example" - assert result.capabilities.tools is not None - - # Test sum tool - tool_result = await session.call_tool("sum", {"a": 5, "b": 3}) - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - assert tool_result.content[0].text == "8" - - # Test weather tool - weather_result = await session.call_tool("get_weather", {"city": "London"}) - assert len(weather_result.content) == 1 - assert isinstance(weather_result.content[0], TextContent) - assert "Weather in London: 22degreesC" in weather_result.content[0].text - - -# Test resources -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("basic_resource", "sse"), - ("basic_resource", "streamable-http"), - ], - indirect=True, -) -async def test_basic_resources(server_transport: str, server_url: str) -> None: + async with Client(basic_tool.mcp) as client: + assert client.server_capabilities is not None + assert client.server_capabilities.tools is not None + + # Test sum tool + tool_result = await client.call_tool("sum", {"a": 5, "b": 3}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == "8" + + # Test weather tool + weather_result = await client.call_tool("get_weather", {"city": "London"}) + assert len(weather_result.content) == 1 + assert isinstance(weather_result.content[0], TextContent) + assert "Weather in London: 22degreesC" in weather_result.content[0].text + + +async def test_basic_resources() -> None: """Test basic resource functionality.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Resource Example" - assert result.capabilities.resources is not None - - # Test document resource - doc_content = await session.read_resource("file://documents/readme") - assert isinstance(doc_content, ReadResourceResult) - assert len(doc_content.contents) == 1 - assert isinstance(doc_content.contents[0], TextResourceContents) - assert "Content of readme" in doc_content.contents[0].text - - # Test settings resource - settings_content = await session.read_resource("config://settings") - assert isinstance(settings_content, ReadResourceResult) - assert len(settings_content.contents) == 1 - assert isinstance(settings_content.contents[0], TextResourceContents) - settings_json = json.loads(settings_content.contents[0].text) - assert settings_json["theme"] == "dark" - assert settings_json["language"] == "en" - - -# Test prompts -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("basic_prompt", "sse"), - ("basic_prompt", "streamable-http"), - ], - indirect=True, -) -async def test_basic_prompts(server_transport: str, server_url: str) -> None: + async with Client(basic_resource.mcp) as client: + assert client.server_capabilities is not None + assert client.server_capabilities.resources is not None + + # Test document resource + doc_content = await client.read_resource("file://documents/readme") + assert isinstance(doc_content, ReadResourceResult) + assert len(doc_content.contents) == 1 + assert isinstance(doc_content.contents[0], TextResourceContents) + assert "Content of readme" in doc_content.contents[0].text + + # Test settings resource + settings_content = await client.read_resource("config://settings") + assert isinstance(settings_content, ReadResourceResult) + assert len(settings_content.contents) == 1 + assert isinstance(settings_content.contents[0], TextResourceContents) + settings_json = json.loads(settings_content.contents[0].text) + assert settings_json["theme"] == "dark" + assert settings_json["language"] == "en" + + +async def test_basic_prompts() -> None: """Test basic prompt functionality.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Prompt Example" - assert result.capabilities.prompts is not None - - # Test review_code prompt - prompts = await session.list_prompts() - review_prompt = next((p for p in prompts.prompts if p.name == "review_code"), None) - assert review_prompt is not None - - prompt_result = await session.get_prompt("review_code", {"code": "def hello():\n print('Hello')"}) - assert isinstance(prompt_result, GetPromptResult) - assert len(prompt_result.messages) == 1 - assert isinstance(prompt_result.messages[0].content, TextContent) - assert "Please review this code:" in prompt_result.messages[0].content.text - assert "def hello():" in prompt_result.messages[0].content.text - - # Test debug_error prompt - debug_result = await session.get_prompt( - "debug_error", {"error": "TypeError: 'NoneType' object is not subscriptable"} - ) - assert isinstance(debug_result, GetPromptResult) - assert len(debug_result.messages) == 3 - assert debug_result.messages[0].role == "user" - assert isinstance(debug_result.messages[0].content, TextContent) - assert "I'm seeing this error:" in debug_result.messages[0].content.text - assert debug_result.messages[1].role == "user" - assert isinstance(debug_result.messages[1].content, TextContent) - assert "TypeError" in debug_result.messages[1].content.text - assert debug_result.messages[2].role == "assistant" - assert isinstance(debug_result.messages[2].content, TextContent) - assert "I'll help debug that" in debug_result.messages[2].content.text - - -# Test progress reporting -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("tool_progress", "sse"), - ("tool_progress", "streamable-http"), - ], - indirect=True, -) -async def test_tool_progress(server_transport: str, server_url: str) -> None: + async with Client(basic_prompt.mcp) as client: + assert client.server_capabilities is not None + assert client.server_capabilities.prompts is not None + + # Test review_code prompt + prompts = await client.list_prompts() + review_prompt = next((p for p in prompts.prompts if p.name == "review_code"), None) + assert review_prompt is not None + + prompt_result = await client.get_prompt("review_code", {"code": "def hello():\n print('Hello')"}) + assert isinstance(prompt_result, GetPromptResult) + assert len(prompt_result.messages) == 1 + assert isinstance(prompt_result.messages[0].content, TextContent) + assert "Please review this code:" in prompt_result.messages[0].content.text + assert "def hello():" in prompt_result.messages[0].content.text + + # Test debug_error prompt + debug_result = await client.get_prompt( + "debug_error", {"error": "TypeError: 'NoneType' object is not subscriptable"} + ) + assert isinstance(debug_result, GetPromptResult) + assert len(debug_result.messages) == 3 + assert debug_result.messages[0].role == "user" + assert isinstance(debug_result.messages[0].content, TextContent) + assert "I'm seeing this error:" in debug_result.messages[0].content.text + assert debug_result.messages[1].role == "user" + assert isinstance(debug_result.messages[1].content, TextContent) + assert "TypeError" in debug_result.messages[1].content.text + assert debug_result.messages[2].role == "assistant" + assert isinstance(debug_result.messages[2].content, TextContent) + assert "I'll help debug that" in debug_result.messages[2].content.text + + +async def test_tool_progress() -> None: """Test tool progress reporting.""" - transport = server_transport collector = NotificationCollector() async def message_handler(message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception): @@ -358,134 +192,79 @@ async def message_handler(message: RequestResponder[ServerRequest, ClientResult] if isinstance(message, Exception): # pragma: no cover raise message - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Progress Example" - - # Test progress callback - progress_updates = [] - - async def progress_callback(progress: float, total: float | None, message: str | None) -> None: - progress_updates.append((progress, total, message)) - - # Call tool with progress - steps = 3 - tool_result = await session.call_tool( - "long_running_task", - {"task_name": "Test Task", "steps": steps}, - progress_callback=progress_callback, - ) - assert tool_result.content == snapshot([TextContent(text="Task 'Test Task' completed")]) - - # Verify progress updates - assert len(progress_updates) == steps - for i, (progress, total, message) in enumerate(progress_updates): - expected_progress = (i + 1) / steps - assert abs(progress - expected_progress) < 0.01 - assert total == 1.0 - assert f"Step {i + 1}/{steps}" in message - - # Verify log messages - assert len(collector.log_messages) > 0 - - -# Test sampling -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("sampling", "sse"), - ("sampling", "streamable-http"), - ], - indirect=True, -) -async def test_sampling(server_transport: str, server_url: str) -> None: + async with Client(tool_progress.mcp, message_handler=message_handler) as client: + # Test progress callback + progress_updates = [] + + async def progress_callback(progress: float, total: float | None, message: str | None) -> None: + progress_updates.append((progress, total, message)) + + # Call tool with progress + steps = 3 + tool_result = await client.call_tool( + "long_running_task", + {"task_name": "Test Task", "steps": steps}, + progress_callback=progress_callback, + ) + assert tool_result.content == snapshot([TextContent(text="Task 'Test Task' completed")]) + + # Verify progress updates + assert len(progress_updates) == steps + for i, (progress, total, message) in enumerate(progress_updates): + expected_progress = (i + 1) / steps + assert abs(progress - expected_progress) < 0.01 + assert total == 1.0 + assert f"Step {i + 1}/{steps}" in message + + # Verify log messages + assert len(collector.log_messages) > 0 + + +async def test_sampling() -> None: """Test sampling (LLM interaction) functionality.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, sampling_callback=sampling_callback) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Sampling Example" - assert result.capabilities.tools is not None - - # Test sampling tool - sampling_result = await session.call_tool("generate_poem", {"topic": "nature"}) - assert len(sampling_result.content) == 1 - assert isinstance(sampling_result.content[0], TextContent) - assert "This is a simulated LLM response" in sampling_result.content[0].text - - -# Test elicitation -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("elicitation", "sse"), - ("elicitation", "streamable-http"), - ], - indirect=True, -) -async def test_elicitation(server_transport: str, server_url: str) -> None: + async with Client(sampling.mcp, sampling_callback=sampling_callback) as client: + assert client.server_capabilities is not None + assert client.server_capabilities.tools is not None + + # Test sampling tool + sampling_result = await client.call_tool("generate_poem", {"topic": "nature"}) + assert len(sampling_result.content) == 1 + assert isinstance(sampling_result.content[0], TextContent) + assert "This is a simulated LLM response" in sampling_result.content[0].text + + +async def test_elicitation() -> None: """Test elicitation (user interaction) functionality.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, elicitation_callback=elicitation_callback) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Elicitation Example" - - # Test booking with unavailable date (triggers elicitation) - booking_result = await session.call_tool( - "book_table", - { - "date": "2024-12-25", # Unavailable date - "time": "19:00", - "party_size": 4, - }, - ) - assert len(booking_result.content) == 1 - assert isinstance(booking_result.content[0], TextContent) - assert "[SUCCESS] Booked for 2024-12-26" in booking_result.content[0].text - - # Test booking with available date (no elicitation) - booking_result = await session.call_tool( - "book_table", - { - "date": "2024-12-20", # Available date - "time": "20:00", - "party_size": 2, - }, - ) - assert len(booking_result.content) == 1 - assert isinstance(booking_result.content[0], TextContent) - assert "[SUCCESS] Booked for 2024-12-20 at 20:00" in booking_result.content[0].text - - -# Test notifications -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("notifications", "sse"), - ("notifications", "streamable-http"), - ], - indirect=True, -) -async def test_notifications(server_transport: str, server_url: str) -> None: + async with Client(elicitation.mcp, elicitation_callback=elicitation_callback) as client: + # Test booking with unavailable date (triggers elicitation) + booking_result = await client.call_tool( + "book_table", + { + "date": "2024-12-25", # Unavailable date + "time": "19:00", + "party_size": 4, + }, + ) + assert len(booking_result.content) == 1 + assert isinstance(booking_result.content[0], TextContent) + assert "[SUCCESS] Booked for 2024-12-26" in booking_result.content[0].text + + # Test booking with available date (no elicitation) + booking_result = await client.call_tool( + "book_table", + { + "date": "2024-12-20", # Available date + "time": "20:00", + "party_size": 2, + }, + ) + assert len(booking_result.content) == 1 + assert isinstance(booking_result.content[0], TextContent) + assert "[SUCCESS] Booked for 2024-12-20 at 20:00" in booking_result.content[0].text + + +async def test_notifications() -> None: """Test notifications and logging functionality.""" - transport = server_transport collector = NotificationCollector() async def message_handler(message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception): @@ -493,150 +272,87 @@ async def message_handler(message: RequestResponder[ServerRequest, ClientResult] if isinstance(message, Exception): # pragma: no cover raise message - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Notifications Example" - - # Call tool that generates notifications - tool_result = await session.call_tool("process_data", {"data": "test_data"}) - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - assert "Processed: test_data" in tool_result.content[0].text - - # Verify log messages at different levels - assert len(collector.log_messages) >= 4 - log_levels = {msg.level for msg in collector.log_messages} - assert "debug" in log_levels - assert "info" in log_levels - assert "warning" in log_levels - assert "error" in log_levels - - # Verify resource list changed notification - assert len(collector.resource_notifications) > 0 - - -# Test completion -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("completion", "sse"), - ("completion", "streamable-http"), - ], - indirect=True, -) -async def test_completion(server_transport: str, server_url: str) -> None: + async with Client(notifications.mcp, message_handler=message_handler) as client: + # Call tool that generates notifications + tool_result = await client.call_tool("process_data", {"data": "test_data"}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert "Processed: test_data" in tool_result.content[0].text + + # Verify log messages at different levels + assert len(collector.log_messages) >= 4 + log_levels = {msg.level for msg in collector.log_messages} + assert "debug" in log_levels + assert "info" in log_levels + assert "warning" in log_levels + assert "error" in log_levels + + # Verify resource list changed notification + assert len(collector.resource_notifications) > 0 + + +async def test_completion() -> None: """Test completion (autocomplete) functionality.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Example" - assert result.capabilities.resources is not None - assert result.capabilities.prompts is not None - - # Test resource completion - completion_result = await session.complete( - ref=ResourceTemplateReference(type="ref/resource", uri="github://repos/{owner}/{repo}"), - argument={"name": "repo", "value": ""}, - context_arguments={"owner": "modelcontextprotocol"}, - ) - - assert completion_result is not None - assert hasattr(completion_result, "completion") - assert completion_result.completion is not None - assert len(completion_result.completion.values) == 3 - assert "python-sdk" in completion_result.completion.values - assert "typescript-sdk" in completion_result.completion.values - assert "specification" in completion_result.completion.values - - # Test prompt completion - completion_result = await session.complete( - ref=PromptReference(type="ref/prompt", name="review_code"), - argument={"name": "language", "value": "py"}, - ) - - assert completion_result is not None - assert hasattr(completion_result, "completion") - assert completion_result.completion is not None - assert "python" in completion_result.completion.values - assert all(lang.startswith("py") for lang in completion_result.completion.values) - - -# Test MCPServer quickstart example -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("mcpserver_quickstart", "sse"), - ("mcpserver_quickstart", "streamable-http"), - ], - indirect=True, -) -async def test_mcpserver_quickstart(server_transport: str, server_url: str) -> None: + async with Client(completion.mcp) as client: + assert client.server_capabilities is not None + assert client.server_capabilities.resources is not None + assert client.server_capabilities.prompts is not None + + # Test resource completion + completion_result = await client.complete( + ref=ResourceTemplateReference(type="ref/resource", uri="github://repos/{owner}/{repo}"), + argument={"name": "repo", "value": ""}, + context_arguments={"owner": "modelcontextprotocol"}, + ) + + assert completion_result is not None + assert hasattr(completion_result, "completion") + assert completion_result.completion is not None + assert len(completion_result.completion.values) == 3 + assert "python-sdk" in completion_result.completion.values + assert "typescript-sdk" in completion_result.completion.values + assert "specification" in completion_result.completion.values + + # Test prompt completion + completion_result = await client.complete( + ref=PromptReference(type="ref/prompt", name="review_code"), + argument={"name": "language", "value": "py"}, + ) + + assert completion_result is not None + assert hasattr(completion_result, "completion") + assert completion_result.completion is not None + assert "python" in completion_result.completion.values + assert all(lang.startswith("py") for lang in completion_result.completion.values) + + +async def test_mcpserver_quickstart() -> None: """Test MCPServer quickstart example.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Demo" - - # Test add tool - tool_result = await session.call_tool("add", {"a": 10, "b": 20}) - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - assert tool_result.content[0].text == "30" - - # Test greeting resource directly - resource_result = await session.read_resource("greeting://Alice") - assert len(resource_result.contents) == 1 - assert isinstance(resource_result.contents[0], TextResourceContents) - assert resource_result.contents[0].text == "Hello, Alice!" - - -# Test structured output example -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("structured_output", "sse"), - ("structured_output", "streamable-http"), - ], - indirect=True, -) -async def test_structured_output(server_transport: str, server_url: str) -> None: + async with Client(mcpserver_quickstart.mcp) as client: + # Test add tool + tool_result = await client.call_tool("add", {"a": 10, "b": 20}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == "30" + + # Test greeting resource directly + resource_result = await client.read_resource("greeting://Alice") + assert len(resource_result.contents) == 1 + assert isinstance(resource_result.contents[0], TextResourceContents) + assert resource_result.contents[0].text == "Hello, Alice!" + + +async def test_structured_output() -> None: """Test structured output functionality.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Structured Output Example" - - # Test get_weather tool - weather_result = await session.call_tool("get_weather", {"city": "New York"}) - assert len(weather_result.content) == 1 - assert isinstance(weather_result.content[0], TextContent) - - # Check that the result contains expected weather data - result_text = weather_result.content[0].text - assert "22.5" in result_text # temperature - assert "sunny" in result_text # condition - assert "45" in result_text # humidity - assert "5.2" in result_text # wind_speed + async with Client(structured_output.mcp) as client: + # Test get_weather tool + weather_result = await client.call_tool("get_weather", {"city": "New York"}) + assert len(weather_result.content) == 1 + assert isinstance(weather_result.content[0], TextContent) + + # Check that the result contains expected weather data + result_text = weather_result.content[0].text + assert "22.5" in result_text # temperature + assert "sunny" in result_text # condition + assert "45" in result_text # humidity + assert "5.2" in result_text # wind_speed From 67201a9bbd9c31419716886d9c1036e6370f83dc Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 18 Mar 2026 15:48:30 +0000 Subject: [PATCH 20/60] test: fix WS test port race; narrow to single smoke test covering both transport ends (#2267) --- src/mcp/server/websocket.py | 8 +- tests/client/test_client.py | 17 ++- tests/shared/test_ws.py | 205 ++++-------------------------------- tests/test_helpers.py | 54 ++++++++++ 4 files changed, 97 insertions(+), 187 deletions(-) diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index 3e675da5fd..32b50560cc 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -10,7 +10,7 @@ from mcp.shared.message import SessionMessage -@asynccontextmanager # pragma: no cover +@asynccontextmanager async def websocket_server(scope: Scope, receive: Receive, send: Send): """WebSocket server transport for MCP. This is an ASGI application, suitable for use with a framework like Starlette and a server like Hypercorn. @@ -34,13 +34,13 @@ async def ws_reader(): async for msg in websocket.iter_text(): try: client_message = types.jsonrpc_message_adapter.validate_json(msg, by_name=False) - except ValidationError as exc: + except ValidationError as exc: # pragma: no cover await read_stream_writer.send(exc) continue session_message = SessionMessage(client_message) await read_stream_writer.send(session_message) - except anyio.ClosedResourceError: + except anyio.ClosedResourceError: # pragma: no cover await websocket.close() async def ws_writer(): @@ -49,7 +49,7 @@ async def ws_writer(): async for session_message in write_stream_reader: obj = session_message.message.model_dump_json(by_alias=True, exclude_unset=True) await websocket.send_text(obj) - except anyio.ClosedResourceError: + except anyio.ClosedResourceError: # pragma: no cover await websocket.close() async with anyio.create_task_group() as tg: diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 45300063a2..3bdd305702 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -8,7 +8,7 @@ import pytest from inline_snapshot import snapshot -from mcp import types +from mcp import MCPError, types from mcp.client._memory import InMemoryTransport from mcp.client.client import Client from mcp.server import Server, ServerRequestContext @@ -175,6 +175,21 @@ async def test_read_resource(app: MCPServer): ) +async def test_read_resource_error_propagates(): + """MCPError raised by a server handler propagates to the client with its code intact.""" + + async def handle_read_resource( + ctx: ServerRequestContext, params: types.ReadResourceRequestParams + ) -> ReadResourceResult: + raise MCPError(code=404, message="no resource with that URI was found") + + server = Server("test", on_read_resource=handle_read_resource) + async with Client(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.read_resource("unknown://example") + assert exc_info.value.error.code == 404 + + async def test_get_prompt(app: MCPServer): """Test getting a prompt.""" async with Client(app) as client: diff --git a/tests/shared/test_ws.py b/tests/shared/test_ws.py index 9addb661de..482dcdcf32 100644 --- a/tests/shared/test_ws.py +++ b/tests/shared/test_ws.py @@ -1,210 +1,51 @@ -import multiprocessing -import socket -from collections.abc import AsyncGenerator, Generator -from urllib.parse import urlparse +"""Smoke test for the WebSocket transport. + +Runs the full WS stack end-to-end over a real TCP connection, covering both +``src/mcp/client/websocket.py`` and ``src/mcp/server/websocket.py``. MCP +semantics (error propagation, timeouts, etc.) are transport-agnostic and are +covered in ``tests/client/test_client.py`` and ``tests/issues/test_88_random_error.py``. +""" + +from collections.abc import Generator -import anyio import pytest -import uvicorn from starlette.applications import Starlette from starlette.routing import WebSocketRoute from starlette.websockets import WebSocket -from mcp import MCPError from mcp.client.session import ClientSession from mcp.client.websocket import websocket_client -from mcp.server import Server, ServerRequestContext +from mcp.server import Server from mcp.server.websocket import websocket_server -from mcp.types import ( - CallToolRequestParams, - CallToolResult, - EmptyResult, - InitializeResult, - ListToolsResult, - PaginatedRequestParams, - ReadResourceRequestParams, - ReadResourceResult, - TextContent, - TextResourceContents, - Tool, -) -from tests.test_helpers import wait_for_server +from mcp.types import EmptyResult, InitializeResult +from tests.test_helpers import run_uvicorn_in_thread SERVER_NAME = "test_server_for_WS" -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: - return f"ws://127.0.0.1:{server_port}" - - -async def handle_read_resource( # pragma: no cover - ctx: ServerRequestContext, params: ReadResourceRequestParams -) -> ReadResourceResult: - parsed = urlparse(str(params.uri)) - if parsed.scheme == "foobar": - return ReadResourceResult( - contents=[TextResourceContents(uri=str(params.uri), text=f"Read {parsed.netloc}", mime_type="text/plain")] - ) - elif parsed.scheme == "slow": - await anyio.sleep(2.0) - return ReadResourceResult( - contents=[ - TextResourceContents( - uri=str(params.uri), text=f"Slow response from {parsed.netloc}", mime_type="text/plain" - ) - ] - ) - raise MCPError(code=404, message="OOPS! no resource with that URI was found") - - -async def handle_list_tools( # pragma: no cover - ctx: ServerRequestContext, params: PaginatedRequestParams | None -) -> ListToolsResult: - return ListToolsResult( - tools=[ - Tool( - name="test_tool", - description="A test tool", - input_schema={"type": "object", "properties": {}}, - ) - ] - ) - +def make_server_app() -> Starlette: + srv = Server(SERVER_NAME) -async def handle_call_tool( # pragma: no cover - ctx: ServerRequestContext, params: CallToolRequestParams -) -> CallToolResult: - return CallToolResult(content=[TextContent(type="text", text=f"Called {params.name}")]) - - -def _create_server() -> Server: # pragma: no cover - return Server( - SERVER_NAME, - on_read_resource=handle_read_resource, - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, - ) - - -# Test fixtures -def make_server_app() -> Starlette: # pragma: no cover - """Create test Starlette app with WebSocket transport""" - server = _create_server() - - async def handle_ws(websocket: WebSocket): + async def handle_ws(websocket: WebSocket) -> None: async with websocket_server(websocket.scope, websocket.receive, websocket.send) as streams: - await server.run(streams[0], streams[1], server.create_initialization_options()) - - app = Starlette(routes=[WebSocketRoute("/ws", endpoint=handle_ws)]) - return app - - -def run_server(server_port: int) -> None: # pragma: no cover - app = make_server_app() - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting server on {server_port}") - server.run() - - -@pytest.fixture() -def server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True) - print("starting process") - proc.start() + await srv.run(streams[0], streams[1], srv.create_initialization_options()) - # Wait for server to be running - print("waiting for server to start") - wait_for_server(server_port) + return Starlette(routes=[WebSocketRoute("/ws", endpoint=handle_ws)]) - yield - print("killing server") - # Signal the server to stop - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("server process failed to terminate") - - -@pytest.fixture() -async def initialized_ws_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: - """Create and initialize a WebSocket client session""" - async with websocket_client(server_url + "/ws") as streams: - async with ClientSession(*streams) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == SERVER_NAME - - # Test ping - ping_result = await session.send_ping() - assert isinstance(ping_result, EmptyResult) - - yield session +@pytest.fixture +def ws_server_url() -> Generator[str, None, None]: + with run_uvicorn_in_thread(make_server_app()) as base_url: + yield base_url.replace("http://", "ws://") + "/ws" -# Tests @pytest.mark.anyio -async def test_ws_client_basic_connection(server: None, server_url: str) -> None: - """Test the WebSocket connection establishment""" - async with websocket_client(server_url + "/ws") as streams: +async def test_ws_client_basic_connection(ws_server_url: str) -> None: + async with websocket_client(ws_server_url) as streams: async with ClientSession(*streams) as session: - # Test initialization result = await session.initialize() assert isinstance(result, InitializeResult) assert result.server_info.name == SERVER_NAME - # Test ping ping_result = await session.send_ping() assert isinstance(ping_result, EmptyResult) - - -@pytest.mark.anyio -async def test_ws_client_happy_request_and_response( - initialized_ws_client_session: ClientSession, -) -> None: - """Test a successful request and response via WebSocket""" - result = await initialized_ws_client_session.read_resource("foobar://example") - assert isinstance(result, ReadResourceResult) - assert isinstance(result.contents, list) - assert len(result.contents) > 0 - assert isinstance(result.contents[0], TextResourceContents) - assert result.contents[0].text == "Read example" - - -@pytest.mark.anyio -async def test_ws_client_exception_handling( - initialized_ws_client_session: ClientSession, -) -> None: - """Test exception handling in WebSocket communication""" - with pytest.raises(MCPError) as exc_info: - await initialized_ws_client_session.read_resource("unknown://example") - assert exc_info.value.error.code == 404 - - -@pytest.mark.anyio -async def test_ws_client_timeout( - initialized_ws_client_session: ClientSession, -) -> None: - """Test timeout handling in WebSocket communication""" - # Set a very short timeout to trigger a timeout exception - with pytest.raises(TimeoutError): - with anyio.fail_after(0.1): # 100ms timeout - await initialized_ws_client_session.read_resource("slow://example") - - # Now test that we can still use the session after a timeout - with anyio.fail_after(5): # Longer timeout to allow completion - result = await initialized_ws_client_session.read_resource("foobar://example") - assert isinstance(result, ReadResourceResult) - assert isinstance(result.contents, list) - assert len(result.contents) > 0 - assert isinstance(result.contents[0], TextResourceContents) - assert result.contents[0].text == "Read example" diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 5c04c269ff..810c72820b 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,7 +1,61 @@ """Common test utilities for MCP server tests.""" import socket +import threading import time +from collections.abc import Generator +from contextlib import contextmanager +from typing import Any + +import uvicorn + +_SERVER_SHUTDOWN_TIMEOUT_S = 5.0 + + +@contextmanager +def run_uvicorn_in_thread(app: Any, **config_kwargs: Any) -> Generator[str, None, None]: + """Run a uvicorn server in a background thread on an ephemeral port. + + The socket is bound and put into listening state *before* the thread + starts, so the port is known immediately with no wait. The kernel's + listen queue buffers any connections that arrive before uvicorn's event + loop reaches ``accept()``, so callers can connect as soon as this + function yields — no polling, no sleeps, no startup race. + + This also avoids the TOCTOU race of the old pick-a-port-then-rebind + pattern: the socket passed here is the one uvicorn serves on, with no + gap where another pytest-xdist worker could claim it. + + Args: + app: ASGI application to serve. + **config_kwargs: Additional keyword arguments for :class:`uvicorn.Config` + (e.g. ``log_level``). ``host``/``port`` are ignored since the + socket is pre-bound. + + Yields: + The base URL of the running server, e.g. ``http://127.0.0.1:54321``. + """ + host = "127.0.0.1" + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((host, 0)) + sock.listen() + port = sock.getsockname()[1] + + config_kwargs.setdefault("log_level", "error") + # Uvicorn's interface autodetection calls asyncio.iscoroutinefunction, + # which Python 3.14 deprecates. Under filterwarnings=error this crashes + # the server thread silently. Starlette is asgi3; skip the autodetect. + config_kwargs.setdefault("interface", "asgi3") + server = uvicorn.Server(config=uvicorn.Config(app=app, **config_kwargs)) + + thread = threading.Thread(target=server.run, kwargs={"sockets": [sock]}, daemon=True) + thread.start() + try: + yield f"http://{host}:{port}" + finally: + server.should_exit = True + thread.join(timeout=_SERVER_SHUTDOWN_TIMEOUT_S) def wait_for_server(port: int, timeout: float = 20.0) -> None: From 20dd94632ef1b8a8e2db5ca9c0c38dab845fa5bb Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 18 Mar 2026 17:31:26 +0000 Subject: [PATCH 21/60] feat(client): store InitializeResult as initialize_result (#2300) --- docs/migration.md | 24 +++++++++++++++++++ src/mcp/client/client.py | 15 ++++++++---- src/mcp/client/session.py | 13 ++++++----- tests/client/test_client.py | 3 ++- tests/client/test_session.py | 27 +++++++++++----------- tests/client/transports/test_memory.py | 2 +- tests/server/mcpserver/test_integration.py | 17 +++++--------- 7 files changed, 64 insertions(+), 37 deletions(-) diff --git a/docs/migration.md b/docs/migration.md index 7cf0325533..dd6a7a18f4 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -169,6 +169,30 @@ result = await session.list_resources(params=PaginatedRequestParams(cursor="next result = await session.list_tools(params=PaginatedRequestParams(cursor="next_page_token")) ``` +### `ClientSession.get_server_capabilities()` replaced by `initialize_result` property + +`ClientSession` now stores the full `InitializeResult` via an `initialize_result` property. This provides access to `server_info`, `capabilities`, `instructions`, and the negotiated `protocol_version` through a single property. The `get_server_capabilities()` method has been removed. + +**Before (v1):** + +```python +capabilities = session.get_server_capabilities() +# server_info, instructions, protocol_version were not stored — had to capture initialize() return value +``` + +**After (v2):** + +```python +result = session.initialize_result +if result is not None: + capabilities = result.capabilities + server_info = result.server_info + instructions = result.instructions + version = result.protocol_version +``` + +The high-level `Client.initialize_result` returns the same `InitializeResult` but is non-nullable — initialization is guaranteed inside the context manager, so no `None` check is needed. This replaces v1's `Client.server_capabilities`; use `client.initialize_result.capabilities` instead. + ### `McpError` renamed to `MCPError` The `McpError` exception class has been renamed to `MCPError` for consistent naming with the MCP acronym style used throughout the SDK. diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 7dc67c5844..34d6a360fa 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -19,6 +19,7 @@ EmptyResult, GetPromptResult, Implementation, + InitializeResult, ListPromptsResult, ListResourcesResult, ListResourceTemplatesResult, @@ -29,7 +30,6 @@ ReadResourceResult, RequestParamsMeta, ResourceTemplateReference, - ServerCapabilities, ) @@ -155,9 +155,16 @@ def session(self) -> ClientSession: return self._session @property - def server_capabilities(self) -> ServerCapabilities | None: - """The server capabilities received during initialization, or None if not yet initialized.""" - return self.session.get_server_capabilities() + def initialize_result(self) -> InitializeResult: + """The server's InitializeResult. + + Contains server_info, capabilities, instructions, and the negotiated protocol_version. + Raises RuntimeError if accessed outside the context manager. + """ + result = self.session.initialize_result + if result is None: # pragma: no cover + raise RuntimeError("Client must be used within an async context manager") + return result async def send_ping(self, *, meta: RequestParamsMeta | None = None) -> EmptyResult: """Send a ping request to the server.""" diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index a0ca751bd7..7c964a334c 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -131,7 +131,7 @@ def __init__( self._logging_callback = logging_callback or _default_logging_callback self._message_handler = message_handler or _default_message_handler self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} - self._server_capabilities: types.ServerCapabilities | None = None + self._initialize_result: types.InitializeResult | None = None self._experimental_features: ExperimentalClientFeatures | None = None # Experimental: Task handlers (use defaults if not provided) @@ -185,18 +185,19 @@ async def initialize(self) -> types.InitializeResult: if result.protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: raise RuntimeError(f"Unsupported protocol version from the server: {result.protocol_version}") - self._server_capabilities = result.capabilities + self._initialize_result = result await self.send_notification(types.InitializedNotification()) return result - def get_server_capabilities(self) -> types.ServerCapabilities | None: - """Return the server capabilities received during initialization. + @property + def initialize_result(self) -> types.InitializeResult | None: + """The server's InitializeResult. None until initialize() has been called. - Returns None if the session has not been initialized yet. + Contains server_info, capabilities, instructions, and the negotiated protocol_version. """ - return self._server_capabilities + return self._initialize_result @property def experimental(self) -> ExperimentalClientFeatures: diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 3bdd305702..18368e6bb3 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -99,7 +99,7 @@ def greeting_prompt(name: str) -> str: async def test_client_is_initialized(app: MCPServer): """Test that the client is initialized after entering context.""" async with Client(app) as client: - assert client.server_capabilities == snapshot( + assert client.initialize_result.capabilities == snapshot( ServerCapabilities( experimental={}, prompts=PromptsCapability(list_changed=False), @@ -107,6 +107,7 @@ async def test_client_is_initialized(app: MCPServer): tools=ToolsCapability(list_changed=False), ) ) + assert client.initialize_result.server_info.name == "test" async def test_client_with_simple_server(simple_server: Server): diff --git a/tests/client/test_session.py b/tests/client/test_session.py index d6d13e273c..f25c964f03 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -540,8 +540,8 @@ async def mock_server(): @pytest.mark.anyio -async def test_get_server_capabilities(): - """Test that get_server_capabilities returns None before init and capabilities after""" +async def test_initialize_result(): + """Test that initialize_result is None before init and contains the full result after.""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) @@ -551,6 +551,8 @@ async def test_get_server_capabilities(): resources=types.ResourcesCapability(subscribe=True, list_changed=True), tools=types.ToolsCapability(list_changed=False), ) + expected_server_info = Implementation(name="mock-server", version="0.1.0") + expected_instructions = "Use the tools wisely." async def mock_server(): session_message = await client_to_server_receive.receive() @@ -564,7 +566,8 @@ async def mock_server(): result = InitializeResult( protocol_version=LATEST_PROTOCOL_VERSION, capabilities=expected_capabilities, - server_info=Implementation(name="mock-server", version="0.1.0"), + server_info=expected_server_info, + instructions=expected_instructions, ) async with server_to_client_send: @@ -590,21 +593,17 @@ async def mock_server(): server_to_client_send, server_to_client_receive, ): - assert session.get_server_capabilities() is None + assert session.initialize_result is None tg.start_soon(mock_server) await session.initialize() - capabilities = session.get_server_capabilities() - assert capabilities is not None - assert capabilities == expected_capabilities - assert capabilities.logging is not None - assert capabilities.prompts is not None - assert capabilities.prompts.list_changed is True - assert capabilities.resources is not None - assert capabilities.resources.subscribe is True - assert capabilities.tools is not None - assert capabilities.tools.list_changed is False + result = session.initialize_result + assert result is not None + assert result.server_info == expected_server_info + assert result.capabilities == expected_capabilities + assert result.instructions == expected_instructions + assert result.protocol_version == LATEST_PROTOCOL_VERSION @pytest.mark.anyio diff --git a/tests/client/transports/test_memory.py b/tests/client/transports/test_memory.py index 47be3e2089..c8fc41fd5d 100644 --- a/tests/client/transports/test_memory.py +++ b/tests/client/transports/test_memory.py @@ -69,7 +69,7 @@ async def test_with_mcpserver(mcpserver_server: MCPServer): async def test_server_is_running(mcpserver_server: MCPServer): """Test that the server is running and responding to requests.""" async with Client(mcpserver_server) as client: - assert client.server_capabilities is not None + assert client.initialize_result.capabilities.tools is not None async def test_list_tools(mcpserver_server: MCPServer): diff --git a/tests/server/mcpserver/test_integration.py b/tests/server/mcpserver/test_integration.py index 90e333b775..f71c0574cd 100644 --- a/tests/server/mcpserver/test_integration.py +++ b/tests/server/mcpserver/test_integration.py @@ -109,8 +109,7 @@ async def elicitation_callback(context: RequestContext[ClientSession], params: E async def test_basic_tools() -> None: """Test basic tool functionality.""" async with Client(basic_tool.mcp) as client: - assert client.server_capabilities is not None - assert client.server_capabilities.tools is not None + assert client.initialize_result.capabilities.tools is not None # Test sum tool tool_result = await client.call_tool("sum", {"a": 5, "b": 3}) @@ -128,8 +127,7 @@ async def test_basic_tools() -> None: async def test_basic_resources() -> None: """Test basic resource functionality.""" async with Client(basic_resource.mcp) as client: - assert client.server_capabilities is not None - assert client.server_capabilities.resources is not None + assert client.initialize_result.capabilities.resources is not None # Test document resource doc_content = await client.read_resource("file://documents/readme") @@ -151,8 +149,7 @@ async def test_basic_resources() -> None: async def test_basic_prompts() -> None: """Test basic prompt functionality.""" async with Client(basic_prompt.mcp) as client: - assert client.server_capabilities is not None - assert client.server_capabilities.prompts is not None + assert client.initialize_result.capabilities.prompts is not None # Test review_code prompt prompts = await client.list_prompts() @@ -223,8 +220,7 @@ async def progress_callback(progress: float, total: float | None, message: str | async def test_sampling() -> None: """Test sampling (LLM interaction) functionality.""" async with Client(sampling.mcp, sampling_callback=sampling_callback) as client: - assert client.server_capabilities is not None - assert client.server_capabilities.tools is not None + assert client.initialize_result.capabilities.tools is not None # Test sampling tool sampling_result = await client.call_tool("generate_poem", {"topic": "nature"}) @@ -294,9 +290,8 @@ async def message_handler(message: RequestResponder[ServerRequest, ClientResult] async def test_completion() -> None: """Test completion (autocomplete) functionality.""" async with Client(completion.mcp) as client: - assert client.server_capabilities is not None - assert client.server_capabilities.resources is not None - assert client.server_capabilities.prompts is not None + assert client.initialize_result.capabilities.resources is not None + assert client.initialize_result.capabilities.prompts is not None # Test resource completion completion_result = await client.complete( From 5388bea53ad7b13db7d031cdfd27392474c89007 Mon Sep 17 00:00:00 2001 From: Jonathan Hefner Date: Wed, 18 Mar 2026 13:15:17 -0500 Subject: [PATCH 22/60] docs: generate hierarchical per-module API reference pages (#2103) --- docs/api.md | 1 - docs/hooks/gen_ref_pages.py | 35 +++++++++++++++++++++++++++++++++++ docs/index.md | 2 +- docs/migration.md | 2 +- mkdocs.yml | 11 +++++++---- pyproject.toml | 2 ++ uv.lock | 28 ++++++++++++++++++++++++++++ 7 files changed, 74 insertions(+), 7 deletions(-) delete mode 100644 docs/api.md create mode 100644 docs/hooks/gen_ref_pages.py diff --git a/docs/api.md b/docs/api.md deleted file mode 100644 index 3f696af543..0000000000 --- a/docs/api.md +++ /dev/null @@ -1 +0,0 @@ -::: mcp diff --git a/docs/hooks/gen_ref_pages.py b/docs/hooks/gen_ref_pages.py new file mode 100644 index 0000000000..ad8c19b45f --- /dev/null +++ b/docs/hooks/gen_ref_pages.py @@ -0,0 +1,35 @@ +"""Generate the code reference pages and navigation.""" + +from pathlib import Path + +import mkdocs_gen_files + +nav = mkdocs_gen_files.Nav() + +root = Path(__file__).parent.parent.parent +src = root / "src" + +for path in sorted(src.rglob("*.py")): + module_path = path.relative_to(src).with_suffix("") + doc_path = path.relative_to(src).with_suffix(".md") + full_doc_path = Path("api", doc_path) + + parts = tuple(module_path.parts) + + if parts[-1] == "__init__": + parts = parts[:-1] + doc_path = doc_path.with_name("index.md") + full_doc_path = full_doc_path.with_name("index.md") + elif parts[-1].startswith("_"): + continue + + nav[parts] = doc_path.as_posix() + + with mkdocs_gen_files.open(full_doc_path, "w") as fd: + ident = ".".join(parts) + fd.write(f"::: {ident}") + + mkdocs_gen_files.set_edit_path(full_doc_path, path.relative_to(root)) + +with mkdocs_gen_files.open("api/SUMMARY.md", "w") as nav_file: + nav_file.writelines(nav.build_literate_nav()) diff --git a/docs/index.md b/docs/index.md index e096d910b4..436d1c8fcd 100644 --- a/docs/index.md +++ b/docs/index.md @@ -64,4 +64,4 @@ npx -y @modelcontextprotocol/inspector ## API Reference -Full API documentation is available in the [API Reference](api.md). +Full API documentation is available in the [API Reference](api/mcp/index.md). diff --git a/docs/migration.md b/docs/migration.md index dd6a7a18f4..3b47f9aade 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -883,6 +883,6 @@ The lowlevel `Server` also now exposes a `session_manager` property to access th If you encounter issues during migration: -1. Check the [API Reference](api.md) for updated method signatures +1. Check the [API Reference](api/mcp/index.md) for updated method signatures 2. Review the [examples](https://github.com/modelcontextprotocol/python-sdk/tree/main/examples) for updated usage patterns 3. Open an issue on [GitHub](https://github.com/modelcontextprotocol/python-sdk/issues) if you find a bug or need further assistance diff --git a/mkdocs.yml b/mkdocs.yml index 070c533e31..3a555785a7 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -25,7 +25,7 @@ nav: - Introduction: experimental/tasks.md - Server Implementation: experimental/tasks-server.md - Client Usage: experimental/tasks-client.md - - API Reference: api.md + - API Reference: api/ theme: name: "material" @@ -115,10 +115,15 @@ plugins: - social: enabled: !ENV [ENABLE_SOCIAL_CARDS, false] - glightbox + - gen-files: + scripts: + - docs/hooks/gen_ref_pages.py + - literate-nav: + nav_file: SUMMARY.md - mkdocstrings: handlers: python: - paths: [src/mcp] + paths: [src] options: relative_crossrefs: true members_order: source @@ -126,8 +131,6 @@ plugins: show_signature_annotations: true signature_crossrefs: true group_by_category: false - # 3 because docs are in pages with an H2 just above them - heading_level: 3 inventories: - url: https://docs.python.org/3/objects.inv - url: https://docs.pydantic.dev/latest/objects.inv diff --git a/pyproject.toml b/pyproject.toml index f275b90cfd..624ade1709 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,9 @@ dev = [ ] docs = [ "mkdocs>=1.6.1", + "mkdocs-gen-files>=0.5.0", "mkdocs-glightbox>=0.4.0", + "mkdocs-literate-nav>=0.6.1", "mkdocs-material[imaging]>=9.5.45", "mkdocstrings-python>=2.0.1", ] diff --git a/uv.lock b/uv.lock index c25047e48d..4af3532ea2 100644 --- a/uv.lock +++ b/uv.lock @@ -840,7 +840,9 @@ dev = [ ] docs = [ { name = "mkdocs" }, + { name = "mkdocs-gen-files" }, { name = "mkdocs-glightbox" }, + { name = "mkdocs-literate-nav" }, { name = "mkdocs-material", extra = ["imaging"] }, { name = "mkdocstrings-python" }, ] @@ -888,7 +890,9 @@ dev = [ ] docs = [ { name = "mkdocs", specifier = ">=1.6.1" }, + { name = "mkdocs-gen-files", specifier = ">=0.5.0" }, { name = "mkdocs-glightbox", specifier = ">=0.4.0" }, + { name = "mkdocs-literate-nav", specifier = ">=0.6.1" }, { name = "mkdocs-material", extras = ["imaging"], specifier = ">=9.5.45" }, { name = "mkdocstrings-python", specifier = ">=2.0.1" }, ] @@ -1501,6 +1505,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9f/4d/7123b6fa2278000688ebd338e2a06d16870aaf9eceae6ba047ea05f92df1/mkdocs_autorefs-1.4.3-py3-none-any.whl", hash = "sha256:469d85eb3114801d08e9cc55d102b3ba65917a869b893403b8987b601cf55dc9", size = 25034, upload-time = "2025-08-26T14:23:15.906Z" }, ] +[[package]] +name = "mkdocs-gen-files" +version = "0.6.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mkdocs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/61/35/f26349f7fa18414eb2e25d75a6fa9c7e3186c36e1d227c0b2d785a7bd5c4/mkdocs_gen_files-0.6.0.tar.gz", hash = "sha256:52022dc14dcc0451e05e54a8f5d5e7760351b6701eff816d1e9739577ec5635e", size = 8642, upload-time = "2025-11-23T12:13:22.124Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8d/ec/72417415563c60ae01b36f0d497f1f4c803972f447ef4fb7f7746d6e07db/mkdocs_gen_files-0.6.0-py3-none-any.whl", hash = "sha256:815af15f3e2dbfda379629c1b95c02c8e6f232edf2a901186ea3b204ab1135b2", size = 8182, upload-time = "2025-11-23T12:13:20.756Z" }, +] + [[package]] name = "mkdocs-get-deps" version = "0.2.0" @@ -1527,6 +1543,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/30/cf/e9a0ce9da269746906fdc595c030f6df66793dad1487abd1699af2ba44f1/mkdocs_glightbox-0.5.1-py3-none-any.whl", hash = "sha256:f47af0daff164edf8d36e553338425be3aab6e34b987d9cbbc2ae7819a98cb01", size = 26431, upload-time = "2025-09-04T13:10:27.933Z" }, ] +[[package]] +name = "mkdocs-literate-nav" +version = "0.6.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mkdocs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f6/5f/99aa379b305cd1c2084d42db3d26f6de0ea9bf2cc1d10ed17f61aff35b9a/mkdocs_literate_nav-0.6.2.tar.gz", hash = "sha256:760e1708aa4be86af81a2b56e82c739d5a8388a0eab1517ecfd8e5aa40810a75", size = 17419, upload-time = "2025-03-18T21:53:09.711Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8a/84/b5b14d2745e4dd1a90115186284e9ee1b4d0863104011ab46abb7355a1c3/mkdocs_literate_nav-0.6.2-py3-none-any.whl", hash = "sha256:0a6489a26ec7598477b56fa112056a5e3a6c15729f0214bea8a4dbc55bd5f630", size = 13261, upload-time = "2025-03-18T21:53:08.1Z" }, +] + [[package]] name = "mkdocs-material" version = "9.7.2" From 883d89309755def3048d2c8b6ad85a2b9861130b Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 19 Mar 2026 15:16:34 +0000 Subject: [PATCH 23/60] test: rewrite cli.claude config tests to assert JSON output directly (#2311) Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> Co-authored-by: Felix Weinberger --- src/mcp/cli/claude.py | 20 +++-- tests/cli/test_claude.py | 146 ++++++++++++++++++++++++++++++++++++ tests/client/test_config.py | 75 ------------------ 3 files changed, 158 insertions(+), 83 deletions(-) create mode 100644 tests/cli/test_claude.py delete mode 100644 tests/client/test_config.py diff --git a/src/mcp/cli/claude.py b/src/mcp/cli/claude.py index 071b4b6fbb..93bf218fbb 100644 --- a/src/mcp/cli/claude.py +++ b/src/mcp/cli/claude.py @@ -33,7 +33,7 @@ def get_claude_config_path() -> Path | None: # pragma: no cover def get_uv_path() -> str: """Get the full path to the uv executable.""" uv_path = shutil.which("uv") - if not uv_path: # pragma: no cover + if not uv_path: logger.error( "uv executable not found in PATH, falling back to 'uv'. Please ensure uv is installed and in your PATH" ) @@ -65,7 +65,7 @@ def update_claude_config( """ config_dir = get_claude_config_path() uv_path = get_uv_path() - if not config_dir: # pragma: no cover + if not config_dir: raise RuntimeError( "Claude Desktop config directory not found. Please ensure Claude Desktop" " is installed and has been run at least once to initialize its config." @@ -90,7 +90,7 @@ def update_claude_config( config["mcpServers"] = {} # Always preserve existing env vars and merge with new ones - if server_name in config["mcpServers"] and "env" in config["mcpServers"][server_name]: # pragma: no cover + if server_name in config["mcpServers"] and "env" in config["mcpServers"][server_name]: existing_env = config["mcpServers"][server_name]["env"] if env_vars: # New vars take precedence over existing ones @@ -103,22 +103,26 @@ def update_claude_config( # Collect all packages in a set to deduplicate packages = {MCP_PACKAGE} - if with_packages: # pragma: no cover + if with_packages: packages.update(pkg for pkg in with_packages if pkg) # Add all packages with --with for pkg in sorted(packages): args.extend(["--with", pkg]) - if with_editable: # pragma: no cover + if with_editable: args.extend(["--with-editable", str(with_editable)]) # Convert file path to absolute before adding to command # Split off any :object suffix first - if ":" in file_spec: + # First check if we have a Windows path (e.g., C:\...) + has_windows_drive = len(file_spec) > 1 and file_spec[1] == ":" + + # Split on the last colon, but only if it's not part of the Windows drive letter + if ":" in (file_spec[2:] if has_windows_drive else file_spec): file_path, server_object = file_spec.rsplit(":", 1) file_spec = f"{Path(file_path).resolve()}:{server_object}" - else: # pragma: no cover + else: file_spec = str(Path(file_spec).resolve()) # Add mcp run command @@ -127,7 +131,7 @@ def update_claude_config( server_config: dict[str, Any] = {"command": uv_path, "args": args} # Add environment variables if specified - if env_vars: # pragma: no cover + if env_vars: server_config["env"] = env_vars config["mcpServers"][server_name] = server_config diff --git a/tests/cli/test_claude.py b/tests/cli/test_claude.py new file mode 100644 index 0000000000..73d4f0eb52 --- /dev/null +++ b/tests/cli/test_claude.py @@ -0,0 +1,146 @@ +"""Tests for mcp.cli.claude — Claude Desktop config file generation.""" + +import json +from pathlib import Path +from typing import Any + +import pytest + +from mcp.cli.claude import get_uv_path, update_claude_config + + +@pytest.fixture +def config_dir(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path: + """Temp Claude config dir with get_claude_config_path and get_uv_path mocked.""" + claude_dir = tmp_path / "Claude" + claude_dir.mkdir() + monkeypatch.setattr("mcp.cli.claude.get_claude_config_path", lambda: claude_dir) + monkeypatch.setattr("mcp.cli.claude.get_uv_path", lambda: "/fake/bin/uv") + return claude_dir + + +def _read_server(config_dir: Path, name: str) -> dict[str, Any]: + config = json.loads((config_dir / "claude_desktop_config.json").read_text()) + return config["mcpServers"][name] + + +def test_generates_uv_run_command(config_dir: Path): + """Should write a uv run command that invokes mcp run on the resolved file spec.""" + assert update_claude_config(file_spec="server.py:app", server_name="my_server") + + resolved = Path("server.py").resolve() + assert _read_server(config_dir, "my_server") == { + "command": "/fake/bin/uv", + "args": ["run", "--frozen", "--with", "mcp[cli]", "mcp", "run", f"{resolved}:app"], + } + + +def test_file_spec_without_object_suffix(config_dir: Path): + """File specs without :object should still resolve to an absolute path.""" + assert update_claude_config(file_spec="server.py", server_name="s") + + assert _read_server(config_dir, "s")["args"][-1] == str(Path("server.py").resolve()) + + +def test_with_packages_sorted_and_deduplicated(config_dir: Path): + """Extra packages should appear as --with flags, sorted and deduplicated with mcp[cli].""" + assert update_claude_config(file_spec="s.py:app", server_name="s", with_packages=["zebra", "aardvark", "zebra"]) + + args = _read_server(config_dir, "s")["args"] + assert args[:8] == ["run", "--frozen", "--with", "aardvark", "--with", "mcp[cli]", "--with", "zebra"] + + +def test_with_editable_adds_flag(config_dir: Path, tmp_path: Path): + """with_editable should add --with-editable after the --with flags.""" + editable = tmp_path / "project" + assert update_claude_config(file_spec="s.py:app", server_name="s", with_editable=editable) + + args = _read_server(config_dir, "s")["args"] + assert args[4:6] == ["--with-editable", str(editable)] + + +def test_env_vars_written(config_dir: Path): + """env_vars should be written under the server's env key.""" + assert update_claude_config(file_spec="s.py:app", server_name="s", env_vars={"KEY": "val"}) + + assert _read_server(config_dir, "s")["env"] == {"KEY": "val"} + + +def test_existing_env_vars_merged_new_wins(config_dir: Path): + """Re-installing should merge env vars, with new values overriding existing ones.""" + (config_dir / "claude_desktop_config.json").write_text( + json.dumps({"mcpServers": {"s": {"env": {"OLD": "keep", "KEY": "old"}}}}) + ) + + assert update_claude_config(file_spec="s.py:app", server_name="s", env_vars={"KEY": "new"}) + + assert _read_server(config_dir, "s")["env"] == {"OLD": "keep", "KEY": "new"} + + +def test_existing_env_vars_preserved_without_new(config_dir: Path): + """Re-installing without env_vars should keep the existing env block intact.""" + (config_dir / "claude_desktop_config.json").write_text(json.dumps({"mcpServers": {"s": {"env": {"KEEP": "me"}}}})) + + assert update_claude_config(file_spec="s.py:app", server_name="s") + + assert _read_server(config_dir, "s")["env"] == {"KEEP": "me"} + + +def test_other_servers_preserved(config_dir: Path): + """Installing a new server should not clobber existing mcpServers entries.""" + (config_dir / "claude_desktop_config.json").write_text(json.dumps({"mcpServers": {"other": {"command": "x"}}})) + + assert update_claude_config(file_spec="s.py:app", server_name="s") + + config = json.loads((config_dir / "claude_desktop_config.json").read_text()) + assert set(config["mcpServers"]) == {"other", "s"} + assert config["mcpServers"]["other"] == {"command": "x"} + + +def test_raises_when_config_dir_missing(monkeypatch: pytest.MonkeyPatch): + """Should raise RuntimeError when Claude Desktop config dir can't be found.""" + monkeypatch.setattr("mcp.cli.claude.get_claude_config_path", lambda: None) + monkeypatch.setattr("mcp.cli.claude.get_uv_path", lambda: "/fake/bin/uv") + + with pytest.raises(RuntimeError, match="Claude Desktop config directory not found"): + update_claude_config(file_spec="s.py:app", server_name="s") + + +@pytest.mark.parametrize("which_result, expected", [("/usr/local/bin/uv", "/usr/local/bin/uv"), (None, "uv")]) +def test_get_uv_path(monkeypatch: pytest.MonkeyPatch, which_result: str | None, expected: str): + """Should return shutil.which's result, or fall back to bare 'uv' when not on PATH.""" + + def fake_which(cmd: str) -> str | None: + return which_result + + monkeypatch.setattr("shutil.which", fake_which) + assert get_uv_path() == expected + + +@pytest.mark.parametrize( + "file_spec, expected_last_arg", + [ + ("C:\\Users\\server.py", "C:\\Users\\server.py"), + ("C:\\Users\\server.py:app", "C:\\Users\\server.py:app"), + ], +) +def test_windows_drive_letter_not_split( + config_dir: Path, monkeypatch: pytest.MonkeyPatch, file_spec: str, expected_last_arg: str +): + """Drive-letter paths like 'C:\\server.py' must not be split on the drive colon. + + Before the fix, a bare 'C:\\path\\server.py' would hit rsplit(":", 1) and yield + ("C", "\\path\\server.py"), calling resolve() on Path("C") instead of the full path. + """ + seen: list[str] = [] + + def fake_resolve(self: Path) -> Path: + seen.append(str(self)) + return self + + monkeypatch.setattr(Path, "resolve", fake_resolve) + + assert update_claude_config(file_spec=file_spec, server_name="s") + + assert seen == ["C:\\Users\\server.py"] + assert _read_server(config_dir, "s")["args"][-1] == expected_last_arg diff --git a/tests/client/test_config.py b/tests/client/test_config.py deleted file mode 100644 index d1a0576ff3..0000000000 --- a/tests/client/test_config.py +++ /dev/null @@ -1,75 +0,0 @@ -import json -import subprocess -from pathlib import Path -from unittest.mock import patch - -import pytest - -from mcp.cli.claude import update_claude_config - - -@pytest.fixture -def temp_config_dir(tmp_path: Path): - """Create a temporary Claude config directory.""" - config_dir = tmp_path / "Claude" - config_dir.mkdir() - return config_dir - - -@pytest.fixture -def mock_config_path(temp_config_dir: Path): - """Mock get_claude_config_path to return our temporary directory.""" - with patch("mcp.cli.claude.get_claude_config_path", return_value=temp_config_dir): - yield temp_config_dir - - -def test_command_execution(mock_config_path: Path): - """Test that the generated command can actually be executed.""" - # Setup - server_name = "test_server" - file_spec = "test_server.py:app" - - # Update config - success = update_claude_config(file_spec=file_spec, server_name=server_name) - assert success - - # Read the generated config - config_file = mock_config_path / "claude_desktop_config.json" - config = json.loads(config_file.read_text()) - - # Get the command and args - server_config = config["mcpServers"][server_name] - command = server_config["command"] - args = server_config["args"] - - test_args = [command] + args + ["--help"] - - result = subprocess.run(test_args, capture_output=True, text=True, timeout=20, check=False) - - assert result.returncode == 0 - assert "usage" in result.stdout.lower() - - -def test_absolute_uv_path(mock_config_path: Path): - """Test that the absolute path to uv is used when available.""" - # Mock the shutil.which function to return a fake path - mock_uv_path = "/usr/local/bin/uv" - - with patch("mcp.cli.claude.get_uv_path", return_value=mock_uv_path): - # Setup - server_name = "test_server" - file_spec = "test_server.py:app" - - # Update config - success = update_claude_config(file_spec=file_spec, server_name=server_name) - assert success - - # Read the generated config - config_file = mock_config_path / "claude_desktop_config.json" - config = json.loads(config_file.read_text()) - - # Verify the command is the absolute path - server_config = config["mcpServers"][server_name] - command = server_config["command"] - - assert command == mock_uv_path From 92c693bb730d059b2cc3836ccf0cdfe0720e9e18 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 20 Mar 2026 13:37:32 +0000 Subject: [PATCH 24/60] fix: cancel in-flight handlers when transport closes in server.run() (#2306) --- src/mcp/server/lowlevel/server.py | 51 ++++++--- src/mcp/shared/session.py | 6 +- tests/server/test_cancel_handling.py | 158 +++++++++++++++++++++++++++ 3 files changed, 199 insertions(+), 16 deletions(-) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 167f34b8bc..c288422720 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -387,16 +387,23 @@ async def run( await stack.enter_async_context(task_support.run()) async with anyio.create_task_group() as tg: - async for message in session.incoming_messages: - logger.debug("Received message: %s", message) - - tg.start_soon( - self._handle_message, - message, - session, - lifespan_context, - raise_exceptions, - ) + try: + async for message in session.incoming_messages: + logger.debug("Received message: %s", message) + + tg.start_soon( + self._handle_message, + message, + session, + lifespan_context, + raise_exceptions, + ) + finally: + # Transport closed: cancel in-flight handlers. Without this the + # TG join waits for them, and when they eventually try to + # respond they hit a closed write stream (the session's + # _receive_loop closed it when the read stream ended). + tg.cancel_scope.cancel() async def _handle_message( self, @@ -470,16 +477,32 @@ async def _handle_request( except MCPError as err: response = err.error except anyio.get_cancelled_exc_class(): - logger.info("Request %s cancelled - duplicate response suppressed", message.request_id) - return + if message.cancelled: + # Client sent CancelledNotification; responder.cancel() already + # sent an error response, so skip the duplicate. + logger.info("Request %s cancelled - duplicate response suppressed", message.request_id) + return + # Transport-close cancellation from the TG in run(); re-raise so the + # TG swallows its own cancellation. + raise except Exception as err: if raise_exceptions: # pragma: no cover raise err response = types.ErrorData(code=0, message=str(err)) + else: # pragma: no cover + response = types.ErrorData(code=types.METHOD_NOT_FOUND, message="Method not found") + try: await message.respond(response) - else: # pragma: no cover - await message.respond(types.ErrorData(code=types.METHOD_NOT_FOUND, message="Method not found")) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + # Transport closed between handler unblocking and respond. Happens + # when _receive_loop's finally wakes a handler blocked on + # send_request: the handler runs to respond() before run()'s TG + # cancel fires, but after the write stream closed. Closed if our + # end closed (_receive_loop's async-with exit); Broken if the peer + # end closed first (streamable_http terminate()). + logger.debug("Response for %s dropped - transport closed", message.request_id) + return logger.debug("Response sent") diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 9364abb73b..6fc59923f7 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -105,7 +105,7 @@ def __exit__( ) -> None: """Exit the context manager, performing cleanup and notifying completion.""" try: - if self._completed: # pragma: no branch + if self._completed: self._on_complete(self) finally: self._entered = False @@ -418,7 +418,9 @@ async def _receive_loop(self) -> None: finally: # after the read stream is closed, we need to send errors # to any pending requests - for id, stream in self._response_streams.items(): + # Snapshot: stream.send() wakes the waiter, whose finally pops + # from _response_streams before the next __next__() call. + for id, stream in list(self._response_streams.items()): error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed") try: await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index 297f3d6a5c..cff5a37c15 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -6,12 +6,19 @@ from mcp import Client from mcp.server import Server, ServerRequestContext from mcp.shared.exceptions import MCPError +from mcp.shared.message import SessionMessage from mcp.types import ( + LATEST_PROTOCOL_VERSION, CallToolRequest, CallToolRequestParams, CallToolResult, CancelledNotification, CancelledNotificationParams, + ClientCapabilities, + Implementation, + InitializeRequestParams, + JSONRPCNotification, + JSONRPCRequest, ListToolsResult, PaginatedRequestParams, TextContent, @@ -90,3 +97,154 @@ async def first_request(): assert isinstance(content, TextContent) assert content.text == "Call number: 2" assert call_count == 2 + + +@pytest.mark.anyio +async def test_server_cancels_in_flight_handlers_on_transport_close(): + """When the transport closes mid-request, server.run() must cancel in-flight + handlers rather than join on them. + + Without the cancel, the task group waits for the handler, which then tries + to respond through a write stream that _receive_loop already closed, + raising ClosedResourceError and crashing server.run() with exit code 1. + + This drives server.run() with raw memory streams because InMemoryTransport + wraps it in its own finally-cancel (_memory.py) which masks the bug. + """ + handler_started = anyio.Event() + handler_cancelled = anyio.Event() + server_run_returned = anyio.Event() + + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + handler_started.set() + try: + await anyio.sleep_forever() + finally: + handler_cancelled.set() + # unreachable: sleep_forever only exits via cancellation + raise AssertionError # pragma: no cover + + server = Server("test", on_call_tool=handle_call_tool) + + to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10) + server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server(): + await server.run(server_read, server_write, server.create_initialization_options()) + server_run_returned.set() + + init_req = JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=InitializeRequestParams( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ClientCapabilities(), + client_info=Implementation(name="test", version="1.0"), + ).model_dump(by_alias=True, mode="json", exclude_none=True), + ) + initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized") + call_req = JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"), + ) + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server: + tg.start_soon(run_server) + + await to_server.send(SessionMessage(init_req)) + await from_server.receive() # init response + await to_server.send(SessionMessage(initialized)) + await to_server.send(SessionMessage(call_req)) + + await handler_started.wait() + + # Close the server's input stream — this is what stdin EOF does. + # server.run()'s incoming_messages loop ends, finally-cancel fires, + # handler gets CancelledError, server.run() returns. + await to_server.aclose() + + await server_run_returned.wait() + + assert handler_cancelled.is_set() + + +@pytest.mark.anyio +async def test_server_handles_transport_close_with_pending_server_to_client_requests(): + """When the transport closes while handlers are blocked on server→client + requests (sampling, roots, elicitation), server.run() must still exit cleanly. + + Two bugs covered: + 1. _receive_loop's finally iterates _response_streams with await checkpoints + inside; the woken handler's send_request finally pops from that dict + before the next __next__() — RuntimeError: dictionary changed size. + 2. The woken handler's MCPError is caught in _handle_request, which falls + through to respond() against a write stream _receive_loop already closed. + """ + handlers_started = 0 + both_started = anyio.Event() + server_run_returned = anyio.Event() + + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + nonlocal handlers_started + handlers_started += 1 + if handlers_started == 2: + both_started.set() + # Blocks on send_request waiting for a client response that never comes. + # _receive_loop's finally will wake this with CONNECTION_CLOSED. + await ctx.session.list_roots() + raise AssertionError # pragma: no cover + + server = Server("test", on_call_tool=handle_call_tool) + + to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10) + server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server(): + await server.run(server_read, server_write, server.create_initialization_options()) + server_run_returned.set() + + init_req = JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=InitializeRequestParams( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ClientCapabilities(), + client_info=Implementation(name="test", version="1.0"), + ).model_dump(by_alias=True, mode="json", exclude_none=True), + ) + initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized") + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server: + tg.start_soon(run_server) + + await to_server.send(SessionMessage(init_req)) + await from_server.receive() # init response + await to_server.send(SessionMessage(initialized)) + + # Two tool calls → two handlers → two _response_streams entries. + for rid in (2, 3): + call_req = JSONRPCRequest( + jsonrpc="2.0", + id=rid, + method="tools/call", + params=CallToolRequestParams(name="t", arguments={}).model_dump(by_alias=True, mode="json"), + ) + await to_server.send(SessionMessage(call_req)) + + await both_started.wait() + # Drain the two roots/list requests so send_request's _write_stream.send() + # completes and both handlers are parked at response_stream_reader.receive(). + await from_server.receive() + await from_server.receive() + + await to_server.aclose() + + # Without the fixes: RuntimeError (dict mutation) or ClosedResourceError + # (respond after write-stream close) escapes run_server and this hangs. + await server_run_returned.wait() From 7ba4fb881d85406f44a5af8169fb7200fa7c8e49 Mon Sep 17 00:00:00 2001 From: Felix Weinberger <3823880+felixweinberger@users.noreply.github.com> Date: Tue, 24 Mar 2026 20:36:36 +0000 Subject: [PATCH 25/60] ci: skip claude.yml when comment is '@claude review' (#2337) --- .github/workflows/claude.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml index 8421cf954c..59dac99dcb 100644 --- a/.github/workflows/claude.yml +++ b/.github/workflows/claude.yml @@ -14,7 +14,7 @@ on: jobs: claude: if: | - (github.event_name == 'issue_comment' && contains(github.event.comment.body, '@claude')) || + (github.event_name == 'issue_comment' && contains(github.event.comment.body, '@claude') && !startsWith(github.event.comment.body, '@claude review')) || (github.event_name == 'pull_request_review_comment' && contains(github.event.comment.body, '@claude')) || (github.event_name == 'pull_request_review' && contains(github.event.review.body, '@claude')) || (github.event_name == 'issues' && (contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude'))) From 98f8ef295a6a4178ebeb75fd2e9f346be5c9eb0e Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 25 Mar 2026 23:29:13 +0100 Subject: [PATCH 26/60] Restrict httpx version to <1.0.0 (#2345) Co-authored-by: Max Isbey <224885523+maxisbey@users.noreply.github.com> --- pyproject.toml | 2 +- uv.lock | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 624ade1709..e1b19e3c9f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ ] dependencies = [ "anyio>=4.9", - "httpx>=0.27.1", + "httpx>=0.27.1,<1.0.0", "httpx-sse>=0.4", "pydantic>=2.12.0", "starlette>=0.48.0; python_version >= '3.14'", diff --git a/uv.lock b/uv.lock index 4af3532ea2..8f9a5396a2 100644 --- a/uv.lock +++ b/uv.lock @@ -850,7 +850,7 @@ docs = [ [package.metadata] requires-dist = [ { name = "anyio", specifier = ">=4.9" }, - { name = "httpx", specifier = ">=0.27.1" }, + { name = "httpx", specifier = ">=0.27.1,<1.0.0" }, { name = "httpx-sse", specifier = ">=0.4" }, { name = "jsonschema", specifier = ">=4.20.0" }, { name = "pydantic", specifier = ">=2.12.0" }, From 3517a29c828d596f0e3fb5f82fcfc86fd7a14dd0 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 27 Mar 2026 13:42:15 +0000 Subject: [PATCH 27/60] feat(server): restore `dependencies` parameter on MCPServer (#2358) --- src/mcp/server/mcpserver/server.py | 6 ++++++ tests/server/mcpserver/test_server.py | 9 +++++++++ 2 files changed, 15 insertions(+) diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index 2a7a58117a..6f9bb0e287 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -105,6 +105,9 @@ class Settings(BaseSettings, Generic[LifespanResultT]): # prompt settings warn_on_duplicate_prompts: bool + dependencies: list[str] + """List of dependencies to install in the server environment. Used by the `mcp install` and `mcp dev` CLI.""" + lifespan: Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None """An async context manager that will be called when the server is started.""" @@ -142,6 +145,7 @@ def __init__( warn_on_duplicate_resources: bool = True, warn_on_duplicate_tools: bool = True, warn_on_duplicate_prompts: bool = True, + dependencies: list[str] | None = None, lifespan: Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None = None, auth: AuthSettings | None = None, ): @@ -151,9 +155,11 @@ def __init__( warn_on_duplicate_resources=warn_on_duplicate_resources, warn_on_duplicate_tools=warn_on_duplicate_tools, warn_on_duplicate_prompts=warn_on_duplicate_prompts, + dependencies=dependencies or [], lifespan=lifespan, auth=auth, ) + self.dependencies = self.settings.dependencies self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools) self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources) diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index 3ef06d0381..49b6deb4bb 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -65,6 +65,15 @@ async def test_create_server(self): assert len(mcp.icons) == 1 assert mcp.icons[0].src == "https://example.com/icon.png" + def test_dependencies(self): + """Dependencies list is read by `mcp install` / `mcp dev` CLI commands.""" + mcp = MCPServer("test", dependencies=["pandas", "numpy"]) + assert mcp.dependencies == ["pandas", "numpy"] + assert mcp.settings.dependencies == ["pandas", "numpy"] + + mcp_no_deps = MCPServer("test") + assert mcp_no_deps.dependencies == [] + async def test_sse_app_returns_starlette_app(self): """Test that sse_app returns a Starlette application with correct routes.""" mcp = MCPServer("test") From fb2276b95fb5ac6631f9c160c9b1bd7a6d7312a9 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 27 Mar 2026 14:02:41 +0000 Subject: [PATCH 28/60] ci: remove claude-code-review workflow (#2359) --- .github/workflows/claude-code-review.yml | 33 ------------------------ 1 file changed, 33 deletions(-) delete mode 100644 .github/workflows/claude-code-review.yml diff --git a/.github/workflows/claude-code-review.yml b/.github/workflows/claude-code-review.yml deleted file mode 100644 index 514f979d7c..0000000000 --- a/.github/workflows/claude-code-review.yml +++ /dev/null @@ -1,33 +0,0 @@ -# Source: https://github.com/anthropics/claude-code-action/blob/main/docs/code-review.md -name: Claude Code Review - -on: - pull_request: - types: [opened, synchronize, ready_for_review, reopened] - -jobs: - claude-review: - # Fork PRs don't have access to secrets or OIDC tokens, so the action - # cannot authenticate. See https://github.com/anthropics/claude-code-action/issues/339 - if: github.event.pull_request.head.repo.fork == false && github.actor != 'dependabot[bot]' - runs-on: ubuntu-latest - permissions: - contents: read - pull-requests: read - issues: read - id-token: write - - steps: - - name: Checkout repository - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - fetch-depth: 1 - - - name: Run Claude Code Review - id: claude-review - uses: anthropics/claude-code-action@2f8ba26a219c06cfb0f468eef8d97055fa814f97 # v1.0.53 - with: - anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} - plugin_marketplaces: "https://github.com/anthropics/claude-code.git" - plugins: "code-review@claude-code-plugins" - prompt: "/code-review:code-review ${{ github.repository }}/pull/${{ github.event.pull_request.number }}" From e6235d1667ee59c63cfa365ce0136beae05f067a Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 31 Mar 2026 12:49:38 -0400 Subject: [PATCH 29/60] Propagate contextvars.Context through anyio streams without modifying SessionMessage (#2298) --- .../mcp_simple_auth_client/main.py | 6 +- pyproject.toml | 5 +- src/mcp/client/__main__.py | 6 +- src/mcp/client/_transport.py | 7 +- src/mcp/client/session.py | 6 +- src/mcp/client/sse.py | 23 ++-- src/mcp/client/streamable_http.py | 23 ++-- src/mcp/server/lowlevel/server.py | 15 ++- src/mcp/server/session.py | 7 +- src/mcp/server/sse.py | 13 +- src/mcp/server/stdio.py | 12 +- src/mcp/server/streamable_http.py | 18 +-- src/mcp/server/websocket.py | 12 +- src/mcp/shared/_context_streams.py | 119 ++++++++++++++++++ src/mcp/shared/_stream_protocols.py | 49 ++++++++ src/mcp/shared/memory.py | 10 +- src/mcp/shared/session.py | 26 ++-- tests/client/conftest.py | 4 +- tests/client/test_client.py | 33 +++++ tests/shared/test_streamable_http.py | 10 +- 20 files changed, 310 insertions(+), 94 deletions(-) create mode 100644 src/mcp/shared/_context_streams.py create mode 100644 src/mcp/shared/_stream_protocols.py diff --git a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py index 5fac56be5d..6ef2f0b113 100644 --- a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py +++ b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py @@ -18,7 +18,7 @@ from urllib.parse import parse_qs, urlparse import httpx -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from mcp.client._transport import ReadStream, WriteStream from mcp.client.auth import OAuthClientProvider, TokenStorage from mcp.client.session import ClientSession from mcp.client.sse import sse_client @@ -241,8 +241,8 @@ async def _default_redirect_handler(authorization_url: str) -> None: async def _run_session( self, - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], ): """Run the MCP session with the given streams.""" print("🤝 Initializing MCP session...") diff --git a/pyproject.toml b/pyproject.toml index e1b19e3c9f..7d8b4a8743 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -219,13 +219,10 @@ skip_covered = true show_missing = true ignore_errors = true precision = 2 -exclude_lines = [ - "pragma: no cover", +exclude_also = [ "pragma: lax no cover", - "if TYPE_CHECKING:", "@overload", "raise NotImplementedError", - "^\\s*\\.\\.\\.\\s*$", ] # https://coverage.readthedocs.io/en/latest/config.html#paths diff --git a/src/mcp/client/__main__.py b/src/mcp/client/__main__.py index f3db17906d..b9ec344226 100644 --- a/src/mcp/client/__main__.py +++ b/src/mcp/client/__main__.py @@ -6,9 +6,9 @@ from urllib.parse import urlparse import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp import types +from mcp.client._transport import ReadStream, WriteStream from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.stdio import StdioServerParameters, stdio_client @@ -33,8 +33,8 @@ async def message_handler( async def run_session( - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], client_info: types.Implementation | None = None, ): async with ClientSession( diff --git a/src/mcp/client/_transport.py b/src/mcp/client/_transport.py index a863629005..0163fef950 100644 --- a/src/mcp/client/_transport.py +++ b/src/mcp/client/_transport.py @@ -5,11 +5,12 @@ from contextlib import AbstractAsyncContextManager from typing import Protocol -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream - +from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.message import SessionMessage -TransportStreams = tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]] +__all__ = ["ReadStream", "WriteStream", "Transport", "TransportStreams"] + +TransportStreams = tuple[ReadStream[SessionMessage | Exception], WriteStream[SessionMessage]] class Transport(AbstractAsyncContextManager[TransportStreams], Protocol): diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 7c964a334c..0cea454a77 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -4,10 +4,10 @@ from typing import Any, Protocol import anyio.lowlevel -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import TypeAdapter from mcp import types +from mcp.client._transport import ReadStream, WriteStream from mcp.client.experimental import ExperimentalClientFeatures from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers from mcp.shared._context import RequestContext @@ -109,8 +109,8 @@ class ClientSession( ): def __init__( self, - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], read_timeout_seconds: float | None = None, sampling_callback: SamplingFnT | None = None, elicitation_callback: ElicitationFnT | None = None, diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 7b66b5c1b6..193204a153 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -7,11 +7,11 @@ import anyio import httpx from anyio.abc import TaskStatus -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from httpx_sse import aconnect_sse from httpx_sse._exceptions import SSEError from mcp import types +from mcp.shared._context_streams import create_context_streams from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client from mcp.shared.message import SessionMessage @@ -51,12 +51,6 @@ async def sse_client( auth: Optional HTTPX authentication handler. on_session_created: Optional callback invoked with the session ID when received. """ - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] - read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - - write_stream: MemoryObjectSendStream[SessionMessage] - write_stream_reader: MemoryObjectReceiveStream[SessionMessage] - logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}") async with httpx_client_factory( headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout) @@ -65,8 +59,8 @@ async def sse_client( event_source.response.raise_for_status() logger.debug("SSE connection established") - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) + write_stream, write_stream_reader = create_context_streams[SessionMessage](0) async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED): try: @@ -124,7 +118,8 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED): async def post_writer(endpoint_url: str): try: async with write_stream_reader, write_stream: - async for session_message in write_stream_reader: + + async def _send_message(session_message: SessionMessage) -> None: logger.debug(f"Sending client message: {session_message}") response = await client.post( endpoint_url, @@ -136,6 +131,14 @@ async def post_writer(endpoint_url: str): ) response.raise_for_status() logger.debug(f"Client message sent successfully: {response.status_code}") + + async for session_message in write_stream_reader: + sender_ctx = write_stream_reader.last_context + if sender_ctx is not None: + async with anyio.create_task_group() as tg: + sender_ctx.run(tg.start_soon, _send_message, session_message) + else: + await _send_message(session_message) # pragma: no cover except Exception: # pragma: lax no cover logger.exception("Error in post_writer") diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 3afb94b034..9a119c6338 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -11,11 +11,11 @@ import anyio import httpx from anyio.abc import TaskGroup -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from httpx_sse import EventSource, ServerSentEvent, aconnect_sse from pydantic import ValidationError from mcp.client._transport import TransportStreams +from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams from mcp.shared._httpx_utils import create_mcp_http_client from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.types import ( @@ -38,8 +38,8 @@ # TODO(Marcelo): Put the TransportStreams in a module under shared, so we can import here. SessionMessageOrError = SessionMessage | Exception -StreamWriter = MemoryObjectSendStream[SessionMessageOrError] -StreamReader = MemoryObjectReceiveStream[SessionMessage] +StreamWriter = ContextSendStream[SessionMessageOrError] +StreamReader = ContextReceiveStream[SessionMessage] MCP_SESSION_ID = "mcp-session-id" MCP_PROTOCOL_VERSION = "mcp-protocol-version" @@ -434,14 +434,15 @@ async def post_writer( client: httpx.AsyncClient, write_stream_reader: StreamReader, read_stream_writer: StreamWriter, - write_stream: MemoryObjectSendStream[SessionMessage], + write_stream: ContextSendStream[SessionMessage], start_get_stream: Callable[[], None], tg: TaskGroup, ) -> None: """Handle writing requests to the server.""" try: async with write_stream_reader, read_stream_writer, write_stream: - async for session_message in write_stream_reader: + + async def _handle_message(session_message: SessionMessage) -> None: message = session_message.message metadata = ( session_message.metadata @@ -478,6 +479,14 @@ async def handle_request_async(): else: await handle_request_async() + async for session_message in write_stream_reader: + sender_ctx = write_stream_reader.last_context + if sender_ctx is not None: + async with anyio.create_task_group() as tg_local: + sender_ctx.run(tg_local.start_soon, _handle_message, session_message) + else: + await _handle_message(session_message) # pragma: no cover + except Exception: # pragma: lax no cover logger.exception("Error in post_writer") @@ -547,8 +556,8 @@ async def streamable_http_client( if not client_provided: await stack.enter_async_context(client) - read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) - write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) + read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) + write_stream, write_stream_reader = create_context_streams[SessionMessage](0) async with ( read_stream_writer, diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index c288422720..0fdbbff866 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -36,6 +36,7 @@ async def main(): from __future__ import annotations +import contextvars import logging import warnings from collections.abc import AsyncIterator, Awaitable, Callable @@ -44,7 +45,6 @@ async def main(): from typing import Any, Generic import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware @@ -65,6 +65,7 @@ async def main(): from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings +from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.exceptions import MCPError from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder @@ -355,8 +356,8 @@ def session_manager(self) -> StreamableHTTPSessionManager: async def run( self, - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], initialization_options: InitializationOptions, # When False, exceptions are returned as messages to the client. # When True, exceptions are raised, which will cause the server to shut down @@ -391,7 +392,13 @@ async def run( async for message in session.incoming_messages: logger.debug("Received message: %s", message) - tg.start_soon( + if isinstance(message, RequestResponder) and message.context is not None: + context = message.context + else: + context = contextvars.copy_context() + + context.run( + tg.start_soon, self._handle_message, message, session, diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index ce467e6c93..20b640527a 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -33,13 +33,14 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult: import anyio import anyio.lowlevel -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from anyio.streams.memory import MemoryObjectReceiveStream from pydantic import AnyUrl, TypeAdapter from mcp import types from mcp.server.experimental.session_features import ExperimentalServerSessionFeatures from mcp.server.models import InitializationOptions from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages +from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.exceptions import StatelessModeNotSupported from mcp.shared.experimental.tasks.capabilities import check_tasks_capability from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY @@ -79,8 +80,8 @@ class ServerSession( def __init__( self, - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], init_options: InitializationOptions, stateless: bool = False, ) -> None: diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 9dcee67f78..48192ff612 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -43,7 +43,6 @@ async def handle_sse(request): from uuid import UUID, uuid4 import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import ValidationError from sse_starlette import EventSourceResponse from starlette.requests import Request @@ -55,6 +54,7 @@ async def handle_sse(request): TransportSecurityMiddleware, TransportSecuritySettings, ) +from mcp.shared._context_streams import ContextSendStream, create_context_streams from mcp.shared.message import ServerMessageMetadata, SessionMessage logger = logging.getLogger(__name__) @@ -72,7 +72,7 @@ class SseServerTransport: """ _endpoint: str - _read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]] + _read_stream_writers: dict[UUID, ContextSendStream[SessionMessage | Exception]] _security: TransportSecurityMiddleware def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | None = None) -> None: @@ -129,14 +129,9 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # prag raise ValueError("Request validation failed") logger.debug("Setting up SSE connection") - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] - read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - write_stream: MemoryObjectSendStream[SessionMessage] - write_stream_reader: MemoryObjectReceiveStream[SessionMessage] - - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) + write_stream, write_stream_reader = create_context_streams[SessionMessage](0) session_id = uuid4() self._read_stream_writers[session_id] = read_stream_writer diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index 5ea6c4e778..5c1459dff6 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -23,9 +23,9 @@ async def run_server(): import anyio import anyio.lowlevel -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp import types +from mcp.shared._context_streams import create_context_streams from mcp.shared.message import SessionMessage @@ -43,14 +43,8 @@ async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio. if not stdout: stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8")) - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] - read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - - write_stream: MemoryObjectSendStream[SessionMessage] - write_stream_reader: MemoryObjectReceiveStream[SessionMessage] - - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) + write_stream, write_stream_reader = create_context_streams[SessionMessage](0) async def stdin_reader(): try: diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index aa99e7c887..f14201857c 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -25,6 +25,8 @@ from starlette.types import Receive, Scope, Send from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings +from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams +from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp.types import ( @@ -119,10 +121,10 @@ class StreamableHTTPServerTransport: """ # Server notification streams for POST requests as well as standalone SSE stream - _read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = None - _read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None - _write_stream: MemoryObjectSendStream[SessionMessage] | None = None - _write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None + _read_stream_writer: ContextSendStream[SessionMessage | Exception] | None = None + _read_stream: ContextReceiveStream[SessionMessage | Exception] | None = None + _write_stream: ContextSendStream[SessionMessage] | None = None + _write_stream_reader: ContextReceiveStream[SessionMessage] | None = None _security: TransportSecurityMiddleware def __init__( @@ -954,8 +956,8 @@ async def connect( self, ) -> AsyncGenerator[ tuple[ - MemoryObjectReceiveStream[SessionMessage | Exception], - MemoryObjectSendStream[SessionMessage], + ReadStream[SessionMessage | Exception], + WriteStream[SessionMessage], ], None, ]: @@ -967,8 +969,8 @@ async def connect( # Create the memory streams for this connection - read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) - write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) + read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) + write_stream, write_stream_reader = create_context_streams[SessionMessage](0) # Store the streams self._read_stream_writer = read_stream_writer diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index 32b50560cc..277f9b5af2 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -1,12 +1,12 @@ from contextlib import asynccontextmanager import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic_core import ValidationError from starlette.types import Receive, Scope, Send from starlette.websockets import WebSocket from mcp import types +from mcp.shared._context_streams import create_context_streams from mcp.shared.message import SessionMessage @@ -19,14 +19,8 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send): websocket = WebSocket(scope, receive, send) await websocket.accept(subprotocol="mcp") - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] - read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - - write_stream: MemoryObjectSendStream[SessionMessage] - write_stream_reader: MemoryObjectReceiveStream[SessionMessage] - - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) + write_stream, write_stream_reader = create_context_streams[SessionMessage](0) async def ws_reader(): try: diff --git a/src/mcp/shared/_context_streams.py b/src/mcp/shared/_context_streams.py new file mode 100644 index 0000000000..04c33306d9 --- /dev/null +++ b/src/mcp/shared/_context_streams.py @@ -0,0 +1,119 @@ +"""Context-aware memory stream wrappers. + +anyio memory streams do not propagate ``contextvars.Context`` across task +boundaries. These thin wrappers capture the sender's context at ``send()`` +time and expose it on the receive side via ``last_context``, so consumers +can restore it with ``ctx.run(handler, item)``. + +The iteration interface is unchanged (yields ``T``, not tuples), keeping +these wrappers duck-type compatible with plain ``MemoryObjectSendStream`` +and ``MemoryObjectReceiveStream``. +""" + +from __future__ import annotations + +import contextvars +from types import TracebackType +from typing import Any, Generic, TypeVar + +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream + +T = TypeVar("T") + +# Internal payload carried through the underlying raw stream. +_Envelope = tuple[contextvars.Context, T] + + +class ContextSendStream(Generic[T]): + """Send-side wrapper that snapshots ``contextvars.copy_context()`` on every ``send()``.""" + + __slots__ = ("_inner",) + + def __init__(self, inner: MemoryObjectSendStream[_Envelope[T]]) -> None: + self._inner = inner + + async def send(self, item: T) -> None: + await self._inner.send((contextvars.copy_context(), item)) + + def close(self) -> None: + self._inner.close() + + async def aclose(self) -> None: + await self._inner.aclose() + + def clone(self) -> ContextSendStream[T]: # pragma: no cover + return ContextSendStream(self._inner.clone()) + + async def __aenter__(self) -> ContextSendStream[T]: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + await self.aclose() + return None + + +class ContextReceiveStream(Generic[T]): + """Receive-side wrapper that yields ``T`` and stores the sender's context in ``last_context``.""" + + __slots__ = ("_inner", "last_context") + + def __init__(self, inner: MemoryObjectReceiveStream[_Envelope[T]]) -> None: + self._inner = inner + self.last_context: contextvars.Context | None = None + + async def receive(self) -> T: + ctx, item = await self._inner.receive() + self.last_context = ctx + return item + + def close(self) -> None: + self._inner.close() + + async def aclose(self) -> None: + await self._inner.aclose() + + def clone(self) -> ContextReceiveStream[T]: # pragma: no cover + return ContextReceiveStream(self._inner.clone()) + + def __aiter__(self) -> ContextReceiveStream[T]: + return self + + async def __anext__(self) -> T: + try: + return await self.receive() + except anyio.EndOfStream: + raise StopAsyncIteration + + async def __aenter__(self) -> ContextReceiveStream[T]: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + await self.aclose() + return None + + +class create_context_streams( + tuple[ContextSendStream[T], ContextReceiveStream[T]], +): + """Create context-aware memory object streams. + + Supports ``create_context_streams[T](n)`` bracket syntax, + matching anyio's ``create_memory_object_stream`` API style. + """ + + def __new__(cls, max_buffer_size: float = 0) -> tuple[ContextSendStream[T], ContextReceiveStream[T]]: # type: ignore[type-var] + raw_send: MemoryObjectSendStream[Any] + raw_receive: MemoryObjectReceiveStream[Any] + raw_send, raw_receive = anyio.create_memory_object_stream(max_buffer_size) + return (ContextSendStream(raw_send), ContextReceiveStream(raw_receive)) diff --git a/src/mcp/shared/_stream_protocols.py b/src/mcp/shared/_stream_protocols.py new file mode 100644 index 0000000000..b799751329 --- /dev/null +++ b/src/mcp/shared/_stream_protocols.py @@ -0,0 +1,49 @@ +"""Stream protocols for MCP transports. + +These are general-purpose protocols satisfied by both ``MemoryObjectSendStream``/ +``MemoryObjectReceiveStream`` and the context-aware wrappers in ``_context_streams``. +""" + +from __future__ import annotations + +from types import TracebackType +from typing import Protocol, TypeVar + +from typing_extensions import Self + +T_co = TypeVar("T_co", covariant=True) +T_contra = TypeVar("T_contra", contravariant=True) + + +class ReadStream(Protocol[T_co]): + """Protocol for reading items from a stream. + + Consumers that need the sender's context should use + ``getattr(stream, 'last_context', None)``. + """ + + async def receive(self) -> T_co: ... + async def aclose(self) -> None: ... + def __aiter__(self) -> ReadStream[T_co]: ... + async def __anext__(self) -> T_co: ... + async def __aenter__(self) -> Self: ... + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: ... + + +class WriteStream(Protocol[T_contra]): + """Protocol for writing items to a stream.""" + + async def send(self, item: T_contra, /) -> None: ... + async def aclose(self) -> None: ... + async def __aenter__(self) -> Self: ... + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: ... diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index f2d5e2b9ad..468590d095 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -5,12 +5,10 @@ from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream - +from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams from mcp.shared.message import SessionMessage -MessageStream = tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]] +MessageStream = tuple[ContextReceiveStream[SessionMessage | Exception], ContextSendStream[SessionMessage | Exception]] @asynccontextmanager @@ -22,8 +20,8 @@ async def create_client_server_memory_streams() -> AsyncGenerator[tuple[MessageS (read_stream, write_stream) """ # Create streams for both directions - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) + server_to_client_send, server_to_client_receive = create_context_streams[SessionMessage | Exception](1) + client_to_server_send, client_to_server_receive = create_context_streams[SessionMessage | Exception](1) client_streams = (server_to_client_receive, client_to_server_send) server_streams = (client_to_server_receive, server_to_client_send) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 6fc59923f7..3534fbb768 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextvars import logging from collections.abc import Callable from contextlib import AsyncExitStack @@ -7,10 +8,11 @@ from typing import Any, Generic, Protocol, TypeVar import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from anyio.streams.memory import MemoryObjectSendStream from pydantic import BaseModel, TypeAdapter from typing_extensions import Self +from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.exceptions import MCPError from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage from mcp.shared.response_router import ResponseRouter @@ -79,11 +81,13 @@ def __init__( session: BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT], on_complete: Callable[[RequestResponder[ReceiveRequestT, SendResultT]], Any], message_metadata: MessageMetadata = None, + context: contextvars.Context | None = None, ) -> None: self.request_id = request_id self.request_meta = request_meta self.request = request self.message_metadata = message_metadata + self.context = context self._session = session self._completed = False self._cancel_scope = anyio.CancelScope() @@ -181,8 +185,8 @@ class BaseSession( def __init__( self, - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], # If none, reading will never time out read_timeout_seconds: float | None = None, ) -> None: @@ -333,10 +337,10 @@ def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]: async def _receive_loop(self) -> None: async with self._read_stream, self._write_stream: try: - async for message in self._read_stream: - if isinstance(message, Exception): - await self._handle_incoming(message) - elif isinstance(message.message, JSONRPCRequest): + + async def _handle_session_message(message: SessionMessage) -> None: + sender_context: contextvars.Context | None = getattr(self._read_stream, "last_context", None) + if isinstance(message.message, JSONRPCRequest): try: validated_request = self._receive_request_adapter.validate_python( message.message.model_dump(by_alias=True, mode="json", exclude_none=True), @@ -349,6 +353,7 @@ async def _receive_loop(self) -> None: session=self, on_complete=lambda r: self._in_flight.pop(r.request_id, None), message_metadata=message.metadata, + context=sender_context, ) self._in_flight[responder.request_id] = responder await self._received_request(responder) @@ -406,6 +411,13 @@ async def _receive_loop(self) -> None: else: # Response or error await self._handle_response(message) + async for message in self._read_stream: + if isinstance(message, Exception): + await self._handle_incoming(message) + continue + + await _handle_session_message(message) + except anyio.ClosedResourceError: # This is expected when the client disconnects abruptly. # Without this handler, the exception would propagate up and diff --git a/tests/client/conftest.py b/tests/client/conftest.py index 2e39f13630..081e1d68ea 100644 --- a/tests/client/conftest.py +++ b/tests/client/conftest.py @@ -4,15 +4,15 @@ from unittest.mock import patch import pytest -from anyio.streams.memory import MemoryObjectSendStream import mcp.shared.memory +from mcp.client._transport import WriteStream from mcp.shared.message import SessionMessage from mcp.types import JSONRPCNotification, JSONRPCRequest class SpyMemoryObjectSendStream: - def __init__(self, original_stream: MemoryObjectSendStream[SessionMessage]): + def __init__(self, original_stream: WriteStream[SessionMessage]): self.original_stream = original_stream self.sent_messages: list[SessionMessage] = [] diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 18368e6bb3..ac52a9024a 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -2,6 +2,9 @@ from __future__ import annotations +import contextvars +from collections.abc import Iterator +from contextlib import contextmanager from unittest.mock import patch import anyio @@ -320,3 +323,33 @@ async def test_client_uses_transport_directly(app: MCPServer): structured_content={"result": "Hello, Transport!"}, ) ) + + +_TEST_CONTEXTVAR = contextvars.ContextVar("test_var", default="initial") + + +@contextmanager +def _set_test_contextvar(value: str) -> Iterator[None]: + token = _TEST_CONTEXTVAR.set(value) + try: + yield + finally: + _TEST_CONTEXTVAR.reset(token) + + +async def test_context_propagation(): + """Sender's contextvars.Context is propagated to the server handler.""" + server = MCPServer("test") + + @server.tool() + async def check_context() -> str: + """Return the contextvar value visible to the handler.""" + return _TEST_CONTEXTVAR.get() + + async with Client(server) as client: + with _set_test_contextvar("client_value"): + result = await client.call_tool("check_context", {}) + + assert result.content[0].text == "client_value", ( # type: ignore[union-attr] + "Server handler did not see the sender's contextvars.Context" + ) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index f8ca30441b..3d5770fb61 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -45,6 +45,7 @@ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings from mcp.shared._context import RequestContext +from mcp.shared._context_streams import create_context_streams from mcp.shared._httpx_utils import ( MCP_DEFAULT_SSE_READ_TIMEOUT, MCP_DEFAULT_TIMEOUT, @@ -1783,8 +1784,8 @@ async def test_handle_sse_event_skips_empty_data(): # Create a mock SSE event with empty data (keep-alive ping) mock_sse = ServerSentEvent(event="message", data="", id=None, retry=None) - # Create a mock stream writer - write_stream, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](1) + # Create a context-aware stream writer (matches StreamWriter type alias) + write_stream, read_stream = create_context_streams[SessionMessage | Exception](1) try: # Call _handle_sse_event with empty data - should return False and not raise @@ -1794,8 +1795,9 @@ async def test_handle_sse_event_skips_empty_data(): assert result is False # Nothing should have been written to the stream - # Check buffer is empty (statistics().current_buffer_used returns buffer size) - assert write_stream.statistics().current_buffer_used == 0 + with pytest.raises(TimeoutError): + with anyio.fail_after(0): + await read_stream.receive() finally: await write_stream.aclose() await read_stream.aclose() From 3ce0f76e6e3b33f035b6b28421e1b7c6dbe8c77f Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 31 Mar 2026 13:43:56 -0400 Subject: [PATCH 30/60] Don't block the event loop on sync resource and prompt functions (#2380) --- src/mcp/server/mcpserver/prompts/base.py | 10 ++-- .../server/mcpserver/resources/templates.py | 10 ++-- src/mcp/server/mcpserver/resources/types.py | 9 ++-- tests/server/mcpserver/prompts/test_base.py | 19 +++++++ .../resources/test_function_resources.py | 52 +++++++++++++++++++ .../resources/test_resource_template.py | 20 +++++++ 6 files changed, 107 insertions(+), 13 deletions(-) diff --git a/src/mcp/server/mcpserver/prompts/base.py b/src/mcp/server/mcpserver/prompts/base.py index 0c319d53cc..b4810c100e 100644 --- a/src/mcp/server/mcpserver/prompts/base.py +++ b/src/mcp/server/mcpserver/prompts/base.py @@ -2,10 +2,12 @@ from __future__ import annotations +import functools import inspect from collections.abc import Awaitable, Callable, Sequence from typing import TYPE_CHECKING, Any, Literal +import anyio.to_thread import pydantic_core from pydantic import BaseModel, Field, TypeAdapter, validate_call @@ -155,10 +157,10 @@ async def render( # Add context to arguments if needed call_args = inject_context(self.fn, arguments or {}, context, self.context_kwarg) - # Call function and check if result is a coroutine - result = self.fn(**call_args) - if inspect.iscoroutine(result): - result = await result + if inspect.iscoroutinefunction(self.fn): + result = await self.fn(**call_args) + else: + result = await anyio.to_thread.run_sync(functools.partial(self.fn, **call_args)) # Validate messages if not isinstance(result, list | tuple): diff --git a/src/mcp/server/mcpserver/resources/templates.py b/src/mcp/server/mcpserver/resources/templates.py index 2d612657c4..542b5e6f81 100644 --- a/src/mcp/server/mcpserver/resources/templates.py +++ b/src/mcp/server/mcpserver/resources/templates.py @@ -2,12 +2,14 @@ from __future__ import annotations +import functools import inspect import re from collections.abc import Callable from typing import TYPE_CHECKING, Any from urllib.parse import unquote +import anyio.to_thread from pydantic import BaseModel, Field, validate_call from mcp.server.mcpserver.resources.types import FunctionResource, Resource @@ -110,10 +112,10 @@ async def create_resource( # Add context to params if needed params = inject_context(self.fn, params, context, self.context_kwarg) - # Call function and check if result is a coroutine - result = self.fn(**params) - if inspect.iscoroutine(result): - result = await result + if inspect.iscoroutinefunction(self.fn): + result = await self.fn(**params) + else: + result = await anyio.to_thread.run_sync(functools.partial(self.fn, **params)) return FunctionResource( uri=uri, # type: ignore diff --git a/src/mcp/server/mcpserver/resources/types.py b/src/mcp/server/mcpserver/resources/types.py index 42aecd6e39..04763be8ba 100644 --- a/src/mcp/server/mcpserver/resources/types.py +++ b/src/mcp/server/mcpserver/resources/types.py @@ -55,11 +55,10 @@ class FunctionResource(Resource): async def read(self) -> str | bytes: """Read the resource by calling the wrapped function.""" try: - # Call the function first to see if it returns a coroutine - result = self.fn() - # If it's a coroutine, await it - if inspect.iscoroutine(result): - result = await result + if inspect.iscoroutinefunction(self.fn): + result = await self.fn() + else: + result = await anyio.to_thread.run_sync(self.fn) if isinstance(result, Resource): # pragma: no cover return await result.read() diff --git a/tests/server/mcpserver/prompts/test_base.py b/tests/server/mcpserver/prompts/test_base.py index fe18e91bd7..d4e4e6b5a6 100644 --- a/tests/server/mcpserver/prompts/test_base.py +++ b/tests/server/mcpserver/prompts/test_base.py @@ -1,3 +1,4 @@ +import threading from typing import Any import pytest @@ -190,3 +191,21 @@ async def fn() -> dict[str, Any]: ) ) ] + + +@pytest.mark.anyio +async def test_sync_fn_runs_in_worker_thread(): + """Sync prompt functions must run in a worker thread, not the event loop.""" + + main_thread = threading.get_ident() + fn_thread: list[int] = [] + + def blocking_fn() -> str: + fn_thread.append(threading.get_ident()) + return "hello" + + prompt = Prompt.from_function(blocking_fn) + messages = await prompt.render(None, Context()) + + assert messages == [UserMessage(content=TextContent(type="text", text="hello"))] + assert fn_thread[0] != main_thread diff --git a/tests/server/mcpserver/resources/test_function_resources.py b/tests/server/mcpserver/resources/test_function_resources.py index 5f5c216ed1..c1ff960617 100644 --- a/tests/server/mcpserver/resources/test_function_resources.py +++ b/tests/server/mcpserver/resources/test_function_resources.py @@ -1,3 +1,7 @@ +import threading + +import anyio +import anyio.from_thread import pytest from pydantic import BaseModel @@ -190,3 +194,51 @@ def get_data() -> str: # pragma: no cover ) assert resource.meta is None + + +@pytest.mark.anyio +async def test_sync_fn_runs_in_worker_thread(): + """Sync resource functions must run in a worker thread, not the event loop.""" + + main_thread = threading.get_ident() + fn_thread: list[int] = [] + + def blocking_fn() -> str: + fn_thread.append(threading.get_ident()) + return "data" + + resource = FunctionResource(uri="resource://test", name="test", fn=blocking_fn) + result = await resource.read() + + assert result == "data" + assert fn_thread[0] != main_thread + + +@pytest.mark.anyio +async def test_sync_fn_does_not_block_event_loop(): + """A blocking sync resource function must not stall the event loop. + + On regression (sync runs inline), anyio.from_thread.run_sync raises + RuntimeError because there is no worker-thread context, failing fast. + """ + handler_entered = anyio.Event() + release = threading.Event() + + def blocking_fn() -> str: + anyio.from_thread.run_sync(handler_entered.set) + release.wait() + return "done" + + resource = FunctionResource(uri="resource://test", name="test", fn=blocking_fn) + result: list[str | bytes] = [] + + async def run() -> None: + result.append(await resource.read()) + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + tg.start_soon(run) + await handler_entered.wait() + release.set() + + assert result == ["done"] diff --git a/tests/server/mcpserver/resources/test_resource_template.py b/tests/server/mcpserver/resources/test_resource_template.py index 640cfe8031..2a7ba8d503 100644 --- a/tests/server/mcpserver/resources/test_resource_template.py +++ b/tests/server/mcpserver/resources/test_resource_template.py @@ -1,4 +1,5 @@ import json +import threading from typing import Any import pytest @@ -310,3 +311,22 @@ def get_item(item_id: str) -> str: assert resource.meta == metadata assert resource.meta["category"] == "inventory" assert resource.meta["cacheable"] is True + + +@pytest.mark.anyio +async def test_sync_fn_runs_in_worker_thread(): + """Sync template functions must run in a worker thread, not the event loop.""" + + main_thread = threading.get_ident() + fn_thread: list[int] = [] + + def blocking_fn(name: str) -> str: + fn_thread.append(threading.get_ident()) + return f"hello {name}" + + template = ResourceTemplate.from_function(fn=blocking_fn, uri_template="test://{name}") + resource = await template.create_resource("test://world", {"name": "world"}, Context()) + + assert isinstance(resource, FunctionResource) + assert await resource.read() == "hello world" + assert fn_thread[0] != main_thread From 37891f42a409ac10a46fb1577e7fd3ec3f70d5eb Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 31 Mar 2026 16:33:33 -0400 Subject: [PATCH 31/60] Add basic OpenTelemetry tracing for client and server requests (#2381) --- pyproject.toml | 2 + src/mcp/server/lowlevel/server.py | 150 +++++++++++-------- src/mcp/shared/_otel.py | 36 +++++ src/mcp/shared/session.py | 50 ++++--- tests/shared/test_otel.py | 44 ++++++ uv.lock | 237 ++++++++++++++++++++++++++++++ 6 files changed, 436 insertions(+), 83 deletions(-) create mode 100644 src/mcp/shared/_otel.py create mode 100644 tests/shared/test_otel.py diff --git a/pyproject.toml b/pyproject.toml index 7d8b4a8743..be1200cff0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "pyjwt[crypto]>=2.10.1", "typing-extensions>=4.13.0", "typing-inspection>=0.4.1", + "opentelemetry-api>=1.28.0", ] [project.optional-dependencies] @@ -71,6 +72,7 @@ dev = [ "coverage[toml]>=7.10.7,<=7.13", "pillow>=12.0", "strict-no-cover", + "logfire>=3.0.0", ] docs = [ "mkdocs>=1.6.1", diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 0fdbbff866..59de0ace45 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -42,9 +42,10 @@ async def main(): from collections.abc import AsyncIterator, Awaitable, Callable from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager from importlib.metadata import version as importlib_version -from typing import Any, Generic +from typing import Any, Generic, cast import anyio +from opentelemetry.trace import SpanKind, StatusCode from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware @@ -65,6 +66,7 @@ async def main(): from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings +from mcp.shared._otel import extract_trace_context, otel_span from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.exceptions import MCPError from mcp.shared.message import ServerMessageMetadata, SessionMessage @@ -446,72 +448,90 @@ async def _handle_request( ): logger.info("Processing request of type %s", type(req).__name__) - if handler := self._request_handlers.get(req.method): - logger.debug("Dispatching request of type %s", type(req).__name__) + target = getattr(req.params, "name", None) if req.params else None + span_name = f"MCP handle {req.method} {target}" if target else f"MCP handle {req.method}" - try: - # Extract request context and close_sse_stream from message metadata - request_data = None - close_sse_stream_cb = None - close_standalone_sse_stream_cb = None - if message.message_metadata is not None and isinstance(message.message_metadata, ServerMessageMetadata): - request_data = message.message_metadata.request_context - close_sse_stream_cb = message.message_metadata.close_sse_stream - close_standalone_sse_stream_cb = message.message_metadata.close_standalone_sse_stream + # Extract W3C trace context from _meta (SEP-414). + meta = cast(dict[str, Any] | None, getattr(req.params, "meta", None)) if req.params else None + parent_context = extract_trace_context(meta) if meta is not None else None - client_capabilities = session.client_params.capabilities if session.client_params else None - task_support = self._experimental_handlers.task_support if self._experimental_handlers else None - # Get task metadata from request params if present - task_metadata = None - if hasattr(req, "params") and req.params is not None: - task_metadata = getattr(req.params, "task", None) - ctx = ServerRequestContext( - request_id=message.request_id, - meta=message.request_meta, - session=session, - lifespan_context=lifespan_context, - experimental=Experimental( - task_metadata=task_metadata, - _client_capabilities=client_capabilities, - _session=session, - _task_support=task_support, - ), - request=request_data, - close_sse_stream=close_sse_stream_cb, - close_standalone_sse_stream=close_standalone_sse_stream_cb, - ) - response = await handler(ctx, req.params) - except MCPError as err: - response = err.error - except anyio.get_cancelled_exc_class(): - if message.cancelled: - # Client sent CancelledNotification; responder.cancel() already - # sent an error response, so skip the duplicate. - logger.info("Request %s cancelled - duplicate response suppressed", message.request_id) - return - # Transport-close cancellation from the TG in run(); re-raise so the - # TG swallows its own cancellation. - raise - except Exception as err: - if raise_exceptions: # pragma: no cover - raise err - response = types.ErrorData(code=0, message=str(err)) - else: # pragma: no cover - response = types.ErrorData(code=types.METHOD_NOT_FOUND, message="Method not found") - - try: - await message.respond(response) - except (anyio.BrokenResourceError, anyio.ClosedResourceError): - # Transport closed between handler unblocking and respond. Happens - # when _receive_loop's finally wakes a handler blocked on - # send_request: the handler runs to respond() before run()'s TG - # cancel fires, but after the write stream closed. Closed if our - # end closed (_receive_loop's async-with exit); Broken if the peer - # end closed first (streamable_http terminate()). - logger.debug("Response for %s dropped - transport closed", message.request_id) - return - - logger.debug("Response sent") + with otel_span( + span_name, + kind=SpanKind.SERVER, + attributes={"mcp.method.name": req.method, "jsonrpc.request.id": message.request_id}, + context=parent_context, + ) as span: + if handler := self._request_handlers.get(req.method): + logger.debug("Dispatching request of type %s", type(req).__name__) + + try: + # Extract request context and close_sse_stream from message metadata + request_data = None + close_sse_stream_cb = None + close_standalone_sse_stream_cb = None + if message.message_metadata is not None and isinstance( + message.message_metadata, ServerMessageMetadata + ): + request_data = message.message_metadata.request_context + close_sse_stream_cb = message.message_metadata.close_sse_stream + close_standalone_sse_stream_cb = message.message_metadata.close_standalone_sse_stream + + client_capabilities = session.client_params.capabilities if session.client_params else None + task_support = self._experimental_handlers.task_support if self._experimental_handlers else None + # Get task metadata from request params if present + task_metadata = None + if hasattr(req, "params") and req.params is not None: # pragma: no branch + task_metadata = getattr(req.params, "task", None) + ctx = ServerRequestContext( + request_id=message.request_id, + meta=message.request_meta, + session=session, + lifespan_context=lifespan_context, + experimental=Experimental( + task_metadata=task_metadata, + _client_capabilities=client_capabilities, + _session=session, + _task_support=task_support, + ), + request=request_data, + close_sse_stream=close_sse_stream_cb, + close_standalone_sse_stream=close_standalone_sse_stream_cb, + ) + response = await handler(ctx, req.params) + except MCPError as err: + response = err.error + except anyio.get_cancelled_exc_class(): + if message.cancelled: + # Client sent CancelledNotification; responder.cancel() already + # sent an error response, so skip the duplicate. + logger.info("Request %s cancelled - duplicate response suppressed", message.request_id) + return + # Transport-close cancellation from the TG in run(); re-raise so the + # TG swallows its own cancellation. + raise + except Exception as err: + if raise_exceptions: # pragma: no cover + raise err + response = types.ErrorData(code=0, message=str(err)) + else: # pragma: no cover + response = types.ErrorData(code=types.METHOD_NOT_FOUND, message="Method not found") + + if isinstance(response, types.ErrorData) and span is not None: + span.set_status(StatusCode.ERROR, response.message) + + try: + await message.respond(response) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + # Transport closed between handler unblocking and respond. Happens + # when _receive_loop's finally wakes a handler blocked on + # send_request: the handler runs to respond() before run()'s TG + # cancel fires, but after the write stream closed. Closed if our + # end closed (_receive_loop's async-with exit); Broken if the peer + # end closed first (streamable_http terminate()). + logger.debug("Response for %s dropped - transport closed", message.request_id) + return + + logger.debug("Response sent") async def _handle_notification( self, diff --git a/src/mcp/shared/_otel.py b/src/mcp/shared/_otel.py new file mode 100644 index 0000000000..170e873a0f --- /dev/null +++ b/src/mcp/shared/_otel.py @@ -0,0 +1,36 @@ +"""OpenTelemetry helpers for MCP.""" + +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +from typing import Any + +from opentelemetry.context import Context +from opentelemetry.propagate import extract, inject +from opentelemetry.trace import SpanKind, get_tracer + +_tracer = get_tracer("mcp-python-sdk") + + +@contextmanager +def otel_span( + name: str, + *, + kind: SpanKind, + attributes: dict[str, Any] | None = None, + context: Context | None = None, +) -> Iterator[Any]: + """Create an OTel span.""" + with _tracer.start_as_current_span(name, kind=kind, attributes=attributes, context=context) as span: + yield span + + +def inject_trace_context(meta: dict[str, Any]) -> None: + """Inject W3C trace context (traceparent/tracestate) into a `_meta` dict.""" + inject(meta) + + +def extract_trace_context(meta: dict[str, Any]) -> Context: + """Extract W3C trace context from a `_meta` dict.""" + return extract(meta) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 3534fbb768..243eef5ae6 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -9,9 +9,11 @@ import anyio from anyio.streams.memory import MemoryObjectSendStream +from opentelemetry.trace import SpanKind from pydantic import BaseModel, TypeAdapter from typing_extensions import Self +from mcp.shared._otel import inject_trace_context, otel_span from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.exceptions import MCPError from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage @@ -268,24 +270,36 @@ async def send_request( self._progress_callbacks[request_id] = progress_callback try: - jsonrpc_request = JSONRPCRequest(jsonrpc="2.0", id=request_id, **request_data) - await self._write_stream.send(SessionMessage(message=jsonrpc_request, metadata=metadata)) - - # request read timeout takes precedence over session read timeout - timeout = request_read_timeout_seconds or self._session_read_timeout_seconds - - try: - with anyio.fail_after(timeout): - response_or_error = await response_stream_reader.receive() - except TimeoutError: - class_name = request.__class__.__name__ - message = f"Timed out while waiting for response to {class_name}. Waited {timeout} seconds." - raise MCPError(code=REQUEST_TIMEOUT, message=message) - - if isinstance(response_or_error, JSONRPCError): - raise MCPError.from_jsonrpc_error(response_or_error) - else: - return result_type.model_validate(response_or_error.result, by_name=False) + target = request_data.get("params", {}).get("name") + span_name = f"MCP send {request.method} {target}" if target else f"MCP send {request.method}" + + with otel_span( + span_name, + kind=SpanKind.CLIENT, + attributes={"mcp.method.name": request.method, "jsonrpc.request.id": request_id}, + ): + # Inject W3C trace context into _meta (SEP-414). + meta: dict[str, Any] = request_data.setdefault("params", {}).setdefault("_meta", {}) + inject_trace_context(meta) + + jsonrpc_request = JSONRPCRequest(jsonrpc="2.0", id=request_id, **request_data) + await self._write_stream.send(SessionMessage(message=jsonrpc_request, metadata=metadata)) + + # request read timeout takes precedence over session read timeout + timeout = request_read_timeout_seconds or self._session_read_timeout_seconds + + try: + with anyio.fail_after(timeout): + response_or_error = await response_stream_reader.receive() + except TimeoutError: + class_name = request.__class__.__name__ + message = f"Timed out while waiting for response to {class_name}. Waited {timeout} seconds." + raise MCPError(code=REQUEST_TIMEOUT, message=message) + + if isinstance(response_or_error, JSONRPCError): + raise MCPError.from_jsonrpc_error(response_or_error) + else: + return result_type.model_validate(response_or_error.result, by_name=False) finally: self._response_streams.pop(request_id, None) diff --git a/tests/shared/test_otel.py b/tests/shared/test_otel.py new file mode 100644 index 0000000000..ec7ff78cc1 --- /dev/null +++ b/tests/shared/test_otel.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import pytest +from logfire.testing import CaptureLogfire + +from mcp import types +from mcp.client.client import Client +from mcp.server.mcpserver import MCPServer + +pytestmark = pytest.mark.anyio + + +# Logfire warns about propagated trace context by default (distributed_tracing=None). +# This is expected here since we're testing cross-boundary context propagation. +@pytest.mark.filterwarnings("ignore::RuntimeWarning") +async def test_client_and_server_spans(capfire: CaptureLogfire): + """Verify that calling a tool produces client and server spans with correct attributes.""" + server = MCPServer("test") + + @server.tool() + def greet(name: str) -> str: + """Greet someone.""" + return f"Hello, {name}!" + + async with Client(server) as client: + result = await client.call_tool("greet", {"name": "World"}) + + assert isinstance(result.content[0], types.TextContent) + assert result.content[0].text == "Hello, World!" + + spans = capfire.exporter.exported_spans_as_dict() + span_names = {s["name"] for s in spans} + + assert "MCP send tools/call greet" in span_names + assert "MCP handle tools/call greet" in span_names + + client_span = next(s for s in spans if s["name"] == "MCP send tools/call greet") + server_span = next(s for s in spans if s["name"] == "MCP handle tools/call greet") + + assert client_span["attributes"]["mcp.method.name"] == "tools/call" + assert server_span["attributes"]["mcp.method.name"] == "tools/call" + + # Server span should be in the same trace as the client span (context propagation). + assert server_span["context"]["trace_id"] == client_span["context"]["trace_id"] diff --git a/uv.lock b/uv.lock index 8f9a5396a2..5efbb05dce 100644 --- a/uv.lock +++ b/uv.lock @@ -579,6 +579,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f7/ec/67fbef5d497f86283db54c22eec6f6140243aae73265799baaaa19cd17fb/ghp_import-2.1.0-py3-none-any.whl", hash = "sha256:8337dd7b50877f163d4c0289bc1f1c7f127550241988d568c1db512c4324a619", size = 11034, upload-time = "2022-05-02T15:47:14.552Z" }, ] +[[package]] +name = "googleapis-common-protos" +version = "1.73.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a1/c0/4a54c386282c13449eca8bbe2ddb518181dc113e78d240458a68856b4d69/googleapis_common_protos-1.73.1.tar.gz", hash = "sha256:13114f0e9d2391756a0194c3a8131974ed7bffb06086569ba193364af59163b6", size = 147506, upload-time = "2026-03-26T22:17:38.451Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dc/82/fcb6520612bec0c39b973a6c0954b6a0d948aadfe8f7e9487f60ceb8bfa6/googleapis_common_protos-1.73.1-py3-none-any.whl", hash = "sha256:e51f09eb0a43a8602f5a915870972e6b4a394088415c79d79605a46d8e826ee8", size = 297556, upload-time = "2026-03-26T22:15:58.455Z" }, +] + [[package]] name = "griffe" version = "1.14.0" @@ -646,6 +658,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" }, ] +[[package]] +name = "importlib-metadata" +version = "8.7.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "zipp" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f3/49/3b30cad09e7771a4982d9975a8cbf64f00d4a1ececb53297f1d9a7be1b10/importlib_metadata-8.7.1.tar.gz", hash = "sha256:49fef1ae6440c182052f407c8d34a68f72efc36db9ca90dc0113398f2fdde8bb", size = 57107, upload-time = "2025-12-21T10:00:19.278Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fa/5e/f8e9a1d23b9c20a551a8a02ea3637b4642e22c2626e3a13a9a29cdea99eb/importlib_metadata-8.7.1-py3-none-any.whl", hash = "sha256:5a1f80bf1daa489495071efbb095d75a634cf28a8bc299581244063b53176151", size = 27865, upload-time = "2025-12-21T10:00:18.329Z" }, +] + [[package]] name = "iniconfig" version = "2.1.0" @@ -710,6 +734,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/41/45/1a4ed80516f02155c51f51e8cedb3c1902296743db0bbc66608a0db2814f/jsonschema_specifications-2025.9.1-py3-none-any.whl", hash = "sha256:98802fee3a11ee76ecaca44429fda8a41bff98b00a0f2838151b113f210cc6fe", size = 18437, upload-time = "2025-09-08T01:34:57.871Z" }, ] +[[package]] +name = "logfire" +version = "4.31.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "executing" }, + { name = "opentelemetry-exporter-otlp-proto-http" }, + { name = "opentelemetry-instrumentation" }, + { name = "opentelemetry-sdk" }, + { name = "protobuf" }, + { name = "rich" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/61/fc/21f923243d8c3ca2ebfa97de46970ced734e66ac634c1c35b6abb41300f1/logfire-4.31.0.tar.gz", hash = "sha256:361bfda17c9d70ada5d220211033bae06b871ddac9d5b06978bc0ceca6b8e658", size = 1080609, upload-time = "2026-03-27T19:00:46.339Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/49/1a/8c860e35bf847ac0d647d94bad89dccbb66cbcafdd61d8334f8cc7cfdd58/logfire-4.31.0-py3-none-any.whl", hash = "sha256:49fad38b5e6f199a98e9c8814e860c8a42595bb81479b52a20413e53ee475b72", size = 308896, upload-time = "2026-03-27T19:00:43.107Z" }, +] + [[package]] name = "markdown" version = "3.9" @@ -797,6 +840,7 @@ dependencies = [ { name = "httpx" }, { name = "httpx-sse" }, { name = "jsonschema" }, + { name = "opentelemetry-api" }, { name = "pydantic" }, { name = "pydantic-settings" }, { name = "pyjwt", extra = ["crypto"] }, @@ -826,6 +870,7 @@ dev = [ { name = "coverage", extra = ["toml"] }, { name = "dirty-equals" }, { name = "inline-snapshot" }, + { name = "logfire" }, { name = "mcp", extra = ["cli", "ws"] }, { name = "pillow" }, { name = "pyright" }, @@ -853,6 +898,7 @@ requires-dist = [ { name = "httpx", specifier = ">=0.27.1,<1.0.0" }, { name = "httpx-sse", specifier = ">=0.4" }, { name = "jsonschema", specifier = ">=4.20.0" }, + { name = "opentelemetry-api", specifier = ">=1.28.0" }, { name = "pydantic", specifier = ">=2.12.0" }, { name = "pydantic-settings", specifier = ">=2.5.2" }, { name = "pyjwt", extras = ["crypto"], specifier = ">=2.10.1" }, @@ -876,6 +922,7 @@ dev = [ { name = "coverage", extras = ["toml"], specifier = ">=7.10.7,<=7.13" }, { name = "dirty-equals", specifier = ">=0.9.0" }, { name = "inline-snapshot", specifier = ">=0.23.0" }, + { name = "logfire", specifier = ">=3.0.0" }, { name = "mcp", extras = ["cli", "ws"], editable = "." }, { name = "pillow", specifier = ">=12.0" }, { name = "pyright", specifier = ">=1.1.400" }, @@ -1642,6 +1689,103 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314, upload-time = "2024-06-04T18:44:08.352Z" }, ] +[[package]] +name = "opentelemetry-api" +version = "1.39.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "importlib-metadata" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/97/b9/3161be15bb8e3ad01be8be5a968a9237c3027c5be504362ff800fca3e442/opentelemetry_api-1.39.1.tar.gz", hash = "sha256:fbde8c80e1b937a2c61f20347e91c0c18a1940cecf012d62e65a7caf08967c9c", size = 65767, upload-time = "2025-12-11T13:32:39.182Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cf/df/d3f1ddf4bb4cb50ed9b1139cc7b1c54c34a1e7ce8fd1b9a37c0d1551a6bd/opentelemetry_api-1.39.1-py3-none-any.whl", hash = "sha256:2edd8463432a7f8443edce90972169b195e7d6a05500cd29e6d13898187c9950", size = 66356, upload-time = "2025-12-11T13:32:17.304Z" }, +] + +[[package]] +name = "opentelemetry-exporter-otlp-proto-common" +version = "1.39.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-proto" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e9/9d/22d241b66f7bbde88a3bfa6847a351d2c46b84de23e71222c6aae25c7050/opentelemetry_exporter_otlp_proto_common-1.39.1.tar.gz", hash = "sha256:763370d4737a59741c89a67b50f9e39271639ee4afc999dadfe768541c027464", size = 20409, upload-time = "2025-12-11T13:32:40.885Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8c/02/ffc3e143d89a27ac21fd557365b98bd0653b98de8a101151d5805b5d4c33/opentelemetry_exporter_otlp_proto_common-1.39.1-py3-none-any.whl", hash = "sha256:08f8a5862d64cc3435105686d0216c1365dc5701f86844a8cd56597d0c764fde", size = 18366, upload-time = "2025-12-11T13:32:20.2Z" }, +] + +[[package]] +name = "opentelemetry-exporter-otlp-proto-http" +version = "1.39.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "googleapis-common-protos" }, + { name = "opentelemetry-api" }, + { name = "opentelemetry-exporter-otlp-proto-common" }, + { name = "opentelemetry-proto" }, + { name = "opentelemetry-sdk" }, + { name = "requests" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/80/04/2a08fa9c0214ae38880df01e8bfae12b067ec0793446578575e5080d6545/opentelemetry_exporter_otlp_proto_http-1.39.1.tar.gz", hash = "sha256:31bdab9745c709ce90a49a0624c2bd445d31a28ba34275951a6a362d16a0b9cb", size = 17288, upload-time = "2025-12-11T13:32:42.029Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/95/f1/b27d3e2e003cd9a3592c43d099d2ed8d0a947c15281bf8463a256db0b46c/opentelemetry_exporter_otlp_proto_http-1.39.1-py3-none-any.whl", hash = "sha256:d9f5207183dd752a412c4cd564ca8875ececba13be6e9c6c370ffb752fd59985", size = 19641, upload-time = "2025-12-11T13:32:22.248Z" }, +] + +[[package]] +name = "opentelemetry-instrumentation" +version = "0.60b1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-semantic-conventions" }, + { name = "packaging" }, + { name = "wrapt" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/41/0f/7e6b713ac117c1f5e4e3300748af699b9902a2e5e34c9cf443dde25a01fa/opentelemetry_instrumentation-0.60b1.tar.gz", hash = "sha256:57ddc7974c6eb35865af0426d1a17132b88b2ed8586897fee187fd5b8944bd6a", size = 31706, upload-time = "2025-12-11T13:36:42.515Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/77/d2/6788e83c5c86a2690101681aeef27eeb2a6bf22df52d3f263a22cee20915/opentelemetry_instrumentation-0.60b1-py3-none-any.whl", hash = "sha256:04480db952b48fb1ed0073f822f0ee26012b7be7c3eac1a3793122737c78632d", size = 33096, upload-time = "2025-12-11T13:35:33.067Z" }, +] + +[[package]] +name = "opentelemetry-proto" +version = "1.39.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/49/1d/f25d76d8260c156c40c97c9ed4511ec0f9ce353f8108ca6e7561f82a06b2/opentelemetry_proto-1.39.1.tar.gz", hash = "sha256:6c8e05144fc0d3ed4d22c2289c6b126e03bcd0e6a7da0f16cedd2e1c2772e2c8", size = 46152, upload-time = "2025-12-11T13:32:48.681Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/51/95/b40c96a7b5203005a0b03d8ce8cd212ff23f1793d5ba289c87a097571b18/opentelemetry_proto-1.39.1-py3-none-any.whl", hash = "sha256:22cdc78efd3b3765d09e68bfbd010d4fc254c9818afd0b6b423387d9dee46007", size = 72535, upload-time = "2025-12-11T13:32:33.866Z" }, +] + +[[package]] +name = "opentelemetry-sdk" +version = "1.39.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-semantic-conventions" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/eb/fb/c76080c9ba07e1e8235d24cdcc4d125ef7aa3edf23eb4e497c2e50889adc/opentelemetry_sdk-1.39.1.tar.gz", hash = "sha256:cf4d4563caf7bff906c9f7967e2be22d0d6b349b908be0d90fb21c8e9c995cc6", size = 171460, upload-time = "2025-12-11T13:32:49.369Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7c/98/e91cf858f203d86f4eccdf763dcf01cf03f1dae80c3750f7e635bfa206b6/opentelemetry_sdk-1.39.1-py3-none-any.whl", hash = "sha256:4d5482c478513ecb0a5d938dcc61394e647066e0cc2676bee9f3af3f3f45f01c", size = 132565, upload-time = "2025-12-11T13:32:35.069Z" }, +] + +[[package]] +name = "opentelemetry-semantic-conventions" +version = "0.60b1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/91/df/553f93ed38bf22f4b999d9be9c185adb558982214f33eae539d3b5cd0858/opentelemetry_semantic_conventions-0.60b1.tar.gz", hash = "sha256:87c228b5a0669b748c76d76df6c364c369c28f1c465e50f661e39737e84bc953", size = 137935, upload-time = "2025-12-11T13:32:50.487Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/5e/5958555e09635d09b75de3c4f8b9cae7335ca545d77392ffe7331534c402/opentelemetry_semantic_conventions-0.60b1-py3-none-any.whl", hash = "sha256:9fa8c8b0c110da289809292b0591220d3a7b53c1526a23021e977d68597893fb", size = 219982, upload-time = "2025-12-11T13:32:36.955Z" }, +] + [[package]] name = "outcome" version = "1.3.0.post0" @@ -1797,6 +1941,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "protobuf" +version = "6.33.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/66/70/e908e9c5e52ef7c3a6c7902c9dfbb34c7e29c25d2f81ade3856445fd5c94/protobuf-6.33.6.tar.gz", hash = "sha256:a6768d25248312c297558af96a9f9c929e8c4cee0659cb07e780731095f38135", size = 444531, upload-time = "2026-03-18T19:05:00.988Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fc/9f/2f509339e89cfa6f6a4c4ff50438db9ca488dec341f7e454adad60150b00/protobuf-6.33.6-cp310-abi3-win32.whl", hash = "sha256:7d29d9b65f8afef196f8334e80d6bc1d5d4adedb449971fefd3723824e6e77d3", size = 425739, upload-time = "2026-03-18T19:04:48.373Z" }, + { url = "https://files.pythonhosted.org/packages/76/5d/683efcd4798e0030c1bab27374fd13a89f7c2515fb1f3123efdfaa5eab57/protobuf-6.33.6-cp310-abi3-win_amd64.whl", hash = "sha256:0cd27b587afca21b7cfa59a74dcbd48a50f0a6400cfb59391340ad729d91d326", size = 437089, upload-time = "2026-03-18T19:04:50.381Z" }, + { url = "https://files.pythonhosted.org/packages/5c/01/a3c3ed5cd186f39e7880f8303cc51385a198a81469d53d0fdecf1f64d929/protobuf-6.33.6-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:9720e6961b251bde64edfdab7d500725a2af5280f3f4c87e57c0208376aa8c3a", size = 427737, upload-time = "2026-03-18T19:04:51.866Z" }, + { url = "https://files.pythonhosted.org/packages/ee/90/b3c01fdec7d2f627b3a6884243ba328c1217ed2d978def5c12dc50d328a3/protobuf-6.33.6-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:e2afbae9b8e1825e3529f88d514754e094278bb95eadc0e199751cdd9a2e82a2", size = 324610, upload-time = "2026-03-18T19:04:53.096Z" }, + { url = "https://files.pythonhosted.org/packages/9b/ca/25afc144934014700c52e05103c2421997482d561f3101ff352e1292fb81/protobuf-6.33.6-cp39-abi3-manylinux2014_s390x.whl", hash = "sha256:c96c37eec15086b79762ed265d59ab204dabc53056e3443e702d2681f4b39ce3", size = 339381, upload-time = "2026-03-18T19:04:54.616Z" }, + { url = "https://files.pythonhosted.org/packages/16/92/d1e32e3e0d894fe00b15ce28ad4944ab692713f2e7f0a99787405e43533a/protobuf-6.33.6-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:e9db7e292e0ab79dd108d7f1a94fe31601ce1ee3f7b79e0692043423020b0593", size = 323436, upload-time = "2026-03-18T19:04:55.768Z" }, + { url = "https://files.pythonhosted.org/packages/c4/72/02445137af02769918a93807b2b7890047c32bfb9f90371cbc12688819eb/protobuf-6.33.6-py3-none-any.whl", hash = "sha256:77179e006c476e69bf8e8ce866640091ec42e1beb80b213c3900006ecfba6901", size = 170656, upload-time = "2026-03-18T19:04:59.826Z" }, +] + [[package]] name = "pycparser" version = "2.23" @@ -2766,3 +2925,81 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/68/a1/dcb68430b1d00b698ae7a7e0194433bce4f07ded185f0ee5fb21e2a2e91e/websockets-15.0.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:cad21560da69f4ce7658ca2cb83138fb4cf695a2ba3e475e0559e05991aa8122", size = 176884, upload-time = "2025-03-05T20:03:27.934Z" }, { url = "https://files.pythonhosted.org/packages/fa/a8/5b41e0da817d64113292ab1f8247140aac61cbf6cfd085d6a0fa77f4984f/websockets-15.0.1-py3-none-any.whl", hash = "sha256:f7a866fbc1e97b5c617ee4116daaa09b722101d4a3c170c787450ba409f9736f", size = 169743, upload-time = "2025-03-05T20:03:39.41Z" }, ] + +[[package]] +name = "wrapt" +version = "1.17.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/95/8f/aeb76c5b46e273670962298c23e7ddde79916cb74db802131d49a85e4b7d/wrapt-1.17.3.tar.gz", hash = "sha256:f66eb08feaa410fe4eebd17f2a2c8e2e46d3476e9f8c783daa8e09e0faa666d0", size = 55547, upload-time = "2025-08-12T05:53:21.714Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3f/23/bb82321b86411eb51e5a5db3fb8f8032fd30bd7c2d74bfe936136b2fa1d6/wrapt-1.17.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:88bbae4d40d5a46142e70d58bf664a89b6b4befaea7b2ecc14e03cedb8e06c04", size = 53482, upload-time = "2025-08-12T05:51:44.467Z" }, + { url = "https://files.pythonhosted.org/packages/45/69/f3c47642b79485a30a59c63f6d739ed779fb4cc8323205d047d741d55220/wrapt-1.17.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e6b13af258d6a9ad602d57d889f83b9d5543acd471eee12eb51f5b01f8eb1bc2", size = 38676, upload-time = "2025-08-12T05:51:32.636Z" }, + { url = "https://files.pythonhosted.org/packages/d1/71/e7e7f5670c1eafd9e990438e69d8fb46fa91a50785332e06b560c869454f/wrapt-1.17.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd341868a4b6714a5962c1af0bd44f7c404ef78720c7de4892901e540417111c", size = 38957, upload-time = "2025-08-12T05:51:54.655Z" }, + { url = "https://files.pythonhosted.org/packages/de/17/9f8f86755c191d6779d7ddead1a53c7a8aa18bccb7cea8e7e72dfa6a8a09/wrapt-1.17.3-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f9b2601381be482f70e5d1051a5965c25fb3625455a2bf520b5a077b22afb775", size = 81975, upload-time = "2025-08-12T05:52:30.109Z" }, + { url = "https://files.pythonhosted.org/packages/f2/15/dd576273491f9f43dd09fce517f6c2ce6eb4fe21681726068db0d0467096/wrapt-1.17.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:343e44b2a8e60e06a7e0d29c1671a0d9951f59174f3709962b5143f60a2a98bd", size = 83149, upload-time = "2025-08-12T05:52:09.316Z" }, + { url = "https://files.pythonhosted.org/packages/0c/c4/5eb4ce0d4814521fee7aa806264bf7a114e748ad05110441cd5b8a5c744b/wrapt-1.17.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:33486899acd2d7d3066156b03465b949da3fd41a5da6e394ec49d271baefcf05", size = 82209, upload-time = "2025-08-12T05:52:10.331Z" }, + { url = "https://files.pythonhosted.org/packages/31/4b/819e9e0eb5c8dc86f60dfc42aa4e2c0d6c3db8732bce93cc752e604bb5f5/wrapt-1.17.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e6f40a8aa5a92f150bdb3e1c44b7e98fb7113955b2e5394122fa5532fec4b418", size = 81551, upload-time = "2025-08-12T05:52:31.137Z" }, + { url = "https://files.pythonhosted.org/packages/f8/83/ed6baf89ba3a56694700139698cf703aac9f0f9eb03dab92f57551bd5385/wrapt-1.17.3-cp310-cp310-win32.whl", hash = "sha256:a36692b8491d30a8c75f1dfee65bef119d6f39ea84ee04d9f9311f83c5ad9390", size = 36464, upload-time = "2025-08-12T05:53:01.204Z" }, + { url = "https://files.pythonhosted.org/packages/2f/90/ee61d36862340ad7e9d15a02529df6b948676b9a5829fd5e16640156627d/wrapt-1.17.3-cp310-cp310-win_amd64.whl", hash = "sha256:afd964fd43b10c12213574db492cb8f73b2f0826c8df07a68288f8f19af2ebe6", size = 38748, upload-time = "2025-08-12T05:53:00.209Z" }, + { url = "https://files.pythonhosted.org/packages/bd/c3/cefe0bd330d389c9983ced15d326f45373f4073c9f4a8c2f99b50bfea329/wrapt-1.17.3-cp310-cp310-win_arm64.whl", hash = "sha256:af338aa93554be859173c39c85243970dc6a289fa907402289eeae7543e1ae18", size = 36810, upload-time = "2025-08-12T05:52:51.906Z" }, + { url = "https://files.pythonhosted.org/packages/52/db/00e2a219213856074a213503fdac0511203dceefff26e1daa15250cc01a0/wrapt-1.17.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:273a736c4645e63ac582c60a56b0acb529ef07f78e08dc6bfadf6a46b19c0da7", size = 53482, upload-time = "2025-08-12T05:51:45.79Z" }, + { url = "https://files.pythonhosted.org/packages/5e/30/ca3c4a5eba478408572096fe9ce36e6e915994dd26a4e9e98b4f729c06d9/wrapt-1.17.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5531d911795e3f935a9c23eb1c8c03c211661a5060aab167065896bbf62a5f85", size = 38674, upload-time = "2025-08-12T05:51:34.629Z" }, + { url = "https://files.pythonhosted.org/packages/31/25/3e8cc2c46b5329c5957cec959cb76a10718e1a513309c31399a4dad07eb3/wrapt-1.17.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0610b46293c59a3adbae3dee552b648b984176f8562ee0dba099a56cfbe4df1f", size = 38959, upload-time = "2025-08-12T05:51:56.074Z" }, + { url = "https://files.pythonhosted.org/packages/5d/8f/a32a99fc03e4b37e31b57cb9cefc65050ea08147a8ce12f288616b05ef54/wrapt-1.17.3-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b32888aad8b6e68f83a8fdccbf3165f5469702a7544472bdf41f582970ed3311", size = 82376, upload-time = "2025-08-12T05:52:32.134Z" }, + { url = "https://files.pythonhosted.org/packages/31/57/4930cb8d9d70d59c27ee1332a318c20291749b4fba31f113c2f8ac49a72e/wrapt-1.17.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8cccf4f81371f257440c88faed6b74f1053eef90807b77e31ca057b2db74edb1", size = 83604, upload-time = "2025-08-12T05:52:11.663Z" }, + { url = "https://files.pythonhosted.org/packages/a8/f3/1afd48de81d63dd66e01b263a6fbb86e1b5053b419b9b33d13e1f6d0f7d0/wrapt-1.17.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d8a210b158a34164de8bb68b0e7780041a903d7b00c87e906fb69928bf7890d5", size = 82782, upload-time = "2025-08-12T05:52:12.626Z" }, + { url = "https://files.pythonhosted.org/packages/1e/d7/4ad5327612173b144998232f98a85bb24b60c352afb73bc48e3e0d2bdc4e/wrapt-1.17.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:79573c24a46ce11aab457b472efd8d125e5a51da2d1d24387666cd85f54c05b2", size = 82076, upload-time = "2025-08-12T05:52:33.168Z" }, + { url = "https://files.pythonhosted.org/packages/bb/59/e0adfc831674a65694f18ea6dc821f9fcb9ec82c2ce7e3d73a88ba2e8718/wrapt-1.17.3-cp311-cp311-win32.whl", hash = "sha256:c31eebe420a9a5d2887b13000b043ff6ca27c452a9a22fa71f35f118e8d4bf89", size = 36457, upload-time = "2025-08-12T05:53:03.936Z" }, + { url = "https://files.pythonhosted.org/packages/83/88/16b7231ba49861b6f75fc309b11012ede4d6b0a9c90969d9e0db8d991aeb/wrapt-1.17.3-cp311-cp311-win_amd64.whl", hash = "sha256:0b1831115c97f0663cb77aa27d381237e73ad4f721391a9bfb2fe8bc25fa6e77", size = 38745, upload-time = "2025-08-12T05:53:02.885Z" }, + { url = "https://files.pythonhosted.org/packages/9a/1e/c4d4f3398ec073012c51d1c8d87f715f56765444e1a4b11e5180577b7e6e/wrapt-1.17.3-cp311-cp311-win_arm64.whl", hash = "sha256:5a7b3c1ee8265eb4c8f1b7d29943f195c00673f5ab60c192eba2d4a7eae5f46a", size = 36806, upload-time = "2025-08-12T05:52:53.368Z" }, + { url = "https://files.pythonhosted.org/packages/9f/41/cad1aba93e752f1f9268c77270da3c469883d56e2798e7df6240dcb2287b/wrapt-1.17.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ab232e7fdb44cdfbf55fc3afa31bcdb0d8980b9b95c38b6405df2acb672af0e0", size = 53998, upload-time = "2025-08-12T05:51:47.138Z" }, + { url = "https://files.pythonhosted.org/packages/60/f8/096a7cc13097a1869fe44efe68dace40d2a16ecb853141394047f0780b96/wrapt-1.17.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:9baa544e6acc91130e926e8c802a17f3b16fbea0fd441b5a60f5cf2cc5c3deba", size = 39020, upload-time = "2025-08-12T05:51:35.906Z" }, + { url = "https://files.pythonhosted.org/packages/33/df/bdf864b8997aab4febb96a9ae5c124f700a5abd9b5e13d2a3214ec4be705/wrapt-1.17.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6b538e31eca1a7ea4605e44f81a48aa24c4632a277431a6ed3f328835901f4fd", size = 39098, upload-time = "2025-08-12T05:51:57.474Z" }, + { url = "https://files.pythonhosted.org/packages/9f/81/5d931d78d0eb732b95dc3ddaeeb71c8bb572fb01356e9133916cd729ecdd/wrapt-1.17.3-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:042ec3bb8f319c147b1301f2393bc19dba6e176b7da446853406d041c36c7828", size = 88036, upload-time = "2025-08-12T05:52:34.784Z" }, + { url = "https://files.pythonhosted.org/packages/ca/38/2e1785df03b3d72d34fc6252d91d9d12dc27a5c89caef3335a1bbb8908ca/wrapt-1.17.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3af60380ba0b7b5aeb329bc4e402acd25bd877e98b3727b0135cb5c2efdaefe9", size = 88156, upload-time = "2025-08-12T05:52:13.599Z" }, + { url = "https://files.pythonhosted.org/packages/b3/8b/48cdb60fe0603e34e05cffda0b2a4adab81fd43718e11111a4b0100fd7c1/wrapt-1.17.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0b02e424deef65c9f7326d8c19220a2c9040c51dc165cddb732f16198c168396", size = 87102, upload-time = "2025-08-12T05:52:14.56Z" }, + { url = "https://files.pythonhosted.org/packages/3c/51/d81abca783b58f40a154f1b2c56db1d2d9e0d04fa2d4224e357529f57a57/wrapt-1.17.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:74afa28374a3c3a11b3b5e5fca0ae03bef8450d6aa3ab3a1e2c30e3a75d023dc", size = 87732, upload-time = "2025-08-12T05:52:36.165Z" }, + { url = "https://files.pythonhosted.org/packages/9e/b1/43b286ca1392a006d5336412d41663eeef1ad57485f3e52c767376ba7e5a/wrapt-1.17.3-cp312-cp312-win32.whl", hash = "sha256:4da9f45279fff3543c371d5ababc57a0384f70be244de7759c85a7f989cb4ebe", size = 36705, upload-time = "2025-08-12T05:53:07.123Z" }, + { url = "https://files.pythonhosted.org/packages/28/de/49493f962bd3c586ab4b88066e967aa2e0703d6ef2c43aa28cb83bf7b507/wrapt-1.17.3-cp312-cp312-win_amd64.whl", hash = "sha256:e71d5c6ebac14875668a1e90baf2ea0ef5b7ac7918355850c0908ae82bcb297c", size = 38877, upload-time = "2025-08-12T05:53:05.436Z" }, + { url = "https://files.pythonhosted.org/packages/f1/48/0f7102fe9cb1e8a5a77f80d4f0956d62d97034bbe88d33e94699f99d181d/wrapt-1.17.3-cp312-cp312-win_arm64.whl", hash = "sha256:604d076c55e2fdd4c1c03d06dc1a31b95130010517b5019db15365ec4a405fc6", size = 36885, upload-time = "2025-08-12T05:52:54.367Z" }, + { url = "https://files.pythonhosted.org/packages/fc/f6/759ece88472157acb55fc195e5b116e06730f1b651b5b314c66291729193/wrapt-1.17.3-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a47681378a0439215912ef542c45a783484d4dd82bac412b71e59cf9c0e1cea0", size = 54003, upload-time = "2025-08-12T05:51:48.627Z" }, + { url = "https://files.pythonhosted.org/packages/4f/a9/49940b9dc6d47027dc850c116d79b4155f15c08547d04db0f07121499347/wrapt-1.17.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:54a30837587c6ee3cd1a4d1c2ec5d24e77984d44e2f34547e2323ddb4e22eb77", size = 39025, upload-time = "2025-08-12T05:51:37.156Z" }, + { url = "https://files.pythonhosted.org/packages/45/35/6a08de0f2c96dcdd7fe464d7420ddb9a7655a6561150e5fc4da9356aeaab/wrapt-1.17.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:16ecf15d6af39246fe33e507105d67e4b81d8f8d2c6598ff7e3ca1b8a37213f7", size = 39108, upload-time = "2025-08-12T05:51:58.425Z" }, + { url = "https://files.pythonhosted.org/packages/0c/37/6faf15cfa41bf1f3dba80cd3f5ccc6622dfccb660ab26ed79f0178c7497f/wrapt-1.17.3-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:6fd1ad24dc235e4ab88cda009e19bf347aabb975e44fd5c2fb22a3f6e4141277", size = 88072, upload-time = "2025-08-12T05:52:37.53Z" }, + { url = "https://files.pythonhosted.org/packages/78/f2/efe19ada4a38e4e15b6dff39c3e3f3f73f5decf901f66e6f72fe79623a06/wrapt-1.17.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0ed61b7c2d49cee3c027372df5809a59d60cf1b6c2f81ee980a091f3afed6a2d", size = 88214, upload-time = "2025-08-12T05:52:15.886Z" }, + { url = "https://files.pythonhosted.org/packages/40/90/ca86701e9de1622b16e09689fc24b76f69b06bb0150990f6f4e8b0eeb576/wrapt-1.17.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:423ed5420ad5f5529db9ce89eac09c8a2f97da18eb1c870237e84c5a5c2d60aa", size = 87105, upload-time = "2025-08-12T05:52:17.914Z" }, + { url = "https://files.pythonhosted.org/packages/fd/e0/d10bd257c9a3e15cbf5523025252cc14d77468e8ed644aafb2d6f54cb95d/wrapt-1.17.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e01375f275f010fcbf7f643b4279896d04e571889b8a5b3f848423d91bf07050", size = 87766, upload-time = "2025-08-12T05:52:39.243Z" }, + { url = "https://files.pythonhosted.org/packages/e8/cf/7d848740203c7b4b27eb55dbfede11aca974a51c3d894f6cc4b865f42f58/wrapt-1.17.3-cp313-cp313-win32.whl", hash = "sha256:53e5e39ff71b3fc484df8a522c933ea2b7cdd0d5d15ae82e5b23fde87d44cbd8", size = 36711, upload-time = "2025-08-12T05:53:10.074Z" }, + { url = "https://files.pythonhosted.org/packages/57/54/35a84d0a4d23ea675994104e667ceff49227ce473ba6a59ba2c84f250b74/wrapt-1.17.3-cp313-cp313-win_amd64.whl", hash = "sha256:1f0b2f40cf341ee8cc1a97d51ff50dddb9fcc73241b9143ec74b30fc4f44f6cb", size = 38885, upload-time = "2025-08-12T05:53:08.695Z" }, + { url = "https://files.pythonhosted.org/packages/01/77/66e54407c59d7b02a3c4e0af3783168fff8e5d61def52cda8728439d86bc/wrapt-1.17.3-cp313-cp313-win_arm64.whl", hash = "sha256:7425ac3c54430f5fc5e7b6f41d41e704db073309acfc09305816bc6a0b26bb16", size = 36896, upload-time = "2025-08-12T05:52:55.34Z" }, + { url = "https://files.pythonhosted.org/packages/02/a2/cd864b2a14f20d14f4c496fab97802001560f9f41554eef6df201cd7f76c/wrapt-1.17.3-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:cf30f6e3c077c8e6a9a7809c94551203c8843e74ba0c960f4a98cd80d4665d39", size = 54132, upload-time = "2025-08-12T05:51:49.864Z" }, + { url = "https://files.pythonhosted.org/packages/d5/46/d011725b0c89e853dc44cceb738a307cde5d240d023d6d40a82d1b4e1182/wrapt-1.17.3-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:e228514a06843cae89621384cfe3a80418f3c04aadf8a3b14e46a7be704e4235", size = 39091, upload-time = "2025-08-12T05:51:38.935Z" }, + { url = "https://files.pythonhosted.org/packages/2e/9e/3ad852d77c35aae7ddebdbc3b6d35ec8013af7d7dddad0ad911f3d891dae/wrapt-1.17.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:5ea5eb3c0c071862997d6f3e02af1d055f381b1d25b286b9d6644b79db77657c", size = 39172, upload-time = "2025-08-12T05:51:59.365Z" }, + { url = "https://files.pythonhosted.org/packages/c3/f7/c983d2762bcce2326c317c26a6a1e7016f7eb039c27cdf5c4e30f4160f31/wrapt-1.17.3-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:281262213373b6d5e4bb4353bc36d1ba4084e6d6b5d242863721ef2bf2c2930b", size = 87163, upload-time = "2025-08-12T05:52:40.965Z" }, + { url = "https://files.pythonhosted.org/packages/e4/0f/f673f75d489c7f22d17fe0193e84b41540d962f75fce579cf6873167c29b/wrapt-1.17.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:dc4a8d2b25efb6681ecacad42fca8859f88092d8732b170de6a5dddd80a1c8fa", size = 87963, upload-time = "2025-08-12T05:52:20.326Z" }, + { url = "https://files.pythonhosted.org/packages/df/61/515ad6caca68995da2fac7a6af97faab8f78ebe3bf4f761e1b77efbc47b5/wrapt-1.17.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:373342dd05b1d07d752cecbec0c41817231f29f3a89aa8b8843f7b95992ed0c7", size = 86945, upload-time = "2025-08-12T05:52:21.581Z" }, + { url = "https://files.pythonhosted.org/packages/d3/bd/4e70162ce398462a467bc09e768bee112f1412e563620adc353de9055d33/wrapt-1.17.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:d40770d7c0fd5cbed9d84b2c3f2e156431a12c9a37dc6284060fb4bec0b7ffd4", size = 86857, upload-time = "2025-08-12T05:52:43.043Z" }, + { url = "https://files.pythonhosted.org/packages/2b/b8/da8560695e9284810b8d3df8a19396a6e40e7518059584a1a394a2b35e0a/wrapt-1.17.3-cp314-cp314-win32.whl", hash = "sha256:fbd3c8319de8e1dc79d346929cd71d523622da527cca14e0c1d257e31c2b8b10", size = 37178, upload-time = "2025-08-12T05:53:12.605Z" }, + { url = "https://files.pythonhosted.org/packages/db/c8/b71eeb192c440d67a5a0449aaee2310a1a1e8eca41676046f99ed2487e9f/wrapt-1.17.3-cp314-cp314-win_amd64.whl", hash = "sha256:e1a4120ae5705f673727d3253de3ed0e016f7cd78dc463db1b31e2463e1f3cf6", size = 39310, upload-time = "2025-08-12T05:53:11.106Z" }, + { url = "https://files.pythonhosted.org/packages/45/20/2cda20fd4865fa40f86f6c46ed37a2a8356a7a2fde0773269311f2af56c7/wrapt-1.17.3-cp314-cp314-win_arm64.whl", hash = "sha256:507553480670cab08a800b9463bdb881b2edeed77dc677b0a5915e6106e91a58", size = 37266, upload-time = "2025-08-12T05:52:56.531Z" }, + { url = "https://files.pythonhosted.org/packages/77/ed/dd5cf21aec36c80443c6f900449260b80e2a65cf963668eaef3b9accce36/wrapt-1.17.3-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:ed7c635ae45cfbc1a7371f708727bf74690daedc49b4dba310590ca0bd28aa8a", size = 56544, upload-time = "2025-08-12T05:51:51.109Z" }, + { url = "https://files.pythonhosted.org/packages/8d/96/450c651cc753877ad100c7949ab4d2e2ecc4d97157e00fa8f45df682456a/wrapt-1.17.3-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:249f88ed15503f6492a71f01442abddd73856a0032ae860de6d75ca62eed8067", size = 40283, upload-time = "2025-08-12T05:51:39.912Z" }, + { url = "https://files.pythonhosted.org/packages/d1/86/2fcad95994d9b572db57632acb6f900695a648c3e063f2cd344b3f5c5a37/wrapt-1.17.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:5a03a38adec8066d5a37bea22f2ba6bbf39fcdefbe2d91419ab864c3fb515454", size = 40366, upload-time = "2025-08-12T05:52:00.693Z" }, + { url = "https://files.pythonhosted.org/packages/64/0e/f4472f2fdde2d4617975144311f8800ef73677a159be7fe61fa50997d6c0/wrapt-1.17.3-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:5d4478d72eb61c36e5b446e375bbc49ed002430d17cdec3cecb36993398e1a9e", size = 108571, upload-time = "2025-08-12T05:52:44.521Z" }, + { url = "https://files.pythonhosted.org/packages/cc/01/9b85a99996b0a97c8a17484684f206cbb6ba73c1ce6890ac668bcf3838fb/wrapt-1.17.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:223db574bb38637e8230eb14b185565023ab624474df94d2af18f1cdb625216f", size = 113094, upload-time = "2025-08-12T05:52:22.618Z" }, + { url = "https://files.pythonhosted.org/packages/25/02/78926c1efddcc7b3aa0bc3d6b33a822f7d898059f7cd9ace8c8318e559ef/wrapt-1.17.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:e405adefb53a435f01efa7ccdec012c016b5a1d3f35459990afc39b6be4d5056", size = 110659, upload-time = "2025-08-12T05:52:24.057Z" }, + { url = "https://files.pythonhosted.org/packages/dc/ee/c414501ad518ac3e6fe184753632fe5e5ecacdcf0effc23f31c1e4f7bfcf/wrapt-1.17.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:88547535b787a6c9ce4086917b6e1d291aa8ed914fdd3a838b3539dc95c12804", size = 106946, upload-time = "2025-08-12T05:52:45.976Z" }, + { url = "https://files.pythonhosted.org/packages/be/44/a1bd64b723d13bb151d6cc91b986146a1952385e0392a78567e12149c7b4/wrapt-1.17.3-cp314-cp314t-win32.whl", hash = "sha256:41b1d2bc74c2cac6f9074df52b2efbef2b30bdfe5f40cb78f8ca22963bc62977", size = 38717, upload-time = "2025-08-12T05:53:15.214Z" }, + { url = "https://files.pythonhosted.org/packages/79/d9/7cfd5a312760ac4dd8bf0184a6ee9e43c33e47f3dadc303032ce012b8fa3/wrapt-1.17.3-cp314-cp314t-win_amd64.whl", hash = "sha256:73d496de46cd2cdbdbcce4ae4bcdb4afb6a11234a1df9c085249d55166b95116", size = 41334, upload-time = "2025-08-12T05:53:14.178Z" }, + { url = "https://files.pythonhosted.org/packages/46/78/10ad9781128ed2f99dbc474f43283b13fea8ba58723e98844367531c18e9/wrapt-1.17.3-cp314-cp314t-win_arm64.whl", hash = "sha256:f38e60678850c42461d4202739f9bf1e3a737c7ad283638251e79cc49effb6b6", size = 38471, upload-time = "2025-08-12T05:52:57.784Z" }, + { url = "https://files.pythonhosted.org/packages/1f/f6/a933bd70f98e9cf3e08167fc5cd7aaaca49147e48411c0bd5ae701bb2194/wrapt-1.17.3-py3-none-any.whl", hash = "sha256:7171ae35d2c33d326ac19dd8facb1e82e5fd04ef8c6c0e394d7af55a55051c22", size = 23591, upload-time = "2025-08-12T05:53:20.674Z" }, +] + +[[package]] +name = "zipp" +version = "3.23.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e3/02/0f2892c661036d50ede074e376733dca2ae7c6eb617489437771209d4180/zipp-3.23.0.tar.gz", hash = "sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166", size = 25547, upload-time = "2025-06-08T17:06:39.4Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2e/54/647ade08bf0db230bfea292f893923872fd20be6ac6f53b2b936ba839d75/zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e", size = 10276, upload-time = "2025-06-08T17:06:38.034Z" }, +] From d5b9155f14ab3b4b203a47c55b53955b2d0bca09 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 31 Mar 2026 17:11:12 -0400 Subject: [PATCH 32/60] chore(deps): bump requests from 2.32.5 to 2.33.0 in the uv group across 1 directory (#2350) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- uv.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/uv.lock b/uv.lock index 5efbb05dce..e6fb526167 100644 --- a/uv.lock +++ b/uv.lock @@ -2394,7 +2394,7 @@ wheels = [ [[package]] name = "requests" -version = "2.32.5" +version = "2.33.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi" }, @@ -2402,9 +2402,9 @@ dependencies = [ { name = "idna" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517, upload-time = "2025-08-18T20:46:02.573Z" } +sdist = { url = "https://files.pythonhosted.org/packages/34/64/8860370b167a9721e8956ae116825caff829224fbca0ca6e7bf8ddef8430/requests-2.33.0.tar.gz", hash = "sha256:c7ebc5e8b0f21837386ad0e1c8fe8b829fa5f544d8df3b2253bff14ef29d7652", size = 134232, upload-time = "2026-03-25T15:10:41.586Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, + { url = "https://files.pythonhosted.org/packages/56/5d/c814546c2333ceea4ba42262d8c4d55763003e767fa169adc693bd524478/requests-2.33.0-py3-none-any.whl", hash = "sha256:3324635456fa185245e24865e810cecec7b4caf933d7eb133dcde67d48cee69b", size = 65017, upload-time = "2026-03-25T15:10:40.382Z" }, ] [[package]] From cf4e435db015c9fe2d6534ad1acb30b9bb94ca7d Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 8 Apr 2026 14:12:20 +0200 Subject: [PATCH 33/60] Use shared `is_async_callable` instead of `inspect.iscoroutinefunction` (#2389) --- src/mcp/server/mcpserver/prompts/base.py | 7 ++-- .../server/mcpserver/resources/templates.py | 7 ++-- src/mcp/server/mcpserver/resources/types.py | 11 ++++--- src/mcp/server/mcpserver/tools/base.py | 14 ++------ src/mcp/shared/_callable_inspection.py | 33 +++++++++++++++++++ 5 files changed, 50 insertions(+), 22 deletions(-) create mode 100644 src/mcp/shared/_callable_inspection.py diff --git a/src/mcp/server/mcpserver/prompts/base.py b/src/mcp/server/mcpserver/prompts/base.py index b4810c100e..e5b2af7d82 100644 --- a/src/mcp/server/mcpserver/prompts/base.py +++ b/src/mcp/server/mcpserver/prompts/base.py @@ -3,7 +3,6 @@ from __future__ import annotations import functools -import inspect from collections.abc import Awaitable, Callable, Sequence from typing import TYPE_CHECKING, Any, Literal @@ -13,6 +12,7 @@ from mcp.server.mcpserver.utilities.context_injection import find_context_parameter, inject_context from mcp.server.mcpserver.utilities.func_metadata import func_metadata +from mcp.shared._callable_inspection import is_async_callable from mcp.types import ContentBlock, Icon, TextContent if TYPE_CHECKING: @@ -157,8 +157,9 @@ async def render( # Add context to arguments if needed call_args = inject_context(self.fn, arguments or {}, context, self.context_kwarg) - if inspect.iscoroutinefunction(self.fn): - result = await self.fn(**call_args) + fn = self.fn + if is_async_callable(fn): + result = await fn(**call_args) else: result = await anyio.to_thread.run_sync(functools.partial(self.fn, **call_args)) diff --git a/src/mcp/server/mcpserver/resources/templates.py b/src/mcp/server/mcpserver/resources/templates.py index 542b5e6f81..f1ee29a37f 100644 --- a/src/mcp/server/mcpserver/resources/templates.py +++ b/src/mcp/server/mcpserver/resources/templates.py @@ -3,7 +3,6 @@ from __future__ import annotations import functools -import inspect import re from collections.abc import Callable from typing import TYPE_CHECKING, Any @@ -15,6 +14,7 @@ from mcp.server.mcpserver.resources.types import FunctionResource, Resource from mcp.server.mcpserver.utilities.context_injection import find_context_parameter, inject_context from mcp.server.mcpserver.utilities.func_metadata import func_metadata +from mcp.shared._callable_inspection import is_async_callable from mcp.types import Annotations, Icon if TYPE_CHECKING: @@ -112,8 +112,9 @@ async def create_resource( # Add context to params if needed params = inject_context(self.fn, params, context, self.context_kwarg) - if inspect.iscoroutinefunction(self.fn): - result = await self.fn(**params) + fn = self.fn + if is_async_callable(fn): + result = await fn(**params) else: result = await anyio.to_thread.run_sync(functools.partial(self.fn, **params)) diff --git a/src/mcp/server/mcpserver/resources/types.py b/src/mcp/server/mcpserver/resources/types.py index 04763be8ba..d9e472e362 100644 --- a/src/mcp/server/mcpserver/resources/types.py +++ b/src/mcp/server/mcpserver/resources/types.py @@ -1,6 +1,7 @@ """Concrete resource implementations.""" -import inspect +from __future__ import annotations + import json from collections.abc import Callable from pathlib import Path @@ -14,6 +15,7 @@ from pydantic import Field, ValidationInfo, validate_call from mcp.server.mcpserver.resources.base import Resource +from mcp.shared._callable_inspection import is_async_callable from mcp.types import Annotations, Icon @@ -55,8 +57,9 @@ class FunctionResource(Resource): async def read(self) -> str | bytes: """Read the resource by calling the wrapped function.""" try: - if inspect.iscoroutinefunction(self.fn): - result = await self.fn() + fn = self.fn + if is_async_callable(fn): + result = await fn() else: result = await anyio.to_thread.run_sync(self.fn) @@ -83,7 +86,7 @@ def from_function( icons: list[Icon] | None = None, annotations: Annotations | None = None, meta: dict[str, Any] | None = None, - ) -> "FunctionResource": + ) -> FunctionResource: """Create a FunctionResource from a function.""" func_name = name or fn.__name__ if func_name == "": # pragma: no cover diff --git a/src/mcp/server/mcpserver/tools/base.py b/src/mcp/server/mcpserver/tools/base.py index dc65be9885..754313eb8a 100644 --- a/src/mcp/server/mcpserver/tools/base.py +++ b/src/mcp/server/mcpserver/tools/base.py @@ -1,7 +1,5 @@ from __future__ import annotations -import functools -import inspect from collections.abc import Callable from functools import cached_property from typing import TYPE_CHECKING, Any @@ -11,6 +9,7 @@ from mcp.server.mcpserver.exceptions import ToolError from mcp.server.mcpserver.utilities.context_injection import find_context_parameter from mcp.server.mcpserver.utilities.func_metadata import FuncMetadata, func_metadata +from mcp.shared._callable_inspection import is_async_callable from mcp.shared.exceptions import UrlElicitationRequiredError from mcp.shared.tool_name_validation import validate_and_warn_tool_name from mcp.types import Icon, ToolAnnotations @@ -63,7 +62,7 @@ def from_function( raise ValueError("You must provide a name for lambda functions") func_doc = description or fn.__doc__ or "" - is_async = _is_async_callable(fn) + is_async = is_async_callable(fn) if context_kwarg is None: # pragma: no branch context_kwarg = find_context_parameter(fn) @@ -118,12 +117,3 @@ async def run( raise except Exception as e: raise ToolError(f"Error executing tool {self.name}: {e}") from e - - -def _is_async_callable(obj: Any) -> bool: - while isinstance(obj, functools.partial): # pragma: lax no cover - obj = obj.func - - return inspect.iscoroutinefunction(obj) or ( - callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None)) - ) diff --git a/src/mcp/shared/_callable_inspection.py b/src/mcp/shared/_callable_inspection.py new file mode 100644 index 0000000000..0e89e446f8 --- /dev/null +++ b/src/mcp/shared/_callable_inspection.py @@ -0,0 +1,33 @@ +"""Callable inspection utilities. + +Adapted from Starlette's `is_async_callable` implementation. +https://github.com/encode/starlette/blob/main/starlette/_utils.py +""" + +from __future__ import annotations + +import functools +import inspect +from collections.abc import Awaitable, Callable +from typing import Any, TypeGuard, TypeVar, overload + +T = TypeVar("T") + +AwaitableCallable = Callable[..., Awaitable[T]] + + +@overload +def is_async_callable(obj: AwaitableCallable[T]) -> TypeGuard[AwaitableCallable[T]]: ... + + +@overload +def is_async_callable(obj: Any) -> TypeGuard[AwaitableCallable[Any]]: ... + + +def is_async_callable(obj: Any) -> Any: + while isinstance(obj, functools.partial): # pragma: lax no cover + obj = obj.func + + return inspect.iscoroutinefunction(obj) or ( + callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None)) + ) From f27d2aac055ed031512eba84b31ac849bb81da2f Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 9 Apr 2026 13:25:16 +0100 Subject: [PATCH 34/60] docs: fill migration guide gaps surfaced by automated upgrade eval (#2412) --- docs/migration.md | 240 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 235 insertions(+), 5 deletions(-) diff --git a/docs/migration.md b/docs/migration.md index 3b47f9aade..2528f046c6 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -38,6 +38,7 @@ http_client = httpx.AsyncClient( headers={"Authorization": "Bearer token"}, timeout=httpx.Timeout(30, read=300), auth=my_auth, + follow_redirects=True, ) async with http_client: @@ -48,6 +49,8 @@ async with http_client: ... ``` +v1's internal client set `follow_redirects=True`; set it explicitly when supplying your own `httpx.AsyncClient` to preserve that behavior. + ### `get_session_id` callback removed from `streamable_http_client` The `get_session_id` callback (third element of the returned tuple) has been removed from `streamable_http_client`. The function now returns a 2-tuple `(read_stream, write_stream)` instead of a 3-tuple. @@ -100,6 +103,8 @@ async with http_client: The `headers`, `timeout`, `sse_read_timeout`, and `auth` parameters have been removed from `StreamableHTTPTransport`. Configure these on the `httpx.AsyncClient` instead (see example above). +Note: `sse_client` retains its `headers`, `timeout`, `sse_read_timeout`, and `auth` parameters — only the streamable HTTP transport changed. + ### Removed type aliases and classes The following deprecated type aliases and classes have been removed from `mcp.types`: @@ -126,6 +131,52 @@ from mcp.types import ContentBlock, ResourceTemplateReference # Use `str` instead of `Cursor` for pagination cursors ``` +### Field names changed from camelCase to snake_case + +All Pydantic model fields in `mcp.types` now use snake_case names for Python attribute access. The JSON wire format is unchanged — serialization still uses camelCase via Pydantic aliases. + +**Before (v1):** + +```python +result = await session.call_tool("my_tool", {"x": 1}) +if result.isError: + ... + +tools = await session.list_tools() +cursor = tools.nextCursor +schema = tools.tools[0].inputSchema +``` + +**After (v2):** + +```python +result = await session.call_tool("my_tool", {"x": 1}) +if result.is_error: + ... + +tools = await session.list_tools() +cursor = tools.next_cursor +schema = tools.tools[0].input_schema +``` + +Common renames: + +| v1 (camelCase) | v2 (snake_case) | +|----------------|-----------------| +| `inputSchema` | `input_schema` | +| `outputSchema` | `output_schema` | +| `isError` | `is_error` | +| `nextCursor` | `next_cursor` | +| `mimeType` | `mime_type` | +| `structuredContent` | `structured_content` | +| `serverInfo` | `server_info` | +| `protocolVersion` | `protocol_version` | +| `uriTemplate` | `uri_template` | +| `listChanged` | `list_changed` | +| `progressToken` | `progress_token` | + +Because `populate_by_name=True` is set, the old camelCase names still work as constructor kwargs (e.g., `Tool(inputSchema={...})` is accepted), but attribute access must use snake_case (`tool.input_schema`). + ### `args` parameter removed from `ClientSessionGroup.call_tool()` The deprecated `args` parameter has been removed from `ClientSessionGroup.call_tool()`. Use `arguments` instead. @@ -225,6 +276,28 @@ except MCPError as e: from mcp import MCPError ``` +The constructor signature also changed — it now takes `code`, `message`, and optional `data` directly instead of wrapping an `ErrorData`: + +**Before (v1):** + +```python +from mcp.shared.exceptions import McpError +from mcp.types import ErrorData, INVALID_REQUEST + +raise McpError(ErrorData(code=INVALID_REQUEST, message="bad input")) +``` + +**After (v2):** + +```python +from mcp.shared.exceptions import MCPError +from mcp.types import INVALID_REQUEST + +raise MCPError(INVALID_REQUEST, "bad input") +# or, if you already have an ErrorData: +raise MCPError.from_error_data(error_data) +``` + ### `FastMCP` renamed to `MCPServer` The `FastMCP` class has been renamed to `MCPServer` to better reflect its role as the main server class in the SDK. This is a simple rename with no functional changes to the class itself. @@ -240,11 +313,19 @@ mcp = FastMCP("Demo") **After (v2):** ```python -from mcp.server.mcpserver import MCPServer +from mcp.server.mcpserver import MCPServer, Context mcp = MCPServer("Demo") ``` +`Context` is the type annotation for the `ctx` parameter injected into tools, resources, and prompts (see [`get_context()` removed](#mcpserverget_context-removed) below). + +All submodules under `mcp.server.fastmcp.*` are now under `mcp.server.mcpserver.*` with the same structure. Common imports: + +- `Image`, `Audio` — from `mcp.server.mcpserver` (or `.utilities.types`) +- `UserMessage`, `AssistantMessage` — from `mcp.server.mcpserver.prompts.base` +- `ToolError`, `ResourceError` — from `mcp.server.mcpserver.exceptions` + ### `mount_path` parameter removed from MCPServer The `mount_path` parameter has been removed from `MCPServer.__init__()`, `MCPServer.run()`, `MCPServer.run_sse_async()`, and `MCPServer.sse_app()`. It was also removed from the `Settings` class. @@ -312,6 +393,8 @@ app = Starlette(routes=[Mount("/", app=mcp.streamable_http_app(json_response=Tru **Note:** DNS rebinding protection is automatically enabled when `host` is `127.0.0.1`, `localhost`, or `::1`. This now happens in `sse_app()` and `streamable_http_app()` instead of the constructor. +If you were mutating these via `mcp.settings` after construction (e.g., `mcp.settings.port = 9000`), pass them to `run()` / `sse_app()` / `streamable_http_app()` instead — these fields no longer exist on `Settings`. The `debug` and `log_level` parameters remain on the constructor. + ### `MCPServer.get_context()` removed `MCPServer.get_context()` has been removed. Context is now injected by the framework and passed explicitly — there is no ambient ContextVar to read from. @@ -331,6 +414,8 @@ async def my_tool(x: int) -> str: **After (v2):** ```python +from mcp.server.mcpserver import Context + @mcp.tool() async def my_tool(x: int, ctx: Context) -> str: await ctx.info("Processing...") @@ -343,6 +428,45 @@ async def my_tool(x: int, ctx: Context) -> str: The internal layers (`ToolManager.call_tool`, `Tool.run`, `Prompt.render`, `ResourceTemplate.create_resource`, etc.) now require `context` as a positional argument. +### Registering lowlevel handlers on `MCPServer` (workaround) + +`MCPServer` does not expose public APIs for `subscribe_resource`, `unsubscribe_resource`, or `set_logging_level` handlers. In v1, the workaround was to reach into the private lowlevel server and use its decorator methods: + +**Before (v1):** + +```python +@mcp._mcp_server.set_logging_level() # pyright: ignore[reportPrivateUsage] +async def handle_set_logging_level(level: str) -> None: + ... + +mcp._mcp_server.subscribe_resource()(handle_subscribe) # pyright: ignore[reportPrivateUsage] +``` + +In v2, the lowlevel `Server` no longer has decorator methods (handlers are constructor-only), so the equivalent workaround is `_add_request_handler`: + +**After (v2):** + +```python +from mcp.server import ServerRequestContext +from mcp.types import EmptyResult, SetLevelRequestParams, SubscribeRequestParams + + +async def handle_set_logging_level(ctx: ServerRequestContext, params: SetLevelRequestParams) -> EmptyResult: + ... + return EmptyResult() + + +async def handle_subscribe(ctx: ServerRequestContext, params: SubscribeRequestParams) -> EmptyResult: + ... + return EmptyResult() + + +mcp._lowlevel_server._add_request_handler("logging/setLevel", handle_set_logging_level) # pyright: ignore[reportPrivateUsage] +mcp._lowlevel_server._add_request_handler("resources/subscribe", handle_subscribe) # pyright: ignore[reportPrivateUsage] +``` + +This is a private API and may change. A public way to register these handlers on `MCPServer` is planned; until then, use this workaround or use the lowlevel `Server` directly. + ### Replace `RootModel` by union types with `TypeAdapter` validation The following union types are no longer `RootModel` subclasses: @@ -383,6 +507,22 @@ notification = server_notification_adapter.validate_python(data) # No .root access needed - notification is the actual type ``` +The same applies when constructing values — the wrapper call is no longer needed: + +**Before (v1):** + +```python +await session.send_notification(ClientNotification(InitializedNotification())) +await session.send_request(ClientRequest(PingRequest()), EmptyResult) +``` + +**After (v2):** + +```python +await session.send_notification(InitializedNotification()) +await session.send_request(PingRequest(), EmptyResult) +``` + **Available adapters:** | Union Type | Adapter | @@ -428,6 +568,8 @@ server = Server("my-server", on_call_tool=handle_call_tool) ### `RequestContext` type parameters simplified +The `mcp.shared.context` module has been removed. `RequestContext` is now split into `ClientRequestContext` (in `mcp.client.context`) and `ServerRequestContext` (in `mcp.server.context`). + The `RequestContext` class has been split to separate shared fields from server-specific fields. The shared `RequestContext` now only takes 1 type parameter (the session type) instead of 3. **`RequestContext` changes:** @@ -458,11 +600,27 @@ ctx: ClientRequestContext server_ctx: ServerRequestContext[LifespanContextT, RequestT] ``` +The high-level `Context` class (injected into `@mcp.tool()` etc.) similarly dropped its `ServerSessionT` parameter: `Context[ServerSessionT, LifespanContextT, RequestT]` → `Context[LifespanContextT, RequestT]`. Both remaining parameters have defaults, so bare `Context` is usually sufficient: + +**Before (v1):** + +```python +async def my_tool(ctx: Context[ServerSession, None]) -> str: ... +``` + +**After (v2):** + +```python +async def my_tool(ctx: Context) -> str: ... +# or, with an explicit lifespan type: +async def my_tool(ctx: Context[MyLifespanState]) -> str: ... +``` + ### `ProgressContext` and `progress()` context manager removed The `mcp.shared.progress` module (`ProgressContext`, `Progress`, and the `progress()` context manager) has been removed. This module had no real-world adoption — all users send progress notifications via `Context.report_progress()` or `session.send_progress_notification()` directly. -**Before:** +**Before (v1):** ```python from mcp.shared.progress import progress @@ -490,6 +648,46 @@ await session.send_progress_notification( ) ``` +### `create_connected_server_and_client_session` removed + +The `create_connected_server_and_client_session` helper in `mcp.shared.memory` has been removed. Use `mcp.client.Client` instead — it accepts a `Server` or `MCPServer` instance directly and handles the in-memory transport and session setup for you. + +**Before (v1):** + +```python +from mcp.shared.memory import create_connected_server_and_client_session + +async with create_connected_server_and_client_session(server) as session: + result = await session.call_tool("my_tool", {"x": 1}) +``` + +**After (v2):** + +```python +from mcp.client import Client + +async with Client(server) as client: + result = await client.call_tool("my_tool", {"x": 1}) +``` + +`Client` accepts the same callback parameters the old helper did (`sampling_callback`, `list_roots_callback`, `logging_callback`, `message_handler`, `elicitation_callback`, `client_info`) plus `raise_exceptions` to surface server-side errors. + +If you need direct access to the underlying `ClientSession` and memory streams (e.g., for low-level transport testing), `create_client_server_memory_streams` is still available in `mcp.shared.memory`: + +```python +import anyio +from mcp.client.session import ClientSession +from mcp.shared.memory import create_client_server_memory_streams + +async with create_client_server_memory_streams() as (client_streams, server_streams): + async with anyio.create_task_group() as tg: + tg.start_soon(lambda: server.run(*server_streams, server.create_initialization_options())) + async with ClientSession(*client_streams) as session: + await session.initialize() + ... + tg.cancel_scope.cancel() +``` + ### Resource URI type changed from `AnyUrl` to `str` The `uri` field on resource-related types now uses `str` instead of Pydantic's `AnyUrl`. This aligns with the [MCP specification schema](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/draft/schema.ts) which defines URIs as plain strings (`uri: string`) without strict URL validation. This change allows relative paths like `users/me` that were previously rejected. @@ -593,6 +791,8 @@ if ListToolsRequest in server.request_handlers: server = Server("my-server", on_list_tools=handle_list_tools) ``` +If you need to check whether a handler is registered, track this yourself — there is currently no public introspection API. + ### Lowlevel `Server`: decorator-based handlers replaced with constructor `on_*` params The lowlevel `Server` class no longer uses decorator methods for handler registration. Instead, handlers are passed as `on_*` keyword arguments to the constructor. @@ -645,6 +845,29 @@ server = Server("my-server", on_list_tools=handle_list_tools, on_call_tool=handl - Handlers return the full result type (e.g. `ListToolsResult`) rather than unwrapped values (e.g. `list[Tool]`). - The automatic `jsonschema` input/output validation that the old `call_tool()` decorator performed has been removed. There is no built-in replacement — if you relied on schema validation in the lowlevel server, you will need to validate inputs yourself in your handler. +**Complete handler reference:** + +All handlers receive `ctx: ServerRequestContext` as the first argument. The second argument and return type are: + +| v1 decorator | v2 constructor kwarg | `params` type | return type | +|---|---|---|---| +| `@server.list_tools()` | `on_list_tools` | `PaginatedRequestParams \| None` | `ListToolsResult` | +| `@server.call_tool()` | `on_call_tool` | `CallToolRequestParams` | `CallToolResult \| CreateTaskResult` | +| `@server.list_resources()` | `on_list_resources` | `PaginatedRequestParams \| None` | `ListResourcesResult` | +| `@server.list_resource_templates()` | `on_list_resource_templates` | `PaginatedRequestParams \| None` | `ListResourceTemplatesResult` | +| `@server.read_resource()` | `on_read_resource` | `ReadResourceRequestParams` | `ReadResourceResult` | +| `@server.subscribe_resource()` | `on_subscribe_resource` | `SubscribeRequestParams` | `EmptyResult` | +| `@server.unsubscribe_resource()` | `on_unsubscribe_resource` | `UnsubscribeRequestParams` | `EmptyResult` | +| `@server.list_prompts()` | `on_list_prompts` | `PaginatedRequestParams \| None` | `ListPromptsResult` | +| `@server.get_prompt()` | `on_get_prompt` | `GetPromptRequestParams` | `GetPromptResult` | +| `@server.completion()` | `on_completion` | `CompleteRequestParams` | `CompleteResult` | +| `@server.set_logging_level()` | `on_set_logging_level` | `SetLevelRequestParams` | `EmptyResult` | +| — | `on_ping` | `RequestParams \| None` | `EmptyResult` | +| `@server.progress_notification()` | `on_progress` | `ProgressNotificationParams` | `None` | +| — | `on_roots_list_changed` | `NotificationParams \| None` | `None` | + +All `params` and return types are importable from `mcp.types`. + **Notification handlers:** ```python @@ -694,10 +917,17 @@ Note: `params.arguments` can be `None` (the old decorator defaulted it to `{}`). **`read_resource()` — content type wrapping removed:** -The old decorator auto-wrapped `str` into `TextResourceContents` and `bytes` into `BlobResourceContents` (with base64 encoding), and applied a default mime type of `text/plain`: +The old decorator auto-wrapped `Iterable[ReadResourceContents]` (and the deprecated `str`/`bytes` shorthand) into `TextResourceContents`/`BlobResourceContents`, handling base64 encoding and mime-type defaulting: ```python -# Before (v1) — str/bytes auto-wrapped with mime type defaulting +# Before (v1) — Iterable[ReadResourceContents] auto-wrapped +from mcp.server.lowlevel.helper_types import ReadResourceContents + +@server.read_resource() +async def handle(uri: AnyUrl) -> Iterable[ReadResourceContents]: + return [ReadResourceContents(content="file contents", mime_type="text/plain")] + +# Before (v1) — str/bytes shorthand (already deprecated in v1) @server.read_resource() async def handle(uri: str) -> str: return "file contents" @@ -849,7 +1079,7 @@ params = CallToolRequestParams( params = CallToolRequestParams( name="my_tool", arguments={}, - _meta={"progressToken": "tok", "customField": "value"}, # OK + _meta={"my_custom_key": "value", "another": 123}, # OK ) ``` From c5f12ec1e969ef69d4964816b650faf322eda3f3 Mon Sep 17 00:00:00 2001 From: Matt LeMay Date: Sun, 12 Apr 2026 07:42:12 -0500 Subject: [PATCH 35/60] Add `resources` parameter to `MCPServer` (#2414) Co-authored-by: Marcelo Trylesinski --- .../mcpserver/resources/resource_manager.py | 16 +- src/mcp/server/mcpserver/server.py | 5 +- .../server/mcpserver/tools/tool_manager.py | 16 +- .../resources/test_resource_manager.py | 292 ++++++++---------- tests/server/mcpserver/test_server.py | 26 ++ 5 files changed, 170 insertions(+), 185 deletions(-) diff --git a/src/mcp/server/mcpserver/resources/resource_manager.py b/src/mcp/server/mcpserver/resources/resource_manager.py index 6bf17376d1..766cf51aea 100644 --- a/src/mcp/server/mcpserver/resources/resource_manager.py +++ b/src/mcp/server/mcpserver/resources/resource_manager.py @@ -22,28 +22,26 @@ class ResourceManager: """Manages MCPServer resources.""" - def __init__(self, warn_on_duplicate_resources: bool = True): + def __init__(self, warn_on_duplicate_resources: bool = True, *, resources: list[Resource] | None = None): self._resources: dict[str, Resource] = {} self._templates: dict[str, ResourceTemplate] = {} self.warn_on_duplicate_resources = warn_on_duplicate_resources + for resource in resources or (): + self.add_resource(resource) + def add_resource(self, resource: Resource) -> Resource: """Add a resource to the manager. Args: - resource: A Resource instance to add + resource: A Resource instance to add. Returns: - The added resource. If a resource with the same URI already exists, - returns the existing resource. + The added resource. If a resource with the same URI already exists, returns the existing resource. """ logger.debug( "Adding resource", - extra={ - "uri": resource.uri, - "type": type(resource).__name__, - "resource_name": resource.name, - }, + extra={"uri": resource.uri, "type": type(resource).__name__, "resource_name": resource.name}, ) existing = self._resources.get(str(resource.uri)) if existing: diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index 6f9bb0e287..be77705da6 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -140,6 +140,7 @@ def __init__( token_verifier: TokenVerifier | None = None, *, tools: list[Tool] | None = None, + resources: list[Resource] | None = None, debug: bool = False, log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO", warn_on_duplicate_resources: bool = True, @@ -162,7 +163,9 @@ def __init__( self.dependencies = self.settings.dependencies self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools) - self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources) + self._resource_manager = ResourceManager( + resources=resources, warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources + ) self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts) self._lowlevel_server = Server( name=name or "mcp-server", diff --git a/src/mcp/server/mcpserver/tools/tool_manager.py b/src/mcp/server/mcpserver/tools/tool_manager.py index 32ed547973..eef4911f9e 100644 --- a/src/mcp/server/mcpserver/tools/tool_manager.py +++ b/src/mcp/server/mcpserver/tools/tool_manager.py @@ -18,18 +18,12 @@ class ToolManager: """Manages MCPServer tools.""" - def __init__( - self, - warn_on_duplicate_tools: bool = True, - *, - tools: list[Tool] | None = None, - ): + def __init__(self, warn_on_duplicate_tools: bool = True, *, tools: list[Tool] | None = None): self._tools: dict[str, Tool] = {} - if tools is not None: - for tool in tools: - if warn_on_duplicate_tools and tool.name in self._tools: - logger.warning(f"Tool already exists: {tool.name}") - self._tools[tool.name] = tool + for tool in tools or (): + if warn_on_duplicate_tools and tool.name in self._tools: + logger.warning(f"Tool already exists: {tool.name}") + self._tools[tool.name] = tool self.warn_on_duplicate_tools = warn_on_duplicate_tools diff --git a/tests/server/mcpserver/resources/test_resource_manager.py b/tests/server/mcpserver/resources/test_resource_manager.py index 724b579974..b91c71581c 100644 --- a/tests/server/mcpserver/resources/test_resource_manager.py +++ b/tests/server/mcpserver/resources/test_resource_manager.py @@ -1,5 +1,5 @@ +import logging from pathlib import Path -from tempfile import NamedTemporaryFile import pytest from pydantic import AnyUrl @@ -8,170 +8,134 @@ from mcp.server.mcpserver.resources import FileResource, FunctionResource, ResourceManager, ResourceTemplate -@pytest.fixture -def temp_file(): +@pytest.fixture() +def temp_file(tmp_path: Path): """Create a temporary file for testing. File is automatically cleaned up after the test if it still exists. """ - content = "test content" - with NamedTemporaryFile(mode="w", delete=False) as f: - f.write(content) - path = Path(f.name).resolve() - yield path - try: # pragma: lax no cover - path.unlink() - except FileNotFoundError: # pragma: lax no cover - pass # File was already deleted by the test - - -class TestResourceManager: - """Test ResourceManager functionality.""" - - def test_add_resource(self, temp_file: Path): - """Test adding a resource.""" - manager = ResourceManager() - resource = FileResource( - uri=f"file://{temp_file}", - name="test", - path=temp_file, - ) - added = manager.add_resource(resource) - assert added == resource - assert manager.list_resources() == [resource] - - def test_add_duplicate_resource(self, temp_file: Path): - """Test adding the same resource twice.""" - manager = ResourceManager() - resource = FileResource( - uri=f"file://{temp_file}", - name="test", - path=temp_file, - ) - first = manager.add_resource(resource) - second = manager.add_resource(resource) - assert first == second - assert manager.list_resources() == [resource] - - def test_warn_on_duplicate_resources(self, temp_file: Path, caplog: pytest.LogCaptureFixture): - """Test warning on duplicate resources.""" - manager = ResourceManager() - resource = FileResource( - uri=f"file://{temp_file}", - name="test", - path=temp_file, - ) - manager.add_resource(resource) - manager.add_resource(resource) - assert "Resource already exists" in caplog.text - - def test_disable_warn_on_duplicate_resources(self, temp_file: Path, caplog: pytest.LogCaptureFixture): - """Test disabling warning on duplicate resources.""" - manager = ResourceManager(warn_on_duplicate_resources=False) - resource = FileResource( - uri=f"file://{temp_file}", - name="test", - path=temp_file, - ) - manager.add_resource(resource) - manager.add_resource(resource) - assert "Resource already exists" not in caplog.text - - @pytest.mark.anyio - async def test_get_resource(self, temp_file: Path): - """Test getting a resource by URI.""" - manager = ResourceManager() - resource = FileResource( - uri=f"file://{temp_file}", - name="test", - path=temp_file, - ) - manager.add_resource(resource) - retrieved = await manager.get_resource(resource.uri, Context()) - assert retrieved == resource - - @pytest.mark.anyio - async def test_get_resource_from_template(self): - """Test getting a resource through a template.""" - manager = ResourceManager() - - def greet(name: str) -> str: - return f"Hello, {name}!" - - template = ResourceTemplate.from_function( - fn=greet, - uri_template="greet://{name}", - name="greeter", - ) - manager._templates[template.uri_template] = template - - resource = await manager.get_resource(AnyUrl("greet://world"), Context()) - assert isinstance(resource, FunctionResource) - content = await resource.read() - assert content == "Hello, world!" - - @pytest.mark.anyio - async def test_get_unknown_resource(self): - """Test getting a non-existent resource.""" - manager = ResourceManager() - with pytest.raises(ValueError, match="Unknown resource"): - await manager.get_resource(AnyUrl("unknown://test"), Context()) - - def test_list_resources(self, temp_file: Path): - """Test listing all resources.""" - manager = ResourceManager() - resource1 = FileResource( - uri=f"file://{temp_file}", - name="test1", - path=temp_file, - ) - resource2 = FileResource( - uri=f"file://{temp_file}2", - name="test2", - path=temp_file, - ) - manager.add_resource(resource1) - manager.add_resource(resource2) - resources = manager.list_resources() - assert len(resources) == 2 - assert resources == [resource1, resource2] - - -class TestResourceManagerMetadata: - """Test ResourceManager Metadata""" - - def test_add_template_with_metadata(self): - """Test that ResourceManager.add_template() accepts and passes meta parameter.""" - - manager = ResourceManager() - - def get_item(id: str) -> str: # pragma: no cover - return f"Item {id}" - - metadata = {"source": "database", "cached": True} - - template = manager.add_template( - fn=get_item, - uri_template="resource://items/{id}", - meta=metadata, - ) - - assert template.meta is not None - assert template.meta == metadata - assert template.meta["source"] == "database" - assert template.meta["cached"] is True - - def test_add_template_without_metadata(self): - """Test that ResourceManager.add_template() works without meta parameter.""" - - manager = ResourceManager() - - def get_item(id: str) -> str: # pragma: no cover - return f"Item {id}" - - template = manager.add_template( - fn=get_item, - uri_template="resource://items/{id}", - ) - - assert template.meta is None + tmp_file = tmp_path / "file" + tmp_file.touch() + yield tmp_file + + +def test_init_with_resources(temp_file: Path, caplog: pytest.LogCaptureFixture): + resource = FileResource(uri=f"file://{temp_file}", name="test", path=temp_file) + manager = ResourceManager(resources=[resource]) + assert manager.list_resources() == [resource] + + duplicate_resource = FileResource(uri=f"file://{temp_file}", name="duplicate", path=temp_file) + + with caplog.at_level(logging.WARNING): + manager = ResourceManager(True, resources=[resource, duplicate_resource]) + + assert "Resource already exists" in caplog.text + assert manager.list_resources() == [resource] + + +def test_add_resource(temp_file: Path): + """Test adding a resource.""" + manager = ResourceManager() + resource = FileResource(uri=f"file://{temp_file}", name="test", path=temp_file) + added = manager.add_resource(resource) + assert added == resource + assert manager.list_resources() == [resource] + + +def test_add_duplicate_resource(temp_file: Path): + """Test adding the same resource twice.""" + manager = ResourceManager() + resource = FileResource(uri=f"file://{temp_file}", name="test", path=temp_file) + first = manager.add_resource(resource) + second = manager.add_resource(resource) + assert first == second + assert manager.list_resources() == [resource] + + +def test_warn_on_duplicate_resources(temp_file: Path, caplog: pytest.LogCaptureFixture): + """Test warning on duplicate resources.""" + manager = ResourceManager() + resource = FileResource(uri=f"file://{temp_file}", name="test", path=temp_file) + manager.add_resource(resource) + manager.add_resource(resource) + assert "Resource already exists" in caplog.text + + +def test_disable_warn_on_duplicate_resources(temp_file: Path, caplog: pytest.LogCaptureFixture): + """Test disabling warning on duplicate resources.""" + manager = ResourceManager(warn_on_duplicate_resources=False) + resource = FileResource(uri=f"file://{temp_file}", name="test", path=temp_file) + manager.add_resource(resource) + manager.add_resource(resource) + assert "Resource already exists" not in caplog.text + + +@pytest.mark.anyio +async def test_get_resource(temp_file: Path): + """Test getting a resource by URI.""" + manager = ResourceManager() + resource = FileResource(uri=f"file://{temp_file}", name="test", path=temp_file) + manager.add_resource(resource) + retrieved = await manager.get_resource(resource.uri, Context()) + assert retrieved == resource + + +@pytest.mark.anyio +async def test_get_resource_from_template(): + """Test getting a resource through a template.""" + manager = ResourceManager() + + def greet(name: str) -> str: + return f"Hello, {name}!" + + template = ResourceTemplate.from_function(fn=greet, uri_template="greet://{name}", name="greeter") + manager._templates[template.uri_template] = template + + resource = await manager.get_resource(AnyUrl("greet://world"), Context()) + assert isinstance(resource, FunctionResource) + content = await resource.read() + assert content == "Hello, world!" + + +@pytest.mark.anyio +async def test_get_unknown_resource(): + """Test getting a non-existent resource.""" + manager = ResourceManager() + with pytest.raises(ValueError, match="Unknown resource"): + await manager.get_resource(AnyUrl("unknown://test"), Context()) + + +def test_list_resources(temp_file: Path): + """Test listing all resources.""" + manager = ResourceManager() + resource1 = FileResource(uri=f"file://{temp_file}", name="test1", path=temp_file) + resource2 = FileResource(uri=f"file://{temp_file}2", name="test2", path=temp_file) + + manager.add_resource(resource1) + manager.add_resource(resource2) + + resources = manager.list_resources() + assert len(resources) == 2 + assert resources == [resource1, resource2] + + +def get_item(id: str) -> str: ... + + +def test_add_template_with_metadata(): + """Test that ResourceManager.add_template() accepts and passes meta parameter.""" + manager = ResourceManager() + metadata = {"source": "database", "cached": True} + template = manager.add_template(fn=get_item, uri_template="resource://items/{id}", meta=metadata) + + assert template.meta is not None + assert template.meta == metadata + assert template.meta["source"] == "database" + assert template.meta["cached"] is True + + +def test_add_template_without_metadata(): + """Test that ResourceManager.add_template() works without meta parameter.""" + manager = ResourceManager() + template = manager.add_template(fn=get_item, uri_template="resource://items/{id}") + assert template.meta is None diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index 49b6deb4bb..3457ec944a 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -685,6 +685,32 @@ async def test_remove_tool_and_call(self): class TestServerResources: + async def test_init_with_resources(self): + def get_text() -> str: + """Seeded resource.""" + return "Hello from init!" + + resource = FunctionResource.from_function(fn=get_text, uri="resource://init", name="init_resource") + + mcp = MCPServer(resources=[resource]) + + async with Client(mcp) as client: + assert client.initialize_result.capabilities.resources is not None + + resources = await client.list_resources() + assert len(resources.resources) == 1 + listed = resources.resources[0] + assert listed.uri == "resource://init" + assert listed.name == "init_resource" + assert listed.description == "Seeded resource." + + result = await client.read_resource("resource://init") + + assert len(result.contents) == 1 + content = result.contents[0] + assert isinstance(content, TextResourceContents) + assert content.text == "Hello from init!" + async def test_text_resource(self): mcp = MCPServer() From 8f806da611131e012937cd6ba189c0bca1bceea2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sun, 12 Apr 2026 12:48:13 +0000 Subject: [PATCH 36/60] chore(deps): bump cryptography from 46.0.5 to 46.0.7 in the uv group across 1 directory (#2406) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- uv.lock | 102 ++++++++++++++++++++++++++++---------------------------- 1 file changed, 51 insertions(+), 51 deletions(-) diff --git a/uv.lock b/uv.lock index e6fb526167..705d014aa5 100644 --- a/uv.lock +++ b/uv.lock @@ -448,62 +448,62 @@ toml = [ [[package]] name = "cryptography" -version = "46.0.5" +version = "46.0.7" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cffi", marker = "platform_python_implementation != 'PyPy'" }, { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/60/04/ee2a9e8542e4fa2773b81771ff8349ff19cdd56b7258a0cc442639052edb/cryptography-46.0.5.tar.gz", hash = "sha256:abace499247268e3757271b2f1e244b36b06f8515cf27c4d49468fc9eb16e93d", size = 750064, upload-time = "2026-02-10T19:18:38.255Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f7/81/b0bb27f2ba931a65409c6b8a8b358a7f03c0e46eceacddff55f7c84b1f3b/cryptography-46.0.5-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:351695ada9ea9618b3500b490ad54c739860883df6c1f555e088eaf25b1bbaad", size = 7176289, upload-time = "2026-02-10T19:17:08.274Z" }, - { url = "https://files.pythonhosted.org/packages/ff/9e/6b4397a3e3d15123de3b1806ef342522393d50736c13b20ec4c9ea6693a6/cryptography-46.0.5-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:c18ff11e86df2e28854939acde2d003f7984f721eba450b56a200ad90eeb0e6b", size = 4275637, upload-time = "2026-02-10T19:17:10.53Z" }, - { url = "https://files.pythonhosted.org/packages/63/e7/471ab61099a3920b0c77852ea3f0ea611c9702f651600397ac567848b897/cryptography-46.0.5-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d7e3d356b8cd4ea5aff04f129d5f66ebdc7b6f8eae802b93739ed520c47c79b", size = 4424742, upload-time = "2026-02-10T19:17:12.388Z" }, - { url = "https://files.pythonhosted.org/packages/37/53/a18500f270342d66bf7e4d9f091114e31e5ee9e7375a5aba2e85a91e0044/cryptography-46.0.5-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:50bfb6925eff619c9c023b967d5b77a54e04256c4281b0e21336a130cd7fc263", size = 4277528, upload-time = "2026-02-10T19:17:13.853Z" }, - { url = "https://files.pythonhosted.org/packages/22/29/c2e812ebc38c57b40e7c583895e73c8c5adb4d1e4a0cc4c5a4fdab2b1acc/cryptography-46.0.5-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:803812e111e75d1aa73690d2facc295eaefd4439be1023fefc4995eaea2af90d", size = 4947993, upload-time = "2026-02-10T19:17:15.618Z" }, - { url = "https://files.pythonhosted.org/packages/6b/e7/237155ae19a9023de7e30ec64e5d99a9431a567407ac21170a046d22a5a3/cryptography-46.0.5-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3ee190460e2fbe447175cda91b88b84ae8322a104fc27766ad09428754a618ed", size = 4456855, upload-time = "2026-02-10T19:17:17.221Z" }, - { url = "https://files.pythonhosted.org/packages/2d/87/fc628a7ad85b81206738abbd213b07702bcbdada1dd43f72236ef3cffbb5/cryptography-46.0.5-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:f145bba11b878005c496e93e257c1e88f154d278d2638e6450d17e0f31e558d2", size = 3984635, upload-time = "2026-02-10T19:17:18.792Z" }, - { url = "https://files.pythonhosted.org/packages/84/29/65b55622bde135aedf4565dc509d99b560ee4095e56989e815f8fd2aa910/cryptography-46.0.5-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:e9251e3be159d1020c4030bd2e5f84d6a43fe54b6c19c12f51cde9542a2817b2", size = 4277038, upload-time = "2026-02-10T19:17:20.256Z" }, - { url = "https://files.pythonhosted.org/packages/bc/36/45e76c68d7311432741faf1fbf7fac8a196a0a735ca21f504c75d37e2558/cryptography-46.0.5-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:47fb8a66058b80e509c47118ef8a75d14c455e81ac369050f20ba0d23e77fee0", size = 4912181, upload-time = "2026-02-10T19:17:21.825Z" }, - { url = "https://files.pythonhosted.org/packages/6d/1a/c1ba8fead184d6e3d5afcf03d569acac5ad063f3ac9fb7258af158f7e378/cryptography-46.0.5-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:4c3341037c136030cb46e4b1e17b7418ea4cbd9dd207e4a6f3b2b24e0d4ac731", size = 4456482, upload-time = "2026-02-10T19:17:25.133Z" }, - { url = "https://files.pythonhosted.org/packages/f9/e5/3fb22e37f66827ced3b902cf895e6a6bc1d095b5b26be26bd13c441fdf19/cryptography-46.0.5-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:890bcb4abd5a2d3f852196437129eb3667d62630333aacc13dfd470fad3aaa82", size = 4405497, upload-time = "2026-02-10T19:17:26.66Z" }, - { url = "https://files.pythonhosted.org/packages/1a/df/9d58bb32b1121a8a2f27383fabae4d63080c7ca60b9b5c88be742be04ee7/cryptography-46.0.5-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:80a8d7bfdf38f87ca30a5391c0c9ce4ed2926918e017c29ddf643d0ed2778ea1", size = 4667819, upload-time = "2026-02-10T19:17:28.569Z" }, - { url = "https://files.pythonhosted.org/packages/ea/ed/325d2a490c5e94038cdb0117da9397ece1f11201f425c4e9c57fe5b9f08b/cryptography-46.0.5-cp311-abi3-win32.whl", hash = "sha256:60ee7e19e95104d4c03871d7d7dfb3d22ef8a9b9c6778c94e1c8fcc8365afd48", size = 3028230, upload-time = "2026-02-10T19:17:30.518Z" }, - { url = "https://files.pythonhosted.org/packages/e9/5a/ac0f49e48063ab4255d9e3b79f5def51697fce1a95ea1370f03dc9db76f6/cryptography-46.0.5-cp311-abi3-win_amd64.whl", hash = "sha256:38946c54b16c885c72c4f59846be9743d699eee2b69b6988e0a00a01f46a61a4", size = 3480909, upload-time = "2026-02-10T19:17:32.083Z" }, - { url = "https://files.pythonhosted.org/packages/00/13/3d278bfa7a15a96b9dc22db5a12ad1e48a9eb3d40e1827ef66a5df75d0d0/cryptography-46.0.5-cp314-cp314t-macosx_10_9_universal2.whl", hash = "sha256:94a76daa32eb78d61339aff7952ea819b1734b46f73646a07decb40e5b3448e2", size = 7119287, upload-time = "2026-02-10T19:17:33.801Z" }, - { url = "https://files.pythonhosted.org/packages/67/c8/581a6702e14f0898a0848105cbefd20c058099e2c2d22ef4e476dfec75d7/cryptography-46.0.5-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5be7bf2fb40769e05739dd0046e7b26f9d4670badc7b032d6ce4db64dddc0678", size = 4265728, upload-time = "2026-02-10T19:17:35.569Z" }, - { url = "https://files.pythonhosted.org/packages/dd/4a/ba1a65ce8fc65435e5a849558379896c957870dd64fecea97b1ad5f46a37/cryptography-46.0.5-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fe346b143ff9685e40192a4960938545c699054ba11d4f9029f94751e3f71d87", size = 4408287, upload-time = "2026-02-10T19:17:36.938Z" }, - { url = "https://files.pythonhosted.org/packages/f8/67/8ffdbf7b65ed1ac224d1c2df3943553766914a8ca718747ee3871da6107e/cryptography-46.0.5-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:c69fd885df7d089548a42d5ec05be26050ebcd2283d89b3d30676eb32ff87dee", size = 4270291, upload-time = "2026-02-10T19:17:38.748Z" }, - { url = "https://files.pythonhosted.org/packages/f8/e5/f52377ee93bc2f2bba55a41a886fd208c15276ffbd2569f2ddc89d50e2c5/cryptography-46.0.5-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:8293f3dea7fc929ef7240796ba231413afa7b68ce38fd21da2995549f5961981", size = 4927539, upload-time = "2026-02-10T19:17:40.241Z" }, - { url = "https://files.pythonhosted.org/packages/3b/02/cfe39181b02419bbbbcf3abdd16c1c5c8541f03ca8bda240debc467d5a12/cryptography-46.0.5-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:1abfdb89b41c3be0365328a410baa9df3ff8a9110fb75e7b52e66803ddabc9a9", size = 4442199, upload-time = "2026-02-10T19:17:41.789Z" }, - { url = "https://files.pythonhosted.org/packages/c0/96/2fcaeb4873e536cf71421a388a6c11b5bc846e986b2b069c79363dc1648e/cryptography-46.0.5-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:d66e421495fdb797610a08f43b05269e0a5ea7f5e652a89bfd5a7d3c1dee3648", size = 3960131, upload-time = "2026-02-10T19:17:43.379Z" }, - { url = "https://files.pythonhosted.org/packages/d8/d2/b27631f401ddd644e94c5cf33c9a4069f72011821cf3dc7309546b0642a0/cryptography-46.0.5-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:4e817a8920bfbcff8940ecfd60f23d01836408242b30f1a708d93198393a80b4", size = 4270072, upload-time = "2026-02-10T19:17:45.481Z" }, - { url = "https://files.pythonhosted.org/packages/f4/a7/60d32b0370dae0b4ebe55ffa10e8599a2a59935b5ece1b9f06edb73abdeb/cryptography-46.0.5-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:68f68d13f2e1cb95163fa3b4db4bf9a159a418f5f6e7242564fc75fcae667fd0", size = 4892170, upload-time = "2026-02-10T19:17:46.997Z" }, - { url = "https://files.pythonhosted.org/packages/d2/b9/cf73ddf8ef1164330eb0b199a589103c363afa0cf794218c24d524a58eab/cryptography-46.0.5-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:a3d1fae9863299076f05cb8a778c467578262fae09f9dc0ee9b12eb4268ce663", size = 4441741, upload-time = "2026-02-10T19:17:48.661Z" }, - { url = "https://files.pythonhosted.org/packages/5f/eb/eee00b28c84c726fe8fa0158c65afe312d9c3b78d9d01daf700f1f6e37ff/cryptography-46.0.5-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:c4143987a42a2397f2fc3b4d7e3a7d313fbe684f67ff443999e803dd75a76826", size = 4396728, upload-time = "2026-02-10T19:17:50.058Z" }, - { url = "https://files.pythonhosted.org/packages/65/f4/6bc1a9ed5aef7145045114b75b77c2a8261b4d38717bd8dea111a63c3442/cryptography-46.0.5-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:7d731d4b107030987fd61a7f8ab512b25b53cef8f233a97379ede116f30eb67d", size = 4652001, upload-time = "2026-02-10T19:17:51.54Z" }, - { url = "https://files.pythonhosted.org/packages/86/ef/5d00ef966ddd71ac2e6951d278884a84a40ffbd88948ef0e294b214ae9e4/cryptography-46.0.5-cp314-cp314t-win32.whl", hash = "sha256:c3bcce8521d785d510b2aad26ae2c966092b7daa8f45dd8f44734a104dc0bc1a", size = 3003637, upload-time = "2026-02-10T19:17:52.997Z" }, - { url = "https://files.pythonhosted.org/packages/b7/57/f3f4160123da6d098db78350fdfd9705057aad21de7388eacb2401dceab9/cryptography-46.0.5-cp314-cp314t-win_amd64.whl", hash = "sha256:4d8ae8659ab18c65ced284993c2265910f6c9e650189d4e3f68445ef82a810e4", size = 3469487, upload-time = "2026-02-10T19:17:54.549Z" }, - { url = "https://files.pythonhosted.org/packages/e2/fa/a66aa722105ad6a458bebd64086ca2b72cdd361fed31763d20390f6f1389/cryptography-46.0.5-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:4108d4c09fbbf2789d0c926eb4152ae1760d5a2d97612b92d508d96c861e4d31", size = 7170514, upload-time = "2026-02-10T19:17:56.267Z" }, - { url = "https://files.pythonhosted.org/packages/0f/04/c85bdeab78c8bc77b701bf0d9bdcf514c044e18a46dcff330df5448631b0/cryptography-46.0.5-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7d1f30a86d2757199cb2d56e48cce14deddf1f9c95f1ef1b64ee91ea43fe2e18", size = 4275349, upload-time = "2026-02-10T19:17:58.419Z" }, - { url = "https://files.pythonhosted.org/packages/5c/32/9b87132a2f91ee7f5223b091dc963055503e9b442c98fc0b8a5ca765fab0/cryptography-46.0.5-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:039917b0dc418bb9f6edce8a906572d69e74bd330b0b3fea4f79dab7f8ddd235", size = 4420667, upload-time = "2026-02-10T19:18:00.619Z" }, - { url = "https://files.pythonhosted.org/packages/a1/a6/a7cb7010bec4b7c5692ca6f024150371b295ee1c108bdc1c400e4c44562b/cryptography-46.0.5-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:ba2a27ff02f48193fc4daeadf8ad2590516fa3d0adeeb34336b96f7fa64c1e3a", size = 4276980, upload-time = "2026-02-10T19:18:02.379Z" }, - { url = "https://files.pythonhosted.org/packages/8e/7c/c4f45e0eeff9b91e3f12dbd0e165fcf2a38847288fcfd889deea99fb7b6d/cryptography-46.0.5-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:61aa400dce22cb001a98014f647dc21cda08f7915ceb95df0c9eaf84b4b6af76", size = 4939143, upload-time = "2026-02-10T19:18:03.964Z" }, - { url = "https://files.pythonhosted.org/packages/37/19/e1b8f964a834eddb44fa1b9a9976f4e414cbb7aa62809b6760c8803d22d1/cryptography-46.0.5-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3ce58ba46e1bc2aac4f7d9290223cead56743fa6ab94a5d53292ffaac6a91614", size = 4453674, upload-time = "2026-02-10T19:18:05.588Z" }, - { url = "https://files.pythonhosted.org/packages/db/ed/db15d3956f65264ca204625597c410d420e26530c4e2943e05a0d2f24d51/cryptography-46.0.5-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:420d0e909050490d04359e7fdb5ed7e667ca5c3c402b809ae2563d7e66a92229", size = 3978801, upload-time = "2026-02-10T19:18:07.167Z" }, - { url = "https://files.pythonhosted.org/packages/41/e2/df40a31d82df0a70a0daf69791f91dbb70e47644c58581d654879b382d11/cryptography-46.0.5-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:582f5fcd2afa31622f317f80426a027f30dc792e9c80ffee87b993200ea115f1", size = 4276755, upload-time = "2026-02-10T19:18:09.813Z" }, - { url = "https://files.pythonhosted.org/packages/33/45/726809d1176959f4a896b86907b98ff4391a8aa29c0aaaf9450a8a10630e/cryptography-46.0.5-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:bfd56bb4b37ed4f330b82402f6f435845a5f5648edf1ad497da51a8452d5d62d", size = 4901539, upload-time = "2026-02-10T19:18:11.263Z" }, - { url = "https://files.pythonhosted.org/packages/99/0f/a3076874e9c88ecb2ecc31382f6e7c21b428ede6f55aafa1aa272613e3cd/cryptography-46.0.5-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:a3d507bb6a513ca96ba84443226af944b0f7f47dcc9a399d110cd6146481d24c", size = 4452794, upload-time = "2026-02-10T19:18:12.914Z" }, - { url = "https://files.pythonhosted.org/packages/02/ef/ffeb542d3683d24194a38f66ca17c0a4b8bf10631feef44a7ef64e631b1a/cryptography-46.0.5-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9f16fbdf4da055efb21c22d81b89f155f02ba420558db21288b3d0035bafd5f4", size = 4404160, upload-time = "2026-02-10T19:18:14.375Z" }, - { url = "https://files.pythonhosted.org/packages/96/93/682d2b43c1d5f1406ed048f377c0fc9fc8f7b0447a478d5c65ab3d3a66eb/cryptography-46.0.5-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:ced80795227d70549a411a4ab66e8ce307899fad2220ce5ab2f296e687eacde9", size = 4667123, upload-time = "2026-02-10T19:18:15.886Z" }, - { url = "https://files.pythonhosted.org/packages/45/2d/9c5f2926cb5300a8eefc3f4f0b3f3df39db7f7ce40c8365444c49363cbda/cryptography-46.0.5-cp38-abi3-win32.whl", hash = "sha256:02f547fce831f5096c9a567fd41bc12ca8f11df260959ecc7c3202555cc47a72", size = 3010220, upload-time = "2026-02-10T19:18:17.361Z" }, - { url = "https://files.pythonhosted.org/packages/48/ef/0c2f4a8e31018a986949d34a01115dd057bf536905dca38897bacd21fac3/cryptography-46.0.5-cp38-abi3-win_amd64.whl", hash = "sha256:556e106ee01aa13484ce9b0239bca667be5004efb0aabbed28d353df86445595", size = 3467050, upload-time = "2026-02-10T19:18:18.899Z" }, - { url = "https://files.pythonhosted.org/packages/eb/dd/2d9fdb07cebdf3d51179730afb7d5e576153c6744c3ff8fded23030c204e/cryptography-46.0.5-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:3b4995dc971c9fb83c25aa44cf45f02ba86f71ee600d81091c2f0cbae116b06c", size = 3476964, upload-time = "2026-02-10T19:18:20.687Z" }, - { url = "https://files.pythonhosted.org/packages/e9/6f/6cc6cc9955caa6eaf83660b0da2b077c7fe8ff9950a3c5e45d605038d439/cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:bc84e875994c3b445871ea7181d424588171efec3e185dced958dad9e001950a", size = 4218321, upload-time = "2026-02-10T19:18:22.349Z" }, - { url = "https://files.pythonhosted.org/packages/3e/5d/c4da701939eeee699566a6c1367427ab91a8b7088cc2328c09dbee940415/cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:2ae6971afd6246710480e3f15824ed3029a60fc16991db250034efd0b9fb4356", size = 4381786, upload-time = "2026-02-10T19:18:24.529Z" }, - { url = "https://files.pythonhosted.org/packages/ac/97/a538654732974a94ff96c1db621fa464f455c02d4bb7d2652f4edc21d600/cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:d861ee9e76ace6cf36a6a89b959ec08e7bc2493ee39d07ffe5acb23ef46d27da", size = 4217990, upload-time = "2026-02-10T19:18:25.957Z" }, - { url = "https://files.pythonhosted.org/packages/ae/11/7e500d2dd3ba891197b9efd2da5454b74336d64a7cc419aa7327ab74e5f6/cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:2b7a67c9cd56372f3249b39699f2ad479f6991e62ea15800973b956f4b73e257", size = 4381252, upload-time = "2026-02-10T19:18:27.496Z" }, - { url = "https://files.pythonhosted.org/packages/bc/58/6b3d24e6b9bc474a2dcdee65dfd1f008867015408a271562e4b690561a4d/cryptography-46.0.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:8456928655f856c6e1533ff59d5be76578a7157224dbd9ce6872f25055ab9ab7", size = 3407605, upload-time = "2026-02-10T19:18:29.233Z" }, +sdist = { url = "https://files.pythonhosted.org/packages/47/93/ac8f3d5ff04d54bc814e961a43ae5b0b146154c89c61b47bb07557679b18/cryptography-46.0.7.tar.gz", hash = "sha256:e4cfd68c5f3e0bfdad0d38e023239b96a2fe84146481852dffbcca442c245aa5", size = 750652, upload-time = "2026-04-08T01:57:54.692Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0b/5d/4a8f770695d73be252331e60e526291e3df0c9b27556a90a6b47bccca4c2/cryptography-46.0.7-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:ea42cbe97209df307fdc3b155f1b6fa2577c0defa8f1f7d3be7d31d189108ad4", size = 7179869, upload-time = "2026-04-08T01:56:17.157Z" }, + { url = "https://files.pythonhosted.org/packages/5f/45/6d80dc379b0bbc1f9d1e429f42e4cb9e1d319c7a8201beffd967c516ea01/cryptography-46.0.7-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:b36a4695e29fe69215d75960b22577197aca3f7a25b9cf9d165dcfe9d80bc325", size = 4275492, upload-time = "2026-04-08T01:56:19.36Z" }, + { url = "https://files.pythonhosted.org/packages/4a/9a/1765afe9f572e239c3469f2cb429f3ba7b31878c893b246b4b2994ffe2fe/cryptography-46.0.7-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ad9ef796328c5e3c4ceed237a183f5d41d21150f972455a9d926593a1dcb308", size = 4426670, upload-time = "2026-04-08T01:56:21.415Z" }, + { url = "https://files.pythonhosted.org/packages/8f/3e/af9246aaf23cd4ee060699adab1e47ced3f5f7e7a8ffdd339f817b446462/cryptography-46.0.7-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:73510b83623e080a2c35c62c15298096e2a5dc8d51c3b4e1740211839d0dea77", size = 4280275, upload-time = "2026-04-08T01:56:23.539Z" }, + { url = "https://files.pythonhosted.org/packages/0f/54/6bbbfc5efe86f9d71041827b793c24811a017c6ac0fd12883e4caa86b8ed/cryptography-46.0.7-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:cbd5fb06b62bd0721e1170273d3f4d5a277044c47ca27ee257025146c34cbdd1", size = 4928402, upload-time = "2026-04-08T01:56:25.624Z" }, + { url = "https://files.pythonhosted.org/packages/2d/cf/054b9d8220f81509939599c8bdbc0c408dbd2bdd41688616a20731371fe0/cryptography-46.0.7-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:420b1e4109cc95f0e5700eed79908cef9268265c773d3a66f7af1eef53d409ef", size = 4459985, upload-time = "2026-04-08T01:56:27.309Z" }, + { url = "https://files.pythonhosted.org/packages/f9/46/4e4e9c6040fb01c7467d47217d2f882daddeb8828f7df800cb806d8a2288/cryptography-46.0.7-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:24402210aa54baae71d99441d15bb5a1919c195398a87b563df84468160a65de", size = 3990652, upload-time = "2026-04-08T01:56:29.095Z" }, + { url = "https://files.pythonhosted.org/packages/36/5f/313586c3be5a2fbe87e4c9a254207b860155a8e1f3cca99f9910008e7d08/cryptography-46.0.7-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:8a469028a86f12eb7d2fe97162d0634026d92a21f3ae0ac87ed1c4a447886c83", size = 4279805, upload-time = "2026-04-08T01:56:30.928Z" }, + { url = "https://files.pythonhosted.org/packages/69/33/60dfc4595f334a2082749673386a4d05e4f0cf4df8248e63b2c3437585f2/cryptography-46.0.7-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:9694078c5d44c157ef3162e3bf3946510b857df5a3955458381d1c7cfc143ddb", size = 4892883, upload-time = "2026-04-08T01:56:32.614Z" }, + { url = "https://files.pythonhosted.org/packages/c7/0b/333ddab4270c4f5b972f980adef4faa66951a4aaf646ca067af597f15563/cryptography-46.0.7-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:42a1e5f98abb6391717978baf9f90dc28a743b7d9be7f0751a6f56a75d14065b", size = 4459756, upload-time = "2026-04-08T01:56:34.306Z" }, + { url = "https://files.pythonhosted.org/packages/d2/14/633913398b43b75f1234834170947957c6b623d1701ffc7a9600da907e89/cryptography-46.0.7-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:91bbcb08347344f810cbe49065914fe048949648f6bd5c2519f34619142bbe85", size = 4410244, upload-time = "2026-04-08T01:56:35.977Z" }, + { url = "https://files.pythonhosted.org/packages/10/f2/19ceb3b3dc14009373432af0c13f46aa08e3ce334ec6eff13492e1812ccd/cryptography-46.0.7-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5d1c02a14ceb9148cc7816249f64f623fbfee39e8c03b3650d842ad3f34d637e", size = 4674868, upload-time = "2026-04-08T01:56:38.034Z" }, + { url = "https://files.pythonhosted.org/packages/1a/bb/a5c213c19ee94b15dfccc48f363738633a493812687f5567addbcbba9f6f/cryptography-46.0.7-cp311-abi3-win32.whl", hash = "sha256:d23c8ca48e44ee015cd0a54aeccdf9f09004eba9fc96f38c911011d9ff1bd457", size = 3026504, upload-time = "2026-04-08T01:56:39.666Z" }, + { url = "https://files.pythonhosted.org/packages/2b/02/7788f9fefa1d060ca68717c3901ae7fffa21ee087a90b7f23c7a603c32ae/cryptography-46.0.7-cp311-abi3-win_amd64.whl", hash = "sha256:397655da831414d165029da9bc483bed2fe0e75dde6a1523ec2fe63f3c46046b", size = 3488363, upload-time = "2026-04-08T01:56:41.893Z" }, + { url = "https://files.pythonhosted.org/packages/7b/56/15619b210e689c5403bb0540e4cb7dbf11a6bf42e483b7644e471a2812b3/cryptography-46.0.7-cp314-cp314t-macosx_10_9_universal2.whl", hash = "sha256:d151173275e1728cf7839aaa80c34fe550c04ddb27b34f48c232193df8db5842", size = 7119671, upload-time = "2026-04-08T01:56:44Z" }, + { url = "https://files.pythonhosted.org/packages/74/66/e3ce040721b0b5599e175ba91ab08884c75928fbeb74597dd10ef13505d2/cryptography-46.0.7-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:db0f493b9181c7820c8134437eb8b0b4792085d37dbb24da050476ccb664e59c", size = 4268551, upload-time = "2026-04-08T01:56:46.071Z" }, + { url = "https://files.pythonhosted.org/packages/03/11/5e395f961d6868269835dee1bafec6a1ac176505a167f68b7d8818431068/cryptography-46.0.7-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ebd6daf519b9f189f85c479427bbd6e9c9037862cf8fe89ee35503bd209ed902", size = 4408887, upload-time = "2026-04-08T01:56:47.718Z" }, + { url = "https://files.pythonhosted.org/packages/40/53/8ed1cf4c3b9c8e611e7122fb56f1c32d09e1fff0f1d77e78d9ff7c82653e/cryptography-46.0.7-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:b7b412817be92117ec5ed95f880defe9cf18a832e8cafacf0a22337dc1981b4d", size = 4271354, upload-time = "2026-04-08T01:56:49.312Z" }, + { url = "https://files.pythonhosted.org/packages/50/46/cf71e26025c2e767c5609162c866a78e8a2915bbcfa408b7ca495c6140c4/cryptography-46.0.7-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:fbfd0e5f273877695cb93baf14b185f4878128b250cc9f8e617ea0c025dfb022", size = 4905845, upload-time = "2026-04-08T01:56:50.916Z" }, + { url = "https://files.pythonhosted.org/packages/c0/ea/01276740375bac6249d0a971ebdf6b4dc9ead0ee0a34ef3b5a88c1a9b0d4/cryptography-46.0.7-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:ffca7aa1d00cf7d6469b988c581598f2259e46215e0140af408966a24cf086ce", size = 4444641, upload-time = "2026-04-08T01:56:52.882Z" }, + { url = "https://files.pythonhosted.org/packages/3d/4c/7d258f169ae71230f25d9f3d06caabcff8c3baf0978e2b7d65e0acac3827/cryptography-46.0.7-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:60627cf07e0d9274338521205899337c5d18249db56865f943cbe753aa96f40f", size = 3967749, upload-time = "2026-04-08T01:56:54.597Z" }, + { url = "https://files.pythonhosted.org/packages/b5/2a/2ea0767cad19e71b3530e4cad9605d0b5e338b6a1e72c37c9c1ceb86c333/cryptography-46.0.7-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:80406c3065e2c55d7f49a9550fe0c49b3f12e5bfff5dedb727e319e1afb9bf99", size = 4270942, upload-time = "2026-04-08T01:56:56.416Z" }, + { url = "https://files.pythonhosted.org/packages/41/3d/fe14df95a83319af25717677e956567a105bb6ab25641acaa093db79975d/cryptography-46.0.7-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:c5b1ccd1239f48b7151a65bc6dd54bcfcc15e028c8ac126d3fada09db0e07ef1", size = 4871079, upload-time = "2026-04-08T01:56:58.31Z" }, + { url = "https://files.pythonhosted.org/packages/9c/59/4a479e0f36f8f378d397f4eab4c850b4ffb79a2f0d58704b8fa0703ddc11/cryptography-46.0.7-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:d5f7520159cd9c2154eb61eb67548ca05c5774d39e9c2c4339fd793fe7d097b2", size = 4443999, upload-time = "2026-04-08T01:57:00.508Z" }, + { url = "https://files.pythonhosted.org/packages/28/17/b59a741645822ec6d04732b43c5d35e4ef58be7bfa84a81e5ae6f05a1d33/cryptography-46.0.7-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:fcd8eac50d9138c1d7fc53a653ba60a2bee81a505f9f8850b6b2888555a45d0e", size = 4399191, upload-time = "2026-04-08T01:57:02.654Z" }, + { url = "https://files.pythonhosted.org/packages/59/6a/bb2e166d6d0e0955f1e9ff70f10ec4b2824c9cfcdb4da772c7dd69cc7d80/cryptography-46.0.7-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:65814c60f8cc400c63131584e3e1fad01235edba2614b61fbfbfa954082db0ee", size = 4655782, upload-time = "2026-04-08T01:57:04.592Z" }, + { url = "https://files.pythonhosted.org/packages/95/b6/3da51d48415bcb63b00dc17c2eff3a651b7c4fed484308d0f19b30e8cb2c/cryptography-46.0.7-cp314-cp314t-win32.whl", hash = "sha256:fdd1736fed309b4300346f88f74cd120c27c56852c3838cab416e7a166f67298", size = 3002227, upload-time = "2026-04-08T01:57:06.91Z" }, + { url = "https://files.pythonhosted.org/packages/32/a8/9f0e4ed57ec9cebe506e58db11ae472972ecb0c659e4d52bbaee80ca340a/cryptography-46.0.7-cp314-cp314t-win_amd64.whl", hash = "sha256:e06acf3c99be55aa3b516397fe42f5855597f430add9c17fa46bf2e0fb34c9bb", size = 3475332, upload-time = "2026-04-08T01:57:08.807Z" }, + { url = "https://files.pythonhosted.org/packages/a7/7f/cd42fc3614386bc0c12f0cb3c4ae1fc2bbca5c9662dfed031514911d513d/cryptography-46.0.7-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:462ad5cb1c148a22b2e3bcc5ad52504dff325d17daf5df8d88c17dda1f75f2a4", size = 7165618, upload-time = "2026-04-08T01:57:10.645Z" }, + { url = "https://files.pythonhosted.org/packages/a5/d0/36a49f0262d2319139d2829f773f1b97ef8aef7f97e6e5bd21455e5a8fb5/cryptography-46.0.7-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:84d4cced91f0f159a7ddacad249cc077e63195c36aac40b4150e7a57e84fffe7", size = 4270628, upload-time = "2026-04-08T01:57:12.885Z" }, + { url = "https://files.pythonhosted.org/packages/8a/6c/1a42450f464dda6ffbe578a911f773e54dd48c10f9895a23a7e88b3e7db5/cryptography-46.0.7-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:128c5edfe5e5938b86b03941e94fac9ee793a94452ad1365c9fc3f4f62216832", size = 4415405, upload-time = "2026-04-08T01:57:14.923Z" }, + { url = "https://files.pythonhosted.org/packages/9a/92/4ed714dbe93a066dc1f4b4581a464d2d7dbec9046f7c8b7016f5286329e2/cryptography-46.0.7-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:5e51be372b26ef4ba3de3c167cd3d1022934bc838ae9eaad7e644986d2a3d163", size = 4272715, upload-time = "2026-04-08T01:57:16.638Z" }, + { url = "https://files.pythonhosted.org/packages/b7/e6/a26b84096eddd51494bba19111f8fffe976f6a09f132706f8f1bf03f51f7/cryptography-46.0.7-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:cdf1a610ef82abb396451862739e3fc93b071c844399e15b90726ef7470eeaf2", size = 4918400, upload-time = "2026-04-08T01:57:19.021Z" }, + { url = "https://files.pythonhosted.org/packages/c7/08/ffd537b605568a148543ac3c2b239708ae0bd635064bab41359252ef88ed/cryptography-46.0.7-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:1d25aee46d0c6f1a501adcddb2d2fee4b979381346a78558ed13e50aa8a59067", size = 4450634, upload-time = "2026-04-08T01:57:21.185Z" }, + { url = "https://files.pythonhosted.org/packages/16/01/0cd51dd86ab5b9befe0d031e276510491976c3a80e9f6e31810cce46c4ad/cryptography-46.0.7-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:cdfbe22376065ffcf8be74dc9a909f032df19bc58a699456a21712d6e5eabfd0", size = 3985233, upload-time = "2026-04-08T01:57:22.862Z" }, + { url = "https://files.pythonhosted.org/packages/92/49/819d6ed3a7d9349c2939f81b500a738cb733ab62fbecdbc1e38e83d45e12/cryptography-46.0.7-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:abad9dac36cbf55de6eb49badd4016806b3165d396f64925bf2999bcb67837ba", size = 4271955, upload-time = "2026-04-08T01:57:24.814Z" }, + { url = "https://files.pythonhosted.org/packages/80/07/ad9b3c56ebb95ed2473d46df0847357e01583f4c52a85754d1a55e29e4d0/cryptography-46.0.7-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:935ce7e3cfdb53e3536119a542b839bb94ec1ad081013e9ab9b7cfd478b05006", size = 4879888, upload-time = "2026-04-08T01:57:26.88Z" }, + { url = "https://files.pythonhosted.org/packages/b8/c7/201d3d58f30c4c2bdbe9b03844c291feb77c20511cc3586daf7edc12a47b/cryptography-46.0.7-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:35719dc79d4730d30f1c2b6474bd6acda36ae2dfae1e3c16f2051f215df33ce0", size = 4449961, upload-time = "2026-04-08T01:57:29.068Z" }, + { url = "https://files.pythonhosted.org/packages/a5/ef/649750cbf96f3033c3c976e112265c33906f8e462291a33d77f90356548c/cryptography-46.0.7-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:7bbc6ccf49d05ac8f7d7b5e2e2c33830d4fe2061def88210a126d130d7f71a85", size = 4401696, upload-time = "2026-04-08T01:57:31.029Z" }, + { url = "https://files.pythonhosted.org/packages/41/52/a8908dcb1a389a459a29008c29966c1d552588d4ae6d43f3a1a4512e0ebe/cryptography-46.0.7-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a1529d614f44b863a7b480c6d000fe93b59acee9c82ffa027cfadc77521a9f5e", size = 4664256, upload-time = "2026-04-08T01:57:33.144Z" }, + { url = "https://files.pythonhosted.org/packages/4b/fa/f0ab06238e899cc3fb332623f337a7364f36f4bb3f2534c2bb95a35b132c/cryptography-46.0.7-cp38-abi3-win32.whl", hash = "sha256:f247c8c1a1fb45e12586afbb436ef21ff1e80670b2861a90353d9b025583d246", size = 3013001, upload-time = "2026-04-08T01:57:34.933Z" }, + { url = "https://files.pythonhosted.org/packages/d2/f1/00ce3bde3ca542d1acd8f8cfa38e446840945aa6363f9b74746394b14127/cryptography-46.0.7-cp38-abi3-win_amd64.whl", hash = "sha256:506c4ff91eff4f82bdac7633318a526b1d1309fc07ca76a3ad182cb5b686d6d3", size = 3472985, upload-time = "2026-04-08T01:57:36.714Z" }, + { url = "https://files.pythonhosted.org/packages/63/0c/dca8abb64e7ca4f6b2978769f6fea5ad06686a190cec381f0a796fdcaaba/cryptography-46.0.7-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:fc9ab8856ae6cf7c9358430e49b368f3108f050031442eaeb6b9d87e4dcf4e4f", size = 3476879, upload-time = "2026-04-08T01:57:38.664Z" }, + { url = "https://files.pythonhosted.org/packages/3a/ea/075aac6a84b7c271578d81a2f9968acb6e273002408729f2ddff517fed4a/cryptography-46.0.7-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:d3b99c535a9de0adced13d159c5a9cf65c325601aa30f4be08afd680643e9c15", size = 4219700, upload-time = "2026-04-08T01:57:40.625Z" }, + { url = "https://files.pythonhosted.org/packages/6c/7b/1c55db7242b5e5612b29fc7a630e91ee7a6e3c8e7bf5406d22e206875fbd/cryptography-46.0.7-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d02c738dacda7dc2a74d1b2b3177042009d5cab7c7079db74afc19e56ca1b455", size = 4385982, upload-time = "2026-04-08T01:57:42.725Z" }, + { url = "https://files.pythonhosted.org/packages/cb/da/9870eec4b69c63ef5925bf7d8342b7e13bc2ee3d47791461c4e49ca212f4/cryptography-46.0.7-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:04959522f938493042d595a736e7dbdff6eb6cc2339c11465b3ff89343b65f65", size = 4219115, upload-time = "2026-04-08T01:57:44.939Z" }, + { url = "https://files.pythonhosted.org/packages/f4/72/05aa5832b82dd341969e9a734d1812a6aadb088d9eb6f0430fc337cc5a8f/cryptography-46.0.7-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:3986ac1dee6def53797289999eabe84798ad7817f3e97779b5061a95b0ee4968", size = 4385479, upload-time = "2026-04-08T01:57:46.86Z" }, + { url = "https://files.pythonhosted.org/packages/20/2a/1b016902351a523aa2bd446b50a5bc1175d7a7d1cf90fe2ef904f9b84ebc/cryptography-46.0.7-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:258514877e15963bd43b558917bc9f54cf7cf866c38aa576ebf47a77ddbc43a4", size = 3412829, upload-time = "2026-04-08T01:57:48.874Z" }, ] [[package]] From 941089e06aa4f7ea09c58f8cae42e28059bb9c9b Mon Sep 17 00:00:00 2001 From: Max <224885523+maxisbey@users.noreply.github.com> Date: Mon, 13 Apr 2026 12:50:04 +0100 Subject: [PATCH 37/60] docs: modernize development guidelines and rename to AGENTS.md (#2413) --- AGENTS.md | 138 ++++++++++++++++++++++++++++++++++++++ CLAUDE.md | 175 +------------------------------------------------ pyproject.toml | 1 - 3 files changed, 139 insertions(+), 175 deletions(-) create mode 100644 AGENTS.md diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000000..9692271044 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,138 @@ +# Development Guidelines + +## Branching Model + + + +- `main` is currently the V2 rework. Breaking changes are expected here — when removing or + replacing an API, delete it outright and document the change in + `docs/migration.md`. Do not add `@deprecated` shims or backward-compat layers + on `main`. +- `v1.x` is the release branch for the current stable line. Backport PRs target + this branch and use a `[v1.x]` title prefix. +- `README.md` is frozen at v1 (a pre-commit hook rejects edits). Edit + `README.v2.md` instead. + +## Package Management + +- ONLY use uv, NEVER pip +- Installation: `uv add ` +- Running tools: `uv run --frozen `. Always pass `--frozen` so uv doesn't + rewrite `uv.lock` as a side effect. +- Cross-version testing: `uv run --frozen --python 3.10 pytest ...` to run + against a specific interpreter (CI covers 3.10–3.14). +- Upgrading: `uv lock --upgrade-package ` +- FORBIDDEN: `uv pip install`, `@latest` syntax +- Don't raise dependency floors for CVEs alone. The `>=` constraint already + lets users upgrade. Only raise a floor when the SDK needs functionality from + the newer version, and don't add SDK code to work around a dependency's + vulnerability. See Kludex/uvicorn#2643 and python-sdk #1552 for reasoning. + +## Code Quality + +- Type hints required for all code +- Public APIs must have docstrings. When a public API raises exceptions a + caller would reasonably catch, document them in a `Raises:` section. Don't + list exceptions from argument validation or programmer error. +- `src/mcp/__init__.py` defines the public API surface via `__all__`. Adding a + symbol there is a deliberate API decision, not a convenience re-export. +- IMPORTANT: All imports go at the top of the file — inline imports hide + dependencies and obscure circular-import bugs. Only exception: when a + top-level import genuinely can't work (lazy-loading optional deps, or + tests that re-import a module). + +## Testing + +- Framework: `uv run --frozen pytest` +- Async testing: use anyio, not asyncio +- Do not use `Test` prefixed classes, use functions +- IMPORTANT: Tests should be fast and deterministic. Prefer in-memory async execution; + reach for threads only when necessary, and subprocesses only as a last resort. +- For end-to-end behavior, an in-memory `Client(server)` is usually the + cleanest approach (see `tests/client/test_client.py` for the canonical + pattern). For narrower changes, testing the function directly is fine. Use + judgment. +- Test files mirror the source tree: `src/mcp/client/stdio.py` → + `tests/client/test_stdio.py`. Add tests to the existing file for that module. +- Avoid `anyio.sleep()` with a fixed duration to wait for async operations. Instead: + - Use `anyio.Event` — set it in the callback/handler, `await event.wait()` in the test + - For stream messages, use `await stream.receive()` instead of `sleep()` + `receive_nowait()` + - Exception: `sleep()` is appropriate when testing time-based features (e.g., timeouts) +- Wrap indefinite waits (`event.wait()`, `stream.receive()`) in `anyio.fail_after(5)` to prevent hangs +- Pytest is configured with `filterwarnings = ["error"]`, so warnings fail + tests. Don't silence warnings from your own code; fix the underlying cause. + Scoped `ignore::` entries for upstream libraries are acceptable in + `pyproject.toml` with a comment explaining why. + +### Coverage + +CI requires 100% (`fail_under = 100`, `branch = true`). + +- Full check: `./scripts/test` (~23s). Runs coverage + `strict-no-cover` on the + default Python. Not identical to CI: CI runs 3.10–3.14 × {ubuntu, windows} + × {locked, lowest-direct}, and some branch-coverage quirks only surface on + specific matrix entries. +- Targeted check while iterating (~4s, deterministic): + + ```bash + uv run --frozen coverage erase + uv run --frozen coverage run -m pytest tests/path/test_foo.py + uv run --frozen coverage combine + uv run --frozen coverage report --include='src/mcp/path/foo.py' --fail-under=0 + # UV_FROZEN=1 propagates --frozen to the uv subprocess strict-no-cover spawns + UV_FROZEN=1 uv run --frozen strict-no-cover + ``` + + Partial runs can't hit 100% (coverage tracks `tests/` too), so `--fail-under=0` + and `--include` scope the report. `strict-no-cover` has no false positives on + partial runs — if your new test executes a line marked `# pragma: no cover`, + even a single-file run catches it. + +Avoid adding new `# pragma: no cover`, `# type: ignore`, or `# noqa` comments. +In tests, use `assert isinstance(x, T)` to narrow types instead of +`# type: ignore`. In library code (`src/`), a `# pragma: no cover` needs very +good reasoning — it usually means a test is missing. Audit before pushing: + +```bash +git diff origin/main... | grep -E '^\+.*(pragma|type: ignore|noqa)' +``` + +What the existing pragmas mean: + +- `# pragma: no cover` — line is never executed. CI's `strict-no-cover` (skipped + on Windows runners) fails if it IS executed. When your test starts covering + such a line, remove the pragma. +- `# pragma: lax no cover` — excluded from coverage but not checked by + `strict-no-cover`. Use for lines covered on some platforms/versions but not + others. +- `# pragma: no branch` — excludes branch arcs only. coverage.py misreports the + `->exit` arc for nested `async with` on Python 3.11+ (worse on 3.14/Windows). + +## Breaking Changes + +When making breaking changes, document them in `docs/migration.md`. Include: + +- What changed +- Why it changed +- How to migrate existing code + +Search for related sections in the migration guide and group related changes together +rather than adding new standalone sections. + +## Formatting & Type Checking + +- Format: `uv run --frozen ruff format .` +- Lint: `uv run --frozen ruff check . --fix` +- Type check: `uv run --frozen pyright` +- Pre-commit runs all of the above plus markdownlint, a `uv.lock` consistency + check, and README checks — see `.pre-commit-config.yaml` + +## Exception Handling + +- **Always use `logger.exception()` instead of `logger.error()` when catching exceptions** + - Don't include the exception in the message: `logger.exception("Failed")` not `logger.exception(f"Failed: {e}")` +- **Catch specific exceptions** where possible: + - File ops: `except (OSError, PermissionError):` + - JSON: `except json.JSONDecodeError:` + - Network: `except (ConnectionError, TimeoutError):` +- **FORBIDDEN** `except Exception:` - unless in top-level handlers diff --git a/CLAUDE.md b/CLAUDE.md index 2eee085e13..43c994c2d3 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,174 +1 @@ -# Development Guidelines - -This document contains critical information about working with this codebase. Follow these guidelines precisely. - -## Core Development Rules - -1. Package Management - - ONLY use uv, NEVER pip - - Installation: `uv add ` - - Running tools: `uv run ` - - Upgrading: `uv lock --upgrade-package ` - - FORBIDDEN: `uv pip install`, `@latest` syntax - -2. Code Quality - - Type hints required for all code - - Public APIs must have docstrings - - Functions must be focused and small - - Follow existing patterns exactly - - Line length: 120 chars maximum - - FORBIDDEN: imports inside functions. THEY SHOULD BE AT THE TOP OF THE FILE. - -3. Testing Requirements - - Framework: `uv run --frozen pytest` - - Async testing: use anyio, not asyncio - - Do not use `Test` prefixed classes, use functions - - Coverage: test edge cases and errors - - New features require tests - - Bug fixes require regression tests - - IMPORTANT: The `tests/client/test_client.py` is the most well designed test file. Follow its patterns. - - IMPORTANT: Be minimal, and focus on E2E tests: Use the `mcp.client.Client` whenever possible. - - Coverage: CI requires 100% (`fail_under = 100`, `branch = true`). - - Full check: `./scripts/test` (~23s). Runs coverage + `strict-no-cover` on the - default Python. Not identical to CI: CI also runs 3.10–3.14 × {ubuntu, windows}, - and some branch-coverage quirks only surface on specific matrix entries. - - Targeted check while iterating (~4s, deterministic): - - ```bash - uv run --frozen coverage erase - uv run --frozen coverage run -m pytest tests/path/test_foo.py - uv run --frozen coverage combine - uv run --frozen coverage report --include='src/mcp/path/foo.py' --fail-under=0 - UV_FROZEN=1 uv run --frozen strict-no-cover - ``` - - Partial runs can't hit 100% (coverage tracks `tests/` too), so `--fail-under=0` - and `--include` scope the report. `strict-no-cover` has no false positives on - partial runs — if your new test executes a line marked `# pragma: no cover`, - even a single-file run catches it. - - Coverage pragmas: - - `# pragma: no cover` — line is never executed. CI's `strict-no-cover` fails if - it IS executed. When your test starts covering such a line, remove the pragma. - - `# pragma: lax no cover` — excluded from coverage but not checked by - `strict-no-cover`. Use for lines covered on some platforms/versions but not - others. - - `# pragma: no branch` — excludes branch arcs only. coverage.py misreports the - `->exit` arc for nested `async with` on Python 3.11+ (worse on 3.14/Windows). - - Avoid `anyio.sleep()` with a fixed duration to wait for async operations. Instead: - - Use `anyio.Event` — set it in the callback/handler, `await event.wait()` in the test - - For stream messages, use `await stream.receive()` instead of `sleep()` + `receive_nowait()` - - Exception: `sleep()` is appropriate when testing time-based features (e.g., timeouts) - - Wrap indefinite waits (`event.wait()`, `stream.receive()`) in `anyio.fail_after(5)` to prevent hangs - -Test files mirror the source tree: `src/mcp/client/streamable_http.py` → `tests/client/test_streamable_http.py` -Add tests to the existing file for that module. - -- For commits fixing bugs or adding features based on user reports add: - - ```bash - git commit --trailer "Reported-by:" - ``` - - Where `` is the name of the user. - -- For commits related to a Github issue, add - - ```bash - git commit --trailer "Github-Issue:#" - ``` - -- NEVER ever mention a `co-authored-by` or similar aspects. In particular, never - mention the tool used to create the commit message or PR. - -## Pull Requests - -- Create a detailed message of what changed. Focus on the high level description of - the problem it tries to solve, and how it is solved. Don't go into the specifics of the - code unless it adds clarity. - -- NEVER ever mention a `co-authored-by` or similar aspects. In particular, never - mention the tool used to create the commit message or PR. - -## Breaking Changes - -When making breaking changes, document them in `docs/migration.md`. Include: - -- What changed -- Why it changed -- How to migrate existing code - -Search for related sections in the migration guide and group related changes together -rather than adding new standalone sections. - -## Python Tools - -## Code Formatting - -1. Ruff - - Format: `uv run --frozen ruff format .` - - Check: `uv run --frozen ruff check .` - - Fix: `uv run --frozen ruff check . --fix` - - Critical issues: - - Line length (88 chars) - - Import sorting (I001) - - Unused imports - - Line wrapping: - - Strings: use parentheses - - Function calls: multi-line with proper indent - - Imports: try to use a single line - -2. Type Checking - - Tool: `uv run --frozen pyright` - - Requirements: - - Type narrowing for strings - - Version warnings can be ignored if checks pass - -3. Pre-commit - - Config: `.pre-commit-config.yaml` - - Runs: on git commit - - Tools: Prettier (YAML/JSON), Ruff (Python) - - Ruff updates: - - Check PyPI versions - - Update config rev - - Commit config first - -## Error Resolution - -1. CI Failures - - Fix order: - 1. Formatting - 2. Type errors - 3. Linting - - Type errors: - - Get full line context - - Check Optional types - - Add type narrowing - - Verify function signatures - -2. Common Issues - - Line length: - - Break strings with parentheses - - Multi-line function calls - - Split imports - - Types: - - Add None checks - - Narrow string types - - Match existing patterns - -3. Best Practices - - Check git status before commits - - Run formatters before type checks - - Keep changes minimal - - Follow existing patterns - - Document public APIs - - Test thoroughly - -## Exception Handling - -- **Always use `logger.exception()` instead of `logger.error()` when catching exceptions** - - Don't include the exception in the message: `logger.exception("Failed")` not `logger.exception(f"Failed: {e}")` -- **Catch specific exceptions** where possible: - - File ops: `except (OSError, PermissionError):` - - JSON: `except json.JSONDecodeError:` - - Network: `except (ConnectionError, TimeoutError):` -- **FORBIDDEN** `except Exception:` - unless in top-level handlers +@AGENTS.md diff --git a/pyproject.toml b/pyproject.toml index be1200cff0..a5d2c3d80a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -185,7 +185,6 @@ filterwarnings = [ # This should be fixed on Uvicorn's side. "ignore::DeprecationWarning:websockets", "ignore:websockets.server.WebSocketServerProtocol is deprecated:DeprecationWarning", - "ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel", # pywin32 internal deprecation warning "ignore:getargs.*The 'u' format is deprecated:DeprecationWarning", ] From 2dfb51a4d1af35ed0462c49ea77cca0a3e68766d Mon Sep 17 00:00:00 2001 From: Felix Weinberger <3823880+felixweinberger@users.noreply.github.com> Date: Mon, 13 Apr 2026 15:42:35 +0100 Subject: [PATCH 38/60] fix(auth): coerce empty-string optional URL fields to None in OAuthClientMetadata (#2404) --- AGENTS.md | 4 +- src/mcp/shared/auth.py | 18 +++++++++ tests/shared/test_auth.py | 82 ++++++++++++++++++++++++++++++++++++++- 3 files changed, 102 insertions(+), 2 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 9692271044..307bd81b3e 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -45,7 +45,9 @@ - Framework: `uv run --frozen pytest` - Async testing: use anyio, not asyncio -- Do not use `Test` prefixed classes, use functions +- Do not use `Test` prefixed classes — write plain top-level `test_*` functions. + Legacy files still contain `Test*` classes; do NOT follow that pattern for new + tests even when adding to such a file. - IMPORTANT: Tests should be fast and deterministic. Prefer in-memory async execution; reach for threads only when necessary, and subprocesses only as a last resort. - For end-to-end behavior, an in-memory `Client(server)` is usually the diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index ca5b7b45ab..ebf534d792 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -67,6 +67,24 @@ class OAuthClientMetadata(BaseModel): software_id: str | None = None software_version: str | None = None + @field_validator( + "client_uri", + "logo_uri", + "tos_uri", + "policy_uri", + "jwks_uri", + mode="before", + ) + @classmethod + def _empty_string_optional_url_to_none(cls, v: object) -> object: + # RFC 7591 §2 marks these URL fields OPTIONAL. Some authorization servers + # echo omitted metadata back as "" instead of dropping the keys, which + # AnyHttpUrl would otherwise reject — throwing away an otherwise valid + # registration response. Treat "" as absent. + if v == "": + return None + return v + def validate_scope(self, requested_scope: str | None) -> list[str] | None: if requested_scope is None: return None diff --git a/tests/shared/test_auth.py b/tests/shared/test_auth.py index cd3c35332f..7463bc5a8a 100644 --- a/tests/shared/test_auth.py +++ b/tests/shared/test_auth.py @@ -1,6 +1,9 @@ """Tests for OAuth 2.0 shared code.""" -from mcp.shared.auth import OAuthMetadata +import pytest +from pydantic import ValidationError + +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthMetadata def test_oauth(): @@ -58,3 +61,80 @@ def test_oauth_with_jarm(): "token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"], } ) + + +# RFC 7591 §2 marks client_uri/logo_uri/tos_uri/policy_uri/jwks_uri as OPTIONAL. +# Some authorization servers echo the client's omitted metadata back as "" +# instead of dropping the keys; without coercion, AnyHttpUrl rejects "" and +# the whole registration response is thrown away even though the server +# returned a valid client_id. + + +@pytest.mark.parametrize( + "empty_field", + ["client_uri", "logo_uri", "tos_uri", "policy_uri", "jwks_uri"], +) +def test_optional_url_empty_string_coerced_to_none(empty_field: str): + data = { + "redirect_uris": ["https://example.com/callback"], + empty_field: "", + } + metadata = OAuthClientMetadata.model_validate(data) + assert getattr(metadata, empty_field) is None + + +def test_all_optional_urls_empty_together(): + data = { + "redirect_uris": ["https://example.com/callback"], + "client_uri": "", + "logo_uri": "", + "tos_uri": "", + "policy_uri": "", + "jwks_uri": "", + } + metadata = OAuthClientMetadata.model_validate(data) + assert metadata.client_uri is None + assert metadata.logo_uri is None + assert metadata.tos_uri is None + assert metadata.policy_uri is None + assert metadata.jwks_uri is None + + +def test_valid_url_passes_through_unchanged(): + data = { + "redirect_uris": ["https://example.com/callback"], + "client_uri": "https://udemy.com/", + } + metadata = OAuthClientMetadata.model_validate(data) + assert str(metadata.client_uri) == "https://udemy.com/" + + +def test_information_full_inherits_coercion(): + """OAuthClientInformationFull subclasses OAuthClientMetadata, so the + same coercion applies to DCR responses parsed via the full model.""" + data = { + "client_id": "abc123", + "redirect_uris": ["https://example.com/callback"], + "client_uri": "", + "logo_uri": "", + "tos_uri": "", + "policy_uri": "", + "jwks_uri": "", + } + info = OAuthClientInformationFull.model_validate(data) + assert info.client_id == "abc123" + assert info.client_uri is None + assert info.logo_uri is None + assert info.tos_uri is None + assert info.policy_uri is None + assert info.jwks_uri is None + + +def test_invalid_non_empty_url_still_rejected(): + """Coercion must only touch empty strings — garbage URLs still raise.""" + data = { + "redirect_uris": ["https://example.com/callback"], + "client_uri": "not a url", + } + with pytest.raises(ValidationError): + OAuthClientMetadata.model_validate(data) From 5cbd259c3bc97a9fe6ea661941feb637e72551be Mon Sep 17 00:00:00 2001 From: Max <224885523+maxisbey@users.noreply.github.com> Date: Mon, 13 Apr 2026 17:03:40 +0100 Subject: [PATCH 39/60] fix: catch PydanticUserError when generating output schema (pydantic 2.13 compat) (#2434) --- src/mcp/server/mcpserver/utilities/func_metadata.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/mcp/server/mcpserver/utilities/func_metadata.py b/src/mcp/server/mcpserver/utilities/func_metadata.py index 062b47d0ff..4a76106371 100644 --- a/src/mcp/server/mcpserver/utilities/func_metadata.py +++ b/src/mcp/server/mcpserver/utilities/func_metadata.py @@ -9,7 +9,7 @@ import anyio import anyio.to_thread import pydantic_core -from pydantic import BaseModel, ConfigDict, Field, WithJsonSchema, create_model +from pydantic import BaseModel, ConfigDict, Field, PydanticUserError, WithJsonSchema, create_model from pydantic.fields import FieldInfo from pydantic.json_schema import GenerateJsonSchema, JsonSchemaWarningKind from typing_extensions import is_typeddict @@ -402,9 +402,16 @@ def _try_create_model_and_schema( # Use StrictJsonSchema to raise exceptions instead of warnings try: schema = model.model_json_schema(schema_generator=StrictJsonSchema) - except (TypeError, ValueError, pydantic_core.SchemaError, pydantic_core.ValidationError) as e: + except ( + PydanticUserError, + TypeError, + ValueError, + pydantic_core.SchemaError, + pydantic_core.ValidationError, + ) as e: # These are expected errors when a type can't be converted to a Pydantic schema - # TypeError: When Pydantic can't handle the type + # PydanticUserError: When Pydantic can't handle the type (e.g. PydanticInvalidForJsonSchema); + # subclasses TypeError on pydantic <2.13 and RuntimeError on pydantic >=2.13 # ValueError: When there are issues with the type definition (including our custom warnings) # SchemaError: When Pydantic can't build a schema # ValidationError: When validation fails From 437d15aa7126b86ffecfa6fd8c0158851a340ced Mon Sep 17 00:00:00 2001 From: Wils Dawson Date: Tue, 14 Apr 2026 03:48:07 -0700 Subject: [PATCH 40/60] SEP-2207: Refresh token guidance (#2039) --- src/mcp/client/auth/oauth2.py | 15 +- src/mcp/client/auth/utils.py | 40 ++-- tests/client/test_auth.py | 354 ++++++++++++++++++++++++++++++++++ 3 files changed, 392 insertions(+), 17 deletions(-) diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 25075dec3b..72309f5775 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -320,7 +320,7 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: raise OAuthFlowError("No callback handler provided for authorization code grant") # pragma: no cover if self.context.oauth_metadata and self.context.oauth_metadata.authorization_endpoint: - auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint) # pragma: no cover + auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint) else: auth_base_url = self.context.get_authorization_base_url(self.context.server_url) auth_endpoint = urljoin(auth_base_url, "/authorize") @@ -343,11 +343,16 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: # Only include resource param if conditions are met if self.context.should_include_resource_param(self.context.protocol_version): - auth_params["resource"] = self.context.get_resource_url() # RFC 8707 # pragma: no cover + auth_params["resource"] = self.context.get_resource_url() # RFC 8707 if self.context.client_metadata.scope: # pragma: no branch auth_params["scope"] = self.context.client_metadata.scope + # OIDC requires prompt=consent when offline_access is requested + # https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess + if "offline_access" in self.context.client_metadata.scope.split(): + auth_params["prompt"] = "consent" + authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}" await self.context.redirect_handler(authorization_url) @@ -576,6 +581,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. extract_scope_from_www_auth(response), self.context.protected_resource_metadata, self.context.oauth_metadata, + self.context.client_metadata.grant_types, ) # Step 4: Register client or use URL-based client ID (CIMD) @@ -622,7 +628,10 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. try: # Step 2a: Update the required scopes self.context.client_metadata.scope = get_client_metadata_scopes( - extract_scope_from_www_auth(response), self.context.protected_resource_metadata + extract_scope_from_www_auth(response), + self.context.protected_resource_metadata, + self.context.oauth_metadata, + self.context.client_metadata.grant_types, ) # Step 2b: Perform (re-)authorization and token exchange diff --git a/src/mcp/client/auth/utils.py b/src/mcp/client/auth/utils.py index 0ca36b98d8..d75324f2f0 100644 --- a/src/mcp/client/auth/utils.py +++ b/src/mcp/client/auth/utils.py @@ -99,24 +99,36 @@ def get_client_metadata_scopes( www_authenticate_scope: str | None, protected_resource_metadata: ProtectedResourceMetadata | None, authorization_server_metadata: OAuthMetadata | None = None, + client_grant_types: list[str] | None = None, ) -> str | None: - """Select scopes as outlined in the 'Scope Selection Strategy' in the MCP spec.""" - # Per MCP spec, scope selection priority order: - # 1. Use scope from WWW-Authenticate header (if provided) - # 2. Use all scopes from PRM scopes_supported (if available) - # 3. Omit scope parameter if neither is available - + """Select effective scopes and augment for refresh token support.""" + selected_scope: str | None = None + + # MCP spec scope selection priority: + # 1. WWW-Authenticate header scope + # 2. PRM scopes_supported + # 3. AS scopes_supported (SDK fallback) + # 4. Omit scope parameter if www_authenticate_scope is not None: - # Priority 1: WWW-Authenticate header scope - return www_authenticate_scope + selected_scope = www_authenticate_scope elif protected_resource_metadata is not None and protected_resource_metadata.scopes_supported is not None: - # Priority 2: PRM scopes_supported - return " ".join(protected_resource_metadata.scopes_supported) + selected_scope = " ".join(protected_resource_metadata.scopes_supported) elif authorization_server_metadata is not None and authorization_server_metadata.scopes_supported is not None: - return " ".join(authorization_server_metadata.scopes_supported) # pragma: no cover - else: - # Priority 3: Omit scope parameter - return None + selected_scope = " ".join(authorization_server_metadata.scopes_supported) + + # SEP-2207: append offline_access when the AS supports it and the client can use refresh tokens + if ( + selected_scope is not None + and authorization_server_metadata is not None + and authorization_server_metadata.scopes_supported is not None + and "offline_access" in authorization_server_metadata.scopes_supported + and client_grant_types is not None + and "refresh_token" in client_grant_types + and "offline_access" not in selected_scope.split() + ): + selected_scope = f"{selected_scope} offline_access" + + return selected_scope def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: str | None, server_url: str) -> list[str]: diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 5aa985e360..bb0bce4c92 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -2264,3 +2264,357 @@ async def callback_handler() -> tuple[str, str | None]: await auth_flow.asend(final_response) except StopAsyncIteration: pass + + +class TestSEP2207OfflineAccessScope: + """Test SEP-2207: offline_access scope augmentation for OIDC-flavored refresh tokens.""" + + def _make_as_metadata(self, scopes_supported: list[str] | None = None) -> OAuthMetadata: + return OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + scopes_supported=scopes_supported, + ) + + def _make_prm(self, scopes_supported: list[str] | None = None) -> ProtectedResourceMetadata: + return ProtectedResourceMetadata( + resource=AnyHttpUrl("https://api.example.com/v1/mcp"), + authorization_servers=[AnyHttpUrl("https://auth.example.com")], + scopes_supported=scopes_supported, + ) + + def test_offline_access_added_when_as_supports_and_client_has_refresh_token(self): + """offline_access is appended when AS advertises it and client supports refresh_token grant.""" + prm = self._make_prm(scopes_supported=["read", "write"]) + asm = self._make_as_metadata(scopes_supported=["read", "write", "offline_access"]) + + scopes = get_client_metadata_scopes( + www_authenticate_scope=None, + protected_resource_metadata=prm, + authorization_server_metadata=asm, + client_grant_types=["authorization_code", "refresh_token"], + ) + assert scopes == "read write offline_access" + + def test_offline_access_added_with_www_authenticate_scope(self): + """offline_access is appended even when scopes come from WWW-Authenticate header.""" + asm = self._make_as_metadata(scopes_supported=["read", "write", "offline_access"]) + + scopes = get_client_metadata_scopes( + www_authenticate_scope="read write", + protected_resource_metadata=None, + authorization_server_metadata=asm, + client_grant_types=["authorization_code", "refresh_token"], + ) + assert scopes == "read write offline_access" + + def test_offline_access_not_added_when_as_does_not_support(self): + """offline_access is not added when AS does not advertise it in scopes_supported.""" + prm = self._make_prm(scopes_supported=["read", "write"]) + asm = self._make_as_metadata(scopes_supported=["read", "write"]) + + scopes = get_client_metadata_scopes( + www_authenticate_scope=None, + protected_resource_metadata=prm, + authorization_server_metadata=asm, + client_grant_types=["authorization_code", "refresh_token"], + ) + assert scopes == "read write" + + def test_offline_access_not_added_when_client_has_no_refresh_token_grant(self): + """offline_access is not added when client does not support refresh_token grant.""" + prm = self._make_prm(scopes_supported=["read", "write"]) + asm = self._make_as_metadata(scopes_supported=["read", "write", "offline_access"]) + + scopes = get_client_metadata_scopes( + www_authenticate_scope=None, + protected_resource_metadata=prm, + authorization_server_metadata=asm, + client_grant_types=["authorization_code"], + ) + assert scopes == "read write" + + def test_offline_access_not_duplicated_when_already_present(self): + """offline_access is not added again if it already appears in the selected scopes.""" + prm = self._make_prm(scopes_supported=["read", "offline_access", "write"]) + asm = self._make_as_metadata(scopes_supported=["read", "write", "offline_access"]) + + scopes = get_client_metadata_scopes( + www_authenticate_scope=None, + protected_resource_metadata=prm, + authorization_server_metadata=asm, + client_grant_types=["authorization_code", "refresh_token"], + ) + assert scopes == "read offline_access write" + + def test_offline_access_not_added_when_no_scopes_selected(self): + """offline_access is not added when no base scopes are available (None).""" + asm = self._make_as_metadata(scopes_supported=["offline_access"]) + + scopes = get_client_metadata_scopes( + www_authenticate_scope=None, + protected_resource_metadata=None, + authorization_server_metadata=asm, + client_grant_types=["authorization_code", "refresh_token"], + ) + # When AS scopes are the only source and include offline_access, + # the base scope is "offline_access" and no duplication happens + assert scopes == "offline_access" + + def test_offline_access_not_added_when_as_scopes_supported_is_none(self): + """offline_access is not added when AS scopes_supported is None.""" + prm = self._make_prm(scopes_supported=["read", "write"]) + asm = self._make_as_metadata(scopes_supported=None) + + scopes = get_client_metadata_scopes( + www_authenticate_scope=None, + protected_resource_metadata=prm, + authorization_server_metadata=asm, + client_grant_types=["authorization_code", "refresh_token"], + ) + assert scopes == "read write" + + def test_offline_access_not_added_when_no_as_metadata(self): + """offline_access is not added when AS metadata is not available.""" + prm = self._make_prm(scopes_supported=["read", "write"]) + + scopes = get_client_metadata_scopes( + www_authenticate_scope=None, + protected_resource_metadata=prm, + authorization_server_metadata=None, + client_grant_types=["authorization_code", "refresh_token"], + ) + assert scopes == "read write" + + def test_offline_access_not_added_when_no_grant_types_provided(self): + """offline_access is not added when client_grant_types is None.""" + prm = self._make_prm(scopes_supported=["read", "write"]) + asm = self._make_as_metadata(scopes_supported=["read", "write", "offline_access"]) + + scopes = get_client_metadata_scopes( + www_authenticate_scope=None, + protected_resource_metadata=prm, + authorization_server_metadata=asm, + client_grant_types=None, + ) + assert scopes == "read write" + + def test_default_client_metadata_includes_refresh_token_grant(self): + """Default OAuthClientMetadata includes refresh_token in grant_types (SEP-2207 guidance).""" + metadata = OAuthClientMetadata(redirect_uris=[AnyUrl("http://localhost:3030/callback")]) + assert "refresh_token" in metadata.grant_types + + @pytest.mark.anyio + async def test_auth_flow_adds_offline_access_when_as_advertises( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ): + """E2E: auth flow includes offline_access in authorization request when AS advertises it.""" + + captured_auth_url: str | None = None + captured_state: str | None = None + + async def redirect_handler(url: str) -> None: + nonlocal captured_auth_url, captured_state + captured_auth_url = url + parsed = urlparse(url) + params = parse_qs(parsed.query) + captured_state = params.get("state", [None])[0] + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", captured_state + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + provider.context.current_tokens = None + provider.context.token_expiry_time = None + provider._initialized = True + + # Pre-set client info to skip DCR + provider.context.client_info = OAuthClientInformationFull( + client_id="test_client", + client_secret="test_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + + test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") + auth_flow = provider.async_auth_flow(test_request) + + # First request + request = await auth_flow.__anext__() + assert "Authorization" not in request.headers + + # Send 401 + response = httpx.Response(401, headers={}, request=test_request) + + # PRM discovery + prm_request = await auth_flow.asend(response) + prm_response = httpx.Response( + 200, + content=( + b'{"resource": "https://api.example.com/v1/mcp",' + b' "authorization_servers": ["https://auth.example.com"],' + b' "scopes_supported": ["read", "write"]}' + ), + request=prm_request, + ) + + # OAuth metadata discovery - AS advertises offline_access + oauth_request = await auth_flow.asend(prm_response) + oauth_response = httpx.Response( + 200, + content=( + b'{"issuer": "https://auth.example.com",' + b' "authorization_endpoint": "https://auth.example.com/authorize",' + b' "token_endpoint": "https://auth.example.com/token",' + b' "scopes_supported": ["read", "write", "offline_access"]}' + ), + request=oauth_request, + ) + + # This triggers authorization, which calls redirect_handler + token_request = await auth_flow.asend(oauth_response) + + # Verify the authorization URL included offline_access in the scope + assert captured_auth_url is not None + parsed = urlparse(captured_auth_url) + params = parse_qs(parsed.query) + scope_value = params["scope"][0] + scope_parts = scope_value.split() + assert "offline_access" in scope_parts + assert "read" in scope_parts + assert "write" in scope_parts + + # OIDC requires prompt=consent when offline_access is requested + assert params["prompt"][0] == "consent" + + # Complete the token exchange + token_response = httpx.Response( + 200, + content=( + b'{"access_token": "new_access_token", "token_type": "Bearer",' + b' "expires_in": 3600, "refresh_token": "new_refresh_token"}' + ), + request=token_request, + ) + + final_request = await auth_flow.asend(token_response) + assert final_request.headers["Authorization"] == "Bearer new_access_token" + + # Close the generator + final_response = httpx.Response(200, request=final_request) + try: + await auth_flow.asend(final_response) + except StopAsyncIteration: + pass + + @pytest.mark.anyio + async def test_auth_flow_no_offline_access_when_as_does_not_advertise( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ): + """E2E: auth flow does NOT include offline_access when AS doesn't advertise it.""" + + captured_auth_url: str | None = None + captured_state: str | None = None + + async def redirect_handler(url: str) -> None: + nonlocal captured_auth_url, captured_state + captured_auth_url = url + parsed = urlparse(url) + params = parse_qs(parsed.query) + captured_state = params.get("state", [None])[0] + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", captured_state + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + provider.context.current_tokens = None + provider.context.token_expiry_time = None + provider._initialized = True + + # Pre-set client info to skip DCR + provider.context.client_info = OAuthClientInformationFull( + client_id="test_client", + client_secret="test_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + + test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") + auth_flow = provider.async_auth_flow(test_request) + + # First request + await auth_flow.__anext__() + + # Send 401 + response = httpx.Response(401, headers={}, request=test_request) + + # PRM discovery + prm_request = await auth_flow.asend(response) + prm_response = httpx.Response( + 200, + content=( + b'{"resource": "https://api.example.com/v1/mcp",' + b' "authorization_servers": ["https://auth.example.com"],' + b' "scopes_supported": ["read", "write"]}' + ), + request=prm_request, + ) + + # OAuth metadata discovery - AS does NOT advertise offline_access + oauth_request = await auth_flow.asend(prm_response) + oauth_response = httpx.Response( + 200, + content=( + b'{"issuer": "https://auth.example.com",' + b' "authorization_endpoint": "https://auth.example.com/authorize",' + b' "token_endpoint": "https://auth.example.com/token",' + b' "scopes_supported": ["read", "write"]}' + ), + request=oauth_request, + ) + + # This triggers authorization, which calls redirect_handler + token_request = await auth_flow.asend(oauth_response) + + # Verify the authorization URL does NOT include offline_access + assert captured_auth_url is not None + parsed = urlparse(captured_auth_url) + params = parse_qs(parsed.query) + scope_value = params["scope"][0] + scope_parts = scope_value.split() + assert "offline_access" not in scope_parts + assert "read" in scope_parts + assert "write" in scope_parts + + # prompt=consent should NOT be present without offline_access + assert "prompt" not in params + + # Complete the token exchange + token_response = httpx.Response( + 200, + content=b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600}', + request=token_request, + ) + + final_request = await auth_flow.asend(token_response) + assert final_request.headers["Authorization"] == "Bearer new_access_token" + + # Close the generator + final_response = httpx.Response(200, request=final_request) + try: + await auth_flow.asend(final_response) + except StopAsyncIteration: + pass From 3d7b311de07aade1281d18aa7b04689a81ab8793 Mon Sep 17 00:00:00 2001 From: "Gyeongjun Paik (Kent)" Date: Wed, 15 Apr 2026 06:41:51 +0900 Subject: [PATCH 41/60] fix: align Context logging methods with MCP spec data type (#2366) Co-authored-by: Claude Opus 4.6 Co-authored-by: Max Isbey <224885523+maxisbey@users.noreply.github.com> --- README.v2.md | 10 +++---- docs/migration.md | 20 +++++++++++++ src/mcp/server/mcpserver/context.py | 41 +++++++++++---------------- tests/client/test_logging_callback.py | 33 +++++++++------------ 4 files changed, 55 insertions(+), 49 deletions(-) diff --git a/README.v2.md b/README.v2.md index 55d867586d..d0851c04e5 100644 --- a/README.v2.md +++ b/README.v2.md @@ -681,11 +681,11 @@ The Context object provides the following capabilities: - `ctx.mcp_server` - Access to the MCPServer server instance (see [MCPServer Properties](#mcpserver-properties)) - `ctx.session` - Access to the underlying session for advanced communication (see [Session Properties and Methods](#session-properties-and-methods)) - `ctx.request_context` - Access to request-specific data and lifespan resources (see [Request Context Properties](#request-context-properties)) -- `await ctx.debug(message)` - Send debug log message -- `await ctx.info(message)` - Send info log message -- `await ctx.warning(message)` - Send warning log message -- `await ctx.error(message)` - Send error log message -- `await ctx.log(level, message, logger_name=None)` - Send log with custom level +- `await ctx.debug(data)` - Send debug log message +- `await ctx.info(data)` - Send info log message +- `await ctx.warning(data)` - Send warning log message +- `await ctx.error(data)` - Send error log message +- `await ctx.log(level, data, logger_name=None)` - Send log with custom level - `await ctx.report_progress(progress, total=None, message=None)` - Report operation progress - `await ctx.read_resource(uri)` - Read a resource by URI - `await ctx.elicit(message, schema)` - Request additional information from user with validation diff --git a/docs/migration.md b/docs/migration.md index 2528f046c6..8b70885e8d 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -467,6 +467,26 @@ mcp._lowlevel_server._add_request_handler("resources/subscribe", handle_subscrib This is a private API and may change. A public way to register these handlers on `MCPServer` is planned; until then, use this workaround or use the lowlevel `Server` directly. +### `MCPServer`'s `Context` logging: `message` renamed to `data`, `extra` removed + +On the high-level `Context` object (`mcp.server.mcpserver.Context`), `log()`, `.debug()`, `.info()`, `.warning()`, and `.error()` now take `data: Any` instead of `message: str`, matching the MCP spec's `LoggingMessageNotificationParams.data` field which allows any JSON-serializable value. The `extra` parameter has been removed — pass structured data directly as `data`. + +The lowlevel `ServerSession.send_log_message(data: Any)` already accepted arbitrary data and is unchanged. + +`Context.log()` also now accepts all eight RFC-5424 log levels (`debug`, `info`, `notice`, `warning`, `error`, `critical`, `alert`, `emergency`) via the `LoggingLevel` type, not just the four it previously allowed. + +```python +# Before +await ctx.info("Connection failed", extra={"host": "localhost", "port": 5432}) +await ctx.log(level="info", message="hello") + +# After +await ctx.info({"message": "Connection failed", "host": "localhost", "port": 5432}) +await ctx.log(level="info", data="hello") +``` + +Positional calls (`await ctx.info("hello")`) are unaffected. + ### Replace `RootModel` by union types with `TypeAdapter` validation The following union types are no longer `RootModel` subclasses: diff --git a/src/mcp/server/mcpserver/context.py b/src/mcp/server/mcpserver/context.py index 1538adc7c7..e87388eee9 100644 --- a/src/mcp/server/mcpserver/context.py +++ b/src/mcp/server/mcpserver/context.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Generic, Literal +from typing import TYPE_CHECKING, Any, Generic from pydantic import AnyUrl, BaseModel @@ -14,6 +14,7 @@ elicit_with_validation, ) from mcp.server.lowlevel.helper_types import ReadResourceContents +from mcp.types import LoggingLevel if TYPE_CHECKING: from mcp.server.mcpserver.server import MCPServer @@ -186,29 +187,23 @@ async def elicit_url( async def log( self, - level: Literal["debug", "info", "warning", "error"], - message: str, + level: LoggingLevel, + data: Any, *, logger_name: str | None = None, - extra: dict[str, Any] | None = None, ) -> None: """Send a log message to the client. Args: - level: Log level (debug, info, warning, error) - message: Log message + level: Log level (debug, info, notice, warning, error, critical, + alert, emergency) + data: The data to be logged. Any JSON serializable type is allowed + (string, dict, list, number, bool, etc.) per the MCP specification. logger_name: Optional logger name - extra: Optional dictionary with additional structured data to include """ - - if extra: - log_data = {"message": message, **extra} - else: - log_data = message - await self.request_context.session.send_log_message( level=level, - data=log_data, + data=data, logger=logger_name, related_request_id=self.request_id, ) @@ -261,20 +256,18 @@ async def close_standalone_sse_stream(self) -> None: await self._request_context.close_standalone_sse_stream() # Convenience methods for common log levels - async def debug(self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None) -> None: + async def debug(self, data: Any, *, logger_name: str | None = None) -> None: """Send a debug log message.""" - await self.log("debug", message, logger_name=logger_name, extra=extra) + await self.log("debug", data, logger_name=logger_name) - async def info(self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None) -> None: + async def info(self, data: Any, *, logger_name: str | None = None) -> None: """Send an info log message.""" - await self.log("info", message, logger_name=logger_name, extra=extra) + await self.log("info", data, logger_name=logger_name) - async def warning( - self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None - ) -> None: + async def warning(self, data: Any, *, logger_name: str | None = None) -> None: """Send a warning log message.""" - await self.log("warning", message, logger_name=logger_name, extra=extra) + await self.log("warning", data, logger_name=logger_name) - async def error(self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None) -> None: + async def error(self, data: Any, *, logger_name: str | None = None) -> None: """Send an error log message.""" - await self.log("error", message, logger_name=logger_name, extra=extra) + await self.log("error", data, logger_name=logger_name) diff --git a/tests/client/test_logging_callback.py b/tests/client/test_logging_callback.py index 1598fd55f6..454c1d3382 100644 --- a/tests/client/test_logging_callback.py +++ b/tests/client/test_logging_callback.py @@ -1,4 +1,4 @@ -from typing import Any, Literal +from typing import Literal import pytest @@ -36,24 +36,20 @@ async def test_tool_with_log( message: str, level: Literal["debug", "info", "warning", "error"], logger: str, ctx: Context ) -> bool: """Send a log notification to the client.""" - await ctx.log(level=level, message=message, logger_name=logger) + await ctx.log(level=level, data=message, logger_name=logger) return True - @server.tool("test_tool_with_log_extra") - async def test_tool_with_log_extra( - message: str, + @server.tool("test_tool_with_log_dict") + async def test_tool_with_log_dict( level: Literal["debug", "info", "warning", "error"], logger: str, - extra_string: str, - extra_dict: dict[str, Any], ctx: Context, ) -> bool: - """Send a log notification to the client with extra fields.""" + """Send a log notification with a dict payload.""" await ctx.log( level=level, - message=message, + data={"message": "Test log message", "extra_string": "example", "extra_dict": {"a": 1, "b": 2, "c": 3}}, logger_name=logger, - extra={"extra_string": extra_string, "extra_dict": extra_dict}, ) return True @@ -84,18 +80,15 @@ async def message_handler( "logger": "test_logger", }, ) - log_result_with_extra = await client.call_tool( - "test_tool_with_log_extra", + log_result_with_dict = await client.call_tool( + "test_tool_with_log_dict", { - "message": "Test log message", "level": "info", "logger": "test_logger", - "extra_string": "example", - "extra_dict": {"a": 1, "b": 2, "c": 3}, }, ) assert log_result.is_error is False - assert log_result_with_extra.is_error is False + assert log_result_with_dict.is_error is False assert len(logging_collector.log_messages) == 2 # Create meta object with related_request_id added dynamically log = logging_collector.log_messages[0] @@ -103,10 +96,10 @@ async def message_handler( assert log.logger == "test_logger" assert log.data == "Test log message" - log_with_extra = logging_collector.log_messages[1] - assert log_with_extra.level == "info" - assert log_with_extra.logger == "test_logger" - assert log_with_extra.data == { + log_with_dict = logging_collector.log_messages[1] + assert log_with_dict.level == "info" + assert log_with_dict.logger == "test_logger" + assert log_with_dict.data == { "message": "Test log message", "extra_string": "example", "extra_dict": {"a": 1, "b": 2, "c": 3}, From 2b0da5631a7cb460c17d60b8cc735626d5e1b3f8 Mon Sep 17 00:00:00 2001 From: Max <224885523+maxisbey@users.noreply.github.com> Date: Thu, 7 May 2026 17:32:30 +0100 Subject: [PATCH 42/60] build: pin PEP 517 build dependencies (#2547) --- pyproject.toml | 18 ++++++++++++++++++ uv.lock | 13 +++++++++++++ 2 files changed, 31 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index a5d2c3d80a..364b9add0b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,24 @@ mcp = "mcp.cli:app [cli]" [tool.uv] default-groups = ["dev", "docs"] required-version = ">=0.9.5" +# PEP 517 build isolation fetches [build-system].requires (and transitives) at +# floating-latest with no hash check on every fresh sync; uv does not lock them +# (astral-sh/uv#5190). Pinning here narrows that to known-good versions. Covers +# the workspace builds (hatchling + uv-dynamic-versioning) and the legacy +# setuptools fallback used by the strict-no-cover git dep. +build-constraint-dependencies = [ + "hatchling==1.29.0", + "uv-dynamic-versioning==0.14.0", + "dunamai==1.26.1", + "jinja2==3.1.6", + "markupsafe==3.0.3", + "packaging==26.1", + "pathspec==1.0.4", + "pluggy==1.6.0", + "tomlkit==0.14.0", + "trove-classifiers==2026.1.14.14", + "setuptools==82.0.1", +] [dependency-groups] dev = [ diff --git a/uv.lock b/uv.lock index 705d014aa5..b396898b66 100644 --- a/uv.lock +++ b/uv.lock @@ -28,6 +28,19 @@ members = [ "mcp-sse-polling-demo", "mcp-structured-output-lowlevel", ] +build-constraints = [ + { name = "dunamai", specifier = "==1.26.1" }, + { name = "hatchling", specifier = "==1.29.0" }, + { name = "jinja2", specifier = "==3.1.6" }, + { name = "markupsafe", specifier = "==3.0.3" }, + { name = "packaging", specifier = "==26.1" }, + { name = "pathspec", specifier = "==1.0.4" }, + { name = "pluggy", specifier = "==1.6.0" }, + { name = "setuptools", specifier = "==82.0.1" }, + { name = "tomlkit", specifier = "==0.14.0" }, + { name = "trove-classifiers", specifier = "==2026.1.14.14" }, + { name = "uv-dynamic-versioning", specifier = "==0.14.0" }, +] [[package]] name = "annotated-types" From bf3e0010b87a6a2535711500eb9590741b1d1940 Mon Sep 17 00:00:00 2001 From: Dayna Blackwell Date: Fri, 8 May 2026 06:28:36 -0700 Subject: [PATCH 43/60] fix: chain exceptions in get_prompt and read_resource handlers (#2542) --- src/mcp/server/mcpserver/server.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index be77705da6..b3471163b7 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -447,8 +447,8 @@ async def read_resource( context = Context(mcp_server=self) try: resource = await self._resource_manager.get_resource(uri, context) - except ValueError: - raise ResourceError(f"Unknown resource: {uri}") + except ValueError as exc: + raise ResourceError(f"Unknown resource: {uri}") from exc try: content = await resource.read() @@ -1109,4 +1109,4 @@ async def get_prompt( ) except Exception as e: logger.exception(f"Error getting prompt {name}") - raise ValueError(str(e)) + raise ValueError(str(e)) from e From 161834d4aee2633c42d3976c8f8751b6c4d947d5 Mon Sep 17 00:00:00 2001 From: Max <224885523+maxisbey@users.noreply.github.com> Date: Fri, 8 May 2026 17:42:44 +0100 Subject: [PATCH 44/60] refactor: import SSEError from httpx_sse public API (#2560) --- src/mcp/client/sse.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 193204a153..74e5ba8062 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -7,8 +7,7 @@ import anyio import httpx from anyio.abc import TaskStatus -from httpx_sse import aconnect_sse -from httpx_sse._exceptions import SSEError +from httpx_sse import SSEError, aconnect_sse from mcp import types from mcp.shared._context_streams import create_context_streams From f4753440dac8b2b6fa6407808e06c51258b78322 Mon Sep 17 00:00:00 2001 From: Max <224885523+maxisbey@users.noreply.github.com> Date: Mon, 18 May 2026 15:28:13 +0100 Subject: [PATCH 45/60] ci: deploy docs to py.sdk.modelcontextprotocol.io via Pages artifact (v1 at /, v2 at /v2/) (#2634) --- .github/workflows/deploy-docs.yml | 57 +++++++++++++++++++++ .github/workflows/publish-docs-manually.yml | 35 ------------- .github/workflows/publish-pypi.yml | 29 ----------- .gitignore | 1 + docs/index.md | 3 ++ mkdocs.yml | 2 +- pyproject.toml | 1 + scripts/build-docs.sh | 54 +++++++++++++++++++ 8 files changed, 117 insertions(+), 65 deletions(-) create mode 100644 .github/workflows/deploy-docs.yml delete mode 100644 .github/workflows/publish-docs-manually.yml create mode 100755 scripts/build-docs.sh diff --git a/.github/workflows/deploy-docs.yml b/.github/workflows/deploy-docs.yml new file mode 100644 index 0000000000..d9362afd57 --- /dev/null +++ b/.github/workflows/deploy-docs.yml @@ -0,0 +1,57 @@ +name: Deploy Docs + +on: + push: + branches: + - main + - v1.x + paths: + - docs/** + - mkdocs.yml + - src/mcp/** + - scripts/build-docs.sh + - pyproject.toml + - uv.lock + - .github/workflows/deploy-docs.yml + workflow_dispatch: + +concurrency: + group: deploy-docs + cancel-in-progress: false + +jobs: + deploy-docs: + runs-on: ubuntu-latest + + permissions: + contents: read + pages: write + id-token: write + + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + + steps: + - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 + + - name: Install uv + uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 + with: + enable-cache: true + version: 0.9.5 + + - name: Build combined docs (v1.x at /, main at /v2/) + run: bash scripts/build-docs.sh site + + - name: Configure Pages + uses: actions/configure-pages@45bfe0192ca1faeb007ade9deae92b16b8254a0d # v6.0.0 + + - name: Upload Pages artifact + uses: actions/upload-pages-artifact@fc324d3547104276b827a68afc52ff2a11cc49c9 # v5.0.0 + with: + path: site + + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@cd2ce8fcbc39b97be8ca5fce6e763baed58fa128 # v5.0.0 diff --git a/.github/workflows/publish-docs-manually.yml b/.github/workflows/publish-docs-manually.yml deleted file mode 100644 index ee45ab5c8a..0000000000 --- a/.github/workflows/publish-docs-manually.yml +++ /dev/null @@ -1,35 +0,0 @@ -name: Publish Docs manually - -on: - workflow_dispatch: - -jobs: - docs-publish: - runs-on: ubuntu-latest - permissions: - contents: write - steps: - - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 - - name: Configure Git Credentials - run: | - git config user.name github-actions[bot] - git config user.email 41898282+github-actions[bot]@users.noreply.github.com - - - name: Install uv - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 - with: - enable-cache: true - version: 0.9.5 - - - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV - - uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3 - with: - key: mkdocs-material-${{ env.cache_id }} - path: .cache - restore-keys: | - mkdocs-material- - - - run: uv sync --frozen --group docs - - run: uv run --frozen --no-sync mkdocs gh-deploy --force - env: - ENABLE_SOCIAL_CARDS: "true" diff --git a/.github/workflows/publish-pypi.yml b/.github/workflows/publish-pypi.yml index bfc3cc64e1..7ba11e86fa 100644 --- a/.github/workflows/publish-pypi.yml +++ b/.github/workflows/publish-pypi.yml @@ -51,32 +51,3 @@ jobs: - name: Publish package distributions to PyPI uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # release/v1 - - docs-publish: - runs-on: ubuntu-latest - needs: ["pypi-publish"] - permissions: - contents: write - steps: - - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 - - name: Configure Git Credentials - run: | - git config user.name github-actions[bot] - git config user.email 41898282+github-actions[bot]@users.noreply.github.com - - - name: Install uv - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 - with: - enable-cache: true - version: 0.9.5 - - - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV - - uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3 - with: - key: mkdocs-material-${{ env.cache_id }} - path: .cache - restore-keys: | - mkdocs-material- - - - run: uv sync --frozen --group docs - - run: uv run --frozen --no-sync mkdocs gh-deploy --force diff --git a/.gitignore b/.gitignore index 5ff4ce9771..3443adf7c8 100644 --- a/.gitignore +++ b/.gitignore @@ -143,6 +143,7 @@ venv.bak/ # mkdocs documentation /site +/.worktrees/ # mypy .mypy_cache/ diff --git a/docs/index.md b/docs/index.md index 436d1c8fcd..6a937da67f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,5 +1,8 @@ # MCP Python SDK +!!! info "You are viewing the in-development v2 documentation" + For the current stable release, see the [v1.x documentation](https://py.sdk.modelcontextprotocol.io/). + The **Model Context Protocol (MCP)** allows applications to provide context for LLMs in a standardized way, separating the concerns of providing context from the actual LLM interaction. This Python SDK implements the full MCP specification, making it easy to: diff --git a/mkdocs.yml b/mkdocs.yml index 3a555785a7..e48c64242d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -5,7 +5,7 @@ strict: true repo_name: modelcontextprotocol/python-sdk repo_url: https://github.com/modelcontextprotocol/python-sdk edit_uri: edit/main/docs/ -site_url: https://modelcontextprotocol.github.io/python-sdk +site_url: https://py.sdk.modelcontextprotocol.io/v2/ # TODO(Marcelo): Add Anthropic copyright? # copyright: © Model Context Protocol 2025 to present diff --git a/pyproject.toml b/pyproject.toml index 364b9add0b..d88869da1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,6 +115,7 @@ bump = true [project.urls] Homepage = "https://modelcontextprotocol.io" +Documentation = "https://py.sdk.modelcontextprotocol.io/v2/" Repository = "https://github.com/modelcontextprotocol/python-sdk" Issues = "https://github.com/modelcontextprotocol/python-sdk/issues" diff --git a/scripts/build-docs.sh b/scripts/build-docs.sh new file mode 100755 index 0000000000..5a61309acf --- /dev/null +++ b/scripts/build-docs.sh @@ -0,0 +1,54 @@ +#!/usr/bin/env bash +# +# Build combined v1 + v2 MkDocs documentation for GitHub Pages. +# +# v1 docs (from the v1.x branch) are placed at the site root. +# v2 docs (from main) are placed under /v2/. +# +# Both branches are fetched fresh from origin, so the output is identical +# regardless of which branch triggered the workflow. This script is intended +# to run in CI; for local single-branch preview use `uv run mkdocs serve`. +# +# Usage: +# scripts/build-docs.sh [output-dir] +# +# Default output directory: site +# +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" +OUTPUT_DIR="$(cd "$REPO_ROOT" && mkdir -p "${1:-site}" && cd "${1:-site}" && pwd)" +V1_WORKTREE="$REPO_ROOT/.worktrees/v1-docs" +V2_WORKTREE="$REPO_ROOT/.worktrees/v2-docs" + +cleanup() { + cd "$REPO_ROOT" + git worktree remove --force "$V1_WORKTREE" 2>/dev/null || true + git worktree remove --force "$V2_WORKTREE" 2>/dev/null || true + rmdir "$REPO_ROOT/.worktrees" 2>/dev/null || true +} +trap cleanup EXIT + +rm -rf "${OUTPUT_DIR:?}"/* + +build_branch() { + local branch="$1" worktree="$2" dest="$3" + + echo "=== Building docs for ${branch} ===" + git fetch origin "$branch" + git worktree remove --force "$worktree" 2>/dev/null || true + rm -rf "$worktree" + git worktree add --detach "$worktree" "origin/${branch}" + + ( + cd "$worktree" + uv sync --frozen --group docs + uv run --frozen --no-sync mkdocs build --site-dir "$dest" + ) +} + +build_branch v1.x "$V1_WORKTREE" "$OUTPUT_DIR" +build_branch main "$V2_WORKTREE" "$OUTPUT_DIR/v2" + +echo "=== Combined docs built at $OUTPUT_DIR ===" From e8e64842781c66b613872cf394de6e7d6f6925bf Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 20 May 2026 17:06:57 +0200 Subject: [PATCH 46/60] ci: add zizmor for GitHub Actions security analysis (#2648) --- .github/workflows/claude.yml | 3 ++- .github/workflows/comment-on-release.yml | 22 ++++++++++++----- .github/workflows/conformance.yml | 4 ++++ .github/workflows/deploy-docs.yml | 2 ++ .github/workflows/publish-pypi.yml | 7 +++++- .github/workflows/shared.yml | 6 +++++ .github/workflows/weekly-lockfile-update.yml | 2 ++ .github/workflows/zizmor.yml | 25 ++++++++++++++++++++ 8 files changed, 63 insertions(+), 8 deletions(-) create mode 100644 .github/workflows/zizmor.yml diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml index 59dac99dcb..18f5377cb6 100644 --- a/.github/workflows/claude.yml +++ b/.github/workflows/claude.yml @@ -30,12 +30,13 @@ jobs: uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: fetch-depth: 1 + persist-credentials: false - name: Run Claude Code id: claude uses: anthropics/claude-code-action@2f8ba26a219c06cfb0f468eef8d97055fa814f97 # v1.0.53 with: - anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} + anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} # zizmor: ignore[secrets-outside-env] use_commit_signing: true additional_permissions: | actions: read diff --git a/.github/workflows/comment-on-release.yml b/.github/workflows/comment-on-release.yml index 15d6a1d26a..66f1fcc32a 100644 --- a/.github/workflows/comment-on-release.yml +++ b/.github/workflows/comment-on-release.yml @@ -16,13 +16,16 @@ jobs: uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 with: fetch-depth: 0 + persist-credentials: false - name: Get previous release id: previous_release uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + env: + CURRENT_TAG: ${{ github.event.release.tag_name }} with: script: | - const currentTag = '${{ github.event.release.tag_name }}'; + const currentTag = process.env.CURRENT_TAG; // Get all releases const { data: releases } = await github.rest.repos.listReleases({ @@ -54,10 +57,13 @@ jobs: - name: Get merged PRs between releases id: get_prs uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + env: + CURRENT_TAG: ${{ github.event.release.tag_name }} + PREVIOUS_TAG_JSON: ${{ steps.previous_release.outputs.result }} with: script: | - const currentTag = '${{ github.event.release.tag_name }}'; - const previousTag = ${{ steps.previous_release.outputs.result }}; + const currentTag = process.env.CURRENT_TAG; + const previousTag = JSON.parse(process.env.PREVIOUS_TAG_JSON); if (!previousTag) { console.log('No previous release found, skipping'); @@ -104,11 +110,15 @@ jobs: - name: Comment on PRs uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + env: + PR_NUMBERS_JSON: ${{ steps.get_prs.outputs.result }} + RELEASE_TAG: ${{ github.event.release.tag_name }} + RELEASE_URL: ${{ github.event.release.html_url }} with: script: | - const prNumbers = ${{ steps.get_prs.outputs.result }}; - const releaseTag = '${{ github.event.release.tag_name }}'; - const releaseUrl = '${{ github.event.release.html_url }}'; + const prNumbers = JSON.parse(process.env.PR_NUMBERS_JSON); + const releaseTag = process.env.RELEASE_TAG; + const releaseUrl = process.env.RELEASE_URL; const comment = `This pull request is included in [${releaseTag}](${releaseUrl})`; diff --git a/.github/workflows/conformance.yml b/.github/workflows/conformance.yml index d876da00b0..9c33d2936b 100644 --- a/.github/workflows/conformance.yml +++ b/.github/workflows/conformance.yml @@ -19,6 +19,8 @@ jobs: continue-on-error: true steps: - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 + with: + persist-credentials: false - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 with: enable-cache: true @@ -34,6 +36,8 @@ jobs: continue-on-error: true steps: - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 + with: + persist-credentials: false - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 with: enable-cache: true diff --git a/.github/workflows/deploy-docs.yml b/.github/workflows/deploy-docs.yml index d9362afd57..fcf7a6c1a7 100644 --- a/.github/workflows/deploy-docs.yml +++ b/.github/workflows/deploy-docs.yml @@ -34,6 +34,8 @@ jobs: steps: - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 + with: + persist-credentials: false - name: Install uv uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 diff --git a/.github/workflows/publish-pypi.yml b/.github/workflows/publish-pypi.yml index 7ba11e86fa..c72061d12b 100644 --- a/.github/workflows/publish-pypi.yml +++ b/.github/workflows/publish-pypi.yml @@ -4,6 +4,9 @@ on: release: types: [published] +permissions: + contents: read + jobs: release-build: name: Build distribution @@ -11,11 +14,13 @@ jobs: needs: [checks] steps: - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 + with: + persist-credentials: false - name: Install uv uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 with: - enable-cache: true + enable-cache: false version: 0.9.5 - name: Set up Python 3.12 diff --git a/.github/workflows/shared.yml b/.github/workflows/shared.yml index efb45c8898..5f115aef89 100644 --- a/.github/workflows/shared.yml +++ b/.github/workflows/shared.yml @@ -14,6 +14,8 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 + with: + persist-credentials: false - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 with: @@ -57,6 +59,8 @@ jobs: steps: - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 + with: + persist-credentials: false - name: Install uv uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 @@ -83,6 +87,8 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 + with: + persist-credentials: false - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 with: diff --git a/.github/workflows/weekly-lockfile-update.yml b/.github/workflows/weekly-lockfile-update.yml index 5d79d06d52..c30c72991a 100644 --- a/.github/workflows/weekly-lockfile-update.yml +++ b/.github/workflows/weekly-lockfile-update.yml @@ -15,6 +15,8 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 with: diff --git a/.github/workflows/zizmor.yml b/.github/workflows/zizmor.yml new file mode 100644 index 0000000000..83d3eb3817 --- /dev/null +++ b/.github/workflows/zizmor.yml @@ -0,0 +1,25 @@ +name: GitHub Actions Security Analysis + +on: + push: + branches: ["main"] + pull_request: + branches: ["**"] + +permissions: {} + +jobs: + zizmor: + runs-on: ubuntu-latest + + permissions: + security-events: write # Required for upload-sarif (used by zizmor-action) to upload SARIF files. + + steps: + - name: Checkout repository + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + + - name: Run zizmor 🌈 + uses: zizmorcore/zizmor-action@5f14fd08f7cf1cb1609c1e344975f152c7ee938d # v0.5.6 From 3eb579948a4719d606d2adbd1f3f69371c9c0f48 Mon Sep 17 00:00:00 2001 From: Max <224885523+maxisbey@users.noreply.github.com> Date: Tue, 26 May 2026 14:13:55 +0100 Subject: [PATCH 47/60] Add subject and claims to AccessToken (#2686) --- .../servers/simple-auth/mcp_simple_auth/auth_server.py | 2 ++ .../mcp_simple_auth/simple_auth_provider.py | 2 ++ .../simple-auth/mcp_simple_auth/token_verifier.py | 2 ++ src/mcp/server/auth/provider.py | 6 +++++- src/mcp/server/mcpserver/context.py | 7 ++++++- tests/server/mcpserver/auth/test_auth_integration.py | 10 ++++++++++ 6 files changed, 27 insertions(+), 2 deletions(-) diff --git a/examples/servers/simple-auth/mcp_simple_auth/auth_server.py b/examples/servers/simple-auth/mcp_simple_auth/auth_server.py index 9d13fffe42..26c87c5ef2 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/auth_server.py +++ b/examples/servers/simple-auth/mcp_simple_auth/auth_server.py @@ -120,6 +120,8 @@ async def introspect_handler(request: Request) -> Response: "iat": int(time.time()), "token_type": "Bearer", "aud": access_token.resource, # RFC 8707 audience claim + "sub": access_token.subject, # RFC 7662 subject + "iss": str(server_settings.server_url), } ) diff --git a/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py b/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py index 3a3895cc57..48eb9a8414 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py +++ b/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py @@ -181,6 +181,7 @@ async def handle_simple_callback(self, username: str, password: str, state: str) scopes=[self.settings.mcp_scope], code_challenge=code_challenge, resource=resource, # RFC 8707 + subject=username, ) self.auth_codes[new_code] = auth_code @@ -219,6 +220,7 @@ async def exchange_authorization_code( scopes=authorization_code.scopes, expires_at=int(time.time()) + 3600, resource=authorization_code.resource, # RFC 8707 + subject=authorization_code.subject, ) # Store user data mapping for this token diff --git a/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py b/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py index 5228d034e4..641095a125 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py +++ b/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py @@ -75,6 +75,8 @@ async def verify_token(self, token: str) -> AccessToken | None: scopes=data.get("scope", "").split() if data.get("scope") else [], expires_at=data.get("exp"), resource=data.get("aud"), # Include resource in token + subject=data.get("sub"), # RFC 7662 subject (resource owner) + claims=data, ) except Exception as e: logger.warning(f"Token introspection failed: {e}") diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index 957082a854..4ce1137575 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Generic, Literal, Protocol, TypeVar +from typing import Any, Generic, Literal, Protocol, TypeVar from urllib.parse import parse_qs, urlencode, urlparse, urlunparse from pydantic import AnyUrl, BaseModel @@ -25,6 +25,7 @@ class AuthorizationCode(BaseModel): redirect_uri: AnyUrl redirect_uri_provided_explicitly: bool resource: str | None = None # RFC 8707 resource indicator + subject: str | None = None # resource owner; propagate to the issued AccessToken class RefreshToken(BaseModel): @@ -32,6 +33,7 @@ class RefreshToken(BaseModel): client_id: str scopes: list[str] expires_at: int | None = None + subject: str | None = None # resource owner; propagate to refreshed AccessTokens class AccessToken(BaseModel): @@ -40,6 +42,8 @@ class AccessToken(BaseModel): scopes: list[str] expires_at: int | None = None resource: str | None = None # RFC 8707 resource indicator + subject: str | None = None # RFC 7662/9068 `sub`: resource owner; unique only per issuer + claims: dict[str, Any] | None = None # additional claims (e.g. `iss`, `act`) RegistrationErrorCode = Literal[ diff --git a/src/mcp/server/mcpserver/context.py b/src/mcp/server/mcpserver/context.py index e87388eee9..39bba839bd 100644 --- a/src/mcp/server/mcpserver/context.py +++ b/src/mcp/server/mcpserver/context.py @@ -208,9 +208,14 @@ async def log( related_request_id=self.request_id, ) + # TODO(maxisbey): see if this is needed otherwise remove @property def client_id(self) -> str | None: - """Get the client ID if available.""" + """Get the client ID if available. + + Note: this reads from the MCP request's `_meta` params, not the OAuth + bearer token. For that, use `get_access_token().client_id`. + """ return self.request_context.meta.get("client_id") if self.request_context.meta else None # pragma: no cover @property diff --git a/tests/server/mcpserver/auth/test_auth_integration.py b/tests/server/mcpserver/auth/test_auth_integration.py index 602f5cc752..35fec1c57e 100644 --- a/tests/server/mcpserver/auth/test_auth_integration.py +++ b/tests/server/mcpserver/auth/test_auth_integration.py @@ -53,6 +53,7 @@ async def authorize(self, client: OAuthClientInformationFull, params: Authorizat redirect_uri_provided_explicitly=params.redirect_uri_provided_explicitly, expires_at=time.time() + 300, scopes=params.scopes or ["read", "write"], + subject="test-user", ) self.auth_codes[code.code] = code @@ -79,6 +80,7 @@ async def exchange_authorization_code( client_id=client.client_id, scopes=authorization_code.scopes, expires_at=int(time.time()) + 3600, + subject=authorization_code.subject, ) self.refresh_tokens[refresh_token] = access_token @@ -108,6 +110,7 @@ async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_t client_id=token_info.client_id, scopes=token_info.scopes, expires_at=token_info.expires_at, + subject=token_info.subject, ) return refresh_obj @@ -141,6 +144,7 @@ async def exchange_refresh_token( client_id=client.client_id, scopes=scopes or token_info.scopes, expires_at=int(time.time()) + 3600, + subject=refresh_token.subject, ) self.refresh_tokens[new_refresh_token] = new_access_token @@ -169,6 +173,7 @@ async def load_access_token(self, token: str) -> AccessToken | None: client_id=token_info.client_id, scopes=token_info.scopes, expires_at=token_info.expires_at, + subject=token_info.subject, ) async def revoke_token(self, token: AccessToken | RefreshToken) -> None: @@ -832,6 +837,7 @@ async def test_authorization_get( assert auth_info.client_id == client_info["client_id"] assert "read" in auth_info.scopes assert "write" in auth_info.scopes + assert auth_info.subject == "test-user" # 6. Refresh the token response = await test_client.post( @@ -852,6 +858,10 @@ async def test_authorization_get( assert new_token_response["access_token"] != access_token assert new_token_response["refresh_token"] != refresh_token + refreshed_auth_info = await mock_oauth_provider.load_access_token(new_token_response["access_token"]) + assert refreshed_auth_info + assert refreshed_auth_info.subject == "test-user" + # 7. Revoke the token response = await test_client.post( "/revoke", From 24725633f11276e94c5f7310ee0296a0e2117618 Mon Sep 17 00:00:00 2001 From: Max <224885523+maxisbey@users.noreply.github.com> Date: Thu, 28 May 2026 19:48:30 +0100 Subject: [PATCH 48/60] test: interaction-model end-to-end suite with a requirements manifest (#2691) --- pyproject.toml | 5 +- src/mcp/client/auth/oauth2.py | 20 +- src/mcp/client/client.py | 2 +- src/mcp/client/session.py | 10 +- src/mcp/client/streamable_http.py | 16 +- src/mcp/server/auth/handlers/authorize.py | 2 +- src/mcp/server/auth/middleware/bearer_auth.py | 2 +- src/mcp/server/lowlevel/server.py | 16 +- src/mcp/server/mcpserver/context.py | 4 +- src/mcp/server/mcpserver/prompts/base.py | 2 +- src/mcp/server/mcpserver/server.py | 2 +- src/mcp/server/session.py | 8 +- src/mcp/server/sse.py | 13 +- src/mcp/server/streamable_http.py | 98 +- src/mcp/server/streamable_http_manager.py | 2 +- src/mcp/server/transport_security.py | 34 +- src/mcp/shared/auth.py | 4 +- src/mcp/shared/session.py | 2 +- tests/interaction/README.md | 228 ++ tests/interaction/__init__.py | 0 tests/interaction/_connect.py | 360 +++ tests/interaction/_helpers.py | 107 + tests/interaction/_requirements.py | 2816 +++++++++++++++++ tests/interaction/auth/__init__.py | 0 tests/interaction/auth/_harness.py | 465 +++ tests/interaction/auth/_provider.py | 186 ++ tests/interaction/auth/test_as_handlers.py | 300 ++ .../interaction/auth/test_authorize_token.py | 399 +++ tests/interaction/auth/test_bearer.py | 189 ++ tests/interaction/auth/test_discovery.py | 333 ++ tests/interaction/auth/test_flow.py | 239 ++ tests/interaction/auth/test_lifecycle.py | 445 +++ tests/interaction/conftest.py | 23 + tests/interaction/lowlevel/__init__.py | 0 .../interaction/lowlevel/test_cancellation.py | 234 ++ tests/interaction/lowlevel/test_completion.py | 131 + .../interaction/lowlevel/test_elicitation.py | 662 ++++ tests/interaction/lowlevel/test_flows.py | 203 ++ tests/interaction/lowlevel/test_initialize.py | 384 +++ .../interaction/lowlevel/test_list_changed.py | 136 + tests/interaction/lowlevel/test_logging.py | 127 + tests/interaction/lowlevel/test_meta.py | 63 + tests/interaction/lowlevel/test_pagination.py | 242 ++ tests/interaction/lowlevel/test_ping.py | 53 + tests/interaction/lowlevel/test_progress.py | 301 ++ tests/interaction/lowlevel/test_prompts.py | 209 ++ tests/interaction/lowlevel/test_resources.py | 309 ++ tests/interaction/lowlevel/test_roots.py | 166 + tests/interaction/lowlevel/test_sampling.py | 687 ++++ tests/interaction/lowlevel/test_timeouts.py | 114 + tests/interaction/lowlevel/test_tools.py | 512 +++ tests/interaction/lowlevel/test_wire.py | 309 ++ tests/interaction/mcpserver/__init__.py | 0 .../interaction/mcpserver/test_completion.py | 38 + tests/interaction/mcpserver/test_context.py | 271 ++ tests/interaction/mcpserver/test_prompts.py | 195 ++ tests/interaction/mcpserver/test_resources.py | 183 ++ tests/interaction/mcpserver/test_tools.py | 432 +++ tests/interaction/test_coverage.py | 105 + tests/interaction/transports/__init__.py | 0 tests/interaction/transports/_bridge.py | 169 + tests/interaction/transports/_event_store.py | 55 + tests/interaction/transports/_stdio_server.py | 63 + tests/interaction/transports/test_bridge.py | 94 + .../transports/test_client_transport_http.py | 247 ++ tests/interaction/transports/test_flows.py | 129 + .../transports/test_hosting_http.py | 344 ++ .../transports/test_hosting_resume.py | 372 +++ .../transports/test_hosting_session.py | 202 ++ tests/interaction/transports/test_sse.py | 90 + tests/interaction/transports/test_stdio.py | 143 + .../transports/test_streamable_http.py | 168 + uv.lock | 2 +- 73 files changed, 14355 insertions(+), 121 deletions(-) create mode 100644 tests/interaction/README.md create mode 100644 tests/interaction/__init__.py create mode 100644 tests/interaction/_connect.py create mode 100644 tests/interaction/_helpers.py create mode 100644 tests/interaction/_requirements.py create mode 100644 tests/interaction/auth/__init__.py create mode 100644 tests/interaction/auth/_harness.py create mode 100644 tests/interaction/auth/_provider.py create mode 100644 tests/interaction/auth/test_as_handlers.py create mode 100644 tests/interaction/auth/test_authorize_token.py create mode 100644 tests/interaction/auth/test_bearer.py create mode 100644 tests/interaction/auth/test_discovery.py create mode 100644 tests/interaction/auth/test_flow.py create mode 100644 tests/interaction/auth/test_lifecycle.py create mode 100644 tests/interaction/conftest.py create mode 100644 tests/interaction/lowlevel/__init__.py create mode 100644 tests/interaction/lowlevel/test_cancellation.py create mode 100644 tests/interaction/lowlevel/test_completion.py create mode 100644 tests/interaction/lowlevel/test_elicitation.py create mode 100644 tests/interaction/lowlevel/test_flows.py create mode 100644 tests/interaction/lowlevel/test_initialize.py create mode 100644 tests/interaction/lowlevel/test_list_changed.py create mode 100644 tests/interaction/lowlevel/test_logging.py create mode 100644 tests/interaction/lowlevel/test_meta.py create mode 100644 tests/interaction/lowlevel/test_pagination.py create mode 100644 tests/interaction/lowlevel/test_ping.py create mode 100644 tests/interaction/lowlevel/test_progress.py create mode 100644 tests/interaction/lowlevel/test_prompts.py create mode 100644 tests/interaction/lowlevel/test_resources.py create mode 100644 tests/interaction/lowlevel/test_roots.py create mode 100644 tests/interaction/lowlevel/test_sampling.py create mode 100644 tests/interaction/lowlevel/test_timeouts.py create mode 100644 tests/interaction/lowlevel/test_tools.py create mode 100644 tests/interaction/lowlevel/test_wire.py create mode 100644 tests/interaction/mcpserver/__init__.py create mode 100644 tests/interaction/mcpserver/test_completion.py create mode 100644 tests/interaction/mcpserver/test_context.py create mode 100644 tests/interaction/mcpserver/test_prompts.py create mode 100644 tests/interaction/mcpserver/test_resources.py create mode 100644 tests/interaction/mcpserver/test_tools.py create mode 100644 tests/interaction/test_coverage.py create mode 100644 tests/interaction/transports/__init__.py create mode 100644 tests/interaction/transports/_bridge.py create mode 100644 tests/interaction/transports/_event_store.py create mode 100644 tests/interaction/transports/_stdio_server.py create mode 100644 tests/interaction/transports/test_bridge.py create mode 100644 tests/interaction/transports/test_client_transport_http.py create mode 100644 tests/interaction/transports/test_flows.py create mode 100644 tests/interaction/transports/test_hosting_http.py create mode 100644 tests/interaction/transports/test_hosting_resume.py create mode 100644 tests/interaction/transports/test_hosting_session.py create mode 100644 tests/interaction/transports/test_sse.py create mode 100644 tests/interaction/transports/test_stdio.py create mode 100644 tests/interaction/transports/test_streamable_http.py diff --git a/pyproject.toml b/pyproject.toml index d88869da1c..6d2319621a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,7 @@ dev = [ # We add mcp[cli,ws] so `uv sync` considers the extras. "mcp[cli,ws]", "pyright>=1.1.400", - "pytest>=8.3.4", + "pytest>=8.4.0", "ruff>=0.8.5", "trio>=0.26.2", "pytest-flakefinder>=1.1.0", @@ -193,6 +193,9 @@ strict-no-cover = { git = "https://github.com/pydantic/strict-no-cover" } [tool.pytest.ini_options] log_cli = true xfail_strict = true +markers = [ + "requirement(id): links a test to the entry in tests/interaction/_requirements.py it exercises", +] addopts = """ --color=yes --capture=fd diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 72309f5775..3c546fda2b 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -360,10 +360,10 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: auth_code, returned_state = await self.context.callback_handler() if returned_state is None or not secrets.compare_digest(returned_state, state): - raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {state}") # pragma: no cover + raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {state}") if not auth_code: - raise OAuthFlowError("No authorization code received") # pragma: no cover + raise OAuthFlowError("No authorization code received") # Return auth code and code verifier for token exchange return auth_code, pkce_params.code_verifier @@ -452,7 +452,7 @@ async def _refresh_token(self) -> httpx.Request: return httpx.Request("POST", token_url, data=refresh_data, headers=headers) - async def _handle_refresh_response(self, response: httpx.Response) -> bool: # pragma: no cover + async def _handle_refresh_response(self, response: httpx.Response) -> bool: """Handle token refresh response. Returns True if successful.""" if response.status_code != 200: logger.warning(f"Token refresh failed: {response.status_code}") @@ -468,12 +468,12 @@ async def _handle_refresh_response(self, response: httpx.Response) -> bool: # p await self.context.storage.set_tokens(token_response) return True - except ValidationError: + except ValidationError: # pragma: no cover logger.exception("Invalid refresh response") self.context.clear_tokens() return False - async def _initialize(self) -> None: # pragma: no cover + async def _initialize(self) -> None: """Load stored tokens and client info.""" self.context.current_tokens = await self.context.storage.get_tokens() self.context.client_info = await self.context.storage.get_client_info() @@ -507,17 +507,17 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. """HTTPX auth flow integration.""" async with self.context.lock: if not self._initialized: - await self._initialize() # pragma: no cover + await self._initialize() # Capture protocol version from request headers self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION) if not self.context.is_token_valid() and self.context.can_refresh_token(): # Try to refresh token - refresh_request = await self._refresh_token() # pragma: no cover - refresh_response = yield refresh_request # pragma: no cover + refresh_request = await self._refresh_token() + refresh_response = yield refresh_request - if not await self._handle_refresh_response(refresh_response): # pragma: no cover + if not await self._handle_refresh_response(refresh_response): # Refresh failed, need full re-authentication self._initialized = False @@ -612,7 +612,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # Step 5: Perform authorization and complete token exchange token_response = yield await self._perform_authorization() await self._handle_token_response(token_response) - except Exception: # pragma: no cover + except Exception: logger.exception("OAuth flow error") raise diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 34d6a360fa..b33fea4052 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -305,4 +305,4 @@ async def list_tools(self, *, cursor: str | None = None, meta: RequestParamsMeta async def send_roots_list_changed(self) -> None: """Send a notification that the roots list has changed.""" # TODO(Marcelo): Currently, there is no way for the server to handle this. We should add support. - await self.session.send_roots_list_changed() # pragma: no cover + await self.session.send_roots_list_changed() diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 0cea454a77..86113874be 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -74,7 +74,7 @@ async def _default_elicitation_callback( context: RequestContext[ClientSession], params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: - return types.ErrorData( # pragma: no cover + return types.ErrorData( code=types.INVALID_REQUEST, message="Elicitation not supported", ) @@ -337,9 +337,7 @@ async def _validate_tool_result(self, name: str, result: types.CallToolResult) - from jsonschema import SchemaError, ValidationError, validate if result.structured_content is None: - raise RuntimeError( - f"Tool {name} has an output schema but did not return structured content" - ) # pragma: no cover + raise RuntimeError(f"Tool {name} has an output schema but did not return structured content") try: validate(result.structured_content, output_schema) except ValidationError as e: @@ -408,7 +406,7 @@ async def list_tools(self, *, params: types.PaginatedRequestParams | None = None return result - async def send_roots_list_changed(self) -> None: # pragma: no cover + async def send_roots_list_changed(self) -> None: """Send a roots/list_changed notification.""" await self.send_notification(types.RootsListChangedNotification()) @@ -449,7 +447,7 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques client_response = ClientResponse.validate_python(response) await responder.respond(client_response) - case types.PingRequest(): # pragma: no cover + case types.PingRequest(): with responder: return await responder.respond(types.EmptyResult()) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 9a119c6338..aa3e50e07e 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -210,7 +210,7 @@ async def handle_get_stream(self, client: httpx.AsyncClient, read_stream_writer: # Stream ended normally (server closed) - reset attempt counter attempt = 0 - except Exception: # pragma: lax no cover + except Exception: logger.debug("GET stream error", exc_info=True) attempt += 1 @@ -267,8 +267,8 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: logger.debug("Received 202 Accepted") return - if response.status_code == 404: # pragma: no branch - if isinstance(message, JSONRPCRequest): # pragma: no branch + if response.status_code == 404: + if isinstance(message, JSONRPCRequest): error_data = ErrorData(code=INVALID_REQUEST, message="Session terminated") session_message = SessionMessage(JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data)) await ctx.read_stream_writer.send(session_message) @@ -492,17 +492,17 @@ async def handle_request_async(): async def terminate_session(self, client: httpx.AsyncClient) -> None: """Terminate the session by sending a DELETE request.""" - if not self.session_id: # pragma: lax no cover - return + if not self.session_id: + return # pragma: no cover try: headers = self._prepare_headers() response = await client.delete(self.url, headers=headers) - if response.status_code == 405: # pragma: lax no cover + if response.status_code == 405: logger.debug("Server does not allow session termination") - elif response.status_code not in (200, 204): # pragma: lax no cover - logger.warning(f"Session termination failed: {response.status_code}") + elif response.status_code not in (200, 204): + logger.warning(f"Session termination failed: {response.status_code}") # pragma: no cover except Exception as exc: # pragma: no cover logger.warning(f"Session termination failed: {exc}") diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index dec6713b13..5cf93cf8c2 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -117,7 +117,7 @@ async def error_response( pass # the error response MUST contain the state specified by the client, if any - if state is None: # pragma: no cover + if state is None: # make last-ditch effort to load state state = best_effort_extract_string("state", params) diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index 6825c00b9e..2eafdc793e 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -95,7 +95,7 @@ async def _send_auth_error(self, send: Send, status_code: int, error: str, descr """Send an authentication error response with WWW-Authenticate header.""" # Build WWW-Authenticate header value www_auth_parts = [f'error="{error}"', f'error_description="{description}"'] - if self.resource_metadata_url: # pragma: no cover + if self.resource_metadata_url: www_auth_parts.append(f'resource_metadata="{self.resource_metadata_url}"') www_authenticate = f"Bearer {', '.join(www_auth_parts)}" diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 59de0ace45..5e4e2e6f5b 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -349,12 +349,12 @@ def session_manager(self) -> StreamableHTTPSessionManager: Raises: RuntimeError: If called before streamable_http_app() has been called. """ - if self._session_manager is None: # pragma: no cover - raise RuntimeError( + if self._session_manager is None: + raise RuntimeError( # pragma: no cover "Session manager can only be accessed after calling streamable_http_app(). " "The session manager is created lazily to avoid unnecessary initialization." ) - return self._session_manager # pragma: no cover + return self._session_manager async def run( self, @@ -513,7 +513,7 @@ async def _handle_request( if raise_exceptions: # pragma: no cover raise err response = types.ErrorData(code=0, message=str(err)) - else: # pragma: no cover + else: response = types.ErrorData(code=types.METHOD_NOT_FOUND, message="Method not found") if isinstance(response, types.ErrorData) and span is not None: @@ -603,7 +603,7 @@ def streamable_http_app( required_scopes: list[str] = [] # Set up auth if configured - if auth: # pragma: no cover + if auth: required_scopes = auth.required_scopes or [] # Add auth middleware if token verifier is available @@ -629,10 +629,10 @@ def streamable_http_app( ) # Set up routes with or without auth - if token_verifier: # pragma: no cover + if token_verifier: # Determine resource metadata URL resource_metadata_url = None - if auth and auth.resource_server_url: + if auth and auth.resource_server_url: # pragma: no branch # Build compliant metadata URL for WWW-Authenticate header resource_metadata_url = build_resource_metadata_url(auth.resource_server_url) @@ -652,7 +652,7 @@ def streamable_http_app( ) # Add protected resource metadata endpoint if configured as RS - if auth and auth.resource_server_url: # pragma: no cover + if auth and auth.resource_server_url: routes.extend( create_protected_resource_routes( resource_url=auth.resource_server_url, diff --git a/src/mcp/server/mcpserver/context.py b/src/mcp/server/mcpserver/context.py index 39bba839bd..92de074d34 100644 --- a/src/mcp/server/mcpserver/context.py +++ b/src/mcp/server/mcpserver/context.py @@ -94,7 +94,7 @@ async def report_progress(self, progress: float, total: float | None = None, mes """ progress_token = self.request_context.meta.get("progress_token") if self.request_context.meta else None - if progress_token is None: # pragma: no cover + if progress_token is None: return await self.request_context.session.send_progress_notification( @@ -242,7 +242,7 @@ async def close_sse_stream(self) -> None: This is a no-op if not using StreamableHTTP transport with event_store. The callback is only available when event_store is configured. """ - if self._request_context and self._request_context.close_sse_stream: # pragma: no cover + if self._request_context and self._request_context.close_sse_stream: # pragma: no branch await self._request_context.close_sse_stream() async def close_standalone_sse_stream(self) -> None: diff --git a/src/mcp/server/mcpserver/prompts/base.py b/src/mcp/server/mcpserver/prompts/base.py index e5b2af7d82..2f778eb514 100644 --- a/src/mcp/server/mcpserver/prompts/base.py +++ b/src/mcp/server/mcpserver/prompts/base.py @@ -185,5 +185,5 @@ async def render( raise ValueError(f"Could not convert prompt result to message: {msg}") return messages - except Exception as e: # pragma: no cover + except Exception as e: raise ValueError(f"Error rendering prompt {self.name}: {e}") diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index b3471163b7..ec2365810e 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -244,7 +244,7 @@ def session_manager(self) -> StreamableHTTPSessionManager: Raises: RuntimeError: If called before streamable_http_app() has been called. """ - return self._lowlevel_server.session_manager # pragma: no cover + return self._lowlevel_server.session_manager @overload def run(self, transport: Literal["stdio"] = ...) -> None: ... diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 20b640527a..fc2f97a9cb 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -223,7 +223,7 @@ async def send_log_message( related_request_id, ) - async def send_resource_updated(self, uri: str | AnyUrl) -> None: # pragma: no cover + async def send_resource_updated(self, uri: str | AnyUrl) -> None: """Send a resource updated notification.""" await self.send_notification( types.ResourceUpdatedNotification( @@ -447,7 +447,7 @@ async def elicit_url( metadata=ServerMessageMetadata(related_request_id=related_request_id), ) - async def send_ping(self) -> types.EmptyResult: # pragma: no cover + async def send_ping(self) -> types.EmptyResult: """Send a ping request.""" return await self.send_request( types.PingRequest(), @@ -479,11 +479,11 @@ async def send_resource_list_changed(self) -> None: """Send a resource list changed notification.""" await self.send_notification(types.ResourceListChangedNotification()) - async def send_tool_list_changed(self) -> None: # pragma: no cover + async def send_tool_list_changed(self) -> None: """Send a tool list changed notification.""" await self.send_notification(types.ToolListChangedNotification()) - async def send_prompt_list_changed(self) -> None: # pragma: no cover + async def send_prompt_list_changed(self) -> None: """Send a prompt list changed notification.""" await self.send_notification(types.PromptListChangedNotification()) diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 48192ff612..be8e979c9d 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -116,15 +116,15 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}") @asynccontextmanager - async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # pragma: no cover - if scope["type"] != "http": + async def connect_sse(self, scope: Scope, receive: Receive, send: Send): + if scope["type"] != "http": # pragma: no cover logger.error("connect_sse received non-HTTP request") raise ValueError("connect_sse can only handle HTTP requests") # Validate request headers for DNS rebinding protection request = Request(scope, receive) error_response = await self._security.validate_request(request, is_post=False) - if error_response: + if error_response: # pragma: no cover await error_response(scope, receive, send) raise ValueError("Request validation failed") @@ -179,6 +179,7 @@ async def response_wrapper(scope: Scope, receive: Receive, send: Send): await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)( scope, receive, send ) + await sse_stream_reader.aclose() await read_stream_writer.aclose() await write_stream_reader.aclose() self._read_stream_writers.pop(session_id, None) @@ -190,13 +191,13 @@ async def response_wrapper(scope: Scope, receive: Receive, send: Send): logger.debug("Yielding read and write streams") yield (read_stream, write_stream) - async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: # pragma: no cover + async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: logger.debug("Handling POST message") request = Request(scope, receive) # Validate request headers for DNS rebinding protection error_response = await self._security.validate_request(request, is_post=True) - if error_response: + if error_response: # pragma: no cover return await error_response(scope, receive, send) session_id_param = request.query_params.get("session_id") @@ -225,7 +226,7 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) try: message = types.jsonrpc_message_adapter.validate_json(body, by_name=False) logger.debug(f"Validated client message: {message}") - except ValidationError as err: + except ValidationError as err: # pragma: no cover logger.exception("Failed to parse message") response = Response("Could not parse message", status_code=400) await response(scope, receive, send) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index f14201857c..f2f4407cea 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -179,7 +179,7 @@ def is_terminated(self) -> bool: """Check if this transport has been explicitly terminated.""" return self._terminated - def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover + def close_sse_stream(self, request_id: RequestId) -> None: """Close SSE connection for a specific request without terminating the stream. This method closes the HTTP connection for the specified request, triggering @@ -198,11 +198,11 @@ def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover the disconnect. """ writer = self._sse_stream_writers.pop(request_id, None) - if writer: + if writer: # pragma: no branch writer.close() # Also close and remove request streams - if request_id in self._request_streams: + if request_id in self._request_streams: # pragma: no branch send_stream, receive_stream = self._request_streams.pop(request_id) send_stream.close() receive_stream.close() @@ -242,7 +242,7 @@ def _create_session_message( # Only provide close callbacks when client supports resumability if self._event_store and protocol_version >= "2025-11-25": - async def close_stream_callback() -> None: # pragma: no cover + async def close_stream_callback() -> None: self.close_sse_stream(request_id) async def close_standalone_stream_callback() -> None: # pragma: no cover @@ -293,7 +293,7 @@ def _create_error_response( ) -> Response: """Create an error response with a simple string message.""" response_headers = {"Content-Type": CONTENT_TYPE_JSON} - if headers: # pragma: no cover + if headers: response_headers.update(headers) if self.mcp_session_id: @@ -320,10 +320,10 @@ def _create_json_response( ) -> Response: """Create a JSON response from a JSONRPCMessage.""" response_headers = {"Content-Type": CONTENT_TYPE_JSON} - if headers: # pragma: lax no cover - response_headers.update(headers) + if headers: + response_headers.update(headers) # pragma: no cover - if self.mcp_session_id: # pragma: lax no cover + if self.mcp_session_id: response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id return Response( @@ -344,7 +344,7 @@ def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: } # If an event ID was provided, include it - if event_message.event_id: # pragma: no cover + if event_message.event_id: event_data["id"] = event_message.event_id return event_data @@ -374,7 +374,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No await error_response(scope, receive, send) return - if self._terminated: # pragma: no cover + if self._terminated: # If the session has been terminated, return 404 Not Found response = self._create_error_response( "Not Found: Session has been terminated", @@ -389,7 +389,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No await self._handle_get_request(request, send) elif request.method == "DELETE": await self._handle_delete_request(request, send) - else: # pragma: no cover + else: await self._handle_unsupported_request(request, send) def _check_accept_headers(self, request: Request) -> tuple[bool, bool]: @@ -421,7 +421,7 @@ async def _validate_accept_header(self, request: Request, scope: Scope, send: Se has_json, has_sse = self._check_accept_headers(request) if self.is_json_response_enabled: # For JSON-only responses, only require application/json - if not has_json: # pragma: lax no cover + if not has_json: # pragma: no cover response = self._create_error_response( "Not Acceptable: Client must accept application/json", HTTPStatus.NOT_ACCEPTABLE, @@ -469,7 +469,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re try: message = jsonrpc_message_adapter.validate_python(raw_message, by_name=False) - except ValidationError as e: # pragma: no cover + except ValidationError as e: response = self._create_error_response( f"Validation error: {str(e)}", HTTPStatus.BAD_REQUEST, @@ -495,7 +495,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re ) await response(scope, receive, send) return - elif not await self._validate_request_headers(request, send): # pragma: no cover + elif not await self._validate_request_headers(request, send): return # For notifications and responses only, return 202 Accepted @@ -579,7 +579,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re # Store writer reference so close_sse_stream() can close it self._sse_stream_writers[request_id] = sse_stream_writer - async def sse_writer(): # pragma: lax no cover + async def sse_writer(): # Get the request ID from the incoming request message try: async with sse_stream_writer, request_stream_reader: @@ -595,10 +595,10 @@ async def sse_writer(): # pragma: lax no cover # If response, remove from pending streams and close if isinstance(event_message.message, JSONRPCResponse | JSONRPCError): break - except anyio.ClosedResourceError: + except anyio.ClosedResourceError: # pragma: lax no cover # Expected when close_sse_stream() is called logger.debug("SSE stream closed by close_sse_stream()") - except Exception: + except Exception: # pragma: lax no cover logger.exception("Error in SSE writer") finally: logger.debug("Closing SSE writer") @@ -628,14 +628,14 @@ async def sse_writer(): # pragma: lax no cover # Then send the message to be processed by the server session_message = self._create_session_message(message, request, request_id, protocol_version) await writer.send(session_message) - except Exception: # pragma: no cover + except Exception: # pragma: lax no cover logger.exception("SSE response error") await sse_stream_writer.aclose() await self._clean_up_memory_streams(request_id) finally: await sse_stream_reader.aclose() - except Exception as err: # pragma: no cover + except Exception as err: logger.exception("Error handling POST request") response = self._create_error_response( f"Error handling POST request: {err}", @@ -643,9 +643,9 @@ async def sse_writer(): # pragma: lax no cover INTERNAL_ERROR, ) await response(scope, receive, send) - if writer: + if writer: # pragma: no cover await writer.send(Exception(err)) - return + return # pragma: no cover async def _handle_get_request(self, request: Request, send: Send) -> None: """Handle GET request to establish SSE. @@ -661,7 +661,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: # Validate Accept header - must include text/event-stream _, has_sse = self._check_accept_headers(request) - if not has_sse: # pragma: no cover + if not has_sse: response = self._create_error_response( "Not Acceptable: Client must accept text/event-stream", HTTPStatus.NOT_ACCEPTABLE, @@ -673,7 +673,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: return # Handle resumability: check for Last-Event-ID header - if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): # pragma: no cover + if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): await self._replay_events(last_event_id, request, send) return @@ -683,11 +683,11 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: "Content-Type": CONTENT_TYPE_SSE, } - if self.mcp_session_id: + if self.mcp_session_id: # pragma: no branch headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id # Check if we already have an active GET stream - if GET_STREAM_KEY in self._request_streams: # pragma: no cover + if GET_STREAM_KEY in self._request_streams: response = self._create_error_response( "Conflict: Only one SSE stream is allowed per session", HTTPStatus.CONFLICT, @@ -707,7 +707,7 @@ async def standalone_sse_writer(): async with sse_stream_writer, standalone_stream_reader: # Process messages from the standalone stream - async for event_message in standalone_stream_reader: # pragma: lax no cover + async for event_message in standalone_stream_reader: # For the standalone stream, we handle: # - JSONRPCNotification (server sends notifications to client) # - JSONRPCRequest (server sends requests to client) @@ -716,8 +716,8 @@ async def standalone_sse_writer(): # Send the message via SSE event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) - except Exception: # pragma: no cover - logger.exception("Error in standalone SSE writer") + except Exception: + logger.exception("Error in standalone SSE writer") # pragma: no cover finally: logger.debug("Closing standalone SSE writer") await self._clean_up_memory_streams(GET_STREAM_KEY) @@ -775,7 +775,7 @@ async def terminate(self) -> None: request_stream_keys = list(self._request_streams.keys()) # Close all request streams asynchronously - for key in request_stream_keys: # pragma: lax no cover + for key in request_stream_keys: await self._clean_up_memory_streams(key) # Clear the request streams dictionary immediately @@ -793,13 +793,13 @@ async def terminate(self) -> None: # During cleanup, we catch all exceptions since streams might be in various states logger.debug(f"Error closing streams: {e}") - async def _handle_unsupported_request(self, request: Request, send: Send) -> None: # pragma: no cover + async def _handle_unsupported_request(self, request: Request, send: Send) -> None: """Handle unsupported HTTP methods.""" headers = { "Content-Type": CONTENT_TYPE_JSON, "Allow": "GET, POST, DELETE", } - if self.mcp_session_id: + if self.mcp_session_id: # pragma: no branch headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id response = self._create_error_response( @@ -809,7 +809,7 @@ async def _handle_unsupported_request(self, request: Request, send: Send) -> Non ) await response(request.scope, request.receive, send) - async def _validate_request_headers(self, request: Request, send: Send) -> bool: # pragma: lax no cover + async def _validate_request_headers(self, request: Request, send: Send) -> bool: if not await self._validate_session(request, send): return False if not await self._validate_protocol_version(request, send): @@ -818,7 +818,7 @@ async def _validate_request_headers(self, request: Request, send: Send) -> bool: async def _validate_session(self, request: Request, send: Send) -> bool: """Validate the session ID in the request.""" - if not self.mcp_session_id: # pragma: no cover + if not self.mcp_session_id: # If we're not using session IDs, return True return True @@ -826,7 +826,7 @@ async def _validate_session(self, request: Request, send: Send) -> bool: request_session_id = self._get_session_id(request) # If no session ID provided but required, return error - if not request_session_id: # pragma: no cover + if not request_session_id: response = self._create_error_response( "Bad Request: Missing session ID", HTTPStatus.BAD_REQUEST, @@ -851,11 +851,11 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER) # If no protocol version provided, assume default version - if protocol_version is None: # pragma: no cover + if protocol_version is None: protocol_version = DEFAULT_NEGOTIATED_VERSION # Check if the protocol version is supported - if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: # pragma: no cover + if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: supported_versions = ", ".join(SUPPORTED_PROTOCOL_VERSIONS) response = self._create_error_response( f"Bad Request: Unsupported protocol version: {protocol_version}. " @@ -867,14 +867,14 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool return True - async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: # pragma: no cover + async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: """Replays events that would have been sent after the specified event ID. Only used when resumability is enabled. """ event_store = self._event_store if not event_store: - return + return # pragma: no cover try: headers = { @@ -883,7 +883,7 @@ async def _replay_events(self, last_event_id: str, request: Request, send: Send) "Content-Type": CONTENT_TYPE_SSE, } - if self.mcp_session_id: + if self.mcp_session_id: # pragma: no branch headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id # Get protocol version from header (already validated in _validate_protocol_version) @@ -904,7 +904,7 @@ async def send_event(event_message: EventMessage) -> None: stream_id = await event_store.replay_events_after(last_event_id, send_event) # If stream ID not in mapping, create it - if stream_id and stream_id not in self._request_streams: + if stream_id and stream_id not in self._request_streams: # pragma: no branch # Register SSE writer so close_sse_stream() can close it self._sse_stream_writers[stream_id] = sse_stream_writer @@ -921,10 +921,10 @@ async def send_event(event_message: EventMessage) -> None: event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) - except anyio.ClosedResourceError: + except anyio.ClosedResourceError: # pragma: lax no cover # Expected when close_sse_stream() is called logger.debug("Replay SSE stream closed by close_sse_stream()") - except Exception: + except Exception: # pragma: lax no cover logger.exception("Error in replay sender") # Create and start EventSourceResponse @@ -936,13 +936,13 @@ async def send_event(event_message: EventMessage) -> None: try: await response(request.scope, request.receive, send) - except Exception: + except Exception: # pragma: lax no cover logger.exception("Error in replay response") finally: await sse_stream_writer.aclose() await sse_stream_reader.aclose() - except Exception: + except Exception: # pragma: lax no cover logger.exception("Error replaying events") response = self._create_error_response( "Error replaying events", @@ -993,7 +993,7 @@ async def message_router(): if isinstance(message, JSONRPCResponse | JSONRPCError) and message.id is not None: target_request_id = str(message.id) # Extract related_request_id from meta if it exists - elif ( # pragma: no cover + elif ( session_message.metadata is not None and isinstance( session_message.metadata, @@ -1009,7 +1009,7 @@ async def message_router(): # regardless of whether a client is connected # messages will be replayed on the re-connect event_id = None - if self._event_store: # pragma: lax no cover + if self._event_store: event_id = await self._event_store.store_event(request_stream_id, message) logger.debug(f"Stored {event_id} from {request_stream_id}") @@ -1020,14 +1020,14 @@ async def message_router(): except (anyio.BrokenResourceError, anyio.ClosedResourceError): # pragma: no cover # Stream might be closed, remove from registry self._request_streams.pop(request_stream_id, None) - else: # pragma: no cover + else: logger.debug( f"""Request stream {request_stream_id} not found for message. Still processing message as the client might reconnect and replay.""" ) except anyio.ClosedResourceError: - if self._terminated: + if self._terminated: # pragma: lax no cover logger.debug("Read stream closed by client") else: logger.exception("Unexpected closure of read stream in message router") @@ -1041,7 +1041,7 @@ async def message_router(): # Yield the streams for the caller to use yield read_stream, write_stream finally: - for stream_id in list(self._request_streams.keys()): # pragma: lax no cover + for stream_id in list(self._request_streams.keys()): await self._clean_up_memory_streams(stream_id) self._request_streams.clear() diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index c25314eab6..39d434505c 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -173,7 +173,7 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA self.app.create_initialization_options(), stateless=True, ) - except Exception: # pragma: no cover + except Exception: # pragma: lax no cover logger.exception("Stateless session crashed") # Assert task group is not None for type checking diff --git a/src/mcp/server/transport_security.py b/src/mcp/server/transport_security.py index 1ed9842c0e..707d4b61dd 100644 --- a/src/mcp/server/transport_security.py +++ b/src/mcp/server/transport_security.py @@ -40,19 +40,19 @@ def __init__(self, settings: TransportSecuritySettings | None = None): # If not specified, disable DNS rebinding protection by default for backwards compatibility self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False) - def _validate_host(self, host: str | None) -> bool: # pragma: no cover + def _validate_host(self, host: str | None) -> bool: """Validate the Host header against allowed values.""" - if not host: + if not host: # pragma: no cover logger.warning("Missing Host header in request") return False # Check exact match first - if host in self.settings.allowed_hosts: + if host in self.settings.allowed_hosts: # pragma: no cover return True # Check wildcard port patterns for allowed in self.settings.allowed_hosts: - if allowed.endswith(":*"): + if allowed.endswith(":*"): # pragma: no branch # Extract base host from pattern base_host = allowed[:-2] # Check if the actual host starts with base host and has a port @@ -62,19 +62,19 @@ def _validate_host(self, host: str | None) -> bool: # pragma: no cover logger.warning(f"Invalid Host header: {host}") return False - def _validate_origin(self, origin: str | None) -> bool: # pragma: no cover + def _validate_origin(self, origin: str | None) -> bool: """Validate the Origin header against allowed values.""" # Origin can be absent for same-origin requests - if not origin: + if not origin: # pragma: no cover return True # Check exact match first - if origin in self.settings.allowed_origins: + if origin in self.settings.allowed_origins: # pragma: no cover return True # Check wildcard port patterns for allowed in self.settings.allowed_origins: - if allowed.endswith(":*"): + if allowed.endswith(":*"): # pragma: no branch # Extract base origin from pattern base_origin = allowed[:-2] # Check if the actual origin starts with base origin and has a port @@ -103,14 +103,14 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res if not self.settings.enable_dns_rebinding_protection: return None - # Validate Host header # pragma: no cover - host = request.headers.get("host") # pragma: no cover - if not self._validate_host(host): # pragma: no cover - return Response("Invalid Host header", status_code=421) # pragma: no cover + # Validate Host header + host = request.headers.get("host") + if not self._validate_host(host): + return Response("Invalid Host header", status_code=421) - # Validate Origin header # pragma: no cover - origin = request.headers.get("origin") # pragma: no cover - if not self._validate_origin(origin): # pragma: no cover - return Response("Invalid Origin header", status_code=403) # pragma: no cover + # Validate Origin header + origin = request.headers.get("origin") + if not self._validate_origin(origin): + return Response("Invalid Origin header", status_code=403) - return None # pragma: no cover + return None diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index ebf534d792..3b48152d5b 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -91,9 +91,9 @@ def validate_scope(self, requested_scope: str | None) -> list[str] | None: requested_scopes = requested_scope.split(" ") allowed_scopes = [] if self.scope is None else self.scope.split(" ") for scope in requested_scopes: - if scope not in allowed_scopes: # pragma: no branch + if scope not in allowed_scopes: raise InvalidScopeError(f"Client was not registered with scope {scope}") - return requested_scopes # pragma: no cover + return requested_scopes def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl: if redirect_uri is not None: diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 243eef5ae6..9c72a23844 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -451,7 +451,7 @@ async def _handle_session_message(message: SessionMessage) -> None: try: await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) await stream.aclose() - except Exception: # pragma: no cover + except Exception: # pragma: lax no cover # Stream might already be closed pass self._response_streams.clear() diff --git a/tests/interaction/README.md b/tests/interaction/README.md new file mode 100644 index 0000000000..be68c3b0f1 --- /dev/null +++ b/tests/interaction/README.md @@ -0,0 +1,228 @@ +# Interaction-model test suite + +This suite enumerates the MCP interaction model as end-to-end tests: one test per piece of +functionality, asserting the full client↔server round trip through the public API. It exists to +pin the SDK's observable behaviour — every request type, every notification direction, every +error plane — so that internal rewrites of the send/receive path can be proven equivalent by +running the suite before and after. + +```bash +uv run --frozen pytest tests/interaction/ +``` + +The whole suite is in-process and event-driven — including the streamable HTTP, SSE, and OAuth +flows — with a single subprocess test for stdio. + +## Ground rules + +- **Public API only.** Tests drive a `Client` connected to a `Server` or `MCPServer`. Nothing + reaches into session internals, so the suite keeps working when those internals change. + `ClientSession` is used directly only for behaviours `Client` cannot express (skipping + initialization, requesting a non-default protocol version). +- **Pin current behaviour.** Every test passes against the current `main`, including behaviours + that diverge from the specification. A failing or xfailed test proves nothing about whether a + rewrite preserved behaviour; a passing test that pins the wrong output exactly does. Known + divergences are recorded as data on the requirement (see below), not worked around in the test. +- **Spec-mandated assertions, not implementation quirks.** Error *codes* are asserted against + the constants in `mcp.types`; error *message strings* are pinned only where they are the + SDK's own deliberate output. +- **No sleeps, no real I/O.** Concurrency is coordinated with `anyio.Event`; every wait that + could hang is bounded by `anyio.fail_after(5)`. The HTTP and OAuth tests drive the Starlette + app in-process through the suite's streaming ASGI bridge (`transports/_bridge.py`), which + delivers each response chunk as the server produces it — full duplex, but still no sockets, + threads, or subprocesses anywhere outside the one stdio test. + +## Layout + +```text +tests/interaction/ + _requirements.py the requirements manifest (see below) + _helpers.py shared type aliases + the wire-recording transport + _connect.py the transport-parametrized connection factories + conftest.py the connect fixture (the transport matrix) + test_coverage.py enforces the manifest ↔ test contract + lowlevel/ one file per feature area, against the low-level Server + mcpserver/ the same feature areas in MCPServer's natural idiom + transports/ behaviour specific to one transport (sessions, resumability, framing) + auth/ OAuth flows against an in-process authorization server +``` + +The two server APIs produce genuinely different wire output for the same conceptual feature +(`MCPServer` generates schemas, converts exceptions to `isError` results, attaches structured +content), so they get parallel directories with mirrored file names rather than one parametrized +test body — each directory pins its flavour's true output exactly. + +### The transport matrix + +Transport-agnostic tests take the `connect` fixture instead of constructing `Client(server)` +directly, and therefore run once per transport: over the in-memory transport, over the server's +real streamable HTTP app driven in-process through the streaming bridge, and over the legacy SSE +transport the same way. A test connects with `async with connect(server, ...) as client:` and +asserts the same output on every leg, because the transport is not supposed to change observable +behaviour. Tests that are tied to one transport do not use the fixture: the wire-recording tests +(their seam is the in-memory stream pair), the bare-`ClientSession` lifecycle tests, the +real-clock timeout tests (the timeout machinery is transport-independent and must not race +transport latency), and everything under `transports/`, which pins behaviour only observable on +that transport. + +A transport conformance test in `transports/` speaks raw `httpx` against the mounted ASGI app +**only** when its assertion is about HTTP semantics that `Client` cannot observe — status codes, +response headers, SSE event fields, which stream a message travels on. Any other behaviour is +asserted through a `Client`, connected to the mounted app via `client_via_http(http)` so several +clients can share one session manager. + +## The requirements manifest + +`_requirements.py` maps every behaviour the suite covers to the reason it must hold: + +```python +"tools:call:content:text": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#text-content", + behavior="tools/call delivers arguments to the tool handler and returns its text content.", +), +``` + +- **`source`** is a deep link into the MCP specification for externally mandated behaviour, + the literal string `"sdk"` for behaviour the SDK chose where the spec is silent, or + `"issue:#n"` for a regression lock. +- **`behavior`** describes the *required* behaviour — what the specification (or the SDK's own + contract) says should happen. Tests always pin the SDK's current behaviour; where that falls + short of `behavior`, the gap is recorded as data rather than hidden in the test. +- **`divergence`** records that gap for entries whose tests pin the divergent current behaviour. +- **`deferred`** marks a behaviour that is tracked but has no test in this suite, with a precise + reason: the SDK does not implement it, the negative cannot be observed, the assertion is + schema-level rather than interaction-level, the feature is experimental (tasks), or the test + would require real-time waits the suite refuses. +- **`transports`** names the transports a behaviour applies to; omitted means transport-independent. +- **`issue`** carries the tracking link for a recorded gap once one is filed. + +Tests link themselves to the manifest with a decorator: + +```python +@requirement("tools:call:content:text") +async def test_call_tool_returns_text_content() -> None: ... +``` + +`test_coverage.py` enforces the contract in both directions: every non-deferred requirement must +be exercised by at least one test, every deferred requirement by none, and an unknown ID fails at +import time. A behaviour without a manifest entry cannot be silently half-tested, and a manifest +entry without a test cannot be silently aspirational. + +### The divergence lifecycle + +1. A test reveals that the SDK does not do what the spec says. The test pins what the SDK + *actually does* and a `Divergence(note=..., issue=...)` goes on the requirement. +2. When the behaviour is eventually fixed, the pinned test fails. Whoever makes the change finds + the divergence note explaining that the old behaviour was a known gap, re-pins the test to the + spec-correct output, and deletes the `Divergence`. +3. An empty divergence list means the SDK is spec-conformant on every behaviour the suite covers. + +A requirement may carry both `divergence` and `deferred`: the divergence records that the SDK falls +short of the spec, and the deferral records why no test pins it (typically because the divergent +behaviour cannot be driven through the public API). Divergence alone implies a test pins the +divergent behaviour; divergence plus deferred means the gap is known but unpinned. + +This is also the triage key for any rewrite: a test that fails on the new code path either has a +divergence note (the rewrite accidentally fixed a known gap — decide whether to keep the fix) or +it does not (the rewrite broke something that was correct — fix the rewrite). + +### When a new spec revision is released + +1. Update `SPEC_REVISION` and walk the new revision's changelog. +2. For each changed interaction, find its requirements (the IDs use the wire method strings the + changelog speaks in), re-audit the tests against the new text, and update `source` links and + assertions where behaviour legitimately changed. +3. New interactions get new requirements and new tests; removed interactions get their + requirements deleted along with their tests. +4. A behaviour that is correct under both revisions needs no change beyond the `source` link. + +## Writing a test + +The shortest complete example of the conventions: + +```python +@requirement("tools:call:content:text") +async def test_call_tool_returns_text_content() -> None: + """Arguments reach the tool handler; its content comes back as the call result.""" + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "add" + assert params.arguments is not None + return CallToolResult(content=[TextContent(text=str(params.arguments["a"] + params.arguments["b"]))]) + + server = Server("adder", on_call_tool=call_tool) + + async with Client(server) as client: + result = await client.call_tool("add", {"a": 2, "b": 3}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="5")])) +``` + +- **The server is defined inside the test** (or in a small fixture at the top of the file when + several tests genuinely share it). The whole observable behaviour fits on one screen. +- **Test names are behaviour sentences** — they state the observable outcome, not the feature + being poked. Docstrings add the one or two sentences of context a reviewer needs, including + whether the assertion is spec-mandated, SDK-defined, or a known divergence. +- **Handlers assert their dispatch identity first** (`assert params.name == "add"`), proving the + request that arrived is the request the test sent. +- **The result proves the round trip.** Server-side observations travel back to the test through + the protocol itself (a tool returns what it saw) or through a closure-captured list; the test + asserts after the call returns. +- **Order within a test**: server handlers → server construction → client callbacks → connect → + act → assert. The test reads in the order the conversation happens. +- A registered handler or tool that a test never invokes gets a `raise NotImplementedError` body + so it cannot silently become load-bearing. +- A test that needs a peer no real `Server` or `Client` can play (a server that answers initialize + with an unsupported version, a client that sends malformed params) plays that side of the wire by + hand over `create_client_server_memory_streams()`. This scripted-peer pattern is the suite's only + way to drive behaviour the typed API cannot produce, and the docstring of every such test says so. + +Stack a second `@requirement` decorator only when a test's natural assertions incidentally prove +another behaviour — one capabilities snapshot proving four `*:capability:declared` entries, one +input-schema identity check proving each preserved keyword. Do not build a test around covering +many requirements at once; if the assertions would be separate, write separate tests. + +### Choosing an assertion + +| The property under test is… | Assert with | +|---|---| +| the result of a transformation (arguments → output, exception → error result) | `result == snapshot(...)` of the full object, so any field the implementation adds or drops fails the test | +| pass-through of an opaque value (`_meta`, cursors) | identity against the same variable that was sent — a snapshot of a pass-through value only matches the input because a human checked two literals correspond | +| an error | `pytest.raises(MCPError)` and a snapshot of `exc.value.error` when the message is the SDK's own; a plain `==` on `.code` against the `mcp.types` constant when it is not | +| third-party output embedded in a result (validation messages) | the stable prefix only — never pin text that changes with a dependency upgrade | + +### Notifications and concurrency + +The client's receive loop dispatches each incoming message to completion before reading the next, +and the in-memory transport delivers everything on one ordered stream. Together these guarantee +that every notification a server handler emits before its response reaches the client callback +before the originating request returns — so tests collect notifications into a plain list and +assert after the call, with no synchronisation. The exceptions: + +- a notification not triggered by a request the test is awaiting needs an `anyio.Event` set in + the receiving handler and awaited under `anyio.fail_after(5)`; +- the ordering guarantee does not survive transports that split messages across streams (the + streamable HTTP standalone GET stream) — see `transports/test_streamable_http.py`. + +### Coverage + +CI requires 100% line and branch coverage, including `tests/`, and `strict-no-cover` fails the +build if a line marked `# pragma: no cover` is ever executed. When a new test starts covering a +pragma'd line in `src/`, delete the pragma in the same change. Do not add new `# type: ignore` or +`# noqa` comments; restructure instead. Two pragmas are sanctioned in this suite's test code, both +for known-upstream tracer bugs and only after restructuring has been tried: `# pragma: no branch` +on a `with`/`async with` line whose only fault is coverage.py mis-tracing the exit arc of a nested +async context (reserve it for shapes that cannot collapse — a sync `with` adjacent to an +`async with`); and `# pragma: lax no cover` on a single statement that 3.11's tracer drops because +the preceding `async with` unwinds via `coro.throw()` (python/cpython#106749, wontfix on 3.11) — +this hits any test that must run statements after a `ClientSession`/`streamable_http_client` exits +but still inside an outer `async with`, and no restructure can avoid it. + +A handful of `# pragma: lax no cover` markers in `src/` cover teardown exception handlers whose +execution is timing-dependent under the in-process HTTP bridge — the POST-stream and +stateless-session `except Exception` handlers in `server/streamable_http*.py`, the `_terminated` +check in `message_router`, and the response-stream double-close guard in +`BaseSession._receive_loop`. `strict-no-cover` does not check `lax` lines; do not promote them to +strict `no cover` without first making the teardown ordering deterministic. The suite also relies +on a one-line `src/mcp/server/sse.py` fix (`sse_stream_reader.aclose()`) that closes a stream the +SSE leg would otherwise leak. diff --git a/tests/interaction/__init__.py b/tests/interaction/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/interaction/_connect.py b/tests/interaction/_connect.py new file mode 100644 index 0000000000..1faf4aa8d6 --- /dev/null +++ b/tests/interaction/_connect.py @@ -0,0 +1,360 @@ +"""Transport-parametrized connection factories for the interaction suite. + +The `connect` fixture (see conftest.py) hands tests one of these factories so the same test body +runs over each transport without naming any of them: the factory is a drop-in replacement for +constructing `Client(server, ...)` and yields the connected client. The HTTP factories drive the +server's real Starlette app through the in-process streaming bridge, so the full transport layer +(session ids, SSE encoding, session management) runs with no sockets, threads, or subprocesses. +""" + +from collections.abc import AsyncIterator, Awaitable, Callable, Iterable +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from typing import Any, Protocol + +import httpx +from httpx_sse import ServerSentEvent, aconnect_sse +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import Mount, Route + +from mcp.client.client import Client +from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT +from mcp.client.sse import sse_client +from mcp.client.streamable_http import streamable_http_client +from mcp.server import Server +from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier +from mcp.server.auth.settings import AuthSettings +from mcp.server.mcpserver import MCPServer +from mcp.server.sse import SseServerTransport +from mcp.server.streamable_http import EventStore +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.server.transport_security import TransportSecuritySettings +from mcp.types import ( + LATEST_PROTOCOL_VERSION, + ClientCapabilities, + Implementation, + InitializeRequestParams, + JSONRPCMessage, + JSONRPCRequest, + JSONRPCResponse, + jsonrpc_message_adapter, +) +from tests.interaction.transports._bridge import StreamingASGITransport + +# The in-process app is mounted at this origin purely so URLs are well-formed; nothing listens here. +BASE_URL = "http://127.0.0.1:8000" + +# DNS-rebinding protection validates Host/Origin headers against a real network attack that cannot +# exist for an in-process ASGI app, so the in-process factories disable it; tests that exercise the +# protection itself pass explicit settings (or transport_security=None to get the localhost +# auto-enable behaviour). +NO_DNS_REBINDING_PROTECTION = TransportSecuritySettings(enable_dns_rebinding_protection=False) + + +class Connect(Protocol): + """Connect a Client to a server over the transport selected by the `connect` fixture. + + Accepts the same keyword arguments as `Client` and yields the connected client. + """ + + def __call__( + self, + server: Server | MCPServer, + *, + read_timeout_seconds: float | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + client_info: Implementation | None = None, + elicitation_callback: ElicitationFnT | None = None, + ) -> AbstractAsyncContextManager[Client]: ... + + +@asynccontextmanager +async def connect_in_memory( + server: Server | MCPServer, + *, + read_timeout_seconds: float | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + client_info: Implementation | None = None, + elicitation_callback: ElicitationFnT | None = None, +) -> AsyncIterator[Client]: + """Yield a Client connected to the server over the in-memory transport.""" + async with Client( + server, + read_timeout_seconds=read_timeout_seconds, + sampling_callback=sampling_callback, + list_roots_callback=list_roots_callback, + logging_callback=logging_callback, + message_handler=message_handler, + client_info=client_info, + elicitation_callback=elicitation_callback, + ) as client: + yield client + + +@asynccontextmanager +async def connect_over_streamable_http( + server: Server | MCPServer, + *, + stateless_http: bool = False, + json_response: bool = False, + event_store: EventStore | None = None, + retry_interval: int | None = None, + read_timeout_seconds: float | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + client_info: Implementation | None = None, + elicitation_callback: ElicitationFnT | None = None, +) -> AsyncIterator[Client]: + """Yield a Client connected to the server's streamable HTTP app, entirely in process. + + With the defaults this is the matrix leg (stateful sessions, SSE responses); the + transport-specific tests pass `stateless_http` or `json_response` to select the other + server modes, and the resumability tests pass an `event_store` (with `retry_interval=0` so + the client's reconnection wait is a no-op). + """ + app = server.streamable_http_app( + stateless_http=stateless_http, + json_response=json_response, + event_store=event_store, + retry_interval=retry_interval, + transport_security=NO_DNS_REBINDING_PROTECTION, + ) + async with ( + server.session_manager.run(), + httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http_client, + Client( + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client), + read_timeout_seconds=read_timeout_seconds, + sampling_callback=sampling_callback, + list_roots_callback=list_roots_callback, + logging_callback=logging_callback, + message_handler=message_handler, + client_info=client_info, + elicitation_callback=elicitation_callback, + ) as client, + ): + yield client + + +@asynccontextmanager +async def mounted_app( + server: Server | MCPServer, + *, + stateless_http: bool = False, + json_response: bool = False, + event_store: EventStore | None = None, + retry_interval: int | None = None, + transport_security: TransportSecuritySettings | None = NO_DNS_REBINDING_PROTECTION, + on_request: Callable[[httpx.Request], Awaitable[None]] | None = None, + headers: dict[str, str] | None = None, + auth: AuthSettings | None = None, + token_verifier: TokenVerifier | None = None, + auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any] | None = None, +) -> AsyncIterator[tuple[httpx.AsyncClient, StreamableHTTPSessionManager]]: + """Mount the server's streamable HTTP app on the in-process bridge and yield an httpx client. + + Yields the httpx client (rooted at the in-process origin) and the live session manager. Tests + use this in two ways: for raw-httpx assertions (status codes, headers, SSE bytes) the test + speaks HTTP through the yielded client directly; for client-driven assertions the test wraps + that client in `client_via_http(http)`, which lets several `Client`s share the one mounted + session manager. `on_request` records every outgoing HTTP request before it leaves the + yielded client. + + DNS-rebinding protection is disabled by default; pass explicit settings (or `None` for the + localhost auto-enable behaviour) to test the protection itself. + """ + lowlevel = server._lowlevel_server if isinstance(server, MCPServer) else server + app = lowlevel.streamable_http_app( + stateless_http=stateless_http, + json_response=json_response, + event_store=event_store, + retry_interval=retry_interval, + transport_security=transport_security, + auth=auth, + token_verifier=token_verifier, + auth_server_provider=auth_server_provider, + ) + event_hooks = {"request": [on_request]} if on_request is not None else None + async with ( + server.session_manager.run(), + httpx.AsyncClient( + transport=StreamingASGITransport(app), base_url=BASE_URL, event_hooks=event_hooks, headers=headers + ) as http_client, + ): + yield http_client, server.session_manager + + +@asynccontextmanager +async def client_via_http( + http_client: httpx.AsyncClient, + *, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + elicitation_callback: ElicitationFnT | None = None, +) -> AsyncIterator[Client]: + """Connect a `Client` over an already-mounted streamable HTTP app. + + Use with `mounted_app(...)` so several `Client`s share the one session manager, or so a + client-driven assertion can sit alongside raw-httpx assertions in the same test. The + underlying `httpx.AsyncClient` is left open when the `Client` exits. + """ + transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) + async with Client( + transport, + logging_callback=logging_callback, + message_handler=message_handler, + elicitation_callback=elicitation_callback, + ) as client: + yield client + + +def parse_sse_messages(events: Iterable[ServerSentEvent]) -> list[JSONRPCMessage]: + """Decode SSE events into JSON-RPC messages, skipping priming events that carry no data.""" + return [jsonrpc_message_adapter.validate_json(event.data) for event in events if event.data] + + +async def post_jsonrpc( + http: httpx.AsyncClient, body: dict[str, object], *, session_id: str | None = None +) -> tuple[httpx.Response, list[JSONRPCMessage]]: + """POST a JSON-RPC body and read its SSE response stream to completion. + + Returns the HTTP response (for header/status assertions) and the parsed JSON-RPC messages + that arrived on the response's SSE stream. Only meaningful for requests the server answers + with `text/event-stream`; for error responses or 202 notification acknowledgements, use + `httpx.AsyncClient.post` directly and assert on the response. + """ + async with aconnect_sse(http, "POST", "/mcp", json=body, headers=base_headers(session_id=session_id)) as source: + events = [event async for event in source.aiter_sse()] + return source.response, parse_sse_messages(events) + + +def base_headers(*, session_id: str | None = None) -> dict[str, str]: + """Standard request headers for raw-httpx streamable-HTTP tests. + + Every well-formed request carries these (Accept covering both response representations, + Content-Type for POST bodies, MCP-Protocol-Version at the latest revision, and the session + ID once one exists), so a test that wants to assert a specific rejection only varies the one + header under test. + """ + headers = { + "accept": "application/json, text/event-stream", + "content-type": "application/json", + "mcp-protocol-version": LATEST_PROTOCOL_VERSION, + } + if session_id is not None: + headers["mcp-session-id"] = session_id + return headers + + +def initialize_body(request_id: int = 1) -> dict[str, object]: + """A wire-level initialize JSON-RPC request body, exactly as an SDK client would send it.""" + params = InitializeRequestParams( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ClientCapabilities(), + client_info=Implementation(name="raw", version="0.0.0"), + ) + return JSONRPCRequest( + jsonrpc="2.0", id=request_id, method="initialize", params=params.model_dump(by_alias=True, exclude_none=True) + ).model_dump(by_alias=True, exclude_none=True) + + +async def initialize_via_http(http: httpx.AsyncClient) -> str: + """Perform the initialize handshake over a raw `httpx.AsyncClient` and return the session ID. + + Validates the SSE response and sends the `notifications/initialized` follow-up, so the server + is fully ready for subsequent feature requests when this returns. + """ + async with aconnect_sse(http, "POST", "/mcp", json=initialize_body(), headers=base_headers()) as source: + assert source.response.status_code == 200 + # An event-store-backed server opens the stream with a priming event (empty data); skip it. + events = [event async for event in source.aiter_sse() if event.data] + assert len(events) == 1 + assert JSONRPCResponse.model_validate_json(events[0].data).id == 1 + session_id = source.response.headers["mcp-session-id"] + initialized = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "method": "notifications/initialized"}, + headers=base_headers(session_id=session_id), + ) + assert initialized.status_code == 202 + return session_id + + +def build_sse_app(server: Server | MCPServer) -> tuple[Starlette, SseServerTransport]: + """Mount a server on a Starlette app exposing the legacy SSE transport at /sse and /messages/. + + `MCPServer.sse_app()` exists but does not expose the underlying `SseServerTransport`, which + the SSE-specific tests need; building the app explicitly here gives both server flavours the + same routing while keeping that handle. + """ + sse = SseServerTransport( + "/messages/", security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False) + ) + lowlevel = server._lowlevel_server if isinstance(server, MCPServer) else server + + async def handle_sse(request: Request) -> Response: + async with sse.connect_sse(request.scope, request.receive, request._send) as (read, write): + await lowlevel.run(read, write, lowlevel.create_initialization_options()) + return Response() + + app = Starlette( + routes=[ + Route("/sse", endpoint=handle_sse, methods=["GET"]), + Mount("/messages/", app=sse.handle_post_message), + ], + ) + return app, sse + + +@asynccontextmanager +async def connect_over_sse( + server: Server | MCPServer, + *, + read_timeout_seconds: float | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + client_info: Implementation | None = None, + elicitation_callback: ElicitationFnT | None = None, +) -> AsyncIterator[Client]: + """Yield a Client connected to the server's legacy SSE transport, entirely in process.""" + app, _ = build_sse_app(server) + + def httpx_client_factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + # The SSE server transport's connect_sse runs the entire MCP session inside the GET + # request and only releases its streams after that request observes a disconnect, so the + # bridge must let the application drain rather than cancelling at close. + return httpx.AsyncClient( + transport=StreamingASGITransport(app, cancel_on_close=False), + base_url=BASE_URL, + headers=headers, + timeout=timeout, + auth=auth, + ) + + transport = sse_client(f"{BASE_URL}/sse", httpx_client_factory=httpx_client_factory) + async with Client( + transport, + read_timeout_seconds=read_timeout_seconds, + sampling_callback=sampling_callback, + list_roots_callback=list_roots_callback, + logging_callback=logging_callback, + message_handler=message_handler, + client_info=client_info, + elicitation_callback=elicitation_callback, + ) as client: + yield client diff --git a/tests/interaction/_helpers.py b/tests/interaction/_helpers.py new file mode 100644 index 0000000000..25833b0ca5 --- /dev/null +++ b/tests/interaction/_helpers.py @@ -0,0 +1,107 @@ +"""Shared helpers for the interaction suite. + +Keep this module small: it exists only for (a) types that every test would otherwise have to +assemble from the SDK's internals to annotate a client callback, and (b) the recording transport +used by the wire-level tests. Server fixtures and assertion helpers belong in the test that uses +them. +""" + +from types import TracebackType + +import anyio +from typing_extensions import Self + +from mcp.client._transport import ReadStream, Transport, TransportStreams, WriteStream +from mcp.shared.message import SessionMessage +from mcp.shared.session import RequestResponder +from mcp.types import ClientResult, ServerNotification, ServerRequest + +# TODO: this union is the parameter type of every client message handler (MessageHandlerFnT), +# but the SDK does not export a name for it -- writing a correctly-typed handler requires +# importing RequestResponder from mcp.shared.session and assembling the union by hand. It +# should be a named, exported alias next to MessageHandlerFnT (like ClientRequestContext is +# for the request callbacks), at which point this alias can be deleted. +IncomingMessage = RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception +"""Everything a client message handler can receive.""" + + +class _RecordingReadStream: + """Delegates to a read stream, appending every received message to a log.""" + + def __init__(self, inner: ReadStream[SessionMessage | Exception], log: list[SessionMessage | Exception]) -> None: + self._inner = inner + self._log = log + + async def receive(self) -> SessionMessage | Exception: + item = await self._inner.receive() + self._log.append(item) + return item + + async def aclose(self) -> None: + await self._inner.aclose() + + def __aiter__(self) -> Self: + return self + + async def __anext__(self) -> SessionMessage | Exception: + try: + return await self.receive() + except anyio.EndOfStream: + raise StopAsyncIteration from None + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None + ) -> bool | None: + await self.aclose() + return None + + +class _RecordingWriteStream: + """Delegates to a write stream, appending every sent message to a log.""" + + def __init__(self, inner: WriteStream[SessionMessage], log: list[SessionMessage]) -> None: + self._inner = inner + self._log = log + + async def send(self, item: SessionMessage, /) -> None: + self._log.append(item) + await self._inner.send(item) + + async def aclose(self) -> None: + await self._inner.aclose() + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None + ) -> bool | None: + await self.aclose() + return None + + +class RecordingTransport: + """Wraps a Transport and records every message crossing the client's transport boundary. + + `sent` holds everything the client wrote towards the server; `received` holds everything the + server delivered to the client. The recording sits at the transport seam -- the exact payloads + a real transport would serialise -- and never touches the session, so wire-level assertions + written against it survive changes to the receive path. + """ + + def __init__(self, inner: Transport) -> None: + self.inner = inner + self.sent: list[SessionMessage] = [] + self.received: list[SessionMessage | Exception] = [] + + async def __aenter__(self) -> TransportStreams: + read_stream, write_stream = await self.inner.__aenter__() + return _RecordingReadStream(read_stream, self.received), _RecordingWriteStream(write_stream, self.sent) + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None + ) -> bool | None: + return await self.inner.__aexit__(exc_type, exc_val, exc_tb) diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py new file mode 100644 index 0000000000..109b30fc77 --- /dev/null +++ b/tests/interaction/_requirements.py @@ -0,0 +1,2816 @@ +"""Requirements manifest for the interaction-model test suite. + +Every user-facing behaviour the SDK must satisfy, keyed by a stable `:[:]` +ID. Each entry owns the tests that exercise it: tests declare `@requirement("")` (a test that +proves several behaviours stacks several decorators) and `test_coverage.py` enforces the contract +in both directions: every non-deferred requirement has at least one test, and every test carries +at least one requirement. + +Sources: + spec URL -- externally mandated by the MCP specification (deep link to the section) + `sdk` -- a behavioural guarantee the SDK chose; not spec-mandated + `issue:#n` -- regression lock-in for a previously fixed bug + +The `behavior` sentence describes the REQUIRED behaviour -- what the specification (or the SDK's +own contract) says should happen. Tests always pin the SDK's current behaviour. Where current +behaviour falls short of `behavior`, the gap is recorded as data: `divergence` on entries whose +tests pin the divergent behaviour, or `deferred` on entries that are tracked but not yet covered +by a test in this suite. An entry may carry both: `divergence` records the spec-compliance gap +(issue-able) and `deferred` records why no test exists; `divergence` alone implies a test pins +the divergent behaviour. `issue` carries the tracking link for a recorded gap once one is filed. + +`deferred` reasons take one of three shapes: where the behaviour is exercised elsewhere in this +repo the reason names the covering test path; where the SDK does not implement the behaviour at +all the reason starts with "Not implemented in the SDK"; and where an interaction-level test is +planned but not yet written the reason starts with "Not yet covered here". + +`transports` records which transports a behaviour applies to (or is observable on); None means +the behaviour is transport-independent. + +The ID vocabulary and entry granularity are aligned with the TypeScript SDK's end-to-end +requirements suite, so coverage and recorded divergences can be compared across the two SDKs +entry by entry; IDs that exist in only one SDK reflect genuinely different API surface. +""" + +import re +from collections.abc import Callable +from dataclasses import dataclass +from typing import Literal, TypeVar + +import pytest + +SPEC_REVISION = "2025-11-25" +SPEC_BASE_URL = f"https://modelcontextprotocol.io/specification/{SPEC_REVISION}" + +Transport = Literal["in-memory", "stdio", "streamable-http", "sse"] + +_TestFn = TypeVar("_TestFn", bound=Callable[..., object]) + +_SOURCE_PATTERN = re.compile(r"https://modelcontextprotocol\.io/specification/.+|sdk|issue:#\d+") + +_TASKS_DEFERRAL = ( + "Tasks are experimental and the spec is being substantially revised; python task behaviour is " + "covered by tests/experimental/tasks/ until the next spec revision settles." +) + + +@dataclass(frozen=True, kw_only=True) +class Divergence: + """A documented gap between the SDK behaviour this suite pins and what `source` mandates.""" + + note: str + issue: str | None = None + + +@dataclass(frozen=True, kw_only=True) +class Requirement: + """A single testable behaviour and the provenance of why it must hold.""" + + source: str + behavior: str + transports: tuple[Transport, ...] | None = None + divergence: Divergence | None = None + deferred: str | None = None + issue: str | None = None + + def __post_init__(self) -> None: + if not _SOURCE_PATTERN.fullmatch(self.source): + raise ValueError(f"source must be a specification URL, 'sdk', or 'issue:#n', got {self.source!r}") + + +REQUIREMENTS: dict[str, Requirement] = { + # ═══════════════════════════════════════════════════════════════════════════ + # Lifecycle & version negotiation + # ═══════════════════════════════════════════════════════════════════════════ + "lifecycle:capability:client-not-declared": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#operation", + behavior=( + "The client rejects sending notifications or registering handlers for capabilities it did not declare." + ), + divergence=Divergence( + note=( + "The client does not check its own declared capabilities before sending notifications or " + "serving callbacks; nothing prevents a caller from violating the spec's MUST." + ), + ), + deferred=( + "Not implemented in the SDK: the client does not check its own declared capabilities before " + "sending notifications or serving callbacks." + ), + ), + "lifecycle:capability:server-not-advertised": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#operation", + behavior=( + "The client rejects calls to methods (e.g. resources/list) for capabilities the server did not advertise." + ), + divergence=Divergence( + note=( + "The client sends any request regardless of the server's advertised capabilities and " + "surfaces whatever the server answers; the spec's MUST is not enforced." + ), + ), + deferred=( + "Not implemented in the SDK: the client sends any request regardless of the server's " + "advertised capabilities and surfaces whatever the server answers." + ), + ), + "lifecycle:initialize:basic": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior=( + "Connecting sends initialize with the protocol version, client capabilities, and client " + "info; the server responds with its own and the connection is established." + ), + ), + "lifecycle:initialize:server-info": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior="The initialize result identifies the server: name and version, plus title when declared.", + ), + "lifecycle:initialize:instructions": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior="A server may include an instructions string in the initialize result; the client exposes it.", + ), + "lifecycle:initialize:capabilities:from-handlers": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#capability-negotiation", + behavior=( + "The server advertises a capability for each feature area it has a registered handler for, " + "and omits the capability for areas it does not." + ), + ), + "lifecycle:initialize:capabilities:minimal": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#capability-negotiation", + behavior="A server with no feature handlers advertises no feature capabilities.", + ), + "lifecycle:initialize:client-info": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior="The client's name, version, and title are visible to server handlers after initialization.", + ), + "lifecycle:initialize:client-capabilities": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#capability-negotiation", + behavior=( + "The client capabilities visible to the server reflect which client callbacks are configured " + "(sampling, elicitation, roots)." + ), + ), + "lifecycle:initialized-notification": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior=( + "After successful initialization, the client sends exactly one initialized notification, " + "before any non-ping request." + ), + ), + "lifecycle:ping": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/ping#behavior-requirements", + behavior="ping in either direction returns an empty result.", + ), + "ping:client-to-server": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/ping#behavior-requirements", + behavior="A client-initiated ping receives an empty result from the server.", + ), + "ping:server-to-client": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/ping#behavior-requirements", + behavior="A server-initiated ping receives an empty result from the client.", + ), + "lifecycle:requests-before-initialized": Requirement( + source="sdk", + behavior=( + "A request other than ping sent before the initialization handshake completes is rejected with an error." + ), + ), + "lifecycle:pre-initialization-ordering": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior=( + "Before initialization completes, the client sends no requests other than pings, and the " + "server sends no requests other than pings and logging." + ), + divergence=Divergence( + note=( + "The server's send methods (create_message / elicit_form / list_roots) do not check " + "initialization state before sending; on the client side, Client always completes the " + "handshake before any caller code runs." + ), + ), + deferred=( + "Not implemented in the SDK: neither side enforces sender-side restraint. The server's send " + "methods (create_message / elicit_form / list_roots) do not check initialization state before " + "sending, and there is no natural hook to issue a server-to-client request between the " + "initialize response and the initialized notification through the public API; on the client " + "side, Client always completes the handshake before any caller code runs." + ), + ), + "lifecycle:version:downgrade": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#version-negotiation", + behavior=( + "When the server returns an older supported protocol version, the client downgrades to it " + "and the connection succeeds at that version." + ), + ), + "lifecycle:version:match": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#version-negotiation", + behavior=( + "When the server supports the requested protocol version it echoes that version in the " + "initialize result, and the connection proceeds at that version." + ), + ), + "lifecycle:version:server-fallback-latest": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#version-negotiation", + behavior=( + "An initialize request carrying a protocol version the server does not support is answered " + "with another version the server supports — the latest one — rather than an error." + ), + ), + "lifecycle:version:reject-unsupported": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#version-negotiation", + behavior=( + "A client that receives an initialize response carrying a protocol version it does not " + "support fails initialization with an error rather than proceeding with the session." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Protocol primitives: cancellation, timeout, progress, errors, _meta + # ═══════════════════════════════════════════════════════════════════════════ + "protocol:request-id:unique": Requirement( + source=f"{SPEC_BASE_URL}/basic#requests", + behavior=( + "Every request sent on a session carries a unique, non-null string or integer id; ids are " + "never reused within the session." + ), + ), + "protocol:notifications:no-response": Requirement( + source=f"{SPEC_BASE_URL}/basic#notifications", + behavior=( + "Notifications are never answered: every message the server delivers is either the response " + "to a request the client sent or a notification carrying no id." + ), + ), + "protocol:cancel:abort-signal": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#cancellation-flow", + behavior=( + "Cancelling an in-flight request through the client API sends notifications/cancelled with " + "the request id and fails the local call." + ), + deferred=( + "Not implemented in the SDK: there is no public client-side API to cancel an in-flight " + "request; cancellation requires hand-constructing the notification (which is how " + "protocol:cancel:in-flight exercises the receiving side)." + ), + ), + "protocol:cancel:handler-abort-propagates": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior="On the receiving side, a cancellation notification stops the running request handler.", + ), + "protocol:cancel:in-flight": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior=( + "A cancellation notification for an in-flight request stops the server-side handler, and the " + "receiver does not send a response for the cancelled request." + ), + divergence=Divergence( + note=( + "The spec says receivers of a cancellation SHOULD NOT send a response for the cancelled " + "request; the server sends an error response (code 0, 'Request cancelled'), which is what " + "unblocks the SDK client's pending call." + ), + ), + ), + "protocol:cancel:initialize-not-cancellable": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior="The client never sends notifications/cancelled for the initialize request.", + deferred=( + "Not implemented in the SDK: the client has no public cancellation API at all, so no pathway " + "exists that could cancel initialize; there is no distinct behaviour to pin beyond that absence." + ), + ), + "protocol:cancel:late-response-ignored": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior=( + "A response that arrives after the sender issued notifications/cancelled is ignored; the " + "request stays failed and no error is raised." + ), + divergence=Divergence( + note=( + "A response whose id matches no in-flight request is delivered to the message handler " + "as a RuntimeError rather than being silently ignored. The post-cancellation case is the " + "same code path; tested in its unknown-id form because that is deterministic without the " + "client-side cancellation API the SDK does not yet provide." + ), + ), + ), + "protocol:cancel:server-survives": Requirement( + source="sdk", + behavior="The session continues to serve new requests after an earlier request was cancelled.", + ), + "protocol:cancel:server-to-client": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior=( + "A server that abandons an in-flight server-initiated request (sampling, elicitation, roots) " + "cancels it, and the client stops processing the cancelled request." + ), + divergence=Divergence( + note=( + "Abandoning a server-side send_request emits no cancellation notification, and the client " + "could not act on one anyway: client callbacks run inline in the receive loop, so a " + "cancellation is not even read until the callback has finished." + ), + ), + deferred=( + "Not implemented in the SDK: abandoning a server-side send_request emits no cancellation " + "notification (the same sender-side gap recorded on protocol:timeout:sends-cancellation), and " + "the client could not act on one anyway because client callbacks run inline in the receive " + "loop, so a cancellation would not even be read until the callback had already finished." + ), + ), + "protocol:cancel:unknown-id-ignored": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#error-handling", + behavior=( + "The receiver silently ignores a cancellation notification referencing an unknown or " + "already-completed request id; no error response is sent and no exception is raised." + ), + ), + "protocol:cancel:sender-targeting": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior=( + "Cancellation notifications reference only requests that were previously issued in the same " + "direction and are believed to still be in flight." + ), + deferred=( + "Not implemented in the SDK: there is no public client-side cancel API to drive (see " + "protocol:cancel:abort-signal), so the sender-side targeting rule has nothing to pin." + ), + ), + "protocol:error:connection-closed": Requirement( + source="sdk", + behavior="Closing the transport fails all in-flight requests with a connection-closed error.", + ), + "protocol:error:internal-error": Requirement( + source=f"{SPEC_BASE_URL}/basic#responses", + behavior=( + "An unhandled exception in a request handler is returned to the caller as JSON-RPC error " + "-32603 Internal error." + ), + divergence=Divergence( + note=( + "The low-level Server returns code 0 (not a defined JSON-RPC code) instead of -32603 and " + "leaks str(exc) as the error message." + ), + ), + ), + "protocol:error:invalid-params": Requirement( + source=f"{SPEC_BASE_URL}/basic#responses", + behavior="A request with malformed params is answered with JSON-RPC error -32602 Invalid params.", + ), + "protocol:error:method-not-found": Requirement( + source=f"{SPEC_BASE_URL}/basic#responses", + behavior="A request whose method has no registered handler is answered with a METHOD_NOT_FOUND error.", + ), + "protocol:meta:related-task": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#related-task-metadata", + behavior="Messages may carry related-task _meta associating them with a task.", + deferred=_TASKS_DEFERRAL, + ), + "meta:request-to-handler": Requirement( + source=f"{SPEC_BASE_URL}/basic#_meta", + behavior="The _meta object the client attaches to a request is visible to the server handler.", + ), + "meta:result-to-client": Requirement( + source=f"{SPEC_BASE_URL}/basic#_meta", + behavior="The _meta object a handler attaches to its result is delivered to the client.", + ), + "protocol:progress:callback": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=( + "Progress notifications emitted by a handler during a request are delivered to the caller's " + "progress callback, in order, with their progress, total, and message." + ), + ), + "protocol:progress:token-injected": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=( + "Supplying a progress callback attaches a progress token to the outgoing request, which the " + "server-side handler can observe in its request metadata." + ), + ), + "protocol:progress:token-unique": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=("Concurrent in-flight requests that each supply a progress callback carry distinct progress tokens."), + ), + "protocol:progress:monotonic": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=( + "The progress value increases with each notification for a given token, even when the total is unknown." + ), + divergence=Divergence( + note=( + "The spec MUST is not enforced: progress values are not validated on either side, so a " + "handler that emits non-increasing values has them forwarded to the callback unchanged." + ), + ), + ), + "protocol:progress:stops-after-completion": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#behavior-requirements", + behavior="Progress notifications for a token stop once the associated request completes.", + divergence=Divergence( + note=( + "send_progress_notification does not check whether the token's request has already " + "completed; the late notification is sent and reaches the client." + ), + ), + ), + "protocol:progress:late-dropped-by-client": Requirement( + source="sdk", + behavior=( + "A progress notification that arrives after its request has completed is not delivered to the " + "original progress callback." + ), + ), + "protocol:progress:no-token": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior="Without a progress callback the request carries no progress token.", + ), + "protocol:progress:client-to-server": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior="A progress notification sent by the client is delivered to the server's progress handler.", + ), + "protocol:timeout:basic": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior=( + "A request that exceeds its read timeout fails with a request-timeout error instead of " + "waiting forever for the response." + ), + ), + "protocol:timeout:max-total": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior="A maximum total timeout is enforced even when progress notifications keep arriving.", + divergence=Divergence( + note=( + "There is no maximum-total-timeout option; only the per-request read timeout exists, so the " + "spec's SHOULD that an overall maximum is always enforced cannot be satisfied." + ), + ), + deferred=( + "Not implemented in the SDK: there is no maximum-total-timeout option; only the per-request " + "read timeout exists." + ), + ), + "protocol:timeout:reset-on-progress": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior="When configured to do so, each progress notification resets the request's read timeout.", + deferred=( + "Not implemented in the SDK: progress notifications do not reset the request read timeout and " + "no option exists to enable that." + ), + ), + "protocol:timeout:sends-cancellation": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior=( + "When a request times out, the sender issues notifications/cancelled for that request before " + "failing the local call." + ), + divergence=Divergence( + note=( + "The client only raises locally and sends nothing on timeout, so the server keeps running the handler." + ), + ), + ), + "protocol:timeout:session-survives": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior="The session continues to serve new requests after an earlier request timed out.", + ), + "protocol:timeout:session-default": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior="A session-level read timeout applies to every request that does not override it.", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Tools + # ═══════════════════════════════════════════════════════════════════════════ + "tools:call:content:audio": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#audio-content", + behavior="A tool result can carry audio content: base64 data with a mimeType.", + ), + "tools:call:content:embedded-resource": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#embedded-resources", + behavior="A tool result can carry an embedded resource with full text or blob contents.", + ), + "tools:call:content:image": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#image-content", + behavior="A tool result can carry image content: base64 data with a mimeType.", + ), + "tools:call:content:mixed": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool-result", + behavior="A tool result can carry multiple content blocks of different types; order is preserved.", + ), + "tools:call:content:resource-link": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#resource-links", + behavior="A tool result can carry a resource_link content block referencing a resource by URI.", + ), + "tools:call:content:text": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#text-content", + behavior="tools/call delivers arguments to the tool handler and returns its text content to the caller.", + ), + "tools:call:concurrent": Requirement( + source="sdk", + behavior=( + "Multiple tool calls in flight on one session are dispatched concurrently, and each caller " + "receives the response to its own request." + ), + ), + "tools:call:elicitation-roundtrip": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#user-interaction-model", + behavior=( + "A tool handler that issues an elicitation receives the client's result and can embed it in " + "the tool call result." + ), + ), + "tools:call:is-error": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#error-handling", + behavior=( + "A tool execution failure is returned as a result with isError true and the failure described " + "in content, not as a JSON-RPC error." + ), + ), + "tools:call:logging-mid-execution": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#log-message-notifications", + behavior=( + "Log notifications emitted by a tool handler during execution reach the client's logging " + "callback before the tool result returns." + ), + ), + "tools:call:progress": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=( + "Progress notifications emitted by a tool handler reach the caller's progress callback before " + "the tool result returns." + ), + ), + "tools:call:sampling-roundtrip": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", + behavior=( + "A tool handler that issues a sampling request receives the client's completion and can embed " + "it in the tool call result." + ), + ), + "tools:call:structured-content": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#structured-content", + behavior="A tool result can carry structuredContent alongside content; the client receives both.", + ), + "tools:call:structured-content:text-mirror": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#structured-content", + behavior="A tool returning structured content also returns the serialized JSON as a text content block.", + ), + "tools:call:unknown-name": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#error-handling", + behavior="tools/call for a name the server does not recognise returns a JSON-RPC error.", + ), + "tools:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#capabilities", + behavior="A server with a list_tools handler advertises the tools capability in its initialize result.", + ), + "tools:input-schema:json-schema-2020-12": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool", + behavior=( + "A tool registered with a JSON Schema 2020-12 inputSchema (nested objects, $defs references) " + "is discoverable and callable." + ), + ), + "tools:input-schema:preserve-additional-properties": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool", + behavior="tools/list preserves inputSchema additionalProperties as registered.", + ), + "tools:input-schema:preserve-defs": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool", + behavior="tools/list preserves inputSchema $defs as registered.", + ), + "tools:input-schema:preserve-schema-dialect": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool", + behavior="tools/list preserves the inputSchema $schema dialect URI as registered.", + ), + "tools:list-changed": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#list-changed-notification", + behavior=( + "When the tool set changes, the server sends notifications/tools/list_changed and it reaches " + "the client's handler." + ), + ), + "tools:list:basic": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#listing-tools", + behavior="tools/list returns the registered tools with name, description, and inputSchema.", + ), + "tools:list:metadata": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool", + behavior=( + "Optional Tool fields supplied by the server (title, annotations, outputSchema, icons, _meta) " + "are delivered to the client unchanged." + ), + ), + "tools:list:pagination": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#response-format", + behavior=( + "tools/list supports cursor pagination: the nextCursor returned by a list handler round-trips " + "back to the handler as an opaque cursor until the listing is exhausted." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Tools: SDK guarantees + # ═══════════════════════════════════════════════════════════════════════════ + "client:output-schema:skip-on-error": Requirement( + source="sdk", + behavior="The client skips structured-content validation when the tool result has isError true.", + ), + "client:output-schema:validate": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#output-schema", + behavior=( + "A tool result whose structuredContent does not conform to the tool's declared outputSchema " + "is rejected by the client: the call raises instead of returning the invalid result." + ), + ), + "client:output-schema:missing-structured": Requirement( + source="sdk", + behavior="A tool that declares an output schema but returns no structuredContent fails client-side validation.", + ), + "client:output-schema:auto-list": Requirement( + source="sdk", + behavior=( + "Calling a tool whose output schema is not yet cached issues an implicit tools/list to " + "populate the cache; subsequent calls of the same tool do not." + ), + divergence=Divergence( + note=( + "Design concern rather than spec violation: the implicit request is invisible to the " + "caller, and against a server that registers only on_call_tool a successful call surfaces " + "as METHOD_NOT_FOUND from a tools/list the caller never asked for." + ), + ), + ), + "mcpserver:output-schema:missing-structured": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#output-schema", + behavior="A tool with an output schema whose function returns no structured content produces a server error.", + ), + "mcpserver:output-schema:server-validate": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#output-schema", + behavior=( + "MCPServer validates structured content against the tool's output schema before returning; a " + "mismatch produces a server error." + ), + ), + "mcpserver:output-schema:skip-on-error": Requirement( + source="sdk", + behavior="Server-side output schema validation is skipped when the tool returns an isError result.", + ), + "mcpserver:tool:duplicate-name": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool-names", + behavior="Registering a tool with a name already in use is rejected at registration time.", + divergence=Divergence( + note=( + "MCPServer logs a warning and keeps the first registration instead of rejecting; " + "warn_on_duplicate_tools defaults to True and warning is the only effect -- there is " + "no rejection mode." + ), + ), + ), + "mcpserver:tool:extra": Requirement( + source="sdk", + behavior=( + "Tool functions can access request metadata (request id, client params, session) through the " + "Context parameter." + ), + ), + "mcpserver:tool:handler-throws": Requirement( + source="sdk", + behavior=( + "An exception raised by a tool function (ToolError or otherwise) is caught and returned as a " + "tool result with isError true and the failure text in content; it does not become a JSON-RPC error." + ), + ), + "mcpserver:tool:input-validation": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#error-handling", + behavior=( + "Arguments that fail the tool's input validation produce a tool execution error (isError true " + "with the validation failure described in content) without invoking the function." + ), + ), + "mcpserver:tool:naming-validation": Requirement( + source="sdk", + behavior=( + "Registering a tool whose name violates the spec's tool-naming conventions emits a warning; " + "registration still succeeds." + ), + ), + "mcpserver:tool:output-schema:model": Requirement( + source="sdk", + behavior=( + "A tool returning a typed model advertises a matching generated outputSchema and returns the " + "model's fields as structuredContent alongside a serialised text block." + ), + ), + "mcpserver:tool:output-schema:wrapped": Requirement( + source="sdk", + behavior=( + "A tool returning a non-object type (primitive or list) wraps the value as {'result': ...} in " + "structuredContent, with a matching generated outputSchema." + ), + ), + "mcpserver:tool:schema-variants": Requirement( + source="sdk", + behavior=( + "Tool input schemas generated from complex parameter types (unions, nested models, " + "constrained types) validate and coerce arguments before the function runs." + ), + ), + "mcpserver:tool:unknown-name": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#error-handling", + behavior="tools/call for a name that was never registered returns a JSON-RPC error.", + divergence=Divergence( + note=( + "The spec classifies unknown tools as a protocol error (its example uses -32602 Invalid " + "params); MCPServer reports a tool execution error (isError true) instead. The low-level " + "path follows the spec example (see tools:call:unknown-name)." + ), + ), + ), + "mcpserver:tool:url-elicitation-error": Requirement( + source="sdk", + behavior=( + "A tool function that raises the URL-elicitation-required error surfaces to the caller as " + "error -32042 with the elicitation parameters intact." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # MCPServer: Context helpers (SDK) + # ═══════════════════════════════════════════════════════════════════════════ + "mcpserver:context:logging": Requirement( + source="sdk", + behavior=( + "The Context logging helpers (debug/info/warning/error) send log message notifications at the " + "corresponding severity." + ), + ), + "mcpserver:context:progress": Requirement( + source="sdk", + behavior=( + "Context.report_progress sends a progress notification against the requesting client's progress token." + ), + ), + "mcpserver:context:elicit": Requirement( + source="sdk", + behavior=( + "Context.elicit sends a form elicitation built from a typed schema and returns a typed " + "accepted/declined/cancelled result." + ), + ), + "mcpserver:context:read-resource": Requirement( + source="sdk", + behavior="Context.read_resource reads a resource registered on the same server from inside a tool.", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Resources + # ═══════════════════════════════════════════════════════════════════════════ + "resources:annotations": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#annotations", + behavior="Resource annotations supplied by the server round-trip to the client in the list result.", + divergence=Divergence( + note=( + "The SDK Annotations model is missing the schema's lastModified field; MCPModel uses the " + "pydantic default extra='ignore', so the value is silently dropped on parse." + ), + ), + ), + "resources:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#capabilities", + behavior=( + "A server with resource handlers advertises the resources capability, including the subscribe " + "sub-flag when a subscribe handler is registered." + ), + ), + "resources:list-changed": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#list-changed-notification", + behavior=( + "When the resource set changes, the server sends notifications/resources/list_changed and it " + "reaches the client's handler." + ), + ), + "resources:list:basic": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#listing-resources", + behavior=( + "resources/list returns the registered resources with uri, name, and the optional descriptive " + "fields supplied by the server." + ), + ), + "resources:list:pagination": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#operations-supporting-pagination", + behavior="resources/list supports cursor pagination.", + ), + "resources:read:blob": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#reading-resources", + behavior="resources/read returns binary contents base64-encoded in blob.", + ), + "resources:read:template-vars": Requirement( + source="sdk", + behavior="Variables extracted from a templated resource URI reach the resource function as typed arguments.", + ), + "resources:read:text": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#reading-resources", + behavior="resources/read returns text contents carrying uri, mimeType, and the text.", + ), + "resources:read:unknown-uri": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#error-handling", + behavior="resources/read for an unknown URI returns JSON-RPC error -32002 (resource not found).", + ), + "resources:subscribe": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#subscriptions", + behavior="resources/subscribe delivers the URI to the server's subscribe handler and returns an empty result.", + ), + "resources:subscribe:capability-required": Requirement( + source="sdk", + behavior=( + "resources/subscribe to a server that did not advertise the subscribe capability is rejected with an error." + ), + ), + "resources:subscribe:updated": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#subscriptions", + behavior="After resources/subscribe, changes to that resource send notifications/resources/updated.", + deferred=( + "Not implemented in the SDK: the server keeps no subscription state linking subscribe to " + "updated notifications; emitting updates is entirely handler code. The two halves are pinned " + "separately by resources:subscribe and resources:updated-notification." + ), + ), + "resources:templates:list": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#resource-templates", + behavior=( + "resources/templates/list returns the registered templates with their uriTemplate and descriptive fields." + ), + ), + "resources:templates:pagination": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#operations-supporting-pagination", + behavior="resources/templates/list supports cursor pagination.", + ), + "resources:unsubscribe": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#subscriptions", + behavior=( + "resources/unsubscribe delivers the URI to the server's unsubscribe handler and returns an empty result." + ), + ), + "resources:unsubscribe:stops-updates": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#subscriptions", + behavior="After resources/unsubscribe the server stops sending updated notifications for that URI.", + deferred=( + "Not implemented in the SDK: the server keeps no subscription state, so whether updated " + "notifications stop after unsubscribe is entirely handler code; there is no SDK behaviour to " + "pin beyond the unsubscribe request reaching the handler (covered by resources:unsubscribe)." + ), + ), + "resources:updated-notification": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#subscriptions", + behavior=( + "A resources/updated notification sent by the server reaches the client carrying the URI of " + "the changed resource." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Resources: SDK guarantees + # ═══════════════════════════════════════════════════════════════════════════ + "mcpserver:resource:duplicate-name": Requirement( + source="sdk", + behavior="Registering a resource or template with a duplicate identifier is rejected at registration time.", + divergence=Divergence( + note=( + "MCPServer logs a warning and keeps the first registration instead of rejecting; same " + "warn-and-ignore behaviour as duplicate tool names (mcpserver:tool:duplicate-name). " + "Templates differ: a duplicate uri_template silently replaces the first with no warning." + ), + ), + ), + "mcpserver:resource:read-throws-surfaced": Requirement( + source="sdk", + behavior="A resource function that raises is surfaced to the caller as a JSON-RPC error response.", + ), + "mcpserver:resource:static": Requirement( + source="sdk", + behavior=( + "A function registered with @mcp.resource() for a fixed URI is listed by resources/list and " + "served by resources/read at that URI." + ), + ), + "mcpserver:resource:template": Requirement( + source="sdk", + behavior=( + "A function registered with a URI template is listed by resources/templates/list and matched " + "by resources/read, receiving the parameters extracted from the requested URI." + ), + ), + "mcpserver:resource:unknown-uri": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#error-handling", + behavior="resources/read for a URI matching no registered resource returns JSON-RPC error -32002.", + divergence=Divergence( + note=( + "The spec reserves -32002 for resource-not-found; MCPServer raises ResourceError, which " + "the low-level server converts to error code 0." + ), + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Prompts + # ═══════════════════════════════════════════════════════════════════════════ + "prompts:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#capabilities", + behavior="A server with a list_prompts handler advertises the prompts capability in its initialize result.", + ), + "prompts:get:content:audio": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#audio-content", + behavior="Prompt messages may contain audio content with base64 data and a mimeType.", + ), + "prompts:get:content:embedded-resource": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#embedded-resources", + behavior="Prompt messages may contain embedded resource content.", + ), + "prompts:get:content:image": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#image-content", + behavior="Prompt messages may contain image content.", + ), + "prompts:get:missing-required-args": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#error-handling", + behavior="prompts/get omitting a required argument returns JSON-RPC error -32602 (Invalid params).", + divergence=Divergence( + note=( + "MCPServer's prompt renderer raises a plain ValueError before the prompt function runs, " + "which the low-level server converts to error code 0 with the exception text as the message." + ), + ), + ), + "prompts:get:multi-message": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#getting-a-prompt", + behavior="A prompt can return multiple messages mixing user and assistant roles; order is preserved.", + ), + "prompts:get:no-args": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#getting-a-prompt", + behavior="prompts/get with no arguments returns the prompt's messages.", + ), + "prompts:get:unknown-name": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#error-handling", + behavior="prompts/get for an unknown prompt name returns JSON-RPC error -32602 (Invalid params).", + ), + "prompts:get:with-args": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#getting-a-prompt", + behavior="prompts/get delivers the supplied arguments to the prompt handler and returns its messages.", + ), + "prompts:list-changed": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#list-changed-notification", + behavior=( + "When the prompt set changes, the server sends notifications/prompts/list_changed and it " + "reaches the client's handler." + ), + ), + "prompts:list:basic": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#listing-prompts", + behavior="prompts/list returns the registered prompts with name, description, and argument declarations.", + ), + "prompts:list:pagination": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#operations-supporting-pagination", + behavior="prompts/list supports cursor pagination.", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Prompts: SDK guarantees + # ═══════════════════════════════════════════════════════════════════════════ + "mcpserver:prompt:args-validation": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#implementation-considerations", + behavior="prompts/get arguments that fail the prompt's argument schema are rejected before the function runs.", + ), + "mcpserver:prompt:decorated": Requirement( + source="sdk", + behavior=( + "A function registered with @mcp.prompt() is listed with arguments derived from its signature " + "and rendered into prompt messages by prompts/get." + ), + ), + "mcpserver:prompt:duplicate-name": Requirement( + source="sdk", + behavior="Registering a duplicate prompt name is rejected at registration time.", + divergence=Divergence( + note=( + "MCPServer logs a warning and keeps the first registration instead of rejecting; same " + "warn-and-ignore behaviour as duplicate tool names (mcpserver:tool:duplicate-name)." + ), + ), + ), + "mcpserver:prompt:optional-args": Requirement( + source="sdk", + behavior="A prompt with optional arguments can be fetched without supplying them.", + ), + "mcpserver:prompt:unknown-name": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#error-handling", + behavior="prompts/get for a name that was never registered returns JSON-RPC error -32602 (Invalid params).", + divergence=Divergence( + note=( + "The spec's example uses -32602 Invalid params for unknown prompts; MCPServer raises " + "ValueError, which the low-level server converts to error code 0." + ), + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Completion + # ═══════════════════════════════════════════════════════════════════════════ + "completion:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#capabilities", + behavior="A server with a completion handler advertises the completions capability in its initialize result.", + ), + "completion:complete:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#capabilities", + behavior=( + "A server with no completion handler does not advertise the completions capability and rejects " + "completion/complete with METHOD_NOT_FOUND." + ), + ), + "completion:context-arguments": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#requesting-completions", + behavior="Previously-resolved argument values supplied in context.arguments reach the completion handler.", + ), + "completion:error:invalid-ref": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#error-handling", + behavior=( + "completion/complete with a ref naming an unknown prompt or non-matching resource URI returns " + "JSON-RPC error -32602 (Invalid params)." + ), + ), + "completion:prompt-arg": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#reference-types", + behavior="completion/complete with a ref/prompt returns suggested values for the named prompt argument.", + ), + "completion:resource-template-arg": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#reference-types", + behavior="completion/complete with a ref/resource returns suggested values for a URI template variable.", + ), + "completion:result-shape": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#completion-results", + behavior="The completion result carries values (at most 100), an optional total, and an optional hasMore flag.", + ), + "mcpserver:completion:capability-auto": Requirement( + source="sdk", + behavior=( + "MCPServer advertises the completions capability when at least one completion source is " + "registered, and omits it otherwise." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Logging + # ═══════════════════════════════════════════════════════════════════════════ + "logging:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#capabilities", + behavior=( + "A server that emits log message notifications declares the logging capability in its initialize result." + ), + divergence=Divergence( + note=( + "MCPServer registers no setLevel handler, so capability derivation leaves logging unset " + "even though the Context helpers send log message notifications." + ), + ), + ), + "logging:message:all-levels": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#log-levels", + behavior="All eight RFC 5424 severity levels are deliverable as log message notifications.", + ), + "logging:message:fields": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#log-message-notifications", + behavior=( + "A log message sent by a server handler is delivered to the client's logging callback with its " + "severity level, logger name, and data." + ), + ), + "logging:message:filtered": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#setting-log-level", + behavior="After logging/setLevel, log messages below the configured level are not sent.", + divergence=Divergence( + note=( + "Neither MCPServer (which rejects logging/setLevel with method-not-found) nor the " + "low-level Server (which leaves the handler entirely to the author) implements any " + "filtering; messages are delivered at every severity regardless of the requested level." + ), + ), + ), + "logging:set-level": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#setting-log-level", + behavior="logging/setLevel delivers the requested level to the server's handler and returns an empty result.", + ), + "logging:set-level:invalid-level": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#error-handling", + behavior="logging/setLevel with an invalid level value returns JSON-RPC error -32602 (Invalid params).", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Sampling (server → client) + # ═══════════════════════════════════════════════════════════════════════════ + "sampling:capability:declare": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#capabilities", + behavior=( + "A client that handles sampling requests advertises the sampling capability in its initialize request." + ), + ), + "sampling:create:basic": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", + behavior=( + "A sampling/createMessage request from a server handler is answered by the client's sampling " + "callback, and the callback's result (role, content, model, stopReason) is returned to the handler." + ), + ), + "sampling:create:include-context": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#capabilities", + behavior="The includeContext value supplied by the server reaches the client callback intact.", + ), + "sampling:context:server-gated-by-capability": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#capabilities", + behavior=( + "The server does not use includeContext values thisServer or allServers unless the client " + "declared the sampling.context capability." + ), + divergence=Divergence( + note=( + "include_context is forwarded regardless of the client's declared sampling.context " + "capability; the server-side validator only checks tools/tool_choice." + ), + ), + ), + "sampling:create:model-preferences": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#model-preferences", + behavior=( + "The model preferences supplied by the server (hints and the cost, speed, and intelligence " + "priorities) reach the client callback intact." + ), + ), + "sampling:create:system-prompt": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", + behavior="The system prompt supplied by the server reaches the client callback intact.", + ), + "sampling:create:tools": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#tools-in-sampling", + behavior=( + "A sampling request carrying tools and toolChoice reaches the client, and a tool_use response " + "with a toolUse stop reason returns to the requesting handler." + ), + deferred=( + "Not implemented in the SDK: Client does not expose ClientSession's sampling_capabilities " + "parameter, so a client can never declare sampling.tools and the server-side validator " + "rejects every tool-enabled request before it is sent." + ), + ), + "sampling:create-message:audio-content": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#audio-content", + behavior="Sampling messages can carry audio content: base64 data with a mimeType.", + ), + "sampling:create-message:image-content": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#image-content", + behavior="Sampling messages can carry image content: base64 data with a mimeType.", + ), + "sampling:create-message:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#capabilities", + behavior=( + "A sampling request to a client that did not declare the sampling capability fails with an " + "error rather than hanging or being silently dropped; the spec names no error code for this case." + ), + ), + "sampling:error:user-rejected": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#error-handling", + behavior=( + "A sampling request the user rejects is answered with a JSON-RPC error (the spec's code for " + "this case is -1, 'User rejected sampling request'), surfaced to the requesting handler as an MCPError." + ), + ), + "sampling:message:content-cardinality": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling", + behavior="A sampling message's content may be a single block or an array of blocks.", + ), + "sampling:result:no-tools-single-content": Requirement( + source="sdk", + behavior=( + "When the request carries no tools, a sampling callback result whose content is an array is " + "rejected by the client." + ), + divergence=Divergence( + note=( + "The client does not validate the callback result against the request shape; an array-content " + "result for a tool-free request is accepted client-side and surfaces as a raw " + "pydantic.ValidationError from the server's response parsing (send_request) instead." + ), + ), + ), + "sampling:result:with-tools-array-content": Requirement( + source="sdk", + behavior=( + "When the request includes tools, the client accepts a callback result whose content is an " + "array including tool_use blocks." + ), + deferred=( + "Not implemented in the SDK: requires declaring sampling.tools, which the high-level client " + "cannot do (see sampling:create:tools)." + ), + ), + "sampling:tool-result:no-mixed-content": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#tool-result-messages", + behavior=( + "A user sampling message that carries tool_result content contains only tool_result blocks; " + "mixing tool_result with text, image, or audio content is rejected as invalid." + ), + ), + "sampling:tool-use:result-balance": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#tool-use-and-result-balance", + behavior=( + "In a sampling/createMessage request, every assistant tool_use block in messages MUST be " + "matched by a tool_result with the same toolUseId in the immediately-following user message; " + "an unmatched tool_use is rejected with -32602 Invalid params." + ), + divergence=Divergence( + note=( + "The client does not validate inbound tool_use/tool_result balance; the SDK enforces " + "the rule server-side instead, before the request leaves the server (see " + "sampling:tool-use:server-preflight)." + ), + ), + deferred=( + "Not implemented on the client receive path: validation runs only on the server send path " + "(pinned by sampling:tool-use:server-preflight)." + ), + ), + "sampling:tool-use:server-preflight": Requirement( + source="sdk", + behavior=( + "The server validates tool_use/tool_result balance before sending a sampling/createMessage " + "request; an unmatched tool_use raises ValueError and the request never reaches the wire." + ), + ), + "sampling:tools:server-gated-by-capability": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#tools-in-sampling", + behavior=( + "A tool-enabled sampling request to a client that did not declare sampling.tools is rejected " + "by the server before anything reaches the wire (the SDK surfaces this as an Invalid params error)." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Elicitation (server → client) + # ═══════════════════════════════════════════════════════════════════════════ + "elicitation:capability:empty-is-form": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#capabilities", + behavior="A client advertising an empty elicitation capability accepts form-mode elicitation requests.", + deferred=( + "Not implemented in the SDK: a Client with an elicitation callback always declares explicit " + "form and url sub-capabilities, so an empty elicitation capability cannot be produced through " + "the public API." + ), + ), + "elicitation:capability:mode-mismatch": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#error-handling", + behavior=( + "The client answers elicitation requests for a mode it did not advertise with JSON-RPC error " + "-32602 (Invalid params)." + ), + deferred=( + "Not implemented in the SDK: a client cannot be configured form-only or url-only, so the " + "per-mode mismatch error cannot arise (see elicitation:url:not-supported)." + ), + ), + "elicitation:capability:server-respects-mode": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#capabilities", + behavior=( + "The server refuses to send an elicitation request with a mode the connected client did not " + "declare in its capabilities." + ), + divergence=Divergence( + note=( + "The server does not check the client's declared elicitation modes before sending " + "elicitation/create; the spec's MUST NOT is not enforced." + ), + ), + ), + "elicitation:form:action:accept": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior=( + "A form-mode elicitation answered with action 'accept' returns the user's content to the " + "requesting handler." + ), + ), + "elicitation:form:action:cancel": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior="A form-mode elicitation answered with action 'cancel' returns no content to the handler.", + ), + "elicitation:form:action:decline": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior="A form-mode elicitation answered with action 'decline' returns no content to the handler.", + ), + "elicitation:form:basic": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#form-mode-elicitation-requests", + behavior=( + "A form-mode elicitation delivers the message and requested schema to the client callback " + "exactly as the server sent them." + ), + ), + "elicitation:form:defaults": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#requested-schema", + behavior=( + "Optional default values declared in a form-mode requested schema are pre-populated into the " + "form presented to the user." + ), + deferred=( + "Not implemented in the SDK: there is no form-rendering layer that could pre-populate " + "defaults; client callbacks receive the requested schema as-is." + ), + ), + "elicitation:form:mode-omitted-default": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#elicitation-requests", + behavior="An elicitation request with no mode field is treated as form mode by the client.", + ), + "elicitation:form:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#error-handling", + behavior=( + "An elicitation request to a client that did not declare the elicitation capability is " + "answered with -32602 Invalid params." + ), + divergence=Divergence( + note="The client's default callback answers with -32600 Invalid request instead of -32602.", + ), + ), + "elicitation:form:schema:enum-variants": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#requested-schema", + behavior=( + "Requested-schema enum fields (including titled and multi-select variants) reach the client " + "callback as sent." + ), + ), + "elicitation:form:schema:primitives": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#requested-schema", + behavior="Requested-schema fields may be string (with format), number or integer, or boolean.", + ), + "elicitation:form:schema:restricted-subset": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#requested-schema", + behavior=( + "Form-mode requested schemas are flat objects with primitive-typed properties only; nested " + "structures and arrays of objects are not used." + ), + divergence=Divergence( + note=( + "ServerSession.elicit_form forwards an arbitrary dict[str, Any] schema unchanged; no shape " + "validation at the low-level session layer (the high-level Context.elicit / " + "elicit_with_validation helper enforces primitive-only fields before generating the schema)." + ), + ), + ), + "elicitation:form:response-validation": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#form-mode-security", + behavior=( + "Accepted form-mode content is validated against the requested schema: the client validates " + "the response before sending and the server validates the content it receives." + ), + divergence=Divergence( + note=( + "The client never validates outbound content; ServerSession.elicit_form returns received " + "content unvalidated (the high-level Context.elicit / elicit_with_validation helper " + "validates server-side, but the low-level session API does not)." + ), + ), + ), + "elicitation:url:action:accept-no-content": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior=( + "A URL-mode elicitation delivers the message, URL, and elicitationId to the client; an accept " + "response carries no content (accept means the user agreed to visit the URL, not that the " + "interaction completed)." + ), + ), + "elicitation:url:basic": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#url-mode-elicitation-requests", + behavior=( + "A url-mode elicitation delivers the elicitation id and URL to the client callback exactly as " + "the server sent them." + ), + ), + "elicitation:url:cancel": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior="A URL-mode elicitation answered with cancel returns the action with no content.", + ), + "elicitation:url:complete-notification": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#completion-notifications-for-url-mode-elicitation", + behavior=( + "An elicitation/complete notification sent by the server after an out-of-band elicitation " + "finishes reaches the client carrying the elicitationId." + ), + ), + "elicitation:url:complete-unknown-ignored": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#completion-notifications-for-url-mode-elicitation", + behavior=( + "The client ignores an elicitation/complete notification referencing an unknown or " + "already-completed elicitationId without error." + ), + ), + "elicitation:url:decline": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior="A URL-mode elicitation answered with decline returns the action with no content.", + ), + "elicitation:url:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#error-handling", + behavior=( + "A URL-mode elicitation to a client that declared only form-mode support is rejected with an " + "Invalid params error." + ), + deferred=( + "Not implemented in the SDK: a Client with an elicitation callback always declares both the " + "form and url sub-capabilities, so a form-only client cannot be constructed." + ), + ), + "elicitation:url:required-error": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#url-elicitation-required-error", + behavior=( + "A handler that cannot proceed without a URL elicitation rejects the request with error " + "-32042, carrying the pending elicitations in the error data." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Roots (server → client) + # ═══════════════════════════════════════════════════════════════════════════ + "roots:list-changed": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#root-list-changes", + behavior="A roots/list_changed notification sent by the client is delivered to the server's handler.", + ), + "roots:list-changed:client-emits": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#root-list-changes", + behavior=( + "A client that declared roots.listChanged sends notifications/roots/list_changed when its set " + "of roots changes." + ), + deferred=( + "Not implemented in the SDK: the client does not own the root set (it calls back to the host " + "via list_roots_callback), so there is no mutation it could observe to auto-emit on; the SDK " + "provides send_roots_list_changed() for the host to call when its roots change, and that " + "emission path is covered by roots:list-changed." + ), + ), + "roots:list:basic": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#listing-roots", + behavior=( + "A roots/list request from a server handler is answered by the client's roots callback, and " + "the returned roots (uri, name) reach the handler." + ), + ), + "roots:list:client-error": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#error-handling", + behavior="A roots callback that answers with an error surfaces to the requesting handler as an MCPError.", + ), + "roots:list:empty": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#listing-roots", + behavior="An empty roots list is a valid response and reaches the handler as such.", + ), + "roots:list:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#error-handling", + behavior=( + "A roots/list request to a client that did not declare the roots capability is answered with " + "-32601 Method not found." + ), + divergence=Divergence( + note="The client's default callback answers with -32600 Invalid request instead of -32601.", + ), + ), + "roots:uri:file-scheme": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#root", + behavior="Every root returned by the client identifies itself with a file:// URI.", + deferred=( + "Schema-level validation: the FileUrl type on Root.uri rejects any non-file:// scheme at " + "construction and at parse, so a non-conforming root cannot reach the wire from either side; " + "type-level coverage belongs in tests/test_types.py rather than this interaction suite." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # list_changed & dynamic registration + # ═══════════════════════════════════════════════════════════════════════════ + "client:list-changed:auto-refresh": Requirement( + source="sdk", + behavior=( + "A client configured to react to list_changed notifications automatically re-fetches the " + "corresponding list and delivers the fresh result to its callback." + ), + deferred=( + "Not implemented in the SDK: the client has no list-changed auto-refresh mechanism; " + "notifications are only delivered to the message handler." + ), + ), + "client:list-changed:capability-gated": Requirement( + source="sdk", + behavior=( + "The client does not activate list-changed handling for a kind the server did not advertise " + "with listChanged true." + ), + deferred="Not implemented in the SDK: no client-side list-changed handling exists to gate.", + ), + "client:list-changed:signal-only": Requirement( + source="sdk", + behavior="A client configured for signal-only list-changed handling is notified without auto-refreshing.", + deferred="Not implemented in the SDK: no client-side list-changed handling exists.", + ), + "mcpserver:list-changed:debounce": Requirement( + source="sdk", + behavior=( + "Bursts of registration changes on MCPServer are debounced into one list_changed notification per kind." + ), + deferred=( + "Not implemented in the SDK: MCPServer does not send list_changed notifications on " + "registration changes at all (see mcpserver:register:post-connect), so there is nothing to " + "debounce." + ), + ), + "mcpserver:register:post-connect": Requirement( + source="sdk", + behavior=( + "A tool, resource, or prompt registered or removed after the client connected appears in (or " + "disappears from) the corresponding list results, and the change is announced with a " + "list_changed notification." + ), + divergence=Divergence( + note=( + "MCPServer never sends list_changed notifications on registration changes, so a connected " + "client cannot learn that the set changed without polling." + ), + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Pagination + # ═══════════════════════════════════════════════════════════════════════════ + "pagination:exhaustion": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#response-format", + behavior=( + "Following nextCursor until it is absent yields every page exactly once; a result without " + "nextCursor ends the sequence." + ), + ), + "pagination:invalid-cursor": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#error-handling", + behavior="A list request with an invalid cursor returns JSON-RPC error -32602 (Invalid params).", + ), + "pagination:client:cursor-handling": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#implementation-guidelines", + behavior=( + "The client treats cursors as opaque tokens — it does not parse, modify, or persist them — " + "and does not assume a fixed page size." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Tasks (experimental) + # ═══════════════════════════════════════════════════════════════════════════ + "tasks:auth:context-isolation": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-isolation-and-access-control", + behavior=( + "When an authorization context is available, task operations are scoped to the context that " + "created the task: other contexts cannot get it, retrieve its result, cancel it, or see it in " + "tasks/list." + ), + transports=("streamable-http",), + deferred=_TASKS_DEFERRAL, + ), + "tasks:bidirectional": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#definitions", + behavior="Task APIs are bidirectional: the server may create, get, list, and cancel tasks on the client.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:cancel:no-handler-abort": Requirement( + source="sdk", + behavior=( + "tasks/cancel marks the task cancelled without aborting the originating request handler " + "(the spec says receivers SHOULD attempt to stop execution)." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:cancel:remains-cancelled": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-cancellation", + behavior=( + "After tasks/cancel, the task remains cancelled even if the underlying handler subsequently " + "completes or fails." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:cancel:terminal-rejected": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-cancellation", + behavior="tasks/cancel on a task already in a terminal state returns Invalid params (-32602).", + deferred=_TASKS_DEFERRAL, + ), + "tasks:cancel:working": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-cancellation", + behavior="tasks/cancel on a working task transitions it to cancelled and returns the updated task.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:create:ttl-honored": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#ttl-and-resource-management", + behavior=( + "tasks/get responses include the actual ttl applied by the receiver (or null for unlimited); " + "the create-task result carries the same value." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:create:via-tool-call": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#creating-tasks", + behavior="A task-augmented tools/call returns a create-task result instead of the tool result.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:get": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#getting-tasks", + behavior="tasks/get returns the task's current status, ttl, timestamps, and status message.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:lifecycle:initial-working": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-status-lifecycle", + behavior="A newly created task has status 'working'.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:lifecycle:input-required": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#input-required-status", + behavior=( + "While a task awaits a side-channel client response its status is input_required; once the " + "response arrives the task leaves input_required (typically returning to working)." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:list:invalid-cursor": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#protocol-errors", + behavior="tasks/list with an invalid cursor returns Invalid params (-32602).", + deferred=_TASKS_DEFERRAL, + ), + "tasks:list:pagination": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#listing-tasks", + behavior="tasks/list returns created tasks and supports cursor pagination.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:no-capability:ignore-task-param": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-support-and-handling", + behavior=( + "A receiver that did not declare task capability for a request type processes the request " + "normally and returns the ordinary result, ignoring the task augmentation." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:progress:after-create": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-progress-notifications", + behavior=( + "After the create-task result, progress notifications keyed to the original progress token " + "continue to reach the caller until the task is terminal." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:request-cancel:no-task-cancel": Requirement( + source="sdk", + behavior="A cancellation notification for the originating request does not auto-cancel the created task.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:result:failed": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-execution-errors", + behavior="tasks/result for a failed task returns the failure result (isError true).", + deferred=_TASKS_DEFERRAL, + ), + "tasks:result:related-task-meta": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#related-task-metadata", + behavior="The tasks/result response carries related-task _meta naming the requested task.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:result:terminal": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#result-retrieval", + behavior="tasks/result for a completed task returns the stored result of the original request type.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:drain-fifo": Requirement( + source="sdk", + behavior="tasks/result drains queued related-task messages in FIFO order before returning the final result.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:drop-on-cancel": Requirement( + source="sdk", + behavior="When a task is cancelled before tasks/result, queued related-task messages are dropped.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:elicitation": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#input-required-status", + behavior=( + "An elicitation issued mid-task is delivered through the tasks/result side-channel, and the " + "client's response routes back to the handler." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:queue": Requirement( + source="sdk", + behavior=( + "Server-to-client requests with related-task metadata sent while no tasks/result is open are queued." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:sampling": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#input-required-status", + behavior=( + "A sampling request issued mid-task is delivered through the tasks/result side-channel, and " + "the client's response routes back to the task." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:stream": Requirement( + source="sdk", + behavior=( + "Calling tasks/result while the task is working streams related-task messages as they are " + "produced, then returns the result." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:status-notification": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-status-notification", + behavior="Task status notifications deliver status updates carrying the full task fields.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:tool-level:forbidden-with-task-32601": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#tool-level-negotiation", + behavior=( + "A task-augmented tools/call on a tool that does not support tasks returns Method not found (-32601)." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:tool-level:required-no-task-32601": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#tool-level-negotiation", + behavior=("A plain tools/call on a tool that requires task augmentation returns Method not found (-32601)."), + deferred=_TASKS_DEFERRAL, + ), + "tasks:unknown-id": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#protocol-errors", + behavior="tasks/get, tasks/result, and tasks/cancel for an unknown task id return Invalid params (-32602).", + deferred=_TASKS_DEFERRAL, + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Transports (in-suite coverage) + # ═══════════════════════════════════════════════════════════════════════════ + "transport:streamable-http:stateful": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "The interaction round trip (initialize, tool calls, tool errors) works through the " + "streamable HTTP framing in its default stateful SSE-response mode." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:json-response": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior="The interaction round trip works when the server answers with plain JSON instead of SSE.", + transports=("streamable-http",), + ), + "transport:streamable-http:stateless": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "The interaction round trip works in stateless mode, where every request is served by a " + "fresh transport with no session id." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:notifications": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "Notifications emitted during a request are delivered on that request's SSE stream and reach " + "the client's callbacks, in order, before the response." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:stateless-restrictions": Requirement( + source="sdk", + behavior=( + "A handler that attempts a server-initiated request in stateless mode fails with an error " + "result, because there is no session to call back through." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:unrelated-messages": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "A server-to-client message that is not related to an in-flight request is routed to the " + "standalone GET stream and delivered to the client listening on it, not to any request's " + "own stream." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:server-to-client": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "A server-initiated request nested inside an in-flight call round-trips over stateful streamable HTTP." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:resumability": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior="A client that reconnects with Last-Event-ID receives the events it missed.", + transports=("streamable-http",), + ), + "transport:streamable-http:origin-validation": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#security-warning", + behavior="Requests with an invalid Origin header are rejected with 403 before reaching the session.", + transports=("streamable-http",), + ), + "transport:sse": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#backwards-compatibility", + behavior=( + "A client connected over the legacy HTTP+SSE transport completes the handshake and round-trips " + "requests, with server messages delivered on the SSE stream." + ), + transports=("sse",), + ), + "transport:sse:endpoint-event": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#backwards-compatibility", + behavior="Opening the SSE stream delivers an `endpoint` event naming the message-POST URL as the first event.", + transports=("sse",), + ), + "transport:sse:post:session-routing": Requirement( + source="sdk", + behavior=( + "The endpoint URL carries a fresh session identifier; the server registers the session before " + "the endpoint event is sent and releases it when the stream disconnects, and a POST that names " + "no session id, a malformed session id, or an unknown session id is rejected (400/400/404)." + ), + transports=("sse",), + ), + "transport:stdio": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#stdio", + behavior=( + "A Client connected to a real SDK Server over stdio initializes, calls a tool with arguments, " + "and receives notifications and results over the child process's stdin/stdout." + ), + transports=("stdio",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Hosting: session lifecycle + # ═══════════════════════════════════════════════════════════════════════════ + "hosting:session:cors-expose": Requirement( + source="sdk", + behavior="CORS configuration exposes the Mcp-Session-Id header so browser clients can read it.", + transports=("streamable-http",), + deferred="Not implemented in the SDK: CORS configuration is left to the hosting ASGI application.", + ), + "hosting:session:create": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior=( + "An initialize POST without a session id creates a session and returns Mcp-Session-Id in the " + "response headers." + ), + transports=("streamable-http",), + ), + "hosting:session:delete": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="DELETE with a valid Mcp-Session-Id terminates the session.", + transports=("streamable-http",), + ), + "hosting:session:id-charset": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="Generated Mcp-Session-Id values contain only visible ASCII characters.", + transports=("streamable-http",), + ), + "hosting:session:isolation": Requirement( + source="sdk", + behavior="Each session gets its own server instance; closing one session does not affect others.", + transports=("streamable-http",), + ), + "hosting:session:missing-id": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="A non-initialize POST without Mcp-Session-Id in stateful mode returns 400.", + transports=("streamable-http",), + ), + "hosting:session:post-termination-404": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior=( + "After a session is terminated, any further request carrying that session ID is answered with " + "404 Not Found." + ), + transports=("streamable-http",), + ), + "hosting:session:reinitialize": Requirement( + source="sdk", + behavior="A second initialize on an already-initialized session transport is rejected.", + transports=("streamable-http",), + divergence=Divergence( + note=( + "The transport forwards a second initialize carrying the existing session ID to the running " + "server, which answers it as a fresh handshake; nothing rejects re-initialization." + ), + ), + ), + "hosting:session:reuse": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="A POST carrying a valid Mcp-Session-Id routes to that session's transport with state preserved.", + transports=("streamable-http",), + ), + "hosting:session:unknown-id": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="A POST, GET, or DELETE with an unknown Mcp-Session-Id returns 404.", + transports=("streamable-http",), + ), + "hosting:stateless:concurrent-clients": Requirement( + source="sdk", + behavior="Multiple independent clients can connect to a stateless server concurrently.", + transports=("streamable-http",), + ), + "hosting:stateless:no-reuse": Requirement( + source="sdk", + behavior="A stateless per-request transport cannot be reused for a second request.", + transports=("streamable-http",), + ), + "hosting:stateless:no-session-id": Requirement( + source="sdk", + behavior="In stateless mode no Mcp-Session-Id is emitted and no session validation is performed.", + transports=("streamable-http",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Hosting: auth + # ═══════════════════════════════════════════════════════════════════════════ + "hosting:auth:as-router": Requirement( + source="sdk", + behavior=( + "The authorization-server routes expose the authorize, token, and registration endpoints " + "(and revocation when supported)." + ), + transports=("streamable-http",), + ), + "hosting:auth:aud-validation": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#access-token-usage", + behavior="The resource server validates that the token audience matches its resource identifier.", + transports=("streamable-http",), + divergence=Divergence( + note=( + "BearerAuthBackend never inspects AccessToken.resource; a token issued for a different " + "resource is accepted. Spec MUST." + ), + ), + ), + "hosting:auth:authinfo-propagates": Requirement( + source="sdk", + behavior="A valid token's auth info is exposed to request handlers.", + transports=("streamable-http",), + ), + "hosting:auth:expired-401": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#token-handling", + behavior="An expired token returns 401 invalid_token.", + transports=("streamable-http",), + divergence=Divergence( + note="The challenge carries no `scope` parameter; see the note on hosting:auth:missing-401.", + ), + ), + "hosting:auth:invalid-401": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#token-handling", + behavior="A malformed bearer token or token-verification failure returns 401 with WWW-Authenticate.", + transports=("streamable-http",), + divergence=Divergence( + note="The challenge carries no `scope` parameter; see the note on hosting:auth:missing-401.", + ), + ), + "hosting:auth:metadata-endpoints": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-location", + behavior=( + "The MCP server publishes protected-resource metadata at its well-known endpoint, and the " + "authorization server (which the SDK can also host) publishes authorization-server metadata " + "at its own." + ), + transports=("streamable-http",), + ), + "hosting:auth:missing-401": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#protected-resource-metadata-discovery-requirements", + behavior=( + "A request without an Authorization header is rejected with 401; the WWW-Authenticate header " + "carries resource_metadata (one of the spec's two permitted discovery mechanisms)." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The SDK never emits a `scope` parameter in any WWW-Authenticate challenge — neither the " + "discovery-time 401 (#protected-resource-metadata-discovery-requirements SHOULD) nor the " + "runtime 403 (#runtime-insufficient-scope-errors SHOULD); and for the no-credentials case " + 'it emits error="invalid_token", which RFC 6750 Section 3.1 says SHOULD NOT appear when no ' + "authentication information was presented." + ), + ), + ), + "hosting:auth:prm:authorization-servers-field": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-location", + behavior=( + "The protected-resource metadata document includes an authorization_servers array with at least one entry." + ), + transports=("streamable-http",), + ), + "hosting:auth:query-token-ignored": Requirement( + source="sdk", + behavior=( + "An access token presented in the URI query string is not accepted; the request is treated as " + "unauthenticated." + ), + transports=("streamable-http",), + ), + "hosting:auth:scope-403": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#runtime-insufficient-scope-errors", + behavior=( + "A token lacking a required scope returns 403 with WWW-Authenticate carrying " + "insufficient_scope, the required scope, and resource_metadata." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + 'The SDK emits error="insufficient_scope" and error_description but never the `scope` ' + "parameter the spec SHOULD include; the SDK client reads `scope` from this header to drive " + "step-up (utils.py extract_scope_from_www_auth) — a resource-server/client asymmetry." + ), + ), + ), + "hosting:auth:as:authorize-requires-pkce": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-code-protection", + behavior=( + "The bundled authorization endpoint rejects an authorize request that omits " + "`code_challenge` with `invalid_request`." + ), + transports=("streamable-http",), + ), + "hosting:auth:as:verifier-mismatch": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-code-protection", + behavior=( + "The bundled token endpoint rejects an authorization-code exchange whose `code_verifier` " + "does not hash to the stored `code_challenge` with `invalid_grant`." + ), + transports=("streamable-http",), + ), + "hosting:auth:as:code-single-use": Requirement( + source="sdk", + behavior=( + "An authorization code can be exchanged exactly once; a second exchange of the same code " + "is rejected with `invalid_grant`. Enforced by the provider deleting the code on first use; " + "the handler relies on `load_authorization_code` returning None." + ), + transports=("streamable-http",), + ), + "hosting:auth:as:redirect-uri-binding": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#open-redirection", + behavior=( + "The bundled token endpoint rejects an authorization-code exchange whose `redirect_uri` " + "differs from the one used at authorize; the bundled authorize endpoint rejects a " + "`redirect_uri` not in the client's registered list without redirecting to it." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "RFC 6749 §5.2 assigns redirect_uri mismatch at the token endpoint to invalid_grant; " + "the SDK's TokenHandler returns invalid_request (src/mcp/server/auth/handlers/token.py:157). " + "The rejection itself is the security-relevant property and is correct." + ), + ), + ), + "hosting:auth:as:redirect-uri-scheme": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#communication-security", + behavior=( + "The bundled registration endpoint accepts only redirect URIs that use HTTPS or target a loopback host." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "Not enforced: the registration handler models redirect_uris as AnyUrl with no scheme or " + "host check, so http://evil.example/callback is accepted and registered. The spec's " + "localhost-or-HTTPS rule is left to the provider implementation." + ), + ), + ), + "hosting:auth:as:token-cache-headers": Requirement( + source="sdk", + behavior=("Every token-endpoint response carries `Cache-Control: no-store` and `Pragma: no-cache`."), + transports=("streamable-http",), + ), + "hosting:auth:as:register-error-response": Requirement( + source="sdk", + behavior=( + "The bundled registration endpoint answers invalid client metadata with HTTP 400 and an " + "RFC 7591 error body." + ), + transports=("streamable-http",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Hosting: resumability + # ═══════════════════════════════════════════════════════════════════════════ + "hosting:resume:bad-event-id": Requirement( + source="sdk", + behavior="A Last-Event-ID that cannot be mapped to a stream is rejected.", + transports=("streamable-http",), + divergence=Divergence( + note=( + "The replay path returns an empty SSE stream rather than rejecting an unknown " + "Last-Event-ID; the client cannot tell an unknown ID apart from a stream with no missed " + "events." + ), + ), + ), + "hosting:resume:buffered-replay": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior="Notifications emitted while no client is connected are replayed in order on reconnect.", + transports=("streamable-http",), + ), + "hosting:resume:close-stream": Requirement( + source="sdk", + behavior="Handlers can close an SSE stream cleanly when an event store is configured.", + transports=("streamable-http",), + ), + "hosting:resume:event-ids": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior="With an event store configured, every SSE event carries an id field.", + transports=("streamable-http",), + ), + "hosting:resume:priming": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A server-initiated SSE stream begins with a priming event carrying an event ID and an empty " + "data field; a server that closes the connection before terminating the stream sends an SSE " + "retry field first." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The retry hint is attached to the priming event itself rather than sent as a separate " + "event before the connection closes, and a priming event is only sent when an event store " + "is configured and the negotiated protocol version is at least 2025-11-25." + ), + ), + ), + "hosting:resume:replay": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior="GET with Last-Event-ID replays stored events for that stream after the given id.", + transports=("streamable-http",), + ), + "hosting:resume:stream-scoped": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior="Replay via Last-Event-ID returns only messages from the stream that event id belongs to.", + transports=("streamable-http",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Hosting: HTTP semantics + # ═══════════════════════════════════════════════════════════════════════════ + "hosting:http:accept-406": Requirement( + source="sdk", + behavior="A request whose Accept header does not allow the response representation returns 406.", + transports=("streamable-http",), + ), + "hosting:http:batch": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A POST body is a single JSON-RPC message; batched arrays are rejected for protocol revisions " + "that forbid them." + ), + transports=("streamable-http",), + ), + "hosting:http:content-type-415": Requirement( + source="sdk", + behavior="A POST with a Content-Type other than application/json returns 415.", + transports=("streamable-http",), + divergence=Divergence( + note=( + "The transport-security middleware rejects a non-JSON Content-Type with 400 'Invalid " + "Content-Type header' before the request reaches the transport, so the transport's own 415 " + "path is unreachable through any public entry point." + ), + ), + ), + "hosting:http:disconnect-not-cancel": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A client connection drop during an in-flight request does not cancel the server-side " + "handler; the request continues and its result remains retrievable." + ), + transports=("streamable-http",), + ), + "hosting:http:dns-rebinding": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#security-warning", + behavior=( + "The Origin header is validated on every incoming connection; a request with an invalid " + "Origin is rejected with 403 Forbidden." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The spec's Origin validation is an unconditional MUST; the SDK enables it only when the " + "host is a localhost address or explicit TransportSecuritySettings are passed (with no " + "settings, no Origin validation runs), and additionally validates the Host header " + "(returning 421 on mismatch), which the spec does not require." + ), + ), + ), + "hosting:http:json-response-mode": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior="With JSON response mode enabled, POST returns application/json instead of SSE.", + transports=("streamable-http",), + ), + "hosting:http:method-405": Requirement( + source="sdk", + behavior="An unsupported HTTP method on the MCP endpoint returns 405.", + transports=("streamable-http",), + ), + "hosting:http:no-broadcast": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#multiple-connections", + behavior=( + "When multiple SSE streams are open for a session, each server-originated message is sent on " + "exactly one stream, never duplicated." + ), + transports=("streamable-http",), + ), + "hosting:http:notifications-202": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior="A POST containing only notifications or responses returns 202 with no body.", + transports=("streamable-http",), + ), + "hosting:http:onerror": Requirement( + source="sdk", + behavior="Transport-level rejections are reported through an error callback on the server transport.", + transports=("streamable-http",), + deferred="Not implemented in the SDK: the server transport has no error callback; rejections are logged.", + ), + "hosting:http:parse-error-400": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A POST body that is not valid JSON or not a valid JSON-RPC message is rejected with HTTP 400; " + "the body may carry a JSON-RPC error response (the SDK sends a Parse error body)." + ), + transports=("streamable-http",), + ), + "hosting:http:protocol-version-400": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#protocol-version-header", + behavior="An invalid or unsupported MCP-Protocol-Version header returns 400 Bad Request.", + transports=("streamable-http",), + ), + "hosting:http:protocol-version-default": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#protocol-version-header", + behavior=( + "When no MCP-Protocol-Version header is received and the version cannot be determined another " + "way, the server assumes protocol version 2025-03-26." + ), + transports=("streamable-http",), + ), + "hosting:http:response-same-connection": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A response is delivered on the SSE stream opened by the POST that carried its request (or " + "that stream's resumed continuation), not on an unrelated stream." + ), + transports=("streamable-http",), + ), + "hosting:http:second-sse-rejected": Requirement( + source="sdk", + behavior="A second concurrent standalone GET SSE stream on the same session is rejected.", + transports=("streamable-http",), + ), + "hosting:http:sse-close-after-response": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior="The server terminates a POST-initiated SSE stream after writing the JSON-RPC response.", + transports=("streamable-http",), + ), + "hosting:http:standalone-sse": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#listening-for-messages-from-the-server", + behavior="GET opens a standalone SSE stream that receives server-initiated messages.", + transports=("streamable-http",), + ), + "hosting:http:standalone-sse-no-response": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#listening-for-messages-from-the-server", + behavior=( + "The standalone GET SSE stream carries server requests and notifications but never a JSON-RPC " + "response, except when resuming a prior request stream." + ), + transports=("streamable-http",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Client transport: streamable HTTP + # ═══════════════════════════════════════════════════════════════════════════ + "client-transport:http:404-surfaces": Requirement( + source="sdk", + behavior="A 404 (session expired) on a request surfaces as an error to the caller.", + transports=("streamable-http",), + ), + "client-transport:http:session-404-reinitialize": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior=( + "A 404 in response to a request carrying a session ID makes the client start a new session " + "with a fresh InitializeRequest and no session ID attached." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The client surfaces the 404 as an error to the caller instead of re-initializing a new " + "session; the spec's MUST is not satisfied." + ), + ), + deferred=( + "Not implemented in the SDK: the client surfaces a Session terminated error instead of " + "re-initializing (the surfaced error is pinned by client-transport:http:404-surfaces)." + ), + ), + "client-transport:http:accept-header-get": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#listening-for-messages-from-the-server", + behavior="The client GET to the MCP endpoint includes an Accept header listing text/event-stream.", + transports=("streamable-http",), + ), + "client-transport:http:accept-header-post": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "Every client POST to the MCP endpoint includes an Accept header listing both application/json " + "and text/event-stream." + ), + transports=("streamable-http",), + ), + "client-transport:http:concurrent-streams": Requirement( + source="sdk", + behavior="Multiple concurrent POST-initiated SSE streams each deliver their response to the right caller.", + transports=("streamable-http",), + ), + "client-transport:http:custom-client": Requirement( + source="sdk", + behavior=( + "A caller-supplied HTTP client (and its event hooks and headers) is used for all MCP traffic, " + "including auth flows." + ), + transports=("streamable-http",), + ), + "client-transport:http:custom-headers": Requirement( + source="sdk", + behavior="Caller-supplied headers are sent on every POST, GET, and DELETE to the MCP endpoint.", + transports=("streamable-http",), + ), + "client-transport:http:json-response-parsed": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior="A Content-Type application/json response is parsed as a single JSON-RPC message.", + transports=("streamable-http",), + ), + "client-transport:http:no-reconnect-after-close": Requirement( + source="sdk", + behavior="After the transport is closed, no further reconnection attempts are scheduled.", + transports=("streamable-http",), + ), + "client-transport:http:no-reconnect-after-response": Requirement( + source="sdk", + behavior="A POST-initiated stream that already delivered its response is not reconnected when it closes.", + transports=("streamable-http",), + ), + "client-transport:http:protocol-version-header": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#protocol-version-header", + behavior=( + "After initialization, the client sends the negotiated MCP-Protocol-Version header on every " + "subsequent HTTP request." + ), + transports=("streamable-http",), + ), + "client-transport:http:protocol-version-stored": Requirement( + source="sdk", + behavior=( + "The client transport stores the negotiated protocol version and sends it on every subsequent request." + ), + transports=("streamable-http",), + ), + "client-transport:http:reconnect-get": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior=( + "A standalone GET SSE stream that errors is reconnected with the Last-Event-ID of the last received event." + ), + transports=("streamable-http",), + deferred=( + "The server's standalone GET stream emits no priming event or retry hint, so the client's " + "reconnection path always sleeps the hard-coded 1 s default; a deterministic in-process test " + "would require accepting that real-time wait. The POST-stream reconnection path is covered " + "by client-transport:http:reconnect-post-priming." + ), + ), + "client-transport:http:reconnect-post-priming": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A POST-initiated SSE stream that errors before delivering its response is reconnected only " + "if a priming event (an event carrying an ID) was received on it." + ), + transports=("streamable-http",), + ), + "client-transport:http:reconnect-retry-value": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior="Reconnection delay honours the server-provided SSE retry value when one was sent.", + transports=("streamable-http",), + ), + "client-transport:http:resume-stream-api": Requirement( + source="sdk", + behavior=( + "The client can capture a resumption token, reconnect with the same session id, and receive " + "the notifications it missed." + ), + transports=("streamable-http",), + ), + "client-transport:http:session-stored": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior=( + "The Mcp-Session-Id returned by initialize is stored by the client transport and sent on " + "every subsequent request." + ), + transports=("streamable-http",), + ), + "client-transport:http:sse-405-tolerated": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#listening-for-messages-from-the-server", + behavior="Opening the standalone GET SSE stream tolerates a 405 response without failing the connection.", + transports=("streamable-http",), + ), + "client-transport:http:terminate-405-ok": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="Session termination succeeds without error if the server answers 405 (termination unsupported).", + transports=("streamable-http",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Client auth + # ═══════════════════════════════════════════════════════════════════════════ + "client-auth:401-after-auth-throws": Requirement( + source="sdk", + behavior=( + "If the server still returns 401 after a successful authorization, the client fails instead of looping." + ), + transports=("streamable-http",), + ), + "client-auth:401-triggers-flow": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#protected-resource-metadata-discovery-requirements", + behavior="A 401 on a request triggers the OAuth authorization flow once.", + transports=("streamable-http",), + ), + "client-auth:403-scope-upgrade": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#step-up-authorization-flow", + behavior=( + "A 403 with WWW-Authenticate triggers a scope-upgrade authorization attempt; repeated 403s do not loop." + ), + transports=("streamable-http",), + ), + "client-auth:as-metadata-discovery:priority-order": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-metadata-discovery", + behavior=( + "The client discovers authorization-server metadata by trying, in order, the OAuth " + "path-inserted, OIDC path-inserted, and OIDC path-appended well-known URLs (with the " + "root-path forms when the issuer URL has no path)." + ), + transports=("streamable-http",), + ), + "client-auth:as-metadata-discovery:issuer-validation": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-metadata-discovery", + behavior=( + "The client rejects authorization-server metadata whose issuer does not match the URL the " + "metadata was retrieved from (RFC 8414 section 3.3)." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The SDK parses authorization-server metadata without comparing issuer to the discovery " + "URL; a mismatched issuer is accepted and the flow proceeds." + ), + ), + ), + "client-auth:authorize:error-surfaces": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-flow-steps", + behavior=( + "An OAuth error redirect from the authorize endpoint aborts the flow before any token " + "request is issued, surfacing as an error to the caller." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The callback contract has no error form, so the client surfaces 'No authorization code " + "received' rather than the redirect's `error`/`error_description` values." + ), + ), + ), + "client-auth:authorize:offline-access-consent": Requirement( + source="sdk", + behavior=( + "When the authorization server's metadata advertises offline_access in scopes_supported and " + "the client uses the refresh_token grant, offline_access is appended to the requested scope " + "and prompt=consent is added to the authorize request." + ), + transports=("streamable-http",), + ), + "client-auth:bearer-header:every-request": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#token-requirements", + behavior=( + "Once authorized, the client sends the bearer token in the Authorization header on every HTTP " + "request to the MCP server, never in the query string." + ), + transports=("streamable-http",), + ), + "client-auth:cimd": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#client-id-metadata-documents", + behavior="The client can use a client-ID metadata document URL as its OAuth client_id instead of registration.", + transports=("streamable-http",), + ), + "client-auth:client-credentials": Requirement( + source="sdk", + behavior=( + "A client-credentials provider obtains a token without user interaction and the resulting " + "bearer token authorizes subsequent requests." + ), + transports=("streamable-http",), + ), + "client-auth:dcr:registration-error-surfaces": Requirement( + source="sdk", + behavior=( + "A 400 from the registration endpoint surfaces to the caller as an OAuthRegistrationError " + "carrying the status and the server's RFC 7591 error body." + ), + transports=("streamable-http",), + ), + "client-auth:dcr": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#dynamic-client-registration", + behavior=( + "The client performs dynamic client registration against the authorization server when no " + "client_id is preconfigured." + ), + transports=("streamable-http",), + ), + "client-auth:invalid-client-clears-all": Requirement( + source="sdk", + behavior=( + "An invalid-client or unauthorized-client error during authorization invalidates all stored credentials." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The token-response handlers do not parse the error body; an invalid_client or " + "unauthorized_client response leaves stored client_info untouched. The TypeScript SDK " + "clears it." + ), + ), + deferred=( + "Not implemented in the SDK: no token-response path inspects the error code to decide " + "whether to clear client_info." + ), + ), + "client-auth:invalid-grant-clears-tokens": Requirement( + source="sdk", + behavior="An invalid-grant error during authorization invalidates only the stored tokens.", + transports=("streamable-http",), + ), + "client-auth:pkce:refuse-if-unsupported": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-code-protection", + behavior=( + "The client refuses to proceed when the authorization server's metadata does not include " + "code_challenge_methods_supported, since PKCE support cannot be verified." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The client never inspects code_challenge_methods_supported and proceeds with PKCE S256 " + "regardless; the spec MUST is not enforced." + ), + ), + ), + "client-auth:pkce:s256": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-code-protection", + behavior=( + "The authorization request includes a PKCE S256 code challenge and the token request includes " + "the matching verifier." + ), + transports=("streamable-http",), + ), + "client-auth:pre-registration": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#preregistration", + behavior=( + "A client with statically preconfigured credentials skips dynamic registration and uses them directly." + ), + transports=("streamable-http",), + ), + "client-auth:private-key-jwt": Requirement( + source="sdk", + behavior="The client can authenticate the client-credentials grant with a signed JWT assertion.", + transports=("streamable-http",), + ), + "client-auth:prm-discovery:fallback-order": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#protected-resource-metadata-discovery-requirements", + behavior=( + "The client uses resource_metadata from WWW-Authenticate when present, then falls back to the " + "well-known protected-resource locations in the documented order." + ), + transports=("streamable-http",), + ), + "client-auth:prm-discovery:no-prm-fallback": Requirement( + source="sdk", + behavior=( + "When every protected-resource metadata probe fails, the client falls back to discovering " + "authorization-server metadata directly at the MCP server's origin (the legacy 2025-03-26 path) " + "rather than aborting." + ), + transports=("streamable-http",), + ), + "client-auth:prm-resource-mismatch": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-location", + behavior=( + "The client refuses to proceed when the protected-resource metadata's resource field does not " + "match the server URL it is connecting to." + ), + transports=("streamable-http",), + ), + "client-auth:refresh:transparent": Requirement( + source="sdk", + behavior=( + "An access token the client considers expired is transparently refreshed before the next " + "request, using the stored refresh token; the refresh request includes the resource indicator " + "and the new token is persisted." + ), + transports=("streamable-http",), + ), + "client-auth:resource-parameter": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#resource-parameter-implementation", + behavior=( + "The client includes the canonical server URI as the resource parameter in both the " + "authorization request and the token request." + ), + transports=("streamable-http",), + ), + "client-auth:scope-selection:priority": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#scope-selection-strategy", + behavior=( + "Client selects requested scope from the WWW-Authenticate scope param if present; otherwise " + "uses scopes_supported from the PRM document; otherwise omits scope." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The SDK inserts an extra fallback step between PRM and omit: if the authorization " + "server metadata advertises scopes_supported, that list is used (client/auth/utils.py). " + "This is beyond the spec's two-step chain." + ), + ), + ), + "client-auth:state:verify": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#open-redirection", + behavior=( + "A state parameter is included in the authorization URL, and authorization results with a " + "missing or mismatched state are discarded." + ), + transports=("streamable-http",), + ), + "client-auth:token-endpoint-auth-method": Requirement( + source="sdk", + behavior="The client authenticates to the token endpoint using the auth method established at registration.", + transports=("streamable-http",), + ), + "client-auth:token-provenance": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#token-handling", + behavior=( + "The client sends the MCP server only tokens issued by that server's authorization server, " + "never tokens obtained elsewhere." + ), + transports=("streamable-http",), + deferred=( + "Untestable negative through the public API: there is no path to inject a token obtained " + "elsewhere into the auth provider's state, so the absence cannot be observed end to end." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # stdio transport + # ═══════════════════════════════════════════════════════════════════════════ + "transport:stdio:clean-shutdown": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#shutdown", + behavior="Closing the client transport closes the child process's stdin and the server exits cleanly.", + transports=("stdio",), + ), + "transport:stdio:stream-purity": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#stdio", + behavior=( + "Nothing that is not a valid MCP message is written to the server's stdout, and nothing that " + "is not a valid MCP message is written to its stdin." + ), + transports=("stdio",), + divergence=Divergence( + note=( + "stdio_server's own writes satisfy this, but it does not redirect or guard sys.stdout: " + "handler code that calls print() writes directly to the protocol stream and corrupts the " + "framing. The spec MUST is satisfied only as long as application code behaves." + ), + ), + ), + "transport:stdio:no-embedded-newlines": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#stdio", + behavior="Serialized JSON-RPC messages on stdio contain no embedded newlines; one message per line.", + transports=("stdio",), + ), + "transport:stdio:shutdown-escalation": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#stdio", + behavior=( + "If the server process does not exit after stdin is closed, the client transport terminates " + "it (and kills it if still alive) after a grace period." + ), + transports=("stdio",), + deferred=( + "A server that ignores stdin close takes the full PROCESS_TERMINATION_TIMEOUT (2.0 s) grace " + "period plus up to a further 2.0 s for SIGTERM/SIGKILL escalation; testing that path is " + "real-time-bound (the constant is module-level with no public override) and so is deliberately " + "excluded from this suite. Covered by tests/client/test_stdio.py." + ), + ), + "transport:stdio:stderr-passthrough": Requirement( + source="sdk", + behavior="Server stderr is available to the client and is not consumed by the transport.", + transports=("stdio",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Composite end-to-end flows + # ═══════════════════════════════════════════════════════════════════════════ + "flow:compat:dual-transport-server": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#backwards-compatibility", + behavior=( + "A single server instance can serve streamable HTTP and the legacy SSE transport " + "concurrently; clients on either transport can call the same tools." + ), + transports=("streamable-http", "sse"), + ), + "flow:compat:streamable-then-sse-fallback": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#backwards-compatibility", + behavior=( + "When a streamable HTTP initialize fails with 400, 404, or 405, falling back to the legacy " + "SSE client transport against the same server connects successfully." + ), + transports=("streamable-http", "sse"), + divergence=Divergence( + note=( + "The SDK provides no automatic streamable-HTTP-to-SSE client fallback; the spec's " + "client-side SHOULD is left to the application to compose from streamable_http_client " + "and sse_client. Both halves are independently proven by the matrix." + ), + ), + deferred=( + "A demonstration test would only re-prove what the matrix already covers (an SSE-only " + "server is reachable via sse_client; an unmounted route returns 404), with the application " + "doing the fallback in between rather than the SDK." + ), + ), + "flow:elicitation:multi-step-form": Requirement( + source="sdk", + behavior=( + "A single tool handler issues sequential elicitations; an accept on one step feeds the next, " + "and a decline or cancel at any step short-circuits to a final result." + ), + ), + "flow:elicitation:url-at-session-init": Requirement( + source="sdk", + behavior=( + "The server can issue a URL-mode elicitation over the standalone GET stream immediately after " + "session initialization, before any client request." + ), + transports=("streamable-http",), + deferred=( + "Not implemented in the SDK: no public per-session post-initialization hook exists on either " + "server flavour (Server.lifespan runs at server startup, not per session; ServerSession " + "handles the initialized notification internally with no callback). Driving 'before any " + "client request' deterministically would also require knowing the standalone GET stream is " + "established, which has no synchronization signal." + ), + ), + "flow:elicitation:url-required-then-retry": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#url-elicitation-required-error", + behavior=( + "A tool call rejected with the URL-elicitation-required error can be retried successfully " + "after the client completes the URL flow and the server announces completion." + ), + ), + "flow:multi-client:stateful-isolation": Requirement( + source="sdk", + behavior=( + "Independent clients connected to one stateful server each receive a distinct session and " + "only the notifications produced by their own requests." + ), + transports=("streamable-http",), + ), + "flow:oauth:authorization-code-roundtrip": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-flow-steps", + behavior=( + "Connecting to a protected server walks the authorization-code flow end to end: the first " + "attempt requires authorization, the code is exchanged, and a subsequent connection succeeds." + ), + transports=("streamable-http",), + ), + "flow:resume:tool-call-resumption-token": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior=( + "A tool call interrupted mid-stream is transparently resumed by the client transport using " + "the last-seen event id, delivering only the remaining notifications and the final result." + ), + transports=("streamable-http",), + ), + "flow:session:terminate-then-reconnect": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior=("After terminating a session, a fresh connection obtains a new session id and operations succeed."), + transports=("streamable-http",), + ), + "flow:tool-result:resource-link-follow": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#resource-links", + behavior=( + "A resource_link returned by a tool call can be followed with resources/read on the linked " + "URI to retrieve the referenced contents." + ), + ), +} + + +def requirement(requirement_id: str) -> Callable[[_TestFn], _TestFn]: + """Mark a test as exercising a requirement from :data:`REQUIREMENTS`. + + Applies the `requirement` pytest marker and records the coverage link checked by + `test_coverage.py`. Unknown IDs fail at import time so a typo surfaces as a collection + error on the offending test, not as a missing-coverage report later. + """ + if requirement_id not in REQUIREMENTS: + raise KeyError(f"Unknown requirement id {requirement_id!r}: add it to REQUIREMENTS in {__name__}") + + def apply(test_fn: _TestFn) -> _TestFn: + covered_by(requirement_id).append(f"{test_fn.__module__}.{test_fn.__qualname__}") + return pytest.mark.requirement(requirement_id)(test_fn) + + return apply + + +_COVERAGE: dict[str, list[str]] = {} + + +def covered_by(requirement_id: str) -> list[str]: + """Return the (mutable) list of test names recorded as exercising `requirement_id`.""" + return _COVERAGE.setdefault(requirement_id, []) diff --git a/tests/interaction/auth/__init__.py b/tests/interaction/auth/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/interaction/auth/_harness.py b/tests/interaction/auth/_harness.py new file mode 100644 index 0000000000..d013364f33 --- /dev/null +++ b/tests/interaction/auth/_harness.py @@ -0,0 +1,465 @@ +"""In-process harness for the auth interaction tests. + +Co-hosts the SDK's authorization-server routes, protected-resource metadata route, and the +bearer-gated MCP endpoint on one Starlette app via `Server.streamable_http_app(auth=..., +token_verifier=..., auth_server_provider=...)`, drives that app through the streaming bridge +on a single `httpx.AsyncClient` carrying `auth=OAuthClientProvider(...)`, and completes the +authorize redirect headlessly by GETing the URL through the same bridge and parsing the code +from the 302 `Location`. The whole authorization-code flow runs in one event loop with no +sockets, no threads, and no real time. +""" + +import json +from collections.abc import AsyncIterator, Callable, Mapping, Sequence +from contextlib import AsyncExitStack, asynccontextmanager +from dataclasses import dataclass, field +from typing import Any +from urllib.parse import parse_qs, parse_qsl, urlsplit + +import httpx +from pydantic import AnyHttpUrl, AnyUrl, BaseModel +from starlette.types import ASGIApp, Receive, Scope, Send + +from mcp.client.auth import OAuthClientProvider +from mcp.client.client import Client +from mcp.client.streamable_http import streamable_http_client +from mcp.server import Server +from mcp.server.auth.provider import AccessToken, ProviderTokenVerifier +from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions, RevocationOptions +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken +from tests.interaction._connect import BASE_URL, NO_DNS_REBINDING_PROTECTION +from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider +from tests.interaction.transports._bridge import StreamingASGITransport + +REDIRECT_URI = f"{BASE_URL}/oauth/callback" + +AppShim = Callable[[ASGIApp], ASGIApp] + + +@dataclass +class RecordedRequest: + """A snapshot of an `httpx.Request` at the moment it was sent. + + The auth flow re-yields the same `httpx.Request` object after mutating its headers in + place for the retry, so tests that need to assert on the first attempt's headers must + capture a copy rather than a live reference. `record_requests` produces these. + """ + + method: str + url: httpx.URL + headers: dict[str, str] + content: bytes + + @property + def path(self) -> str: + return self.url.path + + +def record_requests() -> tuple[list[RecordedRequest], Callable[[httpx.Request], None]]: + """Build an `on_request` callback that snapshots each request, and the list it appends to.""" + recorded: list[RecordedRequest] = [] + + def on_request(request: httpx.Request) -> None: + recorded.append( + RecordedRequest( + method=request.method, + url=request.url, + headers=dict(request.headers), + content=bytes(request.content), + ) + ) + + return recorded, on_request + + +def metadata_body(model: BaseModel, **extra: object) -> bytes: + """Serialize a metadata model to a JSON body for `shimmed_app(serve=...)`. + + `extra` keys are merged into the serialized object so a test can inject fields the model + does not declare (e.g. an unknown extension field, to prove the client's parser tolerates + unrecognized members per RFC 8414/9728 §3.2). The model itself would silently drop such + fields at construction, so they have to be added after serialization. + """ + document = model.model_dump(by_alias=True, mode="json", exclude_none=True) + document.update(extra) + return json.dumps(document).encode() + + +class StaticTokenVerifier: + """A `TokenVerifier` backed by a fixed token→`AccessToken` mapping. + + Any token string not in the mapping verifies to `None`, which the bearer middleware treats + as an unrecognized token. Tests seed the mapping with the exact token shapes (valid, expired, + wrong scope, wrong audience) they need so the resource-server gate's behaviour is asserted in + isolation from the authorization-server provider. + """ + + def __init__(self, tokens: Mapping[str, AccessToken]) -> None: + self._tokens = dict(tokens) + + async def verify_token(self, token: str) -> AccessToken | None: + return self._tokens.get(token) + + +class InMemoryTokenStorage: + """A `TokenStorage` that holds tokens and client info as instance attributes. + + Tests pre-seed `client_info` (via the constructor or by assignment) to drive the + pre-registered path, and read both attributes after the flow to assert what the SDK + persisted. + """ + + def __init__(self, *, client_info: OAuthClientInformationFull | None = None) -> None: + self.tokens: OAuthToken | None = None + self.client_info: OAuthClientInformationFull | None = client_info + + async def get_tokens(self) -> OAuthToken | None: + return self.tokens + + async def set_tokens(self, tokens: OAuthToken) -> None: + self.tokens = tokens + + async def get_client_info(self) -> OAuthClientInformationFull | None: + return self.client_info + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + self.client_info = client_info + + +class HeadlessOAuth: + """Completes the authorize step in-process by following the redirect through the bridge. + + `redirect_handler` GETs the authorize URL on the bound client (with `auth=None` so the + request does not re-enter the locked auth flow), parses `code` and `state` from the 302 + `Location`, and stashes them; `callback_handler` returns the stashed pair. Tests inspect + `authorize_url` to assert what the SDK put on the authorize request. + + `state_override`: when set, `callback_handler` returns this value as the state instead of + the one parsed from the redirect, so tests can drive the state-mismatch path. + """ + + def __init__(self, *, state_override: str | None = None) -> None: + self.authorize_url: str | None = None + self.authorize_urls: list[str] = [] + self.error: str | None = None + self._state_override = state_override + self._http: httpx.AsyncClient | None = None + self._code: str = "" + self._state: str | None = None + + def bind(self, http_client: httpx.AsyncClient) -> None: + self._http = http_client + + async def redirect_handler(self, authorization_url: str) -> None: + assert self._http is not None + self.authorize_url = authorization_url + self.authorize_urls.append(authorization_url) + # auth=None is load-bearing: without it the GET re-enters OAuthClientProvider.async_auth_flow + # through its context lock and the flow deadlocks. + response = await self._http.get(authorization_url, follow_redirects=False, auth=None) + assert response.status_code == 302, f"authorize endpoint returned {response.status_code}: {response.text}" + params = parse_qs(urlsplit(response.headers["location"]).query) + self._code = params.get("code", [""])[0] + self._state = params.get("state", [None])[0] + self.error = params.get("error", [None])[0] + + async def callback_handler(self) -> tuple[str, str | None]: + return self._code, self._state_override if self._state_override is not None else self._state + + +def auth_settings( + *, required_scopes: Sequence[str] = ("mcp",), valid_scopes: Sequence[str] | None = None +) -> AuthSettings: + """Build `AuthSettings` for the co-hosted authorization + resource server. + + The issuer and resource URLs use the suite's loopback origin, which `validate_issuer_url` + accepts in lieu of HTTPS. Dynamic client registration is enabled. `valid_scopes` defaults + to `required_scopes` so a client requesting exactly those passes registration scope + validation; tests pass a wider set when they need the protected-resource metadata's + `scopes_supported` (which mirrors `required_scopes`) to differ from what the client may + register or when AS metadata should advertise additional scopes such as `offline_access`. + """ + required = list(required_scopes) + valid = list(valid_scopes) if valid_scopes is not None else required + return AuthSettings( + issuer_url=AnyHttpUrl(BASE_URL), + resource_server_url=AnyHttpUrl(f"{BASE_URL}/mcp"), + required_scopes=required, + client_registration_options=ClientRegistrationOptions( + enabled=True, valid_scopes=valid, default_scopes=required + ), + revocation_options=RevocationOptions(enabled=False), + ) + + +def oauth_client_metadata() -> OAuthClientMetadata: + """Build the client's registration metadata. + + `scope` is left unset so the SDK's scope-selection strategy chooses one from the server's + metadata before registration. + """ + return OAuthClientMetadata( + client_name="interaction-suite", + redirect_uris=[AnyUrl(REDIRECT_URI)], + grant_types=["authorization_code", "refresh_token"], + ) + + +def shimmed_app( + app: ASGIApp, + *, + not_found: frozenset[str] = frozenset(), + serve: Mapping[str, bytes | tuple[int, bytes]] | None = None, +) -> ASGIApp: + """Wrap an ASGI app so specific paths return canned responses before reaching the real app. + + Paths in `serve` return the given body as `application/json` (status 200, or the supplied + status when the value is a `(status, body)` pair); paths in `not_found` return 404; + everything else reaches the wrapped app unchanged. Used by the discovery tests to make a + well-known endpoint 404 or return alternate metadata while keeping the real authorization + and MCP endpoints behind it. + """ + overrides: dict[str, tuple[int, bytes]] = { + path: value if isinstance(value, tuple) else (200, value) for path, value in (serve or {}).items() + } + + async def wrapped(scope: Scope, receive: Receive, send: Send) -> None: + path = scope["path"] + if path in overrides: + status, body = overrides[path] + await send( + { + "type": "http.response.start", + "status": status, + "headers": [ + (b"content-type", b"application/json"), + (b"content-length", str(len(body)).encode()), + ], + } + ) + await send({"type": "http.response.body", "body": body}) + return + if path in not_found: + await send({"type": "http.response.start", "status": 404, "headers": []}) + await send({"type": "http.response.body", "body": b""}) + return + await app(scope, receive, send) + + return wrapped + + +def shim( + *, not_found: frozenset[str] = frozenset(), serve: Mapping[str, bytes | tuple[int, bytes]] | None = None +) -> AppShim: + """Build an `app_shim` for `connect_with_oauth` that applies `shimmed_app` with these overrides.""" + return lambda app: shimmed_app(app, not_found=not_found, serve=serve) + + +@dataclass +class _FirstChallenge: + """ASGI shim that answers the first request to a path with 401 + a given WWW-Authenticate. + + Subsequent requests pass through to the wrapped app. Used to make the initial 401 carry + parameters (such as `scope=`) that the SDK's own bearer middleware cannot be configured + to emit, so client behaviour driven by those parameters is reachable end to end. Reserve + this pattern for behaviour the real server cannot be made to produce. + """ + + app: ASGIApp + path: str + www_authenticate: str + _seen: set[str] = field(default_factory=set[str]) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http" and scope["path"] == self.path and self.path not in self._seen: + self._seen.add(self.path) + await send( + { + "type": "http.response.start", + "status": 401, + "headers": [(b"www-authenticate", self.www_authenticate.encode())], + } + ) + await send({"type": "http.response.body", "body": b""}) + return + await self.app(scope, receive, send) + + +def first_challenge_shim(www_authenticate: str, *, path: str = "/mcp") -> Callable[[ASGIApp], ASGIApp]: + """Build an `app_shim` that 401s the first request to `path` with the given header value.""" + return lambda app: _FirstChallenge(app, path, www_authenticate) + + +def step_up_shim(www_authenticate: str, *, on_nth_authenticated_post: int = 2) -> AppShim: + """Build an `app_shim` that 403s the Nth authenticated POST to `/mcp` with the given challenge. + + Subsequent requests pass through. Used to drive the client's `insufficient_scope` step-up + handling: the SDK's bearer middleware never emits `scope=` in its 403 challenge (see the + divergence on `hosting:auth:scope-403`), so the test supplies the 403 itself. Reserve this + pattern for behaviour the real server cannot be made to produce. + + The default `on_nth_authenticated_post=2` targets the `notifications/initialized` POST: the + first authenticated POST is the auth flow's retry of the original initialize request (yielded + after the 401 branch, where the generator ends without inspecting the response), so a 403 + there would not reach the step-up handler. + """ + seen = 0 + fired = False + + def factory(app: ASGIApp) -> ASGIApp: + async def wrapped(scope: Scope, receive: Receive, send: Send) -> None: + nonlocal seen, fired + if ( + not fired + and scope["type"] == "http" + and scope["path"] == "/mcp" + and scope["method"] == "POST" + and any(name == b"authorization" for name, _ in scope["headers"]) + ): + seen += 1 + if seen < on_nth_authenticated_post: + await app(scope, receive, send) + return + fired = True + await send( + { + "type": "http.response.start", + "status": 403, + "headers": [(b"www-authenticate", www_authenticate.encode())], + } + ) + await send({"type": "http.response.body", "body": b""}) + return + await app(scope, receive, send) + + return wrapped + + return factory + + +def m2m_token_shim(provider: InMemoryAuthorizationServerProvider, *, scopes: list[str]) -> AppShim: + """Build an `app_shim` that handles `grant_type=client_credentials` at `/token`. + + The SDK server's `TokenHandler` only routes `authorization_code` and `refresh_token`, so a + `client_credentials` request would fail discriminator validation. This shim mints a token via + `provider.mint_access_token` so the M2M client providers can complete e2e against the real + bearer middleware. The shim is harness; the SDK-under-test is the client provider, whose + outbound `/token` body the test asserts. The shim does not authenticate the client (no + credential check) because the test asserts the credentials on the recorded request, not on + the server's acceptance. + """ + + def factory(app: ASGIApp) -> ASGIApp: + async def wrapped(scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http" and scope["path"] == "/token" and scope["method"] == "POST": + # The streaming bridge buffers the request body and delivers it in a single + # http.request event, so one receive is sufficient. + message = await receive() + assert not message.get("more_body", False) + form = dict(parse_qsl(message.get("body", b"").decode())) + assert form.get("grant_type") == "client_credentials", ( + f"m2m_token_shim only handles client_credentials; got {form.get('grant_type')!r}" + ) + access = provider.mint_access_token(client_id="m2m", scopes=scopes, resource=form.get("resource")) + token = OAuthToken(access_token=access, token_type="Bearer", expires_in=3600, scope=" ".join(scopes)) + response_body = token.model_dump_json(exclude_none=True).encode() + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [ + (b"content-type", b"application/json"), + (b"content-length", str(len(response_body)).encode()), + (b"cache-control", b"no-store"), + ], + } + ) + await send({"type": "http.response.body", "body": response_body}) + return + await app(scope, receive, send) + + return wrapped + + return factory + + +@asynccontextmanager +async def connect_with_oauth( + server: Server, + *, + provider: InMemoryAuthorizationServerProvider, + settings: AuthSettings | None = None, + storage: InMemoryTokenStorage | None = None, + client_metadata: OAuthClientMetadata | None = None, + client_metadata_url: str | None = None, + headless: HeadlessOAuth | None = None, + auth: httpx.Auth | None = None, + verify_tokens: bool = True, + app_shim: Callable[[ASGIApp], ASGIApp] | None = None, + on_request: Callable[[httpx.Request], None] | None = None, +) -> AsyncIterator[tuple[Client, HeadlessOAuth]]: + """Connect a `Client` to a server's bearer-gated streamable-HTTP app, completing OAuth in process. + + Yields the connected `Client` and the `HeadlessOAuth` whose `authorize_url` records what the + SDK put on the authorize request. `on_request` records every HTTP request the underlying + `httpx.AsyncClient` issues, including those yielded from inside the auth flow. + + `headless`: supply a pre-configured `HeadlessOAuth` to override the callback behaviour + (state mismatch, error redirects). `verify_tokens=False` mounts the MCP endpoint without + the bearer middleware so a flow driven by a shimmed 401 completes regardless of the granted + scopes. `app_shim` wraps the built Starlette app before it reaches the bridge transport, + for tests that need to intercept or rewrite specific server responses. + + `auth`: supply a pre-built `httpx.Auth` (such as `ClientCredentialsOAuthProvider`) to use + instead of constructing the default `OAuthClientProvider`; in that case `storage`, + `client_metadata`, `client_metadata_url`, and `headless` are unused (the yielded + `HeadlessOAuth` is never invoked and its `authorize_url` stays None). + """ + settings = settings if settings is not None else auth_settings() + storage = storage if storage is not None else InMemoryTokenStorage() + client_metadata = client_metadata if client_metadata is not None else oauth_client_metadata() + headless = headless if headless is not None else HeadlessOAuth() + + oauth = ( + auth + if auth is not None + else OAuthClientProvider( + server_url=f"{BASE_URL}/mcp", + client_metadata=client_metadata, + storage=storage, + redirect_handler=headless.redirect_handler, + callback_handler=headless.callback_handler, + client_metadata_url=client_metadata_url, + ) + ) + + app: ASGIApp = server.streamable_http_app( + auth=settings, + token_verifier=ProviderTokenVerifier(provider) if verify_tokens else None, + auth_server_provider=provider, + transport_security=NO_DNS_REBINDING_PROTECTION, + ) + if app_shim is not None: + app = app_shim(app) + + event_hooks: dict[str, list[Callable[..., Any]]] | None = None + if on_request is not None: + record = on_request + + async def hook(request: httpx.Request) -> None: + record(request) + + event_hooks = {"request": [hook]} + + async with AsyncExitStack() as stack: + await stack.enter_async_context(server.session_manager.run()) + http_client = await stack.enter_async_context( + httpx.AsyncClient( + transport=StreamingASGITransport(app), base_url=BASE_URL, auth=oauth, event_hooks=event_hooks + ) + ) + headless.bind(http_client) + client = await stack.enter_async_context( + Client(streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client)) + ) + yield client, headless diff --git a/tests/interaction/auth/_provider.py b/tests/interaction/auth/_provider.py new file mode 100644 index 0000000000..5c88995a30 --- /dev/null +++ b/tests/interaction/auth/_provider.py @@ -0,0 +1,186 @@ +"""An in-memory implementation of the SDK's OAuth authorization-server provider protocol. + +The provider holds clients, authorization codes, refresh tokens and access tokens in plain +instance dicts so tests can inspect them; tokens are minted from `secrets.token_hex` so the +values are unique without being predictable. The behaviour mirrors what the SDK's authorization +handlers expect: `authorize` immediately mints a code and returns the redirect, `exchange_*` +issue and rotate tokens, and `load_*` are simple lookups. Only the parts the auth interaction +suite drives are implemented; methods the suite does not exercise raise `NotImplementedError`. +""" + +import secrets +import time + +from mcp.server.auth.provider import ( + AccessToken, + AuthorizationCode, + AuthorizationParams, + OAuthAuthorizationServerProvider, + RefreshToken, + TokenError, + construct_redirect_uri, +) +from mcp.shared.auth import OAuthClientInformationFull, OAuthToken + +_TOKEN_LIFETIME_SECONDS = 3600 + + +class InMemoryAuthorizationServerProvider( + OAuthAuthorizationServerProvider[AuthorizationCode, RefreshToken, AccessToken] +): + """An OAuth authorization-server provider backed by in-memory dicts. + + Holds registered clients, issued codes, refresh tokens and access tokens as instance state + so tests can both drive the SDK's authorization handlers and inspect what was issued. + + Knobs: + `default_scopes`: scopes granted when an authorize request supplies none. + `deny_authorize`: every authorize request returns an `error=access_denied` redirect. + `issue_expired_first`: the first issued token's `expires_in` is in the past so the + client immediately considers it expired and refreshes; the server-side + `AccessToken.expires_at` stays in the future so the bearer middleware accepts it + on the retry that completes the connect. + `fail_next_refresh`: the next refresh-token exchange raises `invalid_grant` once. + `reject_all_tokens`: `load_access_token` returns None for every token, so the bearer + middleware 401s every authenticated request. + """ + + def __init__( + self, + *, + default_scopes: list[str] | None = None, + deny_authorize: bool = False, + issue_expired_first: bool = False, + fail_next_refresh: bool = False, + reject_all_tokens: bool = False, + ) -> None: + self._default_scopes = list(default_scopes) if default_scopes is not None else ["mcp"] + self._issuer = "http://127.0.0.1:8000" + self._deny_authorize = deny_authorize + self._issue_expired_first = issue_expired_first + self._fail_next_refresh = fail_next_refresh + self._reject_all_tokens = reject_all_tokens + self._tokens_issued = 0 + self.clients: dict[str, OAuthClientInformationFull] = {} + self.codes: dict[str, AuthorizationCode] = {} + self.refresh_tokens: dict[str, RefreshToken] = {} + self.access_tokens: dict[str, AccessToken] = {} + + def _next_expires_in(self) -> int: + self._tokens_issued += 1 + if self._issue_expired_first and self._tokens_issued == 1: + return -_TOKEN_LIFETIME_SECONDS + return _TOKEN_LIFETIME_SECONDS + + def mint_access_token(self, *, client_id: str, scopes: list[str], resource: str | None = None) -> str: + """Mint and store an access token, returning its value. + + Used by the auth-code and refresh exchanges and by the M2M `/token` shim. The + server-side `expires_at` is always in the future regardless of `issue_expired_first`, + which only affects what the client is told. + """ + access = f"access_{secrets.token_hex(16)}" + self.access_tokens[access] = AccessToken( + token=access, + client_id=client_id, + scopes=scopes, + expires_at=int(time.time()) + _TOKEN_LIFETIME_SECONDS, + resource=resource, + ) + return access + + async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: + return self.clients.get(client_id) + + async def register_client(self, client_info: OAuthClientInformationFull) -> None: + assert client_info.client_id is not None + self.clients[client_info.client_id] = client_info + + async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: + """Mint an authorization code immediately and return the redirect carrying it. + + A real provider would interpose user consent here; the test provider grants + unconditionally so the headless redirect handler can complete the flow in-process. + When `deny_authorize` is set, returns an `error=access_denied` redirect instead. + """ + assert client.client_id is not None + if self._deny_authorize: + return construct_redirect_uri( + str(params.redirect_uri), error="access_denied", error_description="user denied", state=params.state + ) + code = AuthorizationCode( + code=f"code_{secrets.token_hex(16)}", + client_id=client.client_id, + scopes=params.scopes or self._default_scopes, + expires_at=time.time() + 300, + code_challenge=params.code_challenge, + redirect_uri=params.redirect_uri, + redirect_uri_provided_explicitly=params.redirect_uri_provided_explicitly, + resource=params.resource, + ) + self.codes[code.code] = code + # `iss` is RFC 9207's authorization-response issuer identifier — an extra parameter many + # real authorization servers send. Including it on every success redirect proves the + # client tolerates unrecognized callback parameters (RFC 6749 §4.1.2 MUST) by virtue of + # every flow test passing unchanged. + return construct_redirect_uri(str(params.redirect_uri), code=code.code, state=params.state, iss=self._issuer) + + async def load_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: str + ) -> AuthorizationCode | None: + return self.codes.get(authorization_code) + + async def exchange_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode + ) -> OAuthToken: + """Mint an access token and a refresh token for a valid authorization code, then consume the code.""" + assert client.client_id is not None + access = self.mint_access_token( + client_id=client.client_id, scopes=authorization_code.scopes, resource=authorization_code.resource + ) + refresh = f"refresh_{secrets.token_hex(16)}" + self.refresh_tokens[refresh] = RefreshToken( + token=refresh, + client_id=client.client_id, + scopes=authorization_code.scopes, + ) + del self.codes[authorization_code.code] + return OAuthToken( + access_token=access, + token_type="Bearer", + expires_in=self._next_expires_in(), + scope=" ".join(authorization_code.scopes), + refresh_token=refresh, + ) + + async def load_access_token(self, token: str) -> AccessToken | None: + if self._reject_all_tokens: + return None + return self.access_tokens.get(token) + + async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None: + return self.refresh_tokens.get(refresh_token) + + async def exchange_refresh_token( + self, client: OAuthClientInformationFull, refresh_token: RefreshToken, scopes: list[str] + ) -> OAuthToken: + """Mint a new access token and rotate the refresh token, consuming the old one.""" + assert client.client_id is not None + if self._fail_next_refresh: + self._fail_next_refresh = False + raise TokenError(error="invalid_grant", error_description="refresh denied by harness") + access = self.mint_access_token(client_id=client.client_id, scopes=scopes) + new_refresh = f"refresh_{secrets.token_hex(16)}" + self.refresh_tokens[new_refresh] = RefreshToken(token=new_refresh, client_id=client.client_id, scopes=scopes) + del self.refresh_tokens[refresh_token.token] + return OAuthToken( + access_token=access, + token_type="Bearer", + expires_in=self._next_expires_in(), + scope=" ".join(scopes), + refresh_token=new_refresh, + ) + + async def revoke_token(self, token: AccessToken | RefreshToken) -> None: + """Not exercised by this suite; revocation is out of scope for the interaction tests.""" + raise NotImplementedError diff --git a/tests/interaction/auth/test_as_handlers.py b/tests/interaction/auth/test_as_handlers.py new file mode 100644 index 0000000000..5cb4e92d86 --- /dev/null +++ b/tests/interaction/auth/test_as_handlers.py @@ -0,0 +1,300 @@ +"""Error-plane behaviour of the SDK's bundled OAuth authorization-server handlers. + +The end-to-end OAuth tests prove the handlers' happy paths; these tests drive the same +mounted authorization server directly with raw httpx so the assertions are the HTTP +semantics (status, redirect target, error body, headers) the OAuth RFCs mandate. Almost +every behaviour here is enforced by the SDK's own handlers; where the pinned output +deviates from the RFC, the manifest entry carries the divergence. +""" + +import base64 +import hashlib +import secrets +from collections.abc import AsyncIterator +from urllib.parse import parse_qs, urlsplit + +import httpx +import pytest +from inline_snapshot import snapshot + +from mcp.server import Server +from mcp.server.auth.provider import ProviderTokenVerifier +from mcp.shared.auth import OAuthClientInformationFull +from tests.interaction._connect import mounted_app +from tests.interaction._requirements import requirement +from tests.interaction.auth._harness import REDIRECT_URI, auth_settings, oauth_client_metadata +from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider + +pytestmark = pytest.mark.anyio + + +@pytest.fixture +async def as_app() -> AsyncIterator[tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider]]: + """Co-host the SDK's authorization-server routes and yield a raw httpx client against them.""" + provider = InMemoryAuthorizationServerProvider() + settings = auth_settings() + async with mounted_app( + Server("guarded"), + auth=settings, + token_verifier=ProviderTokenVerifier(provider), + auth_server_provider=provider, + ) as (http, _): + yield http, provider + + +def _pkce_pair() -> tuple[str, str]: + """Generate a (code_verifier, code_challenge) pair the same way the SDK client does.""" + verifier = secrets.token_urlsafe(48)[:64] + challenge = base64.urlsafe_b64encode(hashlib.sha256(verifier.encode()).digest()).decode().rstrip("=") + return verifier, challenge + + +async def _register_client(http: httpx.AsyncClient) -> OAuthClientInformationFull: + """Dynamically register a client and return its full credentials.""" + response = await http.post("/register", content=oauth_client_metadata().model_dump_json()) + assert response.status_code == 201 + return OAuthClientInformationFull.model_validate_json(response.content) + + +async def _mint_code(http: httpx.AsyncClient) -> tuple[OAuthClientInformationFull, str, str]: + """Register a client, complete a valid authorize step, and return (client_info, code, verifier).""" + client_info = await _register_client(http) + assert client_info.client_id is not None + verifier, challenge = _pkce_pair() + response = await http.get( + "/authorize", + params={ + "response_type": "code", + "client_id": client_info.client_id, + "redirect_uri": REDIRECT_URI, + "code_challenge": challenge, + "code_challenge_method": "S256", + "state": "s", + }, + follow_redirects=False, + ) + assert response.status_code == 302 + redirect = urlsplit(response.headers["location"]) + assert f"{redirect.scheme}://{redirect.netloc}{redirect.path}" == REDIRECT_URI + code = parse_qs(redirect.query)["code"][0] + return client_info, code, verifier + + +def _token_form(client_info: OAuthClientInformationFull, **overrides: str) -> dict[str, str]: + """Build the form body for an authorization-code token request, with the defaults a real client would send.""" + assert client_info.client_id is not None + assert client_info.client_secret is not None + form = { + "grant_type": "authorization_code", + "client_id": client_info.client_id, + "client_secret": client_info.client_secret, + "redirect_uri": REDIRECT_URI, + } + form.update(overrides) + return form + + +@requirement("hosting:auth:as:authorize-requires-pkce") +async def test_authorize_without_a_code_challenge_is_rejected_with_invalid_request( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """An authorize request omitting `code_challenge` is redirected back with `error=invalid_request`. + + PKCE is mandatory: the bundled authorize handler models `code_challenge` as a required field, so + a code without a stored challenge can never be issued. That makes the PKCE-downgrade attack (a + token request carrying a verifier for a code minted without a challenge) structurally impossible + through these handlers, so no separate downgrade-guard test is needed. + """ + http, _ = as_app + client_info = await _register_client(http) + assert client_info.client_id is not None + + response = await http.get( + "/authorize", + params={ + "response_type": "code", + "client_id": client_info.client_id, + "redirect_uri": REDIRECT_URI, + "state": "abc", + }, + follow_redirects=False, + ) + + assert response.status_code == 302 + redirect = urlsplit(response.headers["location"]) + assert f"{redirect.scheme}://{redirect.netloc}{redirect.path}" == REDIRECT_URI + params = parse_qs(redirect.query) + assert params["error"] == ["invalid_request"] + assert params["state"] == ["abc"] + assert "code_challenge" in params["error_description"][0] + + +@requirement("hosting:auth:as:verifier-mismatch") +async def test_a_mismatched_code_verifier_is_rejected_with_invalid_grant( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """A token exchange whose `code_verifier` does not hash to the stored challenge is rejected.""" + http, _ = as_app + client_info, code, _ = await _mint_code(http) + + response = await http.post("/token", data=_token_form(client_info, code=code, code_verifier="0" * 64)) + + assert response.status_code == 400 + assert response.json() == snapshot({"error": "invalid_grant", "error_description": "incorrect code_verifier"}) + + +@requirement("hosting:auth:as:code-single-use") +async def test_reusing_an_authorization_code_is_rejected_with_invalid_grant( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """An authorization code can be exchanged exactly once; a second exchange is `invalid_grant`. + + The handler does not track used codes itself: it returns `invalid_grant` whenever the provider's + `load_authorization_code` returns None, and the in-memory provider deletes the code on first + exchange. The test proves the combination enforces single-use; a provider that did not consume + codes would not get this guarantee from the handler. + """ + http, _ = as_app + client_info, code, verifier = await _mint_code(http) + form = _token_form(client_info, code=code, code_verifier=verifier) + + first = await http.post("/token", data=form) + assert first.status_code == 200 + assert first.json()["token_type"] == "Bearer" + + second = await http.post("/token", data=form) + assert second.status_code == 400 + assert second.json() == snapshot( + {"error": "invalid_grant", "error_description": "authorization code does not exist"} + ) + + +@requirement("hosting:auth:as:redirect-uri-binding") +async def test_a_redirect_uri_differing_from_authorize_is_rejected_at_the_token_endpoint( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """A token exchange whose `redirect_uri` differs from the one used at authorize is rejected. + + This is the security-critical half of redirect-URI binding: a code intercepted via redirect + substitution cannot be redeemed because the attacker cannot reproduce the original authorize + redirect URI at the token endpoint. RFC 6749 §5.2 specifies `invalid_grant` for this case; + the SDK returns `invalid_request` (see the divergence on the requirement). The rejection + itself is the security property and is correct. + """ + http, _ = as_app + client_info, code, verifier = await _mint_code(http) + + response = await http.post( + "/token", + data=_token_form(client_info, code=code, code_verifier=verifier, redirect_uri=f"{REDIRECT_URI}/different"), + ) + + assert response.status_code == 400 + assert response.json() == snapshot( + { + "error": "invalid_request", + "error_description": "redirect_uri did not match the one used when creating auth code", + } + ) + + +@requirement("hosting:auth:as:token-cache-headers") +async def test_token_responses_carry_cache_control_no_store( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """Every token-endpoint response (success and error) carries `Cache-Control: no-store`.""" + http, _ = as_app + client_info, code, verifier = await _mint_code(http) + form = _token_form(client_info, code=code, code_verifier=verifier) + + success = await http.post("/token", data=form) + assert success.status_code == 200 + assert success.headers["cache-control"] == "no-store" + assert success.headers["pragma"] == "no-cache" + + failure = await http.post("/token", data=form) + assert failure.status_code == 400 + assert failure.headers["cache-control"] == "no-store" + assert failure.headers["pragma"] == "no-cache" + + +@requirement("hosting:auth:as:register-error-response") +async def test_registration_with_invalid_metadata_is_rejected_with_400( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """Invalid client metadata at the registration endpoint returns 400 with an RFC 7591 error body.""" + http, _ = as_app + + malformed = await http.post("/register", json={"redirect_uris": ["not-a-url"]}) + assert malformed.status_code == 400 + assert malformed.json()["error"] == "invalid_client_metadata" + + body = oauth_client_metadata().model_dump(mode="json", exclude_none=True) + + no_auth_code = await http.post("/register", json=body | {"grant_types": ["refresh_token"]}) + assert no_auth_code.status_code == 400 + assert no_auth_code.json() == snapshot( + {"error": "invalid_client_metadata", "error_description": "grant_types must include 'authorization_code'"} + ) + + bad_scope = await http.post("/register", json=body | {"scope": "forbidden"}) + assert bad_scope.status_code == 400 + body = bad_scope.json() + assert body["error"] == "invalid_client_metadata" + # The description embeds a set difference whose ordering is not stable, so assert the prefix. + assert body["error_description"].startswith("Requested scopes are not valid: ") + + +@requirement("hosting:auth:as:redirect-uri-binding") +async def test_authorize_with_an_unregistered_redirect_uri_is_rejected_directly( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """An authorize request naming an unregistered `redirect_uri` returns 400 without redirecting to it. + + The security property is that the authorization server never redirects to an unvalidated URI: + the response is a direct JSON error to the user agent, not a 302 to the attacker's host. + """ + http, _ = as_app + client_info = await _register_client(http) + assert client_info.client_id is not None + _, challenge = _pkce_pair() + + response = await http.get( + "/authorize", + params={ + "response_type": "code", + "client_id": client_info.client_id, + "redirect_uri": "http://127.0.0.1:8000/evil", + "code_challenge": challenge, + "code_challenge_method": "S256", + }, + follow_redirects=False, + ) + + assert response.status_code == 400 + assert "location" not in response.headers + body = response.json() + assert body["error"] == "invalid_request" + assert "not registered" in body["error_description"] + + +@requirement("hosting:auth:as:redirect-uri-scheme") +async def test_a_non_loopback_http_redirect_uri_is_accepted_at_registration( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """A registration carrying a non-HTTPS, non-loopback redirect URI is accepted. + + The spec requires every redirect URI to be either HTTPS or a loopback host; the bundled + registration handler does not enforce this and registers `http://evil.example/callback` + successfully. See the divergence on the requirement. + """ + http, provider = as_app + body = oauth_client_metadata().model_dump(mode="json", exclude_none=True) + body["redirect_uris"] = ["http://evil.example/callback"] + + response = await http.post("/register", json=body) + + assert response.status_code == 201 + info = OAuthClientInformationFull.model_validate_json(response.content) + assert [str(u) for u in (info.redirect_uris or [])] == ["http://evil.example/callback"] + assert info.client_id in provider.clients diff --git a/tests/interaction/auth/test_authorize_token.py b/tests/interaction/auth/test_authorize_token.py new file mode 100644 index 0000000000..cb8524c097 --- /dev/null +++ b/tests/interaction/auth/test_authorize_token.py @@ -0,0 +1,399 @@ +"""Authorization-request, token-request, and PKCE wire-level invariants of the SDK's OAuth client. + +Every test connects a real `Client` end to end via `connect_with_oauth`; the assertions are on +the parsed authorize URL and the recorded `/token` form body, because those wire shapes are what +the spec mandates and `Client` cannot observe them. The recording uses `record_requests`, which +snapshots each request at send time so the auth flow's in-place header mutation on retry never +affects what was captured for the first attempt. + +Tests #1/#2/#4/#5 share one `recorded_oauth_flow` fixture (one connect, several disjoint +assertions on its recording); the others connect fresh because each needs a different harness +configuration. +""" + +import base64 +import hashlib +import json +import re +from collections.abc import AsyncIterator +from dataclasses import dataclass +from urllib.parse import parse_qsl, quote, urlsplit + +import anyio +import pytest +from inline_snapshot import snapshot +from pydantic import AnyHttpUrl, AnyUrl + +from mcp import types +from mcp.client.auth import OAuthFlowError +from mcp.server import Server, ServerRequestContext +from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata +from mcp.types import ListToolsResult, Tool +from tests.interaction._connect import BASE_URL +from tests.interaction._requirements import requirement +from tests.interaction.auth._harness import ( + REDIRECT_URI, + HeadlessOAuth, + InMemoryTokenStorage, + RecordedRequest, + auth_settings, + connect_with_oauth, + first_challenge_shim, + record_requests, + shimmed_app, +) +from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider + +pytestmark = pytest.mark.anyio + +PRM_PATH = "/.well-known/oauth-protected-resource/mcp" +ASM_PATH = "/.well-known/oauth-authorization-server" + + +async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="echo", input_schema={"type": "object"})]) + + +def authorize_params(authorize_url: str) -> dict[str, str]: + """Parse the authorize URL's query string into a flat dict (one value per key).""" + return dict(parse_qsl(urlsplit(authorize_url).query)) + + +def form_body(request: RecordedRequest) -> dict[str, str]: + """Parse an `application/x-www-form-urlencoded` request body into a flat dict.""" + return dict(parse_qsl(request.content.decode())) + + +def find(recorded: list[RecordedRequest], method: str, path: str) -> list[RecordedRequest]: + """Filter recorded requests by method and exact path.""" + return [r for r in recorded if r.method == method and r.path == path] + + +@dataclass +class RecordedFlow: + """One completed OAuth connect: every recorded request, plus the parsed authorize URL params.""" + + requests: list[RecordedRequest] + authorize_url: str + + @property + def authorize(self) -> dict[str, str]: + return authorize_params(self.authorize_url) + + @property + def token_request(self) -> RecordedRequest: + token_posts = find(self.requests, "POST", "/token") + assert len(token_posts) == 1 + return token_posts[0] + + +@pytest.fixture +async def recorded_oauth_flow() -> AsyncIterator[RecordedFlow]: + """Run one full OAuth connect with default configuration and yield its recorded wire traffic. + + `valid_scopes` includes `offline_access` so the AS metadata advertises it and the SDK's + SEP-2207 auto-append (and the resulting `prompt=consent`) is exercised; `required_scopes` + stays at `["mcp"]` so the issued token still passes the bearer middleware. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + settings = auth_settings(required_scopes=["mcp"], valid_scopes=["mcp", "offline_access"]) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, settings=settings, on_request=on_request) as ( + client, + headless, + ): + await client.list_tools() + + assert headless.authorize_url is not None + yield RecordedFlow(requests=recorded, authorize_url=headless.authorize_url) + + +@requirement("client-auth:pkce:s256") +@requirement("client-auth:resource-parameter") +@requirement("client-auth:authorize:offline-access-consent") +async def test_the_authorize_url_carries_s256_pkce_and_the_resource_indicator( + recorded_oauth_flow: RecordedFlow, +) -> None: + """Every spec-mandated parameter appears on the authorize URL with the right value. + + The full key set is snapshotted so a parameter added or dropped fails the test. The + `code_challenge` length bound is the RFC 7636 §4.2 grammar; an S256 challenge is in + practice always 43 characters, so the upper bound is never approached. + """ + params = recorded_oauth_flow.authorize + + assert sorted(params) == snapshot( + [ + "client_id", + "code_challenge", + "code_challenge_method", + "prompt", + "redirect_uri", + "resource", + "response_type", + "scope", + "state", + ] + ) + assert params["response_type"] == "code" + assert params["code_challenge_method"] == "S256" + assert 43 <= len(params["code_challenge"]) <= 128 + # The exact resource value depends on canonical-URI normalisation (a spec ambiguity); pin + # the stable prefix so the test does not lock in a trailing-slash decision. + assert params["resource"].startswith(BASE_URL) + assert params["state"] != "" + + assert params["scope"].split(" ") == snapshot(["mcp", "offline_access"]) + assert params["prompt"] == "consent" + + +@requirement("client-auth:pkce:s256") +async def test_the_code_verifier_on_the_token_request_hashes_to_the_code_challenge( + recorded_oauth_flow: RecordedFlow, +) -> None: + """The PKCE verifier sent on /token is the S256 pre-image of the challenge sent on /authorize. + + The verifier is also checked against RFC 7636 §4.1's length and `unreserved` charset. + """ + challenge = recorded_oauth_flow.authorize["code_challenge"] + verifier = form_body(recorded_oauth_flow.token_request)["code_verifier"] + + assert re.fullmatch(r"[A-Za-z0-9._~-]{43,128}", verifier) + assert base64.urlsafe_b64encode(hashlib.sha256(verifier.encode()).digest()).decode().rstrip("=") == challenge + + +@requirement("client-auth:state:verify") +async def test_a_mismatched_state_on_the_callback_aborts_the_flow() -> None: + """A callback whose state does not match the value sent on /authorize raises and stops the flow. + + The auth flow runs inside the streamable-HTTP client's task group, so the `OAuthFlowError` + reaches the test wrapped in nested single-element exception groups; `pytest.RaisesGroup` + asserts the leaf type and the SDK-authored message prefix (the full message embeds two + random tokens). + """ + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + headless = HeadlessOAuth(state_override="wrong-state") + + with anyio.fail_after(5): + with pytest.RaisesGroup( + pytest.RaisesExc(OAuthFlowError, match="^State parameter mismatch:"), flatten_subgroups=True + ): + # Entering the connect raises during the OAuth handshake (inside `Client.__aenter__`), + # so an `async with` body would be unreachable; entering explicitly avoids dead code. + await connect_with_oauth(server, provider=provider, headless=headless).__aenter__() + + +@requirement("client-auth:resource-parameter") +async def test_the_authorization_code_token_request_carries_grant_type_code_redirect_and_resource( + recorded_oauth_flow: RecordedFlow, +) -> None: + """The /token form body has exactly the auth-code grant fields, with redirect_uri and resource matching /authorize. + + `client_secret` is present because the SDK's dynamic-registration handler issues a secret + and the client defaults to `client_secret_post`. + """ + token_req = recorded_oauth_flow.token_request + body = form_body(token_req) + + assert sorted(body) == snapshot( + ["client_id", "client_secret", "code", "code_verifier", "grant_type", "redirect_uri", "resource"] + ) + assert body["grant_type"] == "authorization_code" + assert body["code"] != "" + assert body["redirect_uri"] == recorded_oauth_flow.authorize["redirect_uri"] + assert body["resource"] == recorded_oauth_flow.authorize["resource"] + assert token_req.headers["content-type"] == "application/x-www-form-urlencoded" + + +@requirement("client-auth:bearer-header:every-request") +async def test_every_mcp_request_after_auth_carries_the_bearer_header_and_never_a_query_token( + recorded_oauth_flow: RecordedFlow, +) -> None: + """Every MCP request after the flow has `Authorization: Bearer ...` and never `?access_token=`. + + The first /mcp POST is the unauthenticated trigger and is asserted to carry no Authorization + header; that assertion is only meaningful because the recording snapshots requests at send + time (the SDK mutates the same request object in place for the retry). + """ + mcp_posts = find(recorded_oauth_flow.requests, "POST", "/mcp") + assert len(mcp_posts) >= 3 + + assert "authorization" not in mcp_posts[0].headers + for r in mcp_posts[1:]: + assert r.headers["authorization"].startswith("Bearer ") + assert r.headers["authorization"] != "Bearer " + assert "access_token" not in dict(r.url.params) + + +@requirement("client-auth:token-endpoint-auth-method") +async def test_a_client_with_a_secret_authenticates_the_token_request_with_http_basic() -> None: + """A `client_secret_basic` client sends URL-encoded credentials in HTTP Basic, not the body. + + Credentials are URL-encoded before base64 per RFC 6749 §2.3.1; the secret contains `/` so + the encoding is observable. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + client_info = OAuthClientInformationFull( + client_id="cid", + client_secret="s/cret", + token_endpoint_auth_method="client_secret_basic", + redirect_uris=[AnyUrl(REDIRECT_URI)], + grant_types=["authorization_code", "refresh_token"], + scope="mcp", + ) + await provider.register_client(client_info) + storage = InMemoryTokenStorage(client_info=client_info) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, storage=storage, on_request=on_request) as (client, _): + await client.list_tools() + + assert find(recorded, "POST", "/register") == [] + [token_req] = find(recorded, "POST", "/token") + + decoded = base64.b64decode(token_req.headers["authorization"].removeprefix("Basic ")).decode() + assert decoded == f"{quote('cid', safe='')}:{quote('s/cret', safe='')}" + assert "client_secret" not in form_body(token_req) + + +@requirement("client-auth:token-endpoint-auth-method") +async def test_the_registered_auth_method_is_used_regardless_of_as_metadata_advertised_methods() -> None: + """The token-endpoint auth method comes from the registered client info, not from AS metadata. + + The shim serves AS metadata advertising only `client_secret_basic`; the client dynamically + registers and the SDK's registration handler issues `client_secret_post`. The client uses + `client_secret_post` (secret in the body, no Basic header) because the SDK reads the + registered `token_endpoint_auth_method`, not `token_endpoint_auth_methods_supported`. Other + SDKs (TypeScript, Go) do consult the AS metadata; this test pins where the python SDK's + selection point lives. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + override = OAuthMetadata( + issuer=AnyHttpUrl(f"{BASE_URL}/"), + authorization_endpoint=AnyHttpUrl(f"{BASE_URL}/authorize"), + token_endpoint=AnyHttpUrl(f"{BASE_URL}/token"), + registration_endpoint=AnyHttpUrl(f"{BASE_URL}/register"), + scopes_supported=["mcp"], + grant_types_supported=["authorization_code", "refresh_token"], + code_challenge_methods_supported=["S256"], + token_endpoint_auth_methods_supported=["client_secret_basic"], + ) + serve = {ASM_PATH: override.model_dump_json(exclude_none=True).encode()} + + with anyio.fail_after(5): + async with connect_with_oauth( + server, provider=provider, app_shim=lambda app: shimmed_app(app, serve=serve), on_request=on_request + ) as (client, _): + await client.list_tools() + + [register] = find(recorded, "POST", "/register") + assert json.loads(register.content).get("token_endpoint_auth_method") is None + + [token_req] = find(recorded, "POST", "/token") + body = form_body(token_req) + assert "client_secret" in body + assert body["client_secret"] != "" + assert "authorization" not in token_req.headers + + +@requirement("client-auth:scope-selection:priority") +async def test_scope_is_selected_from_the_www_authenticate_challenge_over_prm_metadata() -> None: + """When the 401 challenge carries `scope=`, that value is requested instead of the PRM scopes. + + The SDK's bearer middleware never emits `scope=` in WWW-Authenticate (see the divergence + on `hosting:auth:scope-403`), so the test supplies the first 401 itself via + `first_challenge_shim` and disables token verification so the post-auth retry succeeds + regardless of the granted scope. PRM advertises `["from-prm"]` (it mirrors + `required_scopes`); the challenge says `from-header`; the authorize URL must carry + `from-header`. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider(default_scopes=["from-header"]) + server = Server("guarded", on_list_tools=list_tools) + settings = auth_settings(required_scopes=["from-prm"], valid_scopes=["from-header", "from-prm"]) + challenge = f'Bearer scope="from-header", resource_metadata="{BASE_URL}{PRM_PATH}"' + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + settings=settings, + verify_tokens=False, + app_shim=first_challenge_shim(challenge), + on_request=on_request, + ) as (client, headless): + await client.list_tools() + + assert headless.authorize_url is not None + assert authorize_params(headless.authorize_url)["scope"] == "from-header" + + [register] = find(recorded, "POST", "/register") + assert json.loads(register.content)["scope"] == "from-header" + + +@requirement("client-auth:pkce:refuse-if-unsupported") +async def test_pkce_is_still_sent_when_as_metadata_omits_code_challenge_methods_supported() -> None: + """AS metadata without `code_challenge_methods_supported` does not stop the client sending PKCE. + + The spec says the client MUST refuse to proceed in this case; the SDK proceeds and the flow + completes. See the divergence on the requirement. + """ + override = OAuthMetadata( + issuer=AnyHttpUrl(f"{BASE_URL}/"), + authorization_endpoint=AnyHttpUrl(f"{BASE_URL}/authorize"), + token_endpoint=AnyHttpUrl(f"{BASE_URL}/token"), + registration_endpoint=AnyHttpUrl(f"{BASE_URL}/register"), + scopes_supported=["mcp"], + grant_types_supported=["authorization_code", "refresh_token"], + ) + assert override.code_challenge_methods_supported is None + serve = {ASM_PATH: override.model_dump_json(exclude_none=True).encode()} + + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + with anyio.fail_after(5): + async with connect_with_oauth( + server, provider=provider, app_shim=lambda app: shimmed_app(app, serve=serve) + ) as (client, headless): + result = await client.list_tools() + + assert headless.authorize_url is not None + params = authorize_params(headless.authorize_url) + assert params["code_challenge_method"] == "S256" + assert params["code_challenge"] != "" + assert result.tools[0].name == "echo" + + +@requirement("client-auth:authorize:error-surfaces") +async def test_an_authorize_error_on_the_callback_aborts_the_flow_before_the_token_request() -> None: + """An `error=` redirect from /authorize aborts the flow with no /token request issued. + + The SDK's callback contract is `() -> (code, state)` with no error form, so the failure is + observed as an empty code reaching the SDK and `OAuthFlowError("No authorization code + received")` being raised. The actual `error` value from the redirect is not surfaced to the + caller; that gap is noted in the manifest. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider(deny_authorize=True) + server = Server("guarded", on_list_tools=list_tools) + headless = HeadlessOAuth() + + with anyio.fail_after(5): + with pytest.RaisesGroup( + pytest.RaisesExc(OAuthFlowError, match="^No authorization code received$"), flatten_subgroups=True + ): + await connect_with_oauth(server, provider=provider, headless=headless, on_request=on_request).__aenter__() + + assert headless.error == "access_denied" + assert find(recorded, "POST", "/token") == [] diff --git a/tests/interaction/auth/test_bearer.py b/tests/interaction/auth/test_bearer.py new file mode 100644 index 0000000000..341a8e0db9 --- /dev/null +++ b/tests/interaction/auth/test_bearer.py @@ -0,0 +1,189 @@ +"""Resource-server bearer-token gate: status codes and `WWW-Authenticate` for each token shape. + +These tests mount only the resource-server side of the auth wiring (a `StaticTokenVerifier` +seeded with hand-built tokens, no authorization-server provider) and speak raw HTTP, since +every assertion is about HTTP semantics the SDK `Client` cannot observe: the 401/403 status, +the `WWW-Authenticate` header structure, and that a wrong-audience token reaches the MCP +endpoint behind the gate. The flow side of the same 401 is `test_flow.py`'s flagship test. +""" + +import time +from collections.abc import AsyncIterator + +import httpx +import pytest +from inline_snapshot import snapshot + +from mcp.server import Server +from mcp.server.auth.provider import AccessToken +from mcp.types import JSONRPCResponse +from tests.interaction._connect import base_headers, initialize_body, mounted_app +from tests.interaction._requirements import requirement +from tests.interaction.auth._harness import StaticTokenVerifier, auth_settings + +pytestmark = pytest.mark.anyio + +REQUIRED_SCOPE = "mcp:read" +RESOURCE_METADATA_URL = "http://127.0.0.1:8000/.well-known/oauth-protected-resource/mcp" + +_FUTURE = int(time.time()) + 3600 +_PAST = int(time.time()) - 3600 + +TOKENS = { + "tok-valid": AccessToken(token="tok-valid", client_id="c", scopes=[REQUIRED_SCOPE], expires_at=_FUTURE), + "tok-expired": AccessToken(token="tok-expired", client_id="c", scopes=[REQUIRED_SCOPE], expires_at=_PAST), + "tok-noscope": AccessToken(token="tok-noscope", client_id="c", scopes=["other:thing"], expires_at=_FUTURE), + "tok-wrong-aud": AccessToken( + token="tok-wrong-aud", + client_id="c", + scopes=[REQUIRED_SCOPE], + expires_at=_FUTURE, + resource="https://other.example/mcp", + ), +} + + +@pytest.fixture +async def protected() -> AsyncIterator[httpx.AsyncClient]: + """A bearer-gated streamable-HTTP app (resource server only) on the in-process bridge.""" + server = Server("rs") + settings = auth_settings(required_scopes=[REQUIRED_SCOPE]) + async with mounted_app(server, auth=settings, token_verifier=StaticTokenVerifier(TOKENS)) as (http, _): + yield http + + +async def post_mcp( + http: httpx.AsyncClient, *, bearer: str | None = None, query: dict[str, str] | None = None +) -> httpx.Response: + """POST an initialize body to `/mcp`, optionally with a bearer token and/or a query string.""" + headers = base_headers() + if bearer is not None: + headers["authorization"] = f"Bearer {bearer}" + return await http.post("/mcp", headers=headers, params=query, json=initialize_body()) + + +def parse_www_authenticate(value: str) -> dict[str, str]: + """Parse a `Bearer k="v", k="v"` challenge into a dict. + + The SDK emits each parameter exactly once, comma-space separated, with double-quoted + values that contain no quotes themselves; this helper relies on that and would fail + visibly if the format changed. + """ + scheme, _, params = value.partition(" ") + assert scheme == "Bearer" + return {key: quoted.strip('"') for key, _, quoted in (pair.partition("=") for pair in params.split(", "))} + + +@requirement("hosting:auth:missing-401") +async def test_a_request_with_no_authorization_header_is_challenged_with_resource_metadata( + protected: httpx.AsyncClient, +) -> None: + """No `Authorization` header → 401 with a `WWW-Authenticate` carrying `resource_metadata`. + + The snapshot pins current behaviour: the SDK collapses the no-header, unknown-token, and + expired-token cases into one challenge (`error="invalid_token"`, no `scope` parameter). The + spec says the discovery-time challenge SHOULD include `scope` and RFC 6750 says the + no-credentials case SHOULD NOT carry an error code; both gaps are recorded as the divergence + on this requirement. Asserting the dict equals an exact key set also pins that no parameter + appears twice. + """ + response = await post_mcp(protected) + + assert response.status_code == 401 + assert response.headers["www-authenticate"] == snapshot( + 'Bearer error="invalid_token", error_description="Authentication required", ' + 'resource_metadata="http://127.0.0.1:8000/.well-known/oauth-protected-resource/mcp"' + ) + assert parse_www_authenticate(response.headers["www-authenticate"]) == { + "error": "invalid_token", + "error_description": "Authentication required", + "resource_metadata": RESOURCE_METADATA_URL, + } + assert response.json() == snapshot({"error": "invalid_token", "error_description": "Authentication required"}) + + +@requirement("hosting:auth:invalid-401") +async def test_an_unrecognized_bearer_token_is_answered_401_invalid_token(protected: httpx.AsyncClient) -> None: + """A token the verifier does not recognize is answered 401 `invalid_token`. + + The challenge is identical to the no-header case (the backend returns `None` for both); the + missing `scope` parameter is the recorded divergence on this requirement. + """ + response = await post_mcp(protected, bearer="tok-unknown") + + assert response.status_code == 401 + assert parse_www_authenticate(response.headers["www-authenticate"]) == { + "error": "invalid_token", + "error_description": "Authentication required", + "resource_metadata": RESOURCE_METADATA_URL, + } + + +@requirement("hosting:auth:expired-401") +async def test_an_expired_token_is_answered_401(protected: httpx.AsyncClient) -> None: + """A token whose `expires_at` is in the past is answered 401 `invalid_token`. + + The expiry check is the bearer backend's, against the wall clock; the test seeds a concrete + past timestamp so no time mocking is involved. The missing `scope` parameter is the recorded + divergence on this requirement. + """ + response = await post_mcp(protected, bearer="tok-expired") + + assert response.status_code == 401 + assert parse_www_authenticate(response.headers["www-authenticate"])["error"] == "invalid_token" + + +@requirement("hosting:auth:scope-403") +async def test_a_token_missing_a_required_scope_is_answered_403_insufficient_scope_without_a_scope_param( + protected: httpx.AsyncClient, +) -> None: + """A token lacking the required scope is answered 403 `insufficient_scope`, with no `scope` parameter. + + The spec's runtime-insufficient-scope guidance says the challenge SHOULD include `scope` + naming the required scope; the SDK never emits it, recorded as the divergence on this + requirement. The SDK client reads `scope` from this header to drive step-up, so the gap is + a resource-server/client asymmetry. + """ + response = await post_mcp(protected, bearer="tok-noscope") + + assert response.status_code == 403 + parsed = parse_www_authenticate(response.headers["www-authenticate"]) + assert parsed == { + "error": "insufficient_scope", + "error_description": f"Required scope: {REQUIRED_SCOPE}", + "resource_metadata": RESOURCE_METADATA_URL, + } + assert "scope" not in parsed + + +@requirement("hosting:auth:aud-validation") +async def test_a_token_with_a_mismatched_audience_is_accepted(protected: httpx.AsyncClient) -> None: + """A token whose `resource` does not match the server's resource identifier is accepted. + + The spec mandates the resource server validate the token's audience; the bearer backend + never inspects `AccessToken.resource`, so the request passes the gate and the MCP endpoint + serves it. This pins current behaviour with the divergence recorded on the requirement. + """ + response = await post_mcp(protected, bearer="tok-wrong-aud") + + assert response.status_code == 200 + assert response.headers["content-type"].startswith("text/event-stream") + # The body is finite SSE: a result event followed by stream close. Pull the JSON-RPC response + # out of the buffered text to prove the MCP endpoint actually answered the initialize request. + [data] = [line.removeprefix("data: ") for line in response.text.splitlines() if line.startswith("data: ")] + assert "protocolVersion" in JSONRPCResponse.model_validate_json(data).result + + +@requirement("hosting:auth:query-token-ignored") +async def test_an_access_token_in_the_query_string_is_not_accepted(protected: httpx.AsyncClient) -> None: + """A valid token presented in the URI query string is treated as no authentication. + + The bearer backend reads only the `Authorization` header, so `?access_token=...` is never + consulted; the request is treated as unauthenticated and answered 401. This satisfies, by + absence, the security best-practice that resource servers must not accept query-string + tokens. + """ + response = await post_mcp(protected, query={"access_token": "tok-valid"}) + + assert response.status_code == 401 + assert parse_www_authenticate(response.headers["www-authenticate"])["error"] == "invalid_token" diff --git a/tests/interaction/auth/test_discovery.py b/tests/interaction/auth/test_discovery.py new file mode 100644 index 0000000000..68c33c8a2d --- /dev/null +++ b/tests/interaction/auth/test_discovery.py @@ -0,0 +1,333 @@ +"""Protected-resource and authorization-server metadata discovery, end to end. + +Every client-side test connects a real `Client` via `connect_with_oauth` and asserts on the +recorded request paths the discovery probes produced; the discovery URL ordering is a wire +detail `Client` cannot observe directly but the recording can. Tests that need a metadata +endpoint to 404 or return alternate content wrap the SDK's app in `shimmed_app` while leaving +the real authorize and token endpoints behind it, so the rest of the flow runs unaltered. + +The two server-side tests (#5, #6) drive raw httpx against `mounted_app` because their +assertions are the metadata response bodies and headers, which `Client` does not surface. +""" + +import json + +import anyio +import pytest +from inline_snapshot import snapshot +from pydantic import AnyHttpUrl + +from mcp import types +from mcp.client.auth import OAuthFlowError, OAuthRegistrationError +from mcp.server import Server, ServerRequestContext +from mcp.shared.auth import OAuthMetadata, ProtectedResourceMetadata +from mcp.types import ListToolsResult, Tool +from tests.interaction._connect import BASE_URL, mounted_app +from tests.interaction._requirements import requirement +from tests.interaction.auth._harness import ( + RecordedRequest, + auth_settings, + connect_with_oauth, + metadata_body, + record_requests, + shim, +) +from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider + +pytestmark = pytest.mark.anyio + +PRM_PATH_SUFFIXED = "/.well-known/oauth-protected-resource/mcp" +PRM_ROOT = "/.well-known/oauth-protected-resource" +ASM_ROOT = "/.well-known/oauth-authorization-server" +OIDC_ROOT = "/.well-known/openid-configuration" + + +async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="probe", input_schema={"type": "object"})]) + + +def discovery_gets(recorded: list[RecordedRequest]) -> list[str]: + """Return the well-known GET paths in recorded order, ignoring everything else.""" + return [r.path for r in recorded if r.method == "GET" and "/.well-known/" in r.path] + + +def real_asm() -> OAuthMetadata: + """Build an authorization-server metadata document pointing at the real co-hosted endpoints.""" + return OAuthMetadata( + issuer=AnyHttpUrl(BASE_URL), + authorization_endpoint=AnyHttpUrl(f"{BASE_URL}/authorize"), + token_endpoint=AnyHttpUrl(f"{BASE_URL}/token"), + registration_endpoint=AnyHttpUrl(f"{BASE_URL}/register"), + scopes_supported=["mcp"], + grant_types_supported=["authorization_code", "refresh_token"], + code_challenge_methods_supported=["S256"], + ) + + +@requirement("client-auth:prm-discovery:fallback-order") +async def test_prm_discovery_uses_the_resource_metadata_url_from_www_authenticate() -> None: + """The first protected-resource probe is the URL the 401's `WWW-Authenticate` header supplied. + + With co-hosted defaults the header carries the path-suffixed well-known URL; the client + fetches that one first and, because it succeeds, never falls back. The single-probe + sequence proves priority 1. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, on_request=on_request) as (client, _): + await client.list_tools() + + assert discovery_gets(recorded) == snapshot([PRM_PATH_SUFFIXED, ASM_ROOT]) + assert (recorded[0].method, recorded[0].path) == ("POST", "/mcp") + assert (recorded[1].method, recorded[1].path) == ("GET", PRM_PATH_SUFFIXED) + + +@requirement("client-auth:prm-discovery:fallback-order") +async def test_prm_discovery_falls_back_from_path_well_known_to_root_on_404() -> None: + """When the path-suffixed PRM well-known 404s, the client falls back to the root well-known. + + The exact GET count is not asserted: the WWW-Authenticate URL equals the path well-known + here, so the SDK probes it twice (once as priority 1, once as priority 2) before reaching + root. Asserting "path before root, root reached, then the flow proceeds" pins the spec + invariant; the duplicate probe is an implementation detail. The served PRM body carries an + unrecognized field to prove the client's parser ignores unknown members (RFC 9728 §3.2). + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + prm = ProtectedResourceMetadata( + resource=AnyHttpUrl(f"{BASE_URL}/mcp"), authorization_servers=[AnyHttpUrl(BASE_URL)] + ) + app_shim = shim( + not_found=frozenset({PRM_PATH_SUFFIXED}), + serve={PRM_ROOT: metadata_body(prm, x_unknown_extension="ignored")}, + ) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, app_shim=app_shim, on_request=on_request) as ( + client, + _, + ): + await client.list_tools() + + well_known = discovery_gets(recorded) + assert PRM_PATH_SUFFIXED in well_known + assert PRM_ROOT in well_known + assert well_known.index(PRM_PATH_SUFFIXED) < well_known.index(PRM_ROOT) + assert any(r.path == "/authorize" for r in recorded) + + +@requirement("client-auth:prm-discovery:no-prm-fallback") +async def test_when_every_prm_probe_fails_the_client_discovers_as_metadata_at_the_server_origin() -> None: + """When every protected-resource metadata probe 404s, the client falls back to the legacy path. + + The legacy 2025-03-26 behaviour: with no PRM document available, treat the MCP server's + origin as the authorization server and fetch its `/.well-known/oauth-authorization-server` + directly. The real co-hosted ASM endpoint is at exactly that location, so the flow completes. + The recorded sequence shows both PRM well-known paths probed (and failed) before ASM_ROOT. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + app_shim = shim(not_found=frozenset({PRM_PATH_SUFFIXED, PRM_ROOT})) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, app_shim=app_shim, on_request=on_request) as ( + client, + _, + ): + result = await client.list_tools() + + well_known = discovery_gets(recorded) + assert PRM_PATH_SUFFIXED in well_known + assert PRM_ROOT in well_known + assert well_known[-1] == ASM_ROOT + assert all(well_known.index(prm) < well_known.index(ASM_ROOT) for prm in (PRM_PATH_SUFFIXED, PRM_ROOT)) + assert result.tools[0].name == "probe" + + +@requirement("client-auth:dcr:registration-error-surfaces") +async def test_a_400_from_the_registration_endpoint_surfaces_as_a_registration_error() -> None: + """A 400 from `/register` surfaces as `OAuthRegistrationError` carrying the server's body. + + The shim makes `/register` return RFC 7591's `invalid_client_metadata`; the SDK reads the + body and raises with the status and text in the message, before any authorize or token + request is made. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + error_body = json.dumps({"error": "invalid_client_metadata", "error_description": "no"}).encode() + app_shim = shim(serve={"/register": (400, error_body)}) + + with anyio.fail_after(5): + with pytest.RaisesGroup( + pytest.RaisesExc(OAuthRegistrationError, match=r"^Registration failed: 400 .*invalid_client_metadata"), + flatten_subgroups=True, + ): + await connect_with_oauth(server, provider=provider, app_shim=app_shim, on_request=on_request).__aenter__() + + assert [r.path for r in recorded if r.path in ("/authorize", "/token")] == [] + + +@requirement("client-auth:prm-resource-mismatch") +async def test_prm_with_a_mismatched_resource_aborts_the_flow_before_authorize() -> None: + """A PRM document whose `resource` does not cover the server URL aborts the flow. + + The shim serves PRM at the URL the WWW-Authenticate header supplies, but with a `resource` + on a different path; `check_resource_allowed` rejects it and `OAuthFlowError` is raised + before any authorize or token request is made. The error reaches the test wrapped in nested + single-element exception groups by the streamable-HTTP client's task group. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + prm = ProtectedResourceMetadata( + resource=AnyHttpUrl(f"{BASE_URL}/other"), authorization_servers=[AnyHttpUrl(BASE_URL)] + ) + app_shim = shim(serve={PRM_PATH_SUFFIXED: metadata_body(prm)}) + + with anyio.fail_after(5): + with pytest.RaisesGroup( + pytest.RaisesExc(OAuthFlowError, match="^Protected resource .* does not match expected"), + flatten_subgroups=True, + ): + await connect_with_oauth(server, provider=provider, app_shim=app_shim, on_request=on_request).__aenter__() + + assert [r.path for r in recorded if r.path in ("/authorize", "/token")] == [] + + +@requirement("client-auth:as-metadata-discovery:priority-order") +@pytest.mark.parametrize( + ("authorization_server", "not_found", "serve_at", "expected_order"), + [ + pytest.param( + f"{BASE_URL}/", + frozenset({ASM_ROOT}), + OIDC_ROOT, + [ASM_ROOT, OIDC_ROOT], + id="root-issuer", + ), + pytest.param( + f"{BASE_URL}/tenant", + frozenset({f"{ASM_ROOT}/tenant", f"{OIDC_ROOT}/tenant"}), + "/tenant/.well-known/openid-configuration", + [f"{ASM_ROOT}/tenant", f"{OIDC_ROOT}/tenant", "/tenant/.well-known/openid-configuration"], + id="path-issuer", + ), + ], +) +async def test_as_metadata_discovery_falls_back_through_the_spec_endpoint_order( + authorization_server: str, not_found: frozenset[str], serve_at: str, expected_order: list[str] +) -> None: + """Authorization-server metadata is fetched at the spec's endpoints in the spec's order. + + The shim 404s every endpoint before the last so the recording proves each probe and its + position. For an issuer URL with no path the order is OAuth root then OIDC root; for an + issuer URL with a path component it is OAuth path-inserted, OIDC path-inserted, then OIDC + path-appended (the spec's three-endpoint MUST). The path-issuer case is driven by serving + a PRM whose `authorization_servers` carries the path; the SDK's own AS routes stay at root + (the served body points at the real `/authorize` and `/token`). The served bodies carry an + unrecognized field to prove the client's parser ignores unknown members (RFC 8414 §3.2). + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + prm = ProtectedResourceMetadata( + resource=AnyHttpUrl(f"{BASE_URL}/mcp"), authorization_servers=[AnyHttpUrl(authorization_server)] + ) + app_shim = shim( + not_found=not_found, + serve={ + PRM_PATH_SUFFIXED: metadata_body(prm), + serve_at: metadata_body(real_asm(), x_unknown_extension="ignored"), + }, + ) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, app_shim=app_shim, on_request=on_request) as ( + client, + _, + ): + await client.list_tools() + + assert discovery_gets(recorded) == [PRM_PATH_SUFFIXED, *expected_order] + + +@requirement("hosting:auth:metadata-endpoints") +@requirement("hosting:auth:prm:authorization-servers-field") +async def test_the_prm_endpoint_serves_the_resource_url_and_at_least_one_authorization_server() -> None: + """The protected-resource metadata document the SDK serves identifies the resource and an authorization server. + + Also asserts the response is `application/json` (RFC 9728 §3.2) and that fields the SDK has + no value for are absent rather than null (`PydanticJSONResponse` serializes with + `exclude_none=True`, satisfying RFC 9728 §3.2's omit-zero-value rule). + """ + server = Server("bare") + provider = InMemoryAuthorizationServerProvider() + + async with mounted_app(server, auth=auth_settings(), auth_server_provider=provider) as (http, _): + response = await http.get(PRM_PATH_SUFFIXED) + + assert response.status_code == 200 + assert response.headers["content-type"].startswith("application/json") + + document = json.loads(response.content) + assert "resource_documentation" not in document + assert "scopes_supported" in document + + metadata = ProtectedResourceMetadata.model_validate(document) + assert str(metadata.resource).rstrip("/") == f"{BASE_URL}/mcp" + assert len(metadata.authorization_servers) >= 1 + assert metadata.bearer_methods_supported == ["header"] + + +@requirement("hosting:auth:as-router") +async def test_as_metadata_advertises_authorize_token_registration_and_s256() -> None: + """The authorization-server metadata document the SDK serves names the required endpoints and S256.""" + server = Server("bare") + provider = InMemoryAuthorizationServerProvider() + + async with mounted_app(server, auth=auth_settings(), auth_server_provider=provider) as (http, _): + response = await http.get(ASM_ROOT) + + assert response.status_code == 200 + assert response.headers["content-type"].startswith("application/json") + + metadata = OAuthMetadata.model_validate_json(response.content) + assert str(metadata.issuer).rstrip("/") == BASE_URL + assert str(metadata.authorization_endpoint) == f"{BASE_URL}/authorize" + assert str(metadata.token_endpoint) == f"{BASE_URL}/token" + assert str(metadata.registration_endpoint) == f"{BASE_URL}/register" + assert metadata.response_types_supported == ["code"] + assert metadata.code_challenge_methods_supported is not None + assert "S256" in metadata.code_challenge_methods_supported + + +@requirement("client-auth:as-metadata-discovery:issuer-validation") +async def test_as_metadata_with_a_mismatched_issuer_is_accepted_and_the_flow_proceeds() -> None: + """Authorization-server metadata whose `issuer` does not match the discovery URL is accepted. + + RFC 8414 §3.3 requires the client to reject the document; the SDK parses and uses it + without comparing `issuer` to the URL it was fetched from. See the divergence on the + requirement. The served body carries an unrecognized field as a fold-in proof of + unknown-field tolerance. + """ + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + metadata = real_asm() + metadata.issuer = AnyHttpUrl(f"{BASE_URL}/wrong-issuer") + app_shim = shim(serve={ASM_ROOT: metadata_body(metadata, x_unknown_extension="ignored")}) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, app_shim=app_shim) as (client, _): + result = await client.list_tools() + + assert result.tools[0].name == "probe" diff --git a/tests/interaction/auth/test_flow.py b/tests/interaction/auth/test_flow.py new file mode 100644 index 0000000000..968fc5f980 --- /dev/null +++ b/tests/interaction/auth/test_flow.py @@ -0,0 +1,239 @@ +"""End-to-end OAuth authorization-code flow against the SDK's own server, fully in process. + +Auth is HTTP-only so these tests are not transport-parametrized; each connects via +`connect_with_oauth`, which co-hosts the SDK's authorization server, protected-resource +metadata, and bearer-gated MCP endpoint on one bridge-backed Starlette app and drives the +whole flow through one `httpx.AsyncClient` carrying the SDK's `OAuthClientProvider`. The +authorize redirect completes headlessly through the same bridge, so every request the flow +makes is observable via `on_request`. +""" + +import json +from collections import Counter +from urllib.parse import parse_qs, urlsplit + +import anyio +import httpx +import pytest +from inline_snapshot import snapshot +from pydantic import AnyUrl + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.server.auth.middleware.auth_context import get_access_token +from mcp.shared.auth import OAuthClientInformationFull +from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool +from tests.interaction._connect import BASE_URL +from tests.interaction._requirements import requirement +from tests.interaction.auth._harness import ( + REDIRECT_URI, + InMemoryTokenStorage, + auth_settings, + connect_with_oauth, + oauth_client_metadata, + shimmed_app, +) +from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider +from tests.interaction.transports._bridge import StreamingASGITransport + +pytestmark = pytest.mark.anyio + + +async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="whoami", input_schema={"type": "object"})]) + + +@requirement("flow:oauth:authorization-code-roundtrip") +@requirement("client-auth:401-triggers-flow") +@requirement("hosting:auth:missing-401") +async def test_an_unauthenticated_request_is_challenged_then_the_full_oauth_flow_connects() -> None: + """Connecting to a bearer-gated server walks the full authorization-code flow and succeeds. + + Three requirements are proven by one connect: the flow runs end to end (authorization-code + roundtrip), it was triggered by a 401 on the first MCP request (401-triggers-flow), and + that 401 carried `resource_metadata` in `WWW-Authenticate` for discovery (missing-401). + The flagship test pins the recorded request sequence so the discovery → registration → + authorize → token → retry order is asserted explicitly. + + Steps the SDK is expected to perform: + 1. POST /mcp without a token → 401 with `WWW-Authenticate: Bearer resource_metadata=...`. + 2. GET the protected-resource metadata. + 3. GET the authorization-server metadata. + 4. POST /register (dynamic client registration). + 5. GET /authorize → 302 with code+state (completed by the headless redirect). + 6. POST /token (authorization-code exchange). + 7. Retry POST /mcp with `Authorization: Bearer ` → succeeds. + """ + requests: list[httpx.Request] = [] + provider = InMemoryAuthorizationServerProvider() + storage = InMemoryTokenStorage() + server = Server("guarded", on_list_tools=list_tools) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, storage=storage, on_request=requests.append) as ( + client, + headless, + ): + result = await client.list_tools() + + assert result == snapshot(ListToolsResult(tools=[Tool(name="whoami", input_schema={"type": "object"})])) + assert headless.authorize_url is not None + + paths = [(r.method, r.url.path) for r in requests] + assert Counter(paths) == snapshot( + Counter( + { + ("POST", "/mcp"): 4, + ("GET", "/.well-known/oauth-protected-resource/mcp"): 1, + ("GET", "/.well-known/oauth-authorization-server"): 1, + ("POST", "/register"): 1, + ("GET", "/authorize"): 1, + ("POST", "/token"): 1, + ("GET", "/mcp"): 1, + ("DELETE", "/mcp"): 1, + } + ) + ) + + assert (requests[0].method, requests[0].url.path) == ("POST", "/mcp") + # The recorded Request objects are live references: the auth flow mutates the original + # request's headers in place when it adds the bearer token for the retry, so the first + # entry's headers cannot be used to assert "no Authorization on the first attempt". The + # path multiset above proving discovery happened is the evidence the first attempt was 401. + + # The first PRM discovery GET carries the protocol-version header (an SDK behaviour, not a + # spec requirement on discovery requests). + prm_get = next(r for r in requests if r.url.path == "/.well-known/oauth-protected-resource/mcp") + assert prm_get.headers.get("mcp-protocol-version") == snapshot("2025-11-25") + + authorize = parse_qs(urlsplit(headless.authorize_url).query) + assert authorize["response_type"] == ["code"] + assert authorize["code_challenge_method"] == ["S256"] + assert authorize["client_id"][0] in provider.clients + + assert storage.tokens is not None + bearer = f"Bearer {storage.tokens.access_token}" + authed_mcp = [r for r in requests if r.url.path == "/mcp" and r.headers.get("authorization") == bearer] + assert len(authed_mcp) > 0 + assert storage.tokens.access_token in provider.access_tokens + + +@requirement("hosting:auth:authinfo-propagates") +async def test_the_access_token_reaches_the_tool_handler_via_get_access_token() -> None: + """A tool handler reads the request's access token through `get_access_token()`.""" + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "whoami" + token = get_access_token() + assert token is not None + return CallToolResult(content=[TextContent(text=" ".join(token.scopes))]) + + server = Server("guarded", on_list_tools=list_tools, on_call_tool=call_tool) + provider = InMemoryAuthorizationServerProvider() + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider) as (client, _): + result = await client.call_tool("whoami", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="mcp")])) + + +@requirement("client-auth:pre-registration") +async def test_a_preregistered_client_skips_registration() -> None: + """A client whose storage already holds client info uses it instead of registering. + + The provider holds the same registration server-side so the authorize and token steps + accept it; the recorded requests prove no `/register` call was made. + """ + requests: list[httpx.Request] = [] + provider = InMemoryAuthorizationServerProvider() + storage = InMemoryTokenStorage() + server = Server("guarded", on_list_tools=list_tools) + + client_info = OAuthClientInformationFull( + client_id="preregistered", + client_secret="s3cret", + token_endpoint_auth_method="client_secret_post", + redirect_uris=[AnyUrl(REDIRECT_URI)], + grant_types=["authorization_code", "refresh_token"], + scope="mcp", + ) + await provider.register_client(client_info) + storage.client_info = client_info + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, storage=storage, on_request=requests.append) as ( + client, + _, + ): + await client.list_tools() + + assert [r.url.path for r in requests].count("/register") == 0 + assert list(provider.clients) == ["preregistered"] + + +@requirement("client-auth:dcr") +async def test_the_dcr_request_carries_the_client_metadata() -> None: + """Dynamic registration sends the client's metadata and persists what the server issued. + + The body of the recorded `/register` POST carries the metadata the test supplied (with the + scope filled in from server discovery), and the server's issued client_id and secret are + persisted to storage and held by the provider. + """ + requests: list[httpx.Request] = [] + provider = InMemoryAuthorizationServerProvider() + storage = InMemoryTokenStorage() + server = Server("guarded", on_list_tools=list_tools) + + client_metadata = oauth_client_metadata() + client_metadata.software_id = "interaction-test-suite" + + with anyio.fail_after(5): + async with connect_with_oauth( + server, provider=provider, storage=storage, client_metadata=client_metadata, on_request=requests.append + ) as (client, _): + await client.list_tools() + + register = next(r for r in requests if r.url.path == "/register") + assert register.headers["content-type"] == "application/json" + body = json.loads(register.content) + assert body == snapshot( + { + "redirect_uris": ["http://127.0.0.1:8000/oauth/callback"], + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "scope": "mcp", + "client_name": "interaction-suite", + "software_id": "interaction-test-suite", + } + ) + + assert storage.client_info is not None + assert storage.client_info.client_id is not None + assert storage.client_info.client_secret is not None + assert list(provider.clients) == [storage.client_info.client_id] + + +async def test_shimmed_app_serves_overrides_404s_and_otherwise_forwards_to_the_wrapped_app() -> None: + """Harness self-test: `shimmed_app` serves canned bodies, 404s, and forwards everything else. + + Wraps a real auth-hosting Starlette app so the forward path is exercised against the SDK's + own routing; provided here so the discovery tests can rely on the shim without each adding + their own contract test. + """ + server = Server("bare") + provider = InMemoryAuthorizationServerProvider() + real_app = server.streamable_http_app(auth=auth_settings(), auth_server_provider=provider) + app = shimmed_app(real_app, not_found=frozenset({"/missing"}), serve={"/override": b'{"shimmed": true}'}) + async with server.session_manager.run(): + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http: + served = await http.get("/override") + assert served.status_code == 200 + assert served.headers["content-type"] == "application/json" + assert served.json() == {"shimmed": True} + + assert (await http.get("/missing")).status_code == 404 + + forwarded = await http.get("/.well-known/oauth-authorization-server") + assert forwarded.status_code == 200 + assert forwarded.json()["issuer"] == "http://127.0.0.1:8000/" diff --git a/tests/interaction/auth/test_lifecycle.py b/tests/interaction/auth/test_lifecycle.py new file mode 100644 index 0000000000..aa552ae8a6 --- /dev/null +++ b/tests/interaction/auth/test_lifecycle.py @@ -0,0 +1,445 @@ +"""Token lifecycle, step-up, and registration-variant flows of the SDK's OAuth client. + +Every test connects end to end via `connect_with_oauth`; the assertions are recording-first +(the recorded request sequence is asserted before, or independently of, the call result), so a +surprise in the refresh or step-up paths produces a readable diff of what fired rather than an +opaque failure. The provider knobs that drive each scenario are documented per test. +""" + +import base64 +from collections import Counter +from urllib.parse import parse_qsl, urlsplit + +import anyio +import pytest +from inline_snapshot import snapshot +from pydantic import AnyHttpUrl, AnyUrl + +from mcp import MCPError, types +from mcp.client.auth.extensions.client_credentials import ClientCredentialsOAuthProvider, PrivateKeyJWTOAuthProvider +from mcp.server import Server, ServerRequestContext +from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata +from mcp.types import INTERNAL_ERROR, ListToolsResult, Tool +from tests.interaction._connect import BASE_URL +from tests.interaction._requirements import requirement +from tests.interaction.auth._harness import ( + REDIRECT_URI, + InMemoryTokenStorage, + RecordedRequest, + auth_settings, + connect_with_oauth, + m2m_token_shim, + metadata_body, + record_requests, + shim, + step_up_shim, +) +from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider + +pytestmark = pytest.mark.anyio + +PRM_PATH = "/.well-known/oauth-protected-resource/mcp" +ASM_PATH = "/.well-known/oauth-authorization-server" +CIMD_URL = "https://client.example/.well-known/mcp-client" + + +async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="echo", input_schema={"type": "object"})]) + + +def form_body(request: RecordedRequest) -> dict[str, str]: + """Parse an `application/x-www-form-urlencoded` request body into a flat dict.""" + return dict(parse_qsl(request.content.decode())) + + +def authorize_params(authorize_url: str) -> dict[str, str]: + """Parse the authorize URL's query string into a flat dict.""" + return dict(parse_qsl(urlsplit(authorize_url).query)) + + +def find(recorded: list[RecordedRequest], method: str, path: str) -> list[RecordedRequest]: + return [r for r in recorded if r.method == method and r.path == path] + + +def path_counts(recorded: list[RecordedRequest]) -> Counter[tuple[str, str]]: + return Counter((r.method, r.path) for r in recorded) + + +def cimd_supported_metadata() -> bytes: + """AS metadata advertising `client_id_metadata_document_supported: true` (the SDK server never sets it).""" + metadata = OAuthMetadata( + issuer=AnyHttpUrl(f"{BASE_URL}/"), + authorization_endpoint=AnyHttpUrl(f"{BASE_URL}/authorize"), + token_endpoint=AnyHttpUrl(f"{BASE_URL}/token"), + registration_endpoint=AnyHttpUrl(f"{BASE_URL}/register"), + scopes_supported=["mcp"], + response_types_supported=["code"], + grant_types_supported=["authorization_code", "refresh_token"], + code_challenge_methods_supported=["S256"], + client_id_metadata_document_supported=True, + ) + return metadata_body(metadata) + + +def seeded_client(provider: InMemoryAuthorizationServerProvider, **kwargs: object) -> OAuthClientInformationFull: + """Register a client with the provider and return its info, for pre-registration and CIMD scenarios.""" + base: dict[str, object] = { + "client_id": "preregistered", + "token_endpoint_auth_method": "none", + "redirect_uris": [AnyUrl(REDIRECT_URI)], + "grant_types": ["authorization_code", "refresh_token"], + "scope": "mcp", + } + base.update(kwargs) + info = OAuthClientInformationFull.model_validate(base) + assert info.client_id is not None + provider.clients[info.client_id] = info + return info + + +@requirement("client-auth:refresh:transparent") +async def test_an_expired_access_token_is_transparently_refreshed_before_the_next_request() -> None: + """An access token the client considers expired is refreshed and the new bearer is used. + + The provider tells the client `expires_in=-3600` for the first token while keeping the + server-side `expires_at` in the future, so the connect's retry succeeds and the next + request finds the token expired and refreshes. The recorded requests prove exactly one + `grant_type=refresh_token` exchange carrying the resource indicator, and the bearer used + after the refresh is the second access token, which is the one persisted to storage. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider(issue_expired_first=True) + storage = InMemoryTokenStorage() + server = Server("guarded", on_list_tools=list_tools) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, storage=storage, on_request=on_request) as (client, _): + result = await client.list_tools() + + assert result.tools[0].name == "echo" + + token_posts = find(recorded, "POST", "/token") + bodies = [form_body(r) for r in token_posts] + assert [b["grant_type"] for b in bodies] == snapshot(["authorization_code", "refresh_token"]) + + refresh_body = bodies[1] + assert sorted(refresh_body) == snapshot(["client_id", "client_secret", "grant_type", "refresh_token", "resource"]) + assert refresh_body["refresh_token"].startswith("refresh_") + assert refresh_body["resource"].startswith(BASE_URL) + + bearers = {r.headers["authorization"] for r in recorded if r.path == "/mcp" and "authorization" in r.headers} + assert len(bearers) == 2 + assert storage.tokens is not None + assert f"Bearer {storage.tokens.access_token}" in bearers + assert storage.tokens.expires_in == 3600 + + +@requirement("client-auth:403-scope-upgrade") +async def test_a_403_insufficient_scope_triggers_one_reauthorize_with_the_challenged_scope() -> None: + """A 403 `insufficient_scope` challenge is answered by one re-authorize with the challenge's scope. + + The shim 403s the second authenticated `/mcp` POST (the `notifications/initialized` request, + which reaches the auth flow's step-up handler; the first authenticated POST is the post-401 + retry, after which the generator ends without inspecting the response). The challenge names a + wider scope; step-up reuses cached metadata and the existing client registration, + re-authorizes with the new scope, and the connect completes. The client is pre-registered + with both scopes so the server's authorize handler accepts the wider second request. One + re-authorize, one retry; the spec's SHOULD-retry-limit ("a few") is not enforced. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + storage = InMemoryTokenStorage(client_info=seeded_client(provider, scope="mcp write")) + server = Server("guarded", on_list_tools=list_tools) + settings = auth_settings(required_scopes=["mcp"], valid_scopes=["mcp", "write"]) + challenge = 'Bearer error="insufficient_scope", scope="mcp write"' + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + storage=storage, + settings=settings, + app_shim=step_up_shim(challenge), + on_request=on_request, + ) as (client, headless): + result = await client.list_tools() + + assert result.tools[0].name == "echo" + + assert len(headless.authorize_urls) == 2 + assert authorize_params(headless.authorize_urls[0])["scope"] == "mcp" + assert authorize_params(headless.authorize_urls[1])["scope"] == "mcp write" + + counts = path_counts(recorded) + assert counts[("GET", PRM_PATH)] == 1 + assert counts[("GET", ASM_PATH)] == 1 + assert counts[("POST", "/register")] == 0 + assert counts[("GET", "/authorize")] == 2 + assert counts[("POST", "/token")] == 2 + + +@requirement("client-auth:401-after-auth-throws") +async def test_a_second_401_after_a_completed_oauth_flow_surfaces_without_looping() -> None: + """A 401 on the post-auth retry surfaces as an error rather than re-entering discovery. + + The provider rejects every token at verification, so the full flow runs once and the retry + is 401'd. The auth-flow generator ends after that retry, so the 401 propagates and the + transport converts it to an INTERNAL_ERROR result, raising during connect. Discovery, + registration, authorize, and token each ran exactly once: no loop. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider(reject_all_tokens=True) + server = Server("guarded", on_list_tools=list_tools) + + def is_internal_error(error: MCPError) -> bool: + return error.error.code == INTERNAL_ERROR + + with anyio.fail_after(5): + with pytest.RaisesGroup(pytest.RaisesExc(MCPError, check=is_internal_error), flatten_subgroups=True): + # Entering the connect raises during the OAuth handshake (inside `Client.__aenter__`), + # so an `async with` body would be unreachable; entering explicitly avoids dead code. + await connect_with_oauth(server, provider=provider, on_request=on_request).__aenter__() + + counts = path_counts(recorded) + assert counts[("GET", PRM_PATH)] == 1 + assert counts[("GET", ASM_PATH)] == 1 + assert counts[("POST", "/register")] == 1 + assert counts[("GET", "/authorize")] == 1 + assert counts[("POST", "/token")] == 1 + assert counts[("POST", "/mcp")] == 2 + + +@requirement("client-auth:cimd") +async def test_cimd_is_selected_when_the_as_advertises_support_and_a_metadata_url_is_supplied() -> None: + """A client-ID metadata-document URL is used as `client_id` instead of registering. + + AS metadata is shimmed to advertise `client_id_metadata_document_supported: true`; the + provider is pre-seeded so the server's authorize and token handlers accept the URL as a + client_id (the SDK server has no CIMD-aware client lookup of its own). The recorded + requests prove no `/register` call, the authorize URL's `client_id` is the CIMD URL, the + token request uses `token_endpoint_auth_method=none`, and storage persists the URL as + `client_id`. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + seeded_client(provider, client_id=CIMD_URL) + storage = InMemoryTokenStorage() + server = Server("guarded", on_list_tools=list_tools) + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + storage=storage, + client_metadata_url=CIMD_URL, + app_shim=shim(serve={ASM_PATH: cimd_supported_metadata()}), + on_request=on_request, + ) as (client, headless): + await client.list_tools() + + assert find(recorded, "POST", "/register") == [] + assert headless.authorize_url is not None + assert authorize_params(headless.authorize_url)["client_id"] == CIMD_URL + + [token_req] = find(recorded, "POST", "/token") + body = form_body(token_req) + assert body["client_id"] == CIMD_URL + assert "client_secret" not in body + assert "authorization" not in token_req.headers + + assert storage.client_info is not None + assert storage.client_info.client_id == CIMD_URL + assert storage.client_info.token_endpoint_auth_method == "none" + + +@requirement("client-auth:invalid-grant-clears-tokens") +async def test_a_failed_refresh_clears_stored_tokens_and_restarts_the_full_flow() -> None: + """A non-200 refresh response clears the in-memory tokens and the flow re-runs from discovery. + + The first token is reported expired so the next request refreshes; the provider denies the + refresh once with `invalid_grant`, the auth flow clears its tokens, the unauthenticated + request 401s, and discovery, authorize, and token run again. The original registration is + preserved (`client_info` is not cleared). The SDK clears tokens on any non-200 refresh + response, not specifically `error=invalid_grant`; `source="sdk"` so this is a precision + note rather than a divergence. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider(issue_expired_first=True, fail_next_refresh=True) + storage = InMemoryTokenStorage() + server = Server("guarded", on_list_tools=list_tools) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, storage=storage, on_request=on_request) as (client, _): + result = await client.list_tools() + + assert result.tools[0].name == "echo" + + token_posts = find(recorded, "POST", "/token") + assert [form_body(r)["grant_type"] for r in token_posts] == snapshot( + ["authorization_code", "refresh_token", "authorization_code"] + ) + + counts = path_counts(recorded) + assert counts[("POST", "/register")] == 1 + assert counts[("GET", "/authorize")] == 2 + assert counts[("GET", PRM_PATH)] == 2 + assert counts[("GET", ASM_PATH)] == 2 + + assert storage.client_info is not None + assert storage.tokens is not None + assert storage.tokens.access_token in provider.access_tokens + + +@requirement("client-auth:client-credentials") +async def test_client_credentials_provider_obtains_a_token_without_an_authorize_step() -> None: + """The client-credentials provider connects with no authorize step and a `client_credentials` grant. + + The SDK server's `TokenHandler` does not route `client_credentials`, so the harness shim + handles it (the shim is harness; the SDK-under-test is the client provider). The recorded + `/token` body proves the grant type, scope, resource indicator, and HTTP-Basic client + authentication; no `/authorize` or `/register` request was made. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + auth = ClientCredentialsOAuthProvider( + server_url=f"{BASE_URL}/mcp", + storage=InMemoryTokenStorage(), + client_id="m2m-client", + client_secret="m2m-secret", + scopes="mcp", + ) + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + auth=auth, + app_shim=m2m_token_shim(provider, scopes=["mcp"]), + on_request=on_request, + ) as (client, headless): + result = await client.list_tools() + + assert result.tools[0].name == "echo" + assert headless.authorize_url is None + assert find(recorded, "GET", "/authorize") == [] + assert find(recorded, "POST", "/register") == [] + + [token_req] = find(recorded, "POST", "/token") + body = form_body(token_req) + assert body == snapshot( + {"grant_type": "client_credentials", "resource": "http://127.0.0.1:8000/mcp", "scope": "mcp"} + ) + decoded = base64.b64decode(token_req.headers["authorization"].removeprefix("Basic ")).decode() + assert decoded == "m2m-client:m2m-secret" + + +@requirement("client-auth:private-key-jwt") +async def test_private_key_jwt_provider_authenticates_the_token_request_with_an_assertion() -> None: + """The private-key-JWT provider sends a `client_assertion` on the token request, with the issuer as audience. + + The assertion provider is a closure that records the audience it was called with and returns + a fixed opaque value (the JWT contents are not the SDK's concern here); the test asserts the + `client_assertion`/`client_assertion_type` form fields and that the audience matches the AS + metadata's issuer. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + audiences: list[str] = [] + + async def assertion_provider(audience: str) -> str: + audiences.append(audience) + return "header.payload.sig" + + auth = PrivateKeyJWTOAuthProvider( + server_url=f"{BASE_URL}/mcp", + storage=InMemoryTokenStorage(), + client_id="m2m-jwt-client", + assertion_provider=assertion_provider, + scopes="mcp", + ) + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + auth=auth, + app_shim=m2m_token_shim(provider, scopes=["mcp"]), + on_request=on_request, + ) as (client, _): + result = await client.list_tools() + + assert result.tools[0].name == "echo" + assert audiences == [f"{BASE_URL}/"] + + [token_req] = find(recorded, "POST", "/token") + body = form_body(token_req) + assert body == snapshot( + { + "grant_type": "client_credentials", + "client_assertion": "header.payload.sig", + "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", + "resource": "http://127.0.0.1:8000/mcp", + "scope": "mcp", + } + ) + assert "client_secret" not in body + assert "authorization" not in token_req.headers + + +@pytest.mark.parametrize( + ("case", "preseed_storage", "advertise_cimd"), + [("cimd_unsupported_falls_through_to_dcr", False, False), ("preregistered_beats_cimd", True, True)], + ids=["cimd_unsupported_falls_through_to_dcr", "preregistered_beats_cimd"], +) +@requirement("client-auth:cimd") +async def test_registration_priority_prefers_preregistered_then_cimd_then_dcr( + case: str, preseed_storage: bool, advertise_cimd: bool +) -> None: + """The client picks pre-registration over CIMD over DCR, falling through when each is unavailable. + + Two priority edges are exercised: with a CIMD URL configured but no AS support, DCR runs and + the registered `client_id` is used; with a CIMD URL configured and AS support but a + pre-registered client in storage, the stored `client_id` is used and neither CIMD nor DCR + runs. (The positive CIMD case and pre-registration over DCR are covered by their own tests.) + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + storage = InMemoryTokenStorage() + + expected_client_id: str + if preseed_storage: + info = seeded_client(provider) + storage.client_info = info + assert info.client_id is not None + expected_client_id = info.client_id + else: + expected_client_id = "" + + app_shim = shim(serve={ASM_PATH: cimd_supported_metadata()}) if advertise_cimd else None + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + storage=storage, + client_metadata_url=CIMD_URL, + app_shim=app_shim, + on_request=on_request, + ) as (client, headless): + await client.list_tools() + + assert headless.authorize_url is not None + chosen_client_id = authorize_params(headless.authorize_url)["client_id"] + assert chosen_client_id != CIMD_URL + + if case == "cimd_unsupported_falls_through_to_dcr": + assert len(find(recorded, "POST", "/register")) == 1 + assert chosen_client_id in provider.clients + else: + assert find(recorded, "POST", "/register") == [] + assert chosen_client_id == expected_client_id diff --git a/tests/interaction/conftest.py b/tests/interaction/conftest.py new file mode 100644 index 0000000000..c2ace45077 --- /dev/null +++ b/tests/interaction/conftest.py @@ -0,0 +1,23 @@ +"""Shared fixtures for the interaction suite.""" + +import pytest + +from tests.interaction._connect import Connect, connect_in_memory, connect_over_sse, connect_over_streamable_http + +_FACTORIES: dict[str, Connect] = { + "in-memory": connect_in_memory, + "streamable-http": connect_over_streamable_http, + "sse": connect_over_sse, +} + + +@pytest.fixture(params=sorted(_FACTORIES)) +def connect(request: pytest.FixtureRequest) -> Connect: + """The transport-parametrized connection factory: a test using it runs once per transport. + + Tests that are tied to one transport (the wire-recording tests, the bare-ClientSession tests, + the transport-specific tests under transports/) do not use this fixture and connect directly. + """ + transport_name = request.param + assert isinstance(transport_name, str) + return _FACTORIES[transport_name] diff --git a/tests/interaction/lowlevel/__init__.py b/tests/interaction/lowlevel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/interaction/lowlevel/test_cancellation.py b/tests/interaction/lowlevel/test_cancellation.py new file mode 100644 index 0000000000..6f1454e58a --- /dev/null +++ b/tests/interaction/lowlevel/test_cancellation.py @@ -0,0 +1,234 @@ +"""Cancellation interactions against the low-level Server, driven through the public Client API. + +There is no client-side cancellation API: cancelling means sending a CancelledNotification +carrying the request id, which only the server-side handler can observe (`ctx.request_id`), so +these tests capture the id from inside the blocked handler before cancelling. The handler blocks +on an Event rather than a sleep, and every wait is bounded by `anyio.fail_after`. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client import ClientSession +from mcp.server import Server, ServerRequestContext +from mcp.shared.memory import MessageStream, create_client_server_memory_streams +from mcp.shared.message import SessionMessage +from mcp.types import ( + CallToolResult, + EmptyResult, + ErrorData, + Implementation, + InitializeResult, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + PingRequest, + ServerCapabilities, + TextContent, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("protocol:cancel:in-flight") +@requirement("protocol:cancel:handler-abort-propagates") +async def test_cancellation_stops_in_flight_handler(connect: Connect) -> None: + """Cancelling an in-flight request interrupts its handler and fails the pending call. + + The server answers the cancelled request with an error response (the spec says it should + not respond at all; see the divergence note on the requirement), so the caller's pending + request raises rather than hanging. + """ + started = anyio.Event() + handler_cancelled = anyio.Event() + request_ids: list[types.RequestId] = [] + errors: list[ErrorData] = [] + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "block" + assert ctx.request_id is not None + request_ids.append(ctx.request_id) + started.set() + try: + await anyio.Event().wait() # blocks until cancelled; nothing ever sets this event + except anyio.get_cancelled_exc_class(): + handler_cancelled.set() + raise + raise NotImplementedError # unreachable: the wait above never completes normally + + server = Server("blocker", on_call_tool=call_tool) + + async with connect(server) as client: + with anyio.fail_after(5): + async with anyio.create_task_group() as task_group: + + async def call_and_capture_error() -> None: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("block", {}) + errors.append(exc_info.value.error) + + task_group.start_soon(call_and_capture_error) + await started.wait() + await client.session.send_notification( + types.CancelledNotification( + params=types.CancelledNotificationParams(request_id=request_ids[0], reason="user aborted") + ) + ) + + await handler_cancelled.wait() + + assert errors == snapshot([ErrorData(code=0, message="Request cancelled")]) + + +@requirement("protocol:cancel:server-survives") +async def test_session_serves_requests_after_cancellation(connect: Connect) -> None: + """A request cancelled mid-flight does not poison the session: the next request succeeds.""" + started = anyio.Event() + request_ids: list[types.RequestId] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool(name="block", input_schema={"type": "object"}), + types.Tool(name="echo", input_schema={"type": "object"}), + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + if params.name == "echo": + return CallToolResult(content=[TextContent(text="still alive")]) + assert ctx.request_id is not None + request_ids.append(ctx.request_id) + started.set() + await anyio.Event().wait() # blocks until cancelled + raise NotImplementedError # unreachable + + server = Server("blocker", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + with anyio.fail_after(5): + async with anyio.create_task_group() as task_group: + + async def call_and_swallow_cancellation_error() -> None: + with pytest.raises(MCPError): + await client.call_tool("block", {}) + + task_group.start_soon(call_and_swallow_cancellation_error) + await started.wait() + await client.session.send_notification( + types.CancelledNotification(params=types.CancelledNotificationParams(request_id=request_ids[0])) + ) + + result = await client.call_tool("echo", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="still alive")])) + + +@requirement("protocol:cancel:unknown-id-ignored") +async def test_cancellation_for_unknown_request_is_ignored(connect: Connect) -> None: + """A cancellation referencing a request id that is not in flight is ignored without error.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="echo", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "echo" + return CallToolResult(content=[TextContent(text="unbothered")]) + + server = Server("calm", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.session.send_notification( + types.CancelledNotification(params=types.CancelledNotificationParams(request_id=9999)) + ) + result = await client.call_tool("echo", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="unbothered")])) + + +@requirement("protocol:cancel:late-response-ignored") +async def test_a_response_for_an_unknown_request_id_surfaces_to_the_message_handler() -> None: + """A response whose id matches no in-flight request is surfaced to the message handler as a RuntimeError. + + The spec says a sender SHOULD ignore a response that arrives after it issued a cancellation; + that is the same client-side code path as any response with an unknown id, and that form is + deterministic to test without depending on the cancellation API the SDK does not yet provide. + See the divergence note on the requirement. + + A real Server cannot be made to answer with a fabricated id, so the test plays the server's + side of the wire by hand. Reserve this pattern for behaviour no real server can produce. The + other tests in this file run over the transport matrix; this one is in-memory only because the + scripted-peer mechanism is the in-memory stream pair, not because the behaviour is + transport-specific. + """ + + async def scripted_server(streams: MessageStream) -> None: + server_read, server_write = streams + + def respond(request_id: types.RequestId, result: types.Result) -> SessionMessage: + return SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request_id, + # Serialized exactly as a real server serializes results onto the wire. + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + + init = await server_read.receive() + assert isinstance(init, SessionMessage) + assert isinstance(init.message, JSONRPCRequest) + assert init.message.method == "initialize" + await server_write.send( + respond( + init.message.id, + InitializeResult( + protocol_version="2025-11-25", + capabilities=ServerCapabilities(), + server_info=Implementation(name="scripted", version="0.0.1"), + ), + ) + ) + + initialized = await server_read.receive() + assert isinstance(initialized, SessionMessage) + assert isinstance(initialized.message, JSONRPCNotification) + assert initialized.message.method == "notifications/initialized" + + ping = await server_read.receive() + assert isinstance(ping, SessionMessage) + assert isinstance(ping.message, JSONRPCRequest) + assert ping.message.method == "ping" + # First answer with a fabricated id that matches nothing in flight, then the real id. + await server_write.send(respond(9999, EmptyResult())) + await server_write.send(respond(ping.message.id, EmptyResult())) + + incoming: list[IncomingMessage] = [] + + async def message_handler(message: IncomingMessage) -> None: + incoming.append(message) + + async with ( + create_client_server_memory_streams() as ((client_read, client_write), server_streams), + anyio.create_task_group() as task_group, + ClientSession(client_read, client_write, message_handler=message_handler) as session, + ): + task_group.start_soon(scripted_server, server_streams) + with anyio.fail_after(5): + await session.initialize() + pong = await session.send_request(PingRequest(), EmptyResult) + + assert pong == snapshot(EmptyResult()) + assert len(incoming) == 1 + assert isinstance(incoming[0], RuntimeError) + # The full message embeds the response object's repr; only the prefix is stable. + assert str(incoming[0]).startswith("Received response with an unknown request ID:") diff --git a/tests/interaction/lowlevel/test_completion.py b/tests/interaction/lowlevel/test_completion.py new file mode 100644 index 0000000000..6a35404df3 --- /dev/null +++ b/tests/interaction/lowlevel/test_completion.py @@ -0,0 +1,131 @@ +"""Completion interactions against the low-level Server, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + INVALID_PARAMS, + METHOD_NOT_FOUND, + CompleteResult, + Completion, + ErrorData, + PromptReference, + ResourceTemplateReference, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("completion:prompt-arg") +@requirement("completion:result-shape") +async def test_complete_prompt_argument(connect: Connect) -> None: + """Completing a prompt argument delivers the ref, argument name, and current value to the handler. + + The returned values are filtered by the argument's value, proving the value reached the handler. + """ + + async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult: + assert isinstance(params.ref, PromptReference) + assert params.ref.name == "code_review" + assert params.argument.name == "language" + candidates = ["python", "pytorch", "ruby"] + matches = [candidate for candidate in candidates if candidate.startswith(params.argument.value)] + return CompleteResult(completion=Completion(values=matches, total=len(matches), has_more=False)) + + server = Server("completer", on_completion=completion) + + async with connect(server) as client: + result = await client.complete( + PromptReference(name="code_review"), argument={"name": "language", "value": "py"} + ) + + assert result == snapshot( + CompleteResult(completion=Completion(values=["python", "pytorch"], total=2, has_more=False)) + ) + + +@requirement("completion:resource-template-arg") +async def test_complete_resource_template_variable(connect: Connect) -> None: + """Completing a URI template variable delivers the template URI and variable name to the handler.""" + + async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult: + assert isinstance(params.ref, ResourceTemplateReference) + assert params.ref.uri == "github://repos/{owner}/{repo}" + assert params.argument.name == "owner" + return CompleteResult(completion=Completion(values=[f"{params.argument.value}contextprotocol"])) + + server = Server("completer", on_completion=completion) + + async with connect(server) as client: + result = await client.complete( + ResourceTemplateReference(uri="github://repos/{owner}/{repo}"), + argument={"name": "owner", "value": "model"}, + ) + + assert result == snapshot(CompleteResult(completion=Completion(values=["modelcontextprotocol"]))) + + +@requirement("completion:context-arguments") +async def test_complete_receives_context_arguments(connect: Connect) -> None: + """Previously-resolved arguments passed as completion context reach the handler. + + The returned value is derived from the context, proving it arrived. + """ + + async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult: + assert params.argument.name == "repo" + assert params.context is not None + assert params.context.arguments is not None + return CompleteResult(completion=Completion(values=[f"{params.context.arguments['owner']}/python-sdk"])) + + server = Server("completer", on_completion=completion) + + async with connect(server) as client: + result = await client.complete( + ResourceTemplateReference(uri="github://repos/{owner}/{repo}"), + argument={"name": "repo", "value": ""}, + context_arguments={"owner": "modelcontextprotocol"}, + ) + + assert result == snapshot(CompleteResult(completion=Completion(values=["modelcontextprotocol/python-sdk"]))) + + +@requirement("completion:error:invalid-ref") +async def test_completion_against_an_unknown_ref_is_rejected_with_invalid_params(connect: Connect) -> None: + """completion/complete with a ref naming an unknown prompt is answered with -32602 Invalid params. + + The lowlevel server does not validate refs itself (it has no prompt/template registry to check + against); rejecting an unknown ref is the handler's job, and this test pins the spec-recommended + way to do it. + """ + + async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult: + assert isinstance(params.ref, PromptReference) + raise MCPError(code=INVALID_PARAMS, message=f"Unknown prompt: {params.ref.name!r}") + + server = Server("completer", on_completion=completion) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.complete(PromptReference(name="ghost"), argument={"name": "x", "value": ""}) + + assert exc_info.value.error.code == INVALID_PARAMS + + +@requirement("completion:complete:not-supported") +@requirement("protocol:error:method-not-found") +async def test_complete_without_handler_is_method_not_found(connect: Connect) -> None: + """A server with no completion handler advertises no completions capability and rejects the request.""" + server = Server("incomplete") + + async with connect(server) as client: + assert client.initialize_result.capabilities.completions is None + + with pytest.raises(MCPError) as exc_info: + await client.complete(PromptReference(name="anything"), argument={"name": "topic", "value": ""}) + + assert exc_info.value.error == snapshot(ErrorData(code=METHOD_NOT_FOUND, message="Method not found")) diff --git a/tests/interaction/lowlevel/test_elicitation.py b/tests/interaction/lowlevel/test_elicitation.py new file mode 100644 index 0000000000..b8edf601d0 --- /dev/null +++ b/tests/interaction/lowlevel/test_elicitation.py @@ -0,0 +1,662 @@ +"""Form- and URL-mode elicitation against the low-level Server, driven through the public Client API. + +The final test plays the server's side of the wire by hand to issue an elicitation request with no +mode field, because the typed server API (`elicit_form`/`elicit_url`) always serializes one. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, UrlElicitationRequiredError, types +from mcp.client import ClientRequestContext, ClientSession +from mcp.server import Server, ServerRequestContext +from mcp.shared.memory import MessageStream, create_client_server_memory_streams +from mcp.shared.message import SessionMessage +from mcp.types import ( + CallToolResult, + ElicitCompleteNotification, + ElicitCompleteNotificationParams, + ElicitRequestedSchema, + ElicitRequestFormParams, + ElicitRequestURLParams, + ElicitResult, + ErrorData, + Implementation, + InitializeResult, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + ServerCapabilities, + TextContent, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + +REQUESTED_SCHEMA: dict[str, object] = { + "type": "object", + "properties": { + "username": {"type": "string"}, + "newsletter": {"type": "boolean"}, + }, + "required": ["username"], +} + + +@requirement("elicitation:form:action:accept") +@requirement("elicitation:form:basic") +@requirement("tools:call:elicitation-roundtrip") +async def test_elicit_form_accepted_content_returns_to_handler(connect: Connect) -> None: + """An accepted form elicitation returns the user's content to the requesting handler. + + The tool reports the action as text and the received content as structured content, proving + the client's answer made it back into the tool's own result. + """ + received: list[types.ElicitRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="signup", description="Register the user.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "signup" + answer = await ctx.session.elicit_form("Choose a username.", REQUESTED_SCHEMA) + return CallToolResult(content=[TextContent(text=answer.action)], structured_content=answer.content) + + server = Server("registrar", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="accept", content={"username": "ada", "newsletter": True}) + + async with connect(server, elicitation_callback=answer_form) as client: + result = await client.call_tool("signup", {}) + + assert received == snapshot( + [ + ElicitRequestFormParams( + _meta={}, + message="Choose a username.", + requested_schema={ + "type": "object", + "properties": { + "username": {"type": "string"}, + "newsletter": {"type": "boolean"}, + }, + "required": ["username"], + }, + ) + ] + ) + assert result == snapshot( + CallToolResult( + content=[TextContent(text="accept")], + structured_content={"username": "ada", "newsletter": True}, + ) + ) + + +@requirement("elicitation:form:action:decline") +async def test_elicit_form_decline_returns_no_content(connect: Connect) -> None: + """A declined form elicitation returns the decline action to the handler with no content.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="confirm", description="Ask for confirmation.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "confirm" + answer = await ctx.session.elicit_form("Proceed?", {"type": "object", "properties": {}}) + return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + + server = Server("confirmer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="decline") + + async with connect(server, elicitation_callback=answer_form) as client: + result = await client.call_tool("confirm", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="decline content=None")])) + + +@requirement("elicitation:form:action:cancel") +async def test_elicit_form_cancel_returns_no_content(connect: Connect) -> None: + """A cancelled form elicitation returns the cancel action to the handler with no content.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="confirm", description="Ask for confirmation.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "confirm" + answer = await ctx.session.elicit_form("Proceed?", {"type": "object", "properties": {}}) + return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + + server = Server("confirmer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="cancel") + + async with connect(server, elicitation_callback=answer_form) as client: + result = await client.call_tool("confirm", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="cancel content=None")])) + + +@requirement("elicitation:form:not-supported") +@requirement("elicitation:capability:server-respects-mode") +async def test_elicit_form_without_callback_is_error(connect: Connect) -> None: + """Eliciting from a client that configured no elicitation callback fails with an error. + + The client's default callback answers with an Invalid request error, which the server-side + elicit call raises as an MCPError; the tool reports the code and message it caught. The spec + requires -32602 for an undeclared mode (see the divergence note on the requirement). The + request reaching the client also shows the server does not check the client's declared + elicitation capability before sending (see the divergence on `server-respects-mode`). + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="ask", description="Ask the user.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask" + try: + await ctx.session.elicit_form("Anyone there?", {"type": "object", "properties": {}}) + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # elicit_form cannot succeed without a client callback + + server = Server("asker", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("ask", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="-32600: Elicitation not supported")])) + + +@requirement("elicitation:url:action:accept-no-content") +@requirement("elicitation:url:basic") +async def test_elicit_url_delivers_url_and_returns_accept_without_content(connect: Connect) -> None: + """A URL elicitation delivers the message, URL, and elicitation id to the client; accepting it + returns the action with no content. + + Accept means the user agreed to visit the URL, not that the out-of-band interaction finished, + so there is never form content to return. + """ + received: list[types.ElicitRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="authorize", description="Link an account.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "authorize" + answer = await ctx.session.elicit_url( + "Authorize access to your calendar.", "https://example.com/oauth/authorize", "auth-001" + ) + return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + + server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_url(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="accept") + + async with connect(server, elicitation_callback=answer_url) as client: + result = await client.call_tool("authorize", {}) + + assert received == snapshot( + [ + ElicitRequestURLParams( + _meta={}, + message="Authorize access to your calendar.", + url="https://example.com/oauth/authorize", + elicitation_id="auth-001", + ) + ] + ) + assert result == snapshot(CallToolResult(content=[TextContent(text="accept content=None")])) + + +@requirement("elicitation:url:decline") +async def test_elicit_url_decline_returns_no_content(connect: Connect) -> None: + """A declined URL elicitation returns the decline action to the handler with no content.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="authorize", description="Link an account.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "authorize" + answer = await ctx.session.elicit_url( + "Authorize access to your calendar.", "https://example.com/oauth/authorize", "auth-001" + ) + return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + + server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_url(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="decline") + + async with connect(server, elicitation_callback=answer_url) as client: + result = await client.call_tool("authorize", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="decline content=None")])) + + +@requirement("elicitation:url:cancel") +async def test_elicit_url_cancel_returns_no_content(connect: Connect) -> None: + """A cancelled URL elicitation returns the cancel action to the handler with no content.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="authorize", description="Link an account.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "authorize" + answer = await ctx.session.elicit_url( + "Authorize access to your calendar.", "https://example.com/oauth/authorize", "auth-001" + ) + return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + + server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_url(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="cancel") + + async with connect(server, elicitation_callback=answer_url) as client: + result = await client.call_tool("authorize", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="cancel content=None")])) + + +@requirement("elicitation:url:complete-notification") +async def test_elicitation_complete_notification_carries_the_elicited_id_back_to_the_client(connect: Connect) -> None: + """After a URL elicitation finishes, the server announces it with a notification carrying the same id. + + The lifecycle under test: the tool elicits a URL interaction with an elicitationId, the user + agrees to visit the URL, the out-of-band interaction finishes, and the server emits + elicitation/complete so the client can correlate the completion with the elicitation it + accepted earlier. The completion notification carries ``related_request_id`` so over + streamable HTTP it rides the tool call's own stream and reaches the client before the call + returns; the same ordering already holds on in-memory and SSE transports. + """ + elicitation_id = "auth-001" + elicited_ids: list[str] = [] + received: list[IncomingMessage] = [] + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="link_account", description="Link an account.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "link_account" + answer = await ctx.session.elicit_url( + "Authorize access to your files.", "https://example.com/oauth/authorize", elicitation_id + ) + assert answer.action == "accept" + await ctx.session.send_elicit_complete(elicitation_id, related_request_id=ctx.request_id) + return CallToolResult(content=[TextContent(text="linked")]) + + server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_url(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + assert isinstance(params, ElicitRequestURLParams) + elicited_ids.append(params.elicitation_id) + return ElicitResult(action="accept") + + async with connect(server, message_handler=collect, elicitation_callback=answer_url) as client: + await client.call_tool("link_account", {}) + + # The completion notification refers to the same elicitation the client accepted. + assert elicited_ids == [elicitation_id] + assert received == snapshot( + [ElicitCompleteNotification(params=ElicitCompleteNotificationParams(elicitation_id="auth-001"))] + ) + + +@requirement("elicitation:url:required-error") +async def test_url_elicitation_required_error_carries_pending_elicitations(connect: Connect) -> None: + """A request that cannot proceed until a URL interaction completes is rejected with error -32042. + + This is the non-interactive alternative to elicit_url: instead of asking and waiting, the + handler rejects the whole request and lists the required URL elicitations in the error data. + The client is expected to present those URLs, wait for the matching elicitation/complete + notifications, and retry the original request. + """ + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "read_files" + raise UrlElicitationRequiredError( + [ + ElicitRequestURLParams( + message="Authorization required for your files.", + url="https://example.com/oauth/authorize", + elicitation_id="auth-001", + ) + ] + ) + + server = Server("authorizer", on_call_tool=call_tool) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("read_files", {}) + + assert exc_info.value.error == snapshot( + ErrorData( + code=-32042, + message="URL elicitation required", + data={ + "elicitations": [ + { + "mode": "url", + "message": "Authorization required for your files.", + "url": "https://example.com/oauth/authorize", + "elicitationId": "auth-001", + } + ] + }, + ) + ) + + +@requirement("elicitation:form:schema:primitives") +@requirement("elicitation:form:schema:enum-variants") +async def test_elicit_form_schema_with_every_primitive_and_enum_type_reaches_the_callback_as_sent( + connect: Connect, +) -> None: + """A requested schema covering every spec-listed property kind is delivered to the callback unchanged. + + One schema with one property per kind: a formatted string, an integer with bounds, a number, + a boolean, a plain enum, a oneOf-const titled enum, and a multi-select array-of-enum. The + callback observing the same schema as the handler sent proves both the primitive coverage and + the enum-variant coverage in one snapshot. + """ + schema: ElicitRequestedSchema = { + "type": "object", + "properties": { + "email": {"type": "string", "format": "email", "title": "Email", "description": "Contact address."}, + "age": {"type": "integer", "minimum": 0, "maximum": 150}, + "score": {"type": "number"}, + "subscribe": {"type": "boolean", "default": False}, + "tier": {"type": "string", "enum": ["free", "pro", "team"]}, + "region": { + "oneOf": [ + {"const": "eu", "title": "Europe"}, + {"const": "na", "title": "North America"}, + ], + }, + "channels": {"type": "array", "items": {"type": "string", "enum": ["email", "sms", "push"]}}, + }, + "required": ["email"], + } + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="onboard", description="Onboard the user.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "onboard" + answer = await ctx.session.elicit_form("Tell us about yourself.", schema) + return CallToolResult(content=[TextContent(text=answer.action)]) + + server = Server("onboarder", on_list_tools=list_tools, on_call_tool=call_tool) + + received: list[types.ElicitRequestParams] = [] + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="accept", content={"email": "ada@example.com"}) + + async with connect(server, elicitation_callback=answer_form) as client: + await client.call_tool("onboard", {}) + + assert len(received) == 1 + assert isinstance(received[0], ElicitRequestFormParams) + assert received[0].requested_schema == schema + + +@requirement("elicitation:form:schema:restricted-subset") +async def test_elicit_form_with_a_nested_schema_is_forwarded_unchanged(connect: Connect) -> None: + """A requested schema with nested-object and array-of-object properties passes through unchanged. + + The spec restricts form-mode requested schemas to flat objects with primitive-typed properties; + this test pins that the SDK does not enforce that restriction on either side (see the + divergence on the requirement). + """ + schema: ElicitRequestedSchema = { + "type": "object", + "properties": { + "address": { + "type": "object", + "properties": {"street": {"type": "string"}, "city": {"type": "string"}}, + }, + "contacts": { + "type": "array", + "items": {"type": "object", "properties": {"name": {"type": "string"}}}, + }, + }, + } + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="profile", description="Collect a profile.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "profile" + answer = await ctx.session.elicit_form("Profile details.", schema) + return CallToolResult(content=[TextContent(text=answer.action)]) + + server = Server("profiler", on_list_tools=list_tools, on_call_tool=call_tool) + + received: list[types.ElicitRequestParams] = [] + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="decline") + + async with connect(server, elicitation_callback=answer_form) as client: + await client.call_tool("profile", {}) + + assert len(received) == 1 + assert isinstance(received[0], ElicitRequestFormParams) + assert received[0].requested_schema == schema + + +@requirement("elicitation:form:response-validation") +async def test_accepted_elicitation_content_that_violates_the_schema_reaches_the_handler_unchanged( + connect: Connect, +) -> None: + """Accepted form content that contradicts the requested schema is delivered to the handler unchanged. + + The schema requires a string `name`; the callback answers with a wrong-type value and an extra + field. Nothing on either side validates the response against the schema (see the divergence on + the requirement), so the handler observes exactly what the callback sent. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="signup", description="Register the user.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "signup" + answer = await ctx.session.elicit_form( + "Choose a name.", + {"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]}, + ) + return CallToolResult(content=[TextContent(text=answer.action)], structured_content=answer.content) + + server = Server("registrar", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="accept", content={"name": 42, "extra": "field"}) + + async with connect(server, elicitation_callback=answer_form) as client: + result = await client.call_tool("signup", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="accept")], structured_content={"name": 42, "extra": "field"}) + ) + + +@requirement("elicitation:url:complete-unknown-ignored") +async def test_elicitation_complete_for_an_unknown_id_is_received_without_error(connect: Connect) -> None: + """An elicitation/complete for an id the client never elicited is delivered and does not fail anything. + + No URL elicitation precedes the notification; the client neither tracks elicitation ids nor + rejects unknown ones, so the call completes normally and the message handler observes the + notification as-is. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="noop", description="Send a stray complete.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "noop" + await ctx.session.send_elicit_complete("never-elicited", related_request_id=ctx.request_id) + return CallToolResult(content=[TextContent(text="ok")]) + + server = Server("notifier", on_list_tools=list_tools, on_call_tool=call_tool) + + received: list[IncomingMessage] = [] + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async with connect(server, message_handler=collect) as client: + result = await client.call_tool("noop", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="ok")])) + assert received == snapshot( + [ElicitCompleteNotification(params=ElicitCompleteNotificationParams(elicitation_id="never-elicited"))] + ) + + +@requirement("elicitation:form:mode-omitted-default") +async def test_a_mode_less_elicitation_request_is_treated_as_form_mode() -> None: + """An elicitation/create request with no mode field reaches the client callback as form-mode. + + The typed server API always serializes a mode (`elicit_form` writes 'form', `elicit_url` writes + 'url'), so this test plays the server's side of the wire by hand to send a request body without + one. Reserve this pattern for behaviour the typed server API cannot produce. + """ + received: list[types.ElicitRequestParams] = [] + answered = anyio.Event() + server_received: list[JSONRPCMessage] = [] + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="accept", content={}) + + async def scripted_server(streams: MessageStream) -> None: + server_read, server_write = streams + initialize = await server_read.receive() + assert isinstance(initialize, SessionMessage) + request = initialize.message + assert isinstance(request, JSONRPCRequest) + assert request.method == "initialize" + result = InitializeResult( + protocol_version="2025-11-25", + capabilities=ServerCapabilities(), + server_info=Implementation(name="legacy", version="0.0.1"), + ) + await server_write.send( + SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + initialized = await server_read.receive() + assert isinstance(initialized, SessionMessage) + assert isinstance(initialized.message, JSONRPCNotification) + assert initialized.message.method == "notifications/initialized" + # No mode key: a server speaking a pre-mode revision of the spec sends only message + schema. + await server_write.send( + SessionMessage( + JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="elicitation/create", + params={"message": "Legacy ask.", "requestedSchema": {"type": "object", "properties": {}}}, + ) + ) + ) + response = await server_read.receive() + assert isinstance(response, SessionMessage) + server_received.append(response.message) + answered.set() + + async with ( + create_client_server_memory_streams() as ((client_read, client_write), server_streams), + anyio.create_task_group() as tg, + ClientSession(client_read, client_write, elicitation_callback=answer_form) as session, + ): + tg.start_soon(scripted_server, server_streams) + with anyio.fail_after(5): + await session.initialize() + await answered.wait() + + assert received == snapshot( + [ + ElicitRequestFormParams( + _meta=None, + message="Legacy ask.", + requested_schema={"type": "object", "properties": {}}, + ) + ] + ) + assert isinstance(received[0], ElicitRequestFormParams) + assert received[0].mode == "form" + assert len(server_received) == 1 + assert isinstance(server_received[0], JSONRPCResponse) + assert server_received[0].id == 2 diff --git a/tests/interaction/lowlevel/test_flows.py b/tests/interaction/lowlevel/test_flows.py new file mode 100644 index 0000000000..8d96582341 --- /dev/null +++ b/tests/interaction/lowlevel/test_flows.py @@ -0,0 +1,203 @@ +"""Composed multi-feature flows against the low-level Server, driven through the public Client API. + +Each test reads as the scenario it proves: the steps run top to bottom in the order a real client +would perform them, composing two or more feature areas (a tool call followed by a resource read; +a chain of elicitations inside one tool call; the full URL-elicitation-required retry loop). The +individual features are pinned by their own tests; these prove they compose. +""" + +from collections.abc import Awaitable, Callable + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, UrlElicitationRequiredError, types +from mcp.client import ClientRequestContext +from mcp.server import Server, ServerRequestContext +from mcp.server.session import ServerSession +from mcp.types import ( + URL_ELICITATION_REQUIRED, + CallToolResult, + ElicitCompleteNotification, + ElicitRequestFormParams, + ElicitRequestURLParams, + ElicitResult, + EmptyResult, + ListToolsResult, + ReadResourceResult, + ResourceLink, + TextContent, + TextResourceContents, + Tool, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + +ListToolsHandler = Callable[ + [ServerRequestContext, types.PaginatedRequestParams | None], Awaitable[types.ListToolsResult] +] + + +def _list_tools(*names: str) -> ListToolsHandler: + """A list_tools handler advertising the named tools, so call_tool's implicit list succeeds.""" + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name=name, input_schema={"type": "object"}) for name in names]) + + return list_tools + + +@requirement("flow:tool-result:resource-link-follow") +async def test_a_resource_link_returned_by_a_tool_can_be_followed_with_read(connect: Connect) -> None: + """A tool returns a resource_link; reading that link's URI returns the referenced contents. + + Steps: (1) call the tool, (2) extract the link from its content, (3) read_resource on the + link's URI, (4) the read result carries the linked contents. + """ + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "generate" + return CallToolResult(content=[ResourceLink(uri="file:///report.txt", name="report")]) + + async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: + assert str(params.uri) == "file:///report.txt" + return ReadResourceResult(contents=[TextResourceContents(uri="file:///report.txt", text="generated")]) + + server = Server( + "linker", on_list_tools=_list_tools("generate"), on_call_tool=call_tool, on_read_resource=read_resource + ) + + async with connect(server) as client: + called = await client.call_tool("generate", {}) + link = called.content[0] + assert isinstance(link, ResourceLink) + read = await client.read_resource(link.uri) + + assert called == snapshot(CallToolResult(content=[ResourceLink(name="report", uri="file:///report.txt")])) + assert read == snapshot( + ReadResourceResult(contents=[TextResourceContents(uri="file:///report.txt", text="generated")]) + ) + + +@requirement("flow:elicitation:multi-step-form") +async def test_a_tool_handler_chains_form_elicitations_feeding_each_answer_forward(connect: Connect) -> None: + """Sequential form elicitations inside one tool call: each accepted answer feeds the next step. + + Steps: (1) call the tool, (2) the handler issues a step-one form elicitation that the client + accepts with content, (3) the handler issues a step-two elicitation whose message references + the step-one answer, (4) the client accepts step two, (5) the tool result summarises both + answers. The callback is invoked exactly twice with the expected messages and schemas. The + short-circuit on decline is the application's choice (proven separately by the per-action + elicitation tests); what this flow pins is that the chain itself works end to end. + """ + received: list[ElicitRequestFormParams] = [] + answers: list[dict[str, str | int | float | bool | list[str] | None]] = [{"name": "ada"}, {"age": 37}] + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "onboard" + first = await ctx.session.elicit_form( + "Step 1: choose a username.", {"type": "object", "properties": {"name": {"type": "string"}}} + ) + assert first.action == "accept" and first.content is not None + second = await ctx.session.elicit_form( + f"Step 2: confirm age for {first.content['name']}.", + {"type": "object", "properties": {"age": {"type": "integer"}}}, + ) + assert second.action == "accept" and second.content is not None + return CallToolResult(content=[TextContent(text=f"{first.content['name']} is {second.content['age']}")]) + + server = Server("onboarder", on_list_tools=_list_tools("onboard"), on_call_tool=call_tool) + + async def answer(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + assert isinstance(params, ElicitRequestFormParams) + received.append(params) + return ElicitResult(action="accept", content=answers[len(received) - 1]) + + async with connect(server, elicitation_callback=answer) as client: + result = await client.call_tool("onboard", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="ada is 37")])) + assert [(p.message, p.requested_schema) for p in received] == snapshot( + [ + ("Step 1: choose a username.", {"type": "object", "properties": {"name": {"type": "string"}}}), + ("Step 2: confirm age for ada.", {"type": "object", "properties": {"age": {"type": "integer"}}}), + ] + ) + + +@requirement("flow:elicitation:url-required-then-retry") +async def test_a_tool_rejected_with_url_elicitation_required_succeeds_on_retry_after_completion( + connect: Connect, +) -> None: + """The full URL-elicitation-required retry loop: -32042, completion announced, retry succeeds. + + Steps: (1) the first call is rejected with -32042 carrying the required URL elicitation in + its error data, (2) the client extracts the elicitation id from the error, (3) the server + announces completion via the elicitation/complete notification (driven via the captured + session, the same way a real out-of-band callback would reach a held session reference), + (4) the client observes the matching completion notification and retries, (5) the retry + succeeds. The handler distinguishes the two calls by a closure flag the test flips between + them; the test waits on the completion notification with an event so the retry only happens + after the announcement has arrived. + """ + elicitation_id = "auth-001" + authorised: list[bool] = [False] + captured: list[ServerSession] = [] + completed = anyio.Event() + notifications: list[ElicitCompleteNotification] = [] + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "read_files" + captured.append(ctx.session) + if not authorised[0]: + # The log line gives the message handler a non-completion notification, so the test's + # filtering branch is exercised in both directions and the wait remains specific. + await ctx.session.send_log_message(level="warning", data="authorisation required", logger="gate") + raise UrlElicitationRequiredError( + [ + ElicitRequestURLParams( + message="Authorize file access.", + url="https://example.com/oauth/authorize", + elicitation_id=elicitation_id, + ) + ] + ) + return CallToolResult(content=[TextContent(text="contents")]) + + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + """Registered so the logging capability is advertised; the client never sets a level.""" + raise NotImplementedError + + server = Server( + "gatekeeper", + on_list_tools=_list_tools("read_files"), + on_call_tool=call_tool, + on_set_logging_level=set_logging_level, + ) + + async def collect(message: IncomingMessage) -> None: + if isinstance(message, ElicitCompleteNotification): + notifications.append(message) + completed.set() + + async with connect(server, message_handler=collect) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("read_files", {}) + assert exc_info.value.error.code == URL_ELICITATION_REQUIRED + required = UrlElicitationRequiredError.from_error(exc_info.value.error) + assert [e.elicitation_id for e in required.elicitations] == [elicitation_id] + + # The out-of-band interaction completes; the server announces it on the same session. + await captured[0].send_elicit_complete(elicitation_id) + with anyio.fail_after(5): + await completed.wait() + assert notifications[0].params.elicitation_id == elicitation_id + + authorised[0] = True + result = await client.call_tool("read_files", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="contents")])) diff --git a/tests/interaction/lowlevel/test_initialize.py b/tests/interaction/lowlevel/test_initialize.py new file mode 100644 index 0000000000..91adbf5611 --- /dev/null +++ b/tests/interaction/lowlevel/test_initialize.py @@ -0,0 +1,384 @@ +"""Initialization handshake against the low-level Server, driven through the public Client API. + +The later tests drive a bare ClientSession over an InMemoryTransport instead: Client always +performs the full handshake with the latest protocol version, so skipping initialization or +requesting a different version can only be expressed one level down. The final test goes one step +further and plays the server's side of the wire by hand, because no real Server can be made to +answer initialize with an unsupported protocol version. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client import ClientRequestContext, ClientSession +from mcp.client._memory import InMemoryTransport +from mcp.server import Server, ServerRequestContext +from mcp.shared.memory import MessageStream, create_client_server_memory_streams +from mcp.shared.message import SessionMessage +from mcp.types import ( + INVALID_PARAMS, + CallToolResult, + ClientCapabilities, + CompletionsCapability, + EmptyResult, + ErrorData, + Icon, + Implementation, + InitializeRequest, + InitializeRequestParams, + InitializeResult, + JSONRPCRequest, + JSONRPCResponse, + ListToolsRequest, + ListToolsResult, + LoggingCapability, + PromptsCapability, + ResourcesCapability, + ServerCapabilities, + TextContent, + ToolsCapability, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("lifecycle:initialize:basic") +@requirement("lifecycle:initialize:server-info") +async def test_initialize_returns_server_info(connect: Connect) -> None: + """Every identity field the server declares is returned to the client in server_info.""" + server = Server( + "greeter", + version="1.2.3", + title="Greeter", + description="Greets people.", + website_url="https://example.com/greeter", + icons=[Icon(src="https://example.com/icon.png", mime_type="image/png", sizes=["48x48"])], + ) + + async with connect(server) as client: + server_info = client.initialize_result.server_info + + assert server_info == snapshot( + Implementation( + name="greeter", + title="Greeter", + description="Greets people.", + version="1.2.3", + website_url="https://example.com/greeter", + icons=[Icon(src="https://example.com/icon.png", mime_type="image/png", sizes=["48x48"])], + ) + ) + + +@requirement("lifecycle:initialize:instructions") +async def test_initialize_returns_instructions(connect: Connect) -> None: + """Instructions are returned when the server declares them and omitted when it does not.""" + async with connect(Server("guided", instructions="Call the add tool.")) as client: + assert client.initialize_result.instructions == snapshot("Call the add tool.") + + async with connect(Server("unguided")) as client: + assert client.initialize_result.instructions is None + + +@requirement("lifecycle:initialize:capabilities:from-handlers") +@requirement("tools:capability:declared") +@requirement("resources:capability:declared") +@requirement("prompts:capability:declared") +@requirement("completion:capability:declared") +async def test_initialize_capabilities_reflect_registered_handlers(connect: Connect) -> None: + """Each feature area with a registered handler is advertised as a capability. + + The in-memory transport connects with default initialization options, so the + list_changed flags are always False regardless of the server's notification behaviour. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + """Registered only so the tools capability is advertised; never called.""" + raise NotImplementedError + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListResourcesResult: + """Registered only so the resources capability is advertised; never called.""" + raise NotImplementedError + + async def subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeRequestParams) -> types.EmptyResult: + """Registered only so the subscribe sub-capability is advertised; never called.""" + raise NotImplementedError + + async def list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListPromptsResult: + """Registered only so the prompts capability is advertised; never called.""" + raise NotImplementedError + + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> types.EmptyResult: + """Registered only so the logging capability is advertised; never called.""" + raise NotImplementedError + + async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> types.CompleteResult: + """Registered only so the completions capability is advertised; never called.""" + raise NotImplementedError + + server = Server( + "full", + on_list_tools=list_tools, + on_list_resources=list_resources, + on_subscribe_resource=subscribe_resource, + on_list_prompts=list_prompts, + on_set_logging_level=set_logging_level, + on_completion=completion, + ) + + async with connect(server) as client: + capabilities = client.initialize_result.capabilities + + assert capabilities == snapshot( + ServerCapabilities( + experimental={}, + logging=LoggingCapability(), + prompts=PromptsCapability(list_changed=False), + resources=ResourcesCapability(subscribe=True, list_changed=False), + tools=ToolsCapability(list_changed=False), + completions=CompletionsCapability(), + ) + ) + + +@requirement("lifecycle:initialize:capabilities:minimal") +async def test_initialize_minimal_server_advertises_no_capabilities(connect: Connect) -> None: + """A server with no feature handlers advertises no feature capabilities.""" + async with connect(Server("bare")) as client: + capabilities = client.initialize_result.capabilities + + assert capabilities == snapshot(ServerCapabilities(experimental={})) + + +@requirement("lifecycle:initialize:client-info") +async def test_initialize_server_sees_client_info(connect: Connect) -> None: + """The client identity supplied to Client is visible to server handlers after initialization.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="whoami", description="Report the caller.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "whoami" + assert ctx.session.client_params is not None + client_info = ctx.session.client_params.client_info + return CallToolResult(content=[TextContent(text=f"{client_info.name} {client_info.version}")]) + + server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + async with connect(server, client_info=Implementation(name="acme-agent", version="9.9.9")) as client: + result = await client.call_tool("whoami", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="acme-agent 9.9.9")])) + + +@requirement("lifecycle:initialize:client-capabilities") +async def test_initialize_server_sees_client_capabilities(connect: Connect) -> None: + """The client capabilities visible to the server reflect which callbacks the client configured.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="abilities", description="Report capabilities.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "abilities" + assert ctx.session.client_params is not None + capabilities = ctx.session.client_params.capabilities + declared = [ + name + for name, value in ( + ("sampling", capabilities.sampling), + ("elicitation", capabilities.elicitation), + ) + if value is not None + ] + if capabilities.roots is not None: + declared.append(f"roots(list_changed={capabilities.roots.list_changed})") + return CallToolResult(content=[TextContent(text=",".join(declared) or "none")]) + + async def list_roots(context: ClientRequestContext) -> types.ListRootsResult: + """Registered only so the client declares the roots capability; never called.""" + raise NotImplementedError + + server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("abilities", {}) + assert result == snapshot(CallToolResult(content=[TextContent(text="none")])) + + async with connect(server, list_roots_callback=list_roots) as client: + result = await client.call_tool("abilities", {}) + assert result == snapshot(CallToolResult(content=[TextContent(text="roots(list_changed=True)")])) + + +@requirement("lifecycle:requests-before-initialized") +async def test_request_before_initialization_is_rejected() -> None: + """A feature request sent before the handshake completes is rejected; ping is exempt. + + Client always initializes on entry, so this drives a bare ClientSession that never sends + initialize. The server's stated reason for the rejection never reaches the client: the error + is reported as a generic invalid-params failure. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + """Registered so the request is routed to a real handler; never reached.""" + raise NotImplementedError + + server = Server("strict", on_list_tools=list_tools) + + async with ( + InMemoryTransport(server) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + with anyio.fail_after(5): + with pytest.raises(MCPError) as exc_info: + await session.send_request(ListToolsRequest(), ListToolsResult) + + # Ping is explicitly permitted before initialization completes. + pong = await session.send_ping() + + assert exc_info.value.error == snapshot( + ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="") + ) + assert pong == snapshot(EmptyResult()) + + +@requirement("lifecycle:version:match") +@requirement("lifecycle:version:server-fallback-latest") +async def test_initialize_negotiates_protocol_version() -> None: + """The server echoes a supported requested version and answers an unsupported one with its latest. + + Client always requests the latest version, so each half hand-builds an InitializeRequest on a + bare ClientSession to control the requested version. + """ + server = Server("negotiator") + + def initialize_request(protocol_version: str) -> InitializeRequest: + return InitializeRequest( + params=InitializeRequestParams( + protocol_version=protocol_version, + capabilities=ClientCapabilities(), + client_info=Implementation(name="time-traveller", version="0.0.1"), + ) + ) + + async with ( + InMemoryTransport(server) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + with anyio.fail_after(5): + result = await session.send_request(initialize_request("2025-03-26"), InitializeResult) + assert result.protocol_version == snapshot("2025-03-26") + + async with ( + InMemoryTransport(server) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + with anyio.fail_after(5): + result = await session.send_request(initialize_request("1999-01-01"), InitializeResult) + assert result.protocol_version == snapshot("2025-11-25") + + +@requirement("lifecycle:version:reject-unsupported") +async def test_unsupported_server_protocol_version_fails_initialization() -> None: + """An initialize response carrying a protocol version the client does not support fails initialization. + + A real Server only ever answers with a version it supports, so this test alone plays the + server's side of the wire by hand: it reads the initialize request off the raw stream and + answers it with a hand-built result. Reserve this pattern for behaviour no real server can + be made to produce. + """ + + async def scripted_server(streams: MessageStream) -> None: + server_read, server_write = streams + message = await server_read.receive() + assert isinstance(message, SessionMessage) + request = message.message + assert isinstance(request, JSONRPCRequest) + assert request.method == "initialize" + result = InitializeResult( + protocol_version="1991-08-06", + capabilities=ServerCapabilities(), + server_info=Implementation(name="relic", version="0.0.1"), + ) + await server_write.send( + SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request.id, + # Serialized exactly as a real server serializes results onto the wire. + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + + async with ( + create_client_server_memory_streams() as ((client_read, client_write), server_streams), + anyio.create_task_group() as tg, + ClientSession(client_read, client_write) as session, + ): + tg.start_soon(scripted_server, server_streams) + with anyio.fail_after(5): + with pytest.raises(RuntimeError) as exc_info: + await session.initialize() + + assert str(exc_info.value) == snapshot("Unsupported protocol version from the server: 1991-08-06") + + +@requirement("lifecycle:version:downgrade") +async def test_an_older_supported_protocol_version_from_the_server_is_accepted() -> None: + """An initialize response carrying an older supported protocol version completes the handshake at that version. + + A real Server answers with the version the client requested (or its own latest), so this test + plays the server's side of the wire by hand to return a fixed older version regardless of what + was requested. Reserve this pattern for behaviour no real server can be made to produce. + """ + + async def scripted_server(streams: MessageStream) -> None: + server_read, server_write = streams + message = await server_read.receive() + assert isinstance(message, SessionMessage) + request = message.message + assert isinstance(request, JSONRPCRequest) + assert request.method == "initialize" + result = InitializeResult( + protocol_version="2025-06-18", + capabilities=ServerCapabilities(), + server_info=Implementation(name="conservative", version="0.0.1"), + ) + await server_write.send( + SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request.id, + # Serialized exactly as a real server serializes results onto the wire. + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + + async with ( + create_client_server_memory_streams() as ((client_read, client_write), server_streams), + anyio.create_task_group() as tg, + ClientSession(client_read, client_write) as session, + ): + tg.start_soon(scripted_server, server_streams) + with anyio.fail_after(5): + initialize_result = await session.initialize() + + assert initialize_result.protocol_version == snapshot("2025-06-18") diff --git a/tests/interaction/lowlevel/test_list_changed.py b/tests/interaction/lowlevel/test_list_changed.py new file mode 100644 index 0000000000..a2f85eeacf --- /dev/null +++ b/tests/interaction/lowlevel/test_list_changed.py @@ -0,0 +1,136 @@ +"""List-changed notifications from the low-level Server, driven through the public Client API. + +``send_*_list_changed`` does not take a ``related_request_id``, so over streamable HTTP the +notification routes to the standalone GET stream and is not guaranteed to arrive before the tool +result on its POST stream. Tests therefore wait on an event the collector sets, the same pattern +as ``transports/test_streamable_http.py::test_unrelated_server_messages_arrive_on_the_standalone_stream``. +The collector still records every message it receives, so the snapshot also proves nothing else +was delivered. + +The servers register the parent capability (resources/prompts) so that part of the spec's +precondition holds, but the ``listChanged`` sub-capability stays ``False``: ``NotificationOptions`` +is not threaded through any of the suite's connection paths. The tests therefore rely on the +recorded ``lifecycle:capability:server-not-advertised`` divergence and will need updating +alongside the fix that introduces capability gating. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + CallToolResult, + PromptListChangedNotification, + ResourceListChangedNotification, + TextContent, + ToolListChangedNotification, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("tools:list-changed") +async def test_tool_list_changed_notification(connect: Connect) -> None: + """A tools/list_changed notification sent during a tool call reaches the client's message handler.""" + received: list[IncomingMessage] = [] + seen = anyio.Event() + + async def collect(message: IncomingMessage) -> None: + received.append(message) + seen.set() + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="install", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "install" + await ctx.session.send_tool_list_changed() + return CallToolResult(content=[TextContent(text="installed")]) + + server = Server("registry", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server, message_handler=collect) as client: + await client.call_tool("install", {}) + with anyio.fail_after(5): + await seen.wait() + + assert received == snapshot([ToolListChangedNotification()]) + + +@requirement("resources:list-changed") +async def test_resource_list_changed_notification(connect: Connect) -> None: + """A resources/list_changed notification sent during a tool call reaches the client's message handler.""" + received: list[IncomingMessage] = [] + seen = anyio.Event() + + async def collect(message: IncomingMessage) -> None: + received.append(message) + seen.set() + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="mount", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "mount" + await ctx.session.send_resource_list_changed() + return CallToolResult(content=[TextContent(text="mounted")]) + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListResourcesResult: + """Registered so the resources capability is advertised; the client never lists resources.""" + raise NotImplementedError + + server = Server("registry", on_list_tools=list_tools, on_call_tool=call_tool, on_list_resources=list_resources) + + async with connect(server, message_handler=collect) as client: + await client.call_tool("mount", {}) + with anyio.fail_after(5): + await seen.wait() + + assert received == snapshot([ResourceListChangedNotification()]) + + +@requirement("prompts:list-changed") +async def test_prompt_list_changed_notification(connect: Connect) -> None: + """A prompts/list_changed notification sent during a tool call reaches the client's message handler.""" + received: list[IncomingMessage] = [] + seen = anyio.Event() + + async def collect(message: IncomingMessage) -> None: + received.append(message) + seen.set() + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="learn", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "learn" + await ctx.session.send_prompt_list_changed() + return CallToolResult(content=[TextContent(text="learned")]) + + async def list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListPromptsResult: + """Registered so the prompts capability is advertised; the client never lists prompts.""" + raise NotImplementedError + + server = Server("registry", on_list_tools=list_tools, on_call_tool=call_tool, on_list_prompts=list_prompts) + + async with connect(server, message_handler=collect) as client: + await client.call_tool("learn", {}) + with anyio.fail_after(5): + await seen.wait() + + assert received == snapshot([PromptListChangedNotification()]) diff --git a/tests/interaction/lowlevel/test_logging.py b/tests/interaction/lowlevel/test_logging.py new file mode 100644 index 0000000000..fba632ef4d --- /dev/null +++ b/tests/interaction/lowlevel/test_logging.py @@ -0,0 +1,127 @@ +"""Logging interactions against the low-level Server, driven through the public Client API. + +Notification ordering: the in-memory transport delivers every server-to-client message on one +ordered stream, and the client's receive loop dispatches each incoming message to completion +before reading the next one. Over streamable HTTP that ordered single-stream guarantee holds +only for messages that carry a ``related_request_id`` (they ride the originating request's POST +stream); without it the message routes to the standalone GET stream and may arrive after the +response. These tests pass ``related_request_id`` so they can collect into a plain list and +assert after the request completes on every transport leg -- no events, no waiting. +""" + +import pytest +from inline_snapshot import snapshot + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.types import CallToolResult, EmptyResult, LoggingMessageNotificationParams, TextContent +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + +ALL_LEVELS: tuple[types.LoggingLevel, ...] = ( + "debug", + "info", + "notice", + "warning", + "error", + "critical", + "alert", + "emergency", +) + + +@requirement("logging:set-level") +async def test_set_logging_level_reaches_handler(connect: Connect) -> None: + """The level requested by the client is delivered to the server's handler verbatim.""" + + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + assert params.level == "warning" + return EmptyResult() + + server = Server("logger", on_set_logging_level=set_logging_level) + + async with connect(server) as client: + result = await client.set_logging_level("warning") + + assert result == snapshot(EmptyResult()) + + +@requirement("logging:message:fields") +@requirement("tools:call:logging-mid-execution") +async def test_log_messages_reach_logging_callback_in_order(connect: Connect) -> None: + """Log messages sent during a tool call arrive at the logging callback, in order, before the call returns. + + The two messages pin the full notification shape: severity, optional logger name, and both + string and structured data payloads. + """ + received: list[LoggingMessageNotificationParams] = [] + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="chatty", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "chatty" + await ctx.session.send_log_message( + level="info", data="starting up", logger="app.lifecycle", related_request_id=ctx.request_id + ) + await ctx.session.send_log_message( + level="error", data={"code": 502, "retryable": True}, related_request_id=ctx.request_id + ) + return CallToolResult(content=[TextContent(text="done")]) + + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + """Registered so the logging capability is advertised; the client never sets a level.""" + raise NotImplementedError + + server = Server("logger", on_list_tools=list_tools, on_call_tool=call_tool, on_set_logging_level=set_logging_level) + + async with connect(server, logging_callback=collect) as client: + result = await client.call_tool("chatty", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="done")])) + assert received == snapshot( + [ + LoggingMessageNotificationParams(level="info", logger="app.lifecycle", data="starting up"), + LoggingMessageNotificationParams(level="error", data={"code": 502, "retryable": True}), + ] + ) + + +@requirement("logging:message:all-levels") +async def test_log_messages_at_every_severity_level(connect: Connect) -> None: + """Each of the eight RFC 5424 severity levels is deliverable as a log message notification.""" + received: list[LoggingMessageNotificationParams] = [] + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="siren", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "siren" + for level in ALL_LEVELS: + await ctx.session.send_log_message( + level=level, data=f"a {level} message", related_request_id=ctx.request_id + ) + return CallToolResult(content=[TextContent(text="logged")]) + + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + """Registered so the logging capability is advertised; the client never sets a level.""" + raise NotImplementedError + + server = Server("logger", on_list_tools=list_tools, on_call_tool=call_tool, on_set_logging_level=set_logging_level) + + async with connect(server, logging_callback=collect) as client: + await client.call_tool("siren", {}) + + assert [params.level for params in received] == list(ALL_LEVELS) diff --git a/tests/interaction/lowlevel/test_meta.py b/tests/interaction/lowlevel/test_meta.py new file mode 100644 index 0000000000..a9e4f994d8 --- /dev/null +++ b/tests/interaction/lowlevel/test_meta.py @@ -0,0 +1,63 @@ +"""Request and result _meta round trips against the low-level Server, through the public Client API. + +Meta is opaque pass-through data, so these tests assert identity against the value that was sent +rather than snapshotting a literal: the expected value and the sent value are the same variable, +which also proves the SDK injected nothing alongside it. +""" + +import pytest + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.types import CallToolResult, RequestParamsMeta, TextContent +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("meta:request-to-handler") +async def test_request_meta_reaches_handler(connect: Connect) -> None: + """The _meta object the client attaches to a request arrives at the tool handler unchanged.""" + request_meta: RequestParamsMeta = {"example.com/trace": "abc-123"} + observed_metas: list[dict[str, object]] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="traced", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "traced" + assert ctx.meta is not None + observed_metas.append(dict(ctx.meta)) + return CallToolResult(content=[TextContent(text="traced")]) + + server = Server("observability", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.call_tool("traced", {}, meta=request_meta) + + assert observed_metas == [dict(request_meta)] + + +@requirement("meta:result-to-client") +async def test_result_meta_reaches_client(connect: Connect) -> None: + """The _meta object a handler attaches to its result is delivered to the client unchanged.""" + result_meta = {"example.com/cost": 3} + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="metered", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "metered" + return CallToolResult(content=[TextContent(text="done")], _meta=result_meta) + + server = Server("observability", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("metered", {}) + + assert result == CallToolResult(content=[TextContent(text="done")], _meta=result_meta) diff --git a/tests/interaction/lowlevel/test_pagination.py b/tests/interaction/lowlevel/test_pagination.py new file mode 100644 index 0000000000..77db90401e --- /dev/null +++ b/tests/interaction/lowlevel/test_pagination.py @@ -0,0 +1,242 @@ +"""Cursor pagination of the list operations against the low-level Server. + +The cursor is an opaque string chosen by the server: the suite only asserts that whatever the +handler returns as next_cursor comes back verbatim on the client's next call, not any particular +pagination scheme. +""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + INVALID_PARAMS, + ListPromptsResult, + ListResourcesResult, + ListResourceTemplatesResult, + ListToolsResult, + Prompt, + Resource, + ResourceTemplate, + Tool, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("tools:list:pagination") +async def test_next_cursor_round_trips_through_the_client(connect: Connect) -> None: + """The next_cursor a list handler returns reaches the client, and the cursor the client sends + back on the following call reaches the handler verbatim. + """ + cursor = "page-2" + seen_cursors: list[str | None] = [] + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + assert params is not None # the client always sends params, even without a cursor + seen_cursors.append(params.cursor) + if params.cursor is None: + return ListToolsResult( + tools=[Tool(name="alpha", input_schema={"type": "object"})], + next_cursor=cursor, + ) + return ListToolsResult(tools=[Tool(name="beta", input_schema={"type": "object"})]) + + server = Server("paginated", on_list_tools=list_tools) + + async with connect(server) as client: + first_page = await client.list_tools() + second_page = await client.list_tools(cursor=first_page.next_cursor) + + assert first_page.next_cursor == cursor + assert seen_cursors == [None, cursor] + assert [tool.name for tool in first_page.tools] == ["alpha"] + assert second_page == snapshot(ListToolsResult(tools=[Tool(name="beta", input_schema={"type": "object"})])) + + +@requirement("pagination:exhaustion") +@requirement("tools:list:pagination") +async def test_paginating_until_next_cursor_is_absent_yields_every_page(connect: Connect) -> None: + """Following next_cursor until it is absent visits every page exactly once, in order.""" + pages: dict[str | None, tuple[str, str | None]] = { + None: ("alpha", "page-2"), + "page-2": ("beta", "page-3"), + "page-3": ("gamma", None), + } + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + assert params is not None + tool_name, next_cursor = pages[params.cursor] + return ListToolsResult(tools=[Tool(name=tool_name, input_schema={"type": "object"})], next_cursor=next_cursor) + + server = Server("paginated", on_list_tools=list_tools) + + collected: list[str] = [] + cursor: str | None = None + requests_made = 0 + async with connect(server) as client: + while True: + result = await client.list_tools(cursor=cursor) + requests_made += 1 + assert requests_made <= len(pages), "the server kept returning next_cursor past the last page" + collected.extend(tool.name for tool in result.tools) + if result.next_cursor is None: + break + cursor = result.next_cursor + + assert collected == snapshot(["alpha", "beta", "gamma"]) + assert requests_made == len(pages) + + +@requirement("pagination:client:cursor-handling") +async def test_the_client_follows_opaque_cursors_through_pages_of_varying_sizes(connect: Connect) -> None: + """The client passes a server-issued cursor back byte-for-byte and follows pages of varying sizes. + + The cursors are deliberately base64-looking strings (with padding and URL-unsafe characters) to + show the client treats them as opaque tokens; the page sizes [3, 1, 2] show the loop relies only + on next_cursor, not on a fixed page size. + """ + cursor_to_page_2 = "YWxwaGE+YnJhdm8/Y2hhcmxpZQ==" + cursor_to_page_3 = "ZGVsdGE=" + pages: dict[str | None, tuple[list[str], str | None]] = { + None: (["alpha", "beta", "gamma"], cursor_to_page_2), + cursor_to_page_2: (["delta"], cursor_to_page_3), + cursor_to_page_3: (["epsilon", "zeta"], None), + } + received_cursors: list[str | None] = [] + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + assert params is not None + received_cursors.append(params.cursor) + names, next_cursor = pages[params.cursor] + return ListToolsResult( + tools=[Tool(name=name, input_schema={"type": "object"}) for name in names], next_cursor=next_cursor + ) + + server = Server("paginated", on_list_tools=list_tools) + + page_sizes: list[int] = [] + cursor: str | None = None + async with connect(server) as client: + while True: + result = await client.list_tools(cursor=cursor) + page_sizes.append(len(result.tools)) + if result.next_cursor is None: + break + cursor = result.next_cursor + + # Identity, not a snapshot: what arrived at the handler is exactly what the handler issued. + assert received_cursors == [None, cursor_to_page_2, cursor_to_page_3] + assert page_sizes == [3, 1, 2] + + +@requirement("pagination:invalid-cursor") +async def test_an_unrecognized_pagination_cursor_is_rejected_with_invalid_params(connect: Connect) -> None: + """A list request with a cursor the server did not issue is answered with -32602 Invalid params. + + The lowlevel server does not validate cursors itself (they are opaque to it); rejecting an + unrecognized cursor is the handler's job, and this test pins the spec-recommended way to do it. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + assert params is not None + assert params.cursor == "never-issued" + raise MCPError(code=INVALID_PARAMS, message=f"Unknown cursor: {params.cursor!r}") + + server = Server("paginated", on_list_tools=list_tools) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.list_tools(cursor="never-issued") + + assert exc_info.value.error.code == INVALID_PARAMS + + +@requirement("resources:list:pagination") +async def test_resources_list_supports_cursor_pagination(connect: Connect) -> None: + """resources/list round-trips the cursor like every other list operation.""" + cursor = "page-2" + seen_cursors: list[str | None] = [] + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: + assert params is not None + seen_cursors.append(params.cursor) + if params.cursor is None: + return ListResourcesResult(resources=[Resource(uri="memo://1", name="first")], next_cursor=cursor) + return ListResourcesResult(resources=[Resource(uri="memo://2", name="second")]) + + server = Server("paginated", on_list_resources=list_resources) + + async with connect(server) as client: + first_page = await client.list_resources() + second_page = await client.list_resources(cursor=first_page.next_cursor) + + assert first_page.next_cursor == cursor + assert seen_cursors == [None, cursor] + assert [resource.name for resource in first_page.resources] == ["first"] + assert [resource.name for resource in second_page.resources] == ["second"] + assert second_page.next_cursor is None + + +@requirement("resources:templates:pagination") +async def test_resource_templates_list_supports_cursor_pagination(connect: Connect) -> None: + """resources/templates/list round-trips the cursor like every other list operation.""" + cursor = "page-2" + seen_cursors: list[str | None] = [] + + async def list_resource_templates( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourceTemplatesResult: + assert params is not None + seen_cursors.append(params.cursor) + if params.cursor is None: + return ListResourceTemplatesResult( + resource_templates=[ResourceTemplate(name="first", uri_template="users://{id}")], + next_cursor=cursor, + ) + return ListResourceTemplatesResult( + resource_templates=[ResourceTemplate(name="second", uri_template="teams://{id}")] + ) + + server = Server("paginated", on_list_resource_templates=list_resource_templates) + + async with connect(server) as client: + first_page = await client.list_resource_templates() + second_page = await client.list_resource_templates(cursor=first_page.next_cursor) + + assert first_page.next_cursor == cursor + assert seen_cursors == [None, cursor] + assert [template.name for template in first_page.resource_templates] == ["first"] + assert [template.name for template in second_page.resource_templates] == ["second"] + assert second_page.next_cursor is None + + +@requirement("prompts:list:pagination") +async def test_prompts_list_supports_cursor_pagination(connect: Connect) -> None: + """prompts/list round-trips the cursor like every other list operation.""" + cursor = "page-2" + seen_cursors: list[str | None] = [] + + async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListPromptsResult: + assert params is not None + seen_cursors.append(params.cursor) + if params.cursor is None: + return ListPromptsResult(prompts=[Prompt(name="first")], next_cursor=cursor) + return ListPromptsResult(prompts=[Prompt(name="second")]) + + server = Server("paginated", on_list_prompts=list_prompts) + + async with connect(server) as client: + first_page = await client.list_prompts() + second_page = await client.list_prompts(cursor=first_page.next_cursor) + + assert first_page.next_cursor == cursor + assert seen_cursors == [None, cursor] + assert [prompt.name for prompt in first_page.prompts] == ["first"] + assert [prompt.name for prompt in second_page.prompts] == ["second"] + assert second_page.next_cursor is None diff --git a/tests/interaction/lowlevel/test_ping.py b/tests/interaction/lowlevel/test_ping.py new file mode 100644 index 0000000000..797e20dc35 --- /dev/null +++ b/tests/interaction/lowlevel/test_ping.py @@ -0,0 +1,53 @@ +"""Ping interactions against the low-level Server, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.types import CallToolResult, EmptyResult, TextContent +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("lifecycle:ping") +@requirement("ping:client-to-server") +async def test_client_ping_returns_empty_result(connect: Connect) -> None: + """A client ping is answered with an empty result, even by a server with no handlers.""" + server = Server("silent") + + async with connect(server) as client: + result = await client.send_ping() + + assert result == snapshot(EmptyResult()) + + +@requirement("lifecycle:ping") +@requirement("ping:server-to-client") +async def test_server_ping_returns_empty_result(connect: Connect) -> None: + """A server-initiated ping sent while a request is in flight is answered by the client. + + The tool returns the type of the ping response, proving the round trip completed inside + the handler before the tool result was produced. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="ping_back", description="Ping the client.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ping_back" + pong = await ctx.session.send_ping() + return CallToolResult(content=[TextContent(text=type(pong).__name__)]) + + server = Server("pinger", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("ping_back", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="EmptyResult")])) diff --git a/tests/interaction/lowlevel/test_progress.py b/tests/interaction/lowlevel/test_progress.py new file mode 100644 index 0000000000..6350c33a33 --- /dev/null +++ b/tests/interaction/lowlevel/test_progress.py @@ -0,0 +1,301 @@ +"""Progress interactions against the low-level Server, driven through the public Client API. + +Server-to-client progress emitted during a request follows the same ordering guarantee as +logging notifications (see test_logging.py) -- on the in-memory transport unconditionally, and +over streamable HTTP only when sent with ``related_request_id`` so the notification rides the +originating request's POST stream rather than the standalone GET stream. These tests pass +``related_request_id`` so no synchronisation is needed. The client-to-server direction is a +standalone notification with no response to await, so that test waits on an event set by the +server's handler. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.server.session import ServerSession +from mcp.shared.session import ProgressFnT +from mcp.types import CallToolResult, ProgressNotification, ProgressNotificationParams, ProgressToken, TextContent +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("protocol:progress:callback") +@requirement("tools:call:progress") +async def test_progress_during_tool_call_reaches_callback_in_order(connect: Connect) -> None: + """Progress notifications emitted by a tool handler reach the caller's progress callback in order.""" + received: list[tuple[float, float | None, str | None]] = [] + + async def collect(progress: float, total: float | None, message: str | None) -> None: + received.append((progress, total, message)) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="download", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "download" + assert ctx.meta is not None + token = ctx.meta.get("progress_token") + assert token is not None + await ctx.session.send_progress_notification( + token, 1.0, total=3.0, message="first chunk", related_request_id=str(ctx.request_id) + ) + await ctx.session.send_progress_notification( + token, 2.0, total=3.0, message="second chunk", related_request_id=str(ctx.request_id) + ) + await ctx.session.send_progress_notification( + token, 3.0, total=3.0, message="done", related_request_id=str(ctx.request_id) + ) + return CallToolResult(content=[TextContent(text="downloaded")]) + + server = Server("downloader", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("download", {}, progress_callback=collect) + + assert result == snapshot(CallToolResult(content=[TextContent(text="downloaded")])) + assert received == snapshot([(1.0, 3.0, "first chunk"), (2.0, 3.0, "second chunk"), (3.0, 3.0, "done")]) + + +@requirement("protocol:progress:token-injected") +async def test_progress_token_visible_to_handler(connect: Connect) -> None: + """Supplying a progress callback attaches a progress token that the handler can read from the request meta.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="inspect", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "inspect" + assert ctx.meta is not None + return CallToolResult(content=[TextContent(text=str(ctx.meta.get("progress_token")))]) + + server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + + async def ignore(progress: float, total: float | None, message: str | None) -> None: + """A progress callback that is never invoked; the tool only inspects the token.""" + raise NotImplementedError + + async with connect(server) as client: + result = await client.call_tool("inspect", {}, progress_callback=ignore) + + # The token is the request id of the tools/call request itself (initialize is request 0). + assert result == snapshot(CallToolResult(content=[TextContent(text="1")])) + + +@requirement("protocol:progress:no-token") +async def test_no_progress_callback_means_no_token(connect: Connect) -> None: + """Without a progress callback the request carries no progress token. + + The low-level API has no way to report request-scoped progress without a token, so a handler + that sees no token has nothing to send progress against. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="inspect", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "inspect" + assert ctx.meta is not None + return CallToolResult(content=[TextContent(text=str(ctx.meta.get("progress_token")))]) + + server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("inspect", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="None")])) + + +@requirement("protocol:progress:client-to-server") +async def test_client_progress_notification_reaches_server_handler(connect: Connect) -> None: + """A progress notification sent by the client is delivered to the server's progress handler.""" + received: list[ProgressNotificationParams] = [] + delivered = anyio.Event() + + async def on_progress(ctx: ServerRequestContext, params: ProgressNotificationParams) -> None: + received.append(params) + delivered.set() + + server = Server("observer", on_progress=on_progress) + + async with connect(server) as client: + await client.send_progress_notification("upload-1", 0.5, total=1.0, message="halfway") + with anyio.fail_after(5): + await delivered.wait() + + assert received == snapshot( + [ProgressNotificationParams(progress_token="upload-1", progress=0.5, total=1.0, message="halfway")] + ) + + +@requirement("protocol:progress:token-unique") +async def test_concurrent_requests_carry_distinct_progress_tokens(connect: Connect) -> None: + """Two concurrent requests carry distinct progress tokens, and each callback sees only its own progress. + + Without the barrier the first call could run to completion before the second starts, so only one + token would be live at a time and the demultiplexing would never be exercised. The handlers each + block until both have started and then hand control back and forth so the four progress + notifications are emitted in strict a, b, a, b order on the wire. The two handlers send different + progress values so a stream swap (token A delivered to callback B and vice versa) would fail: each + callback receiving exactly its own values proves notifications are routed by token, not by arrival + order or by chance. + """ + progress_values = {"a": (1.0, 2.0), "b": (10.0, 20.0)} + tokens: dict[str, ProgressToken] = {} + entered = {"a": anyio.Event(), "b": anyio.Event()} + # turns[n] is set to release the nth emission; each emission releases the next. + turns = [anyio.Event() for _ in range(4)] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="report", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "report" + assert params.arguments is not None + assert ctx.meta is not None + token = ctx.meta.get("progress_token") + assert token is not None + label = params.arguments["label"] + tokens[label] = token + entered[label].set() + # The two handlers interleave by waiting on alternating turns: a takes 0 and 2, b takes 1 and 3. + first, second = (0, 2) if label == "a" else (1, 3) + await turns[first].wait() + await ctx.session.send_progress_notification( + token, progress_values[label][0], related_request_id=str(ctx.request_id) + ) + turns[first + 1].set() + await turns[second].wait() + await ctx.session.send_progress_notification( + token, progress_values[label][1], related_request_id=str(ctx.request_id) + ) + if second + 1 < len(turns): + turns[second + 1].set() + return CallToolResult(content=[TextContent(text="done")]) + + server = Server("reporter", on_list_tools=list_tools, on_call_tool=call_tool) + + received_a: list[float] = [] + received_b: list[float] = [] + + async def collect_a(progress: float, total: float | None, message: str | None) -> None: + received_a.append(progress) + + async def collect_b(progress: float, total: float | None, message: str | None) -> None: + received_b.append(progress) + + async with connect(server) as client: + + async def call(label: str, collect: ProgressFnT) -> None: + await client.call_tool("report", {"label": label}, progress_callback=collect) + + with anyio.fail_after(5): + async with anyio.create_task_group() as task_group: # pragma: no branch + task_group.start_soon(call, "a", collect_a) + task_group.start_soon(call, "b", collect_b) + await entered["a"].wait() + await entered["b"].wait() + turns[0].set() + + assert tokens["a"] != tokens["b"] + assert received_a == [1.0, 2.0] + assert received_b == [10.0, 20.0] + + +@requirement("protocol:progress:stops-after-completion") +@requirement("protocol:progress:late-dropped-by-client") +async def test_progress_sent_after_the_response_is_not_delivered_to_the_callback(connect: Connect) -> None: + """A progress notification sent after the response is emitted, and the client drops it from the callback. + + This single body proves both halves: the server's `send_progress_notification` happily sends for + a token whose request has already completed (the spec MUST that progress stops is not enforced; + see the divergence on `stops-after-completion`), and the client, having removed the callback when + the call returned, does not deliver the late notification to it. The message handler observes the + late notification arriving so the test knows when to assert without polling. + """ + captured: list[tuple[ServerSession, ProgressToken]] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="report", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "report" + assert ctx.meta is not None + token = ctx.meta.get("progress_token") + assert token is not None + captured.append((ctx.session, token)) + await ctx.session.send_progress_notification(token, 0.5, related_request_id=str(ctx.request_id)) + return CallToolResult(content=[TextContent(text="done")]) + + server = Server("reporter", on_list_tools=list_tools, on_call_tool=call_tool) + + received: list[float] = [] + late_progress_arrived = anyio.Event() + + async def collect(progress: float, total: float | None, message: str | None) -> None: + received.append(progress) + + async def message_handler(message: IncomingMessage) -> None: + if isinstance(message, ProgressNotification) and message.params.progress == 1.0: + late_progress_arrived.set() + + async with connect(server, message_handler=message_handler) as client: + with anyio.fail_after(5): + await client.call_tool("report", {}, progress_callback=collect) + assert received == [0.5] + + server_session, token = captured[0] + await server_session.send_progress_notification(token, 1.0) + await late_progress_arrived.wait() + + assert received == [0.5] + + +@requirement("protocol:progress:monotonic") +async def test_non_increasing_progress_values_are_forwarded_unchanged(connect: Connect) -> None: + """A handler that emits non-increasing progress values has them forwarded to the callback unchanged. + + The spec says progress MUST increase with each notification; the SDK does not enforce that on + either side. See the divergence note on the requirement. + """ + received: list[float] = [] + + async def collect(progress: float, total: float | None, message: str | None) -> None: + received.append(progress) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="zigzag", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "zigzag" + assert ctx.meta is not None + token = ctx.meta.get("progress_token") + assert token is not None + await ctx.session.send_progress_notification(token, 0.5, related_request_id=str(ctx.request_id)) + await ctx.session.send_progress_notification(token, 0.3, related_request_id=str(ctx.request_id)) + await ctx.session.send_progress_notification(token, 0.9, related_request_id=str(ctx.request_id)) + return CallToolResult(content=[TextContent(text="done")]) + + server = Server("zigzagger", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.call_tool("zigzag", {}, progress_callback=collect) + + assert received == snapshot([0.5, 0.3, 0.9]) diff --git a/tests/interaction/lowlevel/test_prompts.py b/tests/interaction/lowlevel/test_prompts.py new file mode 100644 index 0000000000..868b82692c --- /dev/null +++ b/tests/interaction/lowlevel/test_prompts.py @@ -0,0 +1,209 @@ +"""Prompt interactions against the low-level Server, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + INVALID_PARAMS, + AudioContent, + EmbeddedResource, + ErrorData, + GetPromptResult, + Icon, + ImageContent, + ListPromptsResult, + Prompt, + PromptArgument, + PromptMessage, + TextContent, + TextResourceContents, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("prompts:list:basic") +async def test_list_prompts_returns_registered_prompts(connect: Connect) -> None: + """The prompts returned by the handler reach the client with their argument declarations intact.""" + + async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListPromptsResult: + return ListPromptsResult( + prompts=[ + Prompt( + name="code_review", + description="Review a piece of code.", + arguments=[ + PromptArgument(name="code", description="The code to review.", required=True), + PromptArgument(name="style_guide", description="Optional style guide to apply."), + ], + icons=[Icon(src="https://example.com/review.png", mime_type="image/png", sizes=["48x48"])], + ), + Prompt(name="daily_standup"), + ] + ) + + server = Server("prompter", on_list_prompts=list_prompts) + + async with connect(server) as client: + result = await client.list_prompts() + + assert result == snapshot( + ListPromptsResult( + prompts=[ + Prompt( + name="code_review", + description="Review a piece of code.", + arguments=[ + PromptArgument(name="code", description="The code to review.", required=True), + PromptArgument(name="style_guide", description="Optional style guide to apply."), + ], + icons=[Icon(src="https://example.com/review.png", mime_type="image/png", sizes=["48x48"])], + ), + Prompt(name="daily_standup"), + ] + ) + ) + + +@requirement("prompts:get:with-args") +async def test_get_prompt_substitutes_arguments(connect: Connect) -> None: + """Arguments supplied by the client reach the prompt handler; the templated message comes back.""" + + async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: + assert params.name == "greet" + assert params.arguments is not None + return GetPromptResult( + description="A personalised greeting.", + messages=[PromptMessage(role="user", content=TextContent(text=f"Hello, {params.arguments['name']}!"))], + ) + + server = Server("prompter", on_get_prompt=get_prompt) + + async with connect(server) as client: + result = await client.get_prompt("greet", {"name": "Ada"}) + + assert result == snapshot( + GetPromptResult( + description="A personalised greeting.", + messages=[PromptMessage(role="user", content=TextContent(text="Hello, Ada!"))], + ) + ) + + +@requirement("prompts:get:multi-message") +async def test_get_prompt_multiple_messages_preserve_roles_and_order(connect: Connect) -> None: + """A prompt returning a user/assistant conversation reaches the client with roles and order intact.""" + + async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: + assert params.name == "geography_quiz" + return GetPromptResult( + messages=[ + PromptMessage(role="user", content=TextContent(text="What is the capital of France?")), + PromptMessage(role="assistant", content=TextContent(text="The capital of France is Paris.")), + PromptMessage(role="user", content=TextContent(text="And of Italy?")), + ] + ) + + server = Server("prompter", on_get_prompt=get_prompt) + + async with connect(server) as client: + result = await client.get_prompt("geography_quiz") + + assert result == snapshot( + GetPromptResult( + messages=[ + PromptMessage(role="user", content=TextContent(text="What is the capital of France?")), + PromptMessage(role="assistant", content=TextContent(text="The capital of France is Paris.")), + PromptMessage(role="user", content=TextContent(text="And of Italy?")), + ] + ) + ) + + +@requirement("prompts:get:no-args") +async def test_get_prompt_without_arguments_returns_the_messages(connect: Connect) -> None: + """A prompt fetched with no arguments delivers None as the handler's arguments and returns its messages.""" + + async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: + assert params.name == "static" + assert params.arguments is None + return GetPromptResult(messages=[PromptMessage(role="user", content=TextContent(text="Say hello."))]) + + server = Server("prompter", on_get_prompt=get_prompt) + + async with connect(server) as client: + result = await client.get_prompt("static") + + assert result == snapshot( + GetPromptResult(messages=[PromptMessage(role="user", content=TextContent(text="Say hello."))]) + ) + + +@requirement("prompts:get:content:image") +@requirement("prompts:get:content:audio") +@requirement("prompts:get:content:embedded-resource") +async def test_get_prompt_with_non_text_content_round_trips(connect: Connect) -> None: + """Prompt messages can carry image, audio, and embedded-resource content; all reach the client. + + A single full-result snapshot proves all three content types round-trip: each block in the result + is one of the three behaviours under test. Tiny fixed base64 payloads ("aW1n" is b"img", "YXVk" + is b"aud") so the snapshot pins the exact bytes. + """ + + async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: + assert params.name == "media" + return GetPromptResult( + messages=[ + PromptMessage(role="user", content=ImageContent(data="aW1n", mime_type="image/png")), + PromptMessage(role="assistant", content=AudioContent(data="YXVk", mime_type="audio/wav")), + PromptMessage( + role="user", + content=EmbeddedResource( + resource=TextResourceContents(uri="resource://notes/1", mime_type="text/plain", text="attached") + ), + ), + ] + ) + + server = Server("prompter", on_get_prompt=get_prompt) + + async with connect(server) as client: + result = await client.get_prompt("media", {}) + + assert result == snapshot( + GetPromptResult( + messages=[ + PromptMessage(role="user", content=ImageContent(data="aW1n", mime_type="image/png")), + PromptMessage(role="assistant", content=AudioContent(data="YXVk", mime_type="audio/wav")), + PromptMessage( + role="user", + content=EmbeddedResource( + resource=TextResourceContents(uri="resource://notes/1", mime_type="text/plain", text="attached") + ), + ), + ] + ) + ) + + +@requirement("prompts:get:unknown-name") +async def test_get_prompt_unknown_name_is_protocol_error(connect: Connect) -> None: + """A handler that rejects an unrecognised prompt name with MCPError produces a JSON-RPC error. + + The error's code and message chosen by the handler reach the client verbatim. + """ + + async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: + raise MCPError(code=INVALID_PARAMS, message=f"Unknown prompt: {params.name}") + + server = Server("prompter", on_get_prompt=get_prompt) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.get_prompt("nope") + + assert exc_info.value.error == snapshot(ErrorData(code=INVALID_PARAMS, message="Unknown prompt: nope")) diff --git a/tests/interaction/lowlevel/test_resources.py b/tests/interaction/lowlevel/test_resources.py new file mode 100644 index 0000000000..4e369d3645 --- /dev/null +++ b/tests/interaction/lowlevel/test_resources.py @@ -0,0 +1,309 @@ +"""Resource interactions against the low-level Server, driven through the public Client API.""" + +import base64 + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + METHOD_NOT_FOUND, + Annotations, + BlobResourceContents, + CallToolResult, + EmptyResult, + ErrorData, + Icon, + ListResourcesResult, + ListResourceTemplatesResult, + ReadResourceResult, + Resource, + ResourceTemplate, + ResourceUpdatedNotification, + ResourceUpdatedNotificationParams, + TextContent, + TextResourceContents, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("resources:list:basic") +@requirement("resources:annotations") +async def test_list_resources_returns_registered_resources(connect: Connect) -> None: + """Listed resources reach the client with their URIs, names, and optional descriptive fields intact. + + The fully-populated entry includes annotations, so the snapshot also proves they round-trip. + The SDK's Annotations model omits the schema's lastModified field (see the divergence on + resources:annotations); the input is built via model_validate with lastModified set so the + snapshot pins the drop and will fail once the SDK adds the field. + """ + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult( + resources=[ + Resource(uri="memo://minimal", name="minimal"), + Resource( + uri="file:///project/README.md", + name="readme", + title="Project README", + description="The project's front page.", + mime_type="text/markdown", + size=1024, + annotations=Annotations.model_validate( + {"audience": ["user", "assistant"], "priority": 0.8, "lastModified": "2025-01-01T00:00:00Z"} + ), + icons=[Icon(src="https://example.com/readme.png", mime_type="image/png", sizes=["48x48"])], + ), + ] + ) + + server = Server("library", on_list_resources=list_resources) + + async with connect(server) as client: + result = await client.list_resources() + + assert result == snapshot( + ListResourcesResult( + resources=[ + Resource(uri="memo://minimal", name="minimal"), + Resource( + uri="file:///project/README.md", + name="readme", + title="Project README", + description="The project's front page.", + mime_type="text/markdown", + size=1024, + annotations=Annotations(audience=["user", "assistant"], priority=0.8), + icons=[Icon(src="https://example.com/readme.png", mime_type="image/png", sizes=["48x48"])], + ), + ] + ) + ) + + +@requirement("resources:read:text") +async def test_read_resource_text(connect: Connect) -> None: + """Reading a text resource returns its contents with the URI, MIME type, and text supplied by the handler.""" + + async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: + return ReadResourceResult( + contents=[TextResourceContents(uri=params.uri, mime_type="text/plain", text="Hello, world!")] + ) + + server = Server("library", on_read_resource=read_resource) + + async with connect(server) as client: + result = await client.read_resource("file:///greeting.txt") + + assert result == snapshot( + ReadResourceResult( + contents=[TextResourceContents(uri="file:///greeting.txt", mime_type="text/plain", text="Hello, world!")] + ) + ) + + +@requirement("resources:read:blob") +async def test_read_resource_binary(connect: Connect) -> None: + """Reading a binary resource returns its contents base64-encoded in the blob field.""" + + async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: + return ReadResourceResult( + contents=[ + BlobResourceContents( + uri=params.uri, + mime_type="image/png", + blob=base64.b64encode(b"\x89PNG").decode(), + ) + ] + ) + + server = Server("library", on_read_resource=read_resource) + + async with connect(server) as client: + result = await client.read_resource("file:///pixel.png") + + assert result == snapshot( + ReadResourceResult( + contents=[BlobResourceContents(uri="file:///pixel.png", mime_type="image/png", blob="iVBORw==")] + ) + ) + + +@requirement("resources:read:unknown-uri") +async def test_read_resource_unknown_uri_is_protocol_error(connect: Connect) -> None: + """A handler that rejects an unrecognised URI with MCPError produces a JSON-RPC error. + + The spec reserves -32002 for resource-not-found; the code is the handler's choice and reaches + the client verbatim. + """ + + async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: + raise MCPError(code=-32002, message=f"Resource not found: {params.uri}") + + server = Server("library", on_read_resource=read_resource) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.read_resource("file:///missing.txt") + + assert exc_info.value.error == snapshot(ErrorData(code=-32002, message="Resource not found: file:///missing.txt")) + + +@requirement("resources:templates:list") +async def test_list_resource_templates_returns_registered_templates(connect: Connect) -> None: + """Listed resource templates reach the client with their URI templates and descriptive fields intact.""" + + async def list_resource_templates( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourceTemplatesResult: + return ListResourceTemplatesResult( + resource_templates=[ + ResourceTemplate(uri_template="users://{user_id}", name="user"), + ResourceTemplate( + uri_template="logs://{service}/{date}", + name="service_logs", + title="Service logs", + description="One day of logs for one service.", + mime_type="text/plain", + icons=[Icon(src="https://example.com/logs.png", mime_type="image/png", sizes=["48x48"])], + ), + ] + ) + + server = Server("library", on_list_resource_templates=list_resource_templates) + + async with connect(server) as client: + result = await client.list_resource_templates() + + assert result == snapshot( + ListResourceTemplatesResult( + resource_templates=[ + ResourceTemplate(uri_template="users://{user_id}", name="user"), + ResourceTemplate( + uri_template="logs://{service}/{date}", + name="service_logs", + title="Service logs", + description="One day of logs for one service.", + mime_type="text/plain", + icons=[Icon(src="https://example.com/logs.png", mime_type="image/png", sizes=["48x48"])], + ), + ] + ) + ) + + +@requirement("resources:subscribe") +async def test_subscribe_resource_delivers_uri_to_handler(connect: Connect) -> None: + """Subscribing to a resource delivers the URI to the server's subscribe handler and returns an empty result.""" + + async def subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeRequestParams) -> EmptyResult: + assert params.uri == "file:///watched.txt" + return EmptyResult() + + server = Server("library", on_subscribe_resource=subscribe_resource) + + async with connect(server) as client: + result = await client.subscribe_resource("file:///watched.txt") + + assert result == snapshot(EmptyResult()) + + +@requirement("resources:subscribe:capability-required") +async def test_subscribe_without_a_subscribe_handler_is_method_not_found(connect: Connect) -> None: + """Subscribing to a server that registered no subscribe handler is rejected with METHOD_NOT_FOUND. + + The rejection comes from no handler being registered, not from any capability check; see the + divergence on lifecycle:capability:server-not-advertised. + """ + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: + """Registered only so the resources capability is advertised; never called.""" + raise NotImplementedError + + server = Server("library", on_list_resources=list_resources) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.subscribe_resource("file:///watched.txt") + + assert exc_info.value.error == snapshot(ErrorData(code=METHOD_NOT_FOUND, message="Method not found")) + + +@requirement("resources:unsubscribe") +async def test_unsubscribe_resource_delivers_uri_to_handler(connect: Connect) -> None: + """Unsubscribing from a resource delivers the URI to the server's unsubscribe handler.""" + + async def unsubscribe_resource(ctx: ServerRequestContext, params: types.UnsubscribeRequestParams) -> EmptyResult: + assert params.uri == "file:///watched.txt" + return EmptyResult() + + server = Server("library", on_unsubscribe_resource=unsubscribe_resource) + + async with connect(server) as client: + result = await client.unsubscribe_resource("file:///watched.txt") + + assert result == snapshot(EmptyResult()) + + +@requirement("resources:updated-notification") +async def test_resource_updated_notification_reaches_client(connect: Connect) -> None: + """A resources/updated notification sent during a tool call reaches the client with the resource URI. + + ``send_resource_updated`` does not take a ``related_request_id``, so over streamable HTTP the + notification routes to the standalone GET stream and is not guaranteed to arrive before the + tool result; the test waits on an event the collector sets. The collector records every + message the handler receives, so the assertion also proves nothing else was delivered. + """ + received: list[IncomingMessage] = [] + seen = anyio.Event() + + async def collect(message: IncomingMessage) -> None: + received.append(message) + seen.set() + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="touch", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "touch" + await ctx.session.send_resource_updated("file:///watched.txt") + return CallToolResult(content=[TextContent(text="touched")]) + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: + """Registered so the resources capability is advertised; the client never lists resources.""" + raise NotImplementedError + + async def subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeRequestParams) -> EmptyResult: + """Registered so the resources subscribe sub-capability is advertised; the client never subscribes.""" + raise NotImplementedError + + server = Server( + "library", + on_list_tools=list_tools, + on_call_tool=call_tool, + on_list_resources=list_resources, + on_subscribe_resource=subscribe_resource, + ) + + async with connect(server, message_handler=collect) as client: + await client.call_tool("touch", {}) + with anyio.fail_after(5): + await seen.wait() + + assert received == snapshot( + [ResourceUpdatedNotification(params=ResourceUpdatedNotificationParams(uri="file:///watched.txt"))] + ) diff --git a/tests/interaction/lowlevel/test_roots.py b/tests/interaction/lowlevel/test_roots.py new file mode 100644 index 0000000000..8149e0befb --- /dev/null +++ b/tests/interaction/lowlevel/test_roots.py @@ -0,0 +1,166 @@ +"""Roots interactions against the low-level Server, driven through the public Client API.""" + +import anyio +import pytest +from inline_snapshot import snapshot +from pydantic import FileUrl + +from mcp import MCPError, types +from mcp.client import ClientRequestContext +from mcp.server import Server, ServerRequestContext +from mcp.types import INTERNAL_ERROR, CallToolResult, ErrorData, ListRootsResult, Root, TextContent +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("roots:list:basic") +async def test_list_roots_round_trip(connect: Connect) -> None: + """A roots/list request from a tool handler is answered by the client's roots callback. + + The tool reports the URIs and names it received, proving the client's roots reached the server. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="show_roots", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "show_roots" + result = await ctx.session.list_roots() + lines = [f"{root.uri} name={root.name}" for root in result.roots] + return CallToolResult(content=[TextContent(text="\n".join(lines))]) + + server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) + + async def list_roots(context: ClientRequestContext) -> ListRootsResult: + return ListRootsResult( + roots=[ + Root(uri=FileUrl("file:///home/alice/project"), name="project"), + Root(uri=FileUrl("file:///home/alice/scratch")), + ] + ) + + async with connect(server, list_roots_callback=list_roots) as client: + result = await client.call_tool("show_roots", {}) + + assert result == snapshot( + CallToolResult( + content=[TextContent(text="file:///home/alice/project name=project\nfile:///home/alice/scratch name=None")] + ) + ) + + +@requirement("roots:list:empty") +async def test_list_roots_empty(connect: Connect) -> None: + """A client with no roots to offer answers roots/list with an empty list, not an error.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="count_roots", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "count_roots" + result = await ctx.session.list_roots() + return CallToolResult(content=[TextContent(text=str(len(result.roots)))]) + + server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) + + async def list_roots(context: ClientRequestContext) -> ListRootsResult: + return ListRootsResult(roots=[]) + + async with connect(server, list_roots_callback=list_roots) as client: + result = await client.call_tool("count_roots", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="0")])) + + +@requirement("roots:list:not-supported") +async def test_list_roots_without_callback_is_error(connect: Connect) -> None: + """A roots/list request to a client with no roots callback fails with an error the handler can observe. + + The client's default callback answers with INVALID_REQUEST rather than leaving the server + hanging; the spec names -32601 for this case (see the divergence note on the requirement). + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="show_roots", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "show_roots" + try: + await ctx.session.list_roots() + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # list_roots cannot succeed without a client callback + + server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("show_roots", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="-32600: List roots not supported")])) + + +@requirement("roots:list:client-error") +async def test_list_roots_callback_error_surfaces_to_the_handler(connect: Connect) -> None: + """A roots callback that answers with an error fails the roots/list request with that exact error. + + The callback's code and message reach the requesting handler verbatim as an MCPError. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="show_roots", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "show_roots" + try: + await ctx.session.list_roots() + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # the callback always answers with an error + + server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) + + async def list_roots(context: ClientRequestContext) -> ErrorData: + return ErrorData(code=INTERNAL_ERROR, message="roots provider crashed") + + async with connect(server, list_roots_callback=list_roots) as client: + result = await client.call_tool("show_roots", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="-32603: roots provider crashed")])) + + +@requirement("roots:list-changed") +async def test_roots_list_changed_reaches_server_handler(connect: Connect) -> None: + """A roots/list_changed notification from the client is delivered to the server's handler. + + Unlike a request, a notification has no response to await: the handler sets an event and the + test waits on it, which is the only synchronisation point proving delivery. + """ + delivered = anyio.Event() + received: list[types.NotificationParams | None] = [] + + async def roots_list_changed(ctx: ServerRequestContext, params: types.NotificationParams | None) -> None: + received.append(params) + delivered.set() + + server = Server("rooted", on_roots_list_changed=roots_list_changed) + + async def list_roots(context: ClientRequestContext) -> ListRootsResult: + """Registered so the client declares the roots capability; the server never asks for roots.""" + raise NotImplementedError + + async with connect(server, list_roots_callback=list_roots) as client: + await client.send_roots_list_changed() + with anyio.fail_after(5): + await delivered.wait() + + assert received == snapshot([None]) diff --git a/tests/interaction/lowlevel/test_sampling.py b/tests/interaction/lowlevel/test_sampling.py new file mode 100644 index 0000000000..260e564192 --- /dev/null +++ b/tests/interaction/lowlevel/test_sampling.py @@ -0,0 +1,687 @@ +"""Sampling interactions against the low-level Server, driven through the public Client API. + +Each test nests a sampling/createMessage request inside a tool call: the tool handler calls +ctx.session.create_message(), the client's sampling callback answers it, and the handler +round-trips what it received back to the test through its tool result. +""" + +import pydantic +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client import ClientRequestContext +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + AudioContent, + CallToolResult, + CreateMessageRequestParams, + CreateMessageResult, + CreateMessageResultWithTools, + ErrorData, + ImageContent, + ModelHint, + ModelPreferences, + SamplingCapability, + SamplingMessage, + TextContent, + ToolResultContent, + ToolUseContent, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("sampling:create:basic") +@requirement("tools:call:sampling-roundtrip") +async def test_create_message_round_trip(connect: Connect) -> None: + """A handler's sampling request is answered by the client callback, and the callback's result + (role, content, model, stop reason) is returned to the handler. + """ + received: list[CreateMessageRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Say hello."))], + max_tokens=100, + ) + assert isinstance(result.content, TextContent) + return CallToolResult(content=[TextContent(text=f"{result.model}/{result.stop_reason}: {result.content.text}")]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + received.append(params) + return CreateMessageResult( + role="assistant", + content=TextContent(text="Hello to you too."), + model="mock-llm-1", + stop_reason="endTurn", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="mock-llm-1/endTurn: Hello to you too.")])) + assert received == snapshot( + [ + CreateMessageRequestParams( + _meta={}, + messages=[SamplingMessage(role="user", content=TextContent(text="Say hello."))], + max_tokens=100, + ) + ] + ) + + +@requirement("sampling:create:include-context") +@requirement("sampling:create:model-preferences") +@requirement("sampling:create:system-prompt") +@requirement("sampling:context:server-gated-by-capability") +async def test_create_message_params_reach_callback(connect: Connect) -> None: + """Every sampling parameter the handler supplies arrives at the client callback unchanged. + + The client has not declared the sampling.context capability (Client cannot declare it), yet + include_context="thisServer" reaches the callback regardless: the spec's SHOULD NOT is not + enforced. See the divergence note on `sampling:context:server-gated-by-capability`. + """ + received: list[CreateMessageRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Pick a model."))], + max_tokens=50, + system_prompt="You are terse.", + include_context="thisServer", + temperature=0.7, + stop_sequences=["\n\n", "END"], + model_preferences=ModelPreferences( + hints=[ModelHint(name="claude"), ModelHint(name="gpt")], + cost_priority=0.2, + speed_priority=0.3, + intelligence_priority=0.9, + ), + ) + assert isinstance(result.content, TextContent) + return CallToolResult(content=[TextContent(text=result.content.text)]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + received.append(params) + return CreateMessageResult(role="assistant", content=TextContent(text="ok"), model="mock-llm-1") + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="ok")])) + assert received == snapshot( + [ + CreateMessageRequestParams( + _meta={}, + messages=[SamplingMessage(role="user", content=TextContent(text="Pick a model."))], + model_preferences=ModelPreferences( + hints=[ModelHint(name="claude"), ModelHint(name="gpt")], + cost_priority=0.2, + speed_priority=0.3, + intelligence_priority=0.9, + ), + system_prompt="You are terse.", + include_context="thisServer", + temperature=0.7, + max_tokens=50, + stop_sequences=["\n\n", "END"], + ) + ] + ) + + +@requirement("sampling:create-message:image-content") +async def test_create_message_request_with_image_content_reaches_callback(connect: Connect) -> None: + """A sampling request message carrying image content arrives at the client callback intact. + + This is the server-to-client direction: the server includes an image in the conversation it + asks the client to sample from. + """ + received: list[CreateMessageRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="describe_image", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "describe_image" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=ImageContent(data="aW1n", mime_type="image/png"))], + max_tokens=100, + ) + assert isinstance(result.content, TextContent) + return CallToolResult(content=[TextContent(text=result.content.text)]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + received.append(params) + image = params.messages[0].content + assert isinstance(image, ImageContent) + return CreateMessageResult( + role="assistant", + content=TextContent(text=f"described {image.mime_type} ({image.data})"), + model="mock-vision-1", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("describe_image", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="described image/png (aW1n)")])) + assert received == snapshot( + [ + CreateMessageRequestParams( + _meta={}, + messages=[SamplingMessage(role="user", content=ImageContent(data="aW1n", mime_type="image/png"))], + max_tokens=100, + ) + ] + ) + + +@requirement("sampling:create-message:image-content") +async def test_create_message_result_with_image_content_returns_to_handler(connect: Connect) -> None: + """A sampling result whose content is an image is returned to the requesting handler intact. + + This is the client-to-server direction: the model's response is an image rather than text. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="draw", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "draw" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Draw a cat."))], + max_tokens=100, + ) + image = result.content + assert isinstance(image, ImageContent) + return CallToolResult(content=[TextContent(text=f"{result.model}: {image.mime_type} {image.data}")]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + return CreateMessageResult( + role="assistant", + content=ImageContent(data="Y2F0", mime_type="image/png"), + model="mock-vision-1", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("draw", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="mock-vision-1: image/png Y2F0")])) + + +@requirement("sampling:error:user-rejected") +async def test_create_message_callback_error(connect: Connect) -> None: + """A sampling callback that answers with an error surfaces to the requesting handler as an MCPError. + + The error here is the spec's own example for a user rejecting a sampling request (code -1); + the callback's code and message reach the handler verbatim, whatever they are. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + try: + await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Say hello."))], + max_tokens=100, + ) + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # the callback always answers with an error + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback(context: ClientRequestContext, params: CreateMessageRequestParams) -> ErrorData: + return ErrorData(code=-1, message="User rejected sampling request") + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="-1: User rejected sampling request")])) + + +@requirement("sampling:create-message:not-supported") +async def test_create_message_without_callback_is_error(connect: Connect) -> None: + """A sampling request to a client with no sampling callback fails with the SDK's default error.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + try: + await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Say hello."))], + max_tokens=100, + ) + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # create_message cannot succeed without a client callback + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="-32600: Sampling not supported")])) + + +@requirement("sampling:tools:server-gated-by-capability") +async def test_create_message_with_tools_is_rejected_for_unsupporting_client(connect: Connect) -> None: + """A tool-enabled sampling request to a client that has not declared sampling.tools never leaves the server. + + The client supports plain sampling but cannot declare the tools sub-capability (Client does not + expose it), so the server-side validator rejects the request before anything reaches the wire. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + try: + await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="What is the weather?"))], + max_tokens=100, + tools=[types.Tool(name="get_weather", input_schema={"type": "object"})], + ) + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # the validator rejects every tool-enabled request + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + """Declares the plain sampling capability; never invoked because the request is rejected first.""" + raise NotImplementedError + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="-32602: Client does not support sampling tools capability")]) + ) + + +@requirement("sampling:tool-result:no-mixed-content") +async def test_create_message_with_mixed_tool_result_content_is_rejected(connect: Connect) -> None: + """A sampling request whose user message mixes tool_result with other content never leaves the server. + + The message-structure validation runs inside create_message before the request is sent, even + when no tools are passed, so the client callback is never invoked and the handler observes the + ValueError directly. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="summarise_tools", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "summarise_tools" + try: + await ctx.session.create_message( + messages=[ + SamplingMessage( + role="user", + content=[ + ToolResultContent(tool_use_id="call-1", content=[TextContent(text="42")]), + TextContent(text="Also, a comment alongside the result."), + ], + ) + ], + max_tokens=100, + ) + except ValueError as exc: + return CallToolResult(content=[TextContent(text=f"{type(exc).__name__}: {exc}")]) + raise NotImplementedError # the validator rejects the malformed messages before sending + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + """Declares the sampling capability; never invoked because the request is rejected first.""" + raise NotImplementedError + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("summarise_tools", {}) + + assert result == snapshot( + CallToolResult( + content=[ + TextContent(text="ValueError: The last message must contain only tool_result content if any is present") + ] + ) + ) + + +@requirement("sampling:capability:declare") +async def test_a_client_with_a_sampling_callback_declares_the_sampling_capability(connect: Connect) -> None: + """A client connecting with a sampling callback advertises the sampling capability to the server. + + Client cannot declare any sub-capabilities (it does not expose ClientSession's + sampling_capabilities parameter), so the snapshot pins an empty SamplingCapability. + """ + captured: list[SamplingCapability | None] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="capabilities", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "capabilities" + assert ctx.session.client_params is not None + captured.append(ctx.session.client_params.capabilities.sampling) + return CallToolResult(content=[TextContent(text="ok")]) + + server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + """Registered only so the sampling capability is advertised; never called.""" + raise NotImplementedError + + async with connect(server, sampling_callback=sampling_callback) as client: + await client.call_tool("capabilities", {}) + + assert captured == snapshot([SamplingCapability()]) + + +@requirement("sampling:create-message:audio-content") +async def test_create_message_request_with_audio_content_reaches_callback(connect: Connect) -> None: + """A sampling request message carrying audio content arrives at the client callback intact. + + This is the server-to-client direction: the server includes audio in the conversation it asks + the client to sample from. + """ + received: list[CreateMessageRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="transcribe", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "transcribe" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=AudioContent(data="c25k", mime_type="audio/wav"))], + max_tokens=100, + ) + assert isinstance(result.content, TextContent) + return CallToolResult(content=[TextContent(text=result.content.text)]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + received.append(params) + audio = params.messages[0].content + assert isinstance(audio, AudioContent) + return CreateMessageResult( + role="assistant", + content=TextContent(text=f"transcribed {audio.mime_type} ({audio.data})"), + model="mock-audio-1", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("transcribe", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="transcribed audio/wav (c25k)")])) + assert received == snapshot( + [ + CreateMessageRequestParams( + _meta={}, + messages=[SamplingMessage(role="user", content=AudioContent(data="c25k", mime_type="audio/wav"))], + max_tokens=100, + ) + ] + ) + + +@requirement("sampling:create-message:audio-content") +async def test_create_message_result_with_audio_content_returns_to_handler(connect: Connect) -> None: + """A sampling result whose content is audio is returned to the requesting handler intact. + + This is the client-to-server direction: the model's response is audio rather than text. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="speak", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "speak" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Say hello, aloud."))], + max_tokens=100, + ) + audio = result.content + assert isinstance(audio, AudioContent) + return CallToolResult(content=[TextContent(text=f"{result.model}: {audio.mime_type} {audio.data}")]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + return CreateMessageResult( + role="assistant", + content=AudioContent(data="aGVsbG8=", mime_type="audio/wav"), + model="mock-audio-1", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("speak", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="mock-audio-1: audio/wav aGVsbG8=")])) + + +@requirement("sampling:message:content-cardinality") +async def test_create_message_with_list_valued_message_content_reaches_callback(connect: Connect) -> None: + """A sampling message whose content is a list of blocks arrives at the client callback as a list.""" + received: list[CreateMessageRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="caption", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "caption" + result = await ctx.session.create_message( + messages=[ + SamplingMessage( + role="user", + content=[ + TextContent(text="Caption this image."), + ImageContent(data="aW1n", mime_type="image/png"), + ], + ) + ], + max_tokens=100, + ) + assert isinstance(result.content, TextContent) + return CallToolResult(content=[TextContent(text=result.content.text)]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + received.append(params) + content = params.messages[0].content + assert isinstance(content, list) + return CreateMessageResult( + role="assistant", content=TextContent(text=f"{len(content)} blocks"), model="mock-llm-1" + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("caption", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="2 blocks")])) + assert received == snapshot( + [ + CreateMessageRequestParams( + _meta={}, + messages=[ + SamplingMessage( + role="user", + content=[ + TextContent(text="Caption this image."), + ImageContent(data="aW1n", mime_type="image/png"), + ], + ) + ], + max_tokens=100, + ) + ] + ) + + +@requirement("sampling:tool-use:server-preflight") +async def test_create_message_with_mismatched_tool_use_and_result_ids_is_rejected(connect: Connect) -> None: + """A sampling request whose tool_result ids do not match the preceding tool_use ids never leaves the server. + + The message-structure validation runs inside create_message before the request is sent, so the + client callback is never invoked and the handler observes the ValueError directly. The spec's + client-side -32602 check is tracked separately at sampling:tool-use:result-balance. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="continue_tools", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "continue_tools" + try: + await ctx.session.create_message( + messages=[ + SamplingMessage( + role="assistant", + content=[ToolUseContent(id="call-1", name="weather", input={})], + ), + SamplingMessage( + role="user", + content=[ToolResultContent(tool_use_id="call-WRONG", content=[TextContent(text="42")])], + ), + ], + max_tokens=100, + ) + except ValueError as exc: + return CallToolResult(content=[TextContent(text=f"{type(exc).__name__}: {exc}")]) + raise NotImplementedError # the validator rejects the malformed messages before sending + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + """Declares the sampling capability; never invoked because the request is rejected first.""" + raise NotImplementedError + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("continue_tools", {}) + + assert result == snapshot( + CallToolResult( + content=[ + TextContent( + text="ValueError: ids of tool_result blocks and tool_use blocks from previous message do not match" + ) + ] + ) + ) + + +@requirement("sampling:result:no-tools-single-content") +async def test_array_content_result_for_a_tool_free_request_surfaces_as_a_validation_error(connect: Connect) -> None: + """An array-content sampling result for a tool-free request is accepted by the client and fails server-side. + + Only the exception type is asserted: the message is pydantic's, which changes across releases. + See the divergence note on the requirement: the intended behaviour is that the client rejects + the result; instead the client accepts it and the server's response parsing raises. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + try: + await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Two thoughts, please."))], + max_tokens=100, + ) + except pydantic.ValidationError as exc: + return CallToolResult(content=[TextContent(text=type(exc).__name__)]) + raise NotImplementedError # the array-content result fails server-side parsing every time + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResultWithTools: + return CreateMessageResultWithTools( + role="assistant", + content=[TextContent(text="First thought."), TextContent(text="Second thought.")], + model="mock-llm-1", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="ValidationError")])) diff --git a/tests/interaction/lowlevel/test_timeouts.py b/tests/interaction/lowlevel/test_timeouts.py new file mode 100644 index 0000000000..a9c83d641d --- /dev/null +++ b/tests/interaction/lowlevel/test_timeouts.py @@ -0,0 +1,114 @@ +"""Request timeouts against the low-level Server, driven through the public Client API. + +The handler blocks on an event that is never set, so the awaited response can never arrive and +any positive timeout fires deterministically on the next event-loop pass. The timeout is therefore +set to an effectively-zero duration: the tests add no wall-clock time to the suite. (Zero itself +cannot be used: a falsy read_timeout_seconds is silently treated as "no timeout".) +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client.client import Client +from mcp.server import Server, ServerRequestContext +from mcp.types import REQUEST_TIMEOUT, CallToolResult, ErrorData, TextContent +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("protocol:timeout:basic") +@requirement("protocol:timeout:sends-cancellation") +async def test_request_timeout_fails_the_pending_call() -> None: + """A request whose response does not arrive within its read timeout fails with a timeout error. + + No cancellation is sent to the server (see the divergence note on the requirement): the handler + starts and is still running after the caller has already given up. The test waits for the + handler to have started only after the timeout has fired, so the timeout itself races nothing. + """ + handler_started = anyio.Event() + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "block" + handler_started.set() + await anyio.Event().wait() # blocks until the session is torn down + raise NotImplementedError # unreachable + + server = Server("blocker", on_call_tool=call_tool) + + async with Client(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("block", {}, read_timeout_seconds=0.000001) + + # The request was already on the wire: the handler still runs even though the caller gave up. + with anyio.fail_after(5): + await handler_started.wait() + + assert exc_info.value.error == snapshot( + ErrorData( + code=REQUEST_TIMEOUT, + message="Timed out while waiting for response to CallToolRequest. Waited 1e-06 seconds.", + ) + ) + + +@requirement("protocol:timeout:session-survives") +async def test_session_serves_requests_after_timeout() -> None: + """A timed-out request does not poison the session: the next request succeeds.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool(name="block", input_schema={"type": "object"}), + types.Tool(name="echo", input_schema={"type": "object"}), + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + if params.name == "echo": + return CallToolResult(content=[TextContent(text="still alive")]) + await anyio.Event().wait() # blocks until the session is torn down + raise NotImplementedError # unreachable + + server = Server("blocker", on_list_tools=list_tools, on_call_tool=call_tool) + + async with Client(server) as client: + with pytest.raises(MCPError): + await client.call_tool("block", {}, read_timeout_seconds=0.000001) + + result = await client.call_tool("echo", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="still alive")])) + + +@requirement("protocol:timeout:session-default") +async def test_session_level_timeout_applies_to_every_request() -> None: + """A read timeout configured on the client applies to requests that do not set their own.""" + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "block" + await anyio.Event().wait() # blocks until the session is torn down + raise NotImplementedError # unreachable + + server = Server("blocker", on_call_tool=call_tool) + + # The one real wall-clock wait in the suite, and it cannot be made effectively zero like the + # per-request timeouts: a session-level timeout also governs the initialize handshake, so the + # value must be long enough for the in-process handshake to complete before the blocked tool + # call waits it out in full. 50ms buys a ~50x safety margin over the handshake's actual + # latency; lowering it only erodes the margin against CI scheduler jitter without saving + # anything perceptible. + async with Client(server, read_timeout_seconds=0.05) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("block", {}) + + assert exc_info.value.error == snapshot( + ErrorData( + code=REQUEST_TIMEOUT, + message="Timed out while waiting for response to CallToolRequest. Waited 0.05 seconds.", + ) + ) diff --git a/tests/interaction/lowlevel/test_tools.py b/tests/interaction/lowlevel/test_tools.py new file mode 100644 index 0000000000..e8053fbaa7 --- /dev/null +++ b/tests/interaction/lowlevel/test_tools.py @@ -0,0 +1,512 @@ +"""Tool interactions against the low-level Server, driven through the public Client API.""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + INVALID_PARAMS, + AudioContent, + CallToolResult, + EmbeddedResource, + ErrorData, + Icon, + ImageContent, + ListToolsResult, + ResourceLink, + TextContent, + TextResourceContents, + Tool, + ToolAnnotations, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("tools:call:content:text") +async def test_call_tool_returns_text_content(connect: Connect) -> None: + """Arguments reach the tool handler; its content comes back as the call result.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="add", description="Add two integers.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "add" + assert params.arguments is not None + return CallToolResult(content=[TextContent(text=str(params.arguments["a"] + params.arguments["b"]))]) + + server = Server("adder", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("add", {"a": 2, "b": 3}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="5")])) + + +@requirement("tools:call:is-error") +async def test_call_tool_execution_error_is_returned_as_result(connect: Connect) -> None: + """A tool reporting its own failure with is_error=True reaches the client as a result, not an exception. + + Tool execution errors are part of the result so the caller (typically a model) can see + them; only protocol-level failures become JSON-RPC errors. + """ + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "flux" + return CallToolResult(content=[TextContent(text="the flux capacitor is offline")], is_error=True) + + server = Server("errors", on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("flux", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="the flux capacitor is offline")], is_error=True) + ) + + +@requirement("tools:call:unknown-name") +async def test_call_tool_unknown_tool_is_protocol_error(connect: Connect) -> None: + """A handler that rejects an unrecognised tool name with MCPError produces a JSON-RPC error. + + The error's code, message, and data chosen by the handler reach the client verbatim. + """ + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + raise MCPError(code=INVALID_PARAMS, message=f"Unknown tool: {params.name}", data={"requested": params.name}) + + server = Server("errors", on_call_tool=call_tool) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("nope", {}) + + assert exc_info.value.error == snapshot( + ErrorData(code=INVALID_PARAMS, message="Unknown tool: nope", data={"requested": "nope"}) + ) + + +@requirement("protocol:error:internal-error") +async def test_call_tool_uncaught_exception_becomes_error_response(connect: Connect) -> None: + """An uncaught exception in the tool handler surfaces to the client as a JSON-RPC error. + + The low-level server reports it with code 0 and the exception text as the message; see the + divergence note on the requirement. + """ + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "explode" + raise ValueError("boom") + + server = Server("errors", on_call_tool=call_tool) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("explode", {}) + + assert exc_info.value.error == snapshot(ErrorData(code=0, message="boom")) + + +@requirement("tools:list:basic") +async def test_list_tools_returns_registered_tools(connect: Connect) -> None: + """The tools advertised by the server's list handler arrive at the client unchanged.""" + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="add", + description="Add two integers.", + input_schema={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + ), + Tool(name="reset", description="Reset the calculator.", input_schema={"type": "object"}), + ] + ) + + server = Server("calculator", on_list_tools=list_tools) + + async with connect(server) as client: + result = await client.list_tools() + + assert result == snapshot( + ListToolsResult( + tools=[ + Tool( + name="add", + description="Add two integers.", + input_schema={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + ), + Tool(name="reset", description="Reset the calculator.", input_schema={"type": "object"}), + ] + ) + ) + + +@requirement("tools:input-schema:json-schema-2020-12") +@requirement("tools:input-schema:preserve-additional-properties") +@requirement("tools:input-schema:preserve-defs") +@requirement("tools:input-schema:preserve-schema-dialect") +async def test_tools_list_preserves_arbitrary_input_schema_keywords(connect: Connect) -> None: + """A rich JSON Schema 2020-12 inputSchema reaches the client unchanged and the tool is callable. + + The single identity assertion below proves all four pass-through behaviours at once: the same + dict literal that was registered is the dict that arrives, so $schema, $defs, the nested object + property, and additionalProperties are each preserved by virtue of the whole schema being + preserved. The follow-up call proves the rich-schema tool is callable end to end. + """ + schema = { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "$defs": {"positive": {"type": "integer", "exclusiveMinimum": 0}}, + "properties": { + "count": {"$ref": "#/$defs/positive"}, + "options": { + "type": "object", + "properties": {"verbose": {"type": "boolean"}}, + "additionalProperties": False, + }, + }, + "required": ["count"], + "additionalProperties": False, + } + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="typed", input_schema=schema)]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "typed" + assert params.arguments == {"count": 3, "options": {"verbose": True}} + return CallToolResult(content=[TextContent(text="ok")]) + + server = Server("typed", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + listed = await client.list_tools() + called = await client.call_tool("typed", {"count": 3, "options": {"verbose": True}}) + + assert listed.tools[0].input_schema == schema + assert called == snapshot(CallToolResult(content=[TextContent(text="ok")])) + + +@requirement("tools:list:metadata") +async def test_list_tools_optional_fields_round_trip(connect: Connect) -> None: + """Every optional Tool field the server supplies reaches the client unchanged.""" + + tool = Tool( + name="annotated", + title="Annotated tool", + description="A tool carrying every optional field.", + input_schema={"type": "object"}, + output_schema={"type": "object", "properties": {"answer": {"type": "integer"}}}, + icons=[Icon(src="https://example.com/icon.png", mime_type="image/png", sizes=["48x48"])], + annotations=ToolAnnotations(title="Display title", read_only_hint=True, idempotent_hint=True), + _meta={"example.com/source": "interaction-suite"}, + ) + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[tool]) + + server = Server("annotated", on_list_tools=list_tools) + + async with connect(server) as client: + result = await client.list_tools() + + assert result == snapshot( + ListToolsResult( + tools=[ + Tool( + name="annotated", + title="Annotated tool", + description="A tool carrying every optional field.", + input_schema={"type": "object"}, + output_schema={"type": "object", "properties": {"answer": {"type": "integer"}}}, + icons=[Icon(src="https://example.com/icon.png", mime_type="image/png", sizes=["48x48"])], + annotations=ToolAnnotations(title="Display title", read_only_hint=True, idempotent_hint=True), + _meta={"example.com/source": "interaction-suite"}, + ) + ] + ) + ) + + +@requirement("tools:call:content:mixed") +@requirement("tools:call:content:image") +@requirement("tools:call:content:audio") +@requirement("tools:call:content:resource-link") +@requirement("tools:call:content:embedded-resource") +async def test_call_tool_multiple_content_block_types(connect: Connect) -> None: + """A tool result can mix every content block type; all of them arrive in order. + + The payloads are tiny fixed base64 strings ("aW1n" is b"img", "YXVk" is b"aud") so the + snapshot pins the exact bytes the client receives. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="render", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "render" + return CallToolResult( + content=[ + TextContent(text="all five content block types"), + ImageContent(data="aW1n", mime_type="image/png"), + AudioContent(data="YXVk", mime_type="audio/wav"), + ResourceLink(name="report", uri="resource://reports/1", description="The full report"), + EmbeddedResource( + resource=TextResourceContents(uri="resource://reports/1", mime_type="text/plain", text="contents") + ), + ] + ) + + server = Server("renderer", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("render", {}) + + assert result == snapshot( + CallToolResult( + content=[ + TextContent(text="all five content block types"), + ImageContent(data="aW1n", mime_type="image/png"), + AudioContent(data="YXVk", mime_type="audio/wav"), + ResourceLink(name="report", uri="resource://reports/1", description="The full report"), + EmbeddedResource( + resource=TextResourceContents(uri="resource://reports/1", mime_type="text/plain", text="contents") + ), + ] + ) + ) + + +@requirement("tools:call:structured-content") +async def test_call_tool_structured_content(connect: Connect) -> None: + """A tool result carrying structured content alongside content delivers both to the client.""" + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="sum", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "sum" + return CallToolResult(content=[TextContent(text="the sum is 5")], structured_content={"sum": 5}) + + server = Server("calculator", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("sum", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="the sum is 5")], structured_content={"sum": 5})) + + +@requirement("tools:call:concurrent") +async def test_concurrent_tool_calls_complete_independently(connect: Connect) -> None: + """Two tool calls in flight at once run concurrently and each caller gets its own answer. + + Both handlers are held on a shared event after signalling that they have started, and the test + only releases them once both signals have arrived -- a server that processed requests + sequentially would never start the second handler and the test would time out instead. + """ + started: list[str] = [] + started_events = {"first": anyio.Event(), "second": anyio.Event()} + release = anyio.Event() + results: dict[str, CallToolResult] = {} + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="echo", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "echo" + assert params.arguments is not None + tag = params.arguments["tag"] + assert isinstance(tag, str) + started.append(tag) + started_events[tag].set() + await release.wait() + return CallToolResult(content=[TextContent(text=tag)]) + + server = Server("echoer", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + with anyio.fail_after(5): + async with anyio.create_task_group() as task_group: # pragma: no branch + + async def call_and_record(tag: str) -> None: + results[tag] = await client.call_tool("echo", {"tag": tag}) + + task_group.start_soon(call_and_record, "first") + task_group.start_soon(call_and_record, "second") + + # Both handlers are running at the same time before either is allowed to finish. + await started_events["first"].wait() + await started_events["second"].wait() + release.set() + + assert sorted(started) == ["first", "second"] + assert results == snapshot( + { + "first": CallToolResult(content=[TextContent(text="first")]), + "second": CallToolResult(content=[TextContent(text="second")]), + } + ) + + +@requirement("client:output-schema:validate") +async def test_call_tool_structured_content_violating_output_schema_is_rejected_by_the_client(connect: Connect) -> None: + """A result whose structured content does not conform to the tool's declared output schema never + reaches the caller: the client validates it against the schema cached from tools/list and raises. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="forecast", + input_schema={"type": "object"}, + output_schema={ + "type": "object", + "properties": {"temperature": {"type": "number"}}, + "required": ["temperature"], + }, + ) + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "forecast" + return CallToolResult(content=[TextContent(text="warm")], structured_content={"temperature": "warm"}) + + server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.list_tools() + with pytest.raises(RuntimeError) as exc_info: + await client.call_tool("forecast", {}) + + # The message embeds the jsonschema validation error, so only the SDK-authored prefix is pinned. + assert str(exc_info.value).startswith("Invalid structured content returned by tool forecast") + + +@requirement("client:output-schema:skip-on-error") +async def test_is_error_result_bypasses_client_output_schema_validation(connect: Connect) -> None: + """A tool result with isError true is returned as-is even when its structured content violates the schema. + + The schema is cached up front so the client could validate, proving the bypass is specifically the + isError flag and not an empty cache. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="forecast", + input_schema={"type": "object"}, + output_schema={ + "type": "object", + "properties": {"temperature": {"type": "number"}}, + "required": ["temperature"], + }, + ) + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "forecast" + return CallToolResult( + content=[TextContent(text="boom")], structured_content={"temperature": "warm"}, is_error=True + ) + + server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.list_tools() + result = await client.call_tool("forecast", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="boom")], structured_content={"temperature": "warm"}, is_error=True) + ) + + +@requirement("client:output-schema:missing-structured") +async def test_declared_output_schema_with_no_structured_content_is_rejected_by_the_client(connect: Connect) -> None: + """A tool that declared an output schema but returned no structuredContent fails the client-side check. + + The error is the SDK's own message, so the full text is snapshotted. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="forecast", + input_schema={"type": "object"}, + output_schema={"type": "object", "properties": {"temperature": {"type": "number"}}}, + ) + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "forecast" + return CallToolResult(content=[TextContent(text="warm")]) + + server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.list_tools() + with pytest.raises(RuntimeError) as exc_info: + await client.call_tool("forecast", {}) + + assert str(exc_info.value) == snapshot("Tool forecast has an output schema but did not return structured content") + + +@requirement("client:output-schema:auto-list") +async def test_call_tool_populates_the_output_schema_cache_via_an_implicit_tools_list(connect: Connect) -> None: + """Calling a tool whose schema is not cached issues exactly one implicit tools/list to populate it. + + The first call_tool of an uncached tool triggers a tools/list the caller never asked for; the + second call hits the cache and does not. This is the SDK's chosen cache strategy and the cause of + the surprising behaviour where a server with only on_call_tool sees a successful call answered + with METHOD_NOT_FOUND from a request the caller never made; see the divergence on the requirement. + """ + list_calls: list[str] = [] + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + list_calls.append("called") + return ListToolsResult( + tools=[ + Tool( + name="forecast", + input_schema={"type": "object"}, + output_schema={"type": "object", "properties": {"temperature": {"type": "number"}}}, + ) + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "forecast" + return CallToolResult(content=[TextContent(text="21 C")], structured_content={"temperature": 21}) + + server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + first = await client.call_tool("forecast", {}) + assert list_calls == ["called"] + second = await client.call_tool("forecast", {}) + + assert list_calls == ["called"] + assert first == snapshot(CallToolResult(content=[TextContent(text="21 C")], structured_content={"temperature": 21})) + assert second == first diff --git a/tests/interaction/lowlevel/test_wire.py b/tests/interaction/lowlevel/test_wire.py new file mode 100644 index 0000000000..0f9c58aa7a --- /dev/null +++ b/tests/interaction/lowlevel/test_wire.py @@ -0,0 +1,309 @@ +"""Wire-level invariants observed at the client's transport boundary. + +These behaviours are invisible to API callers -- they are properties of the raw JSON-RPC frames. +The tests wrap the in-memory transport in a RecordingTransport, which tees every message crossing +the transport seam into a list without touching the session, so the assertions hold for whatever +the session implementation sends rather than for what its API returns. + +The later tests drive the wire by hand instead: one closes the server-to-client stream while a +request is in flight to pin the connection-closed teardown, and the last two send deliberately +malformed JSON-RPC requests that the typed client API cannot produce. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client import ClientRequestContext, ClientSession +from mcp.client._memory import InMemoryTransport +from mcp.client.client import Client +from mcp.server import Server, ServerRequestContext +from mcp.shared.memory import create_client_server_memory_streams +from mcp.shared.message import SessionMessage +from mcp.types import ( + CONNECTION_CLOSED, + INVALID_PARAMS, + CallToolRequest, + CallToolRequestParams, + CallToolResult, + EmptyResult, + ErrorData, + JSONRPCError, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + ListRootsResult, + TextContent, +) +from tests.interaction._helpers import RecordingTransport, _RecordingReadStream +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +def _echo_server() -> Server: + """A server with one echo tool, used by every test in this module.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="echo", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "echo" + return CallToolResult(content=[TextContent(text="ok")]) + + return Server("wire", on_list_tools=list_tools, on_call_tool=call_tool) + + +@requirement("protocol:request-id:unique") +async def test_request_ids_are_unique_and_never_null() -> None: + """Every request the client sends carries a distinct, non-null id. + + The id sequence is pinned: sequential integers from zero, in send order. + """ + recording = RecordingTransport(InMemoryTransport(_echo_server())) + + async with Client(recording) as client: + await client.list_tools() + await client.call_tool("echo", {}) + await client.call_tool("echo", {}) + await client.send_ping() + + sent = [message.message for message in recording.sent] + request_ids = [message.id for message in sent if isinstance(message, JSONRPCRequest)] + assert all(request_id is not None for request_id in request_ids) + assert len(request_ids) == len(set(request_ids)) + # initialize, tools/list, tools/call, tools/call, ping -- the client does not issue a + # schema-cache refresh here because the explicit tools/list already populated the cache. + assert request_ids == snapshot([0, 1, 2, 3, 4]) + + +@requirement("protocol:notifications:no-response") +async def test_notifications_are_never_answered() -> None: + """A notification produces no response: everything the server sends back answers a request. + + The client sends two notifications (initialized and roots/list_changed) and several requests; + the messages received from the server must be exactly one response per request, each carrying + the id of the request it answers, and nothing else. + """ + + async def list_roots(context: ClientRequestContext) -> ListRootsResult: + """Registered so the client declares the roots capability; the server never asks for roots.""" + raise NotImplementedError + + recording = RecordingTransport(InMemoryTransport(_echo_server())) + + async with Client(recording, list_roots_callback=list_roots) as client: + await client.send_roots_list_changed() + await client.send_ping() + + sent = [message.message for message in recording.sent] + sent_request_ids = [message.id for message in sent if isinstance(message, JSONRPCRequest)] + sent_notifications = [message for message in sent if isinstance(message, JSONRPCNotification)] + received = [message.message for message in recording.received if isinstance(message, SessionMessage)] + received_responses = [message for message in received if isinstance(message, JSONRPCResponse)] + + assert len(sent_notifications) == 2 # notifications/initialized and notifications/roots/list_changed + assert len(received_responses) == len(received) # nothing the server sent was anything but a response + assert [message.id for message in received_responses] == sent_request_ids + + +async def test_recording_read_stream_ends_iteration_when_the_sender_closes() -> None: + """The recording wrapper preserves the end-of-stream behaviour of the stream it wraps. + + This exercises the helper itself rather than an interaction-model behaviour: a transport whose + far end closes must end the client's receive loop cleanly, and the wrapper must not swallow or + mistranslate that. + """ + send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](1) + log: list[SessionMessage | Exception] = [] + async with send_stream, _RecordingReadStream(receive_stream, log) as wrapped: + await send_stream.aclose() + items = [item async for item in wrapped] + assert items == [] + assert log == [] + + +@requirement("lifecycle:initialized-notification") +async def test_exactly_one_initialized_notification_is_sent_after_the_handshake() -> None: + """The client sends initialized exactly once, between the initialize response and its first request. + + The full method sequence the client puts on the wire is pinned in send order. + """ + recording = RecordingTransport(InMemoryTransport(_echo_server())) + + async with Client(recording) as client: + await client.list_tools() + + sent_methods = [ + message.message.method + for message in recording.sent + if isinstance(message.message, JSONRPCRequest | JSONRPCNotification) + ] + assert sent_methods.count("notifications/initialized") == 1 + assert sent_methods == snapshot(["initialize", "notifications/initialized", "tools/list"]) + + +@requirement("protocol:error:connection-closed") +async def test_closing_the_transport_fails_in_flight_requests_with_connection_closed() -> None: + """When the server-to-client stream closes, every in-flight client request fails with CONNECTION_CLOSED. + + Driven over a bare ClientSession against a real Server so the test holds the transport stream + pair directly: once the request is in flight (the server handler signals it has started) the + test closes the server's write stream, which ends the client's receive loop and triggers the + teardown that fails the pending request. + """ + handler_started = anyio.Event() + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "block" + handler_started.set() + await anyio.Event().wait() # blocks until cancelled; nothing ever sets this event + raise NotImplementedError # unreachable: the wait above never completes normally + + server = Server("blocker", on_call_tool=call_tool) + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + errors: list[ErrorData] = [] + + async with anyio.create_task_group() as server_task_group: + server_task_group.start_soon(server.run, server_read, server_write, server.create_initialization_options()) + + async with ClientSession(client_read, client_write) as session: + with anyio.fail_after(5): + await session.initialize() + + async def call_and_capture_error() -> None: + with pytest.raises(MCPError) as exc_info: + await session.send_request( + CallToolRequest(params=CallToolRequestParams(name="block")), CallToolResult + ) + errors.append(exc_info.value.error) + + async with anyio.create_task_group() as task_group: # pragma: no branch + task_group.start_soon(call_and_capture_error) + await handler_started.wait() + await server_write.aclose() + + server_task_group.cancel_scope.cancel() + + assert errors == snapshot([ErrorData(code=CONNECTION_CLOSED, message="Connection closed")]) + + +@requirement("protocol:error:invalid-params") +async def test_malformed_request_params_are_answered_with_invalid_params() -> None: + """A request whose params fail validation is answered with -32602 Invalid params. + + The typed client API cannot construct a request with the wrong parameter types, so the test + plays the client's side of the wire by hand against a real Server: it completes the + initialization handshake at the JSON-RPC layer and then sends a tools/call whose `name` is an + integer. Reserve this pattern for behaviour the typed API cannot produce. + """ + server = Server("strict") + errors: list[ErrorData] = [] + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async with anyio.create_task_group() as server_task_group: + server_task_group.start_soon(server.run, server_read, server_write, server.create_initialization_options()) + + with anyio.fail_after(5): + await client_write.send( + SessionMessage( + JSONRPCRequest( + jsonrpc="2.0", + id=0, + method="initialize", + params={ + "protocolVersion": "2025-11-25", + "capabilities": {}, + "clientInfo": {"name": "raw", "version": "0.0.1"}, + }, + ) + ) + ) + init_response = await client_read.receive() + assert isinstance(init_response, SessionMessage) + assert isinstance(init_response.message, JSONRPCResponse) + await client_write.send( + SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")) + ) + + await client_write.send( + SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=1, method="tools/call", params={"name": 42})) + ) + error_response = await client_read.receive() + assert isinstance(error_response, SessionMessage) + assert isinstance(error_response.message, JSONRPCError) + errors.append(error_response.message.error) + + server_task_group.cancel_scope.cancel() + + assert errors == snapshot([ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="")]) + + +@requirement("logging:set-level:invalid-level") +async def test_set_level_with_an_unrecognized_value_is_answered_with_invalid_params() -> None: + """logging/setLevel with a value outside the spec's level enum is answered with -32602 Invalid params. + + The typed client API cannot construct a setLevel request with an unrecognized level (pyright and + the client-side model both reject it), so the test plays the client's side of the wire by hand + against a real Server. Reserve this pattern for behaviour the typed API cannot produce. + """ + + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + """Registered so the logging capability is advertised; never called -- params validation fails first.""" + raise NotImplementedError + + server = Server("logger", on_set_logging_level=set_logging_level) + errors: list[ErrorData] = [] + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async with anyio.create_task_group() as server_task_group: + server_task_group.start_soon(server.run, server_read, server_write, server.create_initialization_options()) + + with anyio.fail_after(5): + await client_write.send( + SessionMessage( + JSONRPCRequest( + jsonrpc="2.0", + id=0, + method="initialize", + params={ + "protocolVersion": "2025-11-25", + "capabilities": {}, + "clientInfo": {"name": "raw", "version": "0.0.1"}, + }, + ) + ) + ) + init_response = await client_read.receive() + assert isinstance(init_response, SessionMessage) + assert isinstance(init_response.message, JSONRPCResponse) + await client_write.send( + SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")) + ) + + await client_write.send( + SessionMessage( + JSONRPCRequest(jsonrpc="2.0", id=1, method="logging/setLevel", params={"level": "loud"}) + ) + ) + error_response = await client_read.receive() + assert isinstance(error_response, SessionMessage) + assert isinstance(error_response.message, JSONRPCError) + errors.append(error_response.message.error) + + server_task_group.cancel_scope.cancel() + + assert len(errors) == 1 + assert errors[0].code == INVALID_PARAMS diff --git a/tests/interaction/mcpserver/__init__.py b/tests/interaction/mcpserver/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/interaction/mcpserver/test_completion.py b/tests/interaction/mcpserver/test_completion.py new file mode 100644 index 0000000000..7761066e94 --- /dev/null +++ b/tests/interaction/mcpserver/test_completion.py @@ -0,0 +1,38 @@ +"""Completion behaviour against MCPServer, driven through the public Client API.""" + +import pytest + +from mcp.server.mcpserver import MCPServer +from mcp.types import ( + Completion, + CompletionArgument, + CompletionContext, + CompletionsCapability, + PromptReference, + ResourceTemplateReference, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("mcpserver:completion:capability-auto") +async def test_completion_capability_is_advertised_only_when_a_handler_is_registered(connect: Connect) -> None: + """An MCPServer with a registered completion handler advertises the completions capability; one without does not.""" + with_handler = MCPServer("completer") + + @with_handler.completion() + async def complete( + ref: PromptReference | ResourceTemplateReference, + argument: CompletionArgument, + context: CompletionContext | None, + ) -> Completion | None: + """Registered only so the completions capability is advertised; never called.""" + raise NotImplementedError + + async with connect(with_handler) as client: + assert client.initialize_result.capabilities.completions == CompletionsCapability() + + async with connect(MCPServer("plain")) as client: + assert client.initialize_result.capabilities.completions is None diff --git a/tests/interaction/mcpserver/test_context.py b/tests/interaction/mcpserver/test_context.py new file mode 100644 index 0000000000..26556fea7a --- /dev/null +++ b/tests/interaction/mcpserver/test_context.py @@ -0,0 +1,271 @@ +"""The Context convenience methods MCPServer injects into tool functions, observed from the client.""" + +import pytest +from inline_snapshot import snapshot +from pydantic import BaseModel + +from mcp import MCPError +from mcp.client import ClientRequestContext +from mcp.server.elicitation import AcceptedElicitation +from mcp.server.mcpserver import Context, MCPServer +from mcp.types import ( + METHOD_NOT_FOUND, + CallToolResult, + ElicitRequestFormParams, + ElicitRequestParams, + ElicitResult, + ErrorData, + Implementation, + LoggingMessageNotification, + LoggingMessageNotificationParams, + TextContent, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("mcpserver:context:logging") +@requirement("logging:capability:declared") +async def test_context_logging_helpers_send_log_notifications(connect: Connect) -> None: + """Each Context logging helper sends a log message notification at the matching severity. + + All four notifications reach the client's logging callback before the tool call returns; none + of them carry a logger name unless one is passed explicitly. The server emits these without + advertising the logging capability (see the divergence note on logging:capability). + """ + received: list[LoggingMessageNotificationParams] = [] + mcp = MCPServer("chatty") + + @mcp.tool() + async def narrate(ctx: Context) -> str: + await ctx.debug("d") + await ctx.info("i") + await ctx.warning("w") + await ctx.error("e") + return "done" + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params) + + async with connect(mcp, logging_callback=collect) as client: + result = await client.call_tool("narrate", {}) + advertised_logging = client.initialize_result.capabilities.logging + + assert result == snapshot(CallToolResult(content=[TextContent(text="done")], structured_content={"result": "done"})) + assert received == snapshot( + [ + LoggingMessageNotificationParams(level="debug", data="d"), + LoggingMessageNotificationParams(level="info", data="i"), + LoggingMessageNotificationParams(level="warning", data="w"), + LoggingMessageNotificationParams(level="error", data="e"), + ] + ) + # The spec requires servers that emit log notifications to declare the logging capability. + assert advertised_logging is None + + +@requirement("mcpserver:context:progress") +async def test_context_report_progress_sends_progress_notifications(connect: Connect) -> None: + """Context.report_progress sends progress notifications correlated to the calling request. + + The caller's progress callback receives each report, in order, before the tool call returns. + """ + received: list[tuple[float, float | None, str | None]] = [] + mcp = MCPServer("worker") + + @mcp.tool() + async def crunch(ctx: Context) -> str: + await ctx.report_progress(1, 3) + await ctx.report_progress(2, 3, "halfway there") + return "crunched" + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + received.append((progress, total, message)) + + async with connect(mcp) as client: + result = await client.call_tool("crunch", {}, progress_callback=on_progress) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="crunched")], structured_content={"result": "crunched"}) + ) + assert received == snapshot([(1.0, 3.0, None), (2.0, 3.0, "halfway there")]) + + +@requirement("mcpserver:tool:extra") +async def test_context_exposes_request_id_and_client_info_to_a_tool(connect: Connect) -> None: + """A tool can read the per-request id and the connecting client's identity through Context. + + The request id is non-empty (its concrete value depends on transport-level sequencing, so the + test asserts the value the tool saw is the one returned, rather than pinning the literal); the + client info reflects what the caller passed to `Client`. + """ + mcp = MCPServer("introspector") + + @mcp.tool() + async def whoami(ctx: Context) -> str: + client_params = ctx.session.client_params + assert client_params is not None + return f"request {ctx.request_id} from {client_params.client_info.name} {client_params.client_info.version}" + + async with connect(mcp, client_info=Implementation(name="acme-agent", version="9.9.9")) as client: + result = await client.call_tool("whoami", {}) + + assert isinstance(result.content[0], TextContent) + text = result.content[0].text + assert text.startswith("request ") + assert text.endswith(" from acme-agent 9.9.9") + request_id = text.removeprefix("request ").removesuffix(" from acme-agent 9.9.9") + assert request_id + + +@requirement("protocol:progress:no-token") +async def test_report_progress_without_a_progress_token_sends_nothing(connect: Connect) -> None: + """When the caller supplied no progress callback, Context.report_progress is a silent no-op. + + The tool also emits one log message as a sentinel: the message handler receives only that, + proving the notification pipeline works and no progress notification was sent for the + token-less request. + """ + received: list[IncomingMessage] = [] + mcp = MCPServer("quiet") + + @mcp.tool() + async def mill(ctx: Context) -> str: + await ctx.report_progress(1, 3) + await ctx.info("milling done") + return "milled" + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async with connect(mcp, message_handler=collect) as client: + result = await client.call_tool("mill", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="milled")], structured_content={"result": "milled"}) + ) + assert received == snapshot( + [LoggingMessageNotification(params=LoggingMessageNotificationParams(level="info", data="milling done"))] + ) + + +@requirement("mcpserver:context:elicit") +@requirement("tools:call:elicitation-roundtrip") +async def test_context_elicit_returns_typed_result(connect: Connect) -> None: + """Context.elicit sends a form elicitation built from a pydantic schema and returns a typed result. + + The client sees the JSON schema generated from the model; the accepted content is validated + back into the model and handed to the tool as result.data. + """ + received: list[ElicitRequestParams] = [] + mcp = MCPServer("travel") + + class TravelPreferences(BaseModel): + destination: str + window_seat: bool + + @mcp.tool() + async def book_flight(ctx: Context) -> str: + answer = await ctx.elicit("Where to?", TravelPreferences) + assert isinstance(answer, AcceptedElicitation) + return f"{answer.action}: {answer.data.destination} window={answer.data.window_seat}" + + async def answer_form(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="accept", content={"destination": "Lisbon", "window_seat": True}) + + async with connect(mcp, elicitation_callback=answer_form) as client: + result = await client.call_tool("book_flight", {}) + + assert received == snapshot( + [ + ElicitRequestFormParams( + _meta={}, + message="Where to?", + requested_schema={ + "properties": { + "destination": {"title": "Destination", "type": "string"}, + "window_seat": {"title": "Window Seat", "type": "boolean"}, + }, + "required": ["destination", "window_seat"], + "title": "TravelPreferences", + "type": "object", + }, + ) + ] + ) + assert result == snapshot( + CallToolResult( + content=[TextContent(text="accept: Lisbon window=True")], + structured_content={"result": "accept: Lisbon window=True"}, + ) + ) + + +@requirement("mcpserver:context:read-resource") +async def test_context_read_resource_reads_registered_resource(connect: Connect) -> None: + """Context.read_resource lets a tool read a resource registered on the same server. + + The tool reports the MIME type and content it read, proving the resource function ran and its + return value came back through the context. + """ + mcp = MCPServer("library") + + @mcp.resource("config://app") + def app_config() -> str: + """The application configuration.""" + return "theme = dark" + + @mcp.tool() + async def show_config(ctx: Context) -> str: + contents = list(await ctx.read_resource("config://app")) + return "\n".join(f"{item.mime_type}: {item.content!r}" for item in contents) + + async with connect(mcp) as client: + result = await client.call_tool("show_config", {}) + + assert result == snapshot( + CallToolResult( + content=[TextContent(text="text/plain: 'theme = dark'")], + structured_content={"result": "text/plain: 'theme = dark'"}, + ) + ) + + +@requirement("logging:message:filtered") +async def test_set_logging_level_is_rejected_and_messages_are_never_filtered(connect: Connect) -> None: + """MCPServer does not support logging/setLevel, so log messages are never filtered by severity. + + The request is rejected with METHOD_NOT_FOUND because MCPServer registers no handler for it, + and every message a tool emits is delivered regardless of level. The spec says the server + should only send messages at or above the configured level; with no way to configure one, + everything is sent. + """ + received: list[LoggingMessageNotificationParams] = [] + mcp = MCPServer("unfilterable") + + @mcp.tool() + async def chatter(ctx: Context) -> str: + await ctx.debug("noise") + await ctx.error("signal") + return "done" + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params) + + async with connect(mcp, logging_callback=collect) as client: + with pytest.raises(MCPError) as exc_info: + await client.set_logging_level("error") + + await client.call_tool("chatter", {}) + + assert exc_info.value.error == snapshot(ErrorData(code=METHOD_NOT_FOUND, message="Method not found")) + assert received == snapshot( + [ + LoggingMessageNotificationParams(level="debug", data="noise"), + LoggingMessageNotificationParams(level="error", data="signal"), + ] + ) diff --git a/tests/interaction/mcpserver/test_prompts.py b/tests/interaction/mcpserver/test_prompts.py new file mode 100644 index 0000000000..2095f086d4 --- /dev/null +++ b/tests/interaction/mcpserver/test_prompts.py @@ -0,0 +1,195 @@ +"""Prompt interactions against MCPServer, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError +from mcp.server.mcpserver import MCPServer +from mcp.types import ( + ErrorData, + GetPromptResult, + ListPromptsResult, + Prompt, + PromptArgument, + PromptMessage, + TextContent, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("mcpserver:prompt:decorated") +async def test_list_prompts_derives_arguments_from_signature(connect: Connect) -> None: + """A decorated prompt is listed with arguments derived from the function signature. + + Parameters without a default are required; the description comes from the docstring. + """ + mcp = MCPServer("prompter") + + @mcp.prompt() + def code_review(code: str, style_guide: str = "pep8") -> str: + """Review a piece of code.""" + raise NotImplementedError # registered for listing only; never rendered + + async with connect(mcp) as client: + result = await client.list_prompts() + + assert result == snapshot( + ListPromptsResult( + prompts=[ + Prompt( + name="code_review", + description="Review a piece of code.", + arguments=[ + PromptArgument(name="code", required=True), + PromptArgument(name="style_guide", required=False), + ], + ) + ] + ) + ) + + +@requirement("mcpserver:prompt:decorated") +async def test_get_prompt_renders_function_return(connect: Connect) -> None: + """The decorated function's string return value is rendered as a single user message.""" + mcp = MCPServer("prompter") + + @mcp.prompt() + def greet(name: str) -> str: + """A personalised greeting.""" + return f"Say hello to {name}." + + async with connect(mcp) as client: + result = await client.get_prompt("greet", {"name": "Ada"}) + + assert result == snapshot( + GetPromptResult( + description="A personalised greeting.", + messages=[PromptMessage(role="user", content=TextContent(text="Say hello to Ada."))], + ) + ) + + +@requirement("mcpserver:prompt:unknown-name") +async def test_get_unknown_prompt_is_error(connect: Connect) -> None: + """Getting a prompt name that was never registered fails with a JSON-RPC error. + + The spec reserves -32602 for this case; the SDK reports code 0 (see the divergence note on + the requirement). + """ + mcp = MCPServer("prompter") + + @mcp.prompt() + def greet(name: str) -> str: + """A registered prompt; the test requests a different name.""" + raise NotImplementedError + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.get_prompt("nope") + + assert exc_info.value.error == snapshot(ErrorData(code=0, message="Unknown prompt: nope")) + + +@requirement("prompts:get:missing-required-args") +async def test_get_prompt_with_a_missing_required_argument_is_an_error(connect: Connect) -> None: + """Getting a prompt without one of its required arguments fails with a JSON-RPC error. + + The missing argument is detected before the prompt function is called, but the spec's -32602 + Invalid params is reported as error code 0 with the bare exception text (see the divergence + note on the requirement). + """ + mcp = MCPServer("prompter") + + @mcp.prompt() + def greet(name: str) -> str: + """A registered prompt; validation rejects the call before the function runs.""" + raise NotImplementedError + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.get_prompt("greet") + + assert exc_info.value.error == snapshot(ErrorData(code=0, message="Missing required arguments: {'name'}")) + + +@requirement("mcpserver:prompt:args-validation") +async def test_get_prompt_with_a_wrong_type_argument_is_rejected_before_the_function_runs(connect: Connect) -> None: + """An argument that fails the function signature's type validation is rejected before the function runs. + + The decorated function is wrapped in pydantic's validate_call, so a value that cannot be + coerced to the parameter's annotation fails before the body executes. The function body + raises NotImplementedError to prove it never ran. The error is wrapped in the SDK's stable + rendering-error prefix; the body of the message is raw pydantic output and is not asserted. + """ + mcp = MCPServer("prompter") + + @mcp.prompt() + def repeat(phrase: str, count: int) -> str: + """A registered prompt; type validation rejects the call before the function runs.""" + raise NotImplementedError + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.get_prompt("repeat", {"phrase": "hi", "count": "many"}) + + assert exc_info.value.error.code == 0 + assert exc_info.value.error.message.startswith("Error rendering prompt repeat: 1 validation error") + + +@requirement("mcpserver:prompt:optional-args") +async def test_get_prompt_with_an_optional_argument_omitted_uses_the_default(connect: Connect) -> None: + """A prompt rendered without one of its optional arguments uses that parameter's default value.""" + mcp = MCPServer("prompter") + + @mcp.prompt() + def review(code: str, style: str = "pep8") -> str: + """Review a snippet of code against a style guide.""" + return f"Review {code} per {style}." + + async with connect(mcp) as client: + result = await client.get_prompt("review", {"code": "x = 1"}) + + assert result == snapshot( + GetPromptResult( + description="Review a snippet of code against a style guide.", + messages=[PromptMessage(role="user", content=TextContent(text="Review x = 1 per pep8."))], + ) + ) + + +@requirement("mcpserver:prompt:duplicate-name") +async def test_registering_a_duplicate_prompt_name_warns_and_keeps_the_first(connect: Connect) -> None: + """Registering a second prompt with an already-used name keeps the first registration. + + The intended behaviour is rejection at registration time; MCPServer instead logs a warning + and discards the second registration (see the divergence note on the requirement). The + second function is registered via the decorator with an explicit name so the test does not + redefine the same function name in this scope. + """ + mcp = MCPServer("prompter") + + @mcp.prompt() + def greet() -> str: + """The first registration; this is the one that wins.""" + return "first" + + @mcp.prompt(name="greet") + def greet_second() -> str: + """Registered with a duplicate name; the registration is discarded so this never runs.""" + raise NotImplementedError + + async with connect(mcp) as client: + listed = await client.list_prompts() + result = await client.get_prompt("greet") + + assert [prompt.name for prompt in listed.prompts] == ["greet"] + assert result == snapshot( + GetPromptResult( + description="The first registration; this is the one that wins.", + messages=[PromptMessage(role="user", content=TextContent(text="first"))], + ) + ) diff --git a/tests/interaction/mcpserver/test_resources.py b/tests/interaction/mcpserver/test_resources.py new file mode 100644 index 0000000000..57b0fdc86d --- /dev/null +++ b/tests/interaction/mcpserver/test_resources.py @@ -0,0 +1,183 @@ +"""Resource interactions against MCPServer, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError +from mcp.server.mcpserver import MCPServer +from mcp.types import ( + ErrorData, + ListResourcesResult, + ListResourceTemplatesResult, + ReadResourceResult, + Resource, + ResourceTemplate, + TextResourceContents, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("mcpserver:resource:static") +async def test_read_static_resource(connect: Connect) -> None: + """A function registered for a fixed URI is served at that URI with its return value as text.""" + mcp = MCPServer("library") + + @mcp.resource("config://app") + def app_config() -> str: + """The application configuration.""" + return "theme = dark" + + async with connect(mcp) as client: + result = await client.read_resource("config://app") + + assert result == snapshot( + ReadResourceResult( + contents=[TextResourceContents(uri="config://app", mime_type="text/plain", text="theme = dark")] + ) + ) + + +@requirement("mcpserver:resource:static") +async def test_list_static_and_templated_resources(connect: Connect) -> None: + """Statically-registered resources appear in resources/list; templated ones only in templates/list. + + The name and description are derived from the function name and docstring; the MIME type + defaults to text/plain. + """ + mcp = MCPServer("library") + + @mcp.resource("config://app") + def app_config() -> str: + """The application configuration.""" + raise NotImplementedError # registered for listing only; never read + + @mcp.resource("users://{user_id}/profile") + def user_profile(user_id: str) -> str: + """A user's profile.""" + raise NotImplementedError # registered for listing only; never read + + async with connect(mcp) as client: + resources = await client.list_resources() + templates = await client.list_resource_templates() + + assert resources == snapshot( + ListResourcesResult( + resources=[ + Resource( + name="app_config", + uri="config://app", + description="The application configuration.", + mime_type="text/plain", + ) + ] + ) + ) + assert templates == snapshot( + ListResourceTemplatesResult( + resource_templates=[ + ResourceTemplate( + name="user_profile", + uri_template="users://{user_id}/profile", + description="A user's profile.", + mime_type="text/plain", + ) + ] + ) + ) + + +@requirement("mcpserver:resource:template") +@requirement("resources:read:template-vars") +async def test_read_templated_resource(connect: Connect) -> None: + """Reading a URI that matches a registered template invokes the function with the extracted parameters.""" + mcp = MCPServer("library") + + @mcp.resource("users://{user_id}/profile") + def user_profile(user_id: str) -> str: + """A user's profile.""" + return f"profile for {user_id}" + + async with connect(mcp) as client: + result = await client.read_resource("users://42/profile") + + assert result == snapshot( + ReadResourceResult( + contents=[TextResourceContents(uri="users://42/profile", mime_type="text/plain", text="profile for 42")] + ) + ) + + +@requirement("mcpserver:resource:unknown-uri") +async def test_read_unknown_uri_is_error(connect: Connect) -> None: + """Reading a URI that matches no registered resource fails with a JSON-RPC error. + + The spec reserves -32002 for resource-not-found; see the divergence note on the requirement. + """ + mcp = MCPServer("library") + + @mcp.resource("config://app") + def app_config() -> str: + """A registered resource; the test reads a different URI.""" + raise NotImplementedError + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.read_resource("config://missing") + + assert exc_info.value.error == snapshot(ErrorData(code=0, message="Unknown resource: config://missing")) + + +@requirement("mcpserver:resource:read-throws-surfaced") +async def test_resource_function_that_raises_is_surfaced_as_a_jsonrpc_error(connect: Connect) -> None: + """An exception raised by a resource function reaches the caller as a JSON-RPC error. + + MCPServer wraps the failure in a generic error that names only the URI, so the original + exception text is not leaked to the client. The wrapped exception becomes error code 0 the + same way every other unhandled server-side exception does. + """ + mcp = MCPServer("library") + + @mcp.resource("res://boom") + def boom() -> str: + raise RuntimeError("nope") + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.read_resource("res://boom") + + assert exc_info.value.error == snapshot(ErrorData(code=0, message="Error reading resource res://boom")) + + +@requirement("mcpserver:resource:duplicate-name") +async def test_registering_a_duplicate_resource_uri_warns_and_keeps_the_first(connect: Connect) -> None: + """Registering a second static resource at an already-used URI keeps the first registration. + + The intended behaviour is rejection at registration time; MCPServer instead logs a warning + and discards the second registration (see the divergence note on the requirement). The two + registrations use different function names so the test does not redefine a name in this scope; + the resource decorator keys on the URI, not the function name. + """ + mcp = MCPServer("library") + + @mcp.resource("config://app") + def config_first() -> str: + """The first registration; this is the one that wins.""" + return "first" + + @mcp.resource("config://app") + def config_second() -> str: + """Registered at a duplicate URI; the registration is discarded so this never runs.""" + raise NotImplementedError + + async with connect(mcp) as client: + listed = await client.list_resources() + result = await client.read_resource("config://app") + + assert [resource.uri for resource in listed.resources] == ["config://app"] + assert listed.resources[0].name == "config_first" + assert result == snapshot( + ReadResourceResult(contents=[TextResourceContents(uri="config://app", mime_type="text/plain", text="first")]) + ) diff --git a/tests/interaction/mcpserver/test_tools.py b/tests/interaction/mcpserver/test_tools.py new file mode 100644 index 0000000000..05135c1286 --- /dev/null +++ b/tests/interaction/mcpserver/test_tools.py @@ -0,0 +1,432 @@ +"""Tool interactions against MCPServer, driven through the public Client API.""" + +import logging +from typing import Annotated, Literal + +import pytest +from inline_snapshot import snapshot +from pydantic import BaseModel, Field + +from mcp import MCPError +from mcp.server.mcpserver import Context, MCPServer +from mcp.server.mcpserver.exceptions import ToolError +from mcp.shared.exceptions import UrlElicitationRequiredError +from mcp.types import ( + URL_ELICITATION_REQUIRED, + CallToolResult, + ElicitRequestURLParams, + ErrorData, + LoggingMessageNotification, + LoggingMessageNotificationParams, + TextContent, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("tools:call:content:text") +async def test_call_tool_returns_text_content(connect: Connect) -> None: + """Arguments reach the tool function; its return value comes back as text content. + + MCPServer also derives an output schema from the return annotation and attaches the + matching structuredContent to the result. + """ + mcp = MCPServer("adder") + + @mcp.tool() + def add(a: int, b: int) -> str: + return str(a + b) + + async with connect(mcp) as client: + result = await client.call_tool("add", {"a": 2, "b": 3}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="5")], structured_content={"result": "5"})) + + +@requirement("mcpserver:tool:schema-variants") +async def test_complex_parameter_types_are_validated_and_coerced_before_the_tool_runs(connect: Connect) -> None: + """Literal, nested-model, and constrained parameters are validated and coerced from the wire arguments. + + The string "3" is coerced to `int` and the `point` dict to a `Point` instance before the function + body sees them, proving the generated input schema and validation pipeline cover non-trivial types. + """ + mcp = MCPServer("typed") + + class Point(BaseModel): + x: int + y: int + + @mcp.tool() + def place(mode: Literal["fast", "slow"], point: Point, count: Annotated[int, Field(ge=1, le=10)]) -> str: + assert isinstance(point, Point) + return f"{mode} at ({point.x}, {point.y}) x{count}" + + async with connect(mcp) as client: + result = await client.call_tool("place", {"mode": "fast", "point": {"x": "3", "y": 4}, "count": 5}) + + assert result == snapshot( + CallToolResult( + content=[TextContent(text="fast at (3, 4) x5")], structured_content={"result": "fast at (3, 4) x5"} + ) + ) + + +@requirement("mcpserver:tool:handler-throws") +@requirement("mcpserver:output-schema:skip-on-error") +async def test_call_tool_function_exception_becomes_error_result(connect: Connect) -> None: + """An exception raised by a tool function is returned as an is_error result, not a JSON-RPC error. + + The function's `-> str` annotation gives the tool a derived output schema, but the error + result is built before any schema validation runs, so no validation failure is layered on + top of the original exception. + """ + mcp = MCPServer("errors") + + @mcp.tool() + def explode() -> str: + raise ValueError("boom") + + async with connect(mcp) as client: + result = await client.call_tool("explode", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="Error executing tool explode: boom")], is_error=True) + ) + + +@requirement("mcpserver:tool:handler-throws") +async def test_call_tool_tool_error_becomes_error_result(connect: Connect) -> None: + """A ToolError raised by a tool function is returned as an is_error result, not a JSON-RPC error.""" + mcp = MCPServer("errors") + + @mcp.tool() + def flux() -> str: + raise ToolError("flux capacitor offline") + + async with connect(mcp) as client: + result = await client.call_tool("flux", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="Error executing tool flux: flux capacitor offline")], is_error=True) + ) + + +@requirement("mcpserver:tool:unknown-name") +async def test_call_tool_unknown_name_returns_error_result(connect: Connect) -> None: + """Calling a tool name that was never registered is reported as an is_error result. + + The spec classifies unknown tools as a protocol error; see the divergence note on the + requirement. + """ + mcp = MCPServer("errors") + + @mcp.tool() + def add() -> None: + """A registered tool; the test calls a different name.""" + + async with connect(mcp) as client: + result = await client.call_tool("nope", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="Unknown tool: nope")], is_error=True)) + + +@requirement("mcpserver:tool:output-schema:model") +@requirement("tools:call:structured-content:text-mirror") +async def test_call_tool_model_return_becomes_structured_content(connect: Connect) -> None: + """A tool returning a pydantic model advertises the model's schema as the tool's output schema + and returns the model's fields as structured content alongside a serialised text block. + """ + mcp = MCPServer("weather") + + class Weather(BaseModel): + temperature: float + conditions: str + + @mcp.tool() + def get_weather() -> Weather: + return Weather(temperature=22.5, conditions="sunny") + + async with connect(mcp) as client: + listed = await client.list_tools() + result = await client.call_tool("get_weather", {}) + + assert listed.tools[0].output_schema == snapshot( + { + "properties": { + "temperature": {"title": "Temperature", "type": "number"}, + "conditions": {"title": "Conditions", "type": "string"}, + }, + "required": ["temperature", "conditions"], + "title": "Weather", + "type": "object", + } + ) + assert result == snapshot( + CallToolResult( + content=[ + TextContent( + text="""\ +{ + "temperature": 22.5, + "conditions": "sunny" +}\ +""" + ) + ], + structured_content={"temperature": 22.5, "conditions": "sunny"}, + ) + ) + + +@requirement("mcpserver:tool:output-schema:wrapped") +async def test_call_tool_list_return_is_wrapped_in_result_key(connect: Connect) -> None: + """A tool returning a list wraps the value under a "result" key in both the generated output + schema and the structured content. + """ + mcp = MCPServer("primes") + + @mcp.tool() + def primes() -> list[int]: + return [2, 3, 5] + + async with connect(mcp) as client: + listed = await client.list_tools() + result = await client.call_tool("primes", {}) + + assert listed.tools[0].output_schema == snapshot( + { + "properties": {"result": {"items": {"type": "integer"}, "title": "Result", "type": "array"}}, + "required": ["result"], + "title": "primesOutput", + "type": "object", + } + ) + assert result == snapshot( + CallToolResult( + content=[TextContent(text="2"), TextContent(text="3"), TextContent(text="5")], + structured_content={"result": [2, 3, 5]}, + ) + ) + + +@requirement("mcpserver:tool:input-validation") +async def test_call_tool_invalid_arguments_become_error_result(connect: Connect) -> None: + """Arguments that fail validation against the tool's signature are reported as an is_error + result describing the failure, not as a protocol error. + """ + mcp = MCPServer("adder") + + @mcp.tool() + def add(a: int, b: int) -> str: + """Validation rejects the arguments before the function is ever called.""" + raise NotImplementedError + + async with connect(mcp) as client: + result = await client.call_tool("add", {"b": 3}) + + # The description is raw pydantic output -- it embeds a pydantic-version-specific + # errors.pydantic.dev URL and the internal `addArguments` model name -- so only the stable + # prefix is asserted; a full snapshot would break on every pydantic upgrade. + assert result.is_error is True + assert isinstance(result.content[0], TextContent) + assert result.content[0].text.startswith("Error executing tool add: 1 validation error") + + +@requirement("mcpserver:output-schema:server-validate") +@requirement("mcpserver:output-schema:missing-structured") +async def test_tool_with_output_schema_returning_mismatched_structured_content_is_an_error_result( + connect: Connect, +) -> None: + """Structured content that fails the tool's own output schema is rejected on the server side. + + A tool annotated `Annotated[CallToolResult, Model]` returns a hand-built CallToolResult while + declaring `Model` as its output schema; MCPServer validates the supplied structured_content + against that schema before returning. The two cases -- a content shape that does not match, + and no structured content at all -- both fail that validation and are reported as is_error + results carrying the (raw pydantic) validation error wrapped in the SDK's stable prefix. + """ + mcp = MCPServer("forecaster") + + class Weather(BaseModel): + temperature: float + conditions: str + + @mcp.tool() + def mismatched() -> Annotated[CallToolResult, Weather]: + return CallToolResult(content=[TextContent(text="oops")], structured_content={"nope": True}) + + @mcp.tool() + def missing() -> Annotated[CallToolResult, Weather]: + return CallToolResult(content=[TextContent(text="oops")]) + + async with connect(mcp) as client: + mismatched_result = await client.call_tool("mismatched", {}) + missing_result = await client.call_tool("missing", {}) + + # The body of each message is raw pydantic ValidationError output (model name, field paths, + # an errors.pydantic.dev URL) and changes across pydantic versions, so only the SDK's stable + # prefix is asserted. + assert mismatched_result.is_error is True + assert isinstance(mismatched_result.content[0], TextContent) + assert mismatched_result.content[0].text.startswith("Error executing tool mismatched: 2 validation errors") + + assert missing_result.is_error is True + assert isinstance(missing_result.content[0], TextContent) + assert missing_result.content[0].text.startswith("Error executing tool missing: 1 validation error") + + +@requirement("mcpserver:tool:duplicate-name") +async def test_registering_a_duplicate_tool_name_warns_and_keeps_the_first(connect: Connect) -> None: + """Registering a second tool with an already-used name keeps the first registration. + + The intended behaviour is rejection at registration time; MCPServer instead logs a warning + and discards the second registration (see the divergence note on the requirement). The + second function is registered via add_tool with an explicit name so the test does not + redefine the same function name in this scope. + """ + mcp = MCPServer("duplicates") + + @mcp.tool() + def echo() -> str: + return "first" + + def echo_second() -> str: + """Passed to add_tool with a duplicate name; the registration is discarded so this never runs.""" + raise NotImplementedError + + mcp.add_tool(echo_second, name="echo") + + async with connect(mcp) as client: + listed = await client.list_tools() + result = await client.call_tool("echo", {}) + + assert [tool.name for tool in listed.tools] == ["echo"] + assert result == snapshot( + CallToolResult(content=[TextContent(text="first")], structured_content={"result": "first"}) + ) + + +@requirement("mcpserver:tool:naming-validation") +async def test_registering_a_tool_with_a_spec_invalid_name_warns_but_does_not_reject( + connect: Connect, caplog: pytest.LogCaptureFixture +) -> None: + """A tool name that violates the SEP-986 rules logs a warning at registration but is still registered. + + The intended behaviour is rejection at registration time; MCPServer instead logs the + naming-rule violation and proceeds (see the divergence note on the requirement). The warning + spans several SDK-authored log records, so only the stable prefix and inclusion of the + offending name are asserted. + """ + mcp = MCPServer("naming") + + with caplog.at_level(logging.WARNING, logger="mcp.shared.tool_name_validation"): + + @mcp.tool(name="bad name!") + def bad() -> str: + return "ok" + + assert any( + rec.levelno == logging.WARNING + and rec.message.startswith("Tool name validation warning") + and "bad name!" in rec.message + for rec in caplog.records + ) + + async with connect(mcp) as client: + listed = await client.list_tools() + result = await client.call_tool("bad name!", {}) + + assert [tool.name for tool in listed.tools] == ["bad name!"] + assert result == snapshot(CallToolResult(content=[TextContent(text="ok")], structured_content={"result": "ok"})) + + +@requirement("mcpserver:tool:url-elicitation-error") +async def test_decorated_tool_raising_url_elicitation_required_surfaces_as_error_32042(connect: Connect) -> None: + """A decorated tool raising the URL-elicitation-required error reaches the client as error -32042. + + MCPServer wraps every other tool exception as an is_error result; this error is special-cased + so it propagates as the JSON-RPC error the client needs in order to present the listed URL + interactions and retry the call. + """ + mcp = MCPServer("authorizer") + + @mcp.tool() + def read_files() -> str: + raise UrlElicitationRequiredError( + [ + ElicitRequestURLParams( + message="Authorization required for your files.", + url="https://example.com/oauth/authorize", + elicitation_id="auth-001", + ) + ] + ) + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("read_files", {}) + + assert exc_info.value.error.code == URL_ELICITATION_REQUIRED + assert exc_info.value.error == snapshot( + ErrorData( + code=-32042, + message="URL elicitation required", + data={ + "elicitations": [ + { + "mode": "url", + "message": "Authorization required for your files.", + "url": "https://example.com/oauth/authorize", + "elicitationId": "auth-001", + } + ] + }, + ) + ) + + +@requirement("mcpserver:register:post-connect") +async def test_adding_and_removing_tools_does_not_notify_connected_clients(connect: Connect) -> None: + """Mutating the tool set on a running server changes tools/list but sends no notification. + + add_tool and remove_tool only update the registry: a connected client that listed the tools + before the mutation has no way to learn it should list them again. The spec provides + notifications/tools/list_changed for exactly this; MCPServer never sends it. The tool emits + one log message as a sentinel so the test proves notifications do reach the collector -- the + log message arrives, a list_changed does not. + """ + received: list[IncomingMessage] = [] + mcp = MCPServer("mutable") + + def extra() -> str: + """A tool registered at runtime; never called.""" + raise NotImplementedError + + @mcp.tool() + def doomed() -> str: + """A tool removed at runtime; never called.""" + raise NotImplementedError + + @mcp.tool() + async def grow(ctx: Context) -> str: + mcp.add_tool(extra, name="extra") + mcp.remove_tool("doomed") + await ctx.info("tool set changed") + return "mutated" + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async with connect(mcp, message_handler=collect) as client: + before = await client.list_tools() + await client.call_tool("grow", {}) + after = await client.list_tools() + + assert [tool.name for tool in before.tools] == ["doomed", "grow"] + assert [tool.name for tool in after.tools] == ["grow", "extra"] + assert received == snapshot( + [LoggingMessageNotification(params=LoggingMessageNotificationParams(level="info", data="tool set changed"))] + ) diff --git a/tests/interaction/test_coverage.py b/tests/interaction/test_coverage.py new file mode 100644 index 0000000000..7821c1eed5 --- /dev/null +++ b/tests/interaction/test_coverage.py @@ -0,0 +1,105 @@ +"""Enforces the contract between the requirements manifest and the test suite. + +The contract runs in both directions: every non-deferred entry in :data:`REQUIREMENTS` must be +exercised by at least one test, and every test in the suite must carry at least one +`@requirement(...)` mark referencing a manifest entry. Deferral reasons that point at coverage +elsewhere in the repo must point at paths that exist. Test modules are imported directly +(rather than relying on pytest collection) so the check holds even when only this file is run. +""" + +import importlib +import re +from pathlib import Path +from types import ModuleType + +import pytest + +from tests.interaction._requirements import REQUIREMENTS, Requirement, covered_by, requirement + +_SUITE_ROOT = Path(__file__).parent +_REPO_ROOT = _SUITE_ROOT.parent.parent + +# Repo paths cited inside deferral reasons ("Covered by tests/... "). +_CITED_PATH = re.compile(r"(?:tests|src)/[\w./-]*\w") + +# Tests that exercise the suite's own helpers rather than an interaction-model behaviour. +# Anything listed here is exempt from the every-test-has-a-requirement check. +_HARNESS_SELF_TESTS = { + "tests.interaction.lowlevel.test_wire.test_recording_read_stream_ends_iteration_when_the_sender_closes", + "tests.interaction.transports.test_bridge.test_response_chunks_arrive_as_the_application_sends_them", + "tests.interaction.transports.test_bridge.test_closing_the_response_delivers_a_disconnect_to_the_application", + "tests.interaction.transports.test_bridge.test_an_application_failure_before_the_response_starts_fails_the_request", + "tests.interaction.transports.test_bridge.test_disabling_cancel_on_close_lets_the_application_finish_after_disconnect", + "tests.interaction.auth.test_flow.test_shimmed_app_serves_overrides_404s_and_otherwise_forwards_to_the_wrapped_app", +} + + +def _import_all_test_modules() -> list[ModuleType]: + """Import every other test module in the suite so their `@requirement` decorators register.""" + modules: list[ModuleType] = [] + for path in sorted(_SUITE_ROOT.rglob("test_*.py")): + relative = path.relative_to(_SUITE_ROOT).with_suffix("") + name = f"{__package__}.{'.'.join(relative.parts)}" + if name != __name__: + modules.append(importlib.import_module(name)) + return modules + + +def test_every_requirement_is_exercised() -> None: + """Each non-deferred requirement is covered by at least one test (deferred ones by none).""" + _import_all_test_modules() + + uncovered = [ + requirement_id + for requirement_id, spec in sorted(REQUIREMENTS.items()) + if spec.deferred is None and not covered_by(requirement_id) + ] + assert not uncovered, f"Requirements with no test and no deferred reason: {uncovered}" + + stale_deferrals = [ + requirement_id + for requirement_id, spec in sorted(REQUIREMENTS.items()) + if spec.deferred is not None and covered_by(requirement_id) + ] + assert not stale_deferrals, f"Deferred requirements that now have tests (remove deferred): {stale_deferrals}" + + +def test_every_test_exercises_a_requirement() -> None: + """Each test in the suite carries at least one `@requirement` mark (harness self-tests excepted).""" + all_tests = { + f"{module.__name__}.{name}" + for module in _import_all_test_modules() + for name in vars(module) + if name.startswith("test_") + } + linked_tests = {test_name for requirement_id in REQUIREMENTS for test_name in covered_by(requirement_id)} + + unlinked = sorted(all_tests - linked_tests - _HARNESS_SELF_TESTS) + assert not unlinked, f"Tests with no @requirement mark: {unlinked}" + + stale_exemptions = sorted(_HARNESS_SELF_TESTS - all_tests) + assert not stale_exemptions, f"Harness self-test exemptions that no longer exist: {stale_exemptions}" + + +def test_deferral_reasons_cite_existing_paths() -> None: + """Every repo path named in a deferral reason exists, so coverage pointers cannot rot.""" + missing = sorted( + f"{requirement_id}: {cited}" + for requirement_id, spec in REQUIREMENTS.items() + if spec.deferred is not None + for cited in _CITED_PATH.findall(spec.deferred) + if not (_REPO_ROOT / cited).exists() + ) + assert not missing, f"Deferral reasons citing paths that do not exist: {missing}" + + +def test_unknown_requirement_id_is_rejected() -> None: + """Marking a test with an ID that is not in the manifest fails at decoration time.""" + with pytest.raises(KeyError, match="Unknown requirement id 'tools:call:does-not-exist'"): + requirement("tools:call:does-not-exist") + + +def test_invalid_requirement_source_is_rejected() -> None: + """A requirement whose source is not a spec URL, 'sdk', or an issue reference fails at construction.""" + with pytest.raises(ValueError, match="source must be a specification URL"): + Requirement(source="https://example.com/not-the-spec", behavior="Never constructed.") diff --git a/tests/interaction/transports/__init__.py b/tests/interaction/transports/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/interaction/transports/_bridge.py b/tests/interaction/transports/_bridge.py new file mode 100644 index 0000000000..f78c6d14b5 --- /dev/null +++ b/tests/interaction/transports/_bridge.py @@ -0,0 +1,169 @@ +"""An in-process, full-duplex HTTP transport for driving ASGI applications from httpx. + +`httpx.ASGITransport` runs the application to completion and only then hands the buffered +response to the caller, so a server that streams its response — the streamable HTTP transport's +SSE responses — can never converse with the client mid-request: a server-initiated request +nested inside a still-open call deadlocks. `StreamingASGITransport` removes that limitation by +running the application as a background task and forwarding every `http.response.body` chunk to +the client the moment it is sent. Everything happens on the one event loop: no sockets, no +threads, no sleeps, no extra dependencies. + +The behavioural contract, pinned by `test_bridge.py`: + +- The request body is buffered before the application is invoked (MCP requests are small JSON + documents); the response streams chunk by chunk. +- Closing the response — or the whole client — delivers `http.disconnect` to the application, + exactly as a real server sees when its peer goes away. +- An exception the application raises before sending `http.response.start` fails the originating + request with that same exception. After the response has started, a failure is visible to the + client only through the response itself (status code, truncated body) — the same signal a real + server over a real socket would give. + +The transport owns an anyio task group for the application tasks; it is opened and closed by +`httpx.AsyncClient`'s own context manager, so use the client as a context manager (the suite +always does). Closing the transport cancels every running application task by default; set +`cancel_on_close=False` to wait for the application's own disconnect handling instead. +""" + +import math +from collections.abc import AsyncIterator +from types import TracebackType + +import anyio +import anyio.abc +import httpx +from anyio.streams.memory import MemoryObjectReceiveStream +from starlette.types import ASGIApp, Message, Scope + + +class _StreamingResponseBody(httpx.AsyncByteStream): + """A response body that yields chunks as the application produces them. + + Closing it tells the application the client has gone away (`http.disconnect`), mirroring a + peer that drops the connection mid-response. + """ + + def __init__(self, chunks: MemoryObjectReceiveStream[bytes], client_disconnected: anyio.Event) -> None: + self._chunks = chunks + self._client_disconnected = client_disconnected + + async def __aiter__(self) -> AsyncIterator[bytes]: + async for chunk in self._chunks: + yield chunk + + async def aclose(self) -> None: + self._client_disconnected.set() + await self._chunks.aclose() + + +class StreamingASGITransport(httpx.AsyncBaseTransport): + """Drive an ASGI application in-process, streaming each response as it is produced. + + With `cancel_on_close` (the default), closing the transport cancels every application task + still running so harness teardown can never hang. Setting it to False makes the transport wait + for the application's own disconnect handling to complete instead, which is the path the legacy + SSE server transport relies on for resource cleanup. + """ + + _task_group: anyio.abc.TaskGroup + + def __init__(self, app: ASGIApp, *, cancel_on_close: bool = True) -> None: + self._app = app + self._cancel_on_close = cancel_on_close + + async def __aenter__(self) -> "StreamingASGITransport": + self._task_group = anyio.create_task_group() + await self._task_group.__aenter__() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, + ) -> None: + # httpx closes every streamed response before closing the transport, so by now each + # application task has been delivered `http.disconnect`. Either cancel immediately, or wait + # for the application's own disconnect handling to unwind. + if self._cancel_on_close: + self._task_group.cancel_scope.cancel() + await self._task_group.__aexit__(exc_type, exc_value, traceback) + + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + assert isinstance(request.stream, httpx.AsyncByteStream) + request_body = b"".join([chunk async for chunk in request.stream]) + + scope: Scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": request.method, + "scheme": request.url.scheme, + "path": request.url.path, + "raw_path": request.url.raw_path.split(b"?", maxsplit=1)[0], + "query_string": request.url.query, + "root_path": "", + "headers": [(name.lower(), value) for name, value in request.headers.raw], + "server": (request.url.host, request.url.port), + "client": ("127.0.0.1", 1234), + } + + request_delivered = False + client_disconnected = anyio.Event() + response_started = anyio.Event() + response_status = 0 + response_headers: list[tuple[bytes, bytes]] = [] + application_error: Exception | None = None + chunk_writer, chunk_reader = anyio.create_memory_object_stream[bytes](math.inf) + + async def receive_request() -> Message: + nonlocal request_delivered + if not request_delivered: + request_delivered = True + return {"type": "http.request", "body": request_body, "more_body": False} + await client_disconnected.wait() + return {"type": "http.disconnect"} + + async def send_response(message: Message) -> None: + nonlocal response_status, response_headers + if message["type"] == "http.response.start": + response_status = message["status"] + response_headers = list(message.get("headers", [])) + response_started.set() + return + assert message["type"] == "http.response.body" + body: bytes = message.get("body", b"") + if body: + await chunk_writer.send(body) + if not message.get("more_body", False): + await chunk_writer.aclose() + + async def run_application() -> None: + nonlocal application_error + try: + await self._app(scope, receive_request, send_response) + except Exception as exc: # The bridge is the application's outermost boundary: a crash + # must fail the originating request (or show up in the already-started response), + # never tear down the task group shared with every other in-flight request. + application_error = exc + finally: + response_started.set() + await chunk_writer.aclose() + + self._task_group.start_soon(run_application) + try: + await response_started.wait() + if application_error is not None: + raise application_error + except BaseException: + # No response will be built, so close the reader the response body would have owned + # and tell the application its peer has gone away. + client_disconnected.set() + await chunk_reader.aclose() + raise + return httpx.Response( + status_code=response_status, + headers=response_headers, + stream=_StreamingResponseBody(chunk_reader, client_disconnected), + request=request, + ) diff --git a/tests/interaction/transports/_event_store.py b/tests/interaction/transports/_event_store.py new file mode 100644 index 0000000000..84d1a2646a --- /dev/null +++ b/tests/interaction/transports/_event_store.py @@ -0,0 +1,55 @@ +"""A predictable event store for resumability tests. + +The SDK's `EventStore` interface lets a streamable-HTTP server stamp every SSE event with an ID +and replay missed events when a client reconnects with `Last-Event-ID`. This implementation +issues sequential integer IDs starting at "1" so tests can assert exact IDs (the example store +uses uuid4, which cannot be snapshotted) and is small enough that every line is exercised by the +resumability tests themselves. +""" + +import anyio + +from mcp.server.streamable_http import EventCallback, EventId, EventMessage, EventStore, StreamId +from mcp.types import JSONRPCMessage + + +class SequencedEventStore(EventStore): + """Stores every event in order and replays the same-stream tail after a given ID.""" + + def __init__(self) -> None: + self._events: list[tuple[StreamId, JSONRPCMessage | None]] = [] + self._milestones: dict[int, anyio.Event] = {} + + async def store_event(self, stream_id: StreamId, message: JSONRPCMessage | None) -> EventId: + self._events.append((stream_id, message)) + count = len(self._events) + milestone = self._milestones.pop(count, None) + if milestone is not None: + milestone.set() + return str(count) + + async def wait_until_stored(self, count: int) -> None: + """Block until at least `count` events have been stored. + + Tests use this to wait for the server's message router (which runs in another task) to + finish storing a known set of events before issuing a replay, so the replay's content is + deterministic rather than depending on task scheduling order. + """ + if len(self._events) >= count: + return + milestone = self._milestones.setdefault(count, anyio.Event()) + await milestone.wait() + + async def replay_events_after(self, last_event_id: EventId, send_callback: EventCallback) -> StreamId | None: + try: + cursor = int(last_event_id) + except ValueError: + return None + if not 0 < cursor <= len(self._events): + return None + stream_id, _ = self._events[cursor - 1] + for index in range(cursor, len(self._events)): + event_stream_id, message = self._events[index] + if event_stream_id == stream_id and message is not None: + await send_callback(EventMessage(message, str(index + 1))) + return stream_id diff --git a/tests/interaction/transports/_stdio_server.py b/tests/interaction/transports/_stdio_server.py new file mode 100644 index 0000000000..5977cc3e99 --- /dev/null +++ b/tests/interaction/transports/_stdio_server.py @@ -0,0 +1,63 @@ +"""A real low-level Server over the stdio transport, for the suite's one subprocess test. + +Runnable as `python -m tests.interaction.transports._stdio_server` from the repo root; the test +launches it that way via `stdio_client`. Kept separate from the test module so the server lives in +its own importable file (subprocess coverage applies) while the test file follows the suite's +test-only-functions convention. +""" + +import sys + +import anyio + +from mcp.server import Server, ServerRequestContext +from mcp.server.stdio import stdio_server +from mcp.types import ( + CallToolRequestParams, + CallToolResult, + EmptyResult, + ListToolsResult, + PaginatedRequestParams, + SetLevelRequestParams, + TextContent, + Tool, +) + + +async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="echo", + input_schema={"type": "object", "properties": {"text": {"type": "string"}}, "required": ["text"]}, + ) + ] + ) + + +async def call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + assert params.name == "echo" + assert params.arguments is not None + text = params.arguments["text"] + await ctx.session.send_log_message(level="info", data=f"echoing {text}", logger="echo") + return CallToolResult(content=[TextContent(text=text)]) + + +async def set_logging_level(ctx: ServerRequestContext, params: SetLevelRequestParams) -> EmptyResult: + """Registered so the logging capability is advertised; the client never sets a level.""" + raise NotImplementedError + + +server = Server("stdio-echo", on_list_tools=list_tools, on_call_tool=call_tool, on_set_logging_level=set_logging_level) + + +async def main() -> None: + async with stdio_server() as (read_stream, write_stream): + await server.run(read_stream, write_stream, server.create_initialization_options()) + # Reached only when the run loop exits because stdin closed; if the process were terminated + # the test's stderr capture would not see this line. + print("stdio-echo: clean exit", file=sys.stderr, flush=True) + + +if __name__ == "__main__": + anyio.run(main) diff --git a/tests/interaction/transports/test_bridge.py b/tests/interaction/transports/test_bridge.py new file mode 100644 index 0000000000..7420b9d902 --- /dev/null +++ b/tests/interaction/transports/test_bridge.py @@ -0,0 +1,94 @@ +"""Contract tests for the suite's streaming ASGI bridge. + +These pin what `StreamingASGITransport` itself guarantees — chunk-by-chunk delivery, disconnect +propagation, and failure handling — against minimal hand-written ASGI applications, so the MCP +transport tests built on top of it never have to wonder what the harness provides. They are +harness self-tests, not interaction-model tests, and are exempted from the requirement-coverage +contract in `test_coverage.py`. +""" + +import anyio +import httpx +import pytest +from starlette.types import Message, Receive, Scope, Send + +from tests.interaction.transports._bridge import StreamingASGITransport + +pytestmark = pytest.mark.anyio + + +async def test_response_chunks_arrive_as_the_application_sends_them() -> None: + """Each body chunk is delivered as sent, empty chunks are skipped, and the stream ends with the application.""" + + async def chunked_app(scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + assert (await receive())["type"] == "http.request" + await send({"type": "http.response.start", "status": 200, "headers": [(b"content-type", b"text/plain")]}) + await send({"type": "http.response.body", "body": b"first", "more_body": True}) + await send({"type": "http.response.body", "body": b"", "more_body": True}) + await send({"type": "http.response.body", "body": b"second", "more_body": False}) + + async with ( + httpx.AsyncClient(transport=StreamingASGITransport(chunked_app), base_url="http://bridge") as http, + http.stream("GET", "/chunks") as response, + ): + with anyio.fail_after(5): + chunks = [chunk async for chunk in response.aiter_raw()] + + assert response.status_code == 200 + assert response.headers["content-type"] == "text/plain" + assert chunks == [b"first", b"second"] + + +async def test_closing_the_response_delivers_a_disconnect_to_the_application() -> None: + """A client that closes the response early is seen by the application as an http.disconnect.""" + seen_after_request: list[Message] = [] + disconnect_seen = anyio.Event() + + async def waiting_app(scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + assert (await receive())["type"] == "http.request" + await send({"type": "http.response.start", "status": 200, "headers": []}) + seen_after_request.append(await receive()) + disconnect_seen.set() + + async with httpx.AsyncClient(transport=StreamingASGITransport(waiting_app), base_url="http://bridge") as http: + async with http.stream("GET", "/wait") as response: + assert response.status_code == 200 + # Leaving the stream block closes the response while the application is still mid-response. + with anyio.fail_after(5): + await disconnect_seen.wait() + + assert seen_after_request == [{"type": "http.disconnect"}] + + +async def test_an_application_failure_before_the_response_starts_fails_the_request() -> None: + """An exception raised before http.response.start reaches the caller as that same exception.""" + + async def broken_app(scope: Scope, receive: Receive, send: Send) -> None: + raise RuntimeError("the demo application is broken") + + async with httpx.AsyncClient(transport=StreamingASGITransport(broken_app), base_url="http://bridge") as http: + with pytest.raises(RuntimeError, match="the demo application is broken"): + await http.get("/broken") + + +async def test_disabling_cancel_on_close_lets_the_application_finish_after_disconnect() -> None: + """With cancel_on_close=False, an application that runs cleanup after seeing http.disconnect + completes that cleanup before the transport finishes closing.""" + cleanup_ran = anyio.Event() + + async def lingering_app(scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + await receive() + await send({"type": "http.response.start", "status": 200, "headers": []}) + assert (await receive())["type"] == "http.disconnect" + cleanup_ran.set() + + transport = StreamingASGITransport(lingering_app, cancel_on_close=False) + with anyio.fail_after(5): + async with httpx.AsyncClient(transport=transport, base_url="http://bridge") as http: + async with http.stream("GET", "/linger") as response: + assert response.status_code == 200 + assert not cleanup_ran.is_set() + assert cleanup_ran.is_set() diff --git a/tests/interaction/transports/test_client_transport_http.py b/tests/interaction/transports/test_client_transport_http.py new file mode 100644 index 0000000000..65ed03f1e4 --- /dev/null +++ b/tests/interaction/transports/test_client_transport_http.py @@ -0,0 +1,247 @@ +"""Behaviour of the streamable-HTTP client transport itself, observed at the wire. + +These tests connect a real `Client` to a real server over the in-process bridge, recording every +HTTP request the SDK client issues, so the assertions are about what the transport sends (headers, +methods, ordering) rather than what the protocol layer on top of it returns. The recording is the +wire-level instrument; the SDK client never exposes these details. +""" + +from collections.abc import AsyncIterator + +import anyio +import httpx +import pytest +from inline_snapshot import snapshot +from starlette.types import Receive, Scope, Send + +from mcp import MCPError, types +from mcp.client.client import Client +from mcp.client.streamable_http import streamable_http_client +from mcp.server import Server, ServerRequestContext +from mcp.types import INVALID_REQUEST, CallToolResult, ErrorData, ListToolsResult, TextContent, Tool +from tests.interaction._connect import BASE_URL, NO_DNS_REBINDING_PROTECTION, client_via_http, mounted_app +from tests.interaction._requirements import requirement +from tests.interaction.transports._bridge import StreamingASGITransport +from tests.interaction.transports._event_store import SequencedEventStore + +pytestmark = pytest.mark.anyio + + +def _tooled_server() -> Server: + """A low-level server with one echo tool, used by every test in this file.""" + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="echo", description="Echo text.", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "echo" + assert params.arguments is not None + return CallToolResult(content=[TextContent(text=str(params.arguments["text"]))]) + + return Server("echoer", on_list_tools=list_tools, on_call_tool=call_tool) + + +@pytest.fixture +async def recorded() -> AsyncIterator[list[httpx.Request]]: + """Connect a `Client` over a recording HTTP client, list tools, exit, and yield every request sent. + + The HTTP client carries one caller-supplied header (`x-trace`) so its propagation can be + asserted; the recording captures the closing DELETE because it is read after the `Client` has + fully exited. + """ + requests: list[httpx.Request] = [] + + async def record(request: httpx.Request) -> None: + requests.append(request) + + async with mounted_app(_tooled_server(), on_request=record, headers={"x-trace": "abc"}) as (http, _): + async with client_via_http(http) as client: + result = await client.list_tools() + assert [tool.name for tool in result.tools] == ["echo"] + + yield requests + + +def _after_initialize(recorded: list[httpx.Request]) -> list[httpx.Request]: + """Every recorded request after the initialize POST (which carries no session yet).""" + assert recorded[0].method == "POST" + assert "mcp-session-id" not in recorded[0].headers + return recorded[1:] + + +@requirement("client-transport:http:custom-client") +@requirement("client-transport:http:custom-headers") +async def test_the_client_uses_the_supplied_http_client_and_propagates_its_headers( + recorded: list[httpx.Request], +) -> None: + """A caller-supplied `httpx.AsyncClient` is used for every request and carries its own headers. + + The recording itself proves the supplied client is the one in use; the propagated header + proves the SDK transport does not replace the caller's client configuration. + """ + # Exact ordering past the first request is not guaranteed (the standalone GET stream is + # scheduled concurrently with later POSTs), so methods are asserted as a multiset. + assert sorted(request.method for request in recorded) == snapshot(["DELETE", "GET", "POST", "POST", "POST"]) + assert all(request.headers["x-trace"] == "abc" for request in recorded) + + +@requirement("client-transport:http:session-stored") +async def test_every_request_after_initialize_carries_the_issued_session_id(recorded: list[httpx.Request]) -> None: + """The session id from the initialize response is sent on every subsequent request.""" + session_ids = {request.headers["mcp-session-id"] for request in _after_initialize(recorded)} + assert len(session_ids) == 1 + (session_id,) = session_ids + assert session_id + + +@requirement("client-transport:http:protocol-version-stored") +@requirement("client-transport:http:protocol-version-header") +async def test_every_request_after_initialize_carries_the_negotiated_protocol_version( + recorded: list[httpx.Request], +) -> None: + """The negotiated protocol version is sent on every subsequent request (and not on initialize).""" + assert "mcp-protocol-version" not in recorded[0].headers + versions = {request.headers["mcp-protocol-version"] for request in _after_initialize(recorded)} + assert versions == snapshot({"2025-11-25"}) + + +@requirement("client-transport:http:accept-header-post") +@requirement("client-transport:http:accept-header-get") +async def test_accept_headers_cover_the_response_representations_the_transport_handles( + recorded: list[httpx.Request], +) -> None: + """POSTs accept both JSON and SSE; the standalone GET stream accepts SSE.""" + for request in recorded: + if request.method == "POST": + assert "application/json" in request.headers["accept"] + assert "text/event-stream" in request.headers["accept"] + if request.method == "GET": + assert "text/event-stream" in request.headers["accept"] + + +@requirement("client-transport:http:no-reconnect-after-close") +async def test_closing_the_client_sends_delete_and_does_not_reconnect(recorded: list[httpx.Request]) -> None: + """Client teardown sends DELETE and issues no further requests (no resumption GET).""" + assert recorded[-1].method == "DELETE" + assert all("last-event-id" not in request.headers for request in recorded) + + +@requirement("client-transport:http:concurrent-streams") +async def test_concurrent_tool_calls_each_open_a_post_stream_and_receive_their_own_response() -> None: + """Three tool calls issued at once each open their own POST stream and get the right answer.""" + requests: list[httpx.Request] = [] + results: dict[int, CallToolResult] = {} + + async def record(request: httpx.Request) -> None: + requests.append(request) + + async with mounted_app(_tooled_server(), on_request=record) as (http, _), client_via_http(http) as client: + + async def call(n: int) -> None: + results[n] = await client.call_tool("echo", {"text": str(n)}) + + with anyio.fail_after(5): # pragma: no branch + async with anyio.create_task_group() as tg: # pragma: no branch + for n in (1, 2, 3): + tg.start_soon(call, n) + + assert results == snapshot( + { + 1: CallToolResult(content=[TextContent(text="1")]), + 2: CallToolResult(content=[TextContent(text="2")]), + 3: CallToolResult(content=[TextContent(text="3")]), + } + ) + tools_call_posts = [r for r in requests if r.method == "POST" and b'"tools/call"' in r.content] + assert len(tools_call_posts) == 3 + + +@requirement("client-transport:http:sse-405-tolerated") +@requirement("client-transport:http:terminate-405-ok") +async def test_client_tolerates_405_on_get_and_delete() -> None: + """A 405 on the standalone GET stream or the closing DELETE does not fail the connection. + + The GET-stream task swallows the failure and schedules a reconnect that the closing cancel + interrupts before it ever sleeps the full default delay; the DELETE 405 is logged and ignored. + Neither surfaces to the caller. + """ + server = _tooled_server() + real_app = server.streamable_http_app(transport_security=NO_DNS_REBINDING_PROTECTION) + + async def filter_methods(scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http" and scope["method"] in ("GET", "DELETE"): + await send({"type": "http.response.start", "status": 405, "headers": []}) + await send({"type": "http.response.body", "body": b""}) + return + await real_app(scope, receive, send) + + async with ( + server.session_manager.run(), + httpx.AsyncClient(transport=StreamingASGITransport(filter_methods), base_url=BASE_URL) as http_client, + ): + transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) + with anyio.fail_after(5): # pragma: no branch + async with Client(transport) as client: # pragma: no branch + result = await client.list_tools() + + assert [tool.name for tool in result.tools] == ["echo"] + + +@requirement("client-transport:http:no-reconnect-after-response") +async def test_a_completed_post_stream_is_not_reconnected() -> None: + """A POST stream that delivered its response closes without a resumption GET. + + With an event store the server stamps every SSE event with an ID, so the client transport has a + Last-Event-ID it could resume from -- the test proves it does not, because the response arrived + and the stream completed normally. + """ + requests: list[httpx.Request] = [] + + async def record(request: httpx.Request) -> None: + requests.append(request) + + server = _tooled_server() + async with ( + mounted_app(server, event_store=SequencedEventStore(), retry_interval=0, on_request=record) as (http, _), + client_via_http(http) as client, + ): + with anyio.fail_after(5): + result = await client.list_tools() + + assert [tool.name for tool in result.tools] == ["echo"] + resumption_gets = [r for r in requests if r.method == "GET" and "last-event-id" in r.headers] + assert resumption_gets == [] + + +@requirement("client-transport:http:404-surfaces") +async def test_a_404_mid_session_surfaces_as_a_session_terminated_error() -> None: + """A 404 in response to a request after initialization is reported to the caller as an MCP error. + + The spec says the client MUST start a new session in this situation; the SDK instead surfaces a + `Session terminated` error to the caller. The spec's MUST is tracked at + client-transport:http:session-404-reinitialize; this test pins the SDK's current behaviour. + """ + server = _tooled_server() + real_app = server.streamable_http_app(transport_security=NO_DNS_REBINDING_PROTECTION) + initialize_seen = anyio.Event() + + async def first_post_then_404(scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http" and scope["method"] == "POST" and initialize_seen.is_set(): + await send({"type": "http.response.start", "status": 404, "headers": []}) + await send({"type": "http.response.body", "body": b""}) + return + if scope["type"] == "http" and scope["method"] == "POST": + initialize_seen.set() + await real_app(scope, receive, send) + + async with ( + server.session_manager.run(), + httpx.AsyncClient(transport=StreamingASGITransport(first_post_then_404), base_url=BASE_URL) as http_client, + ): + transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) + with anyio.fail_after(5): # pragma: no branch + async with Client(transport) as client: # pragma: no branch + with pytest.raises(MCPError) as exc_info: # pragma: no branch + await client.list_tools() + + assert exc_info.value.error == snapshot(ErrorData(code=INVALID_REQUEST, message="Session terminated")) diff --git a/tests/interaction/transports/test_flows.py b/tests/interaction/transports/test_flows.py new file mode 100644 index 0000000000..c428fe2d68 --- /dev/null +++ b/tests/interaction/transports/test_flows.py @@ -0,0 +1,129 @@ +"""Transport-level composed flows: multi-client isolation, reconnection, and dual-transport hosting. + +These scenarios are about how the transport layer holds together across more than one connection +or more than one transport, so they connect real `Client`s against one mounted server rather than +running over the matrix. +""" + +import anyio +import httpx +import pytest +from inline_snapshot import snapshot + +from mcp.client.session import LoggingFnT +from mcp.server.mcpserver import Context, MCPServer +from mcp.types import CallToolResult, LoggingMessageNotificationParams, TextContent +from tests.interaction._connect import client_via_http, connect_over_sse, mounted_app +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("flow:multi-client:stateful-isolation") +async def test_concurrent_clients_on_one_stateful_server_receive_only_their_own_notifications() -> None: + """Two clients on one stateful manager each receive only the notifications their own request produced. + + Complements `test_terminating_one_session_leaves_others_working` (which proves session + independence under termination) with the notification-isolation dimension: a notification + emitted by one session's handler does not leak to another session's client. + """ + mcp = MCPServer("multi") + + @mcp.tool() + async def announce(label: str, ctx: Context) -> str: + """Emit one info-level log carrying the caller's label, then return it.""" + await ctx.info(label) + return label + + received_a: list[object] = [] + received_b: list[object] = [] + + async def collect_a(params: LoggingMessageNotificationParams) -> None: + received_a.append(params.data) + + async def collect_b(params: LoggingMessageNotificationParams) -> None: + received_b.append(params.data) + + async with mounted_app(mcp) as (http, _): + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + + async def call(label: str, collect: LoggingFnT) -> None: + async with client_via_http(http, logging_callback=collect) as client: + await client.call_tool("announce", {"label": label}) + + tg.start_soon(call, "a", collect_a) + tg.start_soon(call, "b", collect_b) + + assert received_a == ["a"] + assert received_b == ["b"] + + +@requirement("flow:session:terminate-then-reconnect") +async def test_a_fresh_connection_after_termination_obtains_a_new_session_and_operates() -> None: + """After a client terminates, a fresh connection to the same manager gets a distinct session. + + Steps: (1) connect a client and call list_tools, (2) the client exits (its DELETE fires), + (3) connect a second client to the same mounted app, (4) the second client's call_tool + succeeds and the recorded session ids show two distinct sessions were issued. + """ + mcp = MCPServer("reconnectable") + + @mcp.tool() + def echo(text: str) -> str: + """Return the input unchanged.""" + return text + + session_ids: list[str] = [] + + async def record(request: httpx.Request) -> None: + session_id = request.headers.get("mcp-session-id") + if session_id is not None: + session_ids.append(session_id) + + async with mounted_app(mcp, on_request=record) as (http, _): + async with client_via_http(http) as first: + first_result = await first.list_tools() + async with client_via_http(http) as second: + second_result = await second.call_tool("echo", {"text": "again"}) + + assert {tool.name for tool in first_result.tools} == {"echo"} + assert second_result == snapshot( + CallToolResult(content=[TextContent(text="again")], structured_content={"result": "again"}) + ) + distinct = set(session_ids) + assert len(distinct) == 2, f"expected two distinct session ids across the two connections, saw {distinct}" + + +@requirement("flow:compat:dual-transport-server") +async def test_one_server_serves_streamable_http_and_sse_clients_concurrently() -> None: + """One MCPServer instance serves a streamable-HTTP client and a legacy-SSE client at the same time. + + The two transports have independent connection management (the streamable-HTTP session manager + versus a per-connection SSE handler), but both dispatch into the same server's request + handlers. The test connects one client over each transport against the same instance and + proves both reach the same tool. Uses MCPServer because the low-level Server has no SSE + convenience; the entry is about hosting composition, not the low-level API. + """ + mcp = MCPServer("dual") + + @mcp.tool() + def echo(text: str) -> str: + """Return the input unchanged.""" + return text + + async with ( + mounted_app(mcp) as (http, _), + connect_over_sse(mcp) as sse_client, + client_via_http(http) as shttp_client, + ): + with anyio.fail_after(5): + shttp_result = await shttp_client.call_tool("echo", {"text": "via http"}) + sse_result = await sse_client.call_tool("echo", {"text": "via sse"}) + + assert shttp_result == snapshot( + CallToolResult(content=[TextContent(text="via http")], structured_content={"result": "via http"}) + ) + assert sse_result == snapshot( + CallToolResult(content=[TextContent(text="via sse")], structured_content={"result": "via sse"}) + ) diff --git a/tests/interaction/transports/test_hosting_http.py b/tests/interaction/transports/test_hosting_http.py new file mode 100644 index 0000000000..85e64ded42 --- /dev/null +++ b/tests/interaction/transports/test_hosting_http.py @@ -0,0 +1,344 @@ +"""Streamable HTTP semantics: status codes, header validation, message routing, and security. + +These tests speak HTTP directly to the server's mounted ASGI app via the in-process bridge, +asserting the wire contract -- which status code answers which condition, which stream a message +travels on -- that the SDK client never exposes. Transport-agnostic behaviour is covered by the +`connect`-fixture matrix. +""" + +import anyio +import pytest +from anyio.lowlevel import checkpoint +from httpx_sse import ServerSentEvent, aconnect_sse +from inline_snapshot import snapshot + +from mcp.server import Server, ServerRequestContext +from mcp.server.transport_security import TransportSecuritySettings +from mcp.types import ( + INVALID_PARAMS, + PARSE_ERROR, + CallToolRequestParams, + CallToolResult, + EmptyResult, + JSONRPCError, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + ListResourcesResult, + ListToolsResult, + PaginatedRequestParams, + SetLevelRequestParams, + SubscribeRequestParams, + TextContent, +) +from tests.interaction._connect import ( + base_headers, + initialize_body, + initialize_via_http, + mounted_app, + parse_sse_messages, +) +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +def _server() -> Server: + """A low-level server with one tool that emits a related and an unrelated notification.""" + + async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + """Registered only so the tools capability is advertised; never called.""" + raise NotImplementedError + + async def call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + assert params.name == "narrate" + await ctx.session.send_log_message(level="info", data="related", logger=None, related_request_id=ctx.request_id) + await ctx.session.send_resource_updated("file:///watched.txt") + return CallToolResult(content=[TextContent(text="done")]) + + async def set_logging_level(ctx: ServerRequestContext, params: SetLevelRequestParams) -> EmptyResult: + """Registered so the logging capability is advertised; the client never sets a level.""" + raise NotImplementedError + + async def list_resources(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListResourcesResult: + """Registered so the resources capability is advertised; the client never lists resources.""" + raise NotImplementedError + + async def subscribe_resource(ctx: ServerRequestContext, params: SubscribeRequestParams) -> EmptyResult: + """Registered so the resources subscribe sub-capability is advertised; the client never subscribes.""" + raise NotImplementedError + + return Server( + "hosted", + on_list_tools=list_tools, + on_call_tool=call_tool, + on_set_logging_level=set_logging_level, + on_list_resources=list_resources, + on_subscribe_resource=subscribe_resource, + ) + + +@requirement("hosting:http:method-405") +async def test_unsupported_http_methods_return_405() -> None: + """PUT and PATCH on the MCP endpoint return 405 with an Allow header naming the supported methods.""" + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + put = await http.put("/mcp", json={}, headers=base_headers(session_id=session_id)) + patch = await http.patch("/mcp", json={}, headers=base_headers(session_id=session_id)) + + assert (put.status_code, put.headers.get("allow")) == snapshot((405, "GET, POST, DELETE")) + assert (patch.status_code, patch.headers.get("allow")) == snapshot((405, "GET, POST, DELETE")) + + +@requirement("hosting:http:accept-406") +async def test_missing_accept_media_types_return_406() -> None: + """A POST whose Accept header lacks both required types, or a GET lacking text/event-stream, returns 406.""" + async with mounted_app(_server()) as (http, _): + post = await http.post( + "/mcp", json=initialize_body(), headers={"accept": "text/plain", "mcp-protocol-version": "2025-11-25"} + ) + session_id = await initialize_via_http(http) + get = await http.get( + "/mcp", + headers={"accept": "application/json", "mcp-protocol-version": "2025-11-25", "mcp-session-id": session_id}, + ) + + assert (post.status_code, post.json()["error"]["message"]) == snapshot( + (406, "Not Acceptable: Client must accept both application/json and text/event-stream") + ) + assert (get.status_code, get.json()["error"]["message"]) == snapshot( + (406, "Not Acceptable: Client must accept text/event-stream") + ) + + +@requirement("hosting:http:content-type-415") +async def test_non_json_content_type_is_rejected() -> None: + """A POST with a non-JSON Content-Type is rejected before reaching the transport. + + See the divergence on the requirement: the security middleware rejects with 400, so the + transport's own 415 path is unreachable through any public entry point. + """ + async with mounted_app(_server()) as (http, _): + response = await http.post( + "/mcp", content=b"", headers=base_headers() | {"content-type": "text/plain"} + ) + + assert (response.status_code, response.text) == snapshot((400, "Invalid Content-Type header")) + + +@requirement("hosting:http:parse-error-400") +@requirement("hosting:http:batch") +async def test_malformed_and_batched_bodies_return_400() -> None: + """A non-JSON body returns 400 Parse error; a JSON array of requests returns 400 Invalid params.""" + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + not_json = await http.post( + "/mcp", + content=b"this is not json", + headers=base_headers(session_id=session_id) | {"content-type": "application/json"}, + ) + batched = await http.post( + "/mcp", + json=[ + {"jsonrpc": "2.0", "id": 1, "method": "tools/list"}, + {"jsonrpc": "2.0", "id": 2, "method": "tools/list"}, + ], + headers=base_headers(session_id=session_id), + ) + + assert not_json.status_code == 400 + assert JSONRPCError.model_validate_json(not_json.text).error.code == PARSE_ERROR + assert batched.status_code == 400 + assert JSONRPCError.model_validate_json(batched.text).error.code == INVALID_PARAMS + + +@requirement("hosting:http:protocol-version-400") +@requirement("hosting:http:protocol-version-default") +async def test_protocol_version_header_is_validated() -> None: + """An unsupported MCP-Protocol-Version header returns 400; an absent header is accepted as the default.""" + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + + bad = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 2, "method": "tools/list"}, + headers=base_headers(session_id=session_id) | {"mcp-protocol-version": "1991-01-01"}, + ) + # Only Accept and the session ID -- no MCP-Protocol-Version header at all. + defaulted = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "method": "notifications/progress", "params": {"progressToken": 0, "progress": 1}}, + headers={"accept": "application/json, text/event-stream", "mcp-session-id": session_id}, + ) + + assert bad.status_code == 400 + assert JSONRPCError.model_validate_json(bad.text).error.message.startswith( + "Bad Request: Unsupported protocol version: 1991-01-01." + ) + # 202 proves the request was accepted under the assumed default version (2025-03-26). + assert defaulted.status_code == 202 + + +@requirement("hosting:http:json-response-mode") +async def test_json_response_mode_answers_with_application_json_not_sse() -> None: + """With JSON response mode enabled, request POSTs are answered with a single application/json body. + + Asserted at the wire level because the SDK client parses either representation, so a + Client-driven round trip cannot distinguish a JSON response from an SSE one. + """ + async with mounted_app(_server(), json_response=True) as (http, _): + initialized = await http.post("/mcp", json=initialize_body(), headers=base_headers()) + session_id = initialized.headers["mcp-session-id"] + ping = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 2, "method": "ping"}, + headers=base_headers(session_id=session_id), + ) + + assert initialized.status_code == 200 + assert initialized.headers["content-type"].split(";", 1)[0] == "application/json" + assert JSONRPCResponse.model_validate(initialized.json()).id == 1 + assert ping.status_code == 200 + assert ping.headers["content-type"].split(";", 1)[0] == "application/json" + assert JSONRPCResponse.model_validate(ping.json()).id == 2 + + +@requirement("hosting:http:notifications-202") +async def test_notification_post_returns_202_with_no_body() -> None: + """A POST containing only a notification (no request ID) returns 202 Accepted with no body.""" + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + response = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "method": "notifications/progress", "params": {"progressToken": 0, "progress": 1}}, + headers=base_headers(session_id=session_id), + ) + + assert (response.status_code, response.content) == snapshot((202, b"")) + + +@requirement("hosting:http:second-sse-rejected") +async def test_a_second_standalone_get_stream_on_the_same_session_returns_409() -> None: + """Opening a second standalone GET SSE stream while one is already established returns 409 Conflict.""" + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + + async with aconnect_sse(http, "GET", "/mcp", headers=base_headers(session_id=session_id)) as first: + assert first.response.status_code == 200 + # The standalone-stream writer registers its key as its first action, then parks + # awaiting messages; one yield to the loop lets that registration complete before the + # second GET is dispatched. + await checkpoint() + second = await http.get("/mcp", headers=base_headers(session_id=session_id)) + + assert (second.status_code, second.json()["error"]["message"]) == snapshot( + (409, "Conflict: Only one SSE stream is allowed per session") + ) + + +@requirement("hosting:http:standalone-sse") +@requirement("hosting:http:standalone-sse-no-response") +@requirement("hosting:http:response-same-connection") +@requirement("hosting:http:sse-close-after-response") +@requirement("hosting:http:no-broadcast") +async def test_messages_are_routed_to_exactly_one_stream() -> None: + """Each server message travels on exactly one SSE stream and is never broadcast. + + A streamable-HTTP session has two kinds of server-to-client SSE stream: one short-lived stream + per POST request, carrying that request's response and any notifications related to it, and one + long-lived standalone stream (opened by GET) for notifications not tied to any request. The + spec's routing rule is that the POST stream delivers the response (and its related + notifications) and then closes, the standalone stream carries only unrelated notifications and + never a JSON-RPC response, and no message appears on both. The test opens both streams, calls a + tool whose handler emits one related and one unrelated notification, and asserts each message's + routing. + """ + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + post_events: list[ServerSentEvent] = [] + get_events: list[ServerSentEvent] = [] + + async def read_standalone_stream() -> None: + async with aconnect_sse(http, "GET", "/mcp", headers=base_headers(session_id=session_id)) as get: + assert get.response.status_code == 200 + standalone_ready.set() + async for event in get.aiter_sse(): + get_events.append(event) + seen_on_standalone.set() + + standalone_ready = anyio.Event() + seen_on_standalone = anyio.Event() + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + tg.start_soon(read_standalone_stream) + await standalone_ready.wait() + + params = CallToolRequestParams(name="narrate", arguments={}) + body = JSONRPCRequest(jsonrpc="2.0", id=5, method="tools/call", params=params.model_dump()) + async with aconnect_sse( + http, + "POST", + "/mcp", + json=body.model_dump(by_alias=True, exclude_none=True), + headers=base_headers(session_id=session_id), + ) as post: + assert post.response.status_code == 200 + # The POST stream iterator ends when the server closes the stream after the response. + post_events = [event async for event in post.aiter_sse()] + + await seen_on_standalone.wait() + tg.cancel_scope.cancel() + + post_messages = parse_sse_messages(post_events) + get_messages = parse_sse_messages(get_events) + + # POST stream: the related log notification, then the response, then the iterator ends (close). + assert [type(m).__name__ for m in post_messages] == snapshot(["JSONRPCNotification", "JSONRPCResponse"]) + assert isinstance(post_messages[0], JSONRPCNotification) + assert (post_messages[0].method, post_messages[0].params) == snapshot( + ("notifications/message", {"level": "info", "data": "related"}) + ) + assert isinstance(post_messages[1], JSONRPCResponse) + assert post_messages[1].id == 5 + + # Standalone stream: only the unrelated resource-updated notification, never a response. + assert [type(m).__name__ for m in get_messages] == snapshot(["JSONRPCNotification"]) + assert isinstance(get_messages[0], JSONRPCNotification) + assert get_messages[0].method == snapshot("notifications/resources/updated") + + +@requirement("hosting:http:dns-rebinding") +@requirement("transport:streamable-http:origin-validation") +async def test_origin_validation_rejects_disallowed_origins_when_enabled() -> None: + """A disallowed Origin returns 403 (and Host 421) with protection enabled; disabled lets both through. + + See the divergence on hosting:http:dns-rebinding: the spec's Origin validation is an + unconditional MUST, but the SDK enables it only when the host is localhost (or settings are + passed explicitly) and additionally checks the Host header (returning 421), which the spec + does not require. + """ + # transport_security=None triggers the localhost auto-enable behaviour. + async with mounted_app(Server("guarded"), transport_security=None) as (http, _): + bad_origin = await http.post( + "/mcp", json=initialize_body(), headers=base_headers() | {"origin": "http://evil.example"} + ) + bad_host = await http.post("/mcp", json=initialize_body(), headers=base_headers() | {"host": "evil.example"}) + async with aconnect_sse( + http, "POST", "/mcp", json=initialize_body(), headers=base_headers() | {"origin": "http://127.0.0.1:8000"} + ) as ok: + assert ok.response.status_code == 200 + assert [event async for event in ok.aiter_sse()] + + assert (bad_origin.status_code, bad_origin.text) == snapshot((403, "Invalid Origin header")) + assert (bad_host.status_code, bad_host.text) == snapshot((421, "Invalid Host header")) + + async with mounted_app( + Server("unguarded"), transport_security=TransportSecuritySettings(enable_dns_rebinding_protection=False) + ) as (http, _): + async with aconnect_sse( + http, "POST", "/mcp", json=initialize_body(), headers=base_headers() | {"origin": "http://evil.example"} + ) as unguarded: + status = unguarded.response.status_code + assert [event async for event in unguarded.aiter_sse()] + + assert status == 200 diff --git a/tests/interaction/transports/test_hosting_resume.py b/tests/interaction/transports/test_hosting_resume.py new file mode 100644 index 0000000000..c7945d56c3 --- /dev/null +++ b/tests/interaction/transports/test_hosting_resume.py @@ -0,0 +1,372 @@ +"""Resumability over the streamable HTTP transport, exercised entirely in process. + +These tests configure the server with an event store, so every SSE event is stamped with an ID +and a client that loses its connection can resume by sending `Last-Event-ID`. The wire-level +tests (`mounted_app` + raw httpx) assert exactly what travels on the wire; the end-to-end test +drives the SDK client through a server-initiated stream close and proves the call still +completes. The bridge's `aclose()` delivers `http.disconnect` to the running application, so +closing a streaming response mid-read is a deterministic in-process disconnect -- no sockets, +no real time. Every server here uses `retry_interval=0` so reconnection waits are no-ops. +""" + +import json + +import anyio +import httpx +import pytest +from httpx_sse import EventSource, ServerSentEvent +from inline_snapshot import snapshot + +from mcp.client.session import ClientSession +from mcp.client.streamable_http import streamable_http_client +from mcp.server.mcpserver import Context, MCPServer +from mcp.shared.message import ClientMessageMetadata +from mcp.types import ( + LATEST_PROTOCOL_VERSION, + CallToolRequest, + CallToolRequestParams, + CallToolResult, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + LoggingMessageNotificationParams, + TextContent, + jsonrpc_message_adapter, +) +from tests.interaction._connect import ( + BASE_URL, + base_headers, + connect_over_streamable_http, + initialize_via_http, + mounted_app, + parse_sse_messages, +) +from tests.interaction._requirements import requirement +from tests.interaction.transports._event_store import SequencedEventStore + +pytestmark = pytest.mark.anyio + + +def _counting_server() -> MCPServer: + """A server with one tool that emits related notifications and one unrelated notification.""" + mcp = MCPServer("resumable") + + @mcp.tool() + async def count(ctx: Context, n: int) -> str: + """Emit n log notifications related to this call, plus one unrelated resource update.""" + for i in range(1, n + 1): + await ctx.info(f"tick {i}") + await ctx.session.send_resource_updated("file:///elsewhere.txt") + return f"counted to {n}" + + return mcp + + +def _tools_call(request_id: int, name: str, arguments: dict[str, object]) -> str: + """A serialized tools/call JSON-RPC request body.""" + return JSONRPCRequest( + jsonrpc="2.0", id=request_id, method="tools/call", params={"name": name, "arguments": arguments} + ).model_dump_json(by_alias=True, exclude_none=True) + + +async def _read_events(response: httpx.Response, count: int) -> list[ServerSentEvent]: + """Read exactly `count` SSE events from a streaming response without closing it.""" + source = EventSource(response).aiter_sse() + return [await anext(source) for _ in range(count)] + + +@requirement("hosting:resume:event-ids") +@requirement("hosting:resume:priming") +async def test_a_post_sse_stream_begins_with_a_priming_event_and_stamps_every_event() -> None: + """A request's SSE stream opens with a priming event (id, empty data, retry) then stamps each message.""" + async with mounted_app(_counting_server(), event_store=SequencedEventStore(), retry_interval=0) as (http, _): + session_id = await initialize_via_http(http) + with anyio.fail_after(5): + async with http.stream( # pragma: no branch + "POST", "/mcp", content=_tools_call(1, "count", {"n": 2}), headers=base_headers(session_id=session_id) + ) as response: + assert response.status_code == 200 + events = await _read_events(response, 4) + + priming, first, second, result = events + # The priming event is the only event a client could have seen before any work happened, so it + # is the resumption anchor: it carries an ID and empty data. The SDK attaches the retry hint + # to this event (see the divergence on hosting:resume:priming). + assert (priming.id, priming.data, priming.retry) == snapshot(("3", "", 0)) + assert priming.event == snapshot("message") + # Every subsequent event carries an event-store ID; the related notifications and the response + # all ride this stream and close it after the response. + assert [event.id for event in (first, second, result)] == snapshot(["4", "5", "7"]) + assert [json.loads(event.data)["method"] for event in (first, second)] == snapshot( + ["notifications/message", "notifications/message"] + ) + assert jsonrpc_message_adapter.validate_json(result.data) == snapshot( + JSONRPCResponse( + jsonrpc="2.0", + id=1, + result={ + "content": [{"type": "text", "text": "counted to 2"}], + "structuredContent": {"result": "counted to 2"}, + "isError": False, + }, + ) + ) + + +@requirement("hosting:resume:replay") +@requirement("hosting:resume:stream-scoped") +@requirement("hosting:resume:buffered-replay") +async def test_get_with_last_event_id_replays_only_that_streams_missed_events() -> None: + """Reconnecting with Last-Event-ID returns the missed events from that one stream, in order. + + The handler also emits an unrelated notification (which the server stores under the + standalone-stream key); replay must not return it, proving replay is scoped to the stream + the given event ID belongs to. + + Steps: (1) initialize; (2) POST a tool call and read events until the first notification is + captured; (3) close the response mid-stream -- the bridge delivers `http.disconnect`, the + handler keeps running; (4) release the handler so it emits the remaining messages, which the + server buffers in the event store; (5) wait on the event store for the handler's response to + be stored, so the replay's content is independent of task scheduling; (6) GET with + `Last-Event-ID` and assert the replay is exactly the missed events from this request's stream. + """ + release = anyio.Event() + store = SequencedEventStore() + + mcp = MCPServer("resumable") + + @mcp.tool() + async def count(ctx: Context) -> str: + """Emit one related notification, wait for the test, then emit two more plus an unrelated one.""" + await ctx.info("tick 1") + await release.wait() + await ctx.info("tick 2") + await ctx.info("tick 3") + await ctx.session.send_resource_updated("file:///elsewhere.txt") + return "counted" + + async with mounted_app(mcp, event_store=store, retry_interval=0) as (http, _): + session_id = await initialize_via_http(http) + with anyio.fail_after(5): + async with http.stream( + "POST", "/mcp", content=_tools_call(1, "count", {}), headers=base_headers(session_id=session_id) + ) as response: + # Read the priming event and the first notification, then drop the connection. + priming, first = await _read_events(response, 2) + assert (priming.id, first.id) == snapshot(("3", "4")) + last_seen = first.id + release.set() + # The handler keeps running after the disconnect; its remaining messages are stored. + # The first wait returns immediately (the priming and first tick are already stored); + # the second blocks until the response itself is stored so the replay content is fixed. + await store.wait_until_stored(4) + await store.wait_until_stored(8) + replay_headers = base_headers(session_id=session_id) | {"last-event-id": last_seen} + async with http.stream("GET", "/mcp", headers=replay_headers) as replay: # pragma: no branch + assert replay.status_code == 200 + missed = await _read_events(replay, 3) + + decoded = parse_sse_messages(missed) + # Exactly the two remaining related notifications and the response, with their original IDs. + assert [event.id for event in missed] == snapshot(["5", "6", "8"]) + assert [type(message).__name__ for message in decoded] == snapshot( + ["JSONRPCNotification", "JSONRPCNotification", "JSONRPCResponse"] + ) + assert isinstance(decoded[2], JSONRPCResponse) + assert decoded[2].id == 1 + # The unrelated resource-updated notification was stored under the standalone-stream key, not + # this request's stream, so it must not appear in the replay. + assert all( + not (isinstance(message, JSONRPCNotification) and message.method == "notifications/resources/updated") + for message in decoded + ) + + +@requirement("hosting:resume:bad-event-id") +async def test_an_unknown_last_event_id_yields_an_empty_replay_stream() -> None: + """A Last-Event-ID the event store cannot map produces an empty SSE stream rather than an error. + + See the divergence on hosting:resume:bad-event-id: this pins current behaviour. + """ + async with mounted_app(_counting_server(), event_store=SequencedEventStore(), retry_interval=0) as (http, _): + session_id = await initialize_via_http(http) + with anyio.fail_after(5): + for unknown in ("no-such-event", "0"): + headers = base_headers(session_id=session_id) | {"last-event-id": unknown} + async with http.stream("GET", "/mcp", headers=headers) as replay: + assert replay.status_code == 200 + assert replay.headers["content-type"].startswith("text/event-stream") + events = [event async for event in EventSource(replay).aiter_sse()] + assert events == [] + + +@requirement("hosting:http:disconnect-not-cancel") +async def test_dropping_the_connection_mid_request_does_not_cancel_the_handler() -> None: + """Closing the request's SSE connection while the handler is running leaves the handler running. + + The handler signals when it has started and when it has finished; the test drops the + connection in between and then releases the handler. If the disconnect cancelled the handler, + `finished` would never be set and the test would time out. + """ + started = anyio.Event() + release = anyio.Event() + finished = anyio.Event() + + mcp = MCPServer("resumable") + + @mcp.tool() + async def hold(ctx: Context) -> str: + """Signal start, wait for the test, signal completion.""" + started.set() + await release.wait() + await ctx.info("released") + finished.set() + return "held" + + async with mounted_app(mcp, event_store=SequencedEventStore(), retry_interval=0) as (http, _): + session_id = await initialize_via_http(http) + with anyio.fail_after(5): + async with http.stream( + "POST", "/mcp", content=_tools_call(1, "hold", {}), headers=base_headers(session_id=session_id) + ) as response: + await _read_events(response, 1) + await started.wait() + assert not finished.is_set() + release.set() + await finished.wait() + + +# This test intentionally carries every automatic-reconnection requirement: the +# close-then-resume scenario is indivisible, so splitting it would mean five near-identical bodies. +@requirement("hosting:resume:close-stream") +@requirement("transport:streamable-http:resumability") +@requirement("client-transport:http:reconnect-post-priming") +@requirement("client-transport:http:reconnect-retry-value") +@requirement("flow:resume:tool-call-resumption-token") +async def test_a_call_whose_stream_the_server_closes_is_resumed_by_the_client() -> None: + """A server-closed request stream is reconnected by the client and the call completes. + + The handler emits one notification, closes its own SSE stream, then (once released) emits + another and returns. The client observed the priming event (so it has a Last-Event-ID and a + retry hint of 0ms), sees the stream end, reconnects via GET with Last-Event-ID, and receives + the post-close notification and the result over the replay stream. The shared events make the + test deterministic: the handler only proceeds once the test knows the first notification has + arrived (and so the client's reconnection has begun). + """ + received: list[object] = [] + before_seen = anyio.Event() + gate = anyio.Event() + done = anyio.Event() + + mcp = MCPServer("resumable") + + @mcp.tool() + async def interrupt(ctx: Context) -> str: + """Emit, close this call's SSE stream, then emit again after the test releases the gate.""" + await ctx.info("before close") + await ctx.close_sse_stream() + await gate.wait() + await ctx.info("after close") + done.set() + return "resumed" + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params.data) + if params.data == "before close": + before_seen.set() + + result: list[CallToolResult] = [] + async with connect_over_streamable_http( + mcp, event_store=SequencedEventStore(), retry_interval=0, logging_callback=collect + ) as client: + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + + async def call() -> None: + result.append(await client.call_tool("interrupt", {})) + + tg.start_soon(call) + await before_seen.wait() + gate.set() + await done.wait() + + assert result == snapshot( + [CallToolResult(content=[TextContent(text="resumed")], structured_content={"result": "resumed"})] + ) + assert received == snapshot(["before close", "after close"]) + + +@requirement("client-transport:http:resume-stream-api") +async def test_a_captured_resumption_token_replays_missed_messages_on_a_new_connection() -> None: + """A resumption token captured via on_resumption_token_update on one connection lets a fresh + connection retrieve the messages it missed by passing resumption_token to send_request. + + This is the explicit ClientMessageMetadata API, distinct from the automatic reconnection the + previous test covers: the transport dispatches a resumption_token request as a GET with + Last-Event-ID instead of POSTing the body, and remaps the replayed response onto the new + request's id. Client.call_tool does not expose ClientMessageMetadata, so the test drives a + bare ClientSession via session.send_request -- the sanctioned drop-down for behaviour Client + cannot express. The second connection carries the original session id but does not initialize + (the server-side session already is), modelling a caller that resumes after a process restart. + """ + captured: list[str] = [] + received: list[object] = [] + first_seen = anyio.Event() + token_seen = anyio.Event() + release = anyio.Event() + store = SequencedEventStore() + + mcp = MCPServer("resumable") + + @mcp.tool() + async def hold(ctx: Context) -> str: + """Emit one notification, wait for the test, emit another, return.""" + await ctx.info("first") + await release.wait() + await ctx.info("second") + return "done" + + async def on_token(token: str) -> None: + captured.append(token) + if len(captured) >= 2: + token_seen.set() + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params.data) + first_seen.set() + + call = CallToolRequest(params=CallToolRequestParams(name="hold", arguments={})) + capture = ClientMessageMetadata(on_resumption_token_update=on_token) + + async with mounted_app(mcp, event_store=store, retry_interval=0) as (http, manager): + with anyio.fail_after(5): # pragma: no branch + async with ( # pragma: no branch + streamable_http_client(f"{BASE_URL}/mcp", http_client=http, terminate_on_close=False) as (r1, w1), + ClientSession(r1, w1, logging_callback=collect) as first, + anyio.create_task_group() as tg, + ): + await first.initialize() + tg.start_soon(first.send_request, call, CallToolResult, None, capture) + await first_seen.wait() + await token_seen.wait() + assert captured == snapshot(["3", "4"]) + assert received == snapshot(["first"]) + # The session id is only observable via the manager (the client transport does not expose it). + (session_id,) = manager._server_instances + http.headers["mcp-session-id"] = session_id + http.headers["mcp-protocol-version"] = LATEST_PROTOCOL_VERSION + tg.cancel_scope.cancel() + + with anyio.fail_after(5): # pragma: no branch + release.set() # pragma: lax no cover — python/cpython#106749: 3.11 drops this line event + # init priming + init response + call priming + "first" + "second" + result = 6 stored events. + await store.wait_until_stored(6) + async with ( # pragma: no branch + streamable_http_client(f"{BASE_URL}/mcp", http_client=http) as (r2, w2), + ClientSession(r2, w2, logging_callback=collect) as second, + ): + result = await second.send_request( + call, CallToolResult, metadata=ClientMessageMetadata(resumption_token=captured[-1]) + ) + assert result == snapshot(CallToolResult(content=[TextContent(text="done")], structured_content={"result": "done"})) + assert received == snapshot(["first", "second"]) diff --git a/tests/interaction/transports/test_hosting_session.py b/tests/interaction/transports/test_hosting_session.py new file mode 100644 index 0000000000..a926c3e8a2 --- /dev/null +++ b/tests/interaction/transports/test_hosting_session.py @@ -0,0 +1,202 @@ +"""Streamable HTTP session lifecycle: creation, routing, termination, and stateless mode. + +A test here speaks raw HTTP only when its assertion is the wire contract -- which header is +issued, which status code answers which condition -- that the SDK `Client` cannot observe. +Everything else is `Client`-driven against the same mounted session manager. Transport-agnostic +behaviour is covered by the `connect`-fixture matrix. +""" + +import re + +import anyio +import httpx +import pytest +from inline_snapshot import snapshot + +from mcp.server import Server, ServerRequestContext +from mcp.types import JSONRPCResponse, ListToolsResult, PaginatedRequestParams, Tool +from tests.interaction._connect import ( + base_headers, + client_via_http, + initialize_body, + initialize_via_http, + mounted_app, + post_jsonrpc, +) +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +def _server() -> Server: + """A minimal low-level server with one tool, so subsequent-request routing can be observed.""" + + async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="noop", description="Does nothing.", input_schema={"type": "object"})]) + + return Server("hosted", on_list_tools=list_tools) + + +@requirement("hosting:session:create") +@requirement("hosting:session:id-charset") +async def test_initialize_issues_a_visible_ascii_session_id() -> None: + """An initialize POST without a session ID creates a session and returns a visible-ASCII Mcp-Session-Id.""" + async with mounted_app(_server()) as (http, _): + response, messages = await post_jsonrpc(http, initialize_body()) + + assert response.status_code == 200 + session_id = response.headers.get("mcp-session-id") + assert session_id is not None + # The spec requires the session ID to consist only of visible ASCII (0x21-0x7E). + assert re.fullmatch(r"[\x21-\x7E]+", session_id) + assert isinstance(messages[0], JSONRPCResponse) + assert messages[0].id == 1 + + +@requirement("hosting:session:reuse") +async def test_subsequent_requests_with_the_session_id_route_to_the_same_session() -> None: + """Requests carrying the issued Mcp-Session-Id reuse that session's transport rather than creating another.""" + async with mounted_app(_server()) as (http, manager): + async with client_via_http(http) as client: + await client.list_tools() + await client.list_tools() + # The session count is the only signal that distinguishes routing-to-existing from + # silently creating a second session: both produce a successful result. + assert len(manager._server_instances) == 1 + + +@requirement("hosting:session:unknown-id") +async def test_requests_with_an_unknown_session_id_return_404() -> None: + """POST, GET, and DELETE each carrying an unknown Mcp-Session-Id are answered 404 by the manager.""" + async with mounted_app(_server()) as (http, _): + post = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 1, "method": "tools/list"}, + headers=base_headers(session_id="not-a-session"), + ) + get = await http.get("/mcp", headers=base_headers(session_id="not-a-session")) + delete = await http.delete("/mcp", headers=base_headers(session_id="not-a-session")) + + assert (post.status_code, post.json()) == snapshot( + (404, {"jsonrpc": "2.0", "id": None, "error": {"code": -32600, "message": "Session not found"}}) + ) + assert (get.status_code, delete.status_code) == (404, 404) + + +@requirement("hosting:session:missing-id") +async def test_non_initialize_post_without_a_session_id_returns_400() -> None: + """A non-initialize POST that omits Mcp-Session-Id in stateful mode is rejected with 400.""" + async with mounted_app(_server()) as (http, _): + await initialize_via_http(http) + response = await http.post( + "/mcp", json={"jsonrpc": "2.0", "id": 2, "method": "tools/list"}, headers=base_headers() + ) + + assert (response.status_code, response.json()) == snapshot( + (400, {"jsonrpc": "2.0", "id": None, "error": {"code": -32600, "message": "Bad Request: Missing session ID"}}) + ) + + +@requirement("hosting:session:delete") +@requirement("hosting:session:post-termination-404") +async def test_delete_terminates_the_session_and_subsequent_requests_return_404() -> None: + """DELETE with a valid Mcp-Session-Id terminates the session; further requests on that ID return 404.""" + async with mounted_app(_server()) as (http, manager): + session_id = await initialize_via_http(http) + + delete = await http.delete("/mcp", headers=base_headers(session_id=session_id)) + assert delete.status_code == 200 + + # The manager keeps the terminated transport registered, so the next request reaches the + # transport's own _terminated check rather than the manager's unknown-session path. + assert session_id in manager._server_instances + post = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 2, "method": "tools/list"}, + headers=base_headers(session_id=session_id), + ) + assert (post.status_code, post.json()) == snapshot( + ( + 404, + { + "jsonrpc": "2.0", + "id": None, + "error": {"code": -32600, "message": "Not Found: Session has been terminated"}, + }, + ) + ) + + +@requirement("hosting:session:isolation") +async def test_terminating_one_session_leaves_others_working() -> None: + """Terminating one session on a manager does not disturb a concurrent session on the same manager.""" + async with mounted_app(_server()) as (http, manager): + async with client_via_http(http) as survivor: + async with client_via_http(http) as terminated: + await terminated.list_tools() + assert len(manager._server_instances) == 2 + # `terminated` has exited (its DELETE has been sent); `survivor` still answers. + result = await survivor.list_tools() + + assert result.tools[0].name == "noop" + + +@requirement("hosting:session:reinitialize") +async def test_second_initialize_on_an_existing_session_is_accepted() -> None: + """A second initialize POST carrying an existing session ID is processed rather than rejected. + + See the divergence on the requirement: the entry expects a rejection, but the SDK forwards the + second initialize to the running server, which answers it as a fresh handshake. + """ + async with mounted_app(_server()) as (http, manager): + session_id = await initialize_via_http(http) + response, messages = await post_jsonrpc(http, initialize_body(request_id=2), session_id=session_id) + assert len(manager._server_instances) == 1 + + assert response.status_code == snapshot(200) + assert isinstance(messages[0], JSONRPCResponse) + assert messages[0].id == 2 + + +@requirement("hosting:stateless:no-session-id") +@requirement("hosting:stateless:no-reuse") +async def test_stateless_mode_never_issues_a_session_id() -> None: + """A stateless server issues no Mcp-Session-Id and creates no persistent transport. + + The recording proves no request the SDK client sent carried an Mcp-Session-Id (the server + cannot have issued one, or the client would echo it); the empty instance map proves the + manager kept no transport between requests. + """ + requests: list[httpx.Request] = [] + + async def record(request: httpx.Request) -> None: + requests.append(request) + + async with mounted_app(_server(), stateless_http=True, on_request=record) as (http, manager): + async with client_via_http(http) as client: + result = await client.list_tools() + assert manager._server_instances == {} + + assert result.tools[0].name == "noop" + assert all("mcp-session-id" not in request.headers for request in requests) + assert "DELETE" not in {request.method for request in requests} + + +@requirement("hosting:stateless:concurrent-clients") +async def test_stateless_mode_serves_concurrent_clients_independently() -> None: + """Two clients connected concurrently to the same stateless app each complete a round trip.""" + results: dict[str, ListToolsResult] = {} + + async with mounted_app(_server(), stateless_http=True) as (http, _): + + async def list_via(label: str) -> None: + async with client_via_http(http) as client: + results[label] = await client.list_tools() + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + tg.start_soon(list_via, "a") + tg.start_soon(list_via, "b") + + assert results["a"].tools[0].name == "noop" + assert results["b"].tools[0].name == "noop" diff --git a/tests/interaction/transports/test_sse.py b/tests/interaction/transports/test_sse.py new file mode 100644 index 0000000000..9c7353dda5 --- /dev/null +++ b/tests/interaction/transports/test_sse.py @@ -0,0 +1,90 @@ +"""Behaviour specific to the legacy HTTP+SSE transport, exercised entirely in process. + +Transport-agnostic behaviour is covered by the `connect`-fixture matrix, which runs the rest of +the suite over this transport as well; this file pins only what is observable on the SSE wiring +itself: the GET-then-POST connection lifecycle, the endpoint event, and how the message endpoint +rejects requests it cannot route to a session. Every test drives the server's real Starlette app +through the suite's streaming ASGI bridge. +""" + +from uuid import UUID, uuid4 + +import anyio +import httpx +import pytest +from inline_snapshot import snapshot + +from mcp.client.client import Client +from mcp.client.sse import sse_client +from mcp.server import Server +from mcp.types import EmptyResult +from tests.interaction._connect import BASE_URL, build_sse_app +from tests.interaction._requirements import requirement +from tests.interaction.transports._bridge import StreamingASGITransport + +pytestmark = pytest.mark.anyio + + +@requirement("transport:sse") +@requirement("transport:sse:endpoint-event") +async def test_endpoint_event_names_the_message_endpoint_with_a_fresh_session_id() -> None: + """Connecting opens a GET stream whose first event names the POST endpoint and a fresh + session id; messages POSTed there are answered on that stream, and disconnecting releases the + server's session entry.""" + app, sse = build_sse_app(Server("legacy")) + captured_session_id: list[str] = [] + + def httpx_client_factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + return httpx.AsyncClient( + transport=StreamingASGITransport(app, cancel_on_close=False), + base_url=BASE_URL, + headers=headers, + timeout=timeout, + auth=auth, + ) + + transport = sse_client( + f"{BASE_URL}/sse", httpx_client_factory=httpx_client_factory, on_session_created=captured_session_id.append + ) + with anyio.fail_after(5): + async with Client(transport) as client: + assert len(captured_session_id) == 1 + assert UUID(hex=captured_session_id[0]) in sse._read_stream_writers + assert await client.send_ping() == snapshot(EmptyResult()) + + assert sse._read_stream_writers == {} + + +@requirement("transport:sse:post:session-routing") +async def test_post_without_a_session_id_is_rejected() -> None: + """A POST to the message endpoint with no session_id query parameter is answered 400.""" + app, _ = build_sse_app(Server("legacy")) + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http: + response = await http.post("/messages/", json={"jsonrpc": "2.0", "method": "ping", "id": 1}) + assert (response.status_code, response.text) == snapshot((400, "session_id is required")) + + +@requirement("transport:sse:post:session-routing") +async def test_post_with_a_malformed_session_id_is_rejected() -> None: + """A POST whose session_id query parameter is not a UUID is answered 400.""" + app, _ = build_sse_app(Server("legacy")) + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http: + response = await http.post( + "/messages/", params={"session_id": "not-a-uuid"}, json={"jsonrpc": "2.0", "method": "ping", "id": 1} + ) + assert (response.status_code, response.text) == snapshot((400, "Invalid session ID")) + + +@requirement("transport:sse:post:session-routing") +async def test_post_for_an_unknown_session_is_rejected() -> None: + """A POST naming a well-formed session_id that no SSE stream owns is answered 404.""" + app, _ = build_sse_app(Server("legacy")) + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http: + response = await http.post( + "/messages/", params={"session_id": uuid4().hex}, json={"jsonrpc": "2.0", "method": "ping", "id": 1} + ) + assert (response.status_code, response.text) == snapshot((404, "Could not find session")) diff --git a/tests/interaction/transports/test_stdio.py b/tests/interaction/transports/test_stdio.py new file mode 100644 index 0000000000..27cc65de42 --- /dev/null +++ b/tests/interaction/transports/test_stdio.py @@ -0,0 +1,143 @@ +"""The stdio transport: one subprocess end-to-end test and one in-process framing test. + +Everything else in the suite runs in a single process; the subprocess test exists to prove the same +client↔server round trip works over the stdio transport's real boundary (a child process whose +stdin/stdout carry one newline-delimited JSON-RPC message per line). The server lives in +`_stdio_server.py` and is launched via `python -m` so subprocess coverage measurement applies. + +The framing test drives `stdio_server` in-process by passing it injected text streams instead of the +real stdin/stdout, so the raw lines the transport writes can be asserted directly without a process +boundary. + +stdio is deliberately not a leg of the `connect`-fixture matrix: spawning a subprocess per test +would be slow, and the matrix already proves transport-agnosticism over three in-process +transports. Process-lifecycle edge cases (escalation to terminate/kill, parse errors) are covered by +`tests/client/test_stdio.py` and stay deferred here. +""" + +import io +import json +import os +import sys +import tempfile +from pathlib import Path + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp.client.client import Client +from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.server.stdio import stdio_server +from mcp.shared.message import SessionMessage +from mcp.types import ( + CallToolResult, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + LoggingMessageNotificationParams, + TextContent, +) +from mcp.types.jsonrpc import jsonrpc_message_adapter +from tests.interaction._connect import initialize_body +from tests.interaction._requirements import requirement +from tests.interaction.transports import _stdio_server + +pytestmark = pytest.mark.anyio + +_REPO_ROOT = Path(__file__).parents[3] + + +@requirement("transport:stdio") +@requirement("transport:stdio:clean-shutdown") +@requirement("transport:stdio:stderr-passthrough") +async def test_tool_call_and_notification_round_trip_over_a_stdio_subprocess() -> None: + """A Client connected over stdio initializes, calls a tool with arguments, receives the + server's log notification before the call returns, and the server exits when the transport + closes its stdin.""" + received: list[LoggingMessageNotificationParams] = [] + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params) + + with tempfile.TemporaryFile(mode="w+") as errlog: + transport = stdio_client( + StdioServerParameters( + command=sys.executable, + args=["-m", _stdio_server.__name__], + cwd=str(_REPO_ROOT), + # stdio_client deliberately filters the inherited environment to a safe minimum, + # which drops the variables coverage.py's subprocess support uses; pass them through + # so the server module is measured. Empty when not running under coverage. + env={key: value for key, value in os.environ.items() if key.startswith("COVERAGE_")}, + ), + errlog=errlog, + ) + + with anyio.fail_after(10): + async with Client(transport, logging_callback=collect) as client: + assert client.initialize_result.server_info.name == "stdio-echo" + result = await client.call_tool("echo", {"text": "across\nprocesses"}) + + errlog.seek(0) + captured_stderr = errlog.read() + + assert result == snapshot(CallToolResult(content=[TextContent(text="across\nprocesses")])) + # stdio carries one ordered server→client stream, so the same notification-before-response + # guarantee holds here as for the in-memory transport. + assert received == snapshot( + [LoggingMessageNotificationParams(level="info", logger="echo", data="echoing across\nprocesses")] + ) + # The server writes this line only after its run loop returns, which happens when stdin closes: + # seeing it proves the process exited on its own rather than via the transport's terminate + # escalation, without a timing-based assertion. The capture itself proves stderr passthrough: + # the transport routes the child's stderr to the caller's `errlog` without consuming it. + assert captured_stderr == snapshot("stdio-echo: clean exit\n") + + +@requirement("transport:stdio:stream-purity") +@requirement("transport:stdio:no-embedded-newlines") +async def test_stdio_server_writes_one_jsonrpc_message_per_line() -> None: + """Everything `stdio_server` writes is a valid JSON-RPC message on its own line, and nothing else. + + The transport's stdin/stdout parameters are public, so the test injects in-process text streams + instead of the real process handles and drives the read/write streams directly: a JSON-RPC line on + stdin is parsed and delivered, and every message sent on the write stream appears as exactly one + newline-terminated line whose payload newlines are JSON-escaped. This proves the transport's own + framing; it does not guard `sys.stdout` against handler code that prints to it directly (see the + divergence on `transport:stdio:stream-purity`). + """ + captured = io.StringIO() + sent_line = json.dumps(initialize_body(request_id=1)) + "\n" + + with anyio.fail_after(5): + async with ( + stdio_server(stdin=anyio.wrap_file(io.StringIO(sent_line)), stdout=anyio.wrap_file(captured)) as ( + read_stream, + write_stream, + ), + read_stream, + write_stream, + ): + received = await read_stream.receive() + assert isinstance(received, SessionMessage) + assert isinstance(received.message, JSONRPCRequest) + assert received.message.method == "initialize" + + response = JSONRPCResponse(jsonrpc="2.0", id=1, result={"text": "line\nbreak"}) + notification = JSONRPCNotification( + jsonrpc="2.0", method="notifications/message", params={"level": "info", "data": "two\nlines"} + ) + await write_stream.send(SessionMessage(response)) + await write_stream.send(SessionMessage(notification)) + + output = captured.getvalue() + assert output.endswith("\n") + lines = output.removesuffix("\n").split("\n") + assert len(lines) == 2 + messages = [jsonrpc_message_adapter.validate_json(line) for line in lines] + assert [type(message).__name__ for message in messages] == snapshot(["JSONRPCResponse", "JSONRPCNotification"]) + # The newline inside the payload is JSON-escaped on the wire, not a literal newline that would + # break the one-message-per-line framing. + assert r"line\nbreak" in lines[0] + assert r"two\nlines" in lines[1] diff --git a/tests/interaction/transports/test_streamable_http.py b/tests/interaction/transports/test_streamable_http.py new file mode 100644 index 0000000000..d38e2a0bb3 --- /dev/null +++ b/tests/interaction/transports/test_streamable_http.py @@ -0,0 +1,168 @@ +"""Behaviour specific to the streamable HTTP transport, exercised entirely in process. + +Transport-agnostic behaviour is covered by the `connect`-fixture matrix, which runs the rest of +the suite over this transport as well; this file only pins what cannot be observed in memory: the +server's stateless and JSON-response modes, the standalone GET stream, and the full-duplex +server-initiated exchange on a still-open call. Every test drives the server's real Starlette app +through the suite's streaming ASGI bridge — no sockets, threads, or subprocesses. +""" + +import anyio +import pytest +from inline_snapshot import snapshot +from pydantic import BaseModel + +from mcp.client import ClientRequestContext +from mcp.server.elicitation import AcceptedElicitation +from mcp.server.mcpserver import Context, MCPServer +from mcp.types import ( + CallToolResult, + ElicitRequestParams, + ElicitResult, + LoggingMessageNotification, + LoggingMessageNotificationParams, + ResourceUpdatedNotification, + ResourceUpdatedNotificationParams, + TextContent, +) +from tests.interaction._connect import connect_over_streamable_http +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +def _smoke_server() -> MCPServer: + """A server exercising each message shape the transport-specific tests need.""" + mcp = MCPServer("smoke", instructions="Talk to the smoke server.") + + @mcp.tool() + def echo(text: str) -> str: + """Echo the text back.""" + return text + + class Confirmation(BaseModel): + confirmed: bool + + @mcp.tool() + async def ask(ctx: Context) -> str: + """Elicit a confirmation from the client and report the outcome.""" + answer = await ctx.elicit("Proceed?", Confirmation) + # In stateless mode the elicit raises before this point: there is no session to call back through. + assert isinstance(answer, AcceptedElicitation) + return f"confirmed={answer.data.confirmed}" + + @mcp.tool() + async def announce(ctx: Context) -> str: + """Send one notification related to this request and one that is not.""" + await ctx.info("about to announce") + await ctx.session.send_resource_updated("file:///watched.txt") + return "announced" + + return mcp + + +@requirement("transport:streamable-http:json-response") +@requirement("client-transport:http:json-response-parsed") +async def test_tool_call_over_streamable_http_with_json_responses() -> None: + """The round trip works when the server answers with a single JSON body instead of an SSE stream.""" + async with connect_over_streamable_http(_smoke_server(), json_response=True) as client: + assert client.initialize_result.server_info.name == "smoke" + result = await client.call_tool("echo", {"text": "as json"}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="as json")], structured_content={"result": "as json"}) + ) + + +@requirement("transport:streamable-http:stateless") +async def test_tool_calls_over_stateless_streamable_http() -> None: + """Consecutive requests each succeed against a stateless server with no session to share.""" + async with connect_over_streamable_http(_smoke_server(), stateless_http=True) as client: + first = await client.call_tool("echo", {"text": "first"}) + second = await client.call_tool("echo", {"text": "second"}) + + assert first == snapshot( + CallToolResult(content=[TextContent(text="first")], structured_content={"result": "first"}) + ) + assert second == snapshot( + CallToolResult(content=[TextContent(text="second")], structured_content={"result": "second"}) + ) + + +@requirement("transport:streamable-http:stateless-restrictions") +async def test_stateless_streamable_http_rejects_server_initiated_requests() -> None: + """A handler that tries to call back to the client in stateless mode fails: there is no session.""" + async with connect_over_streamable_http(_smoke_server(), stateless_http=True) as client: + result = await client.call_tool("ask", {}) + + assert result.is_error is True + assert isinstance(result.content[0], TextContent) + # The exact message is the StatelessModeNotSupported exception text wrapped by the tool-error + # path; pin the stable prefix rather than the full exception prose. + assert result.content[0].text.startswith("Error executing tool ask:") + + +@requirement("transport:streamable-http:notifications") +@requirement("transport:streamable-http:unrelated-messages") +@requirement("hosting:http:standalone-sse") +async def test_unrelated_server_messages_arrive_on_the_standalone_stream() -> None: + """A server message with no related request reaches the client through the standalone GET stream. + + The log notification is related to the tool call and travels on that call's own SSE stream; + the resource-updated notification is not related to any request, so the only way it can reach + the client is the standalone stream the client opens after initialization. Delivery order + across the two streams is not guaranteed, so the unrelated message is awaited rather than + assumed to beat the tool result. + """ + received: list[IncomingMessage] = [] + resource_update_seen = anyio.Event() + + async def collect(message: IncomingMessage) -> None: + received.append(message) + if isinstance(message, ResourceUpdatedNotification): + resource_update_seen.set() + + async with connect_over_streamable_http(_smoke_server(), message_handler=collect) as client: + result = await client.call_tool("announce", {}) + with anyio.fail_after(5): + await resource_update_seen.wait() + + assert result == snapshot( + CallToolResult(content=[TextContent(text="announced")], structured_content={"result": "announced"}) + ) + # The related log notification rides the call's stream; the unrelated resource-updated + # notification rides the standalone stream. Both arrive, nothing else does. + assert [message for message in received if isinstance(message, LoggingMessageNotification)] == snapshot( + [LoggingMessageNotification(params=LoggingMessageNotificationParams(level="info", data="about to announce"))] + ) + assert [message for message in received if isinstance(message, ResourceUpdatedNotification)] == snapshot( + [ResourceUpdatedNotification(params=ResourceUpdatedNotificationParams(uri="file:///watched.txt"))] + ) + assert len(received) == 2 + + +@requirement("transport:streamable-http:stateful") +@requirement("transport:streamable-http:server-to-client") +async def test_server_initiated_elicitation_round_trips_during_a_tool_call() -> None: + """An elicitation issued mid-call reaches the client and its answer reaches the handler over stateful HTTP. + + The elicitation request travels on the still-open SSE response of the tool call that triggered + it, and the client's answer arrives as a separate POST -- the full-duplex exchange the + streamable HTTP transport exists to provide. + """ + asked: list[ElicitRequestParams] = [] + + async def answer(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + asked.append(params) + return ElicitResult(action="accept", content={"confirmed": True}) + + async with connect_over_streamable_http(_smoke_server(), elicitation_callback=answer) as client: + # Bounded because a harness regression here historically meant deadlock, not failure. + with anyio.fail_after(5): + result = await client.call_tool("ask", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="confirmed=True")], structured_content={"result": "confirmed=True"}) + ) + assert [params.message for params in asked] == snapshot(["Proceed?"]) diff --git a/uv.lock b/uv.lock index b396898b66..5b72e97fce 100644 --- a/uv.lock +++ b/uv.lock @@ -939,7 +939,7 @@ dev = [ { name = "mcp", extras = ["cli", "ws"], editable = "." }, { name = "pillow", specifier = ">=12.0" }, { name = "pyright", specifier = ">=1.1.400" }, - { name = "pytest", specifier = ">=8.3.4" }, + { name = "pytest", specifier = ">=8.4.0" }, { name = "pytest-examples", specifier = ">=0.0.14" }, { name = "pytest-flakefinder", specifier = ">=1.1.0" }, { name = "pytest-pretty", specifier = ">=1.2.0" }, From 616476f6927a5c64213ea97bbd36a7466f410775 Mon Sep 17 00:00:00 2001 From: Max <224885523+maxisbey@users.noreply.github.com> Date: Fri, 29 May 2026 17:48:40 +0100 Subject: [PATCH 49/60] Bind transport sessions to the authenticated principal (#2718) --- src/mcp/server/auth/middleware/bearer_auth.py | 26 +- src/mcp/server/sse.py | 66 ++-- src/mcp/server/streamable_http_manager.py | 40 ++- src/mcp/server/transport_security.py | 14 +- tests/server/test_sse_security.py | 288 +++++++++++++++++- tests/server/test_streamable_http_manager.py | 167 +++++++++- tests/server/test_transport_security.py | 88 ++++++ 7 files changed, 647 insertions(+), 42 deletions(-) create mode 100644 tests/server/test_transport_security.py diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index 2eafdc793e..ba66e94226 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -1,6 +1,6 @@ import json import time -from typing import Any +from typing import Any, TypedDict from pydantic import AnyHttpUrl from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser @@ -19,6 +19,30 @@ def __init__(self, auth_info: AccessToken): self.scopes = auth_info.scopes +class AuthorizationContext(TypedDict): + client_id: str + issuer: str | None + subject: str | None + + +def authorization_context(user: AuthenticatedUser) -> AuthorizationContext: + """Identify the principal `user` represents, for transports to compare + against the principal that created a session. Components the token + verifier does not supply are `None`, so the comparison degrades to the + remaining components. + + See `examples/servers/simple-auth/mcp_simple_auth/token_verifier.py` for + a verifier that populates `subject` and `claims` from an introspection + response.""" + token = user.access_token + issuer = (token.claims or {}).get("iss") + return AuthorizationContext( + client_id=token.client_id, + issuer=str(issuer) if issuer is not None else None, + subject=token.subject, + ) + + class BearerAuthBackend(AuthenticationBackend): """Authentication backend that validates Bearer tokens using a TokenVerifier.""" diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index be8e979c9d..05e948332b 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -50,6 +50,7 @@ async def handle_sse(request): from starlette.types import Receive, Scope, Send from mcp import types +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, AuthorizationContext, authorization_context from mcp.server.transport_security import ( TransportSecurityMiddleware, TransportSecuritySettings, @@ -73,6 +74,9 @@ class SseServerTransport: _endpoint: str _read_stream_writers: dict[UUID, ContextSendStream[SessionMessage | Exception]] + # Identity of the credential that created each session; requests for a + # session must present the same credential. + _session_owners: dict[UUID, AuthorizationContext] _security: TransportSecurityMiddleware def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | None = None) -> None: @@ -112,19 +116,20 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | self._endpoint = endpoint self._read_stream_writers = {} + self._session_owners = {} self._security = TransportSecurityMiddleware(security_settings) logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}") @asynccontextmanager async def connect_sse(self, scope: Scope, receive: Receive, send: Send): - if scope["type"] != "http": # pragma: no cover + if scope["type"] != "http": logger.error("connect_sse received non-HTTP request") raise ValueError("connect_sse can only handle HTTP requests") # Validate request headers for DNS rebinding protection request = Request(scope, receive) error_response = await self._security.validate_request(request, is_post=False) - if error_response: # pragma: no cover + if error_response: await error_response(scope, receive, send) raise ValueError("Request validation failed") @@ -134,6 +139,9 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): write_stream, write_stream_reader = create_context_streams[SessionMessage](0) session_id = uuid4() + user = scope.get("user") + if isinstance(user, AuthenticatedUser): + self._session_owners[session_id] = authorization_context(user) self._read_stream_writers[session_id] = read_stream_writer logger.debug(f"Created new session with ID: {session_id}") @@ -169,27 +177,30 @@ async def sse_writer(): } ) - async with anyio.create_task_group() as tg: - - async def response_wrapper(scope: Scope, receive: Receive, send: Send): - """The EventSourceResponse returning signals a client close / disconnect. - In this case we close our side of the streams to signal the client that - the connection has been closed. - """ - await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)( - scope, receive, send - ) - await sse_stream_reader.aclose() - await read_stream_writer.aclose() - await write_stream_reader.aclose() - self._read_stream_writers.pop(session_id, None) - logging.debug(f"Client session disconnected {session_id}") + try: + async with anyio.create_task_group() as tg: + + async def response_wrapper(scope: Scope, receive: Receive, send: Send): + """The EventSourceResponse returning signals a client close / disconnect. + In this case we close our side of the streams to signal the client that + the connection has been closed. + """ + await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)( + scope, receive, send + ) + await read_stream_writer.aclose() + await write_stream_reader.aclose() + await sse_stream_reader.aclose() + logging.debug(f"Client session disconnected {session_id}") - logger.debug("Starting SSE response task") - tg.start_soon(response_wrapper, scope, receive, send) + logger.debug("Starting SSE response task") + tg.start_soon(response_wrapper, scope, receive, send) - logger.debug("Yielding read and write streams") - yield (read_stream, write_stream) + logger.debug("Yielding read and write streams") + yield (read_stream, write_stream) + finally: + self._read_stream_writers.pop(session_id, None) + self._session_owners.pop(session_id, None) async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: logger.debug("Handling POST message") @@ -197,7 +208,7 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) # Validate request headers for DNS rebinding protection error_response = await self._security.validate_request(request, is_post=True) - if error_response: # pragma: no cover + if error_response: return await error_response(scope, receive, send) session_id_param = request.query_params.get("session_id") @@ -220,13 +231,22 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) response = Response("Could not find session", status_code=404) return await response(scope, receive, send) + user = scope.get("user") + requestor = authorization_context(user) if isinstance(user, AuthenticatedUser) else None + if requestor != self._session_owners.get(session_id): + # A session can only be used with the credential that created it. + # Respond exactly as if the session did not exist. + logger.warning("Rejecting message for session %s: credential does not match", session_id) + response = Response("Could not find session", status_code=404) + return await response(scope, receive, send) + body = await request.body() logger.debug(f"Received JSON: {body}") try: message = types.jsonrpc_message_adapter.validate_json(body, by_name=False) logger.debug(f"Validated client message: {message}") - except ValidationError as err: # pragma: no cover + except ValidationError as err: logger.exception("Failed to parse message") response = Response("Could not parse message", status_code=400) await response(scope, receive, send) diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 39d434505c..81350a8f24 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -5,7 +5,6 @@ import contextlib import logging from collections.abc import AsyncIterator -from http import HTTPStatus from typing import TYPE_CHECKING, Any from uuid import uuid4 @@ -15,6 +14,7 @@ from starlette.responses import Response from starlette.types import Receive, Scope, Send +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, AuthorizationContext, authorization_context from mcp.server.streamable_http import ( MCP_SESSION_ID_HEADER, EventStore, @@ -89,6 +89,9 @@ def __init__( # Session tracking (only used if not stateless) self._session_creation_lock = anyio.Lock() self._server_instances: dict[str, StreamableHTTPServerTransport] = {} + # Identity of the credential that created each session; requests for a + # session must present the same credential. + self._session_owners: dict[str, AuthorizationContext] = {} # The task group will be set during lifespan self._task_group = None @@ -135,6 +138,7 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]: self._task_group = None # Clear any remaining server instances self._server_instances.clear() + self._session_owners.clear() async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: """Process ASGI request with proper session handling and transport setup. @@ -192,9 +196,29 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S request = Request(scope, receive) request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) + user = scope.get("user") + requestor = authorization_context(user) if isinstance(user, AuthenticatedUser) else None + # Existing session case if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: transport = self._server_instances[request_mcp_session_id] + if requestor != self._session_owners.get(request_mcp_session_id): + # A session can only be used with the credential that created + # it. Respond exactly as if the session did not exist. + logger.warning( + "Rejecting request for session %s: credential does not match the one that created the session", + request_mcp_session_id[:64], + ) + body = JSONRPCError( + jsonrpc="2.0", id=None, error=ErrorData(code=INVALID_REQUEST, message="Session not found") + ) + response = Response( + body.model_dump_json(by_alias=True, exclude_unset=True), + status_code=404, + media_type="application/json", + ) + await response(scope, receive, send) + return logger.debug("Session already exists, handling request directly") # Push back idle deadline on activity if transport.idle_scope is not None and self.session_idle_timeout is not None: @@ -216,6 +240,8 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S ) assert http_transport.mcp_session_id is not None + if requestor is not None: + self._session_owners[http_transport.mcp_session_id] = requestor self._server_instances[http_transport.mcp_session_id] = http_transport logger.info(f"Created new transport with session ID: {new_session_id}") @@ -246,6 +272,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE assert http_transport.mcp_session_id is not None logger.info(f"Session {http_transport.mcp_session_id} idle timeout") self._server_instances.pop(http_transport.mcp_session_id, None) + self._session_owners.pop(http_transport.mcp_session_id, None) await http_transport.terminate() except Exception: logger.exception(f"Session {http_transport.mcp_session_id} crashed") @@ -260,6 +287,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE f"{http_transport.mcp_session_id} from active instances." ) del self._server_instances[http_transport.mcp_session_id] + self._session_owners.pop(http_transport.mcp_session_id, None) # Assert task group is not None for type checking assert self._task_group is not None @@ -273,15 +301,11 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE # TODO: Align error code once spec clarifies # See: https://github.com/modelcontextprotocol/python-sdk/issues/1821 logger.info(f"Rejected request with unknown or expired session ID: {request_mcp_session_id[:64]}") - error_response = JSONRPCError( - jsonrpc="2.0", - id=None, - error=ErrorData(code=INVALID_REQUEST, message="Session not found"), + body = JSONRPCError( + jsonrpc="2.0", id=None, error=ErrorData(code=INVALID_REQUEST, message="Session not found") ) response = Response( - content=error_response.model_dump_json(by_alias=True, exclude_unset=True), - status_code=HTTPStatus.NOT_FOUND, - media_type="application/json", + body.model_dump_json(by_alias=True, exclude_unset=True), status_code=404, media_type="application/json" ) await response(scope, receive, send) diff --git a/src/mcp/server/transport_security.py b/src/mcp/server/transport_security.py index 707d4b61dd..d9e9f965b3 100644 --- a/src/mcp/server/transport_security.py +++ b/src/mcp/server/transport_security.py @@ -42,17 +42,17 @@ def __init__(self, settings: TransportSecuritySettings | None = None): def _validate_host(self, host: str | None) -> bool: """Validate the Host header against allowed values.""" - if not host: # pragma: no cover + if not host: logger.warning("Missing Host header in request") return False # Check exact match first - if host in self.settings.allowed_hosts: # pragma: no cover + if host in self.settings.allowed_hosts: return True # Check wildcard port patterns for allowed in self.settings.allowed_hosts: - if allowed.endswith(":*"): # pragma: no branch + if allowed.endswith(":*"): # Extract base host from pattern base_host = allowed[:-2] # Check if the actual host starts with base host and has a port @@ -65,16 +65,16 @@ def _validate_host(self, host: str | None) -> bool: def _validate_origin(self, origin: str | None) -> bool: """Validate the Origin header against allowed values.""" # Origin can be absent for same-origin requests - if not origin: # pragma: no cover + if not origin: return True # Check exact match first - if origin in self.settings.allowed_origins: # pragma: no cover + if origin in self.settings.allowed_origins: return True # Check wildcard port patterns for allowed in self.settings.allowed_origins: - if allowed.endswith(":*"): # pragma: no branch + if allowed.endswith(":*"): # Extract base origin from pattern base_origin = allowed[:-2] # Check if the actual origin starts with base origin and has a port @@ -94,7 +94,7 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res Returns None if validation passes, or an error Response if validation fails. """ # Always validate Content-Type for POST requests - if is_post: # pragma: no branch + if is_post: content_type = request.headers.get("content-type") if not self._validate_content_type(content_type): return Response("Invalid Content-Type header", status_code=400) diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index 010eaf6a25..e95dc51b31 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -1,27 +1,44 @@ -"""Tests for SSE server DNS rebinding protection.""" +"""Tests for SSE server request validation.""" import logging import multiprocessing +import re import socket +import anyio import httpx import pytest +import sse_starlette.sse import uvicorn from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response from starlette.routing import Mount, Route +from starlette.types import Message, Receive, Scope, Send from mcp.server import Server +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser +from mcp.server.auth.provider import AccessToken from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings -from mcp.types import Tool +from mcp.shared._stream_protocols import WriteStream +from mcp.shared.message import SessionMessage +from mcp.types import JSONRPCRequest, JSONRPCResponse, Tool from tests.test_helpers import wait_for_server logger = logging.getLogger(__name__) SERVER_NAME = "test_sse_security_server" +@pytest.fixture(autouse=True) +def reset_sse_starlette_exit_event() -> None: + """sse-starlette<2 caches a module-level anyio.Event on AppStatus; reset it + between tests so it is not bound to a previous test's event loop.""" + app_status = getattr(sse_starlette.sse, "AppStatus", None) + if app_status is not None and hasattr(app_status, "should_exit_event"): # pragma: lax no cover + app_status.should_exit_event = None + + @pytest.fixture def server_port() -> int: with socket.socket() as s: @@ -291,3 +308,270 @@ async def test_sse_security_post_valid_content_type(server_port: int): finally: process.terminate() process.join() + + +def _authenticated_user(client_id: str, subject: str | None = None, issuer: str | None = None) -> AuthenticatedUser: + """Build the scope["user"] value that AuthenticationMiddleware would set for this principal.""" + claims = {"iss": issuer} if issuer is not None else None + return AuthenticatedUser(AccessToken(token="token", client_id=client_id, scopes=[], subject=subject, claims=claims)) + + +def _sse_scope( + method: str, path: str, user: AuthenticatedUser | None, *, query_string: bytes = b"", body: bytes = b"" +) -> tuple[Scope, Receive, Send, list[Message]]: + """Build an ASGI scope/receive/send triple for a request to the SSE transport.""" + scope: Scope = { + "type": "http", + "method": method, + "path": path, + "root_path": "", + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + } + if user is not None: + scope["user"] = user + sent: list[Message] = [] + + async def receive() -> Message: + return {"type": "http.request", "body": body, "more_body": False} + + async def send(message: Message) -> None: + sent.append(message) + + return scope, receive, send, sent + + +def _response_status(sent: list[Message]) -> int: + response_start = next(msg for msg in sent if msg["type"] == "http.response.start") + return response_start["status"] + + +async def _post_message(transport: SseServerTransport, session_id: str, user: AuthenticatedUser | None) -> int: + """POST a message to an SSE session as `user` and return the response status.""" + body = b'{"jsonrpc": "2.0", "id": 1, "method": "ping", "params": null}' + scope, receive, send, sent = _sse_scope( + "POST", "/messages/", user, query_string=f"session_id={session_id}".encode(), body=body + ) + await transport.handle_post_message(scope, receive, send) + return _response_status(sent) + + +_Principal = tuple[str] | tuple[str, str] | tuple[str, str, str] + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("creator", "sender", "expected"), + [ + pytest.param(("client-a",), ("client-b",), 404, id="different-client"), + pytest.param(("client-a",), None, 404, id="unauthenticated-sender"), + pytest.param(("client-a", "alice"), ("client-a", "bob"), 404, id="same-client-different-subject"), + pytest.param(("client-a", "alice"), ("client-a",), 404, id="same-client-no-subject"), + pytest.param( + ("client-a", "alice", "https://i1"), ("client-a", "alice", "https://i2"), 404, id="different-issuer" + ), + pytest.param(None, ("client-a",), 404, id="unauthenticated-creator"), + pytest.param(("client-a",), ("client-a",), 202, id="same-client"), + pytest.param(("client-a", "alice"), ("client-a", "alice"), 202, id="same-client-and-subject"), + pytest.param(None, None, 202, id="both-unauthenticated"), + ], +) +async def test_sse_post_requires_the_credential_that_created_the_session( + creator: _Principal | None, + sender: _Principal | None, + expected: int, +): + """The session endpoint URL issued to one authenticated principal must not + accept messages from a request authenticated as a different one.""" + transport = SseServerTransport("/messages/") + session_id_received = anyio.Event() + session_ids: list[str] = [] + client_disconnected = anyio.Event() + + async def get_send(message: Message) -> None: + # The first body chunk is the SSE event announcing the session URI to POST messages to. + if message["type"] == "http.response.body" and not session_ids: + match = re.search(rb"session_id=([0-9a-f]{32})", message.get("body", b"")) + assert match is not None, f"expected the endpoint event first, got {message!r}" + session_ids.append(match.group(1).decode()) + session_id_received.set() + + async def get_receive() -> Message: + # The SSE client stays connected until the test signals otherwise. + await client_disconnected.wait() + return {"type": "http.disconnect"} + + creator_user = _authenticated_user(*creator) if creator is not None else None + sender_user = _authenticated_user(*sender) if sender is not None else None + + async def hold_sse_connection() -> None: + """Establish the SSE session as `creator` and keep it open, as a server would.""" + scope, _, _, _ = _sse_scope("GET", "/sse", creator_user) + with anyio.fail_after(5): + async with transport.connect_sse(scope, get_receive, get_send) as (read_stream, write_stream): + async with read_stream, write_stream: # pragma: no branch + async for _ in read_stream: + pass + + async with anyio.create_task_group() as tg: + tg.start_soon(hold_sse_connection) + with anyio.fail_after(5): + await session_id_received.wait() + + assert await _post_message(transport, session_ids[0], sender_user) == expected + + client_disconnected.set() + + # Once the connection is gone the session is no longer routable. + assert await _post_message(transport, session_ids[0], creator_user) == 404 + + +@pytest.mark.anyio +async def test_sse_connect_rejects_a_non_http_scope(): + """connect_sse refuses ASGI scopes that are not HTTP requests.""" + transport = SseServerTransport("/messages/") + with pytest.raises(ValueError): + async with transport.connect_sse({"type": "websocket"}, _no_receive, _no_send): + raise NotImplementedError + + +@pytest.mark.anyio +async def test_sse_connect_rejects_a_disallowed_host(): + """connect_sse rejects requests whose Host header fails the configured security check.""" + settings = TransportSecuritySettings(allowed_hosts=["allowed.example.com"]) + transport = SseServerTransport("/messages/", security_settings=settings) + scope, receive, send, sent = _sse_scope("GET", "/sse", None) + scope["headers"] = [(b"host", b"disallowed.example.com")] + + with pytest.raises(ValueError): + async with transport.connect_sse(scope, receive, send): + raise NotImplementedError + assert _response_status(sent) == 421 + + +@pytest.mark.anyio +async def test_sse_post_without_a_session_id_returns_400(): + """POSTs to the messages endpoint must include a session_id query parameter.""" + transport = SseServerTransport("/messages/") + scope, receive, send, sent = _sse_scope("POST", "/messages/", None) + + await transport.handle_post_message(scope, receive, send) + assert _response_status(sent) == 400 + + +@pytest.mark.anyio +async def test_sse_post_with_a_malformed_session_id_returns_400(): + """A session_id that is not 32 hex characters is rejected before any session lookup.""" + transport = SseServerTransport("/messages/") + scope, receive, send, sent = _sse_scope("POST", "/messages/", None, query_string=b"session_id=not-hex") + + await transport.handle_post_message(scope, receive, send) + assert _response_status(sent) == 400 + + +@pytest.mark.anyio +async def test_sse_post_with_a_disallowed_host_is_rejected_before_session_lookup(): + """The transport security check on POST runs before any session-ID handling.""" + settings = TransportSecuritySettings(allowed_hosts=["allowed.example.com"]) + transport = SseServerTransport("/messages/", security_settings=settings) + scope, receive, send, sent = _sse_scope("POST", "/messages/", None) + scope["headers"] = [(b"host", b"disallowed.example.com"), (b"content-type", b"application/json")] + + await transport.handle_post_message(scope, receive, send) + assert _response_status(sent) == 421 + + +@pytest.mark.anyio +async def test_sse_round_trip_delivers_posted_messages_and_streams_responses(): + """A POSTed JSON-RPC message reaches the server's read stream, and a message + written to the server's write stream is sent to the client as an SSE event.""" + transport = SseServerTransport("/messages/") + session = _SseSession(transport) + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + tg.start_soon(session.hold) + await session.ready.wait() + + # POST a parse-failing body: client gets 400, server's read stream receives the error. + scope, receive, send, sent = _sse_scope( + "POST", "/messages/", None, query_string=f"session_id={session.session_id}".encode(), body=b"not json" + ) + await transport.handle_post_message(scope, receive, send) + assert _response_status(sent) == 400 + assert isinstance(await session.next_read_item(), Exception) + + # POST a valid message: client gets 202, server's read stream receives it. + assert await _post_message(transport, session.session_id, None) == 202 + received = await session.next_read_item() + assert isinstance(received, SessionMessage) + assert isinstance(received.message, JSONRPCRequest) + assert received.message.method == "ping" + + # Server writes a response: it appears as an SSE `message` event on the GET stream. + outgoing = JSONRPCResponse(jsonrpc="2.0", id=1, result={}) + await session.write_stream.send(SessionMessage(outgoing)) + chunk = await session.next_body_chunk() + assert b"event: message" in chunk + assert outgoing.model_dump_json(by_alias=True, exclude_unset=True).encode() in chunk + + session.disconnect() + + +class _SseSession: + """Drive an in-process SSE GET connection and surface what the server reads and the client receives. + + `hold` runs the connection in a background task and consumes the server-side read stream + into a buffer so that `handle_post_message` (which writes to that stream with a zero-capacity + channel) never blocks the test body. + """ + + def __init__(self, transport: SseServerTransport) -> None: + self.transport = transport + self.ready = anyio.Event() + self._disconnected = anyio.Event() + self._body_send, self._body_recv = anyio.create_memory_object_stream[bytes](16) + self._read_send, self._read_recv = anyio.create_memory_object_stream[SessionMessage | Exception](16) + self.session_id = "" + self.write_stream: WriteStream[SessionMessage] + + async def hold(self) -> None: + scope, _, _, _ = _sse_scope("GET", "/sse", None) + async with self.transport.connect_sse(scope, self._receive, self._send) as (read, write): + self.write_stream = write + async with read, write, self._body_send, self._body_recv, self._read_send, self._read_recv: + async for item in read: + await self._read_send.send(item) + + def disconnect(self) -> None: + self._disconnected.set() + + async def next_read_item(self) -> SessionMessage | Exception: + return await self._read_recv.receive() + + async def next_body_chunk(self) -> bytes: + return await self._body_recv.receive() + + async def _receive(self) -> Message: + await self._disconnected.wait() + return {"type": "http.disconnect"} + + async def _send(self, message: Message) -> None: + if message["type"] != "http.response.body": + return + body: bytes = message.get("body", b"") + if not self.session_id: + match = re.search(rb"session_id=([0-9a-f]{32})", body) + assert match is not None, f"expected the endpoint event first, got {message!r}" + self.session_id = match.group(1).decode() + self.ready.set() + else: + await self._body_send.send(body) + + +async def _no_receive() -> Message: + raise NotImplementedError + + +async def _no_send(message: Message) -> None: + raise NotImplementedError diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 47cfbf14a4..ba75547964 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -8,11 +8,13 @@ import anyio import httpx import pytest -from starlette.types import Message +from starlette.types import Message, Scope from mcp import Client from mcp.client.streamable_http import streamable_http_client from mcp.server import Server, ServerRequestContext, streamable_http_manager +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser +from mcp.server.auth.provider import AccessToken from mcp.server.streamable_http import MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.types import INVALID_REQUEST, ListToolsResult, PaginatedRequestParams @@ -413,3 +415,166 @@ def test_session_idle_timeout_rejects_non_positive(): def test_session_idle_timeout_rejects_stateless(): with pytest.raises(RuntimeError, match="not supported in stateless"): StreamableHTTPSessionManager(app=Server("test"), session_idle_timeout=30, stateless=True) + + +def _user(client_id: str, subject: str | None = None, issuer: str | None = None) -> AuthenticatedUser: + """Build the scope["user"] value that AuthenticationMiddleware would set for this principal.""" + claims = {"iss": issuer} if issuer is not None else None + return AuthenticatedUser(AccessToken(token="token", client_id=client_id, scopes=[], subject=subject, claims=claims)) + + +def _request_scope( + *, session_id: str | None = None, user: AuthenticatedUser | None = None, method: str = "POST" +) -> Scope: + """Build an ASGI scope for a request to the MCP endpoint.""" + headers = [ + (b"content-type", b"application/json"), + (b"accept", b"application/json, text/event-stream"), + ] + if session_id is not None: + headers.append((b"mcp-session-id", session_id.encode())) + scope: Scope = { + "type": "http", + "method": method, + "path": "/mcp", + "headers": headers, + } + if user is not None: + scope["user"] = user + return scope + + +async def _open_session(manager: StreamableHTTPSessionManager, user: AuthenticatedUser | None) -> str: + """Create a new session as `user` and return its session ID.""" + sent_messages: list[Message] = [] + + async def mock_send(message: Message) -> None: + sent_messages.append(message) + + async def mock_receive() -> Message: + return {"type": "http.request", "body": b"", "more_body": False} + + await manager.handle_request(_request_scope(user=user), mock_receive, mock_send) + + response_start = next(msg for msg in sent_messages if msg["type"] == "http.response.start") + headers = dict(response_start.get("headers", [])) + return headers[MCP_SESSION_ID_HEADER.encode()].decode() + + +async def _request_session( + manager: StreamableHTTPSessionManager, session_id: str, user: AuthenticatedUser | None, method: str = "POST" +) -> int: + """Send a request for an existing session as `user` and return the response status.""" + sent_messages: list[Message] = [] + + async def mock_send(message: Message) -> None: + sent_messages.append(message) + + async def mock_receive() -> Message: + return {"type": "http.request", "body": b"", "more_body": False} + + await manager.handle_request( + _request_scope(session_id=session_id, user=user, method=method), mock_receive, mock_send + ) + + response_start = next(msg for msg in sent_messages if msg["type"] == "http.response.start") + return response_start["status"] + + +@pytest.fixture +async def manager_with_live_session(): + """A running manager around a real `Server`. Sessions remain registered until + `manager.run()` exits because `Server.run` blocks waiting for an initialize message.""" + manager = StreamableHTTPSessionManager(app=Server("test-session-credentials")) + async with manager.run(): + yield manager + + +@pytest.mark.anyio +async def test_session_accepts_requests_from_the_credential_that_created_it( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """Requests presenting the same credential as the one that created the session are served.""" + manager = manager_with_live_session + session_id = await _open_session(manager, _user("client-a")) + + status = await _request_session(manager, session_id, _user("client-a")) + + # The request passes the manager's credential check and reaches the + # session's transport, instead of being answered with 404 by the manager. + assert status != 404 + + +@pytest.mark.anyio +@pytest.mark.parametrize("method", ["POST", "GET", "DELETE"]) +async def test_session_rejects_requests_from_a_different_credential( + manager_with_live_session: StreamableHTTPSessionManager, method: str +) -> None: + """A session created by one credential cannot be used with another credential, whatever the method.""" + manager = manager_with_live_session + session_id = await _open_session(manager, _user("client-a")) + + assert await _request_session(manager, session_id, _user("client-b"), method) == 404 + # The session is still registered and still serves its creator. + assert await _request_session(manager, session_id, _user("client-a")) != 404 + + +@pytest.mark.anyio +async def test_session_rejects_requests_from_a_different_subject_of_the_same_client( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """Two end-users that share an OAuth client cannot use each other's sessions.""" + manager = manager_with_live_session + session_id = await _open_session(manager, _user("client-a", subject="alice")) + + assert await _request_session(manager, session_id, _user("client-a", subject="bob")) == 404 + assert await _request_session(manager, session_id, _user("client-a", subject=None)) == 404 + assert await _request_session(manager, session_id, _user("client-a", subject="alice")) != 404 + + +@pytest.mark.anyio +async def test_session_rejects_requests_with_the_same_subject_from_a_different_issuer( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """A subject is unique only per issuer, so a colliding subject from a different issuer is not the same principal.""" + manager = manager_with_live_session + creator = _user("client-a", subject="alice", issuer="https://issuer.one") + session_id = await _open_session(manager, creator) + + other_issuer = _user("client-a", subject="alice", issuer="https://issuer.two") + assert await _request_session(manager, session_id, other_issuer) == 404 + assert await _request_session(manager, session_id, _user("client-a", subject="alice")) == 404 + assert await _request_session(manager, session_id, creator) != 404 + + +@pytest.mark.anyio +async def test_session_rejects_unauthenticated_requests_for_an_authenticated_session( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """A session created with a credential cannot be used without one.""" + manager = manager_with_live_session + session_id = await _open_session(manager, _user("client-a")) + + assert await _request_session(manager, session_id, None) == 404 + + +@pytest.mark.anyio +async def test_session_rejects_authenticated_requests_for_an_anonymous_session( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """A session created without a credential cannot be used with one.""" + manager = manager_with_live_session + session_id = await _open_session(manager, None) + + assert await _request_session(manager, session_id, _user("client-a")) == 404 + + +@pytest.mark.anyio +async def test_anonymous_session_accepts_anonymous_requests( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """Servers without authentication keep working: no credential on either side.""" + manager = manager_with_live_session + session_id = await _open_session(manager, None) + + assert await _request_session(manager, session_id, None) != 404 diff --git a/tests/server/test_transport_security.py b/tests/server/test_transport_security.py new file mode 100644 index 0000000000..be28980b53 --- /dev/null +++ b/tests/server/test_transport_security.py @@ -0,0 +1,88 @@ +"""Tests for the transport-security request validation middleware.""" + +import pytest +from starlette.requests import Request + +from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings + + +def _request(host: str | None, origin: str | None, content_type: str | None = "application/json") -> Request: + headers: list[tuple[bytes, bytes]] = [] + if content_type is not None: + headers.append((b"content-type", content_type.encode())) + if host is not None: + headers.append((b"host", host.encode())) + if origin is not None: + headers.append((b"origin", origin.encode())) + return Request({"type": "http", "method": "GET", "headers": headers}) + + +SETTINGS = TransportSecuritySettings( + enable_dns_rebinding_protection=True, + allowed_hosts=["good.example", "wild.example:*"], + allowed_origins=["http://good.example", "http://wild.example:*"], +) + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("host", "origin", "expected"), + [ + pytest.param(None, None, 421, id="missing-host"), + pytest.param("evil.example", None, 421, id="host-no-match"), + pytest.param("evil.example:9000", None, 421, id="host-wildcard-base-mismatch"), + pytest.param("good.example", None, None, id="host-exact-no-origin"), + pytest.param("wild.example:9000", None, None, id="host-wildcard-match"), + pytest.param("good.example", "http://evil.example", 403, id="origin-no-match"), + pytest.param("good.example", "http://evil.example:9000", 403, id="origin-wildcard-base-mismatch"), + pytest.param("good.example", "http://good.example", None, id="origin-exact"), + pytest.param("good.example", "http://wild.example:9000", None, id="origin-wildcard-match"), + ], +) +async def test_validate_request_checks_host_then_origin( + host: str | None, origin: str | None, expected: int | None +) -> None: + """Host is checked first, then Origin; exact and wildcard-port allowlist entries are honoured.""" + middleware = TransportSecurityMiddleware(SETTINGS) + response = await middleware.validate_request(_request(host, origin)) + assert (None if response is None else response.status_code) == expected + + +@pytest.mark.anyio +async def test_validate_request_skips_host_and_origin_when_protection_is_disabled() -> None: + """With DNS-rebinding protection off, any Host/Origin is accepted.""" + middleware = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False)) + assert await middleware.validate_request(_request("evil.example", "http://evil.example")) is None + + +@pytest.mark.anyio +async def test_validate_request_defaults_to_protection_disabled() -> None: + """Constructing the middleware without settings leaves DNS-rebinding protection off.""" + middleware = TransportSecurityMiddleware() + assert await middleware.validate_request(_request("evil.example", "http://evil.example")) is None + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("content_type", "expected"), + [ + pytest.param("application/json", None, id="json"), + pytest.param("application/json; charset=utf-8", None, id="json-with-charset"), + pytest.param("APPLICATION/JSON", None, id="case-insensitive"), + pytest.param("text/plain", 400, id="wrong-type"), + pytest.param(None, 400, id="missing"), + ], +) +async def test_validate_request_checks_content_type_on_post(content_type: str | None, expected: int | None) -> None: + """POST requests must carry an application/json Content-Type, regardless of DNS-rebinding settings.""" + middleware = TransportSecurityMiddleware() + response = await middleware.validate_request(_request("any", None, content_type=content_type), is_post=True) + assert (None if response is None else response.status_code) == expected + + +@pytest.mark.anyio +async def test_validate_request_ignores_content_type_on_get() -> None: + """Content-Type is only enforced for POST requests.""" + middleware = TransportSecurityMiddleware(SETTINGS) + response = await middleware.validate_request(_request("good.example", None, content_type="text/plain")) + assert response is None From 2a3d065417088eafef86d0e1b47eca3facc16130 Mon Sep 17 00:00:00 2001 From: devteamaegis Date: Tue, 2 Jun 2026 11:53:51 -0400 Subject: [PATCH 50/60] fix: rename `.gitattribute` to `.gitattributes` so git actually reads it (#2656) Co-authored-by: devteamaegis --- .gitattribute => .gitattributes | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename .gitattribute => .gitattributes (100%) diff --git a/.gitattribute b/.gitattributes similarity index 100% rename from .gitattribute rename to .gitattributes From 453cafb9a9633ca747e193c7e1694c10545d91d2 Mon Sep 17 00:00:00 2001 From: Siddhiraj Katkar <91388781+siddhirajkatkar@users.noreply.github.com> Date: Tue, 2 Jun 2026 21:31:50 +0530 Subject: [PATCH 51/60] fix: add 'invalid_target' to AuthorizationErrorCode (RFC 8707) (#2642) Co-authored-by: Marcelo Trylesinski --- src/mcp/server/auth/provider.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index 4ce1137575..bb47c19566 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -68,6 +68,7 @@ class RegistrationError(Exception): "invalid_scope", "server_error", "temporarily_unavailable", + "invalid_target", ] From a5b2ebb660e59d4dc52315c2827cde5d5e83e92f Mon Sep 17 00:00:00 2001 From: "S;Co" Date: Tue, 2 Jun 2026 17:02:36 +0100 Subject: [PATCH 52/60] Clarify CLI subprocess environment comment (#2672) Co-authored-by: scosemicolon <277933778+scosemicolon@users.noreply.github.com> --- src/mcp/cli/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/cli/cli.py b/src/mcp/cli/cli.py index 62334a4a2c..44e2bae48e 100644 --- a/src/mcp/cli/cli.py +++ b/src/mcp/cli/cli.py @@ -277,7 +277,7 @@ def dev( [npx_cmd, "@modelcontextprotocol/inspector"] + uv_cmd, check=True, shell=shell, - env=dict(os.environ.items()), # Convert to list of tuples for env update + env=dict(os.environ.items()), # Copy the environment for subprocess launch ) sys.exit(process.returncode) except subprocess.CalledProcessError as e: From 8cc187fac06a85a0c80a84bdf47136a0accbb924 Mon Sep 17 00:00:00 2001 From: Max <224885523+maxisbey@users.noreply.github.com> Date: Tue, 2 Jun 2026 17:27:05 +0100 Subject: [PATCH 53/60] Remove Tasks (SEP-1686) from the SDK (#2714) --- README.v2.md | 1 - docs/experimental/index.md | 42 - docs/experimental/tasks-client.md | 361 -------- docs/experimental/tasks-server.md | 577 ------------ docs/experimental/tasks.md | 188 ---- docs/migration.md | 38 +- examples/clients/simple-task-client/README.md | 43 - .../mcp_simple_task_client/__init__.py | 0 .../mcp_simple_task_client/__main__.py | 5 - .../mcp_simple_task_client/main.py | 56 -- .../clients/simple-task-client/pyproject.toml | 43 - .../simple-task-interactive-client/README.md | 87 -- .../__init__.py | 0 .../__main__.py | 5 - .../main.py | 137 --- .../pyproject.toml | 43 - .../servers/simple-task-interactive/README.md | 74 -- .../mcp_simple_task_interactive/__init__.py | 0 .../mcp_simple_task_interactive/__main__.py | 5 - .../mcp_simple_task_interactive/server.py | 138 --- .../simple-task-interactive/pyproject.toml | 43 - examples/servers/simple-task/README.md | 37 - .../simple-task/mcp_simple_task/__init__.py | 0 .../simple-task/mcp_simple_task/__main__.py | 5 - .../simple-task/mcp_simple_task/server.py | 70 -- examples/servers/simple-task/pyproject.toml | 43 - mkdocs.yml | 6 - src/mcp/client/experimental/__init__.py | 8 - src/mcp/client/experimental/task_handlers.py | 293 ------ src/mcp/client/experimental/tasks.py | 208 ----- src/mcp/client/session.py | 53 +- src/mcp/server/context.py | 2 - src/mcp/server/experimental/__init__.py | 10 - .../server/experimental/request_context.py | 217 ----- .../server/experimental/session_features.py | 209 ----- src/mcp/server/experimental/task_context.py | 587 ------------ .../experimental/task_result_handler.py | 218 ----- src/mcp/server/experimental/task_support.py | 116 --- src/mcp/server/lowlevel/experimental.py | 210 ----- src/mcp/server/lowlevel/server.py | 52 +- src/mcp/server/session.py | 195 ---- src/mcp/server/validation.py | 3 +- src/mcp/shared/experimental/__init__.py | 6 - src/mcp/shared/experimental/tasks/__init__.py | 11 - .../shared/experimental/tasks/capabilities.py | 96 -- src/mcp/shared/experimental/tasks/context.py | 95 -- src/mcp/shared/experimental/tasks/helpers.py | 166 ---- .../tasks/in_memory_task_store.py | 217 ----- .../experimental/tasks/message_queue.py | 230 ----- src/mcp/shared/experimental/tasks/polling.py | 43 - src/mcp/shared/experimental/tasks/resolver.py | 58 -- src/mcp/shared/experimental/tasks/store.py | 144 --- src/mcp/shared/response_router.py | 61 -- src/mcp/shared/session.py | 38 +- src/mcp/types/__init__.py | 82 -- src/mcp/types/_types.py | 304 +----- tests/experimental/__init__.py | 0 tests/experimental/tasks/__init__.py | 1 - tests/experimental/tasks/client/__init__.py | 0 .../tasks/client/test_capabilities.py | 312 ------- .../tasks/client/test_handlers.py | 874 ------------------ .../tasks/client/test_poll_task.py | 121 --- tests/experimental/tasks/client/test_tasks.py | 309 ------- tests/experimental/tasks/server/__init__.py | 0 .../experimental/tasks/server/test_context.py | 183 ---- .../tasks/server/test_integration.py | 247 ----- .../tasks/server/test_run_task_flow.py | 367 -------- .../experimental/tasks/server/test_server.py | 797 ---------------- .../tasks/server/test_server_task_context.py | 709 -------------- tests/experimental/tasks/server/test_store.py | 406 -------- .../tasks/server/test_task_result_handler.py | 354 ------- tests/experimental/tasks/test_capabilities.py | 283 ------ .../tasks/test_elicitation_scenarios.py | 695 -------------- .../experimental/tasks/test_message_queue.py | 330 ------- .../tasks/test_request_context.py | 166 ---- .../tasks/test_spec_compliance.py | 717 -------------- tests/interaction/_requirements.py | 5 +- tests/issues/test_176_progress_token.py | 2 - tests/server/mcpserver/test_server.py | 2 - tests/server/test_session.py | 43 + uv.lock | 124 --- 81 files changed, 63 insertions(+), 12963 deletions(-) delete mode 100644 docs/experimental/index.md delete mode 100644 docs/experimental/tasks-client.md delete mode 100644 docs/experimental/tasks-server.md delete mode 100644 docs/experimental/tasks.md delete mode 100644 examples/clients/simple-task-client/README.md delete mode 100644 examples/clients/simple-task-client/mcp_simple_task_client/__init__.py delete mode 100644 examples/clients/simple-task-client/mcp_simple_task_client/__main__.py delete mode 100644 examples/clients/simple-task-client/mcp_simple_task_client/main.py delete mode 100644 examples/clients/simple-task-client/pyproject.toml delete mode 100644 examples/clients/simple-task-interactive-client/README.md delete mode 100644 examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/__init__.py delete mode 100644 examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/__main__.py delete mode 100644 examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py delete mode 100644 examples/clients/simple-task-interactive-client/pyproject.toml delete mode 100644 examples/servers/simple-task-interactive/README.md delete mode 100644 examples/servers/simple-task-interactive/mcp_simple_task_interactive/__init__.py delete mode 100644 examples/servers/simple-task-interactive/mcp_simple_task_interactive/__main__.py delete mode 100644 examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py delete mode 100644 examples/servers/simple-task-interactive/pyproject.toml delete mode 100644 examples/servers/simple-task/README.md delete mode 100644 examples/servers/simple-task/mcp_simple_task/__init__.py delete mode 100644 examples/servers/simple-task/mcp_simple_task/__main__.py delete mode 100644 examples/servers/simple-task/mcp_simple_task/server.py delete mode 100644 examples/servers/simple-task/pyproject.toml delete mode 100644 src/mcp/client/experimental/__init__.py delete mode 100644 src/mcp/client/experimental/task_handlers.py delete mode 100644 src/mcp/client/experimental/tasks.py delete mode 100644 src/mcp/server/experimental/__init__.py delete mode 100644 src/mcp/server/experimental/request_context.py delete mode 100644 src/mcp/server/experimental/session_features.py delete mode 100644 src/mcp/server/experimental/task_context.py delete mode 100644 src/mcp/server/experimental/task_result_handler.py delete mode 100644 src/mcp/server/experimental/task_support.py delete mode 100644 src/mcp/server/lowlevel/experimental.py delete mode 100644 src/mcp/shared/experimental/__init__.py delete mode 100644 src/mcp/shared/experimental/tasks/__init__.py delete mode 100644 src/mcp/shared/experimental/tasks/capabilities.py delete mode 100644 src/mcp/shared/experimental/tasks/context.py delete mode 100644 src/mcp/shared/experimental/tasks/helpers.py delete mode 100644 src/mcp/shared/experimental/tasks/in_memory_task_store.py delete mode 100644 src/mcp/shared/experimental/tasks/message_queue.py delete mode 100644 src/mcp/shared/experimental/tasks/polling.py delete mode 100644 src/mcp/shared/experimental/tasks/resolver.py delete mode 100644 src/mcp/shared/experimental/tasks/store.py delete mode 100644 src/mcp/shared/response_router.py delete mode 100644 tests/experimental/__init__.py delete mode 100644 tests/experimental/tasks/__init__.py delete mode 100644 tests/experimental/tasks/client/__init__.py delete mode 100644 tests/experimental/tasks/client/test_capabilities.py delete mode 100644 tests/experimental/tasks/client/test_handlers.py delete mode 100644 tests/experimental/tasks/client/test_poll_task.py delete mode 100644 tests/experimental/tasks/client/test_tasks.py delete mode 100644 tests/experimental/tasks/server/__init__.py delete mode 100644 tests/experimental/tasks/server/test_context.py delete mode 100644 tests/experimental/tasks/server/test_integration.py delete mode 100644 tests/experimental/tasks/server/test_run_task_flow.py delete mode 100644 tests/experimental/tasks/server/test_server.py delete mode 100644 tests/experimental/tasks/server/test_server_task_context.py delete mode 100644 tests/experimental/tasks/server/test_store.py delete mode 100644 tests/experimental/tasks/server/test_task_result_handler.py delete mode 100644 tests/experimental/tasks/test_capabilities.py delete mode 100644 tests/experimental/tasks/test_elicitation_scenarios.py delete mode 100644 tests/experimental/tasks/test_message_queue.py delete mode 100644 tests/experimental/tasks/test_request_context.py delete mode 100644 tests/experimental/tasks/test_spec_compliance.py diff --git a/README.v2.md b/README.v2.md index d0851c04e5..a888f21bda 100644 --- a/README.v2.md +++ b/README.v2.md @@ -2489,7 +2489,6 @@ MCP servers declare capabilities during initialization: ## Documentation - [API Reference](https://modelcontextprotocol.github.io/python-sdk/api/) -- [Experimental Features (Tasks)](https://modelcontextprotocol.github.io/python-sdk/experimental/tasks/) - [Model Context Protocol documentation](https://modelcontextprotocol.io) - [Model Context Protocol specification](https://modelcontextprotocol.io/specification/latest) - [Officially supported servers](https://github.com/modelcontextprotocol/servers) diff --git a/docs/experimental/index.md b/docs/experimental/index.md deleted file mode 100644 index c97fe2a3d6..0000000000 --- a/docs/experimental/index.md +++ /dev/null @@ -1,42 +0,0 @@ -# Experimental Features - -!!! warning "Experimental APIs" - - The features in this section are experimental and may change without notice. - They track the evolving MCP specification and are not yet stable. - -This section documents experimental features in the MCP Python SDK. These features -implement draft specifications that are still being refined. - -## Available Experimental Features - -### [Tasks](tasks.md) - -Tasks enable asynchronous execution of MCP operations. Instead of waiting for a -long-running operation to complete, the server returns a task reference immediately. -Clients can then poll for status updates and retrieve results when ready. - -Tasks are useful for: - -- **Long-running computations** that would otherwise block -- **Batch operations** that process many items -- **Interactive workflows** that require user input (elicitation) or LLM assistance (sampling) - -## Using Experimental APIs - -Experimental features are accessed via the `.experimental` property: - -```python -# Server-side: enable task support (auto-registers default handlers) -server = Server(name="my-server") -server.experimental.enable_tasks() - -# Client-side -result = await session.experimental.call_tool_as_task("tool_name", {"arg": "value"}) -``` - -## Providing Feedback - -Since these features are experimental, feedback is especially valuable. If you encounter -issues or have suggestions, please open an issue on the -[python-sdk repository](https://github.com/modelcontextprotocol/python-sdk/issues). diff --git a/docs/experimental/tasks-client.md b/docs/experimental/tasks-client.md deleted file mode 100644 index 0374ed86b5..0000000000 --- a/docs/experimental/tasks-client.md +++ /dev/null @@ -1,361 +0,0 @@ -# Client Task Usage - -!!! warning "Experimental" - - Tasks are an experimental feature. The API may change without notice. - -This guide covers calling task-augmented tools from clients, handling the `input_required` status, and advanced patterns like receiving task requests from servers. - -## Quick Start - -Call a tool as a task and poll for the result: - -```python -from mcp.client.session import ClientSession -from mcp.types import CallToolResult - -async with ClientSession(read, write) as session: - await session.initialize() - - # Call tool as task - result = await session.experimental.call_tool_as_task( - "process_data", - {"input": "hello"}, - ttl=60000, - ) - task_id = result.task.taskId - - # Poll until complete - async for status in session.experimental.poll_task(task_id): - print(f"Status: {status.status} - {status.statusMessage or ''}") - - # Get result - final = await session.experimental.get_task_result(task_id, CallToolResult) - print(f"Result: {final.content[0].text}") -``` - -## Calling Tools as Tasks - -Use `call_tool_as_task()` to invoke a tool with task augmentation: - -```python -result = await session.experimental.call_tool_as_task( - "my_tool", # Tool name - {"arg": "value"}, # Arguments - ttl=60000, # Time-to-live in milliseconds - meta={"key": "val"}, # Optional metadata -) - -task_id = result.task.taskId -print(f"Task: {task_id}, Status: {result.task.status}") -``` - -The response is a `CreateTaskResult` containing: - -- `task.taskId` - Unique identifier for polling -- `task.status` - Initial status (usually `"working"`) -- `task.pollInterval` - Suggested polling interval (milliseconds) -- `task.ttl` - Time-to-live for results -- `task.createdAt` - Creation timestamp - -## Polling with poll_task - -The `poll_task()` async iterator polls until the task reaches a terminal state: - -```python -async for status in session.experimental.poll_task(task_id): - print(f"Status: {status.status}") - if status.statusMessage: - print(f"Progress: {status.statusMessage}") -``` - -It automatically: - -- Respects the server's suggested `pollInterval` -- Stops when status is `completed`, `failed`, or `cancelled` -- Yields each status for progress display - -### Handling input_required - -When a task needs user input (elicitation), it transitions to `input_required`. You must call `get_task_result()` to receive and respond to the elicitation: - -```python -async for status in session.experimental.poll_task(task_id): - print(f"Status: {status.status}") - - if status.status == "input_required": - # This delivers the elicitation and waits for completion - final = await session.experimental.get_task_result(task_id, CallToolResult) - break -``` - -The elicitation callback (set during session creation) handles the actual user interaction. - -## Elicitation Callbacks - -To handle elicitation requests from the server, provide a callback when creating the session: - -```python -from mcp.types import ElicitRequestParams, ElicitResult - -async def handle_elicitation(context, params: ElicitRequestParams) -> ElicitResult: - # Display the message to the user - print(f"Server asks: {params.message}") - - # Collect user input (this is a simplified example) - response = input("Your response (y/n): ") - confirmed = response.lower() == "y" - - return ElicitResult( - action="accept", - content={"confirm": confirmed}, - ) - -async with ClientSession( - read, - write, - elicitation_callback=handle_elicitation, -) as session: - await session.initialize() - # ... call tasks that may require elicitation -``` - -## Sampling Callbacks - -Similarly, handle sampling requests with a callback: - -```python -from mcp.types import CreateMessageRequestParams, CreateMessageResult, TextContent - -async def handle_sampling(context, params: CreateMessageRequestParams) -> CreateMessageResult: - # In a real implementation, call your LLM here - prompt = params.messages[-1].content.text if params.messages else "" - - # Return a mock response - return CreateMessageResult( - role="assistant", - content=TextContent(type="text", text=f"Response to: {prompt}"), - model="my-model", - ) - -async with ClientSession( - read, - write, - sampling_callback=handle_sampling, -) as session: - # ... -``` - -## Retrieving Results - -Once a task completes, retrieve the result: - -```python -if status.status == "completed": - result = await session.experimental.get_task_result(task_id, CallToolResult) - for content in result.content: - if hasattr(content, "text"): - print(content.text) - -elif status.status == "failed": - print(f"Task failed: {status.statusMessage}") - -elif status.status == "cancelled": - print("Task was cancelled") -``` - -The result type matches the original request: - -- `tools/call` → `CallToolResult` -- `sampling/createMessage` → `CreateMessageResult` -- `elicitation/create` → `ElicitResult` - -## Cancellation - -Cancel a running task: - -```python -cancel_result = await session.experimental.cancel_task(task_id) -print(f"Cancelled, status: {cancel_result.status}") -``` - -Note: Cancellation is cooperative—the server must check for and handle cancellation. - -## Listing Tasks - -View all tasks on the server: - -```python -result = await session.experimental.list_tasks() -for task in result.tasks: - print(f"{task.taskId}: {task.status}") - -# Handle pagination -while result.nextCursor: - result = await session.experimental.list_tasks(cursor=result.nextCursor) - for task in result.tasks: - print(f"{task.taskId}: {task.status}") -``` - -## Advanced: Client as Task Receiver - -Servers can send task-augmented requests to clients. This is useful when the server needs the client to perform async work (like complex sampling or user interaction). - -### Declaring Client Capabilities - -Register task handlers to declare what task-augmented requests your client accepts: - -```python -from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers -from mcp.types import ( - CreateTaskResult, GetTaskResult, GetTaskPayloadResult, - TaskMetadata, ElicitRequestParams, -) -from mcp.shared.experimental.tasks import InMemoryTaskStore - -# Client-side task store -client_store = InMemoryTaskStore() - -async def handle_augmented_elicitation(context, params: ElicitRequestParams, task_metadata: TaskMetadata): - """Handle task-augmented elicitation from server.""" - # Create a task for this elicitation - task = await client_store.create_task(task_metadata) - - # Start async work (e.g., show UI, wait for user) - async def complete_elicitation(): - # ... do async work ... - result = ElicitResult(action="accept", content={"confirm": True}) - await client_store.store_result(task.taskId, result) - await client_store.update_task(task.taskId, status="completed") - - context.session._task_group.start_soon(complete_elicitation) - - # Return task reference immediately - return CreateTaskResult(task=task) - -async def handle_get_task(context, params): - """Handle tasks/get from server.""" - task = await client_store.get_task(params.taskId) - return GetTaskResult( - taskId=task.taskId, - status=task.status, - statusMessage=task.statusMessage, - createdAt=task.createdAt, - lastUpdatedAt=task.lastUpdatedAt, - ttl=task.ttl, - pollInterval=100, - ) - -async def handle_get_task_result(context, params): - """Handle tasks/result from server.""" - result = await client_store.get_result(params.taskId) - return GetTaskPayloadResult.model_validate(result.model_dump()) - -task_handlers = ExperimentalTaskHandlers( - augmented_elicitation=handle_augmented_elicitation, - get_task=handle_get_task, - get_task_result=handle_get_task_result, -) - -async with ClientSession( - read, - write, - experimental_task_handlers=task_handlers, -) as session: - # Client now accepts task-augmented elicitation from server - await session.initialize() -``` - -This enables flows where: - -1. Client calls a task-augmented tool -2. Server's tool work calls `task.elicit_as_task()` -3. Client receives task-augmented elicitation -4. Client creates its own task, does async work -5. Server polls client's task -6. Eventually both tasks complete - -## Complete Example - -A client that handles all task scenarios: - -```python -import anyio -from mcp.client.session import ClientSession -from mcp.client.stdio import stdio_client -from mcp.types import CallToolResult, ElicitRequestParams, ElicitResult - - -async def elicitation_callback(context, params: ElicitRequestParams) -> ElicitResult: - print(f"\n[Elicitation] {params.message}") - response = input("Confirm? (y/n): ") - return ElicitResult(action="accept", content={"confirm": response.lower() == "y"}) - - -async def main(): - async with stdio_client(command="python", args=["server.py"]) as (read, write): - async with ClientSession( - read, - write, - elicitation_callback=elicitation_callback, - ) as session: - await session.initialize() - - # List available tools - tools = await session.list_tools() - print("Tools:", [t.name for t in tools.tools]) - - # Call a task-augmented tool - print("\nCalling task tool...") - result = await session.experimental.call_tool_as_task( - "confirm_action", - {"action": "delete files"}, - ) - task_id = result.task.taskId - print(f"Task created: {task_id}") - - # Poll and handle input_required - async for status in session.experimental.poll_task(task_id): - print(f"Status: {status.status}") - - if status.status == "input_required": - final = await session.experimental.get_task_result(task_id, CallToolResult) - print(f"Result: {final.content[0].text}") - break - - if status.status == "completed": - final = await session.experimental.get_task_result(task_id, CallToolResult) - print(f"Result: {final.content[0].text}") - - -if __name__ == "__main__": - anyio.run(main) -``` - -## Error Handling - -Handle task errors gracefully: - -```python -from mcp.shared.exceptions import MCPError - -try: - result = await session.experimental.call_tool_as_task("my_tool", args) - task_id = result.task.taskId - - async for status in session.experimental.poll_task(task_id): - if status.status == "failed": - raise RuntimeError(f"Task failed: {status.statusMessage}") - - final = await session.experimental.get_task_result(task_id, CallToolResult) - -except MCPError as e: - print(f"MCP error: {e.message}") -except Exception as e: - print(f"Error: {e}") -``` - -## Next Steps - -- [Server Implementation](tasks-server.md) - Build task-supporting servers -- [Tasks Overview](tasks.md) - Review lifecycle and concepts diff --git a/docs/experimental/tasks-server.md b/docs/experimental/tasks-server.md deleted file mode 100644 index b350ee3bb6..0000000000 --- a/docs/experimental/tasks-server.md +++ /dev/null @@ -1,577 +0,0 @@ -# Server Task Implementation - -!!! warning "Experimental" - - Tasks are an experimental feature. The API may change without notice. - -This guide covers implementing task support in MCP servers, from basic setup to advanced patterns like elicitation and sampling within tasks. - -## Quick Start - -The simplest way to add task support: - -```python -from mcp.server import Server -from mcp.server.experimental.task_context import ServerTaskContext -from mcp.types import CallToolResult, CreateTaskResult, TextContent, Tool, ToolExecution, TASK_REQUIRED - -server = Server("my-server") -server.experimental.enable_tasks() # Registers all task handlers automatically - -@server.list_tools() -async def list_tools(): - return [ - Tool( - name="process_data", - description="Process data asynchronously", - inputSchema={"type": "object", "properties": {"input": {"type": "string"}}}, - execution=ToolExecution(taskSupport=TASK_REQUIRED), - ) - ] - -@server.call_tool() -async def handle_tool(name: str, arguments: dict) -> CallToolResult | CreateTaskResult: - if name == "process_data": - return await handle_process_data(arguments) - return CallToolResult(content=[TextContent(type="text", text=f"Unknown: {name}")], isError=True) - -async def handle_process_data(arguments: dict) -> CreateTaskResult: - ctx = server.request_context - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - async def work(task: ServerTaskContext) -> CallToolResult: - await task.update_status("Processing...") - result = arguments.get("input", "").upper() - return CallToolResult(content=[TextContent(type="text", text=result)]) - - return await ctx.experimental.run_task(work) -``` - -That's it. `enable_tasks()` automatically: - -- Creates an in-memory task store -- Registers handlers for `tasks/get`, `tasks/result`, `tasks/list`, `tasks/cancel` -- Updates server capabilities - -## Tool Declaration - -Tools declare task support via the `execution.taskSupport` field: - -```python -from mcp.types import Tool, ToolExecution, TASK_REQUIRED, TASK_OPTIONAL, TASK_FORBIDDEN - -Tool( - name="my_tool", - inputSchema={"type": "object"}, - execution=ToolExecution(taskSupport=TASK_REQUIRED), # or TASK_OPTIONAL, TASK_FORBIDDEN -) -``` - -| Value | Meaning | -|-------|---------| -| `TASK_REQUIRED` | Tool **must** be called as a task | -| `TASK_OPTIONAL` | Tool supports both sync and task execution | -| `TASK_FORBIDDEN` | Tool **cannot** be called as a task (default) | - -Validate the request matches your tool's requirements: - -```python -@server.call_tool() -async def handle_tool(name: str, arguments: dict): - ctx = server.request_context - - if name == "required_task_tool": - ctx.experimental.validate_task_mode(TASK_REQUIRED) # Raises if not task mode - return await handle_as_task(arguments) - - elif name == "optional_task_tool": - if ctx.experimental.is_task: - return await handle_as_task(arguments) - else: - return handle_sync(arguments) -``` - -## The run_task Pattern - -`run_task()` is the recommended way to execute task work: - -```python -async def handle_my_tool(arguments: dict) -> CreateTaskResult: - ctx = server.request_context - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - async def work(task: ServerTaskContext) -> CallToolResult: - # Your work here - return CallToolResult(content=[TextContent(type="text", text="Done")]) - - return await ctx.experimental.run_task(work) -``` - -**What `run_task()` does:** - -1. Creates a task in the store -2. Spawns your work function in the background -3. Returns `CreateTaskResult` immediately -4. Auto-completes the task when your function returns -5. Auto-fails the task if your function raises - -**The `ServerTaskContext` provides:** - -- `task.task_id` - The task identifier -- `task.update_status(message)` - Update progress -- `task.complete(result)` - Explicitly complete (usually automatic) -- `task.fail(error)` - Explicitly fail -- `task.is_cancelled` - Check if cancellation requested - -## Status Updates - -Keep clients informed of progress: - -```python -async def work(task: ServerTaskContext) -> CallToolResult: - await task.update_status("Starting...") - - for i, item in enumerate(items): - await task.update_status(f"Processing {i+1}/{len(items)}") - await process_item(item) - - await task.update_status("Finalizing...") - return CallToolResult(content=[TextContent(type="text", text="Complete")]) -``` - -Status messages appear in `tasks/get` responses, letting clients show progress to users. - -## Elicitation Within Tasks - -Tasks can request user input via elicitation. This transitions the task to `input_required` status. - -### Form Elicitation - -Collect structured data from the user: - -```python -async def work(task: ServerTaskContext) -> CallToolResult: - await task.update_status("Waiting for confirmation...") - - result = await task.elicit( - message="Delete these files?", - requestedSchema={ - "type": "object", - "properties": { - "confirm": {"type": "boolean"}, - "reason": {"type": "string"}, - }, - "required": ["confirm"], - }, - ) - - if result.action == "accept" and result.content.get("confirm"): - # User confirmed - return CallToolResult(content=[TextContent(type="text", text="Files deleted")]) - else: - # User declined or cancelled - return CallToolResult(content=[TextContent(type="text", text="Cancelled")]) -``` - -### URL Elicitation - -Direct users to external URLs for OAuth, payments, or other out-of-band flows: - -```python -async def work(task: ServerTaskContext) -> CallToolResult: - await task.update_status("Waiting for OAuth...") - - result = await task.elicit_url( - message="Please authorize with GitHub", - url="https://github.com/login/oauth/authorize?client_id=...", - elicitation_id="oauth-github-123", - ) - - if result.action == "accept": - # User completed OAuth flow - return CallToolResult(content=[TextContent(type="text", text="Connected to GitHub")]) - else: - return CallToolResult(content=[TextContent(type="text", text="OAuth cancelled")]) -``` - -## Sampling Within Tasks - -Tasks can request LLM completions from the client: - -```python -from mcp.types import SamplingMessage, TextContent - -async def work(task: ServerTaskContext) -> CallToolResult: - await task.update_status("Generating response...") - - result = await task.create_message( - messages=[ - SamplingMessage( - role="user", - content=TextContent(type="text", text="Write a haiku about coding"), - ) - ], - max_tokens=100, - ) - - haiku = result.content.text if isinstance(result.content, TextContent) else "Error" - return CallToolResult(content=[TextContent(type="text", text=haiku)]) -``` - -Sampling supports additional parameters: - -```python -result = await task.create_message( - messages=[...], - max_tokens=500, - system_prompt="You are a helpful assistant", - temperature=0.7, - stop_sequences=["\n\n"], - model_preferences=ModelPreferences(hints=[ModelHint(name="claude-3")]), -) -``` - -## Cancellation Support - -Check for cancellation in long-running work: - -```python -async def work(task: ServerTaskContext) -> CallToolResult: - for i in range(1000): - if task.is_cancelled: - # Clean up and exit - return CallToolResult(content=[TextContent(type="text", text="Cancelled")]) - - await task.update_status(f"Step {i}/1000") - await process_step(i) - - return CallToolResult(content=[TextContent(type="text", text="Complete")]) -``` - -The SDK's default cancel handler updates the task status. Your work function should check `is_cancelled` periodically. - -## Custom Task Store - -For production, implement `TaskStore` with persistent storage: - -```python -from mcp.shared.experimental.tasks.store import TaskStore -from mcp.types import Task, TaskMetadata, Result - -class RedisTaskStore(TaskStore): - def __init__(self, redis_client): - self.redis = redis_client - - async def create_task(self, metadata: TaskMetadata, task_id: str | None = None) -> Task: - # Create and persist task - ... - - async def get_task(self, task_id: str) -> Task | None: - # Retrieve task from Redis - ... - - async def update_task(self, task_id: str, status: str | None = None, ...) -> Task: - # Update and persist - ... - - async def store_result(self, task_id: str, result: Result) -> None: - # Store result in Redis - ... - - async def get_result(self, task_id: str) -> Result | None: - # Retrieve result - ... - - # ... implement remaining methods -``` - -Use your custom store: - -```python -store = RedisTaskStore(redis_client) -server.experimental.enable_tasks(store=store) -``` - -## Complete Example - -A server with multiple task-supporting tools: - -```python -from mcp.server import Server -from mcp.server.experimental.task_context import ServerTaskContext -from mcp.types import ( - CallToolResult, CreateTaskResult, TextContent, Tool, ToolExecution, - SamplingMessage, TASK_REQUIRED, -) - -server = Server("task-demo") -server.experimental.enable_tasks() - - -@server.list_tools() -async def list_tools(): - return [ - Tool( - name="confirm_action", - description="Requires user confirmation", - inputSchema={"type": "object", "properties": {"action": {"type": "string"}}}, - execution=ToolExecution(taskSupport=TASK_REQUIRED), - ), - Tool( - name="generate_text", - description="Generate text via LLM", - inputSchema={"type": "object", "properties": {"prompt": {"type": "string"}}}, - execution=ToolExecution(taskSupport=TASK_REQUIRED), - ), - ] - - -async def handle_confirm_action(arguments: dict) -> CreateTaskResult: - ctx = server.request_context - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - action = arguments.get("action", "unknown action") - - async def work(task: ServerTaskContext) -> CallToolResult: - result = await task.elicit( - message=f"Confirm: {action}?", - requestedSchema={ - "type": "object", - "properties": {"confirm": {"type": "boolean"}}, - "required": ["confirm"], - }, - ) - - if result.action == "accept" and result.content.get("confirm"): - return CallToolResult(content=[TextContent(type="text", text=f"Executed: {action}")]) - return CallToolResult(content=[TextContent(type="text", text="Cancelled")]) - - return await ctx.experimental.run_task(work) - - -async def handle_generate_text(arguments: dict) -> CreateTaskResult: - ctx = server.request_context - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - prompt = arguments.get("prompt", "Hello") - - async def work(task: ServerTaskContext) -> CallToolResult: - await task.update_status("Generating...") - - result = await task.create_message( - messages=[SamplingMessage(role="user", content=TextContent(type="text", text=prompt))], - max_tokens=200, - ) - - text = result.content.text if isinstance(result.content, TextContent) else "Error" - return CallToolResult(content=[TextContent(type="text", text=text)]) - - return await ctx.experimental.run_task(work) - - -@server.call_tool() -async def handle_tool(name: str, arguments: dict) -> CallToolResult | CreateTaskResult: - if name == "confirm_action": - return await handle_confirm_action(arguments) - elif name == "generate_text": - return await handle_generate_text(arguments) - return CallToolResult(content=[TextContent(type="text", text=f"Unknown: {name}")], isError=True) -``` - -## Error Handling in Tasks - -Tasks handle errors automatically, but you can also fail explicitly: - -```python -async def work(task: ServerTaskContext) -> CallToolResult: - try: - result = await risky_operation() - return CallToolResult(content=[TextContent(type="text", text=result)]) - except PermissionError: - await task.fail("Access denied - insufficient permissions") - raise - except TimeoutError: - await task.fail("Operation timed out after 30 seconds") - raise -``` - -When `run_task()` catches an exception, it automatically: - -1. Marks the task as `failed` -2. Sets `statusMessage` to the exception message -3. Propagates the exception (which is caught by the task group) - -For custom error messages, call `task.fail()` before raising. - -## HTTP Transport Example - -For web applications, use the Streamable HTTP transport: - -```python -import uvicorn - -from mcp.server import Server -from mcp.server.experimental.task_context import ServerTaskContext -from mcp.types import ( - CallToolResult, CreateTaskResult, TextContent, Tool, ToolExecution, TASK_REQUIRED, -) - - -server = Server("http-task-server") -server.experimental.enable_tasks() - - -@server.list_tools() -async def list_tools(): - return [ - Tool( - name="long_operation", - description="A long-running operation", - inputSchema={"type": "object", "properties": {"duration": {"type": "number"}}}, - execution=ToolExecution(taskSupport=TASK_REQUIRED), - ) - ] - - -async def handle_long_operation(arguments: dict) -> CreateTaskResult: - ctx = server.request_context - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - duration = arguments.get("duration", 5) - - async def work(task: ServerTaskContext) -> CallToolResult: - import anyio - for i in range(int(duration)): - await task.update_status(f"Step {i+1}/{int(duration)}") - await anyio.sleep(1) - return CallToolResult(content=[TextContent(type="text", text=f"Completed after {duration}s")]) - - return await ctx.experimental.run_task(work) - - -@server.call_tool() -async def handle_tool(name: str, arguments: dict) -> CallToolResult | CreateTaskResult: - if name == "long_operation": - return await handle_long_operation(arguments) - return CallToolResult(content=[TextContent(type="text", text=f"Unknown: {name}")], isError=True) - - -if __name__ == "__main__": - uvicorn.run(server.streamable_http_app(), host="127.0.0.1", port=8000) -``` - -## Testing Task Servers - -Test task functionality with the SDK's testing utilities: - -```python -import pytest -import anyio -from mcp.client.session import ClientSession -from mcp.types import CallToolResult - - -@pytest.mark.anyio -async def test_task_tool(): - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream(10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream(10) - - async def run_server(): - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options(), - ) - - async def run_client(): - async with ClientSession(server_to_client_receive, client_to_server_send) as session: - await session.initialize() - - # Call the tool as a task - result = await session.experimental.call_tool_as_task("my_tool", {"arg": "value"}) - task_id = result.task.taskId - assert result.task.status == "working" - - # Poll until complete - async for status in session.experimental.poll_task(task_id): - if status.status in ("completed", "failed"): - break - - # Get result - final = await session.experimental.get_task_result(task_id, CallToolResult) - assert len(final.content) > 0 - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) -``` - -## Best Practices - -### Keep Work Functions Focused - -```python -# Good: focused work function -async def work(task: ServerTaskContext) -> CallToolResult: - await task.update_status("Validating...") - validate_input(arguments) - - await task.update_status("Processing...") - result = await process_data(arguments) - - return CallToolResult(content=[TextContent(type="text", text=result)]) -``` - -### Check Cancellation in Loops - -```python -async def work(task: ServerTaskContext) -> CallToolResult: - results = [] - for item in large_dataset: - if task.is_cancelled: - return CallToolResult(content=[TextContent(type="text", text="Cancelled")]) - - results.append(await process(item)) - - return CallToolResult(content=[TextContent(type="text", text=str(results))]) -``` - -### Use Meaningful Status Messages - -```python -async def work(task: ServerTaskContext) -> CallToolResult: - await task.update_status("Connecting to database...") - db = await connect() - - await task.update_status("Fetching records (0/1000)...") - for i, record in enumerate(records): - if i % 100 == 0: - await task.update_status(f"Processing records ({i}/1000)...") - await process(record) - - await task.update_status("Finalizing results...") - return CallToolResult(content=[TextContent(type="text", text="Done")]) -``` - -### Handle Elicitation Responses - -```python -async def work(task: ServerTaskContext) -> CallToolResult: - result = await task.elicit(message="Continue?", requestedSchema={...}) - - match result.action: - case "accept": - # User accepted, process content - return await process_accepted(result.content) - case "decline": - # User explicitly declined - return CallToolResult(content=[TextContent(type="text", text="User declined")]) - case "cancel": - # User cancelled the elicitation - return CallToolResult(content=[TextContent(type="text", text="Cancelled")]) -``` - -## Next Steps - -- [Client Usage](tasks-client.md) - Learn how clients interact with task servers -- [Tasks Overview](tasks.md) - Review lifecycle and concepts diff --git a/docs/experimental/tasks.md b/docs/experimental/tasks.md deleted file mode 100644 index 2d4d06a025..0000000000 --- a/docs/experimental/tasks.md +++ /dev/null @@ -1,188 +0,0 @@ -# Tasks - -!!! warning "Experimental" - - Tasks are an experimental feature tracking the draft MCP specification. - The API may change without notice. - -Tasks enable asynchronous request handling in MCP. Instead of blocking until an operation completes, the receiver creates a task, returns immediately, and the requestor polls for the result. - -## When to Use Tasks - -Tasks are designed for operations that: - -- Take significant time (seconds to minutes) -- Need progress updates during execution -- Require user input mid-execution (elicitation, sampling) -- Should run without blocking the requestor - -Common use cases: - -- Long-running data processing -- Multi-step workflows with user confirmation -- LLM-powered operations requiring sampling -- OAuth flows requiring user browser interaction - -## Task Lifecycle - -```text - ┌─────────────┐ - │ working │ - └──────┬──────┘ - │ - ┌────────────┼────────────┐ - │ │ │ - ▼ ▼ ▼ - ┌────────────┐ ┌───────────┐ ┌───────────┐ - │ completed │ │ failed │ │ cancelled │ - └────────────┘ └───────────┘ └───────────┘ - ▲ - │ - ┌────────┴────────┐ - │ input_required │◄──────┐ - └────────┬────────┘ │ - │ │ - └────────────────┘ -``` - -| Status | Description | -|--------|-------------| -| `working` | Task is being processed | -| `input_required` | Receiver needs input from requestor (elicitation/sampling) | -| `completed` | Task finished successfully | -| `failed` | Task encountered an error | -| `cancelled` | Task was cancelled by requestor | - -Terminal states (`completed`, `failed`, `cancelled`) are final—tasks cannot transition out of them. - -## Bidirectional Flow - -Tasks work in both directions: - -**Client → Server** (most common): - -```text -Client Server - │ │ - │── tools/call (task) ──────────────>│ Creates task - │<── CreateTaskResult ───────────────│ - │ │ - │── tasks/get ──────────────────────>│ - │<── status: working ────────────────│ - │ │ ... work continues ... - │── tasks/get ──────────────────────>│ - │<── status: completed ──────────────│ - │ │ - │── tasks/result ───────────────────>│ - │<── CallToolResult ─────────────────│ -``` - -**Server → Client** (for elicitation/sampling): - -```text -Server Client - │ │ - │── elicitation/create (task) ──────>│ Creates task - │<── CreateTaskResult ───────────────│ - │ │ - │── tasks/get ──────────────────────>│ - │<── status: working ────────────────│ - │ │ ... user interaction ... - │── tasks/get ──────────────────────>│ - │<── status: completed ──────────────│ - │ │ - │── tasks/result ───────────────────>│ - │<── ElicitResult ───────────────────│ -``` - -## Key Concepts - -### Task Metadata - -When augmenting a request with task execution, include `TaskMetadata`: - -```python -from mcp.types import TaskMetadata - -task = TaskMetadata(ttl=60000) # TTL in milliseconds -``` - -The `ttl` (time-to-live) specifies how long the task and result are retained after completion. - -### Task Store - -Servers persist task state in a `TaskStore`. The SDK provides `InMemoryTaskStore` for development: - -```python -from mcp.shared.experimental.tasks import InMemoryTaskStore - -store = InMemoryTaskStore() -``` - -For production, implement `TaskStore` with a database or distributed cache. - -### Capabilities - -Both servers and clients declare task support through capabilities: - -**Server capabilities:** - -- `tasks.requests.tools.call` - Server accepts task-augmented tool calls - -**Client capabilities:** - -- `tasks.requests.sampling.createMessage` - Client accepts task-augmented sampling -- `tasks.requests.elicitation.create` - Client accepts task-augmented elicitation - -The SDK manages these automatically when you enable task support. - -## Quick Example - -**Server** (simplified API): - -```python -from mcp.server import Server -from mcp.server.experimental.task_context import ServerTaskContext -from mcp.types import CallToolResult, TextContent, TASK_REQUIRED - -server = Server("my-server") -server.experimental.enable_tasks() # One-line setup - -@server.call_tool() -async def handle_tool(name: str, arguments: dict): - ctx = server.request_context - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - async def work(task: ServerTaskContext): - await task.update_status("Processing...") - # ... do work ... - return CallToolResult(content=[TextContent(type="text", text="Done!")]) - - return await ctx.experimental.run_task(work) -``` - -**Client:** - -```python -from mcp.client.session import ClientSession -from mcp.types import CallToolResult - -async with ClientSession(read, write) as session: - await session.initialize() - - # Call tool as task - result = await session.experimental.call_tool_as_task("my_tool", {"arg": "value"}) - task_id = result.task.taskId - - # Poll until done - async for status in session.experimental.poll_task(task_id): - print(f"Status: {status.status}") - - # Get result - final = await session.experimental.get_task_result(task_id, CallToolResult) -``` - -## Next Steps - -- [Server Implementation](tasks-server.md) - Build task-supporting servers -- [Client Usage](tasks-client.md) - Call and poll tasks from clients diff --git a/docs/migration.md b/docs/migration.md index 8b70885e8d..9850f74cd4 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -595,7 +595,7 @@ The `RequestContext` class has been split to separate shared fields from server- **`RequestContext` changes:** - Type parameters reduced from `RequestContext[SessionT, LifespanContextT, RequestT]` to `RequestContext[SessionT]` -- Server-specific fields (`lifespan_context`, `experimental`, `request`, `close_sse_stream`, `close_standalone_sse_stream`) moved to new `ServerRequestContext` class in `mcp.server.context` +- Server-specific fields (`lifespan_context`, `request`, `close_sse_stream`, `close_standalone_sse_stream`) moved to new `ServerRequestContext` class in `mcp.server.context` **Before (v1):** @@ -861,7 +861,7 @@ server = Server("my-server", on_list_tools=handle_list_tools, on_call_tool=handl **Key differences:** -- Handlers receive `(ctx, params)` instead of the full request object or unpacked arguments. `ctx` is a `ServerRequestContext` with `session`, `lifespan_context`, and `experimental` fields (plus `request_id`, `meta`, etc. for request handlers). `params` is the typed request params object. +- Handlers receive `(ctx, params)` instead of the full request object or unpacked arguments. `ctx` is a `ServerRequestContext` with `session` and `lifespan_context` fields (plus `request_id`, `meta`, etc. for request handlers). `params` is the typed request params object. - Handlers return the full result type (e.g. `ListToolsResult`) rather than unwrapped values (e.g. `list[Tool]`). - The automatic `jsonschema` input/output validation that the old `call_tool()` decorator performed has been removed. There is no built-in replacement — if you relied on schema validation in the lowlevel server, you will need to validate inputs yourself in your handler. @@ -872,7 +872,7 @@ All handlers receive `ctx: ServerRequestContext` as the first argument. The seco | v1 decorator | v2 constructor kwarg | `params` type | return type | |---|---|---|---| | `@server.list_tools()` | `on_list_tools` | `PaginatedRequestParams \| None` | `ListToolsResult` | -| `@server.call_tool()` | `on_call_tool` | `CallToolRequestParams` | `CallToolResult \| CreateTaskResult` | +| `@server.call_tool()` | `on_call_tool` | `CallToolRequestParams` | `CallToolResult` | | `@server.list_resources()` | `on_list_resources` | `PaginatedRequestParams \| None` | `ListResourcesResult` | | `@server.list_resource_templates()` | `on_list_resource_templates` | `PaginatedRequestParams \| None` | `ListResourceTemplatesResult` | | `@server.read_resource()` | `on_read_resource` | `ReadResourceRequestParams` | `ReadResourceResult` | @@ -1039,37 +1039,11 @@ from mcp.server import ServerRequestContext # but None in notification handlers ``` -### Experimental: task handler decorators removed +### Experimental Tasks support removed -The experimental decorator methods on `ExperimentalHandlers` (`@server.experimental.list_tasks()`, `@server.experimental.get_task()`, etc.) have been removed. +Tasks (SEP-1686) have been removed from the MCP specification and are no longer part of this SDK. The `mcp.client.experimental`, `mcp.server.experimental`, `mcp.shared.experimental`, and `mcp.server.lowlevel.experimental` modules have been removed, along with all `Task*` types, the `tasks` capability fields, `Tool.execution`, and the `experimental` properties on `ClientSession`, `ServerSession`, `Server`, and `ServerRequestContext`. -Default task handlers are still registered automatically via `server.experimental.enable_tasks()`. Custom handlers can be passed as `on_*` kwargs to override specific defaults. - -**Before (v1):** - -```python -server = Server("my-server") -server.experimental.enable_tasks() - -@server.experimental.get_task() -async def custom_get_task(request: GetTaskRequest) -> GetTaskResult: - ... -``` - -**After (v2):** - -```python -from mcp.server import Server, ServerRequestContext -from mcp.types import GetTaskRequestParams, GetTaskResult - - -async def custom_get_task(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: - ... - - -server = Server("my-server") -server.experimental.enable_tasks(on_get_task=custom_get_task) -``` +Tasks are expected to return as a separate MCP extension in a future release. ## Deprecations diff --git a/examples/clients/simple-task-client/README.md b/examples/clients/simple-task-client/README.md deleted file mode 100644 index 103be0f1fb..0000000000 --- a/examples/clients/simple-task-client/README.md +++ /dev/null @@ -1,43 +0,0 @@ -# Simple Task Client - -A minimal MCP client demonstrating polling for task results over streamable HTTP. - -## Running - -First, start the simple-task server in another terminal: - -```bash -cd examples/servers/simple-task -uv run mcp-simple-task -``` - -Then run the client: - -```bash -cd examples/clients/simple-task-client -uv run mcp-simple-task-client -``` - -Use `--url` to connect to a different server. - -## What it does - -1. Connects to the server via streamable HTTP -2. Calls the `long_running_task` tool as a task -3. Polls the task status until completion -4. Retrieves and prints the result - -## Expected output - -```text -Available tools: ['long_running_task'] - -Calling tool as a task... -Task created: - Status: working - Starting work... - Status: working - Processing step 1... - Status: working - Processing step 2... - Status: completed - - -Result: Task completed! -``` diff --git a/examples/clients/simple-task-client/mcp_simple_task_client/__init__.py b/examples/clients/simple-task-client/mcp_simple_task_client/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/examples/clients/simple-task-client/mcp_simple_task_client/__main__.py b/examples/clients/simple-task-client/mcp_simple_task_client/__main__.py deleted file mode 100644 index 2fc2cda8d9..0000000000 --- a/examples/clients/simple-task-client/mcp_simple_task_client/__main__.py +++ /dev/null @@ -1,5 +0,0 @@ -import sys - -from .main import main - -sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/clients/simple-task-client/mcp_simple_task_client/main.py b/examples/clients/simple-task-client/mcp_simple_task_client/main.py deleted file mode 100644 index f9e555c8e6..0000000000 --- a/examples/clients/simple-task-client/mcp_simple_task_client/main.py +++ /dev/null @@ -1,56 +0,0 @@ -"""Simple task client demonstrating MCP tasks polling over streamable HTTP.""" - -import asyncio - -import click -from mcp import ClientSession -from mcp.client.streamable_http import streamable_http_client -from mcp.types import CallToolResult, TextContent - - -async def run(url: str) -> None: - async with streamable_http_client(url) as (read, write): - async with ClientSession(read, write) as session: - await session.initialize() - - # List tools - tools = await session.list_tools() - print(f"Available tools: {[t.name for t in tools.tools]}") - - # Call the tool as a task - print("\nCalling tool as a task...") - - result = await session.experimental.call_tool_as_task( - "long_running_task", - arguments={}, - ttl=60000, - ) - task_id = result.task.task_id - print(f"Task created: {task_id}") - - status = None - # Poll until done (respects server's pollInterval hint) - async for status in session.experimental.poll_task(task_id): - print(f" Status: {status.status} - {status.status_message or ''}") - - # Check final status - if status and status.status != "completed": - print(f"Task ended with status: {status.status}") - return - - # Get the result - task_result = await session.experimental.get_task_result(task_id, CallToolResult) - content = task_result.content[0] - if isinstance(content, TextContent): - print(f"\nResult: {content.text}") - - -@click.command() -@click.option("--url", default="http://localhost:8000/mcp", help="Server URL") -def main(url: str) -> int: - asyncio.run(run(url)) - return 0 - - -if __name__ == "__main__": - main() diff --git a/examples/clients/simple-task-client/pyproject.toml b/examples/clients/simple-task-client/pyproject.toml deleted file mode 100644 index c7abf51159..0000000000 --- a/examples/clients/simple-task-client/pyproject.toml +++ /dev/null @@ -1,43 +0,0 @@ -[project] -name = "mcp-simple-task-client" -version = "0.1.0" -description = "A simple MCP client demonstrating task polling" -readme = "README.md" -requires-python = ">=3.10" -authors = [{ name = "Model Context Protocol a Series of LF Projects, LLC." }] -keywords = ["mcp", "llm", "tasks", "client"] -license = { text = "MIT" } -classifiers = [ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "License :: OSI Approved :: MIT License", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.10", -] -dependencies = ["click>=8.0", "mcp"] - -[project.scripts] -mcp-simple-task-client = "mcp_simple_task_client.main:main" - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.hatch.build.targets.wheel] -packages = ["mcp_simple_task_client"] - -[tool.pyright] -include = ["mcp_simple_task_client"] -venvPath = "." -venv = ".venv" - -[tool.ruff.lint] -select = ["E", "F", "I"] -ignore = [] - -[tool.ruff] -line-length = 120 -target-version = "py310" - -[dependency-groups] -dev = ["pyright>=1.1.378", "ruff>=0.6.9"] diff --git a/examples/clients/simple-task-interactive-client/README.md b/examples/clients/simple-task-interactive-client/README.md deleted file mode 100644 index 3397d3b5d7..0000000000 --- a/examples/clients/simple-task-interactive-client/README.md +++ /dev/null @@ -1,87 +0,0 @@ -# Simple Interactive Task Client - -A minimal MCP client demonstrating responses to interactive tasks (elicitation and sampling). - -## Running - -First, start the interactive task server in another terminal: - -```bash -cd examples/servers/simple-task-interactive -uv run mcp-simple-task-interactive -``` - -Then run the client: - -```bash -cd examples/clients/simple-task-interactive-client -uv run mcp-simple-task-interactive-client -``` - -Use `--url` to connect to a different server. - -## What it does - -1. Connects to the server via streamable HTTP -2. Calls `confirm_delete` - server asks for confirmation, client responds via terminal -3. Calls `write_haiku` - server requests LLM completion, client returns a hardcoded haiku - -## Key concepts - -### Elicitation callback - -```python -async def elicitation_callback(context, params) -> ElicitResult: - # Handle user input request from server - return ElicitResult(action="accept", content={"confirm": True}) -``` - -### Sampling callback - -```python -async def sampling_callback(context, params) -> CreateMessageResult: - # Handle LLM completion request from server - return CreateMessageResult(model="...", role="assistant", content=...) -``` - -### Using call_tool_as_task - -```python -# Call a tool as a task (returns immediately with task reference) -result = await session.experimental.call_tool_as_task("tool_name", {"arg": "value"}) -task_id = result.task.task_id - -# Get result - this delivers elicitation/sampling requests and blocks until complete -final = await session.experimental.get_task_result(task_id, CallToolResult) -``` - -**Important**: The `get_task_result()` call is what triggers the delivery of elicitation -and sampling requests to your callbacks. It blocks until the task completes and returns -the final result. - -## Expected output - -```text -Available tools: ['confirm_delete', 'write_haiku'] - ---- Demo 1: Elicitation --- -Calling confirm_delete tool... -Task created: - -[Elicitation] Server asks: Are you sure you want to delete 'important.txt'? -Your response (y/n): y -[Elicitation] Responding with: confirm=True -Result: Deleted 'important.txt' - ---- Demo 2: Sampling --- -Calling write_haiku tool... -Task created: - -[Sampling] Server requests LLM completion for: Write a haiku about autumn leaves -[Sampling] Responding with haiku -Result: -Haiku: -Cherry blossoms fall -Softly on the quiet pond -Spring whispers goodbye -``` diff --git a/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/__init__.py b/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/__main__.py b/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/__main__.py deleted file mode 100644 index 2fc2cda8d9..0000000000 --- a/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/__main__.py +++ /dev/null @@ -1,5 +0,0 @@ -import sys - -from .main import main - -sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py b/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py deleted file mode 100644 index ff5f499280..0000000000 --- a/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py +++ /dev/null @@ -1,137 +0,0 @@ -"""Simple interactive task client demonstrating elicitation and sampling responses. - -This example demonstrates the spec-compliant polling pattern: -1. Poll tasks/get watching for status changes -2. On input_required, call tasks/result to receive elicitation/sampling requests -3. Continue until terminal status, then retrieve final result -""" - -import asyncio - -import click -from mcp import ClientSession -from mcp.client.context import ClientRequestContext -from mcp.client.streamable_http import streamable_http_client -from mcp.types import ( - CallToolResult, - CreateMessageRequestParams, - CreateMessageResult, - ElicitRequestParams, - ElicitResult, - TextContent, -) - - -async def elicitation_callback( - context: ClientRequestContext, - params: ElicitRequestParams, -) -> ElicitResult: - """Handle elicitation requests from the server.""" - print(f"\n[Elicitation] Server asks: {params.message}") - - # Simple terminal prompt - response = input("Your response (y/n): ").strip().lower() - confirmed = response in ("y", "yes", "true", "1") - - print(f"[Elicitation] Responding with: confirm={confirmed}") - return ElicitResult(action="accept", content={"confirm": confirmed}) - - -async def sampling_callback( - context: ClientRequestContext, - params: CreateMessageRequestParams, -) -> CreateMessageResult: - """Handle sampling requests from the server.""" - # Get the prompt from the first message - prompt = "unknown" - if params.messages: - content = params.messages[0].content - if isinstance(content, TextContent): - prompt = content.text - - print(f"\n[Sampling] Server requests LLM completion for: {prompt}") - - # Return a hardcoded haiku (in real use, call your LLM here) - haiku = """Cherry blossoms fall -Softly on the quiet pond -Spring whispers goodbye""" - - print("[Sampling] Responding with haiku") - return CreateMessageResult( - model="mock-haiku-model", - role="assistant", - content=TextContent(type="text", text=haiku), - ) - - -def get_text(result: CallToolResult) -> str: - """Extract text from a CallToolResult.""" - if result.content and isinstance(result.content[0], TextContent): - return result.content[0].text - return "(no text)" - - -async def run(url: str) -> None: - async with streamable_http_client(url) as (read, write): - async with ClientSession( - read, - write, - elicitation_callback=elicitation_callback, - sampling_callback=sampling_callback, - ) as session: - await session.initialize() - - # List tools - tools = await session.list_tools() - print(f"Available tools: {[t.name for t in tools.tools]}") - - # Demo 1: Elicitation (confirm_delete) - print("\n--- Demo 1: Elicitation ---") - print("Calling confirm_delete tool...") - - elicit_task = await session.experimental.call_tool_as_task("confirm_delete", {"filename": "important.txt"}) - elicit_task_id = elicit_task.task.task_id - print(f"Task created: {elicit_task_id}") - - # Poll until terminal, calling tasks/result on input_required - async for status in session.experimental.poll_task(elicit_task_id): - print(f"[Poll] Status: {status.status}") - if status.status == "input_required": - # Server needs input - tasks/result delivers the elicitation request - elicit_result = await session.experimental.get_task_result(elicit_task_id, CallToolResult) - break - else: - # poll_task exited due to terminal status - elicit_result = await session.experimental.get_task_result(elicit_task_id, CallToolResult) - - print(f"Result: {get_text(elicit_result)}") - - # Demo 2: Sampling (write_haiku) - print("\n--- Demo 2: Sampling ---") - print("Calling write_haiku tool...") - - sampling_task = await session.experimental.call_tool_as_task("write_haiku", {"topic": "autumn leaves"}) - sampling_task_id = sampling_task.task.task_id - print(f"Task created: {sampling_task_id}") - - # Poll until terminal, calling tasks/result on input_required - async for status in session.experimental.poll_task(sampling_task_id): - print(f"[Poll] Status: {status.status}") - if status.status == "input_required": - sampling_result = await session.experimental.get_task_result(sampling_task_id, CallToolResult) - break - else: - sampling_result = await session.experimental.get_task_result(sampling_task_id, CallToolResult) - - print(f"Result:\n{get_text(sampling_result)}") - - -@click.command() -@click.option("--url", default="http://localhost:8000/mcp", help="Server URL") -def main(url: str) -> int: - asyncio.run(run(url)) - return 0 - - -if __name__ == "__main__": - main() diff --git a/examples/clients/simple-task-interactive-client/pyproject.toml b/examples/clients/simple-task-interactive-client/pyproject.toml deleted file mode 100644 index 47191573f2..0000000000 --- a/examples/clients/simple-task-interactive-client/pyproject.toml +++ /dev/null @@ -1,43 +0,0 @@ -[project] -name = "mcp-simple-task-interactive-client" -version = "0.1.0" -description = "A simple MCP client demonstrating interactive task responses" -readme = "README.md" -requires-python = ">=3.10" -authors = [{ name = "Model Context Protocol a Series of LF Projects, LLC." }] -keywords = ["mcp", "llm", "tasks", "client", "elicitation", "sampling"] -license = { text = "MIT" } -classifiers = [ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "License :: OSI Approved :: MIT License", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.10", -] -dependencies = ["click>=8.0", "mcp"] - -[project.scripts] -mcp-simple-task-interactive-client = "mcp_simple_task_interactive_client.main:main" - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.hatch.build.targets.wheel] -packages = ["mcp_simple_task_interactive_client"] - -[tool.pyright] -include = ["mcp_simple_task_interactive_client"] -venvPath = "." -venv = ".venv" - -[tool.ruff.lint] -select = ["E", "F", "I"] -ignore = [] - -[tool.ruff] -line-length = 120 -target-version = "py310" - -[dependency-groups] -dev = ["pyright>=1.1.378", "ruff>=0.6.9"] diff --git a/examples/servers/simple-task-interactive/README.md b/examples/servers/simple-task-interactive/README.md deleted file mode 100644 index b8f384cb48..0000000000 --- a/examples/servers/simple-task-interactive/README.md +++ /dev/null @@ -1,74 +0,0 @@ -# Simple Interactive Task Server - -A minimal MCP server demonstrating interactive tasks with elicitation and sampling. - -## Running - -```bash -cd examples/servers/simple-task-interactive -uv run mcp-simple-task-interactive -``` - -The server starts on `http://localhost:8000/mcp` by default. Use `--port` to change. - -## What it does - -This server exposes two tools: - -### `confirm_delete` (demonstrates elicitation) - -Asks the user for confirmation before "deleting" a file. - -- Uses `task.elicit()` to request user input -- Shows the elicitation flow: task -> input_required -> response -> complete - -### `write_haiku` (demonstrates sampling) - -Asks the LLM to write a haiku about a topic. - -- Uses `task.create_message()` to request LLM completion -- Shows the sampling flow: task -> input_required -> response -> complete - -## Usage with the client - -In one terminal, start the server: - -```bash -cd examples/servers/simple-task-interactive -uv run mcp-simple-task-interactive -``` - -In another terminal, run the interactive client: - -```bash -cd examples/clients/simple-task-interactive-client -uv run mcp-simple-task-interactive-client -``` - -## Expected server output - -When a client connects and calls the tools, you'll see: - -```text -Starting server on http://localhost:8000/mcp - -[Server] confirm_delete called for 'important.txt' -[Server] Task created: -[Server] Sending elicitation request to client... -[Server] Received elicitation response: action=accept, content={'confirm': True} -[Server] Completing task with result: Deleted 'important.txt' - -[Server] write_haiku called for topic 'autumn leaves' -[Server] Task created: -[Server] Sending sampling request to client... -[Server] Received sampling response: Cherry blossoms fall -Softly on the quiet pon... -[Server] Completing task with haiku -``` - -## Key concepts - -1. **ServerTaskContext**: Provides `elicit()` and `create_message()` for user interaction -2. **run_task()**: Spawns background work, auto-completes/fails, returns immediately -3. **TaskResultHandler**: Delivers queued messages and routes responses -4. **Response routing**: Responses are routed back to waiting resolvers diff --git a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/__init__.py b/examples/servers/simple-task-interactive/mcp_simple_task_interactive/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/__main__.py b/examples/servers/simple-task-interactive/mcp_simple_task_interactive/__main__.py deleted file mode 100644 index e7ef16530b..0000000000 --- a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/__main__.py +++ /dev/null @@ -1,5 +0,0 @@ -import sys - -from .server import main - -sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py b/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py deleted file mode 100644 index bc06e12088..0000000000 --- a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py +++ /dev/null @@ -1,138 +0,0 @@ -"""Simple interactive task server demonstrating elicitation and sampling. - -This example shows the simplified task API where: -- server.experimental.enable_tasks() sets up all infrastructure -- ctx.experimental.run_task() handles task lifecycle automatically -- ServerTaskContext.elicit() and ServerTaskContext.create_message() queue requests properly -""" - -from typing import Any - -import click -import uvicorn -from mcp import types -from mcp.server import Server, ServerRequestContext -from mcp.server.experimental.task_context import ServerTaskContext - - -async def handle_list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None -) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[ - types.Tool( - name="confirm_delete", - description="Asks for confirmation before deleting (demonstrates elicitation)", - input_schema={ - "type": "object", - "properties": {"filename": {"type": "string"}}, - }, - execution=types.ToolExecution(task_support=types.TASK_REQUIRED), - ), - types.Tool( - name="write_haiku", - description="Asks LLM to write a haiku (demonstrates sampling)", - input_schema={"type": "object", "properties": {"topic": {"type": "string"}}}, - execution=types.ToolExecution(task_support=types.TASK_REQUIRED), - ), - ] - ) - - -async def handle_confirm_delete(ctx: ServerRequestContext, arguments: dict[str, Any]) -> types.CreateTaskResult: - """Handle the confirm_delete tool - demonstrates elicitation.""" - ctx.experimental.validate_task_mode(types.TASK_REQUIRED) - - filename = arguments.get("filename", "unknown.txt") - print(f"\n[Server] confirm_delete called for '{filename}'") - - async def work(task: ServerTaskContext) -> types.CallToolResult: - print(f"[Server] Task {task.task_id} starting elicitation...") - - result = await task.elicit( - message=f"Are you sure you want to delete '{filename}'?", - requested_schema={ - "type": "object", - "properties": {"confirm": {"type": "boolean"}}, - "required": ["confirm"], - }, - ) - - print(f"[Server] Received elicitation response: action={result.action}, content={result.content}") - - if result.action == "accept" and result.content: - confirmed = result.content.get("confirm", False) - text = f"Deleted '{filename}'" if confirmed else "Deletion cancelled" - else: - text = "Deletion cancelled" - - print(f"[Server] Completing task with result: {text}") - return types.CallToolResult(content=[types.TextContent(type="text", text=text)]) - - return await ctx.experimental.run_task(work) - - -async def handle_write_haiku(ctx: ServerRequestContext, arguments: dict[str, Any]) -> types.CreateTaskResult: - """Handle the write_haiku tool - demonstrates sampling.""" - ctx.experimental.validate_task_mode(types.TASK_REQUIRED) - - topic = arguments.get("topic", "nature") - print(f"\n[Server] write_haiku called for topic '{topic}'") - - async def work(task: ServerTaskContext) -> types.CallToolResult: - print(f"[Server] Task {task.task_id} starting sampling...") - - result = await task.create_message( - messages=[ - types.SamplingMessage( - role="user", - content=types.TextContent(type="text", text=f"Write a haiku about {topic}"), - ) - ], - max_tokens=50, - ) - - haiku = "No response" - if isinstance(result.content, types.TextContent): - haiku = result.content.text - - print(f"[Server] Received sampling response: {haiku[:50]}...") - return types.CallToolResult(content=[types.TextContent(type="text", text=f"Haiku:\n{haiku}")]) - - return await ctx.experimental.run_task(work) - - -async def handle_call_tool( - ctx: ServerRequestContext, params: types.CallToolRequestParams -) -> types.CallToolResult | types.CreateTaskResult: - """Dispatch tool calls to their handlers.""" - arguments = params.arguments or {} - - if params.name == "confirm_delete": - return await handle_confirm_delete(ctx, arguments) - elif params.name == "write_haiku": - return await handle_write_haiku(ctx, arguments) - - return types.CallToolResult( - content=[types.TextContent(type="text", text=f"Unknown tool: {params.name}")], - is_error=True, - ) - - -server = Server( - "simple-task-interactive", - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, -) - -# Enable task support - this auto-registers all handlers -server.experimental.enable_tasks() - - -@click.command() -@click.option("--port", default=8000, help="Port to listen on") -def main(port: int) -> int: - starlette_app = server.streamable_http_app() - print(f"Starting server on http://localhost:{port}/mcp") - uvicorn.run(starlette_app, host="127.0.0.1", port=port) - return 0 diff --git a/examples/servers/simple-task-interactive/pyproject.toml b/examples/servers/simple-task-interactive/pyproject.toml deleted file mode 100644 index 4ec9770763..0000000000 --- a/examples/servers/simple-task-interactive/pyproject.toml +++ /dev/null @@ -1,43 +0,0 @@ -[project] -name = "mcp-simple-task-interactive" -version = "0.1.0" -description = "A simple MCP server demonstrating interactive tasks (elicitation & sampling)" -readme = "README.md" -requires-python = ">=3.10" -authors = [{ name = "Model Context Protocol a Series of LF Projects, LLC." }] -keywords = ["mcp", "llm", "tasks", "elicitation", "sampling"] -license = { text = "MIT" } -classifiers = [ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "License :: OSI Approved :: MIT License", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.10", -] -dependencies = ["anyio>=4.5", "click>=8.0", "mcp", "starlette", "uvicorn"] - -[project.scripts] -mcp-simple-task-interactive = "mcp_simple_task_interactive.server:main" - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.hatch.build.targets.wheel] -packages = ["mcp_simple_task_interactive"] - -[tool.pyright] -include = ["mcp_simple_task_interactive"] -venvPath = "." -venv = ".venv" - -[tool.ruff.lint] -select = ["E", "F", "I"] -ignore = [] - -[tool.ruff] -line-length = 120 -target-version = "py310" - -[dependency-groups] -dev = ["pyright>=1.1.378", "ruff>=0.6.9"] diff --git a/examples/servers/simple-task/README.md b/examples/servers/simple-task/README.md deleted file mode 100644 index 6914e0414f..0000000000 --- a/examples/servers/simple-task/README.md +++ /dev/null @@ -1,37 +0,0 @@ -# Simple Task Server - -A minimal MCP server demonstrating the experimental tasks feature over streamable HTTP. - -## Running - -```bash -cd examples/servers/simple-task -uv run mcp-simple-task -``` - -The server starts on `http://localhost:8000/mcp` by default. Use `--port` to change. - -## What it does - -This server exposes a single tool `long_running_task` that: - -1. Must be called as a task (with `task` metadata in the request) -2. Takes ~3 seconds to complete -3. Sends status updates during execution -4. Returns a result when complete - -## Usage with the client - -In one terminal, start the server: - -```bash -cd examples/servers/simple-task -uv run mcp-simple-task -``` - -In another terminal, run the client: - -```bash -cd examples/clients/simple-task-client -uv run mcp-simple-task-client -``` diff --git a/examples/servers/simple-task/mcp_simple_task/__init__.py b/examples/servers/simple-task/mcp_simple_task/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/examples/servers/simple-task/mcp_simple_task/__main__.py b/examples/servers/simple-task/mcp_simple_task/__main__.py deleted file mode 100644 index e7ef16530b..0000000000 --- a/examples/servers/simple-task/mcp_simple_task/__main__.py +++ /dev/null @@ -1,5 +0,0 @@ -import sys - -from .server import main - -sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/servers/simple-task/mcp_simple_task/server.py b/examples/servers/simple-task/mcp_simple_task/server.py deleted file mode 100644 index 7583cd8f0e..0000000000 --- a/examples/servers/simple-task/mcp_simple_task/server.py +++ /dev/null @@ -1,70 +0,0 @@ -"""Simple task server demonstrating MCP tasks over streamable HTTP.""" - -import anyio -import click -import uvicorn -from mcp import types -from mcp.server import Server, ServerRequestContext -from mcp.server.experimental.task_context import ServerTaskContext - - -async def handle_list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None -) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[ - types.Tool( - name="long_running_task", - description="A task that takes a few seconds to complete with status updates", - input_schema={"type": "object", "properties": {}}, - execution=types.ToolExecution(task_support=types.TASK_REQUIRED), - ) - ] - ) - - -async def handle_call_tool( - ctx: ServerRequestContext, params: types.CallToolRequestParams -) -> types.CallToolResult | types.CreateTaskResult: - """Dispatch tool calls to their handlers.""" - if params.name == "long_running_task": - ctx.experimental.validate_task_mode(types.TASK_REQUIRED) - - async def work(task: ServerTaskContext) -> types.CallToolResult: - await task.update_status("Starting work...") - await anyio.sleep(1) - - await task.update_status("Processing step 1...") - await anyio.sleep(1) - - await task.update_status("Processing step 2...") - await anyio.sleep(1) - - return types.CallToolResult(content=[types.TextContent(type="text", text="Task completed!")]) - - return await ctx.experimental.run_task(work) - - return types.CallToolResult( - content=[types.TextContent(type="text", text=f"Unknown tool: {params.name}")], - is_error=True, - ) - - -server = Server( - "simple-task-server", - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, -) - -# One-line setup: auto-registers get_task, get_task_result, list_tasks, cancel_task -server.experimental.enable_tasks() - - -@click.command() -@click.option("--port", default=8000, help="Port to listen on") -def main(port: int) -> int: - starlette_app = server.streamable_http_app() - - print(f"Starting server on http://localhost:{port}/mcp") - uvicorn.run(starlette_app, host="127.0.0.1", port=port) - return 0 diff --git a/examples/servers/simple-task/pyproject.toml b/examples/servers/simple-task/pyproject.toml deleted file mode 100644 index 921a1c34fc..0000000000 --- a/examples/servers/simple-task/pyproject.toml +++ /dev/null @@ -1,43 +0,0 @@ -[project] -name = "mcp-simple-task" -version = "0.1.0" -description = "A simple MCP server demonstrating tasks" -readme = "README.md" -requires-python = ">=3.10" -authors = [{ name = "Model Context Protocol a Series of LF Projects, LLC." }] -keywords = ["mcp", "llm", "tasks"] -license = { text = "MIT" } -classifiers = [ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "License :: OSI Approved :: MIT License", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.10", -] -dependencies = ["anyio>=4.5", "click>=8.0", "mcp", "starlette", "uvicorn"] - -[project.scripts] -mcp-simple-task = "mcp_simple_task.server:main" - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.hatch.build.targets.wheel] -packages = ["mcp_simple_task"] - -[tool.pyright] -include = ["mcp_simple_task"] -venvPath = "." -venv = ".venv" - -[tool.ruff.lint] -select = ["E", "F", "I"] -ignore = [] - -[tool.ruff] -line-length = 120 -target-version = "py310" - -[dependency-groups] -dev = ["pyright>=1.1.378", "ruff>=0.6.9"] diff --git a/mkdocs.yml b/mkdocs.yml index e48c64242d..cb89faf0f0 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -19,12 +19,6 @@ nav: - Low-Level Server: low-level-server.md - Authorization: authorization.md - Testing: testing.md - - Experimental: - - Overview: experimental/index.md - - Tasks: - - Introduction: experimental/tasks.md - - Server Implementation: experimental/tasks-server.md - - Client Usage: experimental/tasks-client.md - API Reference: api/ theme: diff --git a/src/mcp/client/experimental/__init__.py b/src/mcp/client/experimental/__init__.py deleted file mode 100644 index 8d74cb3044..0000000000 --- a/src/mcp/client/experimental/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""Experimental client features. - -WARNING: These APIs are experimental and may change without notice. -""" - -from mcp.client.experimental.tasks import ExperimentalClientFeatures - -__all__ = ["ExperimentalClientFeatures"] diff --git a/src/mcp/client/experimental/task_handlers.py b/src/mcp/client/experimental/task_handlers.py deleted file mode 100644 index 0ab513236a..0000000000 --- a/src/mcp/client/experimental/task_handlers.py +++ /dev/null @@ -1,293 +0,0 @@ -"""Experimental task handler protocols for server -> client requests. - -This module provides Protocol types and default handlers for when servers -send task-related requests to clients (the reverse of normal client -> server flow). - -WARNING: These APIs are experimental and may change without notice. - -Use cases: -- Server sends task-augmented sampling/elicitation request to client -- Client creates a local task, spawns background work, returns CreateTaskResult -- Server polls client's task status via tasks/get, tasks/result, etc. -""" - -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Protocol - -from pydantic import TypeAdapter - -from mcp import types -from mcp.shared._context import RequestContext -from mcp.shared.session import RequestResponder - -if TYPE_CHECKING: - from mcp.client.session import ClientSession - - -class GetTaskHandlerFnT(Protocol): - """Handler for tasks/get requests from server. - - WARNING: This is experimental and may change without notice. - """ - - async def __call__( - self, - context: RequestContext[ClientSession], - params: types.GetTaskRequestParams, - ) -> types.GetTaskResult | types.ErrorData: ... # pragma: no branch - - -class GetTaskResultHandlerFnT(Protocol): - """Handler for tasks/result requests from server. - - WARNING: This is experimental and may change without notice. - """ - - async def __call__( - self, - context: RequestContext[ClientSession], - params: types.GetTaskPayloadRequestParams, - ) -> types.GetTaskPayloadResult | types.ErrorData: ... # pragma: no branch - - -class ListTasksHandlerFnT(Protocol): - """Handler for tasks/list requests from server. - - WARNING: This is experimental and may change without notice. - """ - - async def __call__( - self, - context: RequestContext[ClientSession], - params: types.PaginatedRequestParams | None, - ) -> types.ListTasksResult | types.ErrorData: ... # pragma: no branch - - -class CancelTaskHandlerFnT(Protocol): - """Handler for tasks/cancel requests from server. - - WARNING: This is experimental and may change without notice. - """ - - async def __call__( - self, - context: RequestContext[ClientSession], - params: types.CancelTaskRequestParams, - ) -> types.CancelTaskResult | types.ErrorData: ... # pragma: no branch - - -class TaskAugmentedSamplingFnT(Protocol): - """Handler for task-augmented sampling/createMessage requests from server. - - When server sends a CreateMessageRequest with task field, this callback - is invoked. The callback should create a task, spawn background work, - and return CreateTaskResult immediately. - - WARNING: This is experimental and may change without notice. - """ - - async def __call__( - self, - context: RequestContext[ClientSession], - params: types.CreateMessageRequestParams, - task_metadata: types.TaskMetadata, - ) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch - - -class TaskAugmentedElicitationFnT(Protocol): - """Handler for task-augmented elicitation/create requests from server. - - When server sends an ElicitRequest with task field, this callback - is invoked. The callback should create a task, spawn background work, - and return CreateTaskResult immediately. - - WARNING: This is experimental and may change without notice. - """ - - async def __call__( - self, - context: RequestContext[ClientSession], - params: types.ElicitRequestParams, - task_metadata: types.TaskMetadata, - ) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch - - -async def default_get_task_handler( - context: RequestContext[ClientSession], - params: types.GetTaskRequestParams, -) -> types.GetTaskResult | types.ErrorData: - return types.ErrorData( - code=types.METHOD_NOT_FOUND, - message="tasks/get not supported", - ) - - -async def default_get_task_result_handler( - context: RequestContext[ClientSession], - params: types.GetTaskPayloadRequestParams, -) -> types.GetTaskPayloadResult | types.ErrorData: - return types.ErrorData( - code=types.METHOD_NOT_FOUND, - message="tasks/result not supported", - ) - - -async def default_list_tasks_handler( - context: RequestContext[ClientSession], - params: types.PaginatedRequestParams | None, -) -> types.ListTasksResult | types.ErrorData: - return types.ErrorData( - code=types.METHOD_NOT_FOUND, - message="tasks/list not supported", - ) - - -async def default_cancel_task_handler( - context: RequestContext[ClientSession], - params: types.CancelTaskRequestParams, -) -> types.CancelTaskResult | types.ErrorData: - return types.ErrorData( - code=types.METHOD_NOT_FOUND, - message="tasks/cancel not supported", - ) - - -async def default_task_augmented_sampling( - context: RequestContext[ClientSession], - params: types.CreateMessageRequestParams, - task_metadata: types.TaskMetadata, -) -> types.CreateTaskResult | types.ErrorData: - return types.ErrorData( - code=types.INVALID_REQUEST, - message="Task-augmented sampling not supported", - ) - - -async def default_task_augmented_elicitation( - context: RequestContext[ClientSession], - params: types.ElicitRequestParams, - task_metadata: types.TaskMetadata, -) -> types.CreateTaskResult | types.ErrorData: - return types.ErrorData( - code=types.INVALID_REQUEST, - message="Task-augmented elicitation not supported", - ) - - -@dataclass -class ExperimentalTaskHandlers: - """Container for experimental task handlers. - - Groups all task-related handlers that handle server -> client requests. - This includes both pure task requests (get, list, cancel, result) and - task-augmented request handlers (sampling, elicitation with task field). - - WARNING: These APIs are experimental and may change without notice. - - Example: - ```python - handlers = ExperimentalTaskHandlers( - get_task=my_get_task_handler, - list_tasks=my_list_tasks_handler, - ) - session = ClientSession(..., experimental_task_handlers=handlers) - ``` - """ - - # Pure task request handlers - get_task: GetTaskHandlerFnT = field(default=default_get_task_handler) - get_task_result: GetTaskResultHandlerFnT = field(default=default_get_task_result_handler) - list_tasks: ListTasksHandlerFnT = field(default=default_list_tasks_handler) - cancel_task: CancelTaskHandlerFnT = field(default=default_cancel_task_handler) - - # Task-augmented request handlers - augmented_sampling: TaskAugmentedSamplingFnT = field(default=default_task_augmented_sampling) - augmented_elicitation: TaskAugmentedElicitationFnT = field(default=default_task_augmented_elicitation) - - def build_capability(self) -> types.ClientTasksCapability | None: - """Build ClientTasksCapability from the configured handlers. - - Returns a capability object that reflects which handlers are configured - (i.e., not using the default "not supported" handlers). - - Returns: - ClientTasksCapability if any handlers are provided, None otherwise - """ - has_list = self.list_tasks is not default_list_tasks_handler - has_cancel = self.cancel_task is not default_cancel_task_handler - has_sampling = self.augmented_sampling is not default_task_augmented_sampling - has_elicitation = self.augmented_elicitation is not default_task_augmented_elicitation - - # If no handlers are provided, return None - if not any([has_list, has_cancel, has_sampling, has_elicitation]): - return None - - # Build requests capability if any request handlers are provided - requests_capability: types.ClientTasksRequestsCapability | None = None - if has_sampling or has_elicitation: - requests_capability = types.ClientTasksRequestsCapability( - sampling=types.TasksSamplingCapability(create_message=types.TasksCreateMessageCapability()) - if has_sampling - else None, - elicitation=types.TasksElicitationCapability(create=types.TasksCreateElicitationCapability()) - if has_elicitation - else None, - ) - - return types.ClientTasksCapability( - list=types.TasksListCapability() if has_list else None, - cancel=types.TasksCancelCapability() if has_cancel else None, - requests=requests_capability, - ) - - @staticmethod - def handles_request(request: types.ServerRequest) -> bool: - """Check if this handler handles the given request type.""" - return isinstance( - request, - types.GetTaskRequest | types.GetTaskPayloadRequest | types.ListTasksRequest | types.CancelTaskRequest, - ) - - async def handle_request( - self, - ctx: RequestContext[ClientSession], - responder: RequestResponder[types.ServerRequest, types.ClientResult], - ) -> None: - """Handle a task-related request from the server. - - Call handles_request() first to check if this handler can handle the request. - """ - client_response_type: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter( - types.ClientResult | types.ErrorData - ) - - match responder.request: - case types.GetTaskRequest(params=params): - response = await self.get_task(ctx, params) - client_response = client_response_type.validate_python(response) - await responder.respond(client_response) - - case types.GetTaskPayloadRequest(params=params): - response = await self.get_task_result(ctx, params) - client_response = client_response_type.validate_python(response) - await responder.respond(client_response) - - case types.ListTasksRequest(params=params): - response = await self.list_tasks(ctx, params) - client_response = client_response_type.validate_python(response) - await responder.respond(client_response) - - case types.CancelTaskRequest(params=params): - response = await self.cancel_task(ctx, params) - client_response = client_response_type.validate_python(response) - await responder.respond(client_response) - - case _: # pragma: no cover - raise ValueError(f"Unhandled request type: {type(responder.request)}") - - -# Backwards compatibility aliases -default_task_augmented_sampling_callback = default_task_augmented_sampling -default_task_augmented_elicitation_callback = default_task_augmented_elicitation diff --git a/src/mcp/client/experimental/tasks.py b/src/mcp/client/experimental/tasks.py deleted file mode 100644 index a566df766b..0000000000 --- a/src/mcp/client/experimental/tasks.py +++ /dev/null @@ -1,208 +0,0 @@ -"""Experimental client-side task support. - -This module provides client methods for interacting with MCP tasks. - -WARNING: These APIs are experimental and may change without notice. - -Example: - ```python - # Call a tool as a task - result = await session.experimental.call_tool_as_task("tool_name", {"arg": "value"}) - task_id = result.task.task_id - - # Get task status - status = await session.experimental.get_task(task_id) - - # Get task result when complete - if status.status == "completed": - result = await session.experimental.get_task_result(task_id, CallToolResult) - - # List all tasks - tasks = await session.experimental.list_tasks() - - # Cancel a task - await session.experimental.cancel_task(task_id) - ``` -""" - -from collections.abc import AsyncIterator -from typing import TYPE_CHECKING, Any, TypeVar - -from mcp import types -from mcp.shared.experimental.tasks.polling import poll_until_terminal -from mcp.types._types import RequestParamsMeta - -if TYPE_CHECKING: - from mcp.client.session import ClientSession - -ResultT = TypeVar("ResultT", bound=types.Result) - - -class ExperimentalClientFeatures: - """Experimental client features for tasks and other experimental APIs. - - WARNING: These APIs are experimental and may change without notice. - - Access via session.experimental: - status = await session.experimental.get_task(task_id) - """ - - def __init__(self, session: "ClientSession") -> None: - self._session = session - - async def call_tool_as_task( - self, - name: str, - arguments: dict[str, Any] | None = None, - *, - ttl: int = 60000, - meta: RequestParamsMeta | None = None, - ) -> types.CreateTaskResult: - """Call a tool as a task, returning a CreateTaskResult for polling. - - This is a convenience method for calling tools that support task execution. - The server will return a task reference instead of the immediate result, - which can then be polled via `get_task()` and retrieved via `get_task_result()`. - - Args: - name: The tool name - arguments: Tool arguments - ttl: Task time-to-live in milliseconds (default: 60000 = 1 minute) - meta: Optional metadata to include in the request - - Returns: - CreateTaskResult containing the task reference - - Example: - ```python - # Create task - result = await session.experimental.call_tool_as_task( - "long_running_tool", {"input": "data"} - ) - task_id = result.task.task_id - - # Poll for completion - while True: - status = await session.experimental.get_task(task_id) - if status.status == "completed": - break - await anyio.sleep(0.5) - - # Get result - final = await session.experimental.get_task_result(task_id, CallToolResult) - ``` - """ - return await self._session.send_request( - types.CallToolRequest( - params=types.CallToolRequestParams( - name=name, - arguments=arguments, - task=types.TaskMetadata(ttl=ttl), - _meta=meta, - ), - ), - types.CreateTaskResult, - ) - - async def get_task(self, task_id: str) -> types.GetTaskResult: - """Get the current status of a task. - - Args: - task_id: The task identifier - - Returns: - GetTaskResult containing the task status and metadata - """ - return await self._session.send_request( - types.GetTaskRequest(params=types.GetTaskRequestParams(task_id=task_id)), - types.GetTaskResult, - ) - - async def get_task_result( - self, - task_id: str, - result_type: type[ResultT], - ) -> ResultT: - """Get the result of a completed task. - - The result type depends on the original request type: - - tools/call tasks return CallToolResult - - Other request types return their corresponding result type - - Args: - task_id: The task identifier - result_type: The expected result type (e.g., CallToolResult) - - Returns: - The task result, validated against result_type - """ - return await self._session.send_request( - types.GetTaskPayloadRequest( - params=types.GetTaskPayloadRequestParams(task_id=task_id), - ), - result_type, - ) - - async def list_tasks( - self, - cursor: str | None = None, - ) -> types.ListTasksResult: - """List all tasks. - - Args: - cursor: Optional pagination cursor - - Returns: - ListTasksResult containing tasks and optional next cursor - """ - params = types.PaginatedRequestParams(cursor=cursor) if cursor else None - return await self._session.send_request( - types.ListTasksRequest(params=params), - types.ListTasksResult, - ) - - async def cancel_task(self, task_id: str) -> types.CancelTaskResult: - """Cancel a running task. - - Args: - task_id: The task identifier - - Returns: - CancelTaskResult with the updated task state - """ - return await self._session.send_request( - types.CancelTaskRequest( - params=types.CancelTaskRequestParams(task_id=task_id), - ), - types.CancelTaskResult, - ) - - async def poll_task(self, task_id: str) -> AsyncIterator[types.GetTaskResult]: - """Poll a task until it reaches a terminal status. - - Yields GetTaskResult for each poll, allowing the caller to react to - status changes (e.g., handle input_required). Exits when the task reaches - a terminal status (completed, failed, cancelled). - - Respects the pollInterval hint from the server. - - Args: - task_id: The task identifier - - Yields: - GetTaskResult for each poll - - Example: - ```python - async for status in session.experimental.poll_task(task_id): - print(f"Status: {status.status}") - if status.status == "input_required": - # Handle elicitation request via tasks/result - pass - - # Task is now terminal, get the result - result = await session.experimental.get_task_result(task_id, CallToolResult) - ``` - """ - async for status in poll_until_terminal(self.get_task, task_id): - yield status diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 86113874be..08f532eca5 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -8,8 +8,6 @@ from mcp import types from mcp.client._transport import ReadStream, WriteStream -from mcp.client.experimental import ExperimentalClientFeatures -from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers from mcp.shared._context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder @@ -120,7 +118,6 @@ def __init__( client_info: types.Implementation | None = None, *, sampling_capabilities: types.SamplingCapability | None = None, - experimental_task_handlers: ExperimentalTaskHandlers | None = None, ) -> None: super().__init__(read_stream, write_stream, read_timeout_seconds=read_timeout_seconds) self._client_info = client_info or DEFAULT_CLIENT_INFO @@ -132,10 +129,6 @@ def __init__( self._message_handler = message_handler or _default_message_handler self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} self._initialize_result: types.InitializeResult | None = None - self._experimental_features: ExperimentalClientFeatures | None = None - - # Experimental: Task handlers (use defaults if not provided) - self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers() @property def _receive_request_adapter(self) -> TypeAdapter[types.ServerRequest]: @@ -174,7 +167,6 @@ async def initialize(self) -> types.InitializeResult: elicitation=elicitation, experimental=None, roots=roots, - tasks=self._task_handlers.build_capability(), ), client_info=self._client_info, ), @@ -199,23 +191,6 @@ def initialize_result(self) -> types.InitializeResult | None: """ return self._initialize_result - @property - def experimental(self) -> ExperimentalClientFeatures: - """Experimental APIs for tasks and other features. - - !!! warning - These APIs are experimental and may change without notice. - - Example: - ```python - status = await session.experimental.get_task(task_id) - result = await session.experimental.get_task_result(task_id, CallToolResult) - ``` - """ - if self._experimental_features is None: - self._experimental_features = ExperimentalClientFeatures(self) - return self._experimental_features - async def send_ping(self, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult: """Send a ping request.""" return await self.send_request(types.PingRequest(params=types.RequestParams(_meta=meta)), types.EmptyResult) @@ -413,31 +388,16 @@ async def send_roots_list_changed(self) -> None: async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None: ctx = RequestContext[ClientSession](request_id=responder.request_id, meta=responder.request_meta, session=self) - # Delegate to experimental task handler if applicable - if self._task_handlers.handles_request(responder.request): - with responder: - await self._task_handlers.handle_request(ctx, responder) - return None - - # Core request handling match responder.request: case types.CreateMessageRequest(params=params): with responder: - # Check if this is a task-augmented request - if params.task is not None: - response = await self._task_handlers.augmented_sampling(ctx, params, params.task) - else: - response = await self._sampling_callback(ctx, params) + response = await self._sampling_callback(ctx, params) client_response = ClientResponse.validate_python(response) await responder.respond(client_response) case types.ElicitRequest(params=params): with responder: - # Check if this is a task-augmented request - if params.task is not None: - response = await self._task_handlers.augmented_elicitation(ctx, params, params.task) - else: - response = await self._elicitation_callback(ctx, params) + response = await self._elicitation_callback(ctx, params) client_response = ClientResponse.validate_python(response) await responder.respond(client_response) @@ -447,14 +407,9 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques client_response = ClientResponse.validate_python(response) await responder.respond(client_response) - case types.PingRequest(): + case types.PingRequest(): # pragma: no branch with responder: - return await responder.respond(types.EmptyResult()) - - case _: # pragma: no cover - pass # Task requests handled above by _task_handlers - - return None + await responder.respond(types.EmptyResult()) async def _handle_incoming( self, diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index d8e11d78b2..bc54c5d2eb 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -5,7 +5,6 @@ from typing_extensions import TypeVar -from mcp.server.experimental.request_context import Experimental from mcp.server.session import ServerSession from mcp.shared._context import RequestContext from mcp.shared.message import CloseSSEStreamCallback @@ -17,7 +16,6 @@ @dataclass(kw_only=True) class ServerRequestContext(RequestContext[ServerSession], Generic[LifespanContextT, RequestT]): lifespan_context: LifespanContextT - experimental: Experimental request: RequestT | None = None close_sse_stream: CloseSSEStreamCallback | None = None close_standalone_sse_stream: CloseSSEStreamCallback | None = None diff --git a/src/mcp/server/experimental/__init__.py b/src/mcp/server/experimental/__init__.py deleted file mode 100644 index fd1db623f2..0000000000 --- a/src/mcp/server/experimental/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Server-side experimental features. - -WARNING: These APIs are experimental and may change without notice. - -Import directly from submodules: -- mcp.server.experimental.task_context.ServerTaskContext -- mcp.server.experimental.task_support.TaskSupport -- mcp.server.experimental.task_result_handler.TaskResultHandler -- mcp.server.experimental.request_context.Experimental -""" diff --git a/src/mcp/server/experimental/request_context.py b/src/mcp/server/experimental/request_context.py deleted file mode 100644 index 3eba65822a..0000000000 --- a/src/mcp/server/experimental/request_context.py +++ /dev/null @@ -1,217 +0,0 @@ -"""Experimental request context features. - -This module provides the Experimental class which gives access to experimental -features within a request context, such as task-augmented request handling. - -WARNING: These APIs are experimental and may change without notice. -""" - -from collections.abc import Awaitable, Callable -from dataclasses import dataclass, field -from typing import Any - -from mcp.server.experimental.task_context import ServerTaskContext -from mcp.server.experimental.task_support import TaskSupport -from mcp.server.session import ServerSession -from mcp.shared.exceptions import MCPError -from mcp.shared.experimental.tasks.helpers import MODEL_IMMEDIATE_RESPONSE_KEY, is_terminal -from mcp.types import ( - METHOD_NOT_FOUND, - TASK_FORBIDDEN, - TASK_REQUIRED, - ClientCapabilities, - CreateTaskResult, - ErrorData, - Result, - TaskExecutionMode, - TaskMetadata, - Tool, -) - - -@dataclass -class Experimental: - """Experimental features context for task-augmented requests. - - Provides helpers for validating task execution compatibility and - running tasks with automatic lifecycle management. - - WARNING: This API is experimental and may change without notice. - """ - - task_metadata: TaskMetadata | None = None - _client_capabilities: ClientCapabilities | None = field(default=None, repr=False) - _session: ServerSession | None = field(default=None, repr=False) - _task_support: TaskSupport | None = field(default=None, repr=False) - - @property - def is_task(self) -> bool: - """Check if this request is task-augmented.""" - return self.task_metadata is not None - - @property - def client_supports_tasks(self) -> bool: - """Check if the client declared task support.""" - if self._client_capabilities is None: - return False - return self._client_capabilities.tasks is not None - - def validate_task_mode( - self, tool_task_mode: TaskExecutionMode | None, *, raise_error: bool = True - ) -> ErrorData | None: - """Validate that the request is compatible with the tool's task execution mode. - - Per MCP spec: - - "required": Clients MUST invoke as a task. Server returns -32601 if not. - - "forbidden" (or None): Clients MUST NOT invoke as a task. Server returns -32601 if they do. - - "optional": Either is acceptable. - - Args: - tool_task_mode: The tool's execution.taskSupport value - ("forbidden", "optional", "required", or None) - raise_error: If True, raises MCPError on validation failure. If False, returns ErrorData. - - Returns: - None if valid, ErrorData if invalid and raise_error=False - - Raises: - MCPError: If invalid and raise_error=True - """ - - mode = tool_task_mode or TASK_FORBIDDEN - - error: ErrorData | None = None - - if mode == TASK_REQUIRED and not self.is_task: - error = ErrorData(code=METHOD_NOT_FOUND, message="This tool requires task-augmented invocation") - elif mode == TASK_FORBIDDEN and self.is_task: - error = ErrorData(code=METHOD_NOT_FOUND, message="This tool does not support task-augmented invocation") - - if error is not None and raise_error: - raise MCPError.from_error_data(error) - - return error - - def validate_for_tool(self, tool: Tool, *, raise_error: bool = True) -> ErrorData | None: - """Validate that the request is compatible with the given tool. - - Convenience wrapper around validate_task_mode that extracts the mode from a Tool. - - Args: - tool: The Tool definition - raise_error: If True, raises MCPError on validation failure. - - Returns: - None if valid, ErrorData if invalid and raise_error=False - """ - mode = tool.execution.task_support if tool.execution else None - return self.validate_task_mode(mode, raise_error=raise_error) - - def can_use_tool(self, tool_task_mode: TaskExecutionMode | None) -> bool: - """Check if this client can use a tool with the given task mode. - - Useful for filtering tool lists or providing warnings. - Returns False if the tool's task mode is "required" but the client doesn't support tasks. - - Args: - tool_task_mode: The tool's execution.taskSupport value - - Returns: - True if the client can use this tool, False otherwise - """ - mode = tool_task_mode or TASK_FORBIDDEN - if mode == TASK_REQUIRED and not self.client_supports_tasks: - return False - return True - - async def run_task( - self, - work: Callable[[ServerTaskContext], Awaitable[Result]], - *, - task_id: str | None = None, - model_immediate_response: str | None = None, - ) -> CreateTaskResult: - """Create a task, spawn background work, and return CreateTaskResult immediately. - - This is the recommended way to handle task-augmented tool calls. It: - 1. Creates a task in the store - 2. Spawns the work function in a background task - 3. Returns CreateTaskResult immediately - - The work function receives a ServerTaskContext with: - - elicit() for sending elicitation requests - - create_message() for sampling requests - - update_status() for progress updates - - complete()/fail() for finishing the task - - When work() returns a Result, the task is auto-completed with that result. - If work() raises an exception, the task is auto-failed. - - Args: - work: Async function that does the actual work - task_id: Optional task ID (generated if not provided) - model_immediate_response: Optional string to include in _meta as - io.modelcontextprotocol/model-immediate-response - - Returns: - CreateTaskResult to return to the client - - Raises: - RuntimeError: If task support is not enabled or task_metadata is missing - - Example: - ```python - async def handle_tool(ctx: RequestContext, params: CallToolRequestParams) -> CallToolResult: - async def work(task: ServerTaskContext) -> CallToolResult: - result = await task.elicit( - message="Are you sure?", - requested_schema={"type": "object", ...} - ) - confirmed = result.content.get("confirm", False) - return CallToolResult(content=[TextContent(text="Done" if confirmed else "Cancelled")]) - - return await ctx.experimental.run_task(work) - ``` - - WARNING: This API is experimental and may change without notice. - """ - if self._task_support is None: - raise RuntimeError("Task support not enabled. Call server.experimental.enable_tasks() first.") - if self._session is None: - raise RuntimeError("Session not available.") - if self.task_metadata is None: - raise RuntimeError( - "Request is not task-augmented (no task field in params). " - "The client must send a task-augmented request." - ) - - support = self._task_support - # Access task_group via TaskSupport - raises if not in run() context - task_group = support.task_group - - task = await support.store.create_task(self.task_metadata, task_id) - - task_ctx = ServerTaskContext( - task=task, - store=support.store, - session=self._session, - queue=support.queue, - handler=support.handler, - ) - - async def execute() -> None: - try: - result = await work(task_ctx) - if not is_terminal(task_ctx.task.status): - await task_ctx.complete(result) - except Exception as e: - if not is_terminal(task_ctx.task.status): - await task_ctx.fail(str(e)) - - task_group.start_soon(execute) - - meta: dict[str, Any] | None = None - if model_immediate_response is not None: - meta = {MODEL_IMMEDIATE_RESPONSE_KEY: model_immediate_response} - - return CreateTaskResult(task=task, **{"_meta": meta} if meta else {}) diff --git a/src/mcp/server/experimental/session_features.py b/src/mcp/server/experimental/session_features.py deleted file mode 100644 index 2f9d1b0320..0000000000 --- a/src/mcp/server/experimental/session_features.py +++ /dev/null @@ -1,209 +0,0 @@ -"""Experimental server session features for server→client task operations. - -This module provides the server-side equivalent of ExperimentalClientFeatures, -allowing the server to send task-augmented requests to the client and poll for results. - -WARNING: These APIs are experimental and may change without notice. -""" - -from collections.abc import AsyncIterator -from typing import TYPE_CHECKING, Any, TypeVar - -from mcp import types -from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages -from mcp.shared.experimental.tasks.capabilities import ( - require_task_augmented_elicitation, - require_task_augmented_sampling, -) -from mcp.shared.experimental.tasks.polling import poll_until_terminal - -if TYPE_CHECKING: - from mcp.server.session import ServerSession - -ResultT = TypeVar("ResultT", bound=types.Result) - - -class ExperimentalServerSessionFeatures: - """Experimental server session features for server→client task operations. - - This provides the server-side equivalent of ExperimentalClientFeatures, - allowing the server to send task-augmented requests to the client and - poll for results. - - WARNING: These APIs are experimental and may change without notice. - - Access via session.experimental: - result = await session.experimental.elicit_as_task(...) - """ - - def __init__(self, session: "ServerSession") -> None: - self._session = session - - async def get_task(self, task_id: str) -> types.GetTaskResult: - """Send tasks/get to the client to get task status. - - Args: - task_id: The task identifier - - Returns: - GetTaskResult containing the task status - """ - return await self._session.send_request( - types.GetTaskRequest(params=types.GetTaskRequestParams(task_id=task_id)), - types.GetTaskResult, - ) - - async def get_task_result( - self, - task_id: str, - result_type: type[ResultT], - ) -> ResultT: - """Send tasks/result to the client to retrieve the final result. - - Args: - task_id: The task identifier - result_type: The expected result type - - Returns: - The task result, validated against result_type - """ - return await self._session.send_request( - types.GetTaskPayloadRequest(params=types.GetTaskPayloadRequestParams(task_id=task_id)), - result_type, - ) - - async def poll_task(self, task_id: str) -> AsyncIterator[types.GetTaskResult]: - """Poll a client task until it reaches terminal status. - - Yields GetTaskResult for each poll, allowing the caller to react to - status changes. Exits when task reaches a terminal status. - - Respects the pollInterval hint from the client. - - Args: - task_id: The task identifier - - Yields: - GetTaskResult for each poll - """ - async for status in poll_until_terminal(self.get_task, task_id): - yield status - - async def elicit_as_task( - self, - message: str, - requested_schema: types.ElicitRequestedSchema, - *, - ttl: int = 60000, - ) -> types.ElicitResult: - """Send a task-augmented elicitation to the client and poll until complete. - - The client will create a local task, process the elicitation asynchronously, - and return the result when ready. This method handles the full flow: - 1. Send elicitation with task field - 2. Receive CreateTaskResult from client - 3. Poll client's task until terminal - 4. Retrieve and return the final ElicitResult - - Args: - message: The message to present to the user - requested_schema: Schema defining the expected response - ttl: Task time-to-live in milliseconds - - Returns: - The client's elicitation response - - Raises: - MCPError: If client doesn't support task-augmented elicitation - """ - client_caps = self._session.client_params.capabilities if self._session.client_params else None - require_task_augmented_elicitation(client_caps) - - create_result = await self._session.send_request( - types.ElicitRequest( - params=types.ElicitRequestFormParams( - message=message, - requested_schema=requested_schema, - task=types.TaskMetadata(ttl=ttl), - ) - ), - types.CreateTaskResult, - ) - - task_id = create_result.task.task_id - - async for _ in self.poll_task(task_id): - pass - - return await self.get_task_result(task_id, types.ElicitResult) - - async def create_message_as_task( - self, - messages: list[types.SamplingMessage], - *, - max_tokens: int, - ttl: int = 60000, - system_prompt: str | None = None, - include_context: types.IncludeContext | None = None, - temperature: float | None = None, - stop_sequences: list[str] | None = None, - metadata: dict[str, Any] | None = None, - model_preferences: types.ModelPreferences | None = None, - tools: list[types.Tool] | None = None, - tool_choice: types.ToolChoice | None = None, - ) -> types.CreateMessageResult: - """Send a task-augmented sampling request and poll until complete. - - The client will create a local task, process the sampling request - asynchronously, and return the result when ready. - - Args: - messages: The conversation messages for sampling - max_tokens: Maximum tokens in the response - ttl: Task time-to-live in milliseconds - system_prompt: Optional system prompt - include_context: Context inclusion strategy - temperature: Sampling temperature - stop_sequences: Stop sequences - metadata: Additional metadata - model_preferences: Model selection preferences - tools: Optional list of tools the LLM can use during sampling - tool_choice: Optional control over tool usage behavior - - Returns: - The sampling result from the client - - Raises: - MCPError: If client doesn't support task-augmented sampling or tools - ValueError: If tool_use or tool_result message structure is invalid - """ - client_caps = self._session.client_params.capabilities if self._session.client_params else None - require_task_augmented_sampling(client_caps) - validate_sampling_tools(client_caps, tools, tool_choice) - validate_tool_use_result_messages(messages) - - create_result = await self._session.send_request( - types.CreateMessageRequest( - params=types.CreateMessageRequestParams( - messages=messages, - max_tokens=max_tokens, - system_prompt=system_prompt, - include_context=include_context, - temperature=temperature, - stop_sequences=stop_sequences, - metadata=metadata, - model_preferences=model_preferences, - tools=tools, - tool_choice=tool_choice, - task=types.TaskMetadata(ttl=ttl), - ) - ), - types.CreateTaskResult, - ) - - task_id = create_result.task.task_id - - async for _ in self.poll_task(task_id): - pass - - return await self.get_task_result(task_id, types.CreateMessageResult) diff --git a/src/mcp/server/experimental/task_context.py b/src/mcp/server/experimental/task_context.py deleted file mode 100644 index 1fc45badfd..0000000000 --- a/src/mcp/server/experimental/task_context.py +++ /dev/null @@ -1,587 +0,0 @@ -"""ServerTaskContext - Server-integrated task context with elicitation and sampling. - -This wraps the pure TaskContext and adds server-specific functionality: -- Elicitation (task.elicit()) -- Sampling (task.create_message()) -- Status notifications -""" - -from typing import Any - -import anyio - -from mcp.server.experimental.task_result_handler import TaskResultHandler -from mcp.server.session import ServerSession -from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages -from mcp.shared.exceptions import MCPError -from mcp.shared.experimental.tasks.capabilities import ( - require_task_augmented_elicitation, - require_task_augmented_sampling, -) -from mcp.shared.experimental.tasks.context import TaskContext -from mcp.shared.experimental.tasks.message_queue import QueuedMessage, TaskMessageQueue -from mcp.shared.experimental.tasks.resolver import Resolver -from mcp.shared.experimental.tasks.store import TaskStore -from mcp.types import ( - INVALID_REQUEST, - TASK_STATUS_INPUT_REQUIRED, - TASK_STATUS_WORKING, - ClientCapabilities, - CreateMessageResult, - CreateTaskResult, - ElicitationCapability, - ElicitRequestedSchema, - ElicitResult, - IncludeContext, - ModelPreferences, - RequestId, - Result, - SamplingCapability, - SamplingMessage, - Task, - TaskMetadata, - TaskStatusNotification, - TaskStatusNotificationParams, - Tool, - ToolChoice, -) - - -class ServerTaskContext: - """Server-integrated task context with elicitation and sampling. - - This wraps a pure TaskContext and adds server-specific functionality: - - elicit() for sending elicitation requests to the client - - create_message() for sampling requests - - Status notifications via the session - - Example: - ```python - async def my_task_work(task: ServerTaskContext) -> CallToolResult: - await task.update_status("Starting...") - - result = await task.elicit( - message="Continue?", - requested_schema={"type": "object", "properties": {"ok": {"type": "boolean"}}} - ) - - if result.content.get("ok"): - return CallToolResult(content=[TextContent(text="Done!")]) - else: - return CallToolResult(content=[TextContent(text="Cancelled")]) - ``` - """ - - def __init__( - self, - *, - task: Task, - store: TaskStore, - session: ServerSession, - queue: TaskMessageQueue, - handler: TaskResultHandler | None = None, - ): - """Create a ServerTaskContext. - - Args: - task: The Task object - store: The task store - session: The server session - queue: The message queue for elicitation/sampling - handler: The result handler for response routing (required for elicit/create_message) - """ - self._ctx = TaskContext(task=task, store=store) - self._session = session - self._queue = queue - self._handler = handler - self._store = store - - # Delegate pure properties to inner context - - @property - def task_id(self) -> str: - """The task identifier.""" - return self._ctx.task_id - - @property - def task(self) -> Task: - """The current task state.""" - return self._ctx.task - - @property - def is_cancelled(self) -> bool: - """Whether cancellation has been requested.""" - return self._ctx.is_cancelled - - def request_cancellation(self) -> None: - """Request cancellation of this task.""" - self._ctx.request_cancellation() - - # Enhanced methods with notifications - - async def update_status(self, message: str, *, notify: bool = True) -> None: - """Update the task's status message. - - Args: - message: The new status message - notify: Whether to send a notification to the client - """ - await self._ctx.update_status(message) - if notify: - await self._send_notification() - - async def complete(self, result: Result, *, notify: bool = True) -> None: - """Mark the task as completed with the given result. - - Args: - result: The task result - notify: Whether to send a notification to the client - """ - await self._ctx.complete(result) - if notify: - await self._send_notification() - - async def fail(self, error: str, *, notify: bool = True) -> None: - """Mark the task as failed with an error message. - - Args: - error: The error message - notify: Whether to send a notification to the client - """ - await self._ctx.fail(error) - if notify: - await self._send_notification() - - async def _send_notification(self) -> None: - """Send a task status notification to the client.""" - task = self._ctx.task - await self._session.send_notification( - TaskStatusNotification( - params=TaskStatusNotificationParams( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=task.poll_interval, - ) - ) - ) - - # Server-specific methods: elicitation and sampling - - def _check_elicitation_capability(self) -> None: - """Check if the client supports elicitation.""" - if not self._session.check_client_capability(ClientCapabilities(elicitation=ElicitationCapability())): - raise MCPError(code=INVALID_REQUEST, message="Client does not support elicitation capability") - - def _check_sampling_capability(self) -> None: - """Check if the client supports sampling.""" - if not self._session.check_client_capability(ClientCapabilities(sampling=SamplingCapability())): - raise MCPError(code=INVALID_REQUEST, message="Client does not support sampling capability") - - async def elicit( - self, - message: str, - requested_schema: ElicitRequestedSchema, - ) -> ElicitResult: - """Send an elicitation request via the task message queue. - - This method: - 1. Checks client capability - 2. Updates task status to "input_required" - 3. Queues the elicitation request - 4. Waits for the response (delivered via tasks/result round-trip) - 5. Updates task status back to "working" - 6. Returns the result - - Args: - message: The message to present to the user - requested_schema: Schema defining the expected response structure - - Returns: - The client's response - - Raises: - MCPError: If client doesn't support elicitation capability - """ - self._check_elicitation_capability() - - if self._handler is None: - raise RuntimeError("handler is required for elicit(). Pass handler= to ServerTaskContext.") - - # Update status to input_required - await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED) - - # Build the request using session's helper - request = self._session._build_elicit_form_request( # pyright: ignore[reportPrivateUsage] - message=message, - requested_schema=requested_schema, - related_task_id=self.task_id, - ) - request_id: RequestId = request.id - - resolver: Resolver[dict[str, Any]] = Resolver() - self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage] - - queued = QueuedMessage( - type="request", - message=request, - resolver=resolver, - original_request_id=request_id, - ) - await self._queue.enqueue(self.task_id, queued) - - try: - # Wait for response (routed back via TaskResultHandler) - response_data = await resolver.wait() - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) - return ElicitResult.model_validate(response_data) - except anyio.get_cancelled_exc_class(): - # This path is tested in test_elicit_restores_status_on_cancellation - # which verifies status is restored to "working" after cancellation. - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) - raise - - async def elicit_url( - self, - message: str, - url: str, - elicitation_id: str, - ) -> ElicitResult: - """Send a URL mode elicitation request via the task message queue. - - This directs the user to an external URL for out-of-band interactions - like OAuth flows, credential collection, or payment processing. - - This method: - 1. Checks client capability - 2. Updates task status to "input_required" - 3. Queues the elicitation request - 4. Waits for the response (delivered via tasks/result round-trip) - 5. Updates task status back to "working" - 6. Returns the result - - Args: - message: Human-readable explanation of why the interaction is needed - url: The URL the user should navigate to - elicitation_id: Unique identifier for tracking this elicitation - - Returns: - The client's response indicating acceptance, decline, or cancellation - - Raises: - MCPError: If client doesn't support elicitation capability - RuntimeError: If handler is not configured - """ - self._check_elicitation_capability() - - if self._handler is None: - raise RuntimeError("handler is required for elicit_url(). Pass handler= to ServerTaskContext.") - - # Update status to input_required - await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED) - - # Build the request using session's helper - request = self._session._build_elicit_url_request( # pyright: ignore[reportPrivateUsage] - message=message, - url=url, - elicitation_id=elicitation_id, - related_task_id=self.task_id, - ) - request_id: RequestId = request.id - - resolver: Resolver[dict[str, Any]] = Resolver() - self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage] - - queued = QueuedMessage( - type="request", - message=request, - resolver=resolver, - original_request_id=request_id, - ) - await self._queue.enqueue(self.task_id, queued) - - try: - # Wait for response (routed back via TaskResultHandler) - response_data = await resolver.wait() - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) - return ElicitResult.model_validate(response_data) - except anyio.get_cancelled_exc_class(): # pragma: no cover - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) - raise - - async def create_message( - self, - messages: list[SamplingMessage], - *, - max_tokens: int, - system_prompt: str | None = None, - include_context: IncludeContext | None = None, - temperature: float | None = None, - stop_sequences: list[str] | None = None, - metadata: dict[str, Any] | None = None, - model_preferences: ModelPreferences | None = None, - tools: list[Tool] | None = None, - tool_choice: ToolChoice | None = None, - ) -> CreateMessageResult: - """Send a sampling request via the task message queue. - - This method: - 1. Checks client capability - 2. Updates task status to "input_required" - 3. Queues the sampling request - 4. Waits for the response (delivered via tasks/result round-trip) - 5. Updates task status back to "working" - 6. Returns the result - - Args: - messages: The conversation messages for sampling - max_tokens: Maximum tokens in the response - system_prompt: Optional system prompt - include_context: Context inclusion strategy - temperature: Sampling temperature - stop_sequences: Stop sequences - metadata: Additional metadata - model_preferences: Model selection preferences - tools: Optional list of tools the LLM can use during sampling - tool_choice: Optional control over tool usage behavior - - Returns: - The sampling result from the client - - Raises: - MCPError: If client doesn't support sampling capability or tools - ValueError: If tool_use or tool_result message structure is invalid - """ - self._check_sampling_capability() - client_caps = self._session.client_params.capabilities if self._session.client_params else None - validate_sampling_tools(client_caps, tools, tool_choice) - validate_tool_use_result_messages(messages) - - if self._handler is None: - raise RuntimeError("handler is required for create_message(). Pass handler= to ServerTaskContext.") - - # Update status to input_required - await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED) - - # Build the request using session's helper - request = self._session._build_create_message_request( # pyright: ignore[reportPrivateUsage] - messages=messages, - max_tokens=max_tokens, - system_prompt=system_prompt, - include_context=include_context, - temperature=temperature, - stop_sequences=stop_sequences, - metadata=metadata, - model_preferences=model_preferences, - tools=tools, - tool_choice=tool_choice, - related_task_id=self.task_id, - ) - request_id: RequestId = request.id - - resolver: Resolver[dict[str, Any]] = Resolver() - self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage] - - queued = QueuedMessage( - type="request", - message=request, - resolver=resolver, - original_request_id=request_id, - ) - await self._queue.enqueue(self.task_id, queued) - - try: - # Wait for response (routed back via TaskResultHandler) - response_data = await resolver.wait() - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) - return CreateMessageResult.model_validate(response_data) - except anyio.get_cancelled_exc_class(): - # This path is tested in test_create_message_restores_status_on_cancellation - # which verifies status is restored to "working" after cancellation. - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) - raise - - async def elicit_as_task( - self, - message: str, - requested_schema: ElicitRequestedSchema, - *, - ttl: int = 60000, - ) -> ElicitResult: - """Send a task-augmented elicitation via the queue, then poll client. - - This is for use inside a task-augmented tool call when you want the client - to handle the elicitation as its own task. The elicitation request is queued - and delivered when the client calls tasks/result. After the client responds - with CreateTaskResult, we poll the client's task until complete. - - Args: - message: The message to present to the user - requested_schema: Schema defining the expected response structure - ttl: Task time-to-live in milliseconds for the client's task - - Returns: - The client's elicitation response - - Raises: - MCPError: If client doesn't support task-augmented elicitation - RuntimeError: If handler is not configured - """ - client_caps = self._session.client_params.capabilities if self._session.client_params else None - require_task_augmented_elicitation(client_caps) - - if self._handler is None: - raise RuntimeError("handler is required for elicit_as_task()") - - # Update status to input_required - await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED) - - request = self._session._build_elicit_form_request( # pyright: ignore[reportPrivateUsage] - message=message, - requested_schema=requested_schema, - related_task_id=self.task_id, - task=TaskMetadata(ttl=ttl), - ) - request_id: RequestId = request.id - - resolver: Resolver[dict[str, Any]] = Resolver() - self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage] - - queued = QueuedMessage( - type="request", - message=request, - resolver=resolver, - original_request_id=request_id, - ) - await self._queue.enqueue(self.task_id, queued) - - try: - # Wait for initial response (CreateTaskResult from client) - response_data = await resolver.wait() - create_result = CreateTaskResult.model_validate(response_data) - client_task_id = create_result.task.task_id - - # Poll the client's task using session.experimental - async for _ in self._session.experimental.poll_task(client_task_id): - pass - - # Get final result from client - result = await self._session.experimental.get_task_result( - client_task_id, - ElicitResult, - ) - - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) - return result - - except anyio.get_cancelled_exc_class(): # pragma: no cover - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) - raise - - async def create_message_as_task( - self, - messages: list[SamplingMessage], - *, - max_tokens: int, - ttl: int = 60000, - system_prompt: str | None = None, - include_context: IncludeContext | None = None, - temperature: float | None = None, - stop_sequences: list[str] | None = None, - metadata: dict[str, Any] | None = None, - model_preferences: ModelPreferences | None = None, - tools: list[Tool] | None = None, - tool_choice: ToolChoice | None = None, - ) -> CreateMessageResult: - """Send a task-augmented sampling request via the queue, then poll client. - - This is for use inside a task-augmented tool call when you want the client - to handle the sampling as its own task. The request is queued and delivered - when the client calls tasks/result. After the client responds with - CreateTaskResult, we poll the client's task until complete. - - Args: - messages: The conversation messages for sampling - max_tokens: Maximum tokens in the response - ttl: Task time-to-live in milliseconds for the client's task - system_prompt: Optional system prompt - include_context: Context inclusion strategy - temperature: Sampling temperature - stop_sequences: Stop sequences - metadata: Additional metadata - model_preferences: Model selection preferences - tools: Optional list of tools the LLM can use during sampling - tool_choice: Optional control over tool usage behavior - - Returns: - The sampling result from the client - - Raises: - MCPError: If client doesn't support task-augmented sampling or tools - ValueError: If tool_use or tool_result message structure is invalid - RuntimeError: If handler is not configured - """ - client_caps = self._session.client_params.capabilities if self._session.client_params else None - require_task_augmented_sampling(client_caps) - validate_sampling_tools(client_caps, tools, tool_choice) - validate_tool_use_result_messages(messages) - - if self._handler is None: - raise RuntimeError("handler is required for create_message_as_task()") - - # Update status to input_required - await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED) - - # Build request WITH task field for task-augmented sampling - request = self._session._build_create_message_request( # pyright: ignore[reportPrivateUsage] - messages=messages, - max_tokens=max_tokens, - system_prompt=system_prompt, - include_context=include_context, - temperature=temperature, - stop_sequences=stop_sequences, - metadata=metadata, - model_preferences=model_preferences, - tools=tools, - tool_choice=tool_choice, - related_task_id=self.task_id, - task=TaskMetadata(ttl=ttl), - ) - request_id: RequestId = request.id - - resolver: Resolver[dict[str, Any]] = Resolver() - self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage] - - queued = QueuedMessage( - type="request", - message=request, - resolver=resolver, - original_request_id=request_id, - ) - await self._queue.enqueue(self.task_id, queued) - - try: - # Wait for initial response (CreateTaskResult from client) - response_data = await resolver.wait() - create_result = CreateTaskResult.model_validate(response_data) - client_task_id = create_result.task.task_id - - # Poll the client's task using session.experimental - async for _ in self._session.experimental.poll_task(client_task_id): - pass - - # Get final result from client - result = await self._session.experimental.get_task_result( - client_task_id, - CreateMessageResult, - ) - - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) - return result - - except anyio.get_cancelled_exc_class(): # pragma: no cover - await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) - raise diff --git a/src/mcp/server/experimental/task_result_handler.py b/src/mcp/server/experimental/task_result_handler.py deleted file mode 100644 index b2268bc1c8..0000000000 --- a/src/mcp/server/experimental/task_result_handler.py +++ /dev/null @@ -1,218 +0,0 @@ -"""TaskResultHandler - Integrated handler for tasks/result endpoint. - -This implements the dequeue-send-wait pattern from the MCP Tasks spec: -1. Dequeue all pending messages for the task -2. Send them to the client via transport with relatedRequestId routing -3. Wait if task is not in terminal state -4. Return final result when task completes - -This is the core of the task message queue pattern. -""" - -import logging -from typing import Any - -import anyio - -from mcp.server.session import ServerSession -from mcp.shared.exceptions import MCPError -from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY, is_terminal -from mcp.shared.experimental.tasks.message_queue import TaskMessageQueue -from mcp.shared.experimental.tasks.resolver import Resolver -from mcp.shared.experimental.tasks.store import TaskStore -from mcp.shared.message import ServerMessageMetadata, SessionMessage -from mcp.types import ( - INVALID_PARAMS, - ErrorData, - GetTaskPayloadRequest, - GetTaskPayloadResult, - RelatedTaskMetadata, - RequestId, -) - -logger = logging.getLogger(__name__) - - -class TaskResultHandler: - """Handler for tasks/result that implements the message queue pattern. - - This handler: - 1. Dequeues pending messages (elicitations, notifications) for the task - 2. Sends them to the client via the response stream - 3. Waits for responses and resolves them back to callers - 4. Blocks until task reaches terminal state - 5. Returns the final result - - Usage: - async def handle_task_result( - ctx: ServerRequestContext, params: GetTaskPayloadRequestParams - ) -> GetTaskPayloadResult: - ... - - server.experimental.enable_tasks( - on_task_result=handle_task_result, - ) - """ - - def __init__( - self, - store: TaskStore, - queue: TaskMessageQueue, - ): - self._store = store - self._queue = queue - # Map from internal request ID to resolver for routing responses - self._pending_requests: dict[RequestId, Resolver[dict[str, Any]]] = {} - - async def send_message( - self, - session: ServerSession, - message: SessionMessage, - ) -> None: - """Send a message via the session. - - This is a helper for delivering queued task messages. - """ - await session.send_message(message) - - async def handle( - self, - request: GetTaskPayloadRequest, - session: ServerSession, - request_id: RequestId, - ) -> GetTaskPayloadResult: - """Handle a tasks/result request. - - This implements the dequeue-send-wait loop: - 1. Dequeue all pending messages - 2. Send each via transport with relatedRequestId = this request's ID - 3. If task not terminal, wait for status change - 4. Loop until task is terminal - 5. Return final result - - Args: - request: The GetTaskPayloadRequest - session: The server session for sending messages - request_id: The request ID for relatedRequestId routing - - Returns: - GetTaskPayloadResult with the task's final payload - """ - task_id = request.params.task_id - - while True: - task = await self._store.get_task(task_id) - if task is None: - raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {task_id}") - - await self._deliver_queued_messages(task_id, session, request_id) - - # If task is terminal, return result - if is_terminal(task.status): - result = await self._store.get_result(task_id) - # GetTaskPayloadResult is a Result with extra="allow" - # The stored result contains the actual payload data - # Per spec: tasks/result MUST include _meta with related-task metadata - related_task = RelatedTaskMetadata(task_id=task_id) - related_task_meta: dict[str, Any] = {RELATED_TASK_METADATA_KEY: related_task.model_dump(by_alias=True)} - if result is not None: - result_data = result.model_dump(by_alias=True) - existing_meta: dict[str, Any] = result_data.get("_meta") or {} - result_data["_meta"] = {**existing_meta, **related_task_meta} - return GetTaskPayloadResult.model_validate(result_data) - return GetTaskPayloadResult.model_validate({"_meta": related_task_meta}) - - # Wait for task update (status change or new messages) - await self._wait_for_task_update(task_id) - - async def _deliver_queued_messages( - self, - task_id: str, - session: ServerSession, - request_id: RequestId, - ) -> None: - """Dequeue and send all pending messages for a task. - - Each message is sent via the session's write stream with - relatedRequestId set so responses route back to this stream. - """ - while True: - message = await self._queue.dequeue(task_id) - if message is None: - break - - # If this is a request (not notification), wait for response - if message.type == "request" and message.resolver is not None: - # Store the resolver so we can route the response back - original_id = message.original_request_id - if original_id is not None: - self._pending_requests[original_id] = message.resolver - - logger.debug("Delivering queued message for task %s: %s", task_id, message.type) - - # Send the message with relatedRequestId for routing - session_message = SessionMessage( - message=message.message, - metadata=ServerMessageMetadata(related_request_id=request_id), - ) - await self.send_message(session, session_message) - - async def _wait_for_task_update(self, task_id: str) -> None: - """Wait for task to be updated (status change or new message). - - Races between store update and queue message - first one wins. - """ - async with anyio.create_task_group() as tg: - - async def wait_for_store() -> None: - try: - await self._store.wait_for_update(task_id) - except Exception: - pass - finally: - tg.cancel_scope.cancel() - - async def wait_for_queue() -> None: - try: - await self._queue.wait_for_message(task_id) - except Exception: - pass - finally: - tg.cancel_scope.cancel() - - tg.start_soon(wait_for_store) - tg.start_soon(wait_for_queue) - - def route_response(self, request_id: RequestId, response: dict[str, Any]) -> bool: - """Route a response back to the waiting resolver. - - This is called when a response arrives for a queued request. - - Args: - request_id: The request ID from the response - response: The response data - - Returns: - True if response was routed, False if no pending request - """ - resolver = self._pending_requests.pop(request_id, None) - if resolver is not None and not resolver.done(): - resolver.set_result(response) - return True - return False - - def route_error(self, request_id: RequestId, error: ErrorData) -> bool: - """Route an error back to the waiting resolver. - - Args: - request_id: The request ID from the error response - error: The error data - - Returns: - True if error was routed, False if no pending request - """ - resolver = self._pending_requests.pop(request_id, None) - if resolver is not None and not resolver.done(): - resolver.set_exception(MCPError.from_error_data(error)) - return True - return False diff --git a/src/mcp/server/experimental/task_support.py b/src/mcp/server/experimental/task_support.py deleted file mode 100644 index b542195048..0000000000 --- a/src/mcp/server/experimental/task_support.py +++ /dev/null @@ -1,116 +0,0 @@ -"""TaskSupport - Configuration for experimental task support. - -This module provides the TaskSupport class which encapsulates all the -infrastructure needed for task-augmented requests: store, queue, and handler. -""" - -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager -from dataclasses import dataclass, field - -import anyio -from anyio.abc import TaskGroup - -from mcp.server.experimental.task_result_handler import TaskResultHandler -from mcp.server.session import ServerSession -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue, TaskMessageQueue -from mcp.shared.experimental.tasks.store import TaskStore - - -@dataclass -class TaskSupport: - """Configuration for experimental task support. - - Encapsulates the task store, message queue, result handler, and task group - for spawning background work. - - When enabled on a server, this automatically: - - Configures response routing for each session - - Provides default handlers for task operations - - Manages a task group for background task execution - - Example: - Simple in-memory setup: - - ```python - server.experimental.enable_tasks() - ``` - - Custom store/queue for distributed systems: - - ```python - server.experimental.enable_tasks( - store=RedisTaskStore(redis_url), - queue=RedisTaskMessageQueue(redis_url), - ) - ``` - """ - - store: TaskStore - queue: TaskMessageQueue - handler: TaskResultHandler = field(init=False) - _task_group: TaskGroup | None = field(init=False, default=None) - - def __post_init__(self) -> None: - """Create the result handler from store and queue.""" - self.handler = TaskResultHandler(self.store, self.queue) - - @property - def task_group(self) -> TaskGroup: - """Get the task group for spawning background work. - - Raises: - RuntimeError: If not within a run() context - """ - if self._task_group is None: - raise RuntimeError("TaskSupport not running. Ensure Server.run() is active.") - return self._task_group - - @asynccontextmanager - async def run(self) -> AsyncIterator[None]: - """Run the task support lifecycle. - - This creates a task group for spawning background task work. - Called automatically by Server.run(). - - Usage: - async with task_support.run(): - # Task group is now available - ... - """ - async with anyio.create_task_group() as tg: - self._task_group = tg - try: - yield - finally: - self._task_group = None - - def configure_session(self, session: ServerSession) -> None: - """Configure a session for task support. - - This registers the result handler as a response router so that - responses to queued requests (elicitation, sampling) are routed - back to the waiting resolvers. - - Called automatically by Server.run() for each new session. - - Args: - session: The session to configure - """ - session.add_response_router(self.handler) - - @classmethod - def in_memory(cls) -> "TaskSupport": - """Create in-memory task support. - - Suitable for development, testing, and single-process servers. - For distributed systems, provide custom store and queue implementations. - - Returns: - TaskSupport configured with in-memory store and queue - """ - return cls( - store=InMemoryTaskStore(), - queue=InMemoryTaskMessageQueue(), - ) diff --git a/src/mcp/server/lowlevel/experimental.py b/src/mcp/server/lowlevel/experimental.py deleted file mode 100644 index 5a907b6407..0000000000 --- a/src/mcp/server/lowlevel/experimental.py +++ /dev/null @@ -1,210 +0,0 @@ -"""Experimental handlers for the low-level MCP server. - -WARNING: These APIs are experimental and may change without notice. -""" - -from __future__ import annotations - -import logging -from collections.abc import Awaitable, Callable -from typing import Any, Generic - -from typing_extensions import TypeVar - -from mcp.server.context import ServerRequestContext -from mcp.server.experimental.task_support import TaskSupport -from mcp.shared.exceptions import MCPError -from mcp.shared.experimental.tasks.helpers import cancel_task -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue, TaskMessageQueue -from mcp.shared.experimental.tasks.store import TaskStore -from mcp.types import ( - INVALID_PARAMS, - CancelTaskRequestParams, - CancelTaskResult, - GetTaskPayloadRequest, - GetTaskPayloadRequestParams, - GetTaskPayloadResult, - GetTaskRequestParams, - GetTaskResult, - ListTasksResult, - PaginatedRequestParams, - ServerCapabilities, - ServerTasksCapability, - ServerTasksRequestsCapability, - TasksCallCapability, - TasksCancelCapability, - TasksListCapability, - TasksToolsCapability, -) - -logger = logging.getLogger(__name__) - -LifespanResultT = TypeVar("LifespanResultT", default=Any) - - -class ExperimentalHandlers(Generic[LifespanResultT]): - """Experimental request/notification handlers. - - WARNING: These APIs are experimental and may change without notice. - """ - - def __init__( - self, - add_request_handler: Callable[ - [str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]]], None - ], - has_handler: Callable[[str], bool], - ) -> None: - self._add_request_handler = add_request_handler - self._has_handler = has_handler - self._task_support: TaskSupport | None = None - - @property - def task_support(self) -> TaskSupport | None: - """Get the task support configuration, if enabled.""" - return self._task_support - - def update_capabilities(self, capabilities: ServerCapabilities) -> None: - # Only add tasks capability if handlers are registered - if not any(self._has_handler(method) for method in ["tasks/get", "tasks/list", "tasks/cancel", "tasks/result"]): - return - - capabilities.tasks = ServerTasksCapability() - if self._has_handler("tasks/list"): - capabilities.tasks.list = TasksListCapability() - if self._has_handler("tasks/cancel"): - capabilities.tasks.cancel = TasksCancelCapability() - - capabilities.tasks.requests = ServerTasksRequestsCapability( - tools=TasksToolsCapability(call=TasksCallCapability()) - ) # assuming always supported for now - - def enable_tasks( - self, - store: TaskStore | None = None, - queue: TaskMessageQueue | None = None, - *, - on_get_task: Callable[[ServerRequestContext[LifespanResultT], GetTaskRequestParams], Awaitable[GetTaskResult]] - | None = None, - on_task_result: Callable[ - [ServerRequestContext[LifespanResultT], GetTaskPayloadRequestParams], Awaitable[GetTaskPayloadResult] - ] - | None = None, - on_list_tasks: Callable[ - [ServerRequestContext[LifespanResultT], PaginatedRequestParams | None], Awaitable[ListTasksResult] - ] - | None = None, - on_cancel_task: Callable[ - [ServerRequestContext[LifespanResultT], CancelTaskRequestParams], Awaitable[CancelTaskResult] - ] - | None = None, - ) -> TaskSupport: - """Enable experimental task support. - - This sets up the task infrastructure and registers handlers for - tasks/get, tasks/result, tasks/list, and tasks/cancel. Custom handlers - can be provided via the on_* kwargs; any not provided will use defaults. - - Args: - store: Custom TaskStore implementation (defaults to InMemoryTaskStore) - queue: Custom TaskMessageQueue implementation (defaults to InMemoryTaskMessageQueue) - on_get_task: Custom handler for tasks/get - on_task_result: Custom handler for tasks/result - on_list_tasks: Custom handler for tasks/list - on_cancel_task: Custom handler for tasks/cancel - - Returns: - The TaskSupport configuration object - - Example: - Simple in-memory setup: - - ```python - server.experimental.enable_tasks() - ``` - - Custom store/queue for distributed systems: - - ```python - server.experimental.enable_tasks( - store=RedisTaskStore(redis_url), - queue=RedisTaskMessageQueue(redis_url), - ) - ``` - - WARNING: This API is experimental and may change without notice. - """ - if store is None: - store = InMemoryTaskStore() - if queue is None: - queue = InMemoryTaskMessageQueue() - - self._task_support = TaskSupport(store=store, queue=queue) - task_support = self._task_support - - # Register user-provided handlers - if on_get_task is not None: - self._add_request_handler("tasks/get", on_get_task) - if on_task_result is not None: - self._add_request_handler("tasks/result", on_task_result) - if on_list_tasks is not None: - self._add_request_handler("tasks/list", on_list_tasks) - if on_cancel_task is not None: - self._add_request_handler("tasks/cancel", on_cancel_task) - - # Fill in defaults for any not provided - if not self._has_handler("tasks/get"): - - async def _default_get_task( - ctx: ServerRequestContext[LifespanResultT], params: GetTaskRequestParams - ) -> GetTaskResult: - task = await task_support.store.get_task(params.task_id) - if task is None: - raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {params.task_id}") - return GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=task.poll_interval, - ) - - self._add_request_handler("tasks/get", _default_get_task) - - if not self._has_handler("tasks/result"): - - async def _default_get_task_result( - ctx: ServerRequestContext[LifespanResultT], params: GetTaskPayloadRequestParams - ) -> GetTaskPayloadResult: - assert ctx.request_id is not None - req = GetTaskPayloadRequest(params=params) - result = await task_support.handler.handle(req, ctx.session, ctx.request_id) - return result - - self._add_request_handler("tasks/result", _default_get_task_result) - - if not self._has_handler("tasks/list"): - - async def _default_list_tasks( - ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None - ) -> ListTasksResult: - cursor = params.cursor if params else None - tasks, next_cursor = await task_support.store.list_tasks(cursor) - return ListTasksResult(tasks=tasks, next_cursor=next_cursor) - - self._add_request_handler("tasks/list", _default_list_tasks) - - if not self._has_handler("tasks/cancel"): - - async def _default_cancel_task( - ctx: ServerRequestContext[LifespanResultT], params: CancelTaskRequestParams - ) -> CancelTaskResult: - result = await cancel_task(task_support.store, params.task_id) - return result - - self._add_request_handler("tasks/cancel", _default_cancel_task) - - return task_support diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 5e4e2e6f5b..37127c5621 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -59,8 +59,6 @@ async def main(): from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes from mcp.server.auth.settings import AuthSettings from mcp.server.context import ServerRequestContext -from mcp.server.experimental.request_context import Experimental -from mcp.server.lowlevel.experimental import ExperimentalHandlers from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.server.streamable_http import EventStore @@ -121,7 +119,7 @@ def __init__( | None = None, on_call_tool: Callable[ [ServerRequestContext[LifespanResultT], types.CallToolRequestParams], - Awaitable[types.CallToolResult | types.CreateTaskResult], + Awaitable[types.CallToolResult], ] | None = None, on_list_resources: Callable[ @@ -197,7 +195,6 @@ def __init__( self._notification_handlers: dict[ str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[None]] ] = {} - self._experimental_handlers: ExperimentalHandlers[LifespanResultT] | None = None self._session_manager: StreamableHTTPSessionManager | None = None logger.debug("Initializing server %r", name) @@ -242,10 +239,6 @@ def _add_request_handler( """Add a request handler, silently replacing any existing handler for the same method.""" self._request_handlers[method] = handler - def _has_handler(self, method: str) -> bool: - """Check if a handler is registered for the given method.""" - return method in self._request_handlers or method in self._notification_handlers - # TODO: Rethink capabilities API. Currently capabilities are derived from registered # handlers but require NotificationOptions to be passed externally for list_changed # flags, and experimental_capabilities as a separate dict. Consider deriving capabilities @@ -323,25 +316,8 @@ def get_capabilities( experimental=experimental_capabilities, completions=completions_capability, ) - if self._experimental_handlers: - self._experimental_handlers.update_capabilities(capabilities) return capabilities - @property - def experimental(self) -> ExperimentalHandlers[LifespanResultT]: - """Experimental APIs for tasks and other features. - - WARNING: These APIs are experimental and may change without notice. - """ - - # We create this inline so we only add these capabilities _if_ they're actually used - if self._experimental_handlers is None: - self._experimental_handlers = ExperimentalHandlers( - add_request_handler=self._add_request_handler, - has_handler=self._has_handler, - ) - return self._experimental_handlers - @property def session_manager(self) -> StreamableHTTPSessionManager: """Get the StreamableHTTP session manager. @@ -383,12 +359,6 @@ async def run( ) ) - # Configure task support for this session if enabled - task_support = self._experimental_handlers.task_support if self._experimental_handlers else None - if task_support is not None: - task_support.configure_session(session) - await stack.enter_async_context(task_support.run()) - async with anyio.create_task_group() as tg: try: async for message in session.incoming_messages: @@ -476,23 +446,11 @@ async def _handle_request( close_sse_stream_cb = message.message_metadata.close_sse_stream close_standalone_sse_stream_cb = message.message_metadata.close_standalone_sse_stream - client_capabilities = session.client_params.capabilities if session.client_params else None - task_support = self._experimental_handlers.task_support if self._experimental_handlers else None - # Get task metadata from request params if present - task_metadata = None - if hasattr(req, "params") and req.params is not None: # pragma: no branch - task_metadata = getattr(req.params, "task", None) ctx = ServerRequestContext( request_id=message.request_id, meta=message.request_meta, session=session, lifespan_context=lifespan_context, - experimental=Experimental( - task_metadata=task_metadata, - _client_capabilities=client_capabilities, - _session=session, - _task_support=task_support, - ), request=request_data, close_sse_stream=close_sse_stream_cb, close_standalone_sse_stream=close_standalone_sse_stream_cb, @@ -543,17 +501,9 @@ async def _handle_notification( logger.debug("Dispatching notification of type %s", type(notify).__name__) try: - client_capabilities = session.client_params.capabilities if session.client_params else None - task_support = self._experimental_handlers.task_support if self._experimental_handlers else None ctx = ServerRequestContext( session=session, lifespan_context=lifespan_context, - experimental=Experimental( - task_metadata=None, - _client_capabilities=client_capabilities, - _session=session, - _task_support=task_support, - ), ) await handler(ctx, notify.params) except Exception: # pragma: no cover diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index fc2f97a9cb..3fc7bbf0d3 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -37,13 +37,10 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult: from pydantic import AnyUrl, TypeAdapter from mcp import types -from mcp.server.experimental.session_features import ExperimentalServerSessionFeatures from mcp.server.models import InitializationOptions from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.exceptions import StatelessModeNotSupported -from mcp.shared.experimental.tasks.capabilities import check_tasks_capability -from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import ( BaseSession, @@ -76,7 +73,6 @@ class ServerSession( ): _initialized: InitializationState = InitializationState.NotInitialized _client_params: types.InitializeRequestParams | None = None - _experimental_features: ExperimentalServerSessionFeatures | None = None def __init__( self, @@ -109,16 +105,6 @@ def _receive_notification_adapter(self) -> TypeAdapter[types.ClientNotification] def client_params(self) -> types.InitializeRequestParams | None: return self._client_params - @property - def experimental(self) -> ExperimentalServerSessionFeatures: - """Experimental APIs for server→client task operations. - - WARNING: These APIs are experimental and may change without notice. - """ - if self._experimental_features is None: - self._experimental_features = ExperimentalServerSessionFeatures(self) - return self._experimental_features - def check_client_capability(self, capability: types.ClientCapabilities) -> bool: """Check if the client supports a specific capability.""" if self._client_params is None: # pragma: lax no cover @@ -150,12 +136,6 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool: if exp_key not in client_caps.experimental or client_caps.experimental[exp_key] != exp_value: return False - if capability.tasks is not None: # pragma: lax no cover - if client_caps.tasks is None: - return False - if not check_tasks_capability(capability.tasks, client_caps.tasks): - return False - return True async def _receive_loop(self) -> None: @@ -509,181 +489,6 @@ async def send_elicit_complete( related_request_id, ) - def _build_elicit_form_request( - self, - message: str, - requested_schema: types.ElicitRequestedSchema, - related_task_id: str | None = None, - task: types.TaskMetadata | None = None, - ) -> types.JSONRPCRequest: - """Build a form mode elicitation request without sending it. - - Args: - message: The message to present to the user - requested_schema: Schema defining the expected response structure - related_task_id: If provided, adds io.modelcontextprotocol/related-task metadata - task: If provided, makes this a task-augmented request - - Returns: - A JSONRPCRequest ready to be sent or queued - """ - params = types.ElicitRequestFormParams( - message=message, - requested_schema=requested_schema, - task=task, - ) - params_data = params.model_dump(by_alias=True, mode="json", exclude_none=True) - - # Add related-task metadata if associated with a parent task - if related_task_id is not None: - # Defensive: model_dump() never includes _meta, but guard against future changes - if "_meta" not in params_data: # pragma: no branch - params_data["_meta"] = {} - params_data["_meta"][RELATED_TASK_METADATA_KEY] = types.RelatedTaskMetadata( - task_id=related_task_id - ).model_dump(by_alias=True) - - request_id = f"task-{related_task_id}-{id(params)}" if related_task_id else self._request_id - if related_task_id is None: - self._request_id += 1 - - return types.JSONRPCRequest( - jsonrpc="2.0", - id=request_id, - method="elicitation/create", - params=params_data, - ) - - def _build_elicit_url_request( - self, - message: str, - url: str, - elicitation_id: str, - related_task_id: str | None = None, - ) -> types.JSONRPCRequest: - """Build a URL mode elicitation request without sending it. - - Args: - message: Human-readable explanation of why the interaction is needed - url: The URL the user should navigate to - elicitation_id: Unique identifier for tracking this elicitation - related_task_id: If provided, adds io.modelcontextprotocol/related-task metadata - - Returns: - A JSONRPCRequest ready to be sent or queued - """ - params = types.ElicitRequestURLParams( - message=message, - url=url, - elicitation_id=elicitation_id, - ) - params_data = params.model_dump(by_alias=True, mode="json", exclude_none=True) - - # Add related-task metadata if associated with a parent task - if related_task_id is not None: - # Defensive: model_dump() never includes _meta, but guard against future changes - if "_meta" not in params_data: # pragma: no branch - params_data["_meta"] = {} - params_data["_meta"][RELATED_TASK_METADATA_KEY] = types.RelatedTaskMetadata( - task_id=related_task_id - ).model_dump(by_alias=True) - - request_id = f"task-{related_task_id}-{id(params)}" if related_task_id else self._request_id - if related_task_id is None: - self._request_id += 1 - - return types.JSONRPCRequest( - jsonrpc="2.0", - id=request_id, - method="elicitation/create", - params=params_data, - ) - - def _build_create_message_request( - self, - messages: list[types.SamplingMessage], - *, - max_tokens: int, - system_prompt: str | None = None, - include_context: types.IncludeContext | None = None, - temperature: float | None = None, - stop_sequences: list[str] | None = None, - metadata: dict[str, Any] | None = None, - model_preferences: types.ModelPreferences | None = None, - tools: list[types.Tool] | None = None, - tool_choice: types.ToolChoice | None = None, - related_task_id: str | None = None, - task: types.TaskMetadata | None = None, - ) -> types.JSONRPCRequest: - """Build a sampling/createMessage request without sending it. - - Args: - messages: The conversation messages to send - max_tokens: Maximum number of tokens to generate - system_prompt: Optional system prompt - include_context: Optional context inclusion setting - temperature: Optional sampling temperature - stop_sequences: Optional stop sequences - metadata: Optional metadata to pass through to the LLM provider - model_preferences: Optional model selection preferences - tools: Optional list of tools the LLM can use during sampling - tool_choice: Optional control over tool usage behavior - related_task_id: If provided, adds io.modelcontextprotocol/related-task metadata - task: If provided, makes this a task-augmented request - - Returns: - A JSONRPCRequest ready to be sent or queued - """ - params = types.CreateMessageRequestParams( - messages=messages, - system_prompt=system_prompt, - include_context=include_context, - temperature=temperature, - max_tokens=max_tokens, - stop_sequences=stop_sequences, - metadata=metadata, - model_preferences=model_preferences, - tools=tools, - tool_choice=tool_choice, - task=task, - ) - params_data = params.model_dump(by_alias=True, mode="json", exclude_none=True) - - # Add related-task metadata if associated with a parent task - if related_task_id is not None: - # Defensive: model_dump() never includes _meta, but guard against future changes - if "_meta" not in params_data: # pragma: no branch - params_data["_meta"] = {} - params_data["_meta"][RELATED_TASK_METADATA_KEY] = types.RelatedTaskMetadata( - task_id=related_task_id - ).model_dump(by_alias=True) - - request_id = f"task-{related_task_id}-{id(params)}" if related_task_id else self._request_id - if related_task_id is None: - self._request_id += 1 - - return types.JSONRPCRequest( - jsonrpc="2.0", - id=request_id, - method="sampling/createMessage", - params=params_data, - ) - - async def send_message(self, message: SessionMessage) -> None: - """Send a raw session message. - - This is primarily used by TaskResultHandler to deliver queued messages - (elicitation/sampling requests) to the client during task execution. - - WARNING: This is a low-level experimental method that may change without - notice. Prefer using higher-level methods like send_notification() or - send_request() for normal operations. - - Args: - message: The session message to send - """ - await self._write_stream.send(message) - async def _handle_incoming(self, req: ServerRequestResponder) -> None: await self._incoming_message_stream_writer.send(req) diff --git a/src/mcp/server/validation.py b/src/mcp/server/validation.py index 5708628074..08f5754f1e 100644 --- a/src/mcp/server/validation.py +++ b/src/mcp/server/validation.py @@ -1,7 +1,6 @@ """Shared validation functions for server requests. -This module provides validation logic for sampling and elicitation requests -that is shared across normal and task-augmented code paths. +This module provides validation logic for sampling and elicitation requests. """ from mcp.shared.exceptions import MCPError diff --git a/src/mcp/shared/experimental/__init__.py b/src/mcp/shared/experimental/__init__.py deleted file mode 100644 index fa6940acc6..0000000000 --- a/src/mcp/shared/experimental/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Pure experimental MCP features (no server dependencies). - -WARNING: These APIs are experimental and may change without notice. - -For server-integrated experimental features, use mcp.server.experimental. -""" diff --git a/src/mcp/shared/experimental/tasks/__init__.py b/src/mcp/shared/experimental/tasks/__init__.py deleted file mode 100644 index 52793e408b..0000000000 --- a/src/mcp/shared/experimental/tasks/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Pure task state management for MCP. - -WARNING: These APIs are experimental and may change without notice. - -Import directly from submodules: -- mcp.shared.experimental.tasks.store.TaskStore -- mcp.shared.experimental.tasks.context.TaskContext -- mcp.shared.experimental.tasks.in_memory_task_store.InMemoryTaskStore -- mcp.shared.experimental.tasks.message_queue.TaskMessageQueue -- mcp.shared.experimental.tasks.helpers.is_terminal -""" diff --git a/src/mcp/shared/experimental/tasks/capabilities.py b/src/mcp/shared/experimental/tasks/capabilities.py deleted file mode 100644 index 51fe64ecc3..0000000000 --- a/src/mcp/shared/experimental/tasks/capabilities.py +++ /dev/null @@ -1,96 +0,0 @@ -"""Tasks capability checking utilities. - -This module provides functions for checking and requiring task-related -capabilities. All tasks capability logic is centralized here to keep -the main session code clean. - -WARNING: These APIs are experimental and may change without notice. -""" - -from mcp.shared.exceptions import MCPError -from mcp.types import INVALID_REQUEST, ClientCapabilities, ClientTasksCapability - - -def check_tasks_capability( - required: ClientTasksCapability, - client: ClientTasksCapability, -) -> bool: - """Check if client's tasks capability matches the required capability. - - Args: - required: The capability being checked for - client: The client's declared capabilities - - Returns: - True if client has the required capability, False otherwise - """ - if required.requests is None: - return True - if client.requests is None: - return False - - # Check elicitation.create - if required.requests.elicitation is not None: - if client.requests.elicitation is None: - return False - if required.requests.elicitation.create is not None: - if client.requests.elicitation.create is None: - return False - - # Check sampling.createMessage - if required.requests.sampling is not None: - if client.requests.sampling is None: - return False - if required.requests.sampling.create_message is not None: - if client.requests.sampling.create_message is None: - return False - - return True - - -def has_task_augmented_elicitation(caps: ClientCapabilities) -> bool: - """Check if capabilities include task-augmented elicitation support.""" - if caps.tasks is None: - return False - if caps.tasks.requests is None: - return False - if caps.tasks.requests.elicitation is None: - return False - return caps.tasks.requests.elicitation.create is not None - - -def has_task_augmented_sampling(caps: ClientCapabilities) -> bool: - """Check if capabilities include task-augmented sampling support.""" - if caps.tasks is None: - return False - if caps.tasks.requests is None: - return False - if caps.tasks.requests.sampling is None: - return False - return caps.tasks.requests.sampling.create_message is not None - - -def require_task_augmented_elicitation(client_caps: ClientCapabilities | None) -> None: - """Raise MCPError if client doesn't support task-augmented elicitation. - - Args: - client_caps: The client's declared capabilities, or None if not initialized - - Raises: - MCPError: If client doesn't support task-augmented elicitation - """ - if client_caps is None or not has_task_augmented_elicitation(client_caps): - raise MCPError(code=INVALID_REQUEST, message="Client does not support task-augmented elicitation") - - -def require_task_augmented_sampling(client_caps: ClientCapabilities | None) -> None: - """Raise MCPError if client doesn't support task-augmented sampling. - - Args: - client_caps: The client's declared capabilities, or None if not initialized - - Raises: - MCPError: If client doesn't support task-augmented sampling - """ - if client_caps is None or not has_task_augmented_sampling(client_caps): - raise MCPError(code=INVALID_REQUEST, message="Client does not support task-augmented sampling") diff --git a/src/mcp/shared/experimental/tasks/context.py b/src/mcp/shared/experimental/tasks/context.py deleted file mode 100644 index ed0d2b91b6..0000000000 --- a/src/mcp/shared/experimental/tasks/context.py +++ /dev/null @@ -1,95 +0,0 @@ -"""TaskContext - Pure task state management. - -This module provides TaskContext, which manages task state without any -server/session dependencies. It can be used standalone for distributed -workers or wrapped by ServerTaskContext for full server integration. -""" - -from mcp.shared.experimental.tasks.store import TaskStore -from mcp.types import TASK_STATUS_COMPLETED, TASK_STATUS_FAILED, Result, Task - - -class TaskContext: - """Pure task state management - no session dependencies. - - This class handles: - - Task state (status, result) - - Cancellation tracking - - Store interactions - - For server-integrated features (elicit, create_message, notifications), - use ServerTaskContext from mcp.server.experimental. - - Example (distributed worker): - async def worker_job(task_id: str): - store = RedisTaskStore(redis_url) - task = await store.get_task(task_id) - ctx = TaskContext(task=task, store=store) - - await ctx.update_status("Working...") - result = await do_work() - await ctx.complete(result) - """ - - def __init__(self, task: Task, store: TaskStore): - self._task = task - self._store = store - self._cancelled = False - - @property - def task_id(self) -> str: - """The task identifier.""" - return self._task.task_id - - @property - def task(self) -> Task: - """The current task state.""" - return self._task - - @property - def is_cancelled(self) -> bool: - """Whether cancellation has been requested.""" - return self._cancelled - - def request_cancellation(self) -> None: - """Request cancellation of this task. - - This sets is_cancelled=True. Task work should check this - periodically and exit gracefully if set. - """ - self._cancelled = True - - async def update_status(self, message: str) -> None: - """Update the task's status message. - - Args: - message: The new status message - """ - self._task = await self._store.update_task( - self.task_id, - status_message=message, - ) - - async def complete(self, result: Result) -> None: - """Mark the task as completed with the given result. - - Args: - result: The task result - """ - await self._store.store_result(self.task_id, result) - self._task = await self._store.update_task( - self.task_id, - status=TASK_STATUS_COMPLETED, - ) - - async def fail(self, error: str) -> None: - """Mark the task as failed with an error message. - - Args: - error: The error message - """ - self._task = await self._store.update_task( - self.task_id, - status=TASK_STATUS_FAILED, - status_message=error, - ) diff --git a/src/mcp/shared/experimental/tasks/helpers.py b/src/mcp/shared/experimental/tasks/helpers.py deleted file mode 100644 index 3f91cd0d06..0000000000 --- a/src/mcp/shared/experimental/tasks/helpers.py +++ /dev/null @@ -1,166 +0,0 @@ -"""Helper functions for pure task management. - -These helpers work with pure TaskContext and don't require server dependencies. -For server-integrated task helpers, use mcp.server.experimental. -""" - -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager -from datetime import datetime, timezone -from uuid import uuid4 - -from mcp.shared.exceptions import MCPError -from mcp.shared.experimental.tasks.context import TaskContext -from mcp.shared.experimental.tasks.store import TaskStore -from mcp.types import ( - INVALID_PARAMS, - TASK_STATUS_CANCELLED, - TASK_STATUS_COMPLETED, - TASK_STATUS_FAILED, - TASK_STATUS_WORKING, - CancelTaskResult, - Task, - TaskMetadata, - TaskStatus, -) - -# Metadata key for model-immediate-response (per MCP spec) -# Servers MAY include this in CreateTaskResult._meta to provide an immediate -# response string while the task executes in the background. -MODEL_IMMEDIATE_RESPONSE_KEY = "io.modelcontextprotocol/model-immediate-response" - -# Metadata key for associating requests with a task (per MCP spec) -RELATED_TASK_METADATA_KEY = "io.modelcontextprotocol/related-task" - - -def is_terminal(status: TaskStatus) -> bool: - """Check if a task status represents a terminal state. - - Terminal states are those where the task has finished and will not change. - - Args: - status: The task status to check - - Returns: - True if the status is terminal (completed, failed, or cancelled) - """ - return status in (TASK_STATUS_COMPLETED, TASK_STATUS_FAILED, TASK_STATUS_CANCELLED) - - -async def cancel_task( - store: TaskStore, - task_id: str, -) -> CancelTaskResult: - """Cancel a task with spec-compliant validation. - - Per spec: "Receivers MUST reject cancellation of terminal status tasks - with -32602 (Invalid params)" - - This helper validates that the task exists and is not in a terminal state - before setting it to "cancelled". - - Args: - store: The task store - task_id: The task identifier to cancel - - Returns: - CancelTaskResult with the cancelled task state - - Raises: - MCPError: With INVALID_PARAMS (-32602) if: - - Task does not exist - - Task is already in a terminal state (completed, failed, cancelled) - - Example: - ```python - async def handle_cancel(ctx, params: CancelTaskRequestParams) -> CancelTaskResult: - return await cancel_task(store, params.task_id) - ``` - """ - task = await store.get_task(task_id) - if task is None: - raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {task_id}") - - if is_terminal(task.status): - raise MCPError(code=INVALID_PARAMS, message=f"Cannot cancel task in terminal state '{task.status}'") - - # Update task to cancelled status - cancelled_task = await store.update_task(task_id, status=TASK_STATUS_CANCELLED) - return CancelTaskResult(**cancelled_task.model_dump()) - - -def generate_task_id() -> str: - """Generate a unique task ID.""" - return str(uuid4()) - - -def create_task_state( - metadata: TaskMetadata, - task_id: str | None = None, -) -> Task: - """Create a Task object with initial state. - - This is a helper for TaskStore implementations. - - Args: - metadata: Task metadata - task_id: Optional task ID (generated if not provided) - - Returns: - A new Task in "working" status - """ - now = datetime.now(timezone.utc) - return Task( - task_id=task_id or generate_task_id(), - status=TASK_STATUS_WORKING, - created_at=now, - last_updated_at=now, - ttl=metadata.ttl, - poll_interval=500, # Default 500ms poll interval - ) - - -@asynccontextmanager -async def task_execution( - task_id: str, - store: TaskStore, -) -> AsyncIterator[TaskContext]: - """Context manager for safe task execution (pure, no server dependencies). - - Loads a task from the store and provides a TaskContext for the work. - If an unhandled exception occurs, the task is automatically marked as failed - and the exception is suppressed (since the failure is captured in task state). - - This is useful for distributed workers that don't have a server session. - - Args: - task_id: The task identifier to execute - store: The task store (must be accessible by the worker) - - Yields: - TaskContext for updating status and completing/failing the task - - Raises: - ValueError: If the task is not found in the store - - Example (distributed worker): - async def worker_process(task_id: str): - store = RedisTaskStore(redis_url) - async with task_execution(task_id, store) as ctx: - await ctx.update_status("Working...") - result = await do_work() - await ctx.complete(result) - """ - task = await store.get_task(task_id) - if task is None: - raise ValueError(f"Task {task_id} not found") - - ctx = TaskContext(task, store) - try: - yield ctx - except Exception as e: - # Auto-fail the task if an exception occurs and task isn't already terminal - # Exception is suppressed since failure is captured in task state - if not is_terminal(ctx.task.status): - await ctx.fail(str(e)) - # Don't re-raise - the failure is recorded in task state diff --git a/src/mcp/shared/experimental/tasks/in_memory_task_store.py b/src/mcp/shared/experimental/tasks/in_memory_task_store.py deleted file mode 100644 index 42f4fb7035..0000000000 --- a/src/mcp/shared/experimental/tasks/in_memory_task_store.py +++ /dev/null @@ -1,217 +0,0 @@ -"""In-memory implementation of TaskStore for demonstration purposes. - -This implementation stores all tasks in memory and provides automatic cleanup -based on the TTL duration specified in the task metadata using lazy expiration. - -Note: This is not suitable for production use as all data is lost on restart. -For production, consider implementing TaskStore with a database or distributed cache. -""" - -from dataclasses import dataclass, field -from datetime import datetime, timedelta, timezone - -import anyio - -from mcp.shared.experimental.tasks.helpers import create_task_state, is_terminal -from mcp.shared.experimental.tasks.store import TaskStore -from mcp.types import Result, Task, TaskMetadata, TaskStatus - - -@dataclass -class StoredTask: - """Internal storage representation of a task.""" - - task: Task - result: Result | None = None - # Time when this task should be removed (None = never) - expires_at: datetime | None = field(default=None) - - -class InMemoryTaskStore(TaskStore): - """A simple in-memory implementation of TaskStore. - - Features: - - Automatic TTL-based cleanup (lazy expiration) - - Thread-safe for single-process async use - - Pagination support for list_tasks - - Limitations: - - All data lost on restart - - Not suitable for distributed systems - - No persistence - - For production, implement TaskStore with Redis, PostgreSQL, etc. - """ - - def __init__(self, page_size: int = 10) -> None: - self._tasks: dict[str, StoredTask] = {} - self._page_size = page_size - self._update_events: dict[str, anyio.Event] = {} - - def _calculate_expiry(self, ttl_ms: int | None) -> datetime | None: - """Calculate expiry time from TTL in milliseconds.""" - if ttl_ms is None: - return None - return datetime.now(timezone.utc) + timedelta(milliseconds=ttl_ms) - - def _is_expired(self, stored: StoredTask) -> bool: - """Check if a task has expired.""" - if stored.expires_at is None: - return False - return datetime.now(timezone.utc) >= stored.expires_at - - def _cleanup_expired(self) -> None: - """Remove all expired tasks. Called lazily during access operations.""" - expired_ids = [task_id for task_id, stored in self._tasks.items() if self._is_expired(stored)] - for task_id in expired_ids: - del self._tasks[task_id] - - async def create_task( - self, - metadata: TaskMetadata, - task_id: str | None = None, - ) -> Task: - """Create a new task with the given metadata.""" - # Cleanup expired tasks on access - self._cleanup_expired() - - task = create_task_state(metadata, task_id) - - if task.task_id in self._tasks: - raise ValueError(f"Task with ID {task.task_id} already exists") - - stored = StoredTask( - task=task, - expires_at=self._calculate_expiry(metadata.ttl), - ) - self._tasks[task.task_id] = stored - - # Return a copy to prevent external modification - return Task(**task.model_dump()) - - async def get_task(self, task_id: str) -> Task | None: - """Get a task by ID.""" - # Cleanup expired tasks on access - self._cleanup_expired() - - stored = self._tasks.get(task_id) - if stored is None: - return None - - # Return a copy to prevent external modification - return Task(**stored.task.model_dump()) - - async def update_task( - self, - task_id: str, - status: TaskStatus | None = None, - status_message: str | None = None, - ) -> Task: - """Update a task's status and/or message.""" - stored = self._tasks.get(task_id) - if stored is None: - raise ValueError(f"Task with ID {task_id} not found") - - # Per spec: Terminal states MUST NOT transition to any other status - if status is not None and status != stored.task.status and is_terminal(stored.task.status): - raise ValueError(f"Cannot transition from terminal status '{stored.task.status}'") - - status_changed = False - if status is not None and stored.task.status != status: - stored.task.status = status - status_changed = True - - if status_message is not None: - stored.task.status_message = status_message - - # Update last_updated_at on any change - stored.task.last_updated_at = datetime.now(timezone.utc) - - # If task is now terminal and has TTL, reset expiry timer - if status is not None and is_terminal(status) and stored.task.ttl is not None: - stored.expires_at = self._calculate_expiry(stored.task.ttl) - - # Notify waiters if status changed - if status_changed: - await self.notify_update(task_id) - - return Task(**stored.task.model_dump()) - - async def store_result(self, task_id: str, result: Result) -> None: - """Store the result for a task.""" - stored = self._tasks.get(task_id) - if stored is None: - raise ValueError(f"Task with ID {task_id} not found") - - stored.result = result - - async def get_result(self, task_id: str) -> Result | None: - """Get the stored result for a task.""" - stored = self._tasks.get(task_id) - if stored is None: - return None - - return stored.result - - async def list_tasks( - self, - cursor: str | None = None, - ) -> tuple[list[Task], str | None]: - """List tasks with pagination.""" - # Cleanup expired tasks on access - self._cleanup_expired() - - all_task_ids = list(self._tasks.keys()) - - start_index = 0 - if cursor is not None: - try: - cursor_index = all_task_ids.index(cursor) - start_index = cursor_index + 1 - except ValueError: - raise ValueError(f"Invalid cursor: {cursor}") - - page_task_ids = all_task_ids[start_index : start_index + self._page_size] - tasks = [Task(**self._tasks[tid].task.model_dump()) for tid in page_task_ids] - - # Determine next cursor - next_cursor = None - if start_index + self._page_size < len(all_task_ids) and page_task_ids: - next_cursor = page_task_ids[-1] - - return tasks, next_cursor - - async def delete_task(self, task_id: str) -> bool: - """Delete a task.""" - if task_id not in self._tasks: - return False - - del self._tasks[task_id] - return True - - async def wait_for_update(self, task_id: str) -> None: - """Wait until the task status changes.""" - if task_id not in self._tasks: - raise ValueError(f"Task with ID {task_id} not found") - - # Create a fresh event for waiting (anyio.Event can't be cleared) - self._update_events[task_id] = anyio.Event() - event = self._update_events[task_id] - await event.wait() - - async def notify_update(self, task_id: str) -> None: - """Signal that a task has been updated.""" - if task_id in self._update_events: - self._update_events[task_id].set() - - # --- Testing/debugging helpers --- - - def cleanup(self) -> None: - """Cleanup all tasks (useful for testing or graceful shutdown).""" - self._tasks.clear() - self._update_events.clear() - - def get_all_tasks(self) -> list[Task]: - """Get all tasks (useful for debugging). Returns copies to prevent modification.""" - self._cleanup_expired() - return [Task(**stored.task.model_dump()) for stored in self._tasks.values()] diff --git a/src/mcp/shared/experimental/tasks/message_queue.py b/src/mcp/shared/experimental/tasks/message_queue.py deleted file mode 100644 index e17c4a8650..0000000000 --- a/src/mcp/shared/experimental/tasks/message_queue.py +++ /dev/null @@ -1,230 +0,0 @@ -"""TaskMessageQueue - FIFO queue for task-related messages. - -This implements the core message queue pattern from the MCP Tasks spec. -When a handler needs to send a request (like elicitation) during a task-augmented -request, the message is enqueued instead of sent directly. Messages are delivered -to the client only through the `tasks/result` endpoint. - -This pattern enables: -1. Decoupling request handling from message delivery -2. Proper bidirectional communication via the tasks/result stream -3. Automatic status management (working <-> input_required) -""" - -from abc import ABC, abstractmethod -from collections import deque -from dataclasses import dataclass, field -from datetime import datetime, timezone -from typing import Any, Literal - -import anyio - -from mcp.shared.experimental.tasks.resolver import Resolver -from mcp.types import JSONRPCNotification, JSONRPCRequest, RequestId - - -@dataclass -class QueuedMessage: - """A message queued for delivery via tasks/result. - - Messages are stored with their type and a resolver for requests - that expect responses. - """ - - type: Literal["request", "notification"] - """Whether this is a request (expects response) or notification (one-way).""" - - message: JSONRPCRequest | JSONRPCNotification - """The JSON-RPC message to send.""" - - timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) - """When the message was enqueued.""" - - resolver: Resolver[dict[str, Any]] | None = None - """Resolver to set when response arrives (only for requests).""" - - original_request_id: RequestId | None = None - """The original request ID used internally, for routing responses back.""" - - -class TaskMessageQueue(ABC): - """Abstract interface for task message queuing. - - This is a FIFO queue that stores messages to be delivered via `tasks/result`. - When a task-augmented handler calls elicit() or sends a notification, the - message is enqueued here instead of being sent directly to the client. - - The `tasks/result` handler then dequeues and sends these messages through - the transport, with `relatedRequestId` set to the tasks/result request ID - so responses are routed correctly. - - Implementations can use in-memory storage, Redis, etc. - """ - - @abstractmethod - async def enqueue(self, task_id: str, message: QueuedMessage) -> None: - """Add a message to the queue for a task. - - Args: - task_id: The task identifier - message: The message to enqueue - """ - - @abstractmethod - async def dequeue(self, task_id: str) -> QueuedMessage | None: - """Remove and return the next message from the queue. - - Args: - task_id: The task identifier - - Returns: - The next message, or None if queue is empty - """ - - @abstractmethod - async def peek(self, task_id: str) -> QueuedMessage | None: - """Return the next message without removing it. - - Args: - task_id: The task identifier - - Returns: - The next message, or None if queue is empty - """ - - @abstractmethod - async def is_empty(self, task_id: str) -> bool: - """Check if the queue is empty for a task. - - Args: - task_id: The task identifier - - Returns: - True if no messages are queued - """ - - @abstractmethod - async def clear(self, task_id: str) -> list[QueuedMessage]: - """Remove and return all messages from the queue. - - This is useful for cleanup when a task is cancelled or completed. - - Args: - task_id: The task identifier - - Returns: - All queued messages (may be empty) - """ - - @abstractmethod - async def wait_for_message(self, task_id: str) -> None: - """Wait until a message is available in the queue. - - This blocks until either: - 1. A message is enqueued for this task - 2. The wait is cancelled - - Args: - task_id: The task identifier - """ - - @abstractmethod - async def notify_message_available(self, task_id: str) -> None: - """Signal that a message is available for a task. - - This wakes up any coroutines waiting in wait_for_message(). - - Args: - task_id: The task identifier - """ - - -class InMemoryTaskMessageQueue(TaskMessageQueue): - """In-memory implementation of TaskMessageQueue. - - This is suitable for single-process servers. For distributed systems, - implement TaskMessageQueue with Redis, RabbitMQ, etc. - - Features: - - FIFO ordering per task - - Async wait for message availability - - Thread-safe for single-process async use - """ - - def __init__(self) -> None: - self._queues: dict[str, deque[QueuedMessage]] = {} - self._events: dict[str, anyio.Event] = {} - - def _get_queue(self, task_id: str) -> deque[QueuedMessage]: - """Get or create the queue for a task.""" - if task_id not in self._queues: - self._queues[task_id] = deque() - return self._queues[task_id] - - async def enqueue(self, task_id: str, message: QueuedMessage) -> None: - """Add a message to the queue.""" - queue = self._get_queue(task_id) - queue.append(message) - # Signal that a message is available - await self.notify_message_available(task_id) - - async def dequeue(self, task_id: str) -> QueuedMessage | None: - """Remove and return the next message.""" - queue = self._get_queue(task_id) - if not queue: - return None - return queue.popleft() - - async def peek(self, task_id: str) -> QueuedMessage | None: - """Return the next message without removing it.""" - queue = self._get_queue(task_id) - if not queue: - return None - return queue[0] - - async def is_empty(self, task_id: str) -> bool: - """Check if the queue is empty.""" - queue = self._get_queue(task_id) - return len(queue) == 0 - - async def clear(self, task_id: str) -> list[QueuedMessage]: - """Remove and return all messages.""" - queue = self._get_queue(task_id) - messages = list(queue) - queue.clear() - return messages - - async def wait_for_message(self, task_id: str) -> None: - """Wait until a message is available.""" - # Check if there are already messages - if not await self.is_empty(task_id): - return - - # Create a fresh event for waiting (anyio.Event can't be cleared) - self._events[task_id] = anyio.Event() - event = self._events[task_id] - - # Double-check after creating event (avoid race condition) - if not await self.is_empty(task_id): - return - - # Wait for a new message - await event.wait() - - async def notify_message_available(self, task_id: str) -> None: - """Signal that a message is available.""" - if task_id in self._events: - self._events[task_id].set() - - def cleanup(self, task_id: str | None = None) -> None: - """Clean up queues and events. - - Args: - task_id: If provided, clean up only this task. Otherwise clean up all. - """ - if task_id is not None: - self._queues.pop(task_id, None) - self._events.pop(task_id, None) - else: - self._queues.clear() - self._events.clear() diff --git a/src/mcp/shared/experimental/tasks/polling.py b/src/mcp/shared/experimental/tasks/polling.py deleted file mode 100644 index e4e13b6640..0000000000 --- a/src/mcp/shared/experimental/tasks/polling.py +++ /dev/null @@ -1,43 +0,0 @@ -"""Shared polling utilities for task operations. - -This module provides generic polling logic that works for both client→server -and server→client task polling. - -WARNING: These APIs are experimental and may change without notice. -""" - -from collections.abc import AsyncIterator, Awaitable, Callable - -import anyio - -from mcp.shared.experimental.tasks.helpers import is_terminal -from mcp.types import GetTaskResult - - -async def poll_until_terminal( - get_task: Callable[[str], Awaitable[GetTaskResult]], - task_id: str, - default_interval_ms: int = 500, -) -> AsyncIterator[GetTaskResult]: - """Poll a task until it reaches terminal status. - - This is a generic utility that works for both client→server and server→client - polling. The caller provides the get_task function appropriate for their direction. - - Args: - get_task: Async function that takes task_id and returns GetTaskResult - task_id: The task to poll - default_interval_ms: Fallback poll interval if server doesn't specify - - Yields: - GetTaskResult for each poll - """ - while True: - status = await get_task(task_id) - yield status - - if is_terminal(status.status): - break - - interval_ms = status.poll_interval if status.poll_interval is not None else default_interval_ms - await anyio.sleep(interval_ms / 1000) diff --git a/src/mcp/shared/experimental/tasks/resolver.py b/src/mcp/shared/experimental/tasks/resolver.py deleted file mode 100644 index 1d233a9309..0000000000 --- a/src/mcp/shared/experimental/tasks/resolver.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Resolver - An anyio-compatible future-like object for async result passing. - -This provides a simple way to pass a result (or exception) from one coroutine -to another without depending on asyncio.Future. -""" - -from typing import Generic, TypeVar, cast - -import anyio - -T = TypeVar("T") - - -class Resolver(Generic[T]): - """A simple resolver for passing results between coroutines. - - Unlike asyncio.Future, this works with any anyio-compatible async backend. - - Usage: - resolver: Resolver[str] = Resolver() - - # In one coroutine: - resolver.set_result("hello") - - # In another coroutine: - result = await resolver.wait() # returns "hello" - """ - - def __init__(self) -> None: - self._event = anyio.Event() - self._value: T | None = None - self._exception: BaseException | None = None - - def set_result(self, value: T) -> None: - """Set the result value and wake up waiters.""" - if self._event.is_set(): - raise RuntimeError("Resolver already completed") - self._value = value - self._event.set() - - def set_exception(self, exc: BaseException) -> None: - """Set an exception and wake up waiters.""" - if self._event.is_set(): - raise RuntimeError("Resolver already completed") - self._exception = exc - self._event.set() - - async def wait(self) -> T: - """Wait for the result and return it, or raise the exception.""" - await self._event.wait() - if self._exception is not None: - raise self._exception - # If we reach here, set_result() was called, so _value is set - return cast(T, self._value) - - def done(self) -> bool: - """Return True if the resolver has been completed.""" - return self._event.is_set() diff --git a/src/mcp/shared/experimental/tasks/store.py b/src/mcp/shared/experimental/tasks/store.py deleted file mode 100644 index 7de97d40ca..0000000000 --- a/src/mcp/shared/experimental/tasks/store.py +++ /dev/null @@ -1,144 +0,0 @@ -"""TaskStore - Abstract interface for task state storage.""" - -from abc import ABC, abstractmethod - -from mcp.types import Result, Task, TaskMetadata, TaskStatus - - -class TaskStore(ABC): - """Abstract interface for task state storage. - - This is a pure storage interface - it doesn't manage execution. - Implementations can use in-memory storage, databases, Redis, etc. - - All methods are async to support various backends. - """ - - @abstractmethod - async def create_task( - self, - metadata: TaskMetadata, - task_id: str | None = None, - ) -> Task: - """Create a new task. - - Args: - metadata: Task metadata (ttl, etc.) - task_id: Optional task ID. If None, implementation should generate one. - - Returns: - The created Task with status="working" - - Raises: - ValueError: If task_id already exists - """ - - @abstractmethod - async def get_task(self, task_id: str) -> Task | None: - """Get a task by ID. - - Args: - task_id: The task identifier - - Returns: - The Task, or None if not found - """ - - @abstractmethod - async def update_task( - self, - task_id: str, - status: TaskStatus | None = None, - status_message: str | None = None, - ) -> Task: - """Update a task's status and/or message. - - Args: - task_id: The task identifier - status: New status (if changing) - status_message: New status message (if changing) - - Returns: - The updated Task - - Raises: - ValueError: If task not found - ValueError: If attempting to transition from a terminal status - (completed, failed, cancelled). Per spec, terminal states - MUST NOT transition to any other status. - """ - - @abstractmethod - async def store_result(self, task_id: str, result: Result) -> None: - """Store the result for a task. - - Args: - task_id: The task identifier - result: The result to store - - Raises: - ValueError: If task not found - """ - - @abstractmethod - async def get_result(self, task_id: str) -> Result | None: - """Get the stored result for a task. - - Args: - task_id: The task identifier - - Returns: - The stored Result, or None if not available - """ - - @abstractmethod - async def list_tasks( - self, - cursor: str | None = None, - ) -> tuple[list[Task], str | None]: - """List tasks with pagination. - - Args: - cursor: Optional cursor for pagination - - Returns: - Tuple of (tasks, next_cursor). next_cursor is None if no more pages. - """ - - @abstractmethod - async def delete_task(self, task_id: str) -> bool: - """Delete a task. - - Args: - task_id: The task identifier - - Returns: - True if deleted, False if not found - """ - - @abstractmethod - async def wait_for_update(self, task_id: str) -> None: - """Wait until the task status changes. - - This blocks until either: - 1. The task status changes - 2. The wait is cancelled - - Used by tasks/result to wait for task completion or status changes. - - Args: - task_id: The task identifier - - Raises: - ValueError: If task not found - """ - - @abstractmethod - async def notify_update(self, task_id: str) -> None: - """Signal that a task has been updated. - - This wakes up any coroutines waiting in wait_for_update(). - - Args: - task_id: The task identifier - """ diff --git a/src/mcp/shared/response_router.py b/src/mcp/shared/response_router.py deleted file mode 100644 index fe24b016f1..0000000000 --- a/src/mcp/shared/response_router.py +++ /dev/null @@ -1,61 +0,0 @@ -"""ResponseRouter - Protocol for pluggable response routing. - -This module defines a protocol for routing JSON-RPC responses to alternative -handlers before falling back to the default response stream mechanism. - -The primary use case is task-augmented requests: when a TaskSession enqueues -a request (like elicitation), the response needs to be routed back to the -waiting resolver instead of the normal response stream. - -Design: -- Protocol-based for testability and flexibility -- Returns bool to indicate if response was handled -- Supports both success responses and errors -""" - -from typing import Any, Protocol - -from mcp.types import ErrorData, RequestId - - -class ResponseRouter(Protocol): - """Protocol for routing responses to alternative handlers. - - Implementations check if they have a pending request for the given ID - and deliver the response/error to the appropriate handler. - - Example: - ```python - class TaskResultHandler(ResponseRouter): - def route_response(self, request_id, response): - resolver = self._pending_requests.pop(request_id, None) - if resolver: - resolver.set_result(response) - return True - return False - ``` - """ - - def route_response(self, request_id: RequestId, response: dict[str, Any]) -> bool: - """Try to route a response to a pending request handler. - - Args: - request_id: The JSON-RPC request ID from the response - response: The response result data - - Returns: - True if the response was handled, False otherwise - """ - ... # pragma: no cover - - def route_error(self, request_id: RequestId, error: ErrorData) -> bool: - """Try to route an error to a pending request handler. - - Args: - request_id: The JSON-RPC request ID from the error response - error: The error data - - Returns: - True if the error was handled, False otherwise - """ - ... # pragma: no cover diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 9c72a23844..ea5d8833bd 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -17,7 +17,6 @@ from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.exceptions import MCPError from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage -from mcp.shared.response_router import ResponseRouter from mcp.types import ( CONNECTION_CLOSED, INVALID_PARAMS, @@ -183,7 +182,6 @@ class BaseSession( _request_id: int _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] _progress_callbacks: dict[RequestId, ProgressFnT] - _response_routers: list[ResponseRouter] def __init__( self, @@ -199,24 +197,8 @@ def __init__( self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} self._progress_callbacks = {} - self._response_routers = [] self._exit_stack = AsyncExitStack() - def add_response_router(self, router: ResponseRouter) -> None: - """Register a response router to handle responses for non-standard requests. - - Response routers are checked in order before falling back to the default - response stream mechanism. This is used by TaskResultHandler to route - responses for queued task requests back to their resolvers. - - !!! warning - This is an experimental API that may change without notice. - - Args: - router: A ResponseRouter implementation - """ - self._response_routers.append(router) - async def __aenter__(self) -> Self: self._task_group = anyio.create_task_group() await self._task_group.__aenter__() @@ -477,11 +459,7 @@ def _normalize_request_id(self, response_id: RequestId) -> RequestId: return response_id async def _handle_response(self, message: SessionMessage) -> None: - """Handle an incoming response or error message. - - Checks response routers first (e.g., for task-related responses), - then falls back to the normal response stream mechanism. - """ + """Handle an incoming response or error message.""" # This check is always true at runtime: the caller (_receive_loop) only invokes # this method in the else branch after checking for JSONRPCRequest and # JSONRPCNotification. However, the type checker can't infer this from the @@ -498,20 +476,6 @@ async def _handle_response(self, message: SessionMessage) -> None: # Normalize response ID to handle type mismatches (e.g., "0" vs 0) response_id = self._normalize_request_id(message.message.id) - # First, check response routers (e.g., TaskResultHandler) - if isinstance(message.message, JSONRPCError): - # Route error to routers - for router in self._response_routers: - if router.route_error(response_id, message.message.error): - return # Handled - else: - # Route success response to routers - response_data: dict[str, Any] = message.message.result or {} - for router in self._response_routers: - if router.route_response(response_id, response_data): - return # Handled - - # Fall back to normal response streams stream = self._response_streams.pop(response_id, None) if stream: await stream.send(message.message) diff --git a/src/mcp/types/__init__.py b/src/mcp/types/__init__.py index b442303937..b2d537fb70 100644 --- a/src/mcp/types/__init__.py +++ b/src/mcp/types/__init__.py @@ -8,14 +8,6 @@ from mcp.types._types import ( DEFAULT_NEGOTIATED_VERSION, LATEST_PROTOCOL_VERSION, - TASK_FORBIDDEN, - TASK_OPTIONAL, - TASK_REQUIRED, - TASK_STATUS_CANCELLED, - TASK_STATUS_COMPLETED, - TASK_STATUS_FAILED, - TASK_STATUS_INPUT_REQUIRED, - TASK_STATUS_WORKING, Annotations, AudioContent, BaseMetadata, @@ -25,15 +17,10 @@ CallToolResult, CancelledNotification, CancelledNotificationParams, - CancelTaskRequest, - CancelTaskRequestParams, - CancelTaskResult, ClientCapabilities, ClientNotification, ClientRequest, ClientResult, - ClientTasksCapability, - ClientTasksRequestsCapability, CompleteRequest, CompleteRequestParams, CompleteResult, @@ -46,7 +33,6 @@ CreateMessageRequestParams, CreateMessageResult, CreateMessageResultWithTools, - CreateTaskResult, ElicitationCapability, ElicitationRequiredErrorData, ElicitCompleteNotification, @@ -63,12 +49,6 @@ GetPromptRequest, GetPromptRequestParams, GetPromptResult, - GetTaskPayloadRequest, - GetTaskPayloadRequestParams, - GetTaskPayloadResult, - GetTaskRequest, - GetTaskRequestParams, - GetTaskResult, Icon, IconTheme, ImageContent, @@ -86,8 +66,6 @@ ListResourceTemplatesResult, ListRootsRequest, ListRootsResult, - ListTasksRequest, - ListTasksResult, ListToolsRequest, ListToolsResult, LoggingCapability, @@ -114,7 +92,6 @@ ReadResourceRequest, ReadResourceRequestParams, ReadResourceResult, - RelatedTaskMetadata, Request, RequestParams, RequestParamsMeta, @@ -142,33 +119,16 @@ ServerNotification, ServerRequest, ServerResult, - ServerTasksCapability, - ServerTasksRequestsCapability, SetLevelRequest, SetLevelRequestParams, StopReason, SubscribeRequest, SubscribeRequestParams, - Task, - TaskExecutionMode, - TaskMetadata, - TasksCallCapability, - TasksCancelCapability, - TasksCreateElicitationCapability, - TasksCreateMessageCapability, - TasksElicitationCapability, - TasksListCapability, - TasksSamplingCapability, - TaskStatus, - TaskStatusNotification, - TaskStatusNotificationParams, - TasksToolsCapability, TextContent, TextResourceContents, Tool, ToolAnnotations, ToolChoice, - ToolExecution, ToolListChangedNotification, ToolResultContent, ToolsCapability, @@ -208,16 +168,6 @@ # Protocol version constants "LATEST_PROTOCOL_VERSION", "DEFAULT_NEGOTIATED_VERSION", - # Task execution mode constants - "TASK_FORBIDDEN", - "TASK_OPTIONAL", - "TASK_REQUIRED", - # Task status constants - "TASK_STATUS_CANCELLED", - "TASK_STATUS_COMPLETED", - "TASK_STATUS_FAILED", - "TASK_STATUS_INPUT_REQUIRED", - "TASK_STATUS_WORKING", # Type aliases and variables "ContentBlock", "ElicitRequestedSchema", @@ -229,8 +179,6 @@ "SamplingContent", "SamplingMessageContentBlock", "StopReason", - "TaskExecutionMode", - "TaskStatus", # Base classes "BaseMetadata", "Request", @@ -245,8 +193,6 @@ "EmptyResult", # Capabilities "ClientCapabilities", - "ClientTasksCapability", - "ClientTasksRequestsCapability", "CompletionsCapability", "ElicitationCapability", "FormElicitationCapability", @@ -258,16 +204,6 @@ "SamplingContextCapability", "SamplingToolsCapability", "ServerCapabilities", - "ServerTasksCapability", - "ServerTasksRequestsCapability", - "TasksCancelCapability", - "TasksCallCapability", - "TasksCreateElicitationCapability", - "TasksCreateMessageCapability", - "TasksElicitationCapability", - "TasksListCapability", - "TasksSamplingCapability", - "TasksToolsCapability", "ToolsCapability", "UrlElicitationCapability", # Content types @@ -300,18 +236,12 @@ "ResourceTemplateReference", "Root", "SamplingMessage", - "Task", - "TaskMetadata", - "RelatedTaskMetadata", "Tool", "ToolAnnotations", "ToolChoice", - "ToolExecution", # Requests "CallToolRequest", "CallToolRequestParams", - "CancelTaskRequest", - "CancelTaskRequestParams", "CompleteRequest", "CompleteRequestParams", "CreateMessageRequest", @@ -321,17 +251,12 @@ "ElicitRequestURLParams", "GetPromptRequest", "GetPromptRequestParams", - "GetTaskPayloadRequest", - "GetTaskPayloadRequestParams", - "GetTaskRequest", - "GetTaskRequestParams", "InitializeRequest", "InitializeRequestParams", "ListPromptsRequest", "ListResourcesRequest", "ListResourceTemplatesRequest", "ListRootsRequest", - "ListTasksRequest", "ListToolsRequest", "PingRequest", "ReadResourceRequest", @@ -344,22 +269,17 @@ "UnsubscribeRequestParams", # Results "CallToolResult", - "CancelTaskResult", "CompleteResult", "CreateMessageResult", "CreateMessageResultWithTools", - "CreateTaskResult", "ElicitResult", "ElicitationRequiredErrorData", "GetPromptResult", - "GetTaskPayloadResult", - "GetTaskResult", "InitializeResult", "ListPromptsResult", "ListResourcesResult", "ListResourceTemplatesResult", "ListRootsResult", - "ListTasksResult", "ListToolsResult", "ReadResourceResult", # Notifications @@ -377,8 +297,6 @@ "ResourceUpdatedNotification", "ResourceUpdatedNotificationParams", "RootsListChangedNotification", - "TaskStatusNotification", - "TaskStatusNotificationParams", "ToolListChangedNotification", # Union types for request/response routing "ClientNotification", diff --git a/src/mcp/types/_types.py b/src/mcp/types/_types.py index 9005d253af..34800ba12e 100644 --- a/src/mcp/types/_types.py +++ b/src/mcp/types/_types.py @@ -1,7 +1,6 @@ from __future__ import annotations -from datetime import datetime -from typing import Annotated, Any, Final, Generic, Literal, TypeAlias, TypeVar +from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar from pydantic import BaseModel, ConfigDict, Field, FileUrl, TypeAdapter from pydantic.alias_generators import to_camel @@ -30,11 +29,6 @@ IconTheme = Literal["light", "dark"] -TaskExecutionMode = Literal["forbidden", "optional", "required"] -TASK_FORBIDDEN: Final[Literal["forbidden"]] = "forbidden" -TASK_OPTIONAL: Final[Literal["optional"]] = "optional" -TASK_REQUIRED: Final[Literal["required"]] = "required" - class MCPModel(BaseModel): """Base class for all MCP protocol types.""" @@ -55,27 +49,7 @@ class RequestParamsMeta(TypedDict, extra_items=Any): """ -class TaskMetadata(MCPModel): - """Metadata for augmenting a request with task execution. - - Include this in the `task` field of the request parameters. - """ - - ttl: Annotated[int, Field(strict=True)] | None = None - """Requested duration in milliseconds to retain task from creation.""" - - class RequestParams(MCPModel): - task: TaskMetadata | None = None - """ - If specified, the caller is requesting task-augmented execution for this request. - The request will return a CreateTaskResult immediately, and the actual result can be - retrieved later via tasks/result. - - Task augmentation is subject to capability negotiation - receivers MUST declare support - for task augmentation of specific request types in their capabilities. - """ - meta: RequestParamsMeta | None = Field(alias="_meta", default=None) @@ -258,55 +232,6 @@ class SamplingCapability(MCPModel): """ -class TasksListCapability(MCPModel): - """Capability for tasks listing operations.""" - - -class TasksCancelCapability(MCPModel): - """Capability for tasks cancel operations.""" - - -class TasksCreateMessageCapability(MCPModel): - """Capability for tasks create messages.""" - - -class TasksSamplingCapability(MCPModel): - """Capability for tasks sampling operations.""" - - create_message: TasksCreateMessageCapability | None = None - - -class TasksCreateElicitationCapability(MCPModel): - """Capability for tasks create elicitation operations.""" - - -class TasksElicitationCapability(MCPModel): - """Capability for tasks elicitation operations.""" - - create: TasksCreateElicitationCapability | None = None - - -class ClientTasksRequestsCapability(MCPModel): - """Capability for tasks requests operations.""" - - sampling: TasksSamplingCapability | None = None - - elicitation: TasksElicitationCapability | None = None - - -class ClientTasksCapability(MCPModel): - """Capability for client tasks operations.""" - - list: TasksListCapability | None = None - """Whether this client supports tasks/list.""" - - cancel: TasksCancelCapability | None = None - """Whether this client supports tasks/cancel.""" - - requests: ClientTasksRequestsCapability | None = None - """Specifies which request types can be augmented with tasks.""" - - class ClientCapabilities(MCPModel): """Capabilities a client may support.""" @@ -321,8 +246,6 @@ class ClientCapabilities(MCPModel): """Present if the client supports elicitation from the user.""" roots: RootsCapability | None = None """Present if the client supports listing roots.""" - tasks: ClientTasksCapability | None = None - """Present if the client supports task-augmented requests.""" class PromptsCapability(MCPModel): @@ -356,30 +279,6 @@ class CompletionsCapability(MCPModel): """Capability for completions operations.""" -class TasksCallCapability(MCPModel): - """Capability for tasks call operations.""" - - -class TasksToolsCapability(MCPModel): - """Capability for tasks tools operations.""" - - call: TasksCallCapability | None = None - - -class ServerTasksRequestsCapability(MCPModel): - """Capability for tasks requests operations.""" - - tools: TasksToolsCapability | None = None - - -class ServerTasksCapability(MCPModel): - """Capability for server tasks operations.""" - - list: TasksListCapability | None = None - cancel: TasksCancelCapability | None = None - requests: ServerTasksRequestsCapability | None = None - - class ServerCapabilities(MCPModel): """Capabilities that a server may support.""" @@ -401,146 +300,6 @@ class ServerCapabilities(MCPModel): completions: CompletionsCapability | None = None """Present if the server offers autocompletion suggestions for prompts and resources.""" - tasks: ServerTasksCapability | None = None - """Present if the server supports task-augmented requests.""" - - -TaskStatus = Literal["working", "input_required", "completed", "failed", "cancelled"] - -# Task status constants -TASK_STATUS_WORKING: Final[Literal["working"]] = "working" -TASK_STATUS_INPUT_REQUIRED: Final[Literal["input_required"]] = "input_required" -TASK_STATUS_COMPLETED: Final[Literal["completed"]] = "completed" -TASK_STATUS_FAILED: Final[Literal["failed"]] = "failed" -TASK_STATUS_CANCELLED: Final[Literal["cancelled"]] = "cancelled" - - -class RelatedTaskMetadata(MCPModel): - """Metadata for associating messages with a task. - - Include this in the `_meta` field under the key `io.modelcontextprotocol/related-task`. - """ - - task_id: str - """The task identifier this message is associated with.""" - - -class Task(MCPModel): - """Data associated with a task.""" - - task_id: str - """The task identifier.""" - - status: TaskStatus - """Current task state.""" - - status_message: str | None = None - """Optional human-readable message describing the current task state. - - This can provide context for any status, including: - - Reasons for "cancelled" status - - Summaries for "completed" status - - Diagnostic information for "failed" status (e.g., error details, what went wrong) - """ - - created_at: datetime # Pydantic will enforce ISO 8601 and re-serialize as a string later - """ISO 8601 timestamp when the task was created.""" - - last_updated_at: datetime - """ISO 8601 timestamp when the task was last updated.""" - - ttl: Annotated[int, Field(strict=True)] | None - """Actual retention duration from creation in milliseconds, null for unlimited.""" - - poll_interval: Annotated[int, Field(strict=True)] | None = None - """Suggested polling interval in milliseconds.""" - - -class CreateTaskResult(Result): - """A response to a task-augmented request.""" - - task: Task - - -class GetTaskRequestParams(RequestParams): - task_id: str - """The task identifier to query.""" - - -class GetTaskRequest(Request[GetTaskRequestParams, Literal["tasks/get"]]): - """A request to retrieve the state of a task.""" - - method: Literal["tasks/get"] = "tasks/get" - - params: GetTaskRequestParams - - -class GetTaskResult(Result, Task): - """The response to a tasks/get request.""" - - -class GetTaskPayloadRequestParams(RequestParams): - task_id: str - """The task identifier to retrieve results for.""" - - -class GetTaskPayloadRequest(Request[GetTaskPayloadRequestParams, Literal["tasks/result"]]): - """A request to retrieve the result of a completed task.""" - - method: Literal["tasks/result"] = "tasks/result" - params: GetTaskPayloadRequestParams - - -class GetTaskPayloadResult(Result): - """The response to a tasks/result request. - - The structure matches the result type of the original request. - For example, a tools/call task would return the CallToolResult structure. - """ - - model_config = ConfigDict(extra="allow", alias_generator=to_camel, populate_by_name=True) - - -class CancelTaskRequestParams(RequestParams): - task_id: str - """The task identifier to cancel.""" - - -class CancelTaskRequest(Request[CancelTaskRequestParams, Literal["tasks/cancel"]]): - """A request to cancel a task.""" - - method: Literal["tasks/cancel"] = "tasks/cancel" - params: CancelTaskRequestParams - - -class CancelTaskResult(Result, Task): - """The response to a tasks/cancel request.""" - - -class ListTasksRequest(PaginatedRequest[Literal["tasks/list"]]): - """A request to retrieve a list of tasks.""" - - method: Literal["tasks/list"] = "tasks/list" - - -class ListTasksResult(PaginatedResult): - """The response to a tasks/list request.""" - - tasks: list[Task] - - -class TaskStatusNotificationParams(NotificationParams, Task): - """Parameters for a `notifications/tasks/status` notification.""" - - -class TaskStatusNotification(Notification[TaskStatusNotificationParams, Literal["notifications/tasks/status"]]): - """An optional notification from the receiver to the requestor, informing them that a task's status has changed. - Receivers are not required to send these notifications. - """ - - method: Literal["notifications/tasks/status"] = "notifications/tasks/status" - params: TaskStatusNotificationParams - class InitializeRequestParams(RequestParams): """Parameters for the initialize request.""" @@ -1133,23 +892,6 @@ class ToolAnnotations(MCPModel): """ -class ToolExecution(MCPModel): - """Execution-related properties for a tool.""" - - task_support: TaskExecutionMode | None = None - """ - Indicates whether this tool supports task-augmented execution. - This allows clients to handle long-running operations through polling - the task system. - - - "forbidden": Tool does not support task-augmented execution (default when absent) - - "optional": Tool may support task-augmented execution - - "required": Tool requires task-augmented execution - - Default: "forbidden" - """ - - class Tool(BaseMetadata): """Definition for a tool the client can call.""" @@ -1172,8 +914,6 @@ class Tool(BaseMetadata): for notes on _meta usage. """ - execution: ToolExecution | None = None - class ListToolsResult(PaginatedResult): """The server's response to a tools/list request from the client.""" @@ -1554,8 +1294,6 @@ class CancelledNotificationParams(NotificationParams): The ID of the request to cancel. This MUST correspond to the ID of a request previously issued in the same direction. - This MUST be provided for cancelling non-task requests. - This MUST NOT be used for cancelling tasks (use the `tasks/cancel` request instead). """ reason: str | None = None """An optional string describing the reason for the cancellation.""" @@ -1607,20 +1345,12 @@ class ElicitCompleteNotification( | UnsubscribeRequest | CallToolRequest | ListToolsRequest - | GetTaskRequest - | GetTaskPayloadRequest - | ListTasksRequest - | CancelTaskRequest ) client_request_adapter = TypeAdapter[ClientRequest](ClientRequest) ClientNotification = ( - CancelledNotification - | ProgressNotification - | InitializedNotification - | RootsListChangedNotification - | TaskStatusNotification + CancelledNotification | ProgressNotification | InitializedNotification | RootsListChangedNotification ) client_notification_adapter = TypeAdapter[ClientNotification](ClientNotification) @@ -1716,31 +1446,11 @@ class ElicitationRequiredErrorData(MCPModel): """List of URL mode elicitations that must be completed.""" -ClientResult = ( - EmptyResult - | CreateMessageResult - | CreateMessageResultWithTools - | ListRootsResult - | ElicitResult - | GetTaskResult - | GetTaskPayloadResult - | ListTasksResult - | CancelTaskResult - | CreateTaskResult -) +ClientResult = EmptyResult | CreateMessageResult | CreateMessageResultWithTools | ListRootsResult | ElicitResult client_result_adapter = TypeAdapter[ClientResult](ClientResult) -ServerRequest = ( - PingRequest - | CreateMessageRequest - | ListRootsRequest - | ElicitRequest - | GetTaskRequest - | GetTaskPayloadRequest - | ListTasksRequest - | CancelTaskRequest -) +ServerRequest = PingRequest | CreateMessageRequest | ListRootsRequest | ElicitRequest server_request_adapter = TypeAdapter[ServerRequest](ServerRequest) @@ -1753,7 +1463,6 @@ class ElicitationRequiredErrorData(MCPModel): | ToolListChangedNotification | PromptListChangedNotification | ElicitCompleteNotification - | TaskStatusNotification ) server_notification_adapter = TypeAdapter[ServerNotification](ServerNotification) @@ -1769,10 +1478,5 @@ class ElicitationRequiredErrorData(MCPModel): | ReadResourceResult | CallToolResult | ListToolsResult - | GetTaskResult - | GetTaskPayloadResult - | ListTasksResult - | CancelTaskResult - | CreateTaskResult ) server_result_adapter = TypeAdapter[ServerResult](ServerResult) diff --git a/tests/experimental/__init__.py b/tests/experimental/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/experimental/tasks/__init__.py b/tests/experimental/tasks/__init__.py deleted file mode 100644 index 6e8649d283..0000000000 --- a/tests/experimental/tasks/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Tests for MCP task support.""" diff --git a/tests/experimental/tasks/client/__init__.py b/tests/experimental/tasks/client/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/experimental/tasks/client/test_capabilities.py b/tests/experimental/tasks/client/test_capabilities.py deleted file mode 100644 index 1ea2199e8c..0000000000 --- a/tests/experimental/tasks/client/test_capabilities.py +++ /dev/null @@ -1,312 +0,0 @@ -"""Tests for client task capabilities declaration during initialization.""" - -import anyio -import pytest - -from mcp import ClientCapabilities, types -from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers -from mcp.client.session import ClientSession -from mcp.shared._context import RequestContext -from mcp.shared.message import SessionMessage -from mcp.types import ( - LATEST_PROTOCOL_VERSION, - Implementation, - InitializeRequest, - InitializeResult, - JSONRPCRequest, - JSONRPCResponse, - ServerCapabilities, - client_request_adapter, -) - - -@pytest.mark.anyio -async def test_client_capabilities_without_tasks(): - """Test that tasks capability is None when not provided.""" - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - - received_capabilities = None - - async def mock_server(): - nonlocal received_capabilities - - session_message = await client_to_server_receive.receive() - jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request, JSONRPCRequest) - request = client_request_adapter.validate_python( - jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) - ) - assert isinstance(request, InitializeRequest) - received_capabilities = request.params.capabilities - - result = InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) - - async with server_to_client_send: - await server_to_client_send.send( - SessionMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) - ) - ) - await client_to_server_receive.receive() - - async with ( - ClientSession( - server_to_client_receive, - client_to_server_send, - ) as session, - anyio.create_task_group() as tg, - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - ): - tg.start_soon(mock_server) - await session.initialize() - - # Assert that tasks capability is None when not provided - assert received_capabilities is not None - assert received_capabilities.tasks is None - - -@pytest.mark.anyio -async def test_client_capabilities_with_tasks(): - """Test that tasks capability is properly set when handlers are provided.""" - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - - received_capabilities: ClientCapabilities | None = None - - # Define custom handlers to trigger capability building (never actually called) - async def my_list_tasks_handler( - context: RequestContext[ClientSession], - params: types.PaginatedRequestParams | None, - ) -> types.ListTasksResult | types.ErrorData: - raise NotImplementedError - - async def my_cancel_task_handler( - context: RequestContext[ClientSession], - params: types.CancelTaskRequestParams, - ) -> types.CancelTaskResult | types.ErrorData: - raise NotImplementedError - - async def mock_server(): - nonlocal received_capabilities - - session_message = await client_to_server_receive.receive() - jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request, JSONRPCRequest) - request = client_request_adapter.validate_python( - jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) - ) - assert isinstance(request, InitializeRequest) - received_capabilities = request.params.capabilities - - result = InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) - - async with server_to_client_send: - await server_to_client_send.send( - SessionMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) - ) - ) - await client_to_server_receive.receive() - - # Create handlers container - task_handlers = ExperimentalTaskHandlers( - list_tasks=my_list_tasks_handler, - cancel_task=my_cancel_task_handler, - ) - - async with ( - ClientSession( - server_to_client_receive, - client_to_server_send, - experimental_task_handlers=task_handlers, - ) as session, - anyio.create_task_group() as tg, - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - ): - tg.start_soon(mock_server) - await session.initialize() - - # Assert that tasks capability is properly set from handlers - assert received_capabilities is not None - assert received_capabilities.tasks is not None - assert isinstance(received_capabilities.tasks, types.ClientTasksCapability) - assert received_capabilities.tasks.list is not None - assert received_capabilities.tasks.cancel is not None - - -@pytest.mark.anyio -async def test_client_capabilities_auto_built_from_handlers(): - """Test that tasks capability is automatically built from provided handlers.""" - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - - received_capabilities: ClientCapabilities | None = None - - # Define custom handlers (not defaults) - async def my_list_tasks_handler( - context: RequestContext[ClientSession], - params: types.PaginatedRequestParams | None, - ) -> types.ListTasksResult | types.ErrorData: - raise NotImplementedError - - async def my_cancel_task_handler( - context: RequestContext[ClientSession], - params: types.CancelTaskRequestParams, - ) -> types.CancelTaskResult | types.ErrorData: - raise NotImplementedError - - async def mock_server(): - nonlocal received_capabilities - - session_message = await client_to_server_receive.receive() - jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request, JSONRPCRequest) - request = client_request_adapter.validate_python( - jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) - ) - assert isinstance(request, InitializeRequest) - received_capabilities = request.params.capabilities - - result = InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) - - async with server_to_client_send: - await server_to_client_send.send( - SessionMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) - ) - ) - await client_to_server_receive.receive() - - # Provide handlers via ExperimentalTaskHandlers - task_handlers = ExperimentalTaskHandlers( - list_tasks=my_list_tasks_handler, - cancel_task=my_cancel_task_handler, - ) - - async with ( - ClientSession( - server_to_client_receive, - client_to_server_send, - experimental_task_handlers=task_handlers, - ) as session, - anyio.create_task_group() as tg, - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - ): - tg.start_soon(mock_server) - await session.initialize() - - # Assert that tasks capability was auto-built from handlers - assert received_capabilities is not None - assert received_capabilities.tasks is not None - assert received_capabilities.tasks.list is not None - assert received_capabilities.tasks.cancel is not None - # requests should be None since we didn't provide task-augmented handlers - assert received_capabilities.tasks.requests is None - - -@pytest.mark.anyio -async def test_client_capabilities_with_task_augmented_handlers(): - """Test that requests capability is built when augmented handlers are provided.""" - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - - received_capabilities: ClientCapabilities | None = None - - # Define task-augmented handler - async def my_augmented_sampling_handler( - context: RequestContext[ClientSession], - params: types.CreateMessageRequestParams, - task_metadata: types.TaskMetadata, - ) -> types.CreateTaskResult | types.ErrorData: - raise NotImplementedError - - async def mock_server(): - nonlocal received_capabilities - - session_message = await client_to_server_receive.receive() - jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request, JSONRPCRequest) - request = client_request_adapter.validate_python( - jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) - ) - assert isinstance(request, InitializeRequest) - received_capabilities = request.params.capabilities - - result = InitializeResult( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ) - - async with server_to_client_send: - await server_to_client_send.send( - SessionMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) - ) - ) - await client_to_server_receive.receive() - - # Provide task-augmented sampling handler - task_handlers = ExperimentalTaskHandlers( - augmented_sampling=my_augmented_sampling_handler, - ) - - async with ( - ClientSession( - server_to_client_receive, - client_to_server_send, - experimental_task_handlers=task_handlers, - ) as session, - anyio.create_task_group() as tg, - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - ): - tg.start_soon(mock_server) - await session.initialize() - - # Assert that tasks capability includes requests.sampling - assert received_capabilities is not None - assert received_capabilities.tasks is not None - assert received_capabilities.tasks.requests is not None - assert received_capabilities.tasks.requests.sampling is not None - assert received_capabilities.tasks.requests.elicitation is None # Not provided diff --git a/tests/experimental/tasks/client/test_handlers.py b/tests/experimental/tasks/client/test_handlers.py deleted file mode 100644 index 137ff80106..0000000000 --- a/tests/experimental/tasks/client/test_handlers.py +++ /dev/null @@ -1,874 +0,0 @@ -"""Tests for client-side task management handlers (server -> client requests). - -These tests verify that clients can handle task-related requests from servers: -- GetTaskRequest - server polling client's task status -- GetTaskPayloadRequest - server getting result from client's task -- ListTasksRequest - server listing client's tasks -- CancelTaskRequest - server cancelling client's task - -This is the inverse of the existing tests in test_tasks.py, which test -client -> server task requests. -""" - -from collections.abc import AsyncIterator -from dataclasses import dataclass - -import anyio -import pytest -from anyio import Event -from anyio.abc import TaskGroup -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream - -from mcp import types -from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers -from mcp.client.session import ClientSession -from mcp.shared._context import RequestContext -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder -from mcp.types import ( - CancelTaskRequest, - CancelTaskRequestParams, - CancelTaskResult, - ClientResult, - CreateMessageRequest, - CreateMessageRequestParams, - CreateMessageResult, - CreateTaskResult, - ElicitRequest, - ElicitRequestFormParams, - ElicitRequestParams, - ElicitResult, - ErrorData, - GetTaskPayloadRequest, - GetTaskPayloadRequestParams, - GetTaskPayloadResult, - GetTaskRequest, - GetTaskRequestParams, - GetTaskResult, - ListTasksRequest, - ListTasksResult, - SamplingMessage, - ServerNotification, - ServerRequest, - TaskMetadata, - TextContent, -) - -# Buffer size for test streams -STREAM_BUFFER_SIZE = 10 - - -@dataclass -class ClientTestStreams: - """Bidirectional message streams for client/server communication in tests.""" - - server_send: MemoryObjectSendStream[SessionMessage] - server_receive: MemoryObjectReceiveStream[SessionMessage] - client_send: MemoryObjectSendStream[SessionMessage] - client_receive: MemoryObjectReceiveStream[SessionMessage] - - -@pytest.fixture -async def client_streams() -> AsyncIterator[ClientTestStreams]: - """Create bidirectional message streams for client tests. - - Automatically closes all streams after the test completes. - """ - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage]( - STREAM_BUFFER_SIZE - ) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage]( - STREAM_BUFFER_SIZE - ) - - streams = ClientTestStreams( - server_send=server_to_client_send, - server_receive=client_to_server_receive, - client_send=client_to_server_send, - client_receive=server_to_client_receive, - ) - - yield streams - - # Cleanup - await server_to_client_send.aclose() - await server_to_client_receive.aclose() - await client_to_server_send.aclose() - await client_to_server_receive.aclose() - - -async def _default_message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, -) -> None: - """Default message handler that ignores messages (tests handle them explicitly).""" - ... - - -@pytest.mark.anyio -async def test_client_handles_get_task_request(client_streams: ClientTestStreams) -> None: - """Test that client can respond to GetTaskRequest from server.""" - with anyio.fail_after(10): - store = InMemoryTaskStore() - received_task_id: str | None = None - - async def get_task_handler( - context: RequestContext[ClientSession], - params: GetTaskRequestParams, - ) -> GetTaskResult | ErrorData: - nonlocal received_task_id - received_task_id = params.task_id - task = await store.get_task(params.task_id) - assert task is not None, f"Test setup error: task {params.task_id} should exist" - return GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=task.poll_interval, - ) - - await store.create_task(TaskMetadata(ttl=60000), task_id="test-task-123") - - task_handlers = ExperimentalTaskHandlers(get_task=get_task_handler) - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - - async def run_client() -> None: - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - experimental_task_handlers=task_handlers, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - typed_request = GetTaskRequest(params=GetTaskRequestParams(task_id="test-task-123")) - request = types.JSONRPCRequest(jsonrpc="2.0", id="req-1", **typed_request.model_dump(by_alias=True)) - await client_streams.server_send.send(SessionMessage(request)) - - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCResponse) - assert response.id == "req-1" - - result = GetTaskResult.model_validate(response.result) - assert result.task_id == "test-task-123" - assert result.status == "working" - assert received_task_id == "test-task-123" - - tg.cancel_scope.cancel() - - store.cleanup() - - -@pytest.mark.anyio -async def test_client_handles_get_task_result_request(client_streams: ClientTestStreams) -> None: - """Test that client can respond to GetTaskPayloadRequest from server.""" - with anyio.fail_after(10): - store = InMemoryTaskStore() - - async def get_task_result_handler( - context: RequestContext[ClientSession], - params: GetTaskPayloadRequestParams, - ) -> GetTaskPayloadResult | ErrorData: - result = await store.get_result(params.task_id) - assert result is not None, f"Test setup error: result for {params.task_id} should exist" - assert isinstance(result, types.CallToolResult) - return GetTaskPayloadResult(**result.model_dump()) - - await store.create_task(TaskMetadata(ttl=60000), task_id="test-task-456") - await store.store_result( - "test-task-456", - types.CallToolResult(content=[TextContent(type="text", text="Task completed successfully!")]), - ) - await store.update_task("test-task-456", status="completed") - - task_handlers = ExperimentalTaskHandlers(get_task_result=get_task_result_handler) - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - - async def run_client() -> None: - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - experimental_task_handlers=task_handlers, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - typed_request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id="test-task-456")) - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-2", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCResponse) - - assert isinstance(response.result, dict) - result_dict = response.result - assert "content" in result_dict - assert len(result_dict["content"]) == 1 - assert result_dict["content"][0]["text"] == "Task completed successfully!" - - tg.cancel_scope.cancel() - - store.cleanup() - - -@pytest.mark.anyio -async def test_client_handles_list_tasks_request(client_streams: ClientTestStreams) -> None: - """Test that client can respond to ListTasksRequest from server.""" - with anyio.fail_after(10): - store = InMemoryTaskStore() - - async def list_tasks_handler( - context: RequestContext[ClientSession], - params: types.PaginatedRequestParams | None, - ) -> ListTasksResult | ErrorData: - cursor = params.cursor if params else None - tasks_list, next_cursor = await store.list_tasks(cursor=cursor) - return ListTasksResult(tasks=tasks_list, next_cursor=next_cursor) - - await store.create_task(TaskMetadata(ttl=60000), task_id="task-1") - await store.create_task(TaskMetadata(ttl=60000), task_id="task-2") - - task_handlers = ExperimentalTaskHandlers(list_tasks=list_tasks_handler) - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - - async def run_client() -> None: - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - experimental_task_handlers=task_handlers, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - typed_request = ListTasksRequest() - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-3", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCResponse) - - result = ListTasksResult.model_validate(response.result) - assert len(result.tasks) == 2 - - tg.cancel_scope.cancel() - - store.cleanup() - - -@pytest.mark.anyio -async def test_client_handles_cancel_task_request(client_streams: ClientTestStreams) -> None: - """Test that client can respond to CancelTaskRequest from server.""" - with anyio.fail_after(10): - store = InMemoryTaskStore() - - async def cancel_task_handler( - context: RequestContext[ClientSession], - params: CancelTaskRequestParams, - ) -> CancelTaskResult | ErrorData: - task = await store.get_task(params.task_id) - assert task is not None, f"Test setup error: task {params.task_id} should exist" - await store.update_task(params.task_id, status="cancelled") - updated = await store.get_task(params.task_id) - assert updated is not None - return CancelTaskResult( - task_id=updated.task_id, - status=updated.status, - created_at=updated.created_at, - last_updated_at=updated.last_updated_at, - ttl=updated.ttl, - ) - - await store.create_task(TaskMetadata(ttl=60000), task_id="task-to-cancel") - - task_handlers = ExperimentalTaskHandlers(cancel_task=cancel_task_handler) - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - - async def run_client() -> None: - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - experimental_task_handlers=task_handlers, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - typed_request = CancelTaskRequest(params=CancelTaskRequestParams(task_id="task-to-cancel")) - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-4", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCResponse) - - result = CancelTaskResult.model_validate(response.result) - assert result.task_id == "task-to-cancel" - assert result.status == "cancelled" - - tg.cancel_scope.cancel() - - store.cleanup() - - -@pytest.mark.anyio -async def test_client_task_augmented_sampling(client_streams: ClientTestStreams) -> None: - """Test that client can handle task-augmented sampling request from server.""" - with anyio.fail_after(10): - store = InMemoryTaskStore() - sampling_completed = Event() - created_task_id: list[str | None] = [None] - background_tg: list[TaskGroup | None] = [None] - - async def task_augmented_sampling_callback( - context: RequestContext[ClientSession], - params: CreateMessageRequestParams, - task_metadata: TaskMetadata, - ) -> CreateTaskResult: - task = await store.create_task(task_metadata) - created_task_id[0] = task.task_id - - async def do_sampling() -> None: - result = CreateMessageResult( - role="assistant", - content=TextContent(type="text", text="Sampled response"), - model="test-model", - stop_reason="endTurn", - ) - await store.store_result(task.task_id, result) - await store.update_task(task.task_id, status="completed") - sampling_completed.set() - - assert background_tg[0] is not None - background_tg[0].start_soon(do_sampling) - return CreateTaskResult(task=task) - - async def get_task_handler( - context: RequestContext[ClientSession], - params: GetTaskRequestParams, - ) -> GetTaskResult | ErrorData: - task = await store.get_task(params.task_id) - assert task is not None, f"Test setup error: task {params.task_id} should exist" - return GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=task.poll_interval, - ) - - async def get_task_result_handler( - context: RequestContext[ClientSession], - params: GetTaskPayloadRequestParams, - ) -> GetTaskPayloadResult | ErrorData: - result = await store.get_result(params.task_id) - assert result is not None, f"Test setup error: result for {params.task_id} should exist" - assert isinstance(result, CreateMessageResult) - return GetTaskPayloadResult(**result.model_dump()) - - task_handlers = ExperimentalTaskHandlers( - augmented_sampling=task_augmented_sampling_callback, - get_task=get_task_handler, - get_task_result=get_task_result_handler, - ) - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - background_tg[0] = tg - - async def run_client() -> None: - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - experimental_task_handlers=task_handlers, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - # Step 1: Server sends task-augmented CreateMessageRequest - typed_request = CreateMessageRequest( - params=CreateMessageRequestParams( - messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], - max_tokens=100, - task=TaskMetadata(ttl=60000), - ) - ) - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-sampling", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - # Step 2: Client responds with CreateTaskResult - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCResponse) - - task_result = CreateTaskResult.model_validate(response.result) - task_id = task_result.task.task_id - assert task_id == created_task_id[0] - - # Step 3: Wait for background sampling - await sampling_completed.wait() - - # Step 4: Server polls task status - typed_poll = GetTaskRequest(params=GetTaskRequestParams(task_id=task_id)) - poll_request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-poll", - **typed_poll.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(poll_request)) - - poll_response_msg = await client_streams.server_receive.receive() - poll_response = poll_response_msg.message - assert isinstance(poll_response, types.JSONRPCResponse) - - status = GetTaskResult.model_validate(poll_response.result) - assert status.status == "completed" - - # Step 5: Server gets result - typed_result_req = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=task_id)) - result_request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-result", - **typed_result_req.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(result_request)) - - result_response_msg = await client_streams.server_receive.receive() - result_response = result_response_msg.message - assert isinstance(result_response, types.JSONRPCResponse) - - assert isinstance(result_response.result, dict) - assert result_response.result["role"] == "assistant" - - tg.cancel_scope.cancel() - - store.cleanup() - - -@pytest.mark.anyio -async def test_client_task_augmented_elicitation(client_streams: ClientTestStreams) -> None: - """Test that client can handle task-augmented elicitation request from server.""" - with anyio.fail_after(10): - store = InMemoryTaskStore() - elicitation_completed = Event() - created_task_id: list[str | None] = [None] - background_tg: list[TaskGroup | None] = [None] - - async def task_augmented_elicitation_callback( - context: RequestContext[ClientSession], - params: ElicitRequestParams, - task_metadata: TaskMetadata, - ) -> CreateTaskResult | ErrorData: - task = await store.create_task(task_metadata) - created_task_id[0] = task.task_id - - async def do_elicitation() -> None: - # Simulate user providing elicitation response - result = ElicitResult(action="accept", content={"name": "Test User"}) - await store.store_result(task.task_id, result) - await store.update_task(task.task_id, status="completed") - elicitation_completed.set() - - assert background_tg[0] is not None - background_tg[0].start_soon(do_elicitation) - return CreateTaskResult(task=task) - - async def get_task_handler( - context: RequestContext[ClientSession], - params: GetTaskRequestParams, - ) -> GetTaskResult | ErrorData: - task = await store.get_task(params.task_id) - assert task is not None, f"Test setup error: task {params.task_id} should exist" - return GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=task.poll_interval, - ) - - async def get_task_result_handler( - context: RequestContext[ClientSession], - params: GetTaskPayloadRequestParams, - ) -> GetTaskPayloadResult | ErrorData: - result = await store.get_result(params.task_id) - assert result is not None, f"Test setup error: result for {params.task_id} should exist" - assert isinstance(result, ElicitResult) - return GetTaskPayloadResult(**result.model_dump()) - - task_handlers = ExperimentalTaskHandlers( - augmented_elicitation=task_augmented_elicitation_callback, - get_task=get_task_handler, - get_task_result=get_task_result_handler, - ) - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - background_tg[0] = tg - - async def run_client() -> None: - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - experimental_task_handlers=task_handlers, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - # Step 1: Server sends task-augmented ElicitRequest - typed_request = ElicitRequest( - params=ElicitRequestFormParams( - message="What is your name?", - requested_schema={"type": "object", "properties": {"name": {"type": "string"}}}, - task=TaskMetadata(ttl=60000), - ) - ) - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-elicit", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - # Step 2: Client responds with CreateTaskResult - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCResponse) - - task_result = CreateTaskResult.model_validate(response.result) - task_id = task_result.task.task_id - assert task_id == created_task_id[0] - - # Step 3: Wait for background elicitation - await elicitation_completed.wait() - - # Step 4: Server polls task status - typed_poll = GetTaskRequest(params=GetTaskRequestParams(task_id=task_id)) - poll_request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-poll", - **typed_poll.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(poll_request)) - - poll_response_msg = await client_streams.server_receive.receive() - poll_response = poll_response_msg.message - assert isinstance(poll_response, types.JSONRPCResponse) - - status = GetTaskResult.model_validate(poll_response.result) - assert status.status == "completed" - - # Step 5: Server gets result - typed_result_req = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=task_id)) - result_request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-result", - **typed_result_req.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(result_request)) - - result_response_msg = await client_streams.server_receive.receive() - result_response = result_response_msg.message - assert isinstance(result_response, types.JSONRPCResponse) - - # Verify the elicitation result - assert isinstance(result_response.result, dict) - assert result_response.result["action"] == "accept" - assert result_response.result["content"] == {"name": "Test User"} - - tg.cancel_scope.cancel() - - store.cleanup() - - -@pytest.mark.anyio -async def test_client_returns_error_for_unhandled_task_request(client_streams: ClientTestStreams) -> None: - """Test that client returns error when no handler is registered for task request.""" - with anyio.fail_after(10): - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - - async def run_client() -> None: - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - typed_request = GetTaskRequest(params=GetTaskRequestParams(task_id="nonexistent")) - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-unhandled", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCError) - assert ( - "not supported" in response.error.message.lower() - or "method not found" in response.error.message.lower() - ) - - tg.cancel_scope.cancel() - - -@pytest.mark.anyio -async def test_client_returns_error_for_unhandled_task_result_request(client_streams: ClientTestStreams) -> None: - """Test that client returns error for unhandled tasks/result request.""" - with anyio.fail_after(10): - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - - async def run_client() -> None: - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - typed_request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id="nonexistent")) - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-result", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCError) - assert "not supported" in response.error.message.lower() - - tg.cancel_scope.cancel() - - -@pytest.mark.anyio -async def test_client_returns_error_for_unhandled_list_tasks_request(client_streams: ClientTestStreams) -> None: - """Test that client returns error for unhandled tasks/list request.""" - with anyio.fail_after(10): - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - - async def run_client() -> None: - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - typed_request = ListTasksRequest() - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-list", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCError) - assert "not supported" in response.error.message.lower() - - tg.cancel_scope.cancel() - - -@pytest.mark.anyio -async def test_client_returns_error_for_unhandled_cancel_task_request(client_streams: ClientTestStreams) -> None: - """Test that client returns error for unhandled tasks/cancel request.""" - with anyio.fail_after(10): - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - - async def run_client() -> None: - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - typed_request = CancelTaskRequest(params=CancelTaskRequestParams(task_id="nonexistent")) - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-cancel", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCError) - assert "not supported" in response.error.message.lower() - - tg.cancel_scope.cancel() - - -@pytest.mark.anyio -async def test_client_returns_error_for_unhandled_task_augmented_sampling(client_streams: ClientTestStreams) -> None: - """Test that client returns error for task-augmented sampling without handler.""" - with anyio.fail_after(10): - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - - async def run_client() -> None: - # No task handlers provided - uses defaults - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - # Send task-augmented sampling request - typed_request = CreateMessageRequest( - params=CreateMessageRequestParams( - messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], - max_tokens=100, - task=TaskMetadata(ttl=60000), - ) - ) - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-sampling", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCError) - assert "not supported" in response.error.message.lower() - - tg.cancel_scope.cancel() - - -@pytest.mark.anyio -async def test_client_returns_error_for_unhandled_task_augmented_elicitation( - client_streams: ClientTestStreams, -) -> None: - """Test that client returns error for task-augmented elicitation without handler.""" - with anyio.fail_after(10): - client_ready = anyio.Event() - - async with anyio.create_task_group() as tg: - - async def run_client() -> None: - # No task handlers provided - uses defaults - async with ClientSession( - client_streams.client_receive, - client_streams.client_send, - message_handler=_default_message_handler, - ): - client_ready.set() - await anyio.sleep_forever() - - tg.start_soon(run_client) - await client_ready.wait() - - # Send task-augmented elicitation request - typed_request = ElicitRequest( - params=ElicitRequestFormParams( - message="What is your name?", - requested_schema={"type": "object", "properties": {"name": {"type": "string"}}}, - task=TaskMetadata(ttl=60000), - ) - ) - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-elicit", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(request)) - - response_msg = await client_streams.server_receive.receive() - response = response_msg.message - assert isinstance(response, types.JSONRPCError) - assert "not supported" in response.error.message.lower() - - tg.cancel_scope.cancel() diff --git a/tests/experimental/tasks/client/test_poll_task.py b/tests/experimental/tasks/client/test_poll_task.py deleted file mode 100644 index 5e3158d955..0000000000 --- a/tests/experimental/tasks/client/test_poll_task.py +++ /dev/null @@ -1,121 +0,0 @@ -"""Tests for poll_task async iterator.""" - -from collections.abc import Callable, Coroutine -from datetime import datetime, timezone -from typing import Any -from unittest.mock import AsyncMock - -import pytest - -from mcp.client.experimental.tasks import ExperimentalClientFeatures -from mcp.types import GetTaskResult, TaskStatus - - -def make_task_result( - status: TaskStatus = "working", - poll_interval: int = 0, - task_id: str = "test-task", - status_message: str | None = None, -) -> GetTaskResult: - """Create GetTaskResult with sensible defaults.""" - now = datetime.now(timezone.utc) - return GetTaskResult( - task_id=task_id, - status=status, - status_message=status_message, - created_at=now, - last_updated_at=now, - ttl=60000, - poll_interval=poll_interval, - ) - - -def make_status_sequence( - *statuses: TaskStatus, - task_id: str = "test-task", -) -> Callable[[str], Coroutine[Any, Any, GetTaskResult]]: - """Create mock get_task that returns statuses in sequence.""" - status_iter = iter(statuses) - - async def mock_get_task(tid: str) -> GetTaskResult: - return make_task_result(status=next(status_iter), task_id=tid) - - return mock_get_task - - -@pytest.fixture -def mock_session() -> AsyncMock: - return AsyncMock() - - -@pytest.fixture -def features(mock_session: AsyncMock) -> ExperimentalClientFeatures: - return ExperimentalClientFeatures(mock_session) - - -@pytest.mark.anyio -async def test_poll_task_yields_until_completed(features: ExperimentalClientFeatures) -> None: - """poll_task yields each status until terminal.""" - features.get_task = make_status_sequence("working", "working", "completed") # type: ignore[method-assign] - - statuses = [s.status async for s in features.poll_task("test-task")] - - assert statuses == ["working", "working", "completed"] - - -@pytest.mark.anyio -@pytest.mark.parametrize("terminal_status", ["completed", "failed", "cancelled"]) -async def test_poll_task_exits_on_terminal(features: ExperimentalClientFeatures, terminal_status: TaskStatus) -> None: - """poll_task exits immediately when task is already terminal.""" - features.get_task = make_status_sequence(terminal_status) # type: ignore[method-assign] - - statuses = [s.status async for s in features.poll_task("test-task")] - - assert statuses == [terminal_status] - - -@pytest.mark.anyio -async def test_poll_task_continues_through_input_required(features: ExperimentalClientFeatures) -> None: - """poll_task yields input_required and continues (non-terminal).""" - features.get_task = make_status_sequence("working", "input_required", "working", "completed") # type: ignore[method-assign] - - statuses = [s.status async for s in features.poll_task("test-task")] - - assert statuses == ["working", "input_required", "working", "completed"] - - -@pytest.mark.anyio -async def test_poll_task_passes_task_id(features: ExperimentalClientFeatures) -> None: - """poll_task passes correct task_id to get_task.""" - received_ids: list[str] = [] - - async def mock_get_task(task_id: str) -> GetTaskResult: - received_ids.append(task_id) - return make_task_result(status="completed", task_id=task_id) - - features.get_task = mock_get_task # type: ignore[method-assign] - - _ = [s async for s in features.poll_task("my-task-123")] - - assert received_ids == ["my-task-123"] - - -@pytest.mark.anyio -async def test_poll_task_yields_full_result(features: ExperimentalClientFeatures) -> None: - """poll_task yields complete GetTaskResult objects.""" - - async def mock_get_task(task_id: str) -> GetTaskResult: - return make_task_result( - status="completed", - task_id=task_id, - status_message="All done!", - ) - - features.get_task = mock_get_task # type: ignore[method-assign] - - results = [r async for r in features.poll_task("test-task")] - - assert len(results) == 1 - assert results[0].status == "completed" - assert results[0].status_message == "All done!" - assert results[0].task_id == "test-task" diff --git a/tests/experimental/tasks/client/test_tasks.py b/tests/experimental/tasks/client/test_tasks.py deleted file mode 100644 index 613c794ebf..0000000000 --- a/tests/experimental/tasks/client/test_tasks.py +++ /dev/null @@ -1,309 +0,0 @@ -"""Tests for the experimental client task methods (session.experimental).""" - -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager -from dataclasses import dataclass, field - -import anyio -import pytest -from anyio import Event -from anyio.abc import TaskGroup - -from mcp import Client -from mcp.server import Server, ServerRequestContext -from mcp.shared.experimental.tasks.helpers import task_execution -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.types import ( - CallToolRequest, - CallToolRequestParams, - CallToolResult, - CancelTaskRequestParams, - CancelTaskResult, - CreateTaskResult, - GetTaskPayloadRequestParams, - GetTaskPayloadResult, - GetTaskRequestParams, - GetTaskResult, - ListTasksResult, - ListToolsResult, - PaginatedRequestParams, - TaskMetadata, - TextContent, -) - -pytestmark = pytest.mark.anyio - - -@dataclass -class AppContext: - """Application context passed via lifespan_context.""" - - task_group: TaskGroup - store: InMemoryTaskStore - task_done_events: dict[str, Event] = field(default_factory=lambda: {}) - - -async def _handle_list_tools( - ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None -) -> ListToolsResult: - raise NotImplementedError - - -async def _handle_call_tool_with_done_event( - ctx: ServerRequestContext[AppContext], params: CallToolRequestParams, *, result_text: str = "Done" -) -> CallToolResult | CreateTaskResult: - app = ctx.lifespan_context - if ctx.experimental.is_task: - task_metadata = ctx.experimental.task_metadata - assert task_metadata is not None - task = await app.store.create_task(task_metadata) - - done_event = Event() - app.task_done_events[task.task_id] = done_event - - async def do_work() -> None: - async with task_execution(task.task_id, app.store) as task_ctx: - await task_ctx.complete(CallToolResult(content=[TextContent(type="text", text=result_text)])) - done_event.set() - - app.task_group.start_soon(do_work) - return CreateTaskResult(task=task) - - raise NotImplementedError - - -def _make_lifespan(store: InMemoryTaskStore, task_done_events: dict[str, Event]): - @asynccontextmanager - async def app_lifespan(server: Server[AppContext]) -> AsyncIterator[AppContext]: - async with anyio.create_task_group() as tg: - yield AppContext(task_group=tg, store=store, task_done_events=task_done_events) - - return app_lifespan - - -async def test_session_experimental_get_task() -> None: - """Test session.experimental.get_task() method.""" - store = InMemoryTaskStore() - task_done_events: dict[str, Event] = {} - - async def handle_get_task(ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams) -> GetTaskResult: - app = ctx.lifespan_context - task = await app.store.get_task(params.task_id) - assert task is not None, f"Test setup error: task {params.task_id} should exist" - return GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=task.poll_interval, - ) - - server: Server[AppContext] = Server( - "test-server", - lifespan=_make_lifespan(store, task_done_events), - on_list_tools=_handle_list_tools, - on_call_tool=_handle_call_tool_with_done_event, - ) - server.experimental.enable_tasks(on_get_task=handle_get_task) - - async with Client(server) as client: - # Create a task - create_result = await client.session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) - ), - CreateTaskResult, - ) - task_id = create_result.task.task_id - - # Wait for task to complete - await task_done_events[task_id].wait() - - # Use session.experimental to get task status - task_status = await client.session.experimental.get_task(task_id) - - assert task_status.task_id == task_id - assert task_status.status == "completed" - - -async def test_session_experimental_get_task_result() -> None: - """Test session.experimental.get_task_result() method.""" - store = InMemoryTaskStore() - task_done_events: dict[str, Event] = {} - - async def handle_call_tool( - ctx: ServerRequestContext[AppContext], params: CallToolRequestParams - ) -> CallToolResult | CreateTaskResult: - return await _handle_call_tool_with_done_event(ctx, params, result_text="Task result content") - - async def handle_get_task_result( - ctx: ServerRequestContext[AppContext], params: GetTaskPayloadRequestParams - ) -> GetTaskPayloadResult: - app = ctx.lifespan_context - result = await app.store.get_result(params.task_id) - assert result is not None, f"Test setup error: result for {params.task_id} should exist" - assert isinstance(result, CallToolResult) - return GetTaskPayloadResult(**result.model_dump()) - - server: Server[AppContext] = Server( - "test-server", - lifespan=_make_lifespan(store, task_done_events), - on_list_tools=_handle_list_tools, - on_call_tool=handle_call_tool, - ) - server.experimental.enable_tasks(on_task_result=handle_get_task_result) - - async with Client(server) as client: - # Create a task - create_result = await client.session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) - ), - CreateTaskResult, - ) - task_id = create_result.task.task_id - - # Wait for task to complete - await task_done_events[task_id].wait() - - # Use TaskClient to get task result - task_result = await client.session.experimental.get_task_result(task_id, CallToolResult) - - assert len(task_result.content) == 1 - content = task_result.content[0] - assert isinstance(content, TextContent) - assert content.text == "Task result content" - - -async def test_session_experimental_list_tasks() -> None: - """Test TaskClient.list_tasks() method.""" - store = InMemoryTaskStore() - task_done_events: dict[str, Event] = {} - - async def handle_list_tasks( - ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None - ) -> ListTasksResult: - app = ctx.lifespan_context - cursor = params.cursor if params else None - tasks_list, next_cursor = await app.store.list_tasks(cursor=cursor) - return ListTasksResult(tasks=tasks_list, next_cursor=next_cursor) - - server: Server[AppContext] = Server( - "test-server", - lifespan=_make_lifespan(store, task_done_events), - on_list_tools=_handle_list_tools, - on_call_tool=_handle_call_tool_with_done_event, - ) - server.experimental.enable_tasks(on_list_tasks=handle_list_tasks) - - async with Client(server) as client: - # Create two tasks - for _ in range(2): - create_result = await client.session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) - ), - CreateTaskResult, - ) - await task_done_events[create_result.task.task_id].wait() - - # Use TaskClient to list tasks - list_result = await client.session.experimental.list_tasks() - - assert len(list_result.tasks) == 2 - - -async def test_session_experimental_cancel_task() -> None: - """Test TaskClient.cancel_task() method.""" - store = InMemoryTaskStore() - task_done_events: dict[str, Event] = {} - - async def handle_call_tool_no_work( - ctx: ServerRequestContext[AppContext], params: CallToolRequestParams - ) -> CallToolResult | CreateTaskResult: - app = ctx.lifespan_context - if ctx.experimental.is_task: - task_metadata = ctx.experimental.task_metadata - assert task_metadata is not None - task = await app.store.create_task(task_metadata) - # Don't start any work - task stays in "working" status - return CreateTaskResult(task=task) - raise NotImplementedError - - async def handle_get_task(ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams) -> GetTaskResult: - app = ctx.lifespan_context - task = await app.store.get_task(params.task_id) - assert task is not None, f"Test setup error: task {params.task_id} should exist" - return GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=task.poll_interval, - ) - - async def handle_cancel_task( - ctx: ServerRequestContext[AppContext], params: CancelTaskRequestParams - ) -> CancelTaskResult: - app = ctx.lifespan_context - task = await app.store.get_task(params.task_id) - assert task is not None, f"Test setup error: task {params.task_id} should exist" - await app.store.update_task(params.task_id, status="cancelled") - updated_task = await app.store.get_task(params.task_id) - assert updated_task is not None - return CancelTaskResult( - task_id=updated_task.task_id, - status=updated_task.status, - created_at=updated_task.created_at, - last_updated_at=updated_task.last_updated_at, - ttl=updated_task.ttl, - ) - - server: Server[AppContext] = Server( - "test-server", - lifespan=_make_lifespan(store, task_done_events), - on_list_tools=_handle_list_tools, - on_call_tool=handle_call_tool_no_work, - ) - server.experimental.enable_tasks(on_get_task=handle_get_task, on_cancel_task=handle_cancel_task) - - async with Client(server) as client: - # Create a task (but don't complete it) - create_result = await client.session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) - ), - CreateTaskResult, - ) - task_id = create_result.task.task_id - - # Verify task is working - status_before = await client.session.experimental.get_task(task_id) - assert status_before.status == "working" - - # Cancel the task - await client.session.experimental.cancel_task(task_id) - - # Verify task is cancelled - status_after = await client.session.experimental.get_task(task_id) - assert status_after.status == "cancelled" diff --git a/tests/experimental/tasks/server/__init__.py b/tests/experimental/tasks/server/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/experimental/tasks/server/test_context.py b/tests/experimental/tasks/server/test_context.py deleted file mode 100644 index a0f1a190d2..0000000000 --- a/tests/experimental/tasks/server/test_context.py +++ /dev/null @@ -1,183 +0,0 @@ -"""Tests for TaskContext and helper functions.""" - -import pytest - -from mcp.shared.experimental.tasks.context import TaskContext -from mcp.shared.experimental.tasks.helpers import create_task_state, task_execution -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.types import CallToolResult, TaskMetadata, TextContent - - -@pytest.mark.anyio -async def test_task_context_properties() -> None: - """Test TaskContext basic properties.""" - store = InMemoryTaskStore() - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - ctx = TaskContext(task, store) - - assert ctx.task_id == task.task_id - assert ctx.task.task_id == task.task_id - assert ctx.task.status == "working" - assert ctx.is_cancelled is False - - store.cleanup() - - -@pytest.mark.anyio -async def test_task_context_update_status() -> None: - """Test TaskContext.update_status.""" - store = InMemoryTaskStore() - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - ctx = TaskContext(task, store) - - await ctx.update_status("Processing step 1...") - - # Check status message was updated - updated = await store.get_task(task.task_id) - assert updated is not None - assert updated.status_message == "Processing step 1..." - - store.cleanup() - - -@pytest.mark.anyio -async def test_task_context_complete() -> None: - """Test TaskContext.complete.""" - store = InMemoryTaskStore() - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - ctx = TaskContext(task, store) - - result = CallToolResult(content=[TextContent(type="text", text="Done!")]) - await ctx.complete(result) - - # Check task status - updated = await store.get_task(task.task_id) - assert updated is not None - assert updated.status == "completed" - - # Check result is stored - stored_result = await store.get_result(task.task_id) - assert stored_result is not None - - store.cleanup() - - -@pytest.mark.anyio -async def test_task_context_fail() -> None: - """Test TaskContext.fail.""" - store = InMemoryTaskStore() - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - ctx = TaskContext(task, store) - - await ctx.fail("Something went wrong!") - - # Check task status - updated = await store.get_task(task.task_id) - assert updated is not None - assert updated.status == "failed" - assert updated.status_message == "Something went wrong!" - - store.cleanup() - - -@pytest.mark.anyio -async def test_task_context_cancellation() -> None: - """Test TaskContext cancellation request.""" - store = InMemoryTaskStore() - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - ctx = TaskContext(task, store) - - assert ctx.is_cancelled is False - - ctx.request_cancellation() - - assert ctx.is_cancelled is True - - store.cleanup() - - -def test_create_task_state_generates_id() -> None: - """create_task_state generates a unique task ID when none provided.""" - task1 = create_task_state(TaskMetadata(ttl=60000)) - task2 = create_task_state(TaskMetadata(ttl=60000)) - - assert task1.task_id != task2.task_id - - -def test_create_task_state_uses_provided_id() -> None: - """create_task_state uses the provided task ID.""" - task = create_task_state(TaskMetadata(ttl=60000), task_id="my-task-123") - assert task.task_id == "my-task-123" - - -def test_create_task_state_null_ttl() -> None: - """create_task_state handles null TTL.""" - task = create_task_state(TaskMetadata(ttl=None)) - assert task.ttl is None - - -def test_create_task_state_has_created_at() -> None: - """create_task_state sets createdAt timestamp.""" - task = create_task_state(TaskMetadata(ttl=60000)) - assert task.created_at is not None - - -@pytest.mark.anyio -async def test_task_execution_provides_context() -> None: - """task_execution provides a TaskContext for the task.""" - store = InMemoryTaskStore() - await store.create_task(TaskMetadata(ttl=60000), task_id="exec-test-1") - - async with task_execution("exec-test-1", store) as ctx: - assert ctx.task_id == "exec-test-1" - assert ctx.task.status == "working" - - store.cleanup() - - -@pytest.mark.anyio -async def test_task_execution_auto_fails_on_exception() -> None: - """task_execution automatically fails task on unhandled exception.""" - store = InMemoryTaskStore() - await store.create_task(TaskMetadata(ttl=60000), task_id="exec-fail-1") - - async with task_execution("exec-fail-1", store): - raise RuntimeError("Oops!") - - # Task should be failed - failed_task = await store.get_task("exec-fail-1") - assert failed_task is not None - assert failed_task.status == "failed" - assert "Oops!" in (failed_task.status_message or "") - - store.cleanup() - - -@pytest.mark.anyio -async def test_task_execution_doesnt_fail_if_already_terminal() -> None: - """task_execution doesn't re-fail if task already terminal.""" - store = InMemoryTaskStore() - await store.create_task(TaskMetadata(ttl=60000), task_id="exec-term-1") - - async with task_execution("exec-term-1", store) as ctx: - # Complete the task first - await ctx.complete(CallToolResult(content=[TextContent(type="text", text="Done")])) - # Then raise - shouldn't change status - raise RuntimeError("This shouldn't matter") - - # Task should remain completed - final_task = await store.get_task("exec-term-1") - assert final_task is not None - assert final_task.status == "completed" - - store.cleanup() - - -@pytest.mark.anyio -async def test_task_execution_not_found() -> None: - """task_execution raises ValueError for non-existent task.""" - store = InMemoryTaskStore() - - with pytest.raises(ValueError, match="not found"): - async with task_execution("nonexistent", store): - ... diff --git a/tests/experimental/tasks/server/test_integration.py b/tests/experimental/tasks/server/test_integration.py deleted file mode 100644 index b5b79033d0..0000000000 --- a/tests/experimental/tasks/server/test_integration.py +++ /dev/null @@ -1,247 +0,0 @@ -"""End-to-end integration tests for tasks functionality. - -These tests demonstrate the full task lifecycle: -1. Client sends task-augmented request (tools/call with task metadata) -2. Server creates task and returns CreateTaskResult immediately -3. Background work executes (using task_execution context manager) -4. Client polls with tasks/get -5. Client retrieves result with tasks/result -""" - -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager -from dataclasses import dataclass, field - -import anyio -import pytest -from anyio import Event -from anyio.abc import TaskGroup - -from mcp import Client -from mcp.server import Server, ServerRequestContext -from mcp.shared.experimental.tasks.helpers import task_execution -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.types import ( - CallToolRequest, - CallToolRequestParams, - CallToolResult, - CreateTaskResult, - GetTaskPayloadRequestParams, - GetTaskPayloadResult, - GetTaskRequestParams, - GetTaskResult, - ListTasksResult, - ListToolsResult, - PaginatedRequestParams, - TaskMetadata, - TextContent, -) - -pytestmark = pytest.mark.anyio - - -@dataclass -class AppContext: - """Application context passed via lifespan_context.""" - - task_group: TaskGroup - store: InMemoryTaskStore - task_done_events: dict[str, Event] = field(default_factory=lambda: {}) - - -def _make_lifespan(store: InMemoryTaskStore, task_done_events: dict[str, Event]): - @asynccontextmanager - async def app_lifespan(server: Server[AppContext]) -> AsyncIterator[AppContext]: - async with anyio.create_task_group() as tg: - yield AppContext(task_group=tg, store=store, task_done_events=task_done_events) - - return app_lifespan - - -async def test_task_lifecycle_with_task_execution() -> None: - """Test the complete task lifecycle using the task_execution pattern.""" - store = InMemoryTaskStore() - task_done_events: dict[str, Event] = {} - - async def handle_list_tools( - ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None - ) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool( - ctx: ServerRequestContext[AppContext], params: CallToolRequestParams - ) -> CallToolResult | CreateTaskResult: - app = ctx.lifespan_context - if params.name == "process_data" and ctx.experimental.is_task: - task_metadata = ctx.experimental.task_metadata - assert task_metadata is not None - task = await app.store.create_task(task_metadata) - - done_event = Event() - app.task_done_events[task.task_id] = done_event - - async def do_work() -> None: - async with task_execution(task.task_id, app.store) as task_ctx: - await task_ctx.update_status("Processing input...") - input_value = (params.arguments or {}).get("input", "") - result_text = f"Processed: {input_value.upper()}" - await task_ctx.complete(CallToolResult(content=[TextContent(type="text", text=result_text)])) - done_event.set() - - app.task_group.start_soon(do_work) - return CreateTaskResult(task=task) - - raise NotImplementedError - - async def handle_get_task(ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams) -> GetTaskResult: - app = ctx.lifespan_context - task = await app.store.get_task(params.task_id) - assert task is not None, f"Test setup error: task {params.task_id} should exist" - return GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=task.poll_interval, - ) - - async def handle_get_task_result( - ctx: ServerRequestContext[AppContext], params: GetTaskPayloadRequestParams - ) -> GetTaskPayloadResult: - app = ctx.lifespan_context - result = await app.store.get_result(params.task_id) - assert result is not None, f"Test setup error: result for {params.task_id} should exist" - assert isinstance(result, CallToolResult) - return GetTaskPayloadResult(**result.model_dump()) - - async def handle_list_tasks( - ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None - ) -> ListTasksResult: - raise NotImplementedError - - server: Server[AppContext] = Server( - "test-tasks", - lifespan=_make_lifespan(store, task_done_events), - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, - ) - server.experimental.enable_tasks( - on_get_task=handle_get_task, - on_task_result=handle_get_task_result, - on_list_tasks=handle_list_tasks, - ) - - async with Client(server) as client: - # Step 1: Send task-augmented tool call - create_result = await client.session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="process_data", - arguments={"input": "hello world"}, - task=TaskMetadata(ttl=60000), - ), - ), - CreateTaskResult, - ) - - assert isinstance(create_result, CreateTaskResult) - assert create_result.task.status == "working" - task_id = create_result.task.task_id - - # Step 2: Wait for task to complete - await task_done_events[task_id].wait() - - task_status = await client.session.experimental.get_task(task_id) - assert task_status.task_id == task_id - assert task_status.status == "completed" - - # Step 3: Retrieve the actual result - task_result = await client.session.experimental.get_task_result(task_id, CallToolResult) - - assert len(task_result.content) == 1 - content = task_result.content[0] - assert isinstance(content, TextContent) - assert content.text == "Processed: HELLO WORLD" - - -async def test_task_auto_fails_on_exception() -> None: - """Test that task_execution automatically fails the task on unhandled exception.""" - store = InMemoryTaskStore() - task_done_events: dict[str, Event] = {} - - async def handle_list_tools( - ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None - ) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool( - ctx: ServerRequestContext[AppContext], params: CallToolRequestParams - ) -> CallToolResult | CreateTaskResult: - app = ctx.lifespan_context - if params.name == "failing_task" and ctx.experimental.is_task: - task_metadata = ctx.experimental.task_metadata - assert task_metadata is not None - task = await app.store.create_task(task_metadata) - - done_event = Event() - app.task_done_events[task.task_id] = done_event - - async def do_failing_work() -> None: - async with task_execution(task.task_id, app.store) as task_ctx: - await task_ctx.update_status("About to fail...") - raise RuntimeError("Something went wrong!") - # This line is reached because task_execution suppresses the exception - done_event.set() - - app.task_group.start_soon(do_failing_work) - return CreateTaskResult(task=task) - - raise NotImplementedError - - async def handle_get_task(ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams) -> GetTaskResult: - app = ctx.lifespan_context - task = await app.store.get_task(params.task_id) - assert task is not None, f"Test setup error: task {params.task_id} should exist" - return GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=task.poll_interval, - ) - - server: Server[AppContext] = Server( - "test-tasks-failure", - lifespan=_make_lifespan(store, task_done_events), - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, - ) - server.experimental.enable_tasks(on_get_task=handle_get_task) - - async with Client(server) as client: - # Send task request - create_result = await client.session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="failing_task", - arguments={}, - task=TaskMetadata(ttl=60000), - ), - ), - CreateTaskResult, - ) - - task_id = create_result.task.task_id - - # Wait for task to complete (even though it fails) - await task_done_events[task_id].wait() - - # Check that task was auto-failed - task_status = await client.session.experimental.get_task(task_id) - - assert task_status.status == "failed" - assert task_status.status_message == "Something went wrong!" diff --git a/tests/experimental/tasks/server/test_run_task_flow.py b/tests/experimental/tasks/server/test_run_task_flow.py deleted file mode 100644 index 027382e69e..0000000000 --- a/tests/experimental/tasks/server/test_run_task_flow.py +++ /dev/null @@ -1,367 +0,0 @@ -"""Tests for the simplified task API: enable_tasks() + run_task() - -This tests the recommended user flow: -1. server.experimental.enable_tasks() - one-line setup -2. ctx.experimental.run_task(work) - spawns work, returns CreateTaskResult -3. work function uses ServerTaskContext for elicit/create_message - -These are integration tests that verify the complete flow works end-to-end. -""" - -from unittest.mock import Mock - -import anyio -import pytest -from anyio import Event - -from mcp import Client -from mcp.server import Server, ServerRequestContext -from mcp.server.experimental.request_context import Experimental -from mcp.server.experimental.task_context import ServerTaskContext -from mcp.server.experimental.task_support import TaskSupport -from mcp.server.lowlevel import NotificationOptions -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue -from mcp.types import ( - TASK_REQUIRED, - CallToolRequestParams, - CallToolResult, - CreateTaskResult, - GetTaskRequestParams, - GetTaskResult, - ListToolsResult, - PaginatedRequestParams, - TextContent, -) - -pytestmark = pytest.mark.anyio - - -async def _handle_list_tools_simple_task( - ctx: ServerRequestContext, params: PaginatedRequestParams | None -) -> ListToolsResult: - raise NotImplementedError - - -async def test_run_task_basic_flow() -> None: - """Test the basic run_task flow without elicitation.""" - work_completed = Event() - received_meta: list[str | None] = [None] - - async def handle_call_tool( - ctx: ServerRequestContext, params: CallToolRequestParams - ) -> CallToolResult | CreateTaskResult: - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - if ctx.meta is not None: # pragma: no branch - received_meta[0] = ctx.meta.get("custom_field") - - async def work(task: ServerTaskContext) -> CallToolResult: - await task.update_status("Working...") - input_val = (params.arguments or {}).get("input", "default") - result = CallToolResult(content=[TextContent(type="text", text=f"Processed: {input_val}")]) - work_completed.set() - return result - - return await ctx.experimental.run_task(work) - - server = Server( - "test-run-task", - on_list_tools=_handle_list_tools_simple_task, - on_call_tool=handle_call_tool, - ) - server.experimental.enable_tasks() - - async with Client(server) as client: - result = await client.session.experimental.call_tool_as_task( - "simple_task", - {"input": "hello"}, - meta={"custom_field": "test_value"}, - ) - - task_id = result.task.task_id - assert result.task.status == "working" - - with anyio.fail_after(5): - await work_completed.wait() - - with anyio.fail_after(5): - while True: - task_status = await client.session.experimental.get_task(task_id) - if task_status.status == "completed": # pragma: no branch - break - - assert received_meta[0] == "test_value" - - -async def test_run_task_auto_fails_on_exception() -> None: - """Test that run_task automatically fails the task when work raises.""" - work_failed = Event() - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool( - ctx: ServerRequestContext, params: CallToolRequestParams - ) -> CallToolResult | CreateTaskResult: - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - async def work(task: ServerTaskContext) -> CallToolResult: - work_failed.set() - raise RuntimeError("Something went wrong!") - - return await ctx.experimental.run_task(work) - - server = Server( - "test-run-task-fail", - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, - ) - server.experimental.enable_tasks() - - async with Client(server) as client: - result = await client.session.experimental.call_tool_as_task("failing_task", {}) - task_id = result.task.task_id - - with anyio.fail_after(5): - await work_failed.wait() - - with anyio.fail_after(5): - while True: - task_status = await client.session.experimental.get_task(task_id) - if task_status.status == "failed": # pragma: no branch - break - - assert "Something went wrong" in (task_status.status_message or "") - - -async def test_enable_tasks_auto_registers_handlers() -> None: - """Test that enable_tasks() auto-registers get_task, list_tasks, cancel_task handlers.""" - server = Server("test-enable-tasks") - - # Before enable_tasks, no task capabilities - caps_before = server.get_capabilities(NotificationOptions(), {}) - assert caps_before.tasks is None - - # Enable tasks - server.experimental.enable_tasks() - - # After enable_tasks, should have task capabilities - caps_after = server.get_capabilities(NotificationOptions(), {}) - assert caps_after.tasks is not None - assert caps_after.tasks.list is not None - assert caps_after.tasks.cancel is not None - assert caps_after.tasks.requests is not None - assert caps_after.tasks.requests.tools is not None - assert caps_after.tasks.requests.tools.call is not None - - -async def test_enable_tasks_with_custom_store_and_queue() -> None: - """Test that enable_tasks() uses provided store and queue instead of defaults.""" - server = Server("test-custom-store-queue") - - custom_store = InMemoryTaskStore() - custom_queue = InMemoryTaskMessageQueue() - - task_support = server.experimental.enable_tasks(store=custom_store, queue=custom_queue) - - assert task_support.store is custom_store - assert task_support.queue is custom_queue - - -async def test_enable_tasks_skips_default_handlers_when_custom_registered() -> None: - """Test that enable_tasks() doesn't override already-registered handlers.""" - server = Server("test-custom-handlers") - - # Register custom handlers via enable_tasks kwargs - async def custom_get_task(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: - raise NotImplementedError - - server.experimental.enable_tasks(on_get_task=custom_get_task) - - # Verify handler is registered - assert server._has_handler("tasks/get") - assert server._has_handler("tasks/list") - assert server._has_handler("tasks/cancel") - assert server._has_handler("tasks/result") - - -async def test_run_task_without_enable_tasks_raises() -> None: - """Test that run_task raises when enable_tasks() wasn't called.""" - experimental = Experimental( - task_metadata=None, - _client_capabilities=None, - _session=None, - _task_support=None, # Not enabled - ) - - async def work(task: ServerTaskContext) -> CallToolResult: - raise NotImplementedError - - with pytest.raises(RuntimeError, match="Task support not enabled"): - await experimental.run_task(work) - - -async def test_task_support_task_group_before_run_raises() -> None: - """Test that accessing task_group before run() raises RuntimeError.""" - task_support = TaskSupport.in_memory() - - with pytest.raises(RuntimeError, match="TaskSupport not running"): - _ = task_support.task_group - - -async def test_run_task_without_session_raises() -> None: - """Test that run_task raises when session is not available.""" - task_support = TaskSupport.in_memory() - - experimental = Experimental( - task_metadata=None, - _client_capabilities=None, - _session=None, # No session - _task_support=task_support, - ) - - async def work(task: ServerTaskContext) -> CallToolResult: - raise NotImplementedError - - with pytest.raises(RuntimeError, match="Session not available"): - await experimental.run_task(work) - - -async def test_run_task_without_task_metadata_raises() -> None: - """Test that run_task raises when request is not task-augmented.""" - task_support = TaskSupport.in_memory() - mock_session = Mock() - - experimental = Experimental( - task_metadata=None, # Not a task-augmented request - _client_capabilities=None, - _session=mock_session, - _task_support=task_support, - ) - - async def work(task: ServerTaskContext) -> CallToolResult: - raise NotImplementedError - - with pytest.raises(RuntimeError, match="Request is not task-augmented"): - await experimental.run_task(work) - - -async def test_run_task_with_model_immediate_response() -> None: - """Test that run_task includes model_immediate_response in CreateTaskResult._meta.""" - work_completed = Event() - immediate_response_text = "Processing your request..." - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool( - ctx: ServerRequestContext, params: CallToolRequestParams - ) -> CallToolResult | CreateTaskResult: - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - async def work(task: ServerTaskContext) -> CallToolResult: - work_completed.set() - return CallToolResult(content=[TextContent(type="text", text="Done")]) - - return await ctx.experimental.run_task(work, model_immediate_response=immediate_response_text) - - server = Server( - "test-run-task-immediate", - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, - ) - server.experimental.enable_tasks() - - async with Client(server) as client: - result = await client.session.experimental.call_tool_as_task("task_with_immediate", {}) - - assert result.meta is not None - assert "io.modelcontextprotocol/model-immediate-response" in result.meta - assert result.meta["io.modelcontextprotocol/model-immediate-response"] == immediate_response_text - - with anyio.fail_after(5): - await work_completed.wait() - - -async def test_run_task_doesnt_complete_if_already_terminal() -> None: - """Test that run_task doesn't auto-complete if work manually completed the task.""" - work_completed = Event() - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool( - ctx: ServerRequestContext, params: CallToolRequestParams - ) -> CallToolResult | CreateTaskResult: - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - async def work(task: ServerTaskContext) -> CallToolResult: - manual_result = CallToolResult(content=[TextContent(type="text", text="Manually completed")]) - await task.complete(manual_result, notify=False) - work_completed.set() - return CallToolResult(content=[TextContent(type="text", text="This should be ignored")]) - - return await ctx.experimental.run_task(work) - - server = Server( - "test-already-complete", - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, - ) - server.experimental.enable_tasks() - - async with Client(server) as client: - result = await client.session.experimental.call_tool_as_task("manual_complete_task", {}) - task_id = result.task.task_id - - with anyio.fail_after(5): - await work_completed.wait() - - with anyio.fail_after(5): - while True: - status = await client.session.experimental.get_task(task_id) - if status.status == "completed": # pragma: no branch - break - - -async def test_run_task_doesnt_fail_if_already_terminal() -> None: - """Test that run_task doesn't auto-fail if work manually failed/cancelled the task.""" - work_completed = Event() - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool( - ctx: ServerRequestContext, params: CallToolRequestParams - ) -> CallToolResult | CreateTaskResult: - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - async def work(task: ServerTaskContext) -> CallToolResult: - await task.fail("Manually failed", notify=False) - work_completed.set() - raise RuntimeError("This error should not change status") - - return await ctx.experimental.run_task(work) - - server = Server( - "test-already-failed", - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, - ) - server.experimental.enable_tasks() - - async with Client(server) as client: - result = await client.session.experimental.call_tool_as_task("manual_cancel_task", {}) - task_id = result.task.task_id - - with anyio.fail_after(5): - await work_completed.wait() - - with anyio.fail_after(5): - while True: - status = await client.session.experimental.get_task(task_id) - if status.status == "failed": # pragma: no branch - break - - assert status.status_message == "Manually failed" diff --git a/tests/experimental/tasks/server/test_server.py b/tests/experimental/tasks/server/test_server.py deleted file mode 100644 index 6a28b274ea..0000000000 --- a/tests/experimental/tasks/server/test_server.py +++ /dev/null @@ -1,797 +0,0 @@ -"""Tests for server-side task support (handlers, capabilities, integration).""" - -from datetime import datetime, timezone -from typing import Any - -import anyio -import pytest - -from mcp import Client -from mcp.client.session import ClientSession -from mcp.server import Server, ServerRequestContext -from mcp.server.lowlevel import NotificationOptions -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession -from mcp.shared.exceptions import MCPError -from mcp.shared.message import ServerMessageMetadata, SessionMessage -from mcp.shared.response_router import ResponseRouter -from mcp.shared.session import RequestResponder -from mcp.types import ( - INVALID_REQUEST, - TASK_FORBIDDEN, - TASK_OPTIONAL, - TASK_REQUIRED, - CallToolRequest, - CallToolRequestParams, - CallToolResult, - CancelTaskRequestParams, - CancelTaskResult, - ClientResult, - ErrorData, - GetTaskPayloadRequest, - GetTaskPayloadRequestParams, - GetTaskPayloadResult, - GetTaskRequestParams, - GetTaskResult, - JSONRPCError, - JSONRPCNotification, - JSONRPCResponse, - ListTasksResult, - ListToolsResult, - PaginatedRequestParams, - SamplingMessage, - ServerCapabilities, - ServerNotification, - ServerRequest, - Task, - TaskMetadata, - TextContent, - Tool, - ToolExecution, -) - -pytestmark = pytest.mark.anyio - - -async def test_list_tasks_handler() -> None: - """Test that experimental list_tasks handler works via Client.""" - now = datetime.now(timezone.utc) - test_tasks = [ - Task(task_id="task-1", status="working", created_at=now, last_updated_at=now, ttl=60000, poll_interval=1000), - Task(task_id="task-2", status="completed", created_at=now, last_updated_at=now, ttl=60000, poll_interval=1000), - ] - - async def handle_list_tasks(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListTasksResult: - return ListTasksResult(tasks=test_tasks) - - server = Server("test") - server.experimental.enable_tasks(on_list_tasks=handle_list_tasks) - - async with Client(server) as client: - result = await client.session.experimental.list_tasks() - assert len(result.tasks) == 2 - assert result.tasks[0].task_id == "task-1" - assert result.tasks[1].task_id == "task-2" - - -async def test_get_task_handler() -> None: - """Test that experimental get_task handler works via Client.""" - - async def handle_get_task(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: - now = datetime.now(timezone.utc) - return GetTaskResult( - task_id=params.task_id, - status="working", - created_at=now, - last_updated_at=now, - ttl=60000, - poll_interval=1000, - ) - - server = Server("test") - server.experimental.enable_tasks(on_get_task=handle_get_task) - - async with Client(server) as client: - result = await client.session.experimental.get_task("test-task-123") - assert result.task_id == "test-task-123" - assert result.status == "working" - - -async def test_get_task_result_handler() -> None: - """Test that experimental get_task_result handler works via Client.""" - - async def handle_get_task_result( - ctx: ServerRequestContext, params: GetTaskPayloadRequestParams - ) -> GetTaskPayloadResult: - return GetTaskPayloadResult() - - server = Server("test") - server.experimental.enable_tasks(on_task_result=handle_get_task_result) - - async with Client(server) as client: - result = await client.session.send_request( - GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id="test-task-123")), - GetTaskPayloadResult, - ) - assert isinstance(result, GetTaskPayloadResult) - - -async def test_cancel_task_handler() -> None: - """Test that experimental cancel_task handler works via Client.""" - - async def handle_cancel_task(ctx: ServerRequestContext, params: CancelTaskRequestParams) -> CancelTaskResult: - now = datetime.now(timezone.utc) - return CancelTaskResult( - task_id=params.task_id, - status="cancelled", - created_at=now, - last_updated_at=now, - ttl=60000, - ) - - server = Server("test") - server.experimental.enable_tasks(on_cancel_task=handle_cancel_task) - - async with Client(server) as client: - result = await client.session.experimental.cancel_task("test-task-123") - assert result.task_id == "test-task-123" - assert result.status == "cancelled" - - -async def test_server_capabilities_include_tasks() -> None: - """Test that server capabilities include tasks when handlers are registered.""" - server = Server("test") - - async def noop_list_tasks(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListTasksResult: - raise NotImplementedError - - async def noop_cancel_task(ctx: ServerRequestContext, params: CancelTaskRequestParams) -> CancelTaskResult: - raise NotImplementedError - - server.experimental.enable_tasks(on_list_tasks=noop_list_tasks, on_cancel_task=noop_cancel_task) - - capabilities = server.get_capabilities(notification_options=NotificationOptions(), experimental_capabilities={}) - - assert capabilities.tasks is not None - assert capabilities.tasks.list is not None - assert capabilities.tasks.cancel is not None - assert capabilities.tasks.requests is not None - assert capabilities.tasks.requests.tools is not None - - -@pytest.mark.skip( - reason="TODO(maxisbey): enable_tasks registers default handlers for all task methods, " - "so partial capabilities aren't possible yet. Low-level API should support " - "selectively enabling/disabling task capabilities." -) -async def test_server_capabilities_partial_tasks() -> None: # pragma: no cover - """Test capabilities with only some task handlers registered.""" - server = Server("test") - - async def noop_list_tasks(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListTasksResult: - raise NotImplementedError - - # Only list_tasks registered, not cancel_task - server.experimental.enable_tasks(on_list_tasks=noop_list_tasks) - - capabilities = server.get_capabilities(notification_options=NotificationOptions(), experimental_capabilities={}) - - assert capabilities.tasks is not None - assert capabilities.tasks.list is not None - assert capabilities.tasks.cancel is None # Not registered - - -async def test_tool_with_task_execution_metadata() -> None: - """Test that tools can declare task execution mode.""" - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult( - tools=[ - Tool( - name="quick_tool", - description="Fast tool", - input_schema={"type": "object", "properties": {}}, - execution=ToolExecution(task_support=TASK_FORBIDDEN), - ), - Tool( - name="long_tool", - description="Long running tool", - input_schema={"type": "object", "properties": {}}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ), - Tool( - name="flexible_tool", - description="Can be either", - input_schema={"type": "object", "properties": {}}, - execution=ToolExecution(task_support=TASK_OPTIONAL), - ), - ] - ) - - server = Server("test", on_list_tools=handle_list_tools) - - async with Client(server) as client: - result = await client.list_tools() - tools = result.tools - - assert tools[0].execution is not None - assert tools[0].execution.task_support == TASK_FORBIDDEN - assert tools[1].execution is not None - assert tools[1].execution.task_support == TASK_REQUIRED - assert tools[2].execution is not None - assert tools[2].execution.task_support == TASK_OPTIONAL - - -async def test_task_metadata_in_call_tool_request() -> None: - """Test that task metadata is accessible via ctx when calling a tool.""" - captured_task_metadata: TaskMetadata | None = None - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: - nonlocal captured_task_metadata - captured_task_metadata = ctx.experimental.task_metadata - return CallToolResult(content=[TextContent(type="text", text="done")]) - - server = Server("test", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) - - async with Client(server) as client: - # Call tool with task metadata - await client.session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="long_task", - arguments={}, - task=TaskMetadata(ttl=60000), - ), - ), - CallToolResult, - ) - - assert captured_task_metadata is not None - assert captured_task_metadata.ttl == 60000 - - -async def test_task_metadata_is_task_property() -> None: - """Test that ctx.experimental.is_task works correctly.""" - is_task_values: list[bool] = [] - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: - is_task_values.append(ctx.experimental.is_task) - return CallToolResult(content=[TextContent(type="text", text="done")]) - - server = Server("test", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) - - async with Client(server) as client: - # Call without task metadata - await client.session.send_request( - CallToolRequest(params=CallToolRequestParams(name="test_tool", arguments={})), - CallToolResult, - ) - - # Call with task metadata - await client.session.send_request( - CallToolRequest( - params=CallToolRequestParams(name="test_tool", arguments={}, task=TaskMetadata(ttl=60000)), - ), - CallToolResult, - ) - - assert len(is_task_values) == 2 - assert is_task_values[0] is False # First call without task - assert is_task_values[1] is True # Second call with task - - -async def test_update_capabilities_no_handlers() -> None: - """Test that update_capabilities returns early when no task handlers are registered.""" - server = Server("test-no-handlers") - _ = server.experimental - - caps = server.get_capabilities(NotificationOptions(), {}) - assert caps.tasks is None - - -async def test_update_capabilities_partial_handlers() -> None: - """Test that update_capabilities skips list/cancel when only tasks/get is registered.""" - server = Server("test-partial") - # Access .experimental to create the ExperimentalHandlers instance - exp = server.experimental - # Second access returns the same cached instance - assert server.experimental is exp - - async def noop_get(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: - raise NotImplementedError - - server._add_request_handler("tasks/get", noop_get) - - caps = server.get_capabilities(NotificationOptions(), {}) - assert caps.tasks is not None - assert caps.tasks.list is None - assert caps.tasks.cancel is None - - -async def test_default_task_handlers_via_enable_tasks() -> None: - """Test that enable_tasks() auto-registers working default handlers.""" - server = Server("test-default-handlers") - task_support = server.experimental.enable_tasks() - store = task_support.store - - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: ... # pragma: no branch - - async def run_server() -> None: - async with task_support.run(): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), - ) as server_session: - task_support.configure_session(server_session) - async for message in server_session.incoming_messages: - await server._handle_message(message, server_session, {}, False) - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - await client_session.initialize() - - # Create a task directly in the store for testing - task = await store.create_task(TaskMetadata(ttl=60000)) - - # Test list_tasks (default handler) - list_result = await client_session.experimental.list_tasks() - assert len(list_result.tasks) == 1 - assert list_result.tasks[0].task_id == task.task_id - - # Test get_task (default handler - found) - get_result = await client_session.experimental.get_task(task.task_id) - assert get_result.task_id == task.task_id - assert get_result.status == "working" - - # Test get_task (default handler - not found path) - with pytest.raises(MCPError, match="not found"): - await client_session.experimental.get_task("nonexistent-task") - - # Create a completed task to test get_task_result - completed_task = await store.create_task(TaskMetadata(ttl=60000)) - await store.store_result( - completed_task.task_id, CallToolResult(content=[TextContent(type="text", text="Test result")]) - ) - await store.update_task(completed_task.task_id, status="completed") - - # Test get_task_result (default handler) - payload_result = await client_session.send_request( - GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=completed_task.task_id)), - GetTaskPayloadResult, - ) - # The result should have the related-task metadata - assert payload_result.meta is not None - assert "io.modelcontextprotocol/related-task" in payload_result.meta - - # Test cancel_task (default handler) - cancel_result = await client_session.experimental.cancel_task(task.task_id) - assert cancel_result.task_id == task.task_id - assert cancel_result.status == "cancelled" - - tg.cancel_scope.cancel() - - -@pytest.mark.anyio -async def test_build_elicit_form_request() -> None: - """Test that _build_elicit_form_request builds a proper elicitation request.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - try: - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions(server_name="test-server", server_version="1.0.0", capabilities=ServerCapabilities()), - ) as server_session: - # Test without task_id - request = server_session._build_elicit_form_request( - message="Test message", - requested_schema={"type": "object", "properties": {"answer": {"type": "string"}}}, - ) - assert request.method == "elicitation/create" - assert request.params is not None - assert request.params["message"] == "Test message" - - # Test with related_task_id (adds related-task metadata) - request_with_task = server_session._build_elicit_form_request( - message="Task message", - requested_schema={"type": "object"}, - related_task_id="test-task-123", - ) - assert request_with_task.method == "elicitation/create" - assert request_with_task.params is not None - assert "_meta" in request_with_task.params - assert "io.modelcontextprotocol/related-task" in request_with_task.params["_meta"] - assert ( - request_with_task.params["_meta"]["io.modelcontextprotocol/related-task"]["taskId"] == "test-task-123" - ) - finally: - await server_to_client_send.aclose() - await server_to_client_receive.aclose() - await client_to_server_send.aclose() - await client_to_server_receive.aclose() - - -@pytest.mark.anyio -async def test_build_elicit_url_request() -> None: - """Test that _build_elicit_url_request builds a proper URL mode elicitation request.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - try: - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions(server_name="test-server", server_version="1.0.0", capabilities=ServerCapabilities()), - ) as server_session: - # Test without related_task_id - request = server_session._build_elicit_url_request( - message="Please authorize with GitHub", - url="https://github.com/login/oauth/authorize", - elicitation_id="oauth-123", - ) - assert request.method == "elicitation/create" - assert request.params is not None - assert request.params["message"] == "Please authorize with GitHub" - assert request.params["url"] == "https://github.com/login/oauth/authorize" - assert request.params["elicitationId"] == "oauth-123" - assert request.params["mode"] == "url" - - # Test with related_task_id (adds related-task metadata) - request_with_task = server_session._build_elicit_url_request( - message="OAuth required", - url="https://example.com/oauth", - elicitation_id="oauth-456", - related_task_id="test-task-789", - ) - assert request_with_task.method == "elicitation/create" - assert request_with_task.params is not None - assert "_meta" in request_with_task.params - assert "io.modelcontextprotocol/related-task" in request_with_task.params["_meta"] - assert ( - request_with_task.params["_meta"]["io.modelcontextprotocol/related-task"]["taskId"] == "test-task-789" - ) - finally: - await server_to_client_send.aclose() - await server_to_client_receive.aclose() - await client_to_server_send.aclose() - await client_to_server_receive.aclose() - - -@pytest.mark.anyio -async def test_build_create_message_request() -> None: - """Test that _build_create_message_request builds a proper sampling request.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - try: - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=ServerCapabilities(), - ), - ) as server_session: - messages = [ - SamplingMessage(role="user", content=TextContent(type="text", text="Hello")), - ] - - # Test without task_id - request = server_session._build_create_message_request( - messages=messages, - max_tokens=100, - system_prompt="You are helpful", - ) - assert request.method == "sampling/createMessage" - assert request.params is not None - assert request.params["maxTokens"] == 100 - - # Test with related_task_id (adds related-task metadata) - request_with_task = server_session._build_create_message_request( - messages=messages, - max_tokens=50, - related_task_id="sampling-task-456", - ) - assert request_with_task.method == "sampling/createMessage" - assert request_with_task.params is not None - assert "_meta" in request_with_task.params - assert "io.modelcontextprotocol/related-task" in request_with_task.params["_meta"] - assert ( - request_with_task.params["_meta"]["io.modelcontextprotocol/related-task"]["taskId"] - == "sampling-task-456" - ) - finally: - await server_to_client_send.aclose() - await server_to_client_receive.aclose() - await client_to_server_send.aclose() - await client_to_server_receive.aclose() - - -@pytest.mark.anyio -async def test_send_message() -> None: - """Test that send_message sends a raw session message.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - try: - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=ServerCapabilities(), - ), - ) as server_session: - # Create a test message - notification = JSONRPCNotification(jsonrpc="2.0", method="test/notification") - message = SessionMessage( - message=notification, - metadata=ServerMessageMetadata(related_request_id="test-req-1"), - ) - - # Send the message - await server_session.send_message(message) - - # Verify it was sent to the stream - received = await server_to_client_receive.receive() - assert isinstance(received.message, JSONRPCNotification) - assert received.message.method == "test/notification" - finally: # pragma: lax no cover - await server_to_client_send.aclose() - await server_to_client_receive.aclose() - await client_to_server_send.aclose() - await client_to_server_receive.aclose() - - -@pytest.mark.anyio -async def test_response_routing_success() -> None: - """Test that response routing works for success responses.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - # Track routed responses with event for synchronization - routed_responses: list[dict[str, Any]] = [] - response_received = anyio.Event() - - class TestRouter(ResponseRouter): - def route_response(self, request_id: str | int, response: dict[str, Any]) -> bool: - routed_responses.append({"id": request_id, "response": response}) - response_received.set() - return True # Handled - - def route_error(self, request_id: str | int, error: ErrorData) -> bool: - raise NotImplementedError - - try: - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=ServerCapabilities(), - ), - ) as server_session: - router = TestRouter() - server_session.add_response_router(router) - - # Simulate receiving a response from client - response = JSONRPCResponse(jsonrpc="2.0", id="test-req-1", result={"status": "ok"}) - message = SessionMessage(message=response) - - # Send from "client" side - await client_to_server_send.send(message) - - # Wait for response to be routed - with anyio.fail_after(5): - await response_received.wait() - - # Verify response was routed - assert len(routed_responses) == 1 - assert routed_responses[0]["id"] == "test-req-1" - assert routed_responses[0]["response"]["status"] == "ok" - finally: # pragma: lax no cover - await server_to_client_send.aclose() - await server_to_client_receive.aclose() - await client_to_server_send.aclose() - await client_to_server_receive.aclose() - - -@pytest.mark.anyio -async def test_response_routing_error() -> None: - """Test that error routing works for error responses.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - # Track routed errors with event for synchronization - routed_errors: list[dict[str, Any]] = [] - error_received = anyio.Event() - - class TestRouter(ResponseRouter): - def route_response(self, request_id: str | int, response: dict[str, Any]) -> bool: - raise NotImplementedError - - def route_error(self, request_id: str | int, error: ErrorData) -> bool: - routed_errors.append({"id": request_id, "error": error}) - error_received.set() - return True # Handled - - try: - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=ServerCapabilities(), - ), - ) as server_session: - router = TestRouter() - server_session.add_response_router(router) - - # Simulate receiving an error response from client - error_data = ErrorData(code=INVALID_REQUEST, message="Test error") - error_response = JSONRPCError(jsonrpc="2.0", id="test-req-2", error=error_data) - message = SessionMessage(message=error_response) - - # Send from "client" side - await client_to_server_send.send(message) - - # Wait for error to be routed - with anyio.fail_after(5): - await error_received.wait() - - # Verify error was routed - assert len(routed_errors) == 1 - assert routed_errors[0]["id"] == "test-req-2" - assert routed_errors[0]["error"].message == "Test error" - finally: # pragma: lax no cover - await server_to_client_send.aclose() - await server_to_client_receive.aclose() - await client_to_server_send.aclose() - await client_to_server_receive.aclose() - - -@pytest.mark.anyio -async def test_response_routing_skips_non_matching_routers() -> None: - """Test that routing continues to next router when first doesn't match.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - # Track which routers were called - router_calls: list[str] = [] - response_received = anyio.Event() - - class NonMatchingRouter(ResponseRouter): - def route_response(self, request_id: str | int, response: dict[str, Any]) -> bool: - router_calls.append("non_matching_response") - return False # Doesn't handle it - - def route_error(self, request_id: str | int, error: ErrorData) -> bool: - raise NotImplementedError - - class MatchingRouter(ResponseRouter): - def route_response(self, request_id: str | int, response: dict[str, Any]) -> bool: - router_calls.append("matching_response") - response_received.set() - return True # Handles it - - def route_error(self, request_id: str | int, error: ErrorData) -> bool: - raise NotImplementedError - - try: - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=ServerCapabilities(), - ), - ) as server_session: - # Add non-matching router first, then matching router - server_session.add_response_router(NonMatchingRouter()) - server_session.add_response_router(MatchingRouter()) - - # Send a response - should skip first router and be handled by second - response = JSONRPCResponse(jsonrpc="2.0", id="test-req-1", result={"status": "ok"}) - message = SessionMessage(message=response) - await client_to_server_send.send(message) - - with anyio.fail_after(5): - await response_received.wait() - - # Verify both routers were called (first returned False, second returned True) - assert router_calls == ["non_matching_response", "matching_response"] - finally: # pragma: lax no cover - await server_to_client_send.aclose() - await server_to_client_receive.aclose() - await client_to_server_send.aclose() - await client_to_server_receive.aclose() - - -@pytest.mark.anyio -async def test_error_routing_skips_non_matching_routers() -> None: - """Test that error routing continues to next router when first doesn't match.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - # Track which routers were called - router_calls: list[str] = [] - error_received = anyio.Event() - - class NonMatchingRouter(ResponseRouter): - def route_response(self, request_id: str | int, response: dict[str, Any]) -> bool: - raise NotImplementedError - - def route_error(self, request_id: str | int, error: ErrorData) -> bool: - router_calls.append("non_matching_error") - return False # Doesn't handle it - - class MatchingRouter(ResponseRouter): - def route_response(self, request_id: str | int, response: dict[str, Any]) -> bool: - raise NotImplementedError - - def route_error(self, request_id: str | int, error: ErrorData) -> bool: - router_calls.append("matching_error") - error_received.set() - return True # Handles it - - try: - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=ServerCapabilities(), - ), - ) as server_session: - # Add non-matching router first, then matching router - server_session.add_response_router(NonMatchingRouter()) - server_session.add_response_router(MatchingRouter()) - - # Send an error - should skip first router and be handled by second - error_data = ErrorData(code=INVALID_REQUEST, message="Test error") - error_response = JSONRPCError(jsonrpc="2.0", id="test-req-2", error=error_data) - message = SessionMessage(message=error_response) - await client_to_server_send.send(message) - - with anyio.fail_after(5): - await error_received.wait() - - # Verify both routers were called (first returned False, second returned True) - assert router_calls == ["non_matching_error", "matching_error"] - finally: # pragma: lax no cover - await server_to_client_send.aclose() - await server_to_client_receive.aclose() - await client_to_server_send.aclose() - await client_to_server_receive.aclose() diff --git a/tests/experimental/tasks/server/test_server_task_context.py b/tests/experimental/tasks/server/test_server_task_context.py deleted file mode 100644 index e23299698c..0000000000 --- a/tests/experimental/tasks/server/test_server_task_context.py +++ /dev/null @@ -1,709 +0,0 @@ -"""Tests for ServerTaskContext.""" - -import asyncio -from unittest.mock import AsyncMock, Mock - -import anyio -import pytest - -from mcp.server.experimental.task_context import ServerTaskContext -from mcp.server.experimental.task_result_handler import TaskResultHandler -from mcp.shared.exceptions import MCPError -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue -from mcp.types import ( - CallToolResult, - ClientCapabilities, - ClientTasksCapability, - ClientTasksRequestsCapability, - Implementation, - InitializeRequestParams, - JSONRPCRequest, - SamplingMessage, - TaskMetadata, - TasksCreateElicitationCapability, - TasksCreateMessageCapability, - TasksElicitationCapability, - TasksSamplingCapability, - TextContent, -) - - -@pytest.mark.anyio -async def test_server_task_context_properties() -> None: - """Test ServerTaskContext property accessors.""" - store = InMemoryTaskStore() - mock_session = Mock() - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-123") - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - ) - - assert ctx.task_id == "test-123" - assert ctx.task.task_id == "test-123" - assert ctx.is_cancelled is False - - store.cleanup() - - -@pytest.mark.anyio -async def test_server_task_context_request_cancellation() -> None: - """Test ServerTaskContext.request_cancellation().""" - store = InMemoryTaskStore() - mock_session = Mock() - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - ) - - assert ctx.is_cancelled is False - ctx.request_cancellation() - assert ctx.is_cancelled is True - - store.cleanup() - - -@pytest.mark.anyio -async def test_server_task_context_update_status_with_notify() -> None: - """Test update_status sends notification when notify=True.""" - store = InMemoryTaskStore() - mock_session = Mock() - mock_session.send_notification = AsyncMock() - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - ) - - await ctx.update_status("Working...", notify=True) - - mock_session.send_notification.assert_called_once() - store.cleanup() - - -@pytest.mark.anyio -async def test_server_task_context_update_status_without_notify() -> None: - """Test update_status skips notification when notify=False.""" - store = InMemoryTaskStore() - mock_session = Mock() - mock_session.send_notification = AsyncMock() - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - ) - - await ctx.update_status("Working...", notify=False) - - mock_session.send_notification.assert_not_called() - store.cleanup() - - -@pytest.mark.anyio -async def test_server_task_context_complete_with_notify() -> None: - """Test complete sends notification when notify=True.""" - store = InMemoryTaskStore() - mock_session = Mock() - mock_session.send_notification = AsyncMock() - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - ) - - result = CallToolResult(content=[TextContent(type="text", text="Done")]) - await ctx.complete(result, notify=True) - - mock_session.send_notification.assert_called_once() - store.cleanup() - - -@pytest.mark.anyio -async def test_server_task_context_fail_with_notify() -> None: - """Test fail sends notification when notify=True.""" - store = InMemoryTaskStore() - mock_session = Mock() - mock_session.send_notification = AsyncMock() - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - ) - - await ctx.fail("Something went wrong", notify=True) - - mock_session.send_notification.assert_called_once() - store.cleanup() - - -@pytest.mark.anyio -async def test_elicit_raises_when_client_lacks_capability() -> None: - """Test that elicit() raises MCPError when client doesn't support elicitation.""" - store = InMemoryTaskStore() - mock_session = Mock() - mock_session.check_client_capability = Mock(return_value=False) - queue = InMemoryTaskMessageQueue() - handler = TaskResultHandler(store, queue) - task = await store.create_task(TaskMetadata(ttl=60000)) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=handler, - ) - - with pytest.raises(MCPError) as exc_info: - await ctx.elicit(message="Test?", requested_schema={"type": "object"}) - - assert "elicitation capability" in exc_info.value.error.message - mock_session.check_client_capability.assert_called_once() - store.cleanup() - - -@pytest.mark.anyio -async def test_create_message_raises_when_client_lacks_capability() -> None: - """Test that create_message() raises MCPError when client doesn't support sampling.""" - store = InMemoryTaskStore() - mock_session = Mock() - mock_session.check_client_capability = Mock(return_value=False) - queue = InMemoryTaskMessageQueue() - handler = TaskResultHandler(store, queue) - task = await store.create_task(TaskMetadata(ttl=60000)) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=handler, - ) - - with pytest.raises(MCPError) as exc_info: - await ctx.create_message(messages=[], max_tokens=100) - - assert "sampling capability" in exc_info.value.error.message - mock_session.check_client_capability.assert_called_once() - store.cleanup() - - -@pytest.mark.anyio -async def test_elicit_raises_without_handler() -> None: - """Test that elicit() raises when handler is not provided.""" - store = InMemoryTaskStore() - mock_session = Mock() - mock_session.check_client_capability = Mock(return_value=True) - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=None, - ) - - with pytest.raises(RuntimeError, match="handler is required"): - await ctx.elicit(message="Test?", requested_schema={"type": "object"}) - - store.cleanup() - - -@pytest.mark.anyio -async def test_elicit_url_raises_without_handler() -> None: - """Test that elicit_url() raises when handler is not provided.""" - store = InMemoryTaskStore() - mock_session = Mock() - mock_session.check_client_capability = Mock(return_value=True) - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=None, - ) - - with pytest.raises(RuntimeError, match="handler is required for elicit_url"): - await ctx.elicit_url( - message="Please authorize", - url="https://example.com/oauth", - elicitation_id="oauth-123", - ) - - store.cleanup() - - -@pytest.mark.anyio -async def test_create_message_raises_without_handler() -> None: - """Test that create_message() raises when handler is not provided.""" - store = InMemoryTaskStore() - mock_session = Mock() - mock_session.check_client_capability = Mock(return_value=True) - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=None, - ) - - with pytest.raises(RuntimeError, match="handler is required"): - await ctx.create_message(messages=[], max_tokens=100) - - store.cleanup() - - -@pytest.mark.anyio -async def test_elicit_queues_request_and_waits_for_response() -> None: - """Test that elicit() queues request and waits for response.""" - store = InMemoryTaskStore() - queue = InMemoryTaskMessageQueue() - handler = TaskResultHandler(store, queue) - task = await store.create_task(TaskMetadata(ttl=60000)) - - mock_session = Mock() - mock_session.check_client_capability = Mock(return_value=True) - mock_session._build_elicit_form_request = Mock( - return_value=JSONRPCRequest( - jsonrpc="2.0", - id="test-req-1", - method="elicitation/create", - params={"message": "Test?", "_meta": {}}, - ) - ) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=handler, - ) - - elicit_result = None - - async def run_elicit() -> None: - nonlocal elicit_result - elicit_result = await ctx.elicit( - message="Test?", - requested_schema={"type": "object"}, - ) - - async with anyio.create_task_group() as tg: - tg.start_soon(run_elicit) - - # Wait for request to be queued - await queue.wait_for_message(task.task_id) - - # Verify task is in input_required status - updated_task = await store.get_task(task.task_id) - assert updated_task is not None - assert updated_task.status == "input_required" - - # Dequeue and simulate response - msg = await queue.dequeue(task.task_id) - assert msg is not None - assert msg.resolver is not None - - # Resolve with mock elicitation response - msg.resolver.set_result({"action": "accept", "content": {"name": "Alice"}}) - - # Verify result - assert elicit_result is not None - assert elicit_result.action == "accept" - assert elicit_result.content == {"name": "Alice"} - - # Verify task is back to working - final_task = await store.get_task(task.task_id) - assert final_task is not None - assert final_task.status == "working" - - store.cleanup() - - -@pytest.mark.anyio -async def test_elicit_url_queues_request_and_waits_for_response() -> None: - """Test that elicit_url() queues request and waits for response.""" - store = InMemoryTaskStore() - queue = InMemoryTaskMessageQueue() - handler = TaskResultHandler(store, queue) - task = await store.create_task(TaskMetadata(ttl=60000)) - - mock_session = Mock() - mock_session.check_client_capability = Mock(return_value=True) - mock_session._build_elicit_url_request = Mock( - return_value=JSONRPCRequest( - jsonrpc="2.0", - id="test-url-req-1", - method="elicitation/create", - params={"message": "Authorize", "url": "https://example.com", "elicitationId": "123", "mode": "url"}, - ) - ) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=handler, - ) - - elicit_result = None - - async def run_elicit_url() -> None: - nonlocal elicit_result - elicit_result = await ctx.elicit_url( - message="Authorize", - url="https://example.com/oauth", - elicitation_id="oauth-123", - ) - - async with anyio.create_task_group() as tg: - tg.start_soon(run_elicit_url) - - # Wait for request to be queued - await queue.wait_for_message(task.task_id) - - # Verify task is in input_required status - updated_task = await store.get_task(task.task_id) - assert updated_task is not None - assert updated_task.status == "input_required" - - # Dequeue and simulate response - msg = await queue.dequeue(task.task_id) - assert msg is not None - assert msg.resolver is not None - - # Resolve with mock elicitation response (URL mode just returns action) - msg.resolver.set_result({"action": "accept"}) - - # Verify result - assert elicit_result is not None - assert elicit_result.action == "accept" - - # Verify task is back to working - final_task = await store.get_task(task.task_id) - assert final_task is not None - assert final_task.status == "working" - - store.cleanup() - - -@pytest.mark.anyio -async def test_create_message_queues_request_and_waits_for_response() -> None: - """Test that create_message() queues request and waits for response.""" - store = InMemoryTaskStore() - queue = InMemoryTaskMessageQueue() - handler = TaskResultHandler(store, queue) - task = await store.create_task(TaskMetadata(ttl=60000)) - - mock_session = Mock() - mock_session.check_client_capability = Mock(return_value=True) - mock_session._build_create_message_request = Mock( - return_value=JSONRPCRequest( - jsonrpc="2.0", - id="test-req-2", - method="sampling/createMessage", - params={"messages": [], "maxTokens": 100, "_meta": {}}, - ) - ) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=handler, - ) - - sampling_result = None - - async def run_sampling() -> None: - nonlocal sampling_result - sampling_result = await ctx.create_message( - messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], - max_tokens=100, - ) - - async with anyio.create_task_group() as tg: - tg.start_soon(run_sampling) - - # Wait for request to be queued - await queue.wait_for_message(task.task_id) - - # Verify task is in input_required status - updated_task = await store.get_task(task.task_id) - assert updated_task is not None - assert updated_task.status == "input_required" - - # Dequeue and simulate response - msg = await queue.dequeue(task.task_id) - assert msg is not None - assert msg.resolver is not None - - # Resolve with mock sampling response - msg.resolver.set_result( - { - "role": "assistant", - "content": {"type": "text", "text": "Hello back!"}, - "model": "test-model", - "stopReason": "endTurn", - } - ) - - # Verify result - assert sampling_result is not None - assert sampling_result.role == "assistant" - assert sampling_result.model == "test-model" - - # Verify task is back to working - final_task = await store.get_task(task.task_id) - assert final_task is not None - assert final_task.status == "working" - - store.cleanup() - - -@pytest.mark.anyio -async def test_elicit_restores_status_on_cancellation() -> None: - """Test that elicit() restores task status to working when cancelled.""" - store = InMemoryTaskStore() - queue = InMemoryTaskMessageQueue() - handler = TaskResultHandler(store, queue) - task = await store.create_task(TaskMetadata(ttl=60000)) - - mock_session = Mock() - mock_session.check_client_capability = Mock(return_value=True) - mock_session._build_elicit_form_request = Mock( - return_value=JSONRPCRequest( - jsonrpc="2.0", - id="test-req-cancel", - method="elicitation/create", - params={"message": "Test?", "_meta": {}}, - ) - ) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=handler, - ) - - cancelled_error_raised = False - - async with anyio.create_task_group() as tg: - - async def do_elicit() -> None: - nonlocal cancelled_error_raised - try: - await ctx.elicit( - message="Test?", - requested_schema={"type": "object"}, - ) - except anyio.get_cancelled_exc_class(): - cancelled_error_raised = True - # Don't re-raise - let the test continue - - tg.start_soon(do_elicit) - - # Wait for request to be queued - await queue.wait_for_message(task.task_id) - - # Verify task is in input_required status - updated_task = await store.get_task(task.task_id) - assert updated_task is not None - assert updated_task.status == "input_required" - - # Get the queued message and set cancellation exception on its resolver - msg = await queue.dequeue(task.task_id) - assert msg is not None - assert msg.resolver is not None - - # Trigger cancellation by setting exception (use asyncio.CancelledError directly) - msg.resolver.set_exception(asyncio.CancelledError()) - - # Verify task is back to working after cancellation - final_task = await store.get_task(task.task_id) - assert final_task is not None - assert final_task.status == "working" - assert cancelled_error_raised - - store.cleanup() - - -@pytest.mark.anyio -async def test_create_message_restores_status_on_cancellation() -> None: - """Test that create_message() restores task status to working when cancelled.""" - store = InMemoryTaskStore() - queue = InMemoryTaskMessageQueue() - handler = TaskResultHandler(store, queue) - task = await store.create_task(TaskMetadata(ttl=60000)) - - mock_session = Mock() - mock_session.check_client_capability = Mock(return_value=True) - mock_session._build_create_message_request = Mock( - return_value=JSONRPCRequest( - jsonrpc="2.0", - id="test-req-cancel-2", - method="sampling/createMessage", - params={"messages": [], "maxTokens": 100, "_meta": {}}, - ) - ) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=handler, - ) - - cancelled_error_raised = False - - async with anyio.create_task_group() as tg: - - async def do_sampling() -> None: - nonlocal cancelled_error_raised - try: - await ctx.create_message( - messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], - max_tokens=100, - ) - except anyio.get_cancelled_exc_class(): - cancelled_error_raised = True - # Don't re-raise - - tg.start_soon(do_sampling) - - # Wait for request to be queued - await queue.wait_for_message(task.task_id) - - # Verify task is in input_required status - updated_task = await store.get_task(task.task_id) - assert updated_task is not None - assert updated_task.status == "input_required" - - # Get the queued message and set cancellation exception on its resolver - msg = await queue.dequeue(task.task_id) - assert msg is not None - assert msg.resolver is not None - - # Trigger cancellation by setting exception (use asyncio.CancelledError directly) - msg.resolver.set_exception(asyncio.CancelledError()) - - # Verify task is back to working after cancellation - final_task = await store.get_task(task.task_id) - assert final_task is not None - assert final_task.status == "working" - assert cancelled_error_raised - - store.cleanup() - - -@pytest.mark.anyio -async def test_elicit_as_task_raises_without_handler() -> None: - """Test that elicit_as_task() raises when handler is not provided.""" - store = InMemoryTaskStore() - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) - - # Create mock session with proper client capabilities - mock_session = Mock() - mock_session.client_params = InitializeRequestParams( - protocol_version="2025-01-01", - capabilities=ClientCapabilities( - tasks=ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()) - ) - ) - ), - client_info=Implementation(name="test", version="1.0"), - ) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=None, - ) - - with pytest.raises(RuntimeError, match="handler is required for elicit_as_task"): - await ctx.elicit_as_task(message="Test?", requested_schema={"type": "object"}) - - store.cleanup() - - -@pytest.mark.anyio -async def test_create_message_as_task_raises_without_handler() -> None: - """Test that create_message_as_task() raises when handler is not provided.""" - store = InMemoryTaskStore() - queue = InMemoryTaskMessageQueue() - task = await store.create_task(TaskMetadata(ttl=60000)) - - # Create mock session with proper client capabilities - mock_session = Mock() - mock_session.client_params = InitializeRequestParams( - protocol_version="2025-01-01", - capabilities=ClientCapabilities( - tasks=ClientTasksCapability( - requests=ClientTasksRequestsCapability( - sampling=TasksSamplingCapability(create_message=TasksCreateMessageCapability()) - ) - ) - ), - client_info=Implementation(name="test", version="1.0"), - ) - - ctx = ServerTaskContext( - task=task, - store=store, - session=mock_session, - queue=queue, - handler=None, - ) - - with pytest.raises(RuntimeError, match="handler is required for create_message_as_task"): - await ctx.create_message_as_task( - messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], - max_tokens=100, - ) - - store.cleanup() diff --git a/tests/experimental/tasks/server/test_store.py b/tests/experimental/tasks/server/test_store.py deleted file mode 100644 index 0d431899c8..0000000000 --- a/tests/experimental/tasks/server/test_store.py +++ /dev/null @@ -1,406 +0,0 @@ -"""Tests for InMemoryTaskStore.""" - -from collections.abc import AsyncIterator -from datetime import datetime, timedelta, timezone - -import pytest - -from mcp.shared.exceptions import MCPError -from mcp.shared.experimental.tasks.helpers import cancel_task -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.types import INVALID_PARAMS, CallToolResult, TaskMetadata, TextContent - - -@pytest.fixture -async def store() -> AsyncIterator[InMemoryTaskStore]: - """Provide a clean InMemoryTaskStore for each test with automatic cleanup.""" - store = InMemoryTaskStore() - yield store - store.cleanup() - - -@pytest.mark.anyio -async def test_create_and_get(store: InMemoryTaskStore) -> None: - """Test InMemoryTaskStore create and get operations.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - - assert task.task_id is not None - assert task.status == "working" - assert task.ttl == 60000 - - retrieved = await store.get_task(task.task_id) - assert retrieved is not None - assert retrieved.task_id == task.task_id - assert retrieved.status == "working" - - -@pytest.mark.anyio -async def test_create_with_custom_id(store: InMemoryTaskStore) -> None: - """Test InMemoryTaskStore create with custom task ID.""" - task = await store.create_task( - metadata=TaskMetadata(ttl=60000), - task_id="my-custom-id", - ) - - assert task.task_id == "my-custom-id" - assert task.status == "working" - - retrieved = await store.get_task("my-custom-id") - assert retrieved is not None - assert retrieved.task_id == "my-custom-id" - - -@pytest.mark.anyio -async def test_create_duplicate_id_raises(store: InMemoryTaskStore) -> None: - """Test that creating a task with duplicate ID raises.""" - await store.create_task(metadata=TaskMetadata(ttl=60000), task_id="duplicate") - - with pytest.raises(ValueError, match="already exists"): - await store.create_task(metadata=TaskMetadata(ttl=60000), task_id="duplicate") - - -@pytest.mark.anyio -async def test_get_nonexistent_returns_none(store: InMemoryTaskStore) -> None: - """Test that getting a nonexistent task returns None.""" - retrieved = await store.get_task("nonexistent") - assert retrieved is None - - -@pytest.mark.anyio -async def test_update_status(store: InMemoryTaskStore) -> None: - """Test InMemoryTaskStore status updates.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - - updated = await store.update_task(task.task_id, status="completed", status_message="All done!") - - assert updated.status == "completed" - assert updated.status_message == "All done!" - - retrieved = await store.get_task(task.task_id) - assert retrieved is not None - assert retrieved.status == "completed" - assert retrieved.status_message == "All done!" - - -@pytest.mark.anyio -async def test_update_nonexistent_raises(store: InMemoryTaskStore) -> None: - """Test that updating a nonexistent task raises.""" - with pytest.raises(ValueError, match="not found"): - await store.update_task("nonexistent", status="completed") - - -@pytest.mark.anyio -async def test_store_and_get_result(store: InMemoryTaskStore) -> None: - """Test InMemoryTaskStore result storage and retrieval.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - - # Store result - result = CallToolResult(content=[TextContent(type="text", text="Result data")]) - await store.store_result(task.task_id, result) - - # Retrieve result - retrieved_result = await store.get_result(task.task_id) - assert retrieved_result == result - - -@pytest.mark.anyio -async def test_get_result_nonexistent_returns_none(store: InMemoryTaskStore) -> None: - """Test that getting result for nonexistent task returns None.""" - result = await store.get_result("nonexistent") - assert result is None - - -@pytest.mark.anyio -async def test_get_result_no_result_returns_none(store: InMemoryTaskStore) -> None: - """Test that getting result when none stored returns None.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - result = await store.get_result(task.task_id) - assert result is None - - -@pytest.mark.anyio -async def test_list_tasks(store: InMemoryTaskStore) -> None: - """Test InMemoryTaskStore list operation.""" - # Create multiple tasks - for _ in range(3): - await store.create_task(metadata=TaskMetadata(ttl=60000)) - - tasks, next_cursor = await store.list_tasks() - assert len(tasks) == 3 - assert next_cursor is None # Less than page size - - -@pytest.mark.anyio -async def test_list_tasks_pagination() -> None: - """Test InMemoryTaskStore pagination.""" - # Needs custom page_size, can't use fixture - store = InMemoryTaskStore(page_size=2) - - # Create 5 tasks - for _ in range(5): - await store.create_task(metadata=TaskMetadata(ttl=60000)) - - # First page - tasks, next_cursor = await store.list_tasks() - assert len(tasks) == 2 - assert next_cursor is not None - - # Second page - tasks, next_cursor = await store.list_tasks(cursor=next_cursor) - assert len(tasks) == 2 - assert next_cursor is not None - - # Third page (last) - tasks, next_cursor = await store.list_tasks(cursor=next_cursor) - assert len(tasks) == 1 - assert next_cursor is None - - store.cleanup() - - -@pytest.mark.anyio -async def test_list_tasks_invalid_cursor(store: InMemoryTaskStore) -> None: - """Test that invalid cursor raises.""" - await store.create_task(metadata=TaskMetadata(ttl=60000)) - - with pytest.raises(ValueError, match="Invalid cursor"): - await store.list_tasks(cursor="invalid-cursor") - - -@pytest.mark.anyio -async def test_delete_task(store: InMemoryTaskStore) -> None: - """Test InMemoryTaskStore delete operation.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - - deleted = await store.delete_task(task.task_id) - assert deleted is True - - retrieved = await store.get_task(task.task_id) - assert retrieved is None - - # Delete non-existent - deleted = await store.delete_task(task.task_id) - assert deleted is False - - -@pytest.mark.anyio -async def test_get_all_tasks_helper(store: InMemoryTaskStore) -> None: - """Test the get_all_tasks debugging helper.""" - await store.create_task(metadata=TaskMetadata(ttl=60000)) - await store.create_task(metadata=TaskMetadata(ttl=60000)) - - all_tasks = store.get_all_tasks() - assert len(all_tasks) == 2 - - -@pytest.mark.anyio -async def test_store_result_nonexistent_raises(store: InMemoryTaskStore) -> None: - """Test that storing result for nonexistent task raises ValueError.""" - result = CallToolResult(content=[TextContent(type="text", text="Result")]) - - with pytest.raises(ValueError, match="not found"): - await store.store_result("nonexistent-id", result) - - -@pytest.mark.anyio -async def test_create_task_with_null_ttl(store: InMemoryTaskStore) -> None: - """Test creating task with null TTL (never expires).""" - task = await store.create_task(metadata=TaskMetadata(ttl=None)) - - assert task.ttl is None - - # Task should persist (not expire) - retrieved = await store.get_task(task.task_id) - assert retrieved is not None - - -@pytest.mark.anyio -async def test_task_expiration_cleanup(store: InMemoryTaskStore) -> None: - """Test that expired tasks are cleaned up lazily.""" - # Create a task with very short TTL - task = await store.create_task(metadata=TaskMetadata(ttl=1)) # 1ms TTL - - # Manually force the expiry to be in the past - stored = store._tasks.get(task.task_id) - assert stored is not None - stored.expires_at = datetime.now(timezone.utc) - timedelta(seconds=10) - - # Task should still exist in internal dict but be expired - assert task.task_id in store._tasks - - # Any access operation should clean up expired tasks - # list_tasks triggers cleanup - tasks, _ = await store.list_tasks() - - # Expired task should be cleaned up - assert task.task_id not in store._tasks - assert len(tasks) == 0 - - -@pytest.mark.anyio -async def test_task_with_null_ttl_never_expires(store: InMemoryTaskStore) -> None: - """Test that tasks with null TTL never expire during cleanup.""" - # Create task with null TTL - task = await store.create_task(metadata=TaskMetadata(ttl=None)) - - # Verify internal storage has no expiry - stored = store._tasks.get(task.task_id) - assert stored is not None - assert stored.expires_at is None - - # Access operations should NOT remove this task - await store.list_tasks() - await store.get_task(task.task_id) - - # Task should still exist - assert task.task_id in store._tasks - retrieved = await store.get_task(task.task_id) - assert retrieved is not None - - -@pytest.mark.anyio -async def test_terminal_task_ttl_reset(store: InMemoryTaskStore) -> None: - """Test that TTL is reset when task enters terminal state.""" - # Create task with short TTL - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) # 60s - - # Get the initial expiry - stored = store._tasks.get(task.task_id) - assert stored is not None - initial_expiry = stored.expires_at - assert initial_expiry is not None - - # Update to terminal state (completed) - await store.update_task(task.task_id, status="completed") - - # Expiry should be reset to a new time (from now + TTL) - new_expiry = stored.expires_at - assert new_expiry is not None - assert new_expiry >= initial_expiry - - -@pytest.mark.anyio -async def test_terminal_status_transition_rejected(store: InMemoryTaskStore) -> None: - """Test that transitions from terminal states are rejected. - - Per spec: Terminal states (completed, failed, cancelled) MUST NOT - transition to any other status. - """ - # Test each terminal status - for terminal_status in ("completed", "failed", "cancelled"): - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - - # Move to terminal state - await store.update_task(task.task_id, status=terminal_status) - - # Attempting to transition to any other status should raise - with pytest.raises(ValueError, match="Cannot transition from terminal status"): - await store.update_task(task.task_id, status="working") - - # Also test transitioning to another terminal state - other_terminal = "failed" if terminal_status != "failed" else "completed" - with pytest.raises(ValueError, match="Cannot transition from terminal status"): - await store.update_task(task.task_id, status=other_terminal) - - -@pytest.mark.anyio -async def test_terminal_status_allows_same_status(store: InMemoryTaskStore) -> None: - """Test that setting the same terminal status doesn't raise. - - This is not a transition, so it should be allowed (no-op). - """ - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - await store.update_task(task.task_id, status="completed") - - # Setting the same status should not raise - updated = await store.update_task(task.task_id, status="completed") - assert updated.status == "completed" - - # Updating just the message should also work - updated = await store.update_task(task.task_id, status_message="Updated message") - assert updated.status_message == "Updated message" - - -@pytest.mark.anyio -async def test_wait_for_update_nonexistent_raises(store: InMemoryTaskStore) -> None: - """Test that wait_for_update raises for nonexistent task.""" - with pytest.raises(ValueError, match="not found"): - await store.wait_for_update("nonexistent-task-id") - - -@pytest.mark.anyio -async def test_cancel_task_succeeds_for_working_task(store: InMemoryTaskStore) -> None: - """Test cancel_task helper succeeds for a working task.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - assert task.status == "working" - - result = await cancel_task(store, task.task_id) - - assert result.task_id == task.task_id - assert result.status == "cancelled" - - # Verify store is updated - retrieved = await store.get_task(task.task_id) - assert retrieved is not None - assert retrieved.status == "cancelled" - - -@pytest.mark.anyio -async def test_cancel_task_rejects_nonexistent_task(store: InMemoryTaskStore) -> None: - """Test cancel_task raises MCPError with INVALID_PARAMS for nonexistent task.""" - with pytest.raises(MCPError) as exc_info: - await cancel_task(store, "nonexistent-task-id") - - assert exc_info.value.error.code == INVALID_PARAMS - assert "not found" in exc_info.value.error.message - - -@pytest.mark.anyio -async def test_cancel_task_rejects_completed_task(store: InMemoryTaskStore) -> None: - """Test cancel_task raises MCPError with INVALID_PARAMS for completed task.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - await store.update_task(task.task_id, status="completed") - - with pytest.raises(MCPError) as exc_info: - await cancel_task(store, task.task_id) - - assert exc_info.value.error.code == INVALID_PARAMS - assert "terminal state 'completed'" in exc_info.value.error.message - - -@pytest.mark.anyio -async def test_cancel_task_rejects_failed_task(store: InMemoryTaskStore) -> None: - """Test cancel_task raises MCPError with INVALID_PARAMS for failed task.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - await store.update_task(task.task_id, status="failed") - - with pytest.raises(MCPError) as exc_info: - await cancel_task(store, task.task_id) - - assert exc_info.value.error.code == INVALID_PARAMS - assert "terminal state 'failed'" in exc_info.value.error.message - - -@pytest.mark.anyio -async def test_cancel_task_rejects_already_cancelled_task(store: InMemoryTaskStore) -> None: - """Test cancel_task raises MCPError with INVALID_PARAMS for already cancelled task.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - await store.update_task(task.task_id, status="cancelled") - - with pytest.raises(MCPError) as exc_info: - await cancel_task(store, task.task_id) - - assert exc_info.value.error.code == INVALID_PARAMS - assert "terminal state 'cancelled'" in exc_info.value.error.message - - -@pytest.mark.anyio -async def test_cancel_task_succeeds_for_input_required_task(store: InMemoryTaskStore) -> None: - """Test cancel_task helper succeeds for a task in input_required status.""" - task = await store.create_task(metadata=TaskMetadata(ttl=60000)) - await store.update_task(task.task_id, status="input_required") - - result = await cancel_task(store, task.task_id) - - assert result.task_id == task.task_id - assert result.status == "cancelled" diff --git a/tests/experimental/tasks/server/test_task_result_handler.py b/tests/experimental/tasks/server/test_task_result_handler.py deleted file mode 100644 index 8b5a03ce2b..0000000000 --- a/tests/experimental/tasks/server/test_task_result_handler.py +++ /dev/null @@ -1,354 +0,0 @@ -"""Tests for TaskResultHandler.""" - -from collections.abc import AsyncIterator -from typing import Any -from unittest.mock import AsyncMock, Mock - -import anyio -import pytest - -from mcp.server.experimental.task_result_handler import TaskResultHandler -from mcp.shared.exceptions import MCPError -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue, QueuedMessage -from mcp.shared.experimental.tasks.resolver import Resolver -from mcp.shared.message import SessionMessage -from mcp.types import ( - INVALID_REQUEST, - CallToolResult, - ErrorData, - GetTaskPayloadRequest, - GetTaskPayloadRequestParams, - GetTaskPayloadResult, - JSONRPCRequest, - TaskMetadata, - TextContent, -) - - -@pytest.fixture -async def store() -> AsyncIterator[InMemoryTaskStore]: - """Provide a clean store for each test.""" - s = InMemoryTaskStore() - yield s - s.cleanup() - - -@pytest.fixture -def queue() -> InMemoryTaskMessageQueue: - """Provide a clean queue for each test.""" - return InMemoryTaskMessageQueue() - - -@pytest.fixture -def handler(store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue) -> TaskResultHandler: - """Provide a handler for each test.""" - return TaskResultHandler(store, queue) - - -@pytest.mark.anyio -async def test_handle_returns_result_for_completed_task( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that handle() returns the stored result for a completed task.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") - result = CallToolResult(content=[TextContent(type="text", text="Done!")]) - await store.store_result(task.task_id, result) - await store.update_task(task.task_id, status="completed") - - mock_session = Mock() - mock_session.send_message = AsyncMock() - - request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=task.task_id)) - response = await handler.handle(request, mock_session, "req-1") - - assert response is not None - assert response.meta is not None - assert "io.modelcontextprotocol/related-task" in response.meta - - -@pytest.mark.anyio -async def test_handle_raises_for_nonexistent_task( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that handle() raises MCPError for nonexistent task.""" - mock_session = Mock() - request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id="nonexistent")) - - with pytest.raises(MCPError) as exc_info: - await handler.handle(request, mock_session, "req-1") - - assert "not found" in exc_info.value.error.message - - -@pytest.mark.anyio -async def test_handle_returns_empty_result_when_no_result_stored( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that handle() returns minimal result when task completed without stored result.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") - await store.update_task(task.task_id, status="completed") - - mock_session = Mock() - mock_session.send_message = AsyncMock() - - request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=task.task_id)) - response = await handler.handle(request, mock_session, "req-1") - - assert response is not None - assert response.meta is not None - assert "io.modelcontextprotocol/related-task" in response.meta - - -@pytest.mark.anyio -async def test_handle_delivers_queued_messages( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that handle() delivers queued messages before returning.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") - - queued_msg = QueuedMessage( - type="notification", - message=JSONRPCRequest( - jsonrpc="2.0", - id="notif-1", - method="test/notification", - params={}, - ), - ) - await queue.enqueue(task.task_id, queued_msg) - await store.update_task(task.task_id, status="completed") - - sent_messages: list[SessionMessage] = [] - - async def track_send(msg: SessionMessage) -> None: - sent_messages.append(msg) - - mock_session = Mock() - mock_session.send_message = track_send - - request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=task.task_id)) - await handler.handle(request, mock_session, "req-1") - - assert len(sent_messages) == 1 - - -@pytest.mark.anyio -async def test_handle_waits_for_task_completion( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that handle() waits for task to complete before returning.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") - - mock_session = Mock() - mock_session.send_message = AsyncMock() - - request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=task.task_id)) - result_holder: list[GetTaskPayloadResult | None] = [None] - - async def run_handle() -> None: - result_holder[0] = await handler.handle(request, mock_session, "req-1") - - async with anyio.create_task_group() as tg: - tg.start_soon(run_handle) - - # Wait for handler to start waiting (event gets created when wait starts) - while task.task_id not in store._update_events: - await anyio.sleep(0) - - await store.store_result(task.task_id, CallToolResult(content=[TextContent(type="text", text="Done")])) - await store.update_task(task.task_id, status="completed") - - assert result_holder[0] is not None - - -@pytest.mark.anyio -async def test_route_response_resolves_pending_request( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that route_response() resolves a pending request.""" - resolver: Resolver[dict[str, Any]] = Resolver() - handler._pending_requests["req-123"] = resolver - - result = handler.route_response("req-123", {"status": "ok"}) - - assert result is True - assert resolver.done() - assert await resolver.wait() == {"status": "ok"} - - -@pytest.mark.anyio -async def test_route_response_returns_false_for_unknown_request( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that route_response() returns False for unknown request ID.""" - result = handler.route_response("unknown-req", {"status": "ok"}) - assert result is False - - -@pytest.mark.anyio -async def test_route_response_returns_false_for_already_done_resolver( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that route_response() returns False if resolver already completed.""" - resolver: Resolver[dict[str, Any]] = Resolver() - resolver.set_result({"already": "done"}) - handler._pending_requests["req-123"] = resolver - - result = handler.route_response("req-123", {"new": "data"}) - - assert result is False - - -@pytest.mark.anyio -async def test_route_error_resolves_pending_request_with_exception( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that route_error() sets exception on pending request.""" - resolver: Resolver[dict[str, Any]] = Resolver() - handler._pending_requests["req-123"] = resolver - - error = ErrorData(code=INVALID_REQUEST, message="Something went wrong") - result = handler.route_error("req-123", error) - - assert result is True - assert resolver.done() - - with pytest.raises(MCPError) as exc_info: - await resolver.wait() - assert exc_info.value.error.message == "Something went wrong" - - -@pytest.mark.anyio -async def test_route_error_returns_false_for_unknown_request( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that route_error() returns False for unknown request ID.""" - error = ErrorData(code=INVALID_REQUEST, message="Error") - result = handler.route_error("unknown-req", error) - assert result is False - - -@pytest.mark.anyio -async def test_deliver_registers_resolver_for_request_messages( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that _deliver_queued_messages registers resolvers for request messages.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") - - resolver: Resolver[dict[str, Any]] = Resolver() - queued_msg = QueuedMessage( - type="request", - message=JSONRPCRequest( - jsonrpc="2.0", - id="inner-req-1", - method="elicitation/create", - params={}, - ), - resolver=resolver, - original_request_id="inner-req-1", - ) - await queue.enqueue(task.task_id, queued_msg) - - mock_session = Mock() - mock_session.send_message = AsyncMock() - - await handler._deliver_queued_messages(task.task_id, mock_session, "outer-req-1") - - assert "inner-req-1" in handler._pending_requests - assert handler._pending_requests["inner-req-1"] is resolver - - -@pytest.mark.anyio -async def test_deliver_skips_resolver_registration_when_no_original_id( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that _deliver_queued_messages skips resolver registration when original_request_id is None.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") - - resolver: Resolver[dict[str, Any]] = Resolver() - queued_msg = QueuedMessage( - type="request", - message=JSONRPCRequest( - jsonrpc="2.0", - id="inner-req-1", - method="elicitation/create", - params={}, - ), - resolver=resolver, - original_request_id=None, # No original request ID - ) - await queue.enqueue(task.task_id, queued_msg) - - mock_session = Mock() - mock_session.send_message = AsyncMock() - - await handler._deliver_queued_messages(task.task_id, mock_session, "outer-req-1") - - # Resolver should NOT be registered since original_request_id is None - assert len(handler._pending_requests) == 0 - # But the message should still be sent - mock_session.send_message.assert_called_once() - - -@pytest.mark.anyio -async def test_wait_for_task_update_handles_store_exception( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that _wait_for_task_update handles store exception gracefully.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") - - # Make wait_for_update raise an exception - async def failing_wait(task_id: str) -> None: - raise RuntimeError("Store error") - - store.wait_for_update = failing_wait # type: ignore[method-assign] - - # Queue a message to unblock the race via the queue path - async def enqueue_later() -> None: - # Wait for queue to start waiting (event gets created when wait starts) - while task.task_id not in queue._events: - await anyio.sleep(0) - await queue.enqueue( - task.task_id, - QueuedMessage( - type="notification", - message=JSONRPCRequest( - jsonrpc="2.0", - id="notif-1", - method="test/notification", - params={}, - ), - ), - ) - - async with anyio.create_task_group() as tg: - tg.start_soon(enqueue_later) - # This should complete via the queue path even though store raises - await handler._wait_for_task_update(task.task_id) - - -@pytest.mark.anyio -async def test_wait_for_task_update_handles_queue_exception( - store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler -) -> None: - """Test that _wait_for_task_update handles queue exception gracefully.""" - task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") - - # Make wait_for_message raise an exception - async def failing_wait(task_id: str) -> None: - raise RuntimeError("Queue error") - - queue.wait_for_message = failing_wait # type: ignore[method-assign] - - # Update the store to unblock the race via the store path - async def update_later() -> None: - # Wait for store to start waiting (event gets created when wait starts) - while task.task_id not in store._update_events: - await anyio.sleep(0) - await store.update_task(task.task_id, status="completed") - - async with anyio.create_task_group() as tg: - tg.start_soon(update_later) - # This should complete via the store path even though queue raises - await handler._wait_for_task_update(task.task_id) diff --git a/tests/experimental/tasks/test_capabilities.py b/tests/experimental/tasks/test_capabilities.py deleted file mode 100644 index 90a8656ba0..0000000000 --- a/tests/experimental/tasks/test_capabilities.py +++ /dev/null @@ -1,283 +0,0 @@ -"""Tests for tasks capability checking utilities.""" - -import pytest - -from mcp import MCPError -from mcp.shared.experimental.tasks.capabilities import ( - check_tasks_capability, - has_task_augmented_elicitation, - has_task_augmented_sampling, - require_task_augmented_elicitation, - require_task_augmented_sampling, -) -from mcp.types import ( - ClientCapabilities, - ClientTasksCapability, - ClientTasksRequestsCapability, - TasksCreateElicitationCapability, - TasksCreateMessageCapability, - TasksElicitationCapability, - TasksSamplingCapability, -) - - -class TestCheckTasksCapability: - """Tests for check_tasks_capability function.""" - - def test_required_requests_none_returns_true(self) -> None: - """When required.requests is None, should return True.""" - required = ClientTasksCapability() - client = ClientTasksCapability() - assert check_tasks_capability(required, client) is True - - def test_client_requests_none_returns_false(self) -> None: - """When client.requests is None but required.requests is set, should return False.""" - required = ClientTasksCapability(requests=ClientTasksRequestsCapability()) - client = ClientTasksCapability() - assert check_tasks_capability(required, client) is False - - def test_elicitation_required_but_client_missing(self) -> None: - """When elicitation is required but client doesn't have it.""" - required = ClientTasksCapability( - requests=ClientTasksRequestsCapability(elicitation=TasksElicitationCapability()) - ) - client = ClientTasksCapability(requests=ClientTasksRequestsCapability()) - assert check_tasks_capability(required, client) is False - - def test_elicitation_create_required_but_client_missing(self) -> None: - """When elicitation.create is required but client doesn't have it.""" - required = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()) - ) - ) - client = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability() # No create - ) - ) - assert check_tasks_capability(required, client) is False - - def test_elicitation_create_present(self) -> None: - """When elicitation.create is required and client has it.""" - required = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()) - ) - ) - client = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()) - ) - ) - assert check_tasks_capability(required, client) is True - - def test_sampling_required_but_client_missing(self) -> None: - """When sampling is required but client doesn't have it.""" - required = ClientTasksCapability(requests=ClientTasksRequestsCapability(sampling=TasksSamplingCapability())) - client = ClientTasksCapability(requests=ClientTasksRequestsCapability()) - assert check_tasks_capability(required, client) is False - - def test_sampling_create_message_required_but_client_missing(self) -> None: - """When sampling.createMessage is required but client doesn't have it.""" - required = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - sampling=TasksSamplingCapability(create_message=TasksCreateMessageCapability()) - ) - ) - client = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - sampling=TasksSamplingCapability() # No createMessage - ) - ) - assert check_tasks_capability(required, client) is False - - def test_sampling_create_message_present(self) -> None: - """When sampling.createMessage is required and client has it.""" - required = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - sampling=TasksSamplingCapability(create_message=TasksCreateMessageCapability()) - ) - ) - client = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - sampling=TasksSamplingCapability(create_message=TasksCreateMessageCapability()) - ) - ) - assert check_tasks_capability(required, client) is True - - def test_both_elicitation_and_sampling_present(self) -> None: - """When both elicitation.create and sampling.createMessage are required and client has both.""" - required = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()), - sampling=TasksSamplingCapability(create_message=TasksCreateMessageCapability()), - ) - ) - client = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()), - sampling=TasksSamplingCapability(create_message=TasksCreateMessageCapability()), - ) - ) - assert check_tasks_capability(required, client) is True - - def test_elicitation_without_create_required(self) -> None: - """When elicitation is required but not create specifically.""" - required = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability() # No create - ) - ) - client = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()) - ) - ) - assert check_tasks_capability(required, client) is True - - def test_sampling_without_create_message_required(self) -> None: - """When sampling is required but not createMessage specifically.""" - required = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - sampling=TasksSamplingCapability() # No createMessage - ) - ) - client = ClientTasksCapability( - requests=ClientTasksRequestsCapability( - sampling=TasksSamplingCapability(create_message=TasksCreateMessageCapability()) - ) - ) - assert check_tasks_capability(required, client) is True - - -class TestHasTaskAugmentedElicitation: - """Tests for has_task_augmented_elicitation function.""" - - def test_tasks_none(self) -> None: - """Returns False when caps.tasks is None.""" - caps = ClientCapabilities() - assert has_task_augmented_elicitation(caps) is False - - def test_requests_none(self) -> None: - """Returns False when caps.tasks.requests is None.""" - caps = ClientCapabilities(tasks=ClientTasksCapability()) - assert has_task_augmented_elicitation(caps) is False - - def test_elicitation_none(self) -> None: - """Returns False when caps.tasks.requests.elicitation is None.""" - caps = ClientCapabilities(tasks=ClientTasksCapability(requests=ClientTasksRequestsCapability())) - assert has_task_augmented_elicitation(caps) is False - - def test_create_none(self) -> None: - """Returns False when caps.tasks.requests.elicitation.create is None.""" - caps = ClientCapabilities( - tasks=ClientTasksCapability( - requests=ClientTasksRequestsCapability(elicitation=TasksElicitationCapability()) - ) - ) - assert has_task_augmented_elicitation(caps) is False - - def test_create_present(self) -> None: - """Returns True when full capability path is present.""" - caps = ClientCapabilities( - tasks=ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()) - ) - ) - ) - assert has_task_augmented_elicitation(caps) is True - - -class TestHasTaskAugmentedSampling: - """Tests for has_task_augmented_sampling function.""" - - def test_tasks_none(self) -> None: - """Returns False when caps.tasks is None.""" - caps = ClientCapabilities() - assert has_task_augmented_sampling(caps) is False - - def test_requests_none(self) -> None: - """Returns False when caps.tasks.requests is None.""" - caps = ClientCapabilities(tasks=ClientTasksCapability()) - assert has_task_augmented_sampling(caps) is False - - def test_sampling_none(self) -> None: - """Returns False when caps.tasks.requests.sampling is None.""" - caps = ClientCapabilities(tasks=ClientTasksCapability(requests=ClientTasksRequestsCapability())) - assert has_task_augmented_sampling(caps) is False - - def test_create_message_none(self) -> None: - """Returns False when caps.tasks.requests.sampling.createMessage is None.""" - caps = ClientCapabilities( - tasks=ClientTasksCapability(requests=ClientTasksRequestsCapability(sampling=TasksSamplingCapability())) - ) - assert has_task_augmented_sampling(caps) is False - - def test_create_message_present(self) -> None: - """Returns True when full capability path is present.""" - caps = ClientCapabilities( - tasks=ClientTasksCapability( - requests=ClientTasksRequestsCapability( - sampling=TasksSamplingCapability(create_message=TasksCreateMessageCapability()) - ) - ) - ) - assert has_task_augmented_sampling(caps) is True - - -class TestRequireTaskAugmentedElicitation: - """Tests for require_task_augmented_elicitation function.""" - - def test_raises_when_none(self) -> None: - """Raises MCPError when client_caps is None.""" - with pytest.raises(MCPError) as exc_info: - require_task_augmented_elicitation(None) - assert "task-augmented elicitation" in str(exc_info.value) - - def test_raises_when_missing(self) -> None: - """Raises MCPError when capability is missing.""" - caps = ClientCapabilities() - with pytest.raises(MCPError) as exc_info: - require_task_augmented_elicitation(caps) - assert "task-augmented elicitation" in str(exc_info.value) - - def test_passes_when_present(self) -> None: - """Does not raise when capability is present.""" - caps = ClientCapabilities( - tasks=ClientTasksCapability( - requests=ClientTasksRequestsCapability( - elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()) - ) - ) - ) - require_task_augmented_elicitation(caps) - - -class TestRequireTaskAugmentedSampling: - """Tests for require_task_augmented_sampling function.""" - - def test_raises_when_none(self) -> None: - """Raises MCPError when client_caps is None.""" - with pytest.raises(MCPError) as exc_info: - require_task_augmented_sampling(None) - assert "task-augmented sampling" in str(exc_info.value) - - def test_raises_when_missing(self) -> None: - """Raises MCPError when capability is missing.""" - caps = ClientCapabilities() - with pytest.raises(MCPError) as exc_info: - require_task_augmented_sampling(caps) - assert "task-augmented sampling" in str(exc_info.value) - - def test_passes_when_present(self) -> None: - """Does not raise when capability is present.""" - caps = ClientCapabilities( - tasks=ClientTasksCapability( - requests=ClientTasksRequestsCapability( - sampling=TasksSamplingCapability(create_message=TasksCreateMessageCapability()) - ) - ) - ) - require_task_augmented_sampling(caps) diff --git a/tests/experimental/tasks/test_elicitation_scenarios.py b/tests/experimental/tasks/test_elicitation_scenarios.py deleted file mode 100644 index 2d0378a9ce..0000000000 --- a/tests/experimental/tasks/test_elicitation_scenarios.py +++ /dev/null @@ -1,695 +0,0 @@ -"""Tests for the four elicitation scenarios with tasks. - -This tests all combinations of tool call types and elicitation types: -1. Normal tool call + Normal elicitation (session.elicit) -2. Normal tool call + Task-augmented elicitation (session.experimental.elicit_as_task) -3. Task-augmented tool call + Normal elicitation (task.elicit) -4. Task-augmented tool call + Task-augmented elicitation (task.elicit_as_task) - -And the same for sampling (create_message). -""" - -from typing import Any - -import anyio -import pytest -from anyio import Event - -from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers -from mcp.client.session import ClientSession -from mcp.server import Server, ServerRequestContext -from mcp.server.experimental.task_context import ServerTaskContext -from mcp.server.lowlevel import NotificationOptions -from mcp.shared._context import RequestContext -from mcp.shared.experimental.tasks.helpers import is_terminal -from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.shared.message import SessionMessage -from mcp.types import ( - TASK_REQUIRED, - CallToolRequestParams, - CallToolResult, - CreateMessageRequestParams, - CreateMessageResult, - CreateTaskResult, - ElicitRequestParams, - ElicitResult, - ErrorData, - GetTaskPayloadResult, - GetTaskResult, - ListToolsResult, - PaginatedRequestParams, - SamplingMessage, - TaskMetadata, - TextContent, - Tool, -) - - -def create_client_task_handlers( - client_task_store: InMemoryTaskStore, - elicit_received: Event, -) -> ExperimentalTaskHandlers: - """Create task handlers for client to handle task-augmented elicitation from server.""" - - elicit_response = ElicitResult(action="accept", content={"confirm": True}) - task_complete_events: dict[str, Event] = {} - - async def handle_augmented_elicitation( - context: RequestContext[ClientSession], - params: ElicitRequestParams, - task_metadata: TaskMetadata, - ) -> CreateTaskResult: - """Handle task-augmented elicitation by creating a client-side task.""" - elicit_received.set() - task = await client_task_store.create_task(task_metadata) - task_complete_events[task.task_id] = Event() - - async def complete_task() -> None: - # Store result before updating status to avoid race condition - await client_task_store.store_result(task.task_id, elicit_response) - await client_task_store.update_task(task.task_id, status="completed") - task_complete_events[task.task_id].set() - - context.session._task_group.start_soon(complete_task) # pyright: ignore[reportPrivateUsage] - return CreateTaskResult(task=task) - - async def handle_get_task( - context: RequestContext[ClientSession], - params: Any, - ) -> GetTaskResult: - """Handle tasks/get from server.""" - task = await client_task_store.get_task(params.task_id) - assert task is not None, f"Task not found: {params.task_id}" - return GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=100, - ) - - async def handle_get_task_result( - context: RequestContext[ClientSession], - params: Any, - ) -> GetTaskPayloadResult | ErrorData: - """Handle tasks/result from server.""" - event = task_complete_events.get(params.task_id) - assert event is not None, f"No completion event for task: {params.task_id}" - await event.wait() - result = await client_task_store.get_result(params.task_id) - assert result is not None, f"Result not found for task: {params.task_id}" - return GetTaskPayloadResult.model_validate(result.model_dump(by_alias=True)) - - return ExperimentalTaskHandlers( - augmented_elicitation=handle_augmented_elicitation, - get_task=handle_get_task, - get_task_result=handle_get_task_result, - ) - - -def create_sampling_task_handlers( - client_task_store: InMemoryTaskStore, - sampling_received: Event, -) -> ExperimentalTaskHandlers: - """Create task handlers for client to handle task-augmented sampling from server.""" - - sampling_response = CreateMessageResult( - role="assistant", - content=TextContent(type="text", text="Hello from the model!"), - model="test-model", - ) - task_complete_events: dict[str, Event] = {} - - async def handle_augmented_sampling( - context: RequestContext[ClientSession], - params: CreateMessageRequestParams, - task_metadata: TaskMetadata, - ) -> CreateTaskResult: - """Handle task-augmented sampling by creating a client-side task.""" - sampling_received.set() - task = await client_task_store.create_task(task_metadata) - task_complete_events[task.task_id] = Event() - - async def complete_task() -> None: - # Store result before updating status to avoid race condition - await client_task_store.store_result(task.task_id, sampling_response) - await client_task_store.update_task(task.task_id, status="completed") - task_complete_events[task.task_id].set() - - context.session._task_group.start_soon(complete_task) # pyright: ignore[reportPrivateUsage] - return CreateTaskResult(task=task) - - async def handle_get_task( - context: RequestContext[ClientSession], - params: Any, - ) -> GetTaskResult: - """Handle tasks/get from server.""" - task = await client_task_store.get_task(params.task_id) - assert task is not None, f"Task not found: {params.task_id}" - return GetTaskResult( - task_id=task.task_id, - status=task.status, - status_message=task.status_message, - created_at=task.created_at, - last_updated_at=task.last_updated_at, - ttl=task.ttl, - poll_interval=100, - ) - - async def handle_get_task_result( - context: RequestContext[ClientSession], - params: Any, - ) -> GetTaskPayloadResult | ErrorData: - """Handle tasks/result from server.""" - event = task_complete_events.get(params.task_id) - assert event is not None, f"No completion event for task: {params.task_id}" - await event.wait() - result = await client_task_store.get_result(params.task_id) - assert result is not None, f"Result not found for task: {params.task_id}" - return GetTaskPayloadResult.model_validate(result.model_dump(by_alias=True)) - - return ExperimentalTaskHandlers( - augmented_sampling=handle_augmented_sampling, - get_task=handle_get_task, - get_task_result=handle_get_task_result, - ) - - -@pytest.mark.anyio -async def test_scenario1_normal_tool_normal_elicitation() -> None: - """Scenario 1: Normal tool call with normal elicitation. - - Server calls session.elicit() directly, client responds immediately. - """ - elicit_received = Event() - tool_result: list[str] = [] - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult( - tools=[ - Tool( - name="confirm_action", - description="Confirm an action", - input_schema={"type": "object"}, - ) - ] - ) - - async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: - # Normal elicitation - expects immediate response - result = await ctx.session.elicit( - message="Please confirm the action", - requested_schema={"type": "object", "properties": {"confirm": {"type": "boolean"}}}, - ) - - confirmed = result.content.get("confirm", False) if result.content else False - tool_result.append("confirmed" if confirmed else "cancelled") - return CallToolResult(content=[TextContent(type="text", text="confirmed" if confirmed else "cancelled")]) - - server = Server("test-scenario1", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) - - # Elicitation callback for client - async def elicitation_callback( - context: RequestContext[ClientSession], - params: ElicitRequestParams, - ) -> ElicitResult: - elicit_received.set() - return ElicitResult(action="accept", content={"confirm": True}) - - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ) - - async def run_client() -> None: - async with ClientSession( - server_to_client_receive, - client_to_server_send, - elicitation_callback=elicitation_callback, - ) as client_session: - await client_session.initialize() - - # Call tool normally (not as task) - result = await client_session.call_tool("confirm_action", {}) - - # Verify elicitation was received and tool completed - assert elicit_received.is_set() - assert len(result.content) > 0 - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "confirmed" - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) - - assert tool_result[0] == "confirmed" - - -@pytest.mark.anyio -async def test_scenario2_normal_tool_task_augmented_elicitation() -> None: - """Scenario 2: Normal tool call with task-augmented elicitation. - - Server calls session.experimental.elicit_as_task(), client creates a task - for the elicitation and returns CreateTaskResult. Server polls client. - """ - elicit_received = Event() - tool_result: list[str] = [] - - # Client-side task store for handling task-augmented elicitation - client_task_store = InMemoryTaskStore() - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult( - tools=[ - Tool( - name="confirm_action", - description="Confirm an action", - input_schema={"type": "object"}, - ) - ] - ) - - async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: - # Task-augmented elicitation - server polls client - result = await ctx.session.experimental.elicit_as_task( - message="Please confirm the action", - requested_schema={"type": "object", "properties": {"confirm": {"type": "boolean"}}}, - ttl=60000, - ) - - confirmed = result.content.get("confirm", False) if result.content else False - tool_result.append("confirmed" if confirmed else "cancelled") - return CallToolResult(content=[TextContent(type="text", text="confirmed" if confirmed else "cancelled")]) - - server = Server("test-scenario2", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) - task_handlers = create_client_task_handlers(client_task_store, elicit_received) - - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ) - - async def run_client() -> None: - async with ClientSession( - server_to_client_receive, - client_to_server_send, - experimental_task_handlers=task_handlers, - ) as client_session: - await client_session.initialize() - - # Call tool normally (not as task) - result = await client_session.call_tool("confirm_action", {}) - - # Verify elicitation was received and tool completed - assert elicit_received.is_set() - assert len(result.content) > 0 - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "confirmed" - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) - - assert tool_result[0] == "confirmed" - client_task_store.cleanup() - - -@pytest.mark.anyio -async def test_scenario3_task_augmented_tool_normal_elicitation() -> None: - """Scenario 3: Task-augmented tool call with normal elicitation. - - Client calls tool as task. Inside the task, server uses task.elicit() - which queues the request and delivers via tasks/result. - """ - elicit_received = Event() - work_completed = Event() - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CreateTaskResult: - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - async def work(task: ServerTaskContext) -> CallToolResult: - # Normal elicitation within task - queued and delivered via tasks/result - result = await task.elicit( - message="Please confirm the action", - requested_schema={"type": "object", "properties": {"confirm": {"type": "boolean"}}}, - ) - - confirmed = result.content.get("confirm", False) if result.content else False - work_completed.set() - return CallToolResult(content=[TextContent(type="text", text="confirmed" if confirmed else "cancelled")]) - - return await ctx.experimental.run_task(work) - - server = Server("test-scenario3", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) - server.experimental.enable_tasks() - - # Elicitation callback for client - async def elicitation_callback( - context: RequestContext[ClientSession], - params: ElicitRequestParams, - ) -> ElicitResult: - elicit_received.set() - return ElicitResult(action="accept", content={"confirm": True}) - - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ) - - async def run_client() -> None: - async with ClientSession( - server_to_client_receive, - client_to_server_send, - elicitation_callback=elicitation_callback, - ) as client_session: - await client_session.initialize() - - # Call tool as task - create_result = await client_session.experimental.call_tool_as_task("confirm_action", {}) - task_id = create_result.task.task_id - assert create_result.task.status == "working" - - # Poll until input_required, then call tasks/result - found_input_required = False - async for status in client_session.experimental.poll_task(task_id): # pragma: no branch - if status.status == "input_required": # pragma: no branch - found_input_required = True - break - assert found_input_required, "Expected to see input_required status" - - # This will deliver the elicitation and get the response - final_result = await client_session.experimental.get_task_result(task_id, CallToolResult) - - # Verify - assert elicit_received.is_set() - assert len(final_result.content) > 0 - assert isinstance(final_result.content[0], TextContent) - assert final_result.content[0].text == "confirmed" - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) - - assert work_completed.is_set() - - -@pytest.mark.anyio -async def test_scenario4_task_augmented_tool_task_augmented_elicitation() -> None: - """Scenario 4: Task-augmented tool call with task-augmented elicitation. - - Client calls tool as task. Inside the task, server uses task.elicit_as_task() - which sends task-augmented elicitation. Client creates its own task for the - elicitation, and server polls the client. - - This tests the full bidirectional flow where: - 1. Client calls tasks/result on server (for tool task) - 2. Server delivers task-augmented elicitation through that stream - 3. Client creates its own task and returns CreateTaskResult - 4. Server polls the client's task while the client's tasks/result is still open - 5. Server gets the ElicitResult and completes the tool task - 6. Client's tasks/result returns with the CallToolResult - """ - elicit_received = Event() - work_completed = Event() - - # Client-side task store for handling task-augmented elicitation - client_task_store = InMemoryTaskStore() - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CreateTaskResult: - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - async def work(task: ServerTaskContext) -> CallToolResult: - # Task-augmented elicitation within task - server polls client - result = await task.elicit_as_task( - message="Please confirm the action", - requested_schema={"type": "object", "properties": {"confirm": {"type": "boolean"}}}, - ttl=60000, - ) - - confirmed = result.content.get("confirm", False) if result.content else False - work_completed.set() - return CallToolResult(content=[TextContent(type="text", text="confirmed" if confirmed else "cancelled")]) - - return await ctx.experimental.run_task(work) - - server = Server("test-scenario4", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) - server.experimental.enable_tasks() - task_handlers = create_client_task_handlers(client_task_store, elicit_received) - - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ) - - async def run_client() -> None: - async with ClientSession( - server_to_client_receive, - client_to_server_send, - experimental_task_handlers=task_handlers, - ) as client_session: - await client_session.initialize() - - # Call tool as task - create_result = await client_session.experimental.call_tool_as_task("confirm_action", {}) - task_id = create_result.task.task_id - assert create_result.task.status == "working" - - # Poll until input_required or terminal, then call tasks/result - found_expected_status = False - async for status in client_session.experimental.poll_task(task_id): # pragma: no branch - if status.status == "input_required" or is_terminal(status.status): # pragma: no branch - found_expected_status = True - break - assert found_expected_status, "Expected to see input_required or terminal status" - - # This will deliver the task-augmented elicitation, - # server will poll client, and eventually return the tool result - final_result = await client_session.experimental.get_task_result(task_id, CallToolResult) - - # Verify - assert elicit_received.is_set() - assert len(final_result.content) > 0 - assert isinstance(final_result.content[0], TextContent) - assert final_result.content[0].text == "confirmed" - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) - - assert work_completed.is_set() - client_task_store.cleanup() - - -@pytest.mark.anyio -async def test_scenario2_sampling_normal_tool_task_augmented_sampling() -> None: - """Scenario 2 for sampling: Normal tool call with task-augmented sampling. - - Server calls session.experimental.create_message_as_task(), client creates - a task for the sampling and returns CreateTaskResult. Server polls client. - """ - sampling_received = Event() - tool_result: list[str] = [] - - # Client-side task store for handling task-augmented sampling - client_task_store = InMemoryTaskStore() - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - return ListToolsResult( - tools=[ - Tool( - name="generate_text", - description="Generate text using sampling", - input_schema={"type": "object"}, - ) - ] - ) - - async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: - # Task-augmented sampling - server polls client - result = await ctx.session.experimental.create_message_as_task( - messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], - max_tokens=100, - ttl=60000, - ) - - assert isinstance(result.content, TextContent), "Expected TextContent response" - response_text = result.content.text - - tool_result.append(response_text) - return CallToolResult(content=[TextContent(type="text", text=response_text)]) - - server = Server("test-scenario2-sampling", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) - task_handlers = create_sampling_task_handlers(client_task_store, sampling_received) - - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ) - - async def run_client() -> None: - async with ClientSession( - server_to_client_receive, - client_to_server_send, - experimental_task_handlers=task_handlers, - ) as client_session: - await client_session.initialize() - - # Call tool normally (not as task) - result = await client_session.call_tool("generate_text", {}) - - # Verify sampling was received and tool completed - assert sampling_received.is_set() - assert len(result.content) > 0 - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Hello from the model!" - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) - - assert tool_result[0] == "Hello from the model!" - client_task_store.cleanup() - - -@pytest.mark.anyio -async def test_scenario4_sampling_task_augmented_tool_task_augmented_sampling() -> None: - """Scenario 4 for sampling: Task-augmented tool call with task-augmented sampling. - - Client calls tool as task. Inside the task, server uses task.create_message_as_task() - which sends task-augmented sampling. Client creates its own task for the sampling, - and server polls the client. - """ - sampling_received = Event() - work_completed = Event() - - # Client-side task store for handling task-augmented sampling - client_task_store = InMemoryTaskStore() - - async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: - raise NotImplementedError - - async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CreateTaskResult: - ctx.experimental.validate_task_mode(TASK_REQUIRED) - - async def work(task: ServerTaskContext) -> CallToolResult: - # Task-augmented sampling within task - server polls client - result = await task.create_message_as_task( - messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], - max_tokens=100, - ttl=60000, - ) - - assert isinstance(result.content, TextContent), "Expected TextContent response" - response_text = result.content.text - - work_completed.set() - return CallToolResult(content=[TextContent(type="text", text=response_text)]) - - return await ctx.experimental.run_task(work) - - server = Server("test-scenario4-sampling", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) - server.experimental.enable_tasks() - task_handlers = create_sampling_task_handlers(client_task_store, sampling_received) - - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ) - - async def run_client() -> None: - async with ClientSession( - server_to_client_receive, - client_to_server_send, - experimental_task_handlers=task_handlers, - ) as client_session: - await client_session.initialize() - - # Call tool as task - create_result = await client_session.experimental.call_tool_as_task("generate_text", {}) - task_id = create_result.task.task_id - assert create_result.task.status == "working" - - # Poll until input_required or terminal - found_expected_status = False - async for status in client_session.experimental.poll_task(task_id): # pragma: no branch - if status.status == "input_required" or is_terminal(status.status): # pragma: no branch - found_expected_status = True - break - assert found_expected_status, "Expected to see input_required or terminal status" - - final_result = await client_session.experimental.get_task_result(task_id, CallToolResult) - - # Verify - assert sampling_received.is_set() - assert len(final_result.content) > 0 - assert isinstance(final_result.content[0], TextContent) - assert final_result.content[0].text == "Hello from the model!" - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) - - assert work_completed.is_set() - client_task_store.cleanup() diff --git a/tests/experimental/tasks/test_message_queue.py b/tests/experimental/tasks/test_message_queue.py deleted file mode 100644 index eca113d5b4..0000000000 --- a/tests/experimental/tasks/test_message_queue.py +++ /dev/null @@ -1,330 +0,0 @@ -"""Tests for TaskMessageQueue and InMemoryTaskMessageQueue.""" - -from collections import deque -from datetime import datetime, timezone - -import anyio -import pytest - -from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue, QueuedMessage -from mcp.shared.experimental.tasks.resolver import Resolver -from mcp.types import JSONRPCNotification, JSONRPCRequest - - -@pytest.fixture -def queue() -> InMemoryTaskMessageQueue: - return InMemoryTaskMessageQueue() - - -def make_request(id: int = 1, method: str = "test/method") -> JSONRPCRequest: - return JSONRPCRequest(jsonrpc="2.0", id=id, method=method) - - -def make_notification(method: str = "test/notify") -> JSONRPCNotification: - return JSONRPCNotification(jsonrpc="2.0", method=method) - - -class TestInMemoryTaskMessageQueue: - @pytest.mark.anyio - async def test_enqueue_and_dequeue(self, queue: InMemoryTaskMessageQueue) -> None: - """Test basic enqueue and dequeue operations.""" - task_id = "task-1" - msg = QueuedMessage(type="request", message=make_request()) - - await queue.enqueue(task_id, msg) - result = await queue.dequeue(task_id) - - assert result is not None - assert result.type == "request" - assert result.message.method == "test/method" - - @pytest.mark.anyio - async def test_dequeue_empty_returns_none(self, queue: InMemoryTaskMessageQueue) -> None: - """Dequeue from empty queue returns None.""" - result = await queue.dequeue("nonexistent-task") - assert result is None - - @pytest.mark.anyio - async def test_fifo_ordering(self, queue: InMemoryTaskMessageQueue) -> None: - """Messages are dequeued in FIFO order.""" - task_id = "task-1" - - await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request(1, "first"))) - await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request(2, "second"))) - await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request(3, "third"))) - - msg1 = await queue.dequeue(task_id) - msg2 = await queue.dequeue(task_id) - msg3 = await queue.dequeue(task_id) - - assert msg1 is not None and msg1.message.method == "first" - assert msg2 is not None and msg2.message.method == "second" - assert msg3 is not None and msg3.message.method == "third" - - @pytest.mark.anyio - async def test_separate_queues_per_task(self, queue: InMemoryTaskMessageQueue) -> None: - """Each task has its own queue.""" - await queue.enqueue("task-1", QueuedMessage(type="request", message=make_request(1, "task1-msg"))) - await queue.enqueue("task-2", QueuedMessage(type="request", message=make_request(2, "task2-msg"))) - - msg1 = await queue.dequeue("task-1") - msg2 = await queue.dequeue("task-2") - - assert msg1 is not None and msg1.message.method == "task1-msg" - assert msg2 is not None and msg2.message.method == "task2-msg" - - @pytest.mark.anyio - async def test_peek_does_not_remove(self, queue: InMemoryTaskMessageQueue) -> None: - """Peek returns message without removing it.""" - task_id = "task-1" - await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request())) - - peeked = await queue.peek(task_id) - dequeued = await queue.dequeue(task_id) - - assert peeked is not None - assert dequeued is not None - assert isinstance(peeked.message, JSONRPCRequest) - assert isinstance(dequeued.message, JSONRPCRequest) - assert peeked.message.id == dequeued.message.id - - @pytest.mark.anyio - async def test_is_empty(self, queue: InMemoryTaskMessageQueue) -> None: - """Test is_empty method.""" - task_id = "task-1" - - assert await queue.is_empty(task_id) is True - - await queue.enqueue(task_id, QueuedMessage(type="notification", message=make_notification())) - assert await queue.is_empty(task_id) is False - - await queue.dequeue(task_id) - assert await queue.is_empty(task_id) is True - - @pytest.mark.anyio - async def test_clear_returns_all_messages(self, queue: InMemoryTaskMessageQueue) -> None: - """Clear removes and returns all messages.""" - task_id = "task-1" - - await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request(1))) - await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request(2))) - await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request(3))) - - messages = await queue.clear(task_id) - - assert len(messages) == 3 - assert await queue.is_empty(task_id) is True - - @pytest.mark.anyio - async def test_clear_empty_queue(self, queue: InMemoryTaskMessageQueue) -> None: - """Clear on empty queue returns empty list.""" - messages = await queue.clear("nonexistent") - assert messages == [] - - @pytest.mark.anyio - async def test_notification_messages(self, queue: InMemoryTaskMessageQueue) -> None: - """Test queuing notification messages.""" - task_id = "task-1" - msg = QueuedMessage(type="notification", message=make_notification("log/message")) - - await queue.enqueue(task_id, msg) - result = await queue.dequeue(task_id) - - assert result is not None - assert result.type == "notification" - assert result.message.method == "log/message" - - @pytest.mark.anyio - async def test_message_timestamp(self, queue: InMemoryTaskMessageQueue) -> None: - """Messages have timestamps.""" - before = datetime.now(timezone.utc) - msg = QueuedMessage(type="request", message=make_request()) - after = datetime.now(timezone.utc) - - assert before <= msg.timestamp <= after - - @pytest.mark.anyio - async def test_message_with_resolver(self, queue: InMemoryTaskMessageQueue) -> None: - """Messages can have resolvers.""" - task_id = "task-1" - resolver: Resolver[dict[str, str]] = Resolver() - - msg = QueuedMessage( - type="request", - message=make_request(), - resolver=resolver, - original_request_id=42, - ) - - await queue.enqueue(task_id, msg) - result = await queue.dequeue(task_id) - - assert result is not None - assert result.resolver is resolver - assert result.original_request_id == 42 - - @pytest.mark.anyio - async def test_cleanup_specific_task(self, queue: InMemoryTaskMessageQueue) -> None: - """Cleanup removes specific task's data.""" - await queue.enqueue("task-1", QueuedMessage(type="request", message=make_request(1))) - await queue.enqueue("task-2", QueuedMessage(type="request", message=make_request(2))) - - queue.cleanup("task-1") - - assert await queue.is_empty("task-1") is True - assert await queue.is_empty("task-2") is False - - @pytest.mark.anyio - async def test_cleanup_all(self, queue: InMemoryTaskMessageQueue) -> None: - """Cleanup without task_id removes all data.""" - await queue.enqueue("task-1", QueuedMessage(type="request", message=make_request(1))) - await queue.enqueue("task-2", QueuedMessage(type="request", message=make_request(2))) - - queue.cleanup() - - assert await queue.is_empty("task-1") is True - assert await queue.is_empty("task-2") is True - - @pytest.mark.anyio - async def test_wait_for_message_returns_immediately_if_message_exists( - self, queue: InMemoryTaskMessageQueue - ) -> None: - """wait_for_message returns immediately if queue not empty.""" - task_id = "task-1" - await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request())) - - # Should return immediately, not block - with anyio.fail_after(1): - await queue.wait_for_message(task_id) - - @pytest.mark.anyio - async def test_wait_for_message_blocks_until_message(self, queue: InMemoryTaskMessageQueue) -> None: - """wait_for_message blocks until a message is enqueued.""" - task_id = "task-1" - received = False - waiter_started = anyio.Event() - - async def enqueue_when_ready() -> None: - # Wait until the waiter has started before enqueueing - await waiter_started.wait() - await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request())) - - async def wait_for_msg() -> None: - nonlocal received - # Signal that we're about to start waiting - waiter_started.set() - await queue.wait_for_message(task_id) - received = True - - async with anyio.create_task_group() as tg: - tg.start_soon(wait_for_msg) - tg.start_soon(enqueue_when_ready) - - assert received is True - - @pytest.mark.anyio - async def test_notify_message_available_wakes_waiter(self, queue: InMemoryTaskMessageQueue) -> None: - """notify_message_available wakes up waiting coroutines.""" - task_id = "task-1" - notified = False - waiter_started = anyio.Event() - - async def notify_when_ready() -> None: - # Wait until the waiter has started before notifying - await waiter_started.wait() - await queue.notify_message_available(task_id) - - async def wait_for_notification() -> None: - nonlocal notified - # Signal that we're about to start waiting - waiter_started.set() - await queue.wait_for_message(task_id) - notified = True - - async with anyio.create_task_group() as tg: - tg.start_soon(wait_for_notification) - tg.start_soon(notify_when_ready) - - assert notified is True - - @pytest.mark.anyio - async def test_peek_empty_queue_returns_none(self, queue: InMemoryTaskMessageQueue) -> None: - """Peek on empty queue returns None.""" - result = await queue.peek("nonexistent-task") - assert result is None - - @pytest.mark.anyio - async def test_wait_for_message_double_check_race_condition(self, queue: InMemoryTaskMessageQueue) -> None: - """wait_for_message returns early if message arrives after event creation but before wait.""" - task_id = "task-1" - - # To test the double-check path (lines 223-225), we need a message to arrive - # after the event is created (line 220) but before event.wait() (line 228). - # We simulate this by injecting a message before is_empty is called the second time. - - original_is_empty = queue.is_empty - call_count = 0 - - async def is_empty_with_injection(tid: str) -> bool: - nonlocal call_count - call_count += 1 - if call_count == 2 and tid == task_id: - # Before second check, inject a message - this simulates a message - # arriving between event creation and the double-check - queue._queues[task_id] = deque([QueuedMessage(type="request", message=make_request())]) - return await original_is_empty(tid) - - queue.is_empty = is_empty_with_injection # type: ignore[method-assign] - - # Should return immediately due to double-check finding the message - with anyio.fail_after(1): - await queue.wait_for_message(task_id) - - -class TestResolver: - @pytest.mark.anyio - async def test_set_result_and_wait(self) -> None: - """Test basic set_result and wait flow.""" - resolver: Resolver[str] = Resolver() - - resolver.set_result("hello") - result = await resolver.wait() - - assert result == "hello" - assert resolver.done() - - @pytest.mark.anyio - async def test_set_exception_and_wait(self) -> None: - """Test set_exception raises on wait.""" - resolver: Resolver[str] = Resolver() - - resolver.set_exception(ValueError("test error")) - - with pytest.raises(ValueError, match="test error"): - await resolver.wait() - - assert resolver.done() - - @pytest.mark.anyio - async def test_set_result_when_already_completed_raises(self) -> None: - """Test that set_result raises if resolver already completed.""" - resolver: Resolver[str] = Resolver() - resolver.set_result("first") - - with pytest.raises(RuntimeError, match="already completed"): - resolver.set_result("second") - - @pytest.mark.anyio - async def test_set_exception_when_already_completed_raises(self) -> None: - """Test that set_exception raises if resolver already completed.""" - resolver: Resolver[str] = Resolver() - resolver.set_result("done") - - with pytest.raises(RuntimeError, match="already completed"): - resolver.set_exception(ValueError("too late")) - - @pytest.mark.anyio - async def test_done_returns_false_before_completion(self) -> None: - """Test done() returns False before any result is set.""" - resolver: Resolver[str] = Resolver() - assert resolver.done() is False diff --git a/tests/experimental/tasks/test_request_context.py b/tests/experimental/tasks/test_request_context.py deleted file mode 100644 index ad4023389e..0000000000 --- a/tests/experimental/tasks/test_request_context.py +++ /dev/null @@ -1,166 +0,0 @@ -"""Tests for the RequestContext.experimental (Experimental class) task validation helpers.""" - -import pytest - -from mcp.server.experimental.request_context import Experimental -from mcp.shared.exceptions import MCPError -from mcp.types import ( - METHOD_NOT_FOUND, - TASK_FORBIDDEN, - TASK_OPTIONAL, - TASK_REQUIRED, - ClientCapabilities, - ClientTasksCapability, - TaskMetadata, - Tool, - ToolExecution, -) - - -def test_is_task_true_when_metadata_present() -> None: - exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) - assert exp.is_task is True - - -def test_is_task_false_when_no_metadata() -> None: - exp = Experimental(task_metadata=None) - assert exp.is_task is False - - -def test_client_supports_tasks_true() -> None: - exp = Experimental(_client_capabilities=ClientCapabilities(tasks=ClientTasksCapability())) - assert exp.client_supports_tasks is True - - -def test_client_supports_tasks_false_no_tasks() -> None: - exp = Experimental(_client_capabilities=ClientCapabilities()) - assert exp.client_supports_tasks is False - - -def test_client_supports_tasks_false_no_capabilities() -> None: - exp = Experimental(_client_capabilities=None) - assert exp.client_supports_tasks is False - - -def test_validate_task_mode_required_with_task_is_valid() -> None: - exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) - error = exp.validate_task_mode(TASK_REQUIRED, raise_error=False) - assert error is None - - -def test_validate_task_mode_required_without_task_returns_error() -> None: - exp = Experimental(task_metadata=None) - error = exp.validate_task_mode(TASK_REQUIRED, raise_error=False) - assert error is not None - assert error.code == METHOD_NOT_FOUND - assert "requires task-augmented" in error.message - - -def test_validate_task_mode_required_without_task_raises_by_default() -> None: - exp = Experimental(task_metadata=None) - with pytest.raises(MCPError) as exc_info: - exp.validate_task_mode(TASK_REQUIRED) - assert exc_info.value.error.code == METHOD_NOT_FOUND - - -def test_validate_task_mode_forbidden_without_task_is_valid() -> None: - exp = Experimental(task_metadata=None) - error = exp.validate_task_mode(TASK_FORBIDDEN, raise_error=False) - assert error is None - - -def test_validate_task_mode_forbidden_with_task_returns_error() -> None: - exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) - error = exp.validate_task_mode(TASK_FORBIDDEN, raise_error=False) - assert error is not None - assert error.code == METHOD_NOT_FOUND - assert "does not support task-augmented" in error.message - - -def test_validate_task_mode_forbidden_with_task_raises_by_default() -> None: - exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) - with pytest.raises(MCPError) as exc_info: - exp.validate_task_mode(TASK_FORBIDDEN) - assert exc_info.value.error.code == METHOD_NOT_FOUND - - -def test_validate_task_mode_none_treated_as_forbidden() -> None: - exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) - error = exp.validate_task_mode(None, raise_error=False) - assert error is not None - assert "does not support task-augmented" in error.message - - -def test_validate_task_mode_optional_with_task_is_valid() -> None: - exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) - error = exp.validate_task_mode(TASK_OPTIONAL, raise_error=False) - assert error is None - - -def test_validate_task_mode_optional_without_task_is_valid() -> None: - exp = Experimental(task_metadata=None) - error = exp.validate_task_mode(TASK_OPTIONAL, raise_error=False) - assert error is None - - -def test_validate_for_tool_with_execution_required() -> None: - exp = Experimental(task_metadata=None) - tool = Tool( - name="test", - description="test", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - error = exp.validate_for_tool(tool, raise_error=False) - assert error is not None - assert "requires task-augmented" in error.message - - -def test_validate_for_tool_without_execution() -> None: - exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) - tool = Tool( - name="test", - description="test", - input_schema={"type": "object"}, - execution=None, - ) - error = exp.validate_for_tool(tool, raise_error=False) - assert error is not None - assert "does not support task-augmented" in error.message - - -def test_validate_for_tool_optional_with_task() -> None: - exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) - tool = Tool( - name="test", - description="test", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_OPTIONAL), - ) - error = exp.validate_for_tool(tool, raise_error=False) - assert error is None - - -def test_can_use_tool_required_with_task_support() -> None: - exp = Experimental(_client_capabilities=ClientCapabilities(tasks=ClientTasksCapability())) - assert exp.can_use_tool(TASK_REQUIRED) is True - - -def test_can_use_tool_required_without_task_support() -> None: - exp = Experimental(_client_capabilities=ClientCapabilities()) - assert exp.can_use_tool(TASK_REQUIRED) is False - - -def test_can_use_tool_optional_without_task_support() -> None: - exp = Experimental(_client_capabilities=ClientCapabilities()) - assert exp.can_use_tool(TASK_OPTIONAL) is True - - -def test_can_use_tool_forbidden_without_task_support() -> None: - exp = Experimental(_client_capabilities=ClientCapabilities()) - assert exp.can_use_tool(TASK_FORBIDDEN) is True - - -def test_can_use_tool_none_without_task_support() -> None: - exp = Experimental(_client_capabilities=ClientCapabilities()) - assert exp.can_use_tool(None) is True diff --git a/tests/experimental/tasks/test_spec_compliance.py b/tests/experimental/tasks/test_spec_compliance.py deleted file mode 100644 index 38d7d0a664..0000000000 --- a/tests/experimental/tasks/test_spec_compliance.py +++ /dev/null @@ -1,717 +0,0 @@ -"""Tasks Spec Compliance Tests -=========================== - -Test structure mirrors: https://modelcontextprotocol.io/specification/draft/basic/utilities/tasks.md - -Each section contains tests for normative requirements (MUST/SHOULD/MAY). -""" - -from datetime import datetime, timezone - -import pytest - -from mcp.server import Server, ServerRequestContext -from mcp.server.lowlevel import NotificationOptions -from mcp.shared.experimental.tasks.helpers import MODEL_IMMEDIATE_RESPONSE_KEY -from mcp.types import ( - CancelTaskRequestParams, - CancelTaskResult, - CreateTaskResult, - GetTaskRequestParams, - GetTaskResult, - ListTasksResult, - PaginatedRequestParams, - ServerCapabilities, - Task, -) - -# Shared test datetime -TEST_DATETIME = datetime(2025, 1, 1, tzinfo=timezone.utc) - - -def _get_capabilities(server: Server) -> ServerCapabilities: - """Helper to get capabilities from a server.""" - return server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ) - - -def test_server_without_task_handlers_has_no_tasks_capability() -> None: - """Server without any task handlers has no tasks capability.""" - server: Server = Server("test") - caps = _get_capabilities(server) - assert caps.tasks is None - - -async def _noop_get_task(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: - raise NotImplementedError - - -async def _noop_list_tasks(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListTasksResult: - raise NotImplementedError - - -async def _noop_cancel_task(ctx: ServerRequestContext, params: CancelTaskRequestParams) -> CancelTaskResult: - raise NotImplementedError - - -def test_server_with_list_tasks_handler_declares_list_capability() -> None: - """Server with list_tasks handler declares tasks.list capability.""" - server: Server = Server("test") - server.experimental.enable_tasks(on_list_tasks=_noop_list_tasks) - - caps = _get_capabilities(server) - assert caps.tasks is not None - assert caps.tasks.list is not None - - -def test_server_with_cancel_task_handler_declares_cancel_capability() -> None: - """Server with cancel_task handler declares tasks.cancel capability.""" - server: Server = Server("test") - server.experimental.enable_tasks(on_cancel_task=_noop_cancel_task) - - caps = _get_capabilities(server) - assert caps.tasks is not None - assert caps.tasks.cancel is not None - - -def test_server_with_get_task_handler_declares_requests_tools_call_capability() -> None: - """Server with get_task handler declares tasks.requests.tools.call capability. - (get_task is required for task-augmented tools/call support) - """ - server: Server = Server("test") - server.experimental.enable_tasks(on_get_task=_noop_get_task) - - caps = _get_capabilities(server) - assert caps.tasks is not None - assert caps.tasks.requests is not None - assert caps.tasks.requests.tools is not None - - -@pytest.mark.skip( - reason="TODO(maxisbey): enable_tasks registers default handlers for all task methods, " - "so partial capabilities aren't possible yet. Low-level API should support " - "selectively enabling/disabling task capabilities." -) -def test_server_without_list_handler_has_no_list_capability() -> None: # pragma: no cover - """Server without list_tasks handler has no tasks.list capability.""" - server: Server = Server("test") - server.experimental.enable_tasks(on_get_task=_noop_get_task) - - caps = _get_capabilities(server) - assert caps.tasks is not None - assert caps.tasks.list is None - - -@pytest.mark.skip( - reason="TODO(maxisbey): enable_tasks registers default handlers for all task methods, " - "so partial capabilities aren't possible yet. Low-level API should support " - "selectively enabling/disabling task capabilities." -) -def test_server_without_cancel_handler_has_no_cancel_capability() -> None: # pragma: no cover - """Server without cancel_task handler has no tasks.cancel capability.""" - server: Server = Server("test") - server.experimental.enable_tasks(on_get_task=_noop_get_task) - - caps = _get_capabilities(server) - assert caps.tasks is not None - assert caps.tasks.cancel is None - - -def test_server_with_all_task_handlers_has_full_capability() -> None: - """Server with all task handlers declares complete tasks capability.""" - server: Server = Server("test") - server.experimental.enable_tasks( - on_list_tasks=_noop_list_tasks, - on_cancel_task=_noop_cancel_task, - on_get_task=_noop_get_task, - ) - - caps = _get_capabilities(server) - assert caps.tasks is not None - assert caps.tasks.list is not None - assert caps.tasks.cancel is not None - assert caps.tasks.requests is not None - assert caps.tasks.requests.tools is not None - - -class TestClientCapabilities: - """Clients declare: - - tasks.list — supports listing operations - - tasks.cancel — supports cancellation - - tasks.requests.sampling.createMessage — task-augmented sampling - - tasks.requests.elicitation.create — task-augmented elicitation - """ - - def test_client_declares_tasks_capability(self) -> None: - """Client can declare tasks capability.""" - pytest.skip("TODO") - - -class TestToolLevelNegotiation: - """Tools in tools/list responses include execution.taskSupport with values: - - Not present or "forbidden": No task augmentation allowed - - "optional": Task augmentation allowed at requestor discretion - - "required": Task augmentation is mandatory - """ - - def test_tool_execution_task_forbidden_rejects_task_augmented_call(self) -> None: - """Tool with execution.taskSupport="forbidden" MUST reject task-augmented calls (-32601).""" - pytest.skip("TODO") - - def test_tool_execution_task_absent_rejects_task_augmented_call(self) -> None: - """Tool without execution.taskSupport MUST reject task-augmented calls (-32601).""" - pytest.skip("TODO") - - def test_tool_execution_task_optional_accepts_normal_call(self) -> None: - """Tool with execution.taskSupport="optional" accepts normal calls.""" - pytest.skip("TODO") - - def test_tool_execution_task_optional_accepts_task_augmented_call(self) -> None: - """Tool with execution.taskSupport="optional" accepts task-augmented calls.""" - pytest.skip("TODO") - - def test_tool_execution_task_required_rejects_normal_call(self) -> None: - """Tool with execution.taskSupport="required" MUST reject non-task calls (-32601).""" - pytest.skip("TODO") - - def test_tool_execution_task_required_accepts_task_augmented_call(self) -> None: - """Tool with execution.taskSupport="required" accepts task-augmented calls.""" - pytest.skip("TODO") - - -class TestCapabilityNegotiation: - """Requestors SHOULD only augment requests with a task if the corresponding - capability has been declared by the receiver. - - Receivers that do not declare the task capability for a request type - MUST process requests of that type normally, ignoring any task-augmentation - metadata if present. - """ - - def test_receiver_without_capability_ignores_task_metadata(self) -> None: - """Receiver without task capability MUST process request normally, - ignoring task-augmentation metadata. - """ - pytest.skip("TODO") - - def test_receiver_with_capability_may_require_task_augmentation(self) -> None: - """Receivers that declare task capability MAY return error (-32600) - for non-task-augmented requests, requiring task augmentation. - """ - pytest.skip("TODO") - - -class TestTaskStatusLifecycle: - """Tasks begin in working status and follow valid transitions: - working → input_required → working → terminal - working → terminal (directly) - input_required → terminal (directly) - - Terminal states (no further transitions allowed): - - completed - - failed - - cancelled - """ - - def test_task_begins_in_working_status(self) -> None: - """Tasks MUST begin in working status.""" - pytest.skip("TODO") - - def test_working_to_completed_transition(self) -> None: - """working → completed is valid.""" - pytest.skip("TODO") - - def test_working_to_failed_transition(self) -> None: - """working → failed is valid.""" - pytest.skip("TODO") - - def test_working_to_cancelled_transition(self) -> None: - """working → cancelled is valid.""" - pytest.skip("TODO") - - def test_working_to_input_required_transition(self) -> None: - """working → input_required is valid.""" - pytest.skip("TODO") - - def test_input_required_to_working_transition(self) -> None: - """input_required → working is valid.""" - pytest.skip("TODO") - - def test_input_required_to_terminal_transition(self) -> None: - """input_required → terminal is valid.""" - pytest.skip("TODO") - - def test_terminal_state_no_further_transitions(self) -> None: - """Terminal states allow no further transitions.""" - pytest.skip("TODO") - - def test_completed_is_terminal(self) -> None: - """completed is a terminal state.""" - pytest.skip("TODO") - - def test_failed_is_terminal(self) -> None: - """failed is a terminal state.""" - pytest.skip("TODO") - - def test_cancelled_is_terminal(self) -> None: - """cancelled is a terminal state.""" - pytest.skip("TODO") - - -class TestInputRequiredStatus: - """When a receiver needs information to proceed, it moves the task to input_required. - The requestor should call tasks/result to retrieve input requests. - The task must include io.modelcontextprotocol/related-task metadata in associated requests. - """ - - def test_input_required_status_retrievable_via_tasks_get(self) -> None: - """Task in input_required status is retrievable via tasks/get.""" - pytest.skip("TODO") - - def test_input_required_related_task_metadata_in_requests(self) -> None: - """Task MUST include io.modelcontextprotocol/related-task metadata - in associated requests. - """ - pytest.skip("TODO") - - -class TestCreatingTask: - """Request structure: - {"method": "tools/call", "params": {"name": "...", "arguments": {...}, "task": {"ttl": 60000}}} - - Response (CreateTaskResult): - {"result": {"task": {"taskId": "...", "status": "working", ...}}} - - Receivers may include io.modelcontextprotocol/model-immediate-response in _meta. - """ - - def test_task_augmented_request_returns_create_task_result(self) -> None: - """Task-augmented request MUST return CreateTaskResult immediately.""" - pytest.skip("TODO") - - def test_create_task_result_contains_task_id(self) -> None: - """CreateTaskResult MUST contain taskId.""" - pytest.skip("TODO") - - def test_create_task_result_contains_status_working(self) -> None: - """CreateTaskResult MUST have status=working initially.""" - pytest.skip("TODO") - - def test_create_task_result_contains_created_at(self) -> None: - """CreateTaskResult MUST contain createdAt timestamp.""" - pytest.skip("TODO") - - def test_create_task_result_created_at_is_iso8601(self) -> None: - """createdAt MUST be ISO 8601 formatted.""" - pytest.skip("TODO") - - def test_create_task_result_may_contain_ttl(self) -> None: - """CreateTaskResult MAY contain ttl.""" - pytest.skip("TODO") - - def test_create_task_result_may_contain_poll_interval(self) -> None: - """CreateTaskResult MAY contain pollInterval.""" - pytest.skip("TODO") - - def test_create_task_result_may_contain_status_message(self) -> None: - """CreateTaskResult MAY contain statusMessage.""" - pytest.skip("TODO") - - def test_receiver_may_override_requested_ttl(self) -> None: - """Receiver MAY override requested ttl but MUST return actual value.""" - pytest.skip("TODO") - - def test_model_immediate_response_in_meta(self) -> None: - """Receiver MAY include io.modelcontextprotocol/model-immediate-response - in _meta to provide immediate response while task executes. - """ - # Verify the constant has the correct value per spec - assert MODEL_IMMEDIATE_RESPONSE_KEY == "io.modelcontextprotocol/model-immediate-response" - - # CreateTaskResult can include model-immediate-response in _meta - task = Task( - task_id="test-123", - status="working", - created_at=TEST_DATETIME, - last_updated_at=TEST_DATETIME, - ttl=60000, - ) - immediate_msg = "Task started, processing your request..." - # Note: Must use _meta= (alias) not meta= due to Pydantic alias handling - result = CreateTaskResult( - task=task, - **{"_meta": {MODEL_IMMEDIATE_RESPONSE_KEY: immediate_msg}}, - ) - - # Verify the metadata is present and correct - assert result.meta is not None - assert MODEL_IMMEDIATE_RESPONSE_KEY in result.meta - assert result.meta[MODEL_IMMEDIATE_RESPONSE_KEY] == immediate_msg - - # Verify it serializes correctly with _meta alias - serialized = result.model_dump(by_alias=True) - assert "_meta" in serialized - assert MODEL_IMMEDIATE_RESPONSE_KEY in serialized["_meta"] - assert serialized["_meta"][MODEL_IMMEDIATE_RESPONSE_KEY] == immediate_msg - - -class TestGettingTaskStatus: - """Request: {"method": "tasks/get", "params": {"taskId": "..."}} - Response: Returns full Task object with current status and pollInterval. - """ - - def test_tasks_get_returns_task_object(self) -> None: - """tasks/get MUST return full Task object.""" - pytest.skip("TODO") - - def test_tasks_get_returns_current_status(self) -> None: - """tasks/get MUST return current status.""" - pytest.skip("TODO") - - def test_tasks_get_may_return_poll_interval(self) -> None: - """tasks/get MAY return pollInterval.""" - pytest.skip("TODO") - - def test_tasks_get_invalid_task_id_returns_error(self) -> None: - """tasks/get with invalid taskId MUST return -32602.""" - pytest.skip("TODO") - - def test_tasks_get_nonexistent_task_id_returns_error(self) -> None: - """tasks/get with nonexistent taskId MUST return -32602.""" - pytest.skip("TODO") - - -class TestRetrievingResults: - """Request: {"method": "tasks/result", "params": {"taskId": "..."}} - Response: The actual operation result structure (e.g., CallToolResult). - - This call blocks until terminal status. - """ - - def test_tasks_result_returns_underlying_result(self) -> None: - """tasks/result MUST return exactly what underlying request would return.""" - pytest.skip("TODO") - - def test_tasks_result_blocks_until_terminal(self) -> None: - """tasks/result MUST block for non-terminal tasks.""" - pytest.skip("TODO") - - def test_tasks_result_unblocks_on_terminal(self) -> None: - """tasks/result MUST unblock upon reaching terminal status.""" - pytest.skip("TODO") - - def test_tasks_result_includes_related_task_metadata(self) -> None: - """tasks/result MUST include io.modelcontextprotocol/related-task in _meta.""" - pytest.skip("TODO") - - def test_tasks_result_returns_error_for_failed_task(self) -> None: - """tasks/result returns the same error the underlying request - would have produced for failed tasks. - """ - pytest.skip("TODO") - - def test_tasks_result_invalid_task_id_returns_error(self) -> None: - """tasks/result with invalid taskId MUST return -32602.""" - pytest.skip("TODO") - - -class TestListingTasks: - """Request: {"method": "tasks/list", "params": {"cursor": "optional"}} - Response: Array of tasks with pagination support via nextCursor. - """ - - def test_tasks_list_returns_array_of_tasks(self) -> None: - """tasks/list MUST return array of tasks.""" - pytest.skip("TODO") - - def test_tasks_list_pagination_with_cursor(self) -> None: - """tasks/list supports pagination via cursor.""" - pytest.skip("TODO") - - def test_tasks_list_returns_next_cursor_when_more_results(self) -> None: - """tasks/list MUST return nextCursor when more results available.""" - pytest.skip("TODO") - - def test_tasks_list_cursors_are_opaque(self) -> None: - """Implementers MUST treat cursors as opaque tokens.""" - pytest.skip("TODO") - - def test_tasks_list_invalid_cursor_returns_error(self) -> None: - """tasks/list with invalid cursor MUST return -32602.""" - pytest.skip("TODO") - - -class TestCancellingTasks: - """Request: {"method": "tasks/cancel", "params": {"taskId": "..."}} - Response: Returns the task object with status: "cancelled". - """ - - def test_tasks_cancel_returns_cancelled_task(self) -> None: - """tasks/cancel MUST return task with status=cancelled.""" - pytest.skip("TODO") - - def test_tasks_cancel_terminal_task_returns_error(self) -> None: - """Cancelling already-terminal task MUST return -32602.""" - pytest.skip("TODO") - - def test_tasks_cancel_completed_task_returns_error(self) -> None: - """Cancelling completed task MUST return -32602.""" - pytest.skip("TODO") - - def test_tasks_cancel_failed_task_returns_error(self) -> None: - """Cancelling failed task MUST return -32602.""" - pytest.skip("TODO") - - def test_tasks_cancel_already_cancelled_task_returns_error(self) -> None: - """Cancelling already-cancelled task MUST return -32602.""" - pytest.skip("TODO") - - def test_tasks_cancel_invalid_task_id_returns_error(self) -> None: - """tasks/cancel with invalid taskId MUST return -32602.""" - pytest.skip("TODO") - - -class TestStatusNotifications: - """Receivers MAY send: {"method": "notifications/tasks/status", "params": {...}} - These are optional; requestors MUST NOT rely on them and SHOULD continue polling. - """ - - def test_receiver_may_send_status_notification(self) -> None: - """Receiver MAY send notifications/tasks/status.""" - pytest.skip("TODO") - - def test_status_notification_contains_task_id(self) -> None: - """Status notification MUST contain taskId.""" - pytest.skip("TODO") - - def test_status_notification_contains_status(self) -> None: - """Status notification MUST contain status.""" - pytest.skip("TODO") - - -class TestTaskManagement: - """- Receivers generate unique task IDs as strings - - Tasks must begin in working status - - createdAt timestamps must be ISO 8601 formatted - - Receivers may override requested ttl but must return actual value - - Receivers may delete tasks after TTL expires - - All task-related messages must include io.modelcontextprotocol/related-task - in _meta except for tasks/get, tasks/list, tasks/cancel operations - """ - - def test_task_ids_are_unique_strings(self) -> None: - """Receivers MUST generate unique task IDs as strings.""" - pytest.skip("TODO") - - def test_multiple_tasks_have_unique_ids(self) -> None: - """Multiple tasks MUST have unique IDs.""" - pytest.skip("TODO") - - def test_receiver_may_delete_tasks_after_ttl(self) -> None: - """Receivers MAY delete tasks after TTL expires.""" - pytest.skip("TODO") - - def test_related_task_metadata_in_task_messages(self) -> None: - """All task-related messages MUST include io.modelcontextprotocol/related-task - in _meta. - """ - pytest.skip("TODO") - - def test_tasks_get_does_not_require_related_task_metadata(self) -> None: - """tasks/get does not require related-task metadata.""" - pytest.skip("TODO") - - def test_tasks_list_does_not_require_related_task_metadata(self) -> None: - """tasks/list does not require related-task metadata.""" - pytest.skip("TODO") - - def test_tasks_cancel_does_not_require_related_task_metadata(self) -> None: - """tasks/cancel does not require related-task metadata.""" - pytest.skip("TODO") - - -class TestResultHandling: - """- Receivers must return CreateTaskResult immediately upon accepting task-augmented requests - - tasks/result must return exactly what the underlying request would return - - tasks/result blocks for non-terminal tasks; must unblock upon reaching terminal status - """ - - def test_create_task_result_returned_immediately(self) -> None: - """Receiver MUST return CreateTaskResult immediately (not after work completes).""" - pytest.skip("TODO") - - def test_tasks_result_matches_underlying_result_structure(self) -> None: - """tasks/result MUST return same structure as underlying request.""" - pytest.skip("TODO") - - def test_tasks_result_for_tool_call_returns_call_tool_result(self) -> None: - """tasks/result for tools/call returns CallToolResult.""" - pytest.skip("TODO") - - -class TestProgressTracking: - """Task-augmented requests support progress notifications using the progressToken - mechanism, which remains valid throughout the task lifetime. - """ - - def test_progress_token_valid_throughout_task_lifetime(self) -> None: - """progressToken remains valid throughout task lifetime.""" - pytest.skip("TODO") - - def test_progress_notifications_sent_during_task_execution(self) -> None: - """Progress notifications can be sent during task execution.""" - pytest.skip("TODO") - - -class TestProtocolErrors: - """Protocol Errors (JSON-RPC standard codes): - - -32600 (Invalid request): Non-task requests to endpoint requiring task augmentation - - -32602 (Invalid params): Invalid/nonexistent taskId, invalid cursor, cancel terminal task - - -32603 (Internal error): Server-side execution failures - """ - - def test_invalid_request_for_required_task_augmentation(self) -> None: - """Non-task request to task-required endpoint returns -32600.""" - pytest.skip("TODO") - - def test_invalid_params_for_invalid_task_id(self) -> None: - """Invalid taskId returns -32602.""" - pytest.skip("TODO") - - def test_invalid_params_for_nonexistent_task_id(self) -> None: - """Nonexistent taskId returns -32602.""" - pytest.skip("TODO") - - def test_invalid_params_for_invalid_cursor(self) -> None: - """Invalid cursor in tasks/list returns -32602.""" - pytest.skip("TODO") - - def test_invalid_params_for_cancel_terminal_task(self) -> None: - """Attempt to cancel terminal task returns -32602.""" - pytest.skip("TODO") - - def test_internal_error_for_server_failure(self) -> None: - """Server-side execution failure returns -32603.""" - pytest.skip("TODO") - - -class TestTaskExecutionErrors: - """When underlying requests fail, the task moves to failed status. - - tasks/get response should include statusMessage explaining failure - - tasks/result returns same error the underlying request would have produced - - For tool calls, isError: true moves task to failed status - """ - - def test_underlying_failure_moves_task_to_failed(self) -> None: - """Underlying request failure moves task to failed status.""" - pytest.skip("TODO") - - def test_failed_task_has_status_message(self) -> None: - """Failed task SHOULD include statusMessage explaining failure.""" - pytest.skip("TODO") - - def test_tasks_result_returns_underlying_error(self) -> None: - """tasks/result returns same error underlying request would produce.""" - pytest.skip("TODO") - - def test_tool_call_is_error_true_moves_to_failed(self) -> None: - """Tool call with isError: true moves task to failed status.""" - pytest.skip("TODO") - - -class TestTaskObject: - """Task Object fields: - - taskId: String identifier - - status: Current execution state - - statusMessage: Optional human-readable description - - createdAt: ISO 8601 timestamp of creation - - ttl: Milliseconds before potential deletion - - pollInterval: Suggested milliseconds between polls - """ - - def test_task_has_task_id_string(self) -> None: - """Task MUST have taskId as string.""" - pytest.skip("TODO") - - def test_task_has_status(self) -> None: - """Task MUST have status.""" - pytest.skip("TODO") - - def test_task_status_message_is_optional(self) -> None: - """Task statusMessage is optional.""" - pytest.skip("TODO") - - def test_task_has_created_at(self) -> None: - """Task MUST have createdAt.""" - pytest.skip("TODO") - - def test_task_ttl_is_optional(self) -> None: - """Task ttl is optional.""" - pytest.skip("TODO") - - def test_task_poll_interval_is_optional(self) -> None: - """Task pollInterval is optional.""" - pytest.skip("TODO") - - -class TestRelatedTaskMetadata: - """Related Task Metadata structure: - {"_meta": {"io.modelcontextprotocol/related-task": {"taskId": "..."}}} - """ - - def test_related_task_metadata_structure(self) -> None: - """Related task metadata has correct structure.""" - pytest.skip("TODO") - - def test_related_task_metadata_contains_task_id(self) -> None: - """Related task metadata contains taskId.""" - pytest.skip("TODO") - - -class TestAccessAndIsolation: - """- Task IDs enable access to sensitive results - - Authorization context binding is essential where available - - For non-authorized environments: strong entropy IDs, strict TTL limits - """ - - def test_task_bound_to_authorization_context(self) -> None: - """Receivers receiving authorization context MUST bind tasks to that context.""" - pytest.skip("TODO") - - def test_reject_task_operations_outside_authorization_context(self) -> None: - """Receivers MUST reject task operations for tasks outside - requestor's authorization context. - """ - pytest.skip("TODO") - - def test_non_authorized_environments_use_secure_ids(self) -> None: - """For non-authorized environments, receivers SHOULD use - cryptographically secure IDs. - """ - pytest.skip("TODO") - - def test_non_authorized_environments_use_shorter_ttls(self) -> None: - """For non-authorized environments, receivers SHOULD use shorter TTLs.""" - pytest.skip("TODO") - - -class TestResourceLimits: - """Receivers should: - - Enforce concurrent task limits per requestor - - Implement maximum TTL constraints - - Clean up expired tasks promptly - """ - - def test_concurrent_task_limit_enforced(self) -> None: - """Receiver SHOULD enforce concurrent task limits per requestor.""" - pytest.skip("TODO") - - def test_maximum_ttl_constraint_enforced(self) -> None: - """Receiver SHOULD implement maximum TTL constraints.""" - pytest.skip("TODO") - - def test_expired_tasks_cleaned_up(self) -> None: - """Receiver SHOULD clean up expired tasks promptly.""" - pytest.skip("TODO") diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 109b30fc77..caed8905d0 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -49,8 +49,9 @@ _SOURCE_PATTERN = re.compile(r"https://modelcontextprotocol\.io/specification/.+|sdk|issue:#\d+") _TASKS_DEFERRAL = ( - "Tasks are experimental and the spec is being substantially revised; python task behaviour is " - "covered by tests/experimental/tasks/ until the next spec revision settles." + "Tasks have been removed from the draft spec and from this SDK; they are expected to return " + "as a separate MCP extension. These 2025-11-25 requirements are tracked but intentionally " + "unimplemented." ) diff --git a/tests/issues/test_176_progress_token.py b/tests/issues/test_176_progress_token.py index 5d5f8b8fc9..bef44928ac 100644 --- a/tests/issues/test_176_progress_token.py +++ b/tests/issues/test_176_progress_token.py @@ -3,7 +3,6 @@ import pytest from mcp.server.context import ServerRequestContext -from mcp.server.experimental.request_context import Experimental from mcp.server.mcpserver import Context pytestmark = pytest.mark.anyio @@ -22,7 +21,6 @@ async def test_progress_token_zero_first_call(): session=mock_session, meta={"progress_token": 0}, lifespan_context=None, - experimental=Experimental(), ) # Create context with our mocks diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index 3457ec944a..21352b5f2f 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -11,7 +11,6 @@ from mcp.client import Client from mcp.server.context import ServerRequestContext -from mcp.server.experimental.request_context import Experimental from mcp.server.mcpserver import Context, MCPServer from mcp.server.mcpserver.exceptions import ToolError from mcp.server.mcpserver.prompts.base import Message, UserMessage @@ -1502,7 +1501,6 @@ async def test_report_progress_passes_related_request_id(): session=mock_session, meta={"progress_token": "tok-1"}, lifespan_context=None, - experimental=Experimental(), ) ctx = Context(request_context=request_context, mcp_server=MagicMock()) diff --git a/tests/server/test_session.py b/tests/server/test_session.py index a2786d865d..6116a7c7f5 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -76,6 +76,49 @@ async def run_server(): assert received_initialized +@pytest.mark.anyio +async def test_check_client_capability(): + """check_client_capability reflects the capabilities sent by the client at initialize.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + + initialized = anyio.Event() + + async def list_roots_callback(context: Any) -> types.ListRootsResult: # pragma: no cover + return types.ListRootsResult(roots=[]) + + async def run_server(server_session: ServerSession): + async for message in server_session.incoming_messages: # pragma: no branch + if isinstance(message, ClientNotification) and isinstance( + message, InitializedNotification + ): # pragma: no branch + initialized.set() + return + + async with ( + ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions(server_name="mcp", server_version="0.1.0", capabilities=ServerCapabilities()), + ) as server_session, + ClientSession( + server_to_client_receive, + client_to_server_send, + list_roots_callback=list_roots_callback, + ) as client_session, + anyio.create_task_group() as tg, + ): + tg.start_soon(run_server, server_session) + await client_session.initialize() + with anyio.fail_after(5): + await initialized.wait() + + # ClientSession advertises roots when a list_roots_callback is provided. + assert server_session.check_client_capability(types.ClientCapabilities(roots=types.RootsCapability())) + # ClientSession does not advertise sampling without a sampling_callback. + assert not server_session.check_client_capability(types.ClientCapabilities(sampling=types.SamplingCapability())) + + @pytest.mark.anyio async def test_server_capabilities(): notification_options = NotificationOptions() diff --git a/uv.lock b/uv.lock index 5b72e97fce..df63607f40 100644 --- a/uv.lock +++ b/uv.lock @@ -18,10 +18,6 @@ members = [ "mcp-simple-resource", "mcp-simple-streamablehttp", "mcp-simple-streamablehttp-stateless", - "mcp-simple-task", - "mcp-simple-task-client", - "mcp-simple-task-interactive", - "mcp-simple-task-interactive-client", "mcp-simple-tool", "mcp-snippets", "mcp-sse-polling-client", @@ -1268,126 +1264,6 @@ dev = [ { name = "ruff", specifier = ">=0.6.9" }, ] -[[package]] -name = "mcp-simple-task" -version = "0.1.0" -source = { editable = "examples/servers/simple-task" } -dependencies = [ - { name = "anyio" }, - { name = "click" }, - { name = "mcp" }, - { name = "starlette" }, - { name = "uvicorn" }, -] - -[package.dev-dependencies] -dev = [ - { name = "pyright" }, - { name = "ruff" }, -] - -[package.metadata] -requires-dist = [ - { name = "anyio", specifier = ">=4.5" }, - { name = "click", specifier = ">=8.0" }, - { name = "mcp", editable = "." }, - { name = "starlette" }, - { name = "uvicorn" }, -] - -[package.metadata.requires-dev] -dev = [ - { name = "pyright", specifier = ">=1.1.378" }, - { name = "ruff", specifier = ">=0.6.9" }, -] - -[[package]] -name = "mcp-simple-task-client" -version = "0.1.0" -source = { editable = "examples/clients/simple-task-client" } -dependencies = [ - { name = "click" }, - { name = "mcp" }, -] - -[package.dev-dependencies] -dev = [ - { name = "pyright" }, - { name = "ruff" }, -] - -[package.metadata] -requires-dist = [ - { name = "click", specifier = ">=8.0" }, - { name = "mcp", editable = "." }, -] - -[package.metadata.requires-dev] -dev = [ - { name = "pyright", specifier = ">=1.1.378" }, - { name = "ruff", specifier = ">=0.6.9" }, -] - -[[package]] -name = "mcp-simple-task-interactive" -version = "0.1.0" -source = { editable = "examples/servers/simple-task-interactive" } -dependencies = [ - { name = "anyio" }, - { name = "click" }, - { name = "mcp" }, - { name = "starlette" }, - { name = "uvicorn" }, -] - -[package.dev-dependencies] -dev = [ - { name = "pyright" }, - { name = "ruff" }, -] - -[package.metadata] -requires-dist = [ - { name = "anyio", specifier = ">=4.5" }, - { name = "click", specifier = ">=8.0" }, - { name = "mcp", editable = "." }, - { name = "starlette" }, - { name = "uvicorn" }, -] - -[package.metadata.requires-dev] -dev = [ - { name = "pyright", specifier = ">=1.1.378" }, - { name = "ruff", specifier = ">=0.6.9" }, -] - -[[package]] -name = "mcp-simple-task-interactive-client" -version = "0.1.0" -source = { editable = "examples/clients/simple-task-interactive-client" } -dependencies = [ - { name = "click" }, - { name = "mcp" }, -] - -[package.dev-dependencies] -dev = [ - { name = "pyright" }, - { name = "ruff" }, -] - -[package.metadata] -requires-dist = [ - { name = "click", specifier = ">=8.0" }, - { name = "mcp", editable = "." }, -] - -[package.metadata.requires-dev] -dev = [ - { name = "pyright", specifier = ">=1.1.378" }, - { name = "ruff", specifier = ">=0.6.9" }, -] - [[package]] name = "mcp-simple-tool" version = "0.1.0" From c91f4069386ad724b275f537ee0f0af2af5a6a1c Mon Sep 17 00:00:00 2001 From: Max <224885523+maxisbey@users.noreply.github.com> Date: Tue, 2 Jun 2026 17:27:41 +0100 Subject: [PATCH 54/60] Require protocol_version to be a string (#2763) --- src/mcp/client/streamable_http.py | 2 +- src/mcp/types/_types.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index aa3e50e07e..9cdf717c73 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -120,7 +120,7 @@ def _maybe_extract_protocol_version_from_message(self, message: JSONRPCMessage) try: # Parse the result as InitializeResult for type safety init_result = InitializeResult.model_validate(message.result, by_name=False) - self.protocol_version = str(init_result.protocol_version) + self.protocol_version = init_result.protocol_version logger.info(f"Negotiated protocol version: {self.protocol_version}") except Exception: # pragma: no cover logger.warning("Failed to parse initialization response as InitializeResult", exc_info=True) diff --git a/src/mcp/types/_types.py b/src/mcp/types/_types.py index 34800ba12e..e9d39ef6f3 100644 --- a/src/mcp/types/_types.py +++ b/src/mcp/types/_types.py @@ -304,7 +304,7 @@ class ServerCapabilities(MCPModel): class InitializeRequestParams(RequestParams): """Parameters for the initialize request.""" - protocol_version: str | int + protocol_version: str """The latest version of the Model Context Protocol that the client supports.""" capabilities: ClientCapabilities client_info: Implementation @@ -322,7 +322,7 @@ class InitializeRequest(Request[InitializeRequestParams, Literal["initialize"]]) class InitializeResult(Result): """After receiving an initialize request from the client, the server sends this.""" - protocol_version: str | int + protocol_version: str """The version of the Model Context Protocol that the server wants to use.""" capabilities: ServerCapabilities server_info: Implementation From 60c04207750f15cb75a0bd6f9eed78cfe06bc492 Mon Sep 17 00:00:00 2001 From: Aryan Motgi <85900811+aryanmotgi@users.noreply.github.com> Date: Tue, 2 Jun 2026 09:28:19 -0700 Subject: [PATCH 55/60] docs: correct create_mcp_http_client default timeout docstring (#2683) Co-authored-by: aryanmotgi Co-authored-by: Marcelo Trylesinski --- src/mcp/shared/_httpx_utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/mcp/shared/_httpx_utils.py b/src/mcp/shared/_httpx_utils.py index 251469eaa1..6a121aff6d 100644 --- a/src/mcp/shared/_httpx_utils.py +++ b/src/mcp/shared/_httpx_utils.py @@ -27,14 +27,12 @@ def create_mcp_http_client( ) -> httpx.AsyncClient: """Create a standardized httpx AsyncClient with MCP defaults. - This function provides common defaults used throughout the MCP codebase: - - follow_redirects=True (always enabled) - - Default timeout of 30 seconds if not specified + Always enables follow_redirects and applies an SSE-friendly default timeout. Args: headers: Optional headers to include with all requests. - timeout: Request timeout as httpx.Timeout object. - Defaults to 30 seconds if not specified. + timeout: Request timeout as httpx.Timeout object. Defaults to 30s for + connect/write/pool and 300s for read (for long-lived SSE streams). auth: Optional authentication handler. Returns: From 4f6f0e8bf5544253e77b4f8d4cad804db5e2e5ef Mon Sep 17 00:00:00 2001 From: Zach Leventer Date: Tue, 2 Jun 2026 12:30:57 -0400 Subject: [PATCH 56/60] Remove dead commented-out code in register_client (#2500) --- src/mcp/client/auth/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/mcp/client/auth/utils.py b/src/mcp/client/auth/utils.py index d75324f2f0..780a24e859 100644 --- a/src/mcp/client/auth/utils.py +++ b/src/mcp/client/auth/utils.py @@ -240,8 +240,6 @@ async def handle_registration_response(response: Response) -> OAuthClientInforma content = await response.aread() client_info = OAuthClientInformationFull.model_validate_json(content) return client_info - # self.context.client_info = client_info - # await self.context.storage.set_client_info(client_info) except ValidationError as e: # pragma: no cover raise OAuthRegistrationError(f"Invalid registration response: {e}") From a9381263275a86257002c1e34101c1dce70cbc05 Mon Sep 17 00:00:00 2001 From: yukawithdata <90426808+yuka-with-data@users.noreply.github.com> Date: Tue, 2 Jun 2026 09:32:53 -0700 Subject: [PATCH 57/60] Doc: Clarify MCP Client-Server model in What is MCP section (#2459) --- README.v2.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/README.v2.md b/README.v2.md index a888f21bda..bae230c3f9 100644 --- a/README.v2.md +++ b/README.v2.md @@ -207,7 +207,11 @@ In the inspector UI, connect to `http://localhost:8000/mcp`. ## What is MCP? -The [Model Context Protocol (MCP)](https://modelcontextprotocol.io) lets you build servers that expose data and functionality to LLM applications in a secure, standardized way. Think of it like a web API, but specifically designed for LLM interactions. MCP servers can: +The [Model Context Protocol (MCP)](https://modelcontextprotocol.io) lets you build servers that expose data and functionality to LLM applications in a secure, standardized way. Think of it like a web API, but specifically designed for LLM interactions. + +MCP follows a **client-server model**, where LLM applications act as clients and connect to MCP servers to access capabilities such as data retrieval and tool execution in a consistent format. + +MCP servers can: - Expose data through **Resources** (think of these sort of like GET endpoints; they are used to load information into the LLM's context) - Provide functionality through **Tools** (sort of like POST endpoints; they are used to execute code or otherwise produce a side effect) From ed6adeee7e698ba66ee6434ee2bd69d0ec8729ea Mon Sep 17 00:00:00 2001 From: Max <224885523+maxisbey@users.noreply.github.com> Date: Tue, 2 Jun 2026 18:34:39 +0100 Subject: [PATCH 58/60] docs: require a passing conformance test for new 2026-07-28 spec features (#2761) --- AGENTS.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/AGENTS.md b/AGENTS.md index 307bd81b3e..3d4c43dd72 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -65,6 +65,11 @@ tests. Don't silence warnings from your own code; fix the underlying cause. Scoped `ignore::` entries for upstream libraries are acceptable in `pyproject.toml` with a comment explaining why. +- New features from the 2026-07-28 spec must have a matching test in the + [conformance suite](https://github.com/modelcontextprotocol/conformance) + that passes against this SDK (CI runs it via + `.github/workflows/conformance.yml`). If no matching test exists, stop and + tell the user so they can raise an issue on the conformance repo. ### Coverage From b3025f93b4ba63c6d0d04b1357947ecdb78eb0bd Mon Sep 17 00:00:00 2001 From: Max <224885523+maxisbey@users.noreply.github.com> Date: Tue, 2 Jun 2026 20:30:17 +0100 Subject: [PATCH 59/60] Run transport security tests in process instead of over sockets (#2764) --- src/mcp/server/streamable_http.py | 2 +- tests/interaction/transports/__init__.py | 9 + tests/server/test_sse_security.py | 309 ++++++----------- tests/server/test_streamable_http_manager.py | 29 +- tests/server/test_streamable_http_security.py | 317 +++++------------- 5 files changed, 212 insertions(+), 454 deletions(-) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index f2f4407cea..98948ff999 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -669,7 +669,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: await response(request.scope, request.receive, send) return - if not await self._validate_request_headers(request, send): # pragma: no cover + if not await self._validate_request_headers(request, send): return # Handle resumability: check for Last-Event-ID header diff --git a/tests/interaction/transports/__init__.py b/tests/interaction/transports/__init__.py index e69de29bb2..b5bbb633c2 100644 --- a/tests/interaction/transports/__init__.py +++ b/tests/interaction/transports/__init__.py @@ -0,0 +1,9 @@ +"""Transport-specific interaction tests, and the in-process streaming bridge they are built on. + +`StreamingASGITransport` is re-exported here as the sanctioned import point for test code +outside this suite (the bridge module itself is suite-private). +""" + +from tests.interaction.transports._bridge import StreamingASGITransport + +__all__ = ["StreamingASGITransport"] diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index e95dc51b31..e77bd5e2c2 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -1,15 +1,12 @@ """Tests for SSE server request validation.""" import logging -import multiprocessing import re -import socket import anyio import httpx import pytest import sse_starlette.sse -import uvicorn from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response @@ -23,12 +20,16 @@ from mcp.server.transport_security import TransportSecuritySettings from mcp.shared._stream_protocols import WriteStream from mcp.shared.message import SessionMessage -from mcp.types import JSONRPCRequest, JSONRPCResponse, Tool -from tests.test_helpers import wait_for_server +from mcp.types import JSONRPCRequest, JSONRPCResponse +from tests.interaction.transports import StreamingASGITransport logger = logging.getLogger(__name__) SERVER_NAME = "test_sse_security_server" +# The in-process app is mounted at this origin purely so URLs are well-formed and the default +# Host header is a localhost form; nothing listens here. +BASE_URL = "http://127.0.0.1:8000" + @pytest.fixture(autouse=True) def reset_sse_starlette_exit_event() -> None: @@ -39,275 +40,161 @@ def reset_sse_starlette_exit_event() -> None: app_status.should_exit_event = None -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: # pragma: no cover - return f"http://127.0.0.1:{server_port}" - - -class SecurityTestServer(Server): # pragma: no cover - def __init__(self): - super().__init__(SERVER_NAME) - - async def on_list_tools(self) -> list[Tool]: - return [] - - -def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): # pragma: no cover - """Run the SSE server with specified security settings.""" - app = SecurityTestServer() +def sse_security_client(security_settings: TransportSecuritySettings | None = None) -> httpx.AsyncClient: + """An httpx client whose requests are served in process by an SSE app with the given settings.""" + server = Server(SERVER_NAME) sse_transport = SseServerTransport("/messages/", security_settings) - async def handle_sse(request: Request): + async def handle_sse(request: Request) -> Response: try: - async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams: - if streams: - await app.run(streams[0], streams[1], app.create_initialization_options()) + async with sse_transport.connect_sse(request.scope, request.receive, request._send) as (read, write): + await server.run(read, write, server.create_initialization_options()) except ValueError as e: - # Validation error was already handled inside connect_sse + # Validation error was already handled inside connect_sse, which sent the rejection + # response itself; its non-empty body checkpoints, so the test reads the rejection + # status before the trailing Response() below sends a second response start. logger.debug(f"SSE connection failed validation: {e}") return Response() - routes = [ - Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse_transport.handle_post_message), - ] - - starlette_app = Starlette(routes=routes) - uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") - - -def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): - """Start server in a separate process.""" - process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) - process.start() - # Wait for server to be ready to accept connections - wait_for_server(port) - return process + app = Starlette( + routes=[ + Route("/sse", endpoint=handle_sse), + Mount("/messages/", app=sse_transport.handle_post_message), + ] + ) + # The SSE GET runs until it observes a disconnect, so the bridge must let the application + # drain on close rather than cancelling it. + transport = StreamingASGITransport(app, cancel_on_close=False) + return httpx.AsyncClient(transport=transport, base_url=BASE_URL) @pytest.mark.anyio -async def test_sse_security_default_settings(server_port: int): - """Test SSE with default security settings (protection disabled).""" - process = start_server_process(server_port) +async def test_sse_security_default_settings() -> None: + """With default security settings (protection disabled), any Host and Origin connect.""" + headers = {"Host": "evil.com", "Origin": "http://evil.com"} - try: - headers = {"Host": "evil.com", "Origin": "http://evil.com"} - - async with httpx.AsyncClient(timeout=5.0) as client: - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - assert response.status_code == 200 - finally: - process.terminate() - process.join() + async with sse_security_client() as client: + async with client.stream("GET", "/sse", headers=headers) as response: + assert response.status_code == 200 @pytest.mark.anyio -async def test_sse_security_invalid_host_header(server_port: int): - """Test SSE with invalid Host header.""" - # Enable security by providing settings with an empty allowed_hosts list +async def test_sse_security_invalid_host_header() -> None: + """A Host header outside allowed_hosts is rejected with 421.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["example.com"]) - process = start_server_process(server_port, security_settings) - try: - # Test with invalid host header - headers = {"Host": "evil.com"} - - async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) - assert response.status_code == 421 - assert response.text == "Invalid Host header" - - finally: - process.terminate() - process.join() + async with sse_security_client(security_settings) as client: + response = await client.get("/sse", headers={"Host": "evil.com"}) + assert response.status_code == 421 + assert response.text == "Invalid Host header" @pytest.mark.anyio -async def test_sse_security_invalid_origin_header(server_port: int): - """Test SSE with invalid Origin header.""" - # Configure security to allow the host but restrict origins +async def test_sse_security_invalid_origin_header() -> None: + """An Origin header outside allowed_origins is rejected with 403.""" security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://localhost:*"] ) - process = start_server_process(server_port, security_settings) - - try: - # Test with invalid origin header - headers = {"Origin": "http://evil.com"} - async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) - assert response.status_code == 403 - assert response.text == "Invalid Origin header" - - finally: - process.terminate() - process.join() + async with sse_security_client(security_settings) as client: + response = await client.get("/sse", headers={"Origin": "http://evil.com"}) + assert response.status_code == 403 + assert response.text == "Invalid Origin header" @pytest.mark.anyio -async def test_sse_security_post_invalid_content_type(server_port: int): - """Test POST endpoint with invalid Content-Type header.""" - # Configure security to allow the host +async def test_sse_security_post_invalid_content_type() -> None: + """A POST whose Content-Type is not application/json (or is missing) is rejected with 400.""" security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"] ) - process = start_server_process(server_port, security_settings) - - try: - async with httpx.AsyncClient(timeout=5.0) as client: - # Test POST with invalid content type - fake_session_id = "12345678123456781234567812345678" - response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", - headers={"Content-Type": "text/plain"}, - content="test", - ) - assert response.status_code == 400 - assert response.text == "Invalid Content-Type header" + fake_session_id = "12345678123456781234567812345678" - # Test POST with missing content type - response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", content="test" - ) - assert response.status_code == 400 - assert response.text == "Invalid Content-Type header" + async with sse_security_client(security_settings) as client: + response = await client.post( + f"/messages/?session_id={fake_session_id}", + headers={"Content-Type": "text/plain"}, + content="test", + ) + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" - finally: - process.terminate() - process.join() + response = await client.post(f"/messages/?session_id={fake_session_id}", content="test") + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" @pytest.mark.anyio -async def test_sse_security_disabled(server_port: int): - """Test SSE with security disabled.""" +async def test_sse_security_disabled() -> None: + """With protection explicitly disabled, a disallowed Host still connects.""" settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) - process = start_server_process(server_port, settings) - - try: - # Test with invalid host header - should still work - headers = {"Host": "evil.com"} - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully even with invalid host - assert response.status_code == 200 - - finally: - process.terminate() - process.join() + async with sse_security_client(settings) as client: + async with client.stream("GET", "/sse", headers={"Host": "evil.com"}) as response: + assert response.status_code == 200 @pytest.mark.anyio -async def test_sse_security_custom_allowed_hosts(server_port: int): - """Test SSE with custom allowed hosts.""" +async def test_sse_security_custom_allowed_hosts() -> None: + """A custom entry in allowed_hosts connects; hosts outside the list are still rejected.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["localhost", "127.0.0.1", "custom.host"], allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"], ) - process = start_server_process(server_port, settings) - - try: - # Test with custom allowed host - headers = {"Host": "custom.host"} - - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully with custom host - assert response.status_code == 200 - # Test with non-allowed host - headers = {"Host": "evil.com"} + async with sse_security_client(settings) as client: + async with client.stream("GET", "/sse", headers={"Host": "custom.host"}) as response: + assert response.status_code == 200 - async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) - assert response.status_code == 421 - assert response.text == "Invalid Host header" - - finally: - process.terminate() - process.join() + response = await client.get("/sse", headers={"Host": "evil.com"}) + assert response.status_code == 421 + assert response.text == "Invalid Host header" @pytest.mark.anyio -async def test_sse_security_wildcard_ports(server_port: int): - """Test SSE with wildcard port patterns.""" +async def test_sse_security_wildcard_ports() -> None: + """A `host:*` pattern accepts that host with any port, for Host and Origin alike.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["localhost:*", "127.0.0.1:*"], allowed_origins=["http://localhost:*", "http://127.0.0.1:*"], ) - process = start_server_process(server_port, settings) - try: - # Test with various port numbers + async with sse_security_client(settings) as client: for test_port in [8080, 3000, 9999]: - headers = {"Host": f"localhost:{test_port}"} - - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully with any port - assert response.status_code == 200 - - headers = {"Origin": f"http://localhost:{test_port}"} - - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully with any port - assert response.status_code == 200 + async with client.stream("GET", "/sse", headers={"Host": f"localhost:{test_port}"}) as response: + assert response.status_code == 200 - finally: - process.terminate() - process.join() + async with client.stream("GET", "/sse", headers={"Origin": f"http://localhost:{test_port}"}) as response: + assert response.status_code == 200 @pytest.mark.anyio -async def test_sse_security_post_valid_content_type(server_port: int): - """Test POST endpoint with valid Content-Type headers.""" - # Configure security to allow the host +async def test_sse_security_post_valid_content_type() -> None: + """Every application/json Content-Type variant passes validation (reaching the session lookup).""" security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"] ) - process = start_server_process(server_port, security_settings) - - try: - async with httpx.AsyncClient() as client: - # Test with various valid content types - valid_content_types = [ - "application/json", - "application/json; charset=utf-8", - "application/json;charset=utf-8", - "APPLICATION/JSON", # Case insensitive - ] - - for content_type in valid_content_types: - # Use a valid UUID format (even though session won't exist) - fake_session_id = "12345678123456781234567812345678" - response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", - headers={"Content-Type": content_type}, - json={"test": "data"}, - ) - # Will get 404 because session doesn't exist, but that's OK - # We're testing that it passes the content-type check - assert response.status_code == 404 - assert response.text == "Could not find session" - - finally: - process.terminate() - process.join() + valid_content_types = [ + "application/json", + "application/json; charset=utf-8", + "application/json;charset=utf-8", + "APPLICATION/JSON", # Case insensitive + ] + # A well-formed session ID that no live session owns. + fake_session_id = "12345678123456781234567812345678" + + async with sse_security_client(security_settings) as client: + for content_type in valid_content_types: + response = await client.post( + f"/messages/?session_id={fake_session_id}", + headers={"Content-Type": content_type}, + json={"test": "data"}, + ) + # 404 proves the request passed the content-type check and reached the session lookup. + assert response.status_code == 404 + assert response.text == "Could not find session" def _authenticated_user(client_id: str, subject: str | None = None, issuer: str | None = None) -> AuthenticatedUser: diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index ba75547964..f02e520eea 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -340,12 +340,33 @@ async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestP await client.list_tools() +class _IdleTimeoutObserver(logging.Handler): + """Resolves `reaped` when the manager logs that a session's idle timeout fired.""" + + def __init__(self) -> None: + super().__init__() + self.reaped = anyio.Event() + + def emit(self, record: logging.LogRecord) -> None: + if "idle timeout" in record.getMessage(): + self.reaped.set() + + @pytest.mark.anyio -async def test_idle_session_is_reaped(): +async def test_idle_session_is_reaped(caplog: pytest.LogCaptureFixture, request: pytest.FixtureRequest): """After idle timeout fires, the session returns 404.""" app = Server("test-idle-reap") manager = StreamableHTTPSessionManager(app=app, session_idle_timeout=0.05) + # The reap is observed through the manager's own "idle timeout" log record: the manager pops + # the session synchronously after emitting it, before its next await, so a waiter woken by + # the record always finds the session gone. caplog.set_level enables INFO so it is created. + observer = _IdleTimeoutObserver() + manager_logger = logging.getLogger(streamable_http_manager.__name__) + manager_logger.addHandler(observer) + request.addfinalizer(lambda: manager_logger.removeHandler(observer)) + caplog.set_level(logging.INFO, logger=streamable_http_manager.__name__) + async with manager.run(): sent_messages: list[Message] = [] @@ -376,8 +397,10 @@ async def mock_receive(): # pragma: no cover assert session_id is not None, "Session ID not found in response headers" - # Wait for the 50ms idle timeout to fire and cleanup to complete - await anyio.sleep(0.1) + # Wait for the 50ms idle timeout to fire and the session to be unregistered. Re-requesting + # the session to poll for the 404 would push its idle deadline forward and keep it alive. + with anyio.fail_after(5): + await observer.reaped.wait() # Verify via public API: old session ID now returns 404 response_messages: list[Message] = [] diff --git a/tests/server/test_streamable_http_security.py b/tests/server/test_streamable_http_security.py index 897555353e..f13bb4a9bb 100644 --- a/tests/server/test_streamable_http_security.py +++ b/tests/server/test_streamable_http_security.py @@ -1,291 +1,130 @@ """Tests for StreamableHTTP server DNS rebinding protection.""" -import multiprocessing -import socket -from collections.abc import AsyncGenerator +from collections.abc import AsyncIterator from contextlib import asynccontextmanager import httpx import pytest -import uvicorn from starlette.applications import Starlette from starlette.routing import Mount -from starlette.types import Receive, Scope, Send from mcp.server import Server from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings -from mcp.types import Tool -from tests.test_helpers import wait_for_server +from tests.interaction.transports import StreamingASGITransport SERVER_NAME = "test_streamable_http_security_server" +# The in-process app is mounted at this origin purely so URLs are well-formed and the default +# Host header is a localhost form; nothing listens here. +BASE_URL = "http://127.0.0.1:8000" -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] +@asynccontextmanager +async def streamable_http_security_client( + security_settings: TransportSecuritySettings | None = None, +) -> AsyncIterator[httpx.AsyncClient]: + """Yield an httpx client served in process by a StreamableHTTP app with the given settings.""" + session_manager = StreamableHTTPSessionManager(app=Server(SERVER_NAME), security_settings=security_settings) + app = Starlette(routes=[Mount("/", app=session_manager.handle_request)]) -@pytest.fixture -def server_url(server_port: int) -> str: # pragma: no cover - return f"http://127.0.0.1:{server_port}" + async with session_manager.run(): + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as client: + yield client -class SecurityTestServer(Server): # pragma: no cover - def __init__(self): - super().__init__(SERVER_NAME) +def _base_headers() -> dict[str, str]: + """Headers every well-formed request carries, so each test varies only the header under test.""" + return {"Accept": "application/json, text/event-stream", "Content-Type": "application/json"} - async def on_list_tools(self) -> list[Tool]: - return [] - -def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): # pragma: no cover - """Run the StreamableHTTP server with specified security settings.""" - app = SecurityTestServer() - - # Create session manager with security settings - session_manager = StreamableHTTPSessionManager( - app=app, - json_response=False, - stateless=False, - security_settings=security_settings, - ) - - # Create the ASGI handler - async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None: - await session_manager.handle_request(scope, receive, send) - - # Create Starlette app with lifespan - @asynccontextmanager - async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: - async with session_manager.run(): - yield - - routes = [ - Mount("/", app=handle_streamable_http), - ] - - starlette_app = Starlette(routes=routes, lifespan=lifespan) - uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") - - -def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): - """Start server in a separate process.""" - process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) - process.start() - # Wait for server to be ready to accept connections - wait_for_server(port) - return process +def _initialize_body() -> dict[str, object]: + """A minimal initialize POST body; these tests assert header validation, not the handshake.""" + return {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} @pytest.mark.anyio -async def test_streamable_http_security_default_settings(server_port: int): - """Test StreamableHTTP with default security settings (protection enabled).""" - process = start_server_process(server_port) - - try: - # Test with valid localhost headers - async with httpx.AsyncClient(timeout=5.0) as client: - # POST request to initialize session - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - ) - assert response.status_code == 200 - assert "mcp-session-id" in response.headers - - finally: - process.terminate() - process.join() +async def test_streamable_http_security_default_settings() -> None: + """With default security settings, a request with localhost headers is served.""" + async with streamable_http_security_client() as client: + response = await client.post("/", json=_initialize_body(), headers=_base_headers()) + assert response.status_code == 200 + assert "mcp-session-id" in response.headers @pytest.mark.anyio -async def test_streamable_http_security_invalid_host_header(server_port: int): - """Test StreamableHTTP with invalid Host header.""" +async def test_streamable_http_security_invalid_host_header() -> None: + """A Host header outside allowed_hosts is rejected with 421.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True) - process = start_server_process(server_port, security_settings) - - try: - # Test with invalid host header - headers = { - "Host": "evil.com", - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers, - ) - assert response.status_code == 421 - assert response.text == "Invalid Host header" - - finally: - process.terminate() - process.join() + + async with streamable_http_security_client(security_settings) as client: + response = await client.post("/", json=_initialize_body(), headers=_base_headers() | {"Host": "evil.com"}) + assert response.status_code == 421 + assert response.text == "Invalid Host header" @pytest.mark.anyio -async def test_streamable_http_security_invalid_origin_header(server_port: int): - """Test StreamableHTTP with invalid Origin header.""" +async def test_streamable_http_security_invalid_origin_header() -> None: + """An Origin header outside allowed_origins is rejected with 403.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"]) - process = start_server_process(server_port, security_settings) - - try: - # Test with invalid origin header - headers = { - "Origin": "http://evil.com", - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers, - ) - assert response.status_code == 403 - assert response.text == "Invalid Origin header" - - finally: - process.terminate() - process.join() + + async with streamable_http_security_client(security_settings) as client: + response = await client.post( + "/", json=_initialize_body(), headers=_base_headers() | {"Origin": "http://evil.com"} + ) + assert response.status_code == 403 + assert response.text == "Invalid Origin header" @pytest.mark.anyio -async def test_streamable_http_security_invalid_content_type(server_port: int): - """Test StreamableHTTP POST with invalid Content-Type header.""" - process = start_server_process(server_port) - - try: - async with httpx.AsyncClient(timeout=5.0) as client: - # Test POST with invalid content type - response = await client.post( - f"http://127.0.0.1:{server_port}/", - headers={ - "Content-Type": "text/plain", - "Accept": "application/json, text/event-stream", - }, - content="test", - ) - assert response.status_code == 400 - assert response.text == "Invalid Content-Type header" - - # Test POST with missing content type - response = await client.post( - f"http://127.0.0.1:{server_port}/", - headers={"Accept": "application/json, text/event-stream"}, - content="test", - ) - assert response.status_code == 400 - assert response.text == "Invalid Content-Type header" - - finally: - process.terminate() - process.join() +async def test_streamable_http_security_invalid_content_type() -> None: + """A POST whose Content-Type is not application/json (or is missing) is rejected with 400.""" + async with streamable_http_security_client() as client: + response = await client.post("/", headers=_base_headers() | {"Content-Type": "text/plain"}, content="test") + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" + + response = await client.post("/", headers={"Accept": "application/json, text/event-stream"}, content="test") + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" @pytest.mark.anyio -async def test_streamable_http_security_disabled(server_port: int): - """Test StreamableHTTP with security disabled.""" +async def test_streamable_http_security_disabled() -> None: + """With protection explicitly disabled, a disallowed Host is still served.""" settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) - process = start_server_process(server_port, settings) - - try: - # Test with invalid host header - should still work - headers = { - "Host": "evil.com", - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers, - ) - # Should connect successfully even with invalid host - assert response.status_code == 200 - - finally: - process.terminate() - process.join() + + async with streamable_http_security_client(settings) as client: + response = await client.post("/", json=_initialize_body(), headers=_base_headers() | {"Host": "evil.com"}) + assert response.status_code == 200 @pytest.mark.anyio -async def test_streamable_http_security_custom_allowed_hosts(server_port: int): - """Test StreamableHTTP with custom allowed hosts.""" +async def test_streamable_http_security_custom_allowed_hosts() -> None: + """A custom entry in allowed_hosts is served.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["localhost", "127.0.0.1", "custom.host"], allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"], ) - process = start_server_process(server_port, settings) - - try: - # Test with custom allowed host - headers = { - "Host": "custom.host", - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers, - ) - # Should connect successfully with custom host - assert response.status_code == 200 - finally: - process.terminate() - process.join() + + async with streamable_http_security_client(settings) as client: + response = await client.post("/", json=_initialize_body(), headers=_base_headers() | {"Host": "custom.host"}) + assert response.status_code == 200 @pytest.mark.anyio -async def test_streamable_http_security_get_request(server_port: int): - """Test StreamableHTTP GET request with security.""" +async def test_streamable_http_security_get_request() -> None: + """GET requests pass the same Host validation before any session handling.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1"]) - process = start_server_process(server_port, security_settings) - - try: - # Test GET request with invalid host header - headers = { - "Host": "evil.com", - "Accept": "text/event-stream", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.get(f"http://127.0.0.1:{server_port}/", headers=headers) - assert response.status_code == 421 - assert response.text == "Invalid Host header" - - # Test GET request with valid host header - headers = { - "Host": "127.0.0.1", - "Accept": "text/event-stream", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - # GET requests need a session ID in StreamableHTTP - # So it will fail with "Missing session ID" not security error - response = await client.get(f"http://127.0.0.1:{server_port}/", headers=headers) - # This should pass security but fail on session validation - assert response.status_code == 400 - body = response.json() - assert "Missing session ID" in body["error"]["message"] - - finally: - process.terminate() - process.join() + + async with streamable_http_security_client(security_settings) as client: + response = await client.get("/", headers={"Accept": "text/event-stream", "Host": "evil.com"}) + assert response.status_code == 421 + assert response.text == "Invalid Host header" + + response = await client.get("/", headers={"Accept": "text/event-stream", "Host": "127.0.0.1"}) + # An allowed host passes security and fails on session validation instead. + assert response.status_code == 400 + body = response.json() + assert "Missing session ID" in body["error"]["message"] From ed39e73c0ba16cd3ff9c3996421a17c21b4ed05b Mon Sep 17 00:00:00 2001 From: Max <224885523+maxisbey@users.noreply.github.com> Date: Tue, 2 Jun 2026 21:46:53 +0100 Subject: [PATCH 60/60] Run SSE and Unicode transport tests in process instead of over sockets (#2765) --- tests/client/test_http_unicode.py | 288 ++++++++------------- tests/shared/test_sse.py | 409 ++++++++++-------------------- 2 files changed, 242 insertions(+), 455 deletions(-) diff --git a/tests/client/test_http_unicode.py b/tests/client/test_http_unicode.py index cc2e14e469..585a142617 100644 --- a/tests/client/test_http_unicode.py +++ b/tests/client/test_http_unicode.py @@ -4,11 +4,10 @@ (server→client and client→server) using the streamable HTTP transport. """ -import multiprocessing -import socket -from collections.abc import AsyncGenerator, Generator +from collections.abc import AsyncIterator from contextlib import asynccontextmanager +import httpx import pytest from starlette.applications import Starlette from starlette.routing import Mount @@ -19,7 +18,10 @@ from mcp.server import Server, ServerRequestContext from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.types import TextContent, Tool -from tests.test_helpers import wait_for_server +from tests.interaction.transports import StreamingASGITransport + +# The in-process app is mounted at this origin purely so URLs are well-formed; nothing listens here. +BASE_URL = "http://127.0.0.1:8000" # Test constants with various Unicode characters UNICODE_TEST_STRINGS = { @@ -41,74 +43,62 @@ } -def run_unicode_server(port: int) -> None: # pragma: no cover - """Run the Unicode test server in a separate process.""" - import uvicorn - - # Need to recreate the server setup in this process - async def handle_list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[ - Tool( - name="echo_unicode", - description="🔤 Echo Unicode text - Hello 👋 World 🌍 - Testing 🧪 Unicode ✨", - input_schema={ - "type": "object", - "properties": { - "text": {"type": "string", "description": "Text to echo back"}, - }, - "required": ["text"], +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + Tool( + name="echo_unicode", + description="🔤 Echo Unicode text - Hello 👋 World 🌍 - Testing 🧪 Unicode ✨", + input_schema={ + "type": "object", + "properties": { + "text": {"type": "string", "description": "Text to echo back"}, }, - ), - ] - ) - - async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: - if params.name == "echo_unicode": - text = params.arguments.get("text", "") if params.arguments else "" - return types.CallToolResult( - content=[ - TextContent( - type="text", - text=f"Echo: {text}", - ) - ] + "required": ["text"], + }, + ), + ] + ) + + +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + assert params.name == "echo_unicode" + assert params.arguments is not None + return types.CallToolResult(content=[TextContent(type="text", text=f"Echo: {params.arguments['text']}")]) + + +async def handle_list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListPromptsResult: + return types.ListPromptsResult( + prompts=[ + types.Prompt( + name="unicode_prompt", + description="Unicode prompt - Слой хранилища, где располагаются", + arguments=[], ) - else: - raise ValueError(f"Unknown tool: {params.name}") - - async def handle_list_prompts( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListPromptsResult: - return types.ListPromptsResult( - prompts=[ - types.Prompt( - name="unicode_prompt", - description="Unicode prompt - Слой хранилища, где располагаются", - arguments=[], - ) - ] - ) - - async def handle_get_prompt( - ctx: ServerRequestContext, params: types.GetPromptRequestParams - ) -> types.GetPromptResult: - if params.name == "unicode_prompt": - return types.GetPromptResult( - messages=[ - types.PromptMessage( - role="user", - content=types.TextContent( - type="text", - text="Hello世界🌍Привет안녕مرحباשלום", - ), - ) - ] + ] + ) + + +async def handle_get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> types.GetPromptResult: + assert params.name == "unicode_prompt" + return types.GetPromptResult( + messages=[ + types.PromptMessage( + role="user", + content=types.TextContent(type="text", text="Hello世界🌍Привет안녕مرحباשלום"), ) - raise ValueError(f"Unknown prompt: {params.name}") + ] + ) + +@asynccontextmanager +async def unicode_session() -> AsyncIterator[ClientSession]: + """Yield an initialized ClientSession speaking streamable HTTP (SSE responses) to the + Unicode test server, entirely in process.""" server = Server( name="unicode_test_server", on_list_tools=handle_list_tools, @@ -116,122 +106,68 @@ async def handle_get_prompt( on_list_prompts=handle_list_prompts, on_get_prompt=handle_get_prompt, ) - - # Create the session manager - session_manager = StreamableHTTPSessionManager( - app=server, - json_response=False, # Use SSE for testing - ) - - @asynccontextmanager - async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: - async with session_manager.run(): - yield - - # Create an ASGI application - app = Starlette( - debug=True, - routes=[ - Mount("/mcp", app=session_manager.handle_request), - ], - lifespan=lifespan, - ) - - # Run the server - config = uvicorn.Config( - app=app, - host="127.0.0.1", - port=port, - log_level="error", - ) - uvicorn_server = uvicorn.Server(config) - uvicorn_server.run() - - -@pytest.fixture -def unicode_server_port() -> int: - """Find an available port for the Unicode test server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def running_unicode_server(unicode_server_port: int) -> Generator[str, None, None]: - """Start a Unicode test server in a separate process.""" - proc = multiprocessing.Process(target=run_unicode_server, kwargs={"port": unicode_server_port}, daemon=True) - proc.start() - - # Wait for server to be ready - wait_for_server(unicode_server_port) - - try: - yield f"http://127.0.0.1:{unicode_server_port}" - finally: - # Clean up - try graceful termination first - proc.terminate() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - proc.kill() - proc.join(timeout=1) + # SSE response mode, so Unicode rides the SSE event encoding rather than a plain JSON body. + session_manager = StreamableHTTPSessionManager(app=server, json_response=False) + app = Starlette(routes=[Mount("/mcp", app=session_manager.handle_request)]) + + async with ( + session_manager.run(), + # follow_redirects matches the SDK's own client factory; Starlette's Mount 307-redirects + # the bare /mcp path to /mcp/. + httpx.AsyncClient( + transport=StreamingASGITransport(app), base_url=BASE_URL, follow_redirects=True + ) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + await session.initialize() + yield session @pytest.mark.anyio -async def test_streamable_http_client_unicode_tool_call(running_unicode_server: str) -> None: +async def test_streamable_http_client_unicode_tool_call() -> None: """Test that Unicode text is correctly handled in tool calls via streamable HTTP.""" - base_url = running_unicode_server - endpoint_url = f"{base_url}/mcp" - - async with streamable_http_client(endpoint_url) as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - - # Test 1: List tools (server→client Unicode in descriptions) - tools = await session.list_tools() - assert len(tools.tools) == 1 + async with unicode_session() as session: + # Test 1: List tools (server→client Unicode in descriptions) + tools = await session.list_tools() + assert len(tools.tools) == 1 - # Check Unicode in tool descriptions - echo_tool = tools.tools[0] - assert echo_tool.name == "echo_unicode" - assert echo_tool.description is not None - assert "🔤" in echo_tool.description - assert "👋" in echo_tool.description + # Check Unicode in tool descriptions + echo_tool = tools.tools[0] + assert echo_tool.name == "echo_unicode" + assert echo_tool.description is not None + assert "🔤" in echo_tool.description + assert "👋" in echo_tool.description - # Test 2: Send Unicode text in tool call (client→server→client) - for test_name, test_string in UNICODE_TEST_STRINGS.items(): - result = await session.call_tool("echo_unicode", arguments={"text": test_string}) + # Test 2: Send Unicode text in tool call (client→server→client) + for test_name, test_string in UNICODE_TEST_STRINGS.items(): + result = await session.call_tool("echo_unicode", arguments={"text": test_string}) - # Verify server correctly received and echoed back Unicode - assert len(result.content) == 1 - content = result.content[0] - assert content.type == "text" - assert f"Echo: {test_string}" == content.text, f"Failed for {test_name}" + # Verify server correctly received and echoed back Unicode + assert len(result.content) == 1 + content = result.content[0] + assert content.type == "text" + assert f"Echo: {test_string}" == content.text, f"Failed for {test_name}" @pytest.mark.anyio -async def test_streamable_http_client_unicode_prompts(running_unicode_server: str) -> None: +async def test_streamable_http_client_unicode_prompts() -> None: """Test that Unicode text is correctly handled in prompts via streamable HTTP.""" - base_url = running_unicode_server - endpoint_url = f"{base_url}/mcp" - - async with streamable_http_client(endpoint_url) as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - - # Test 1: List prompts (server→client Unicode in descriptions) - prompts = await session.list_prompts() - assert len(prompts.prompts) == 1 - - prompt = prompts.prompts[0] - assert prompt.name == "unicode_prompt" - assert prompt.description is not None - assert "Слой хранилища, где располагаются" in prompt.description - - # Test 2: Get prompt with Unicode content (server→client) - result = await session.get_prompt("unicode_prompt", arguments={}) - assert len(result.messages) == 1 - - message = result.messages[0] - assert message.role == "user" - assert message.content.type == "text" - assert message.content.text == "Hello世界🌍Привет안녕مرحباשלום" + async with unicode_session() as session: + # Test 1: List prompts (server→client Unicode in descriptions) + prompts = await session.list_prompts() + assert len(prompts.prompts) == 1 + + prompt = prompts.prompts[0] + assert prompt.name == "unicode_prompt" + assert prompt.description is not None + assert "Слой хранилища, где располагаются" in prompt.description + + # Test 2: Get prompt with Unicode content (server→client) + result = await session.get_prompt("unicode_prompt", arguments={}) + assert len(result.messages) == 1 + + message = result.messages[0] + assert message.role == "user" + assert message.content.type == "text" + assert message.content.text == "Hello世界🌍Привет안녕مرحباשלום" diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 5629a5707b..675a4acb16 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -1,7 +1,7 @@ +"""Tests for the SSE client and server transports, driven entirely in process.""" + import json -import multiprocessing -import socket -from collections.abc import AsyncGenerator, Generator +from collections.abc import AsyncGenerator from typing import Any from unittest.mock import AsyncMock, MagicMock, Mock, patch from urllib.parse import urlparse @@ -9,7 +9,6 @@ import anyio import httpx import pytest -import uvicorn from httpx_sse import ServerSentEvent from inline_snapshot import snapshot from starlette.applications import Starlette @@ -24,6 +23,7 @@ from mcp.server import Server, ServerRequestContext from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings +from mcp.shared._httpx_utils import McpHttpClientFactory from mcp.shared.exceptions import MCPError from mcp.types import ( CallToolRequestParams, @@ -41,171 +41,114 @@ TextResourceContents, Tool, ) -from tests.test_helpers import wait_for_server +from tests.interaction.transports import StreamingASGITransport SERVER_NAME = "test_server_for_SSE" +# The in-process app is mounted at this origin purely so URLs are well-formed; nothing listens here. +BASE_URL = "http://127.0.0.1:8000" + + +def in_process_client_factory(app: Starlette) -> McpHttpClientFactory: + """An httpx_client_factory for sse_client whose clients are served in process by `app`.""" + + def factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + # The SSE GET runs until it observes a disconnect, so the bridge must let the + # application drain on close rather than cancelling it. follow_redirects matches + # create_mcp_http_client, the factory this one stands in for. + return httpx.AsyncClient( + transport=StreamingASGITransport(app, cancel_on_close=False), + base_url=BASE_URL, + headers=headers, + timeout=timeout, + auth=auth, + follow_redirects=True, + ) -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - + return factory -@pytest.fixture -def server_url(server_port: int) -> str: - return f"http://127.0.0.1:{server_port}" - -async def _handle_read_resource( # pragma: no cover - ctx: ServerRequestContext, params: ReadResourceRequestParams -) -> ReadResourceResult: +async def _handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: uri = str(params.uri) parsed = urlparse(uri) if parsed.scheme == "foobar": - text = f"Read {parsed.netloc}" - elif parsed.scheme == "slow": - await anyio.sleep(2.0) - text = f"Slow response from {parsed.netloc}" - else: - raise MCPError(code=404, message="OOPS! no resource with that URI was found") - return ReadResourceResult(contents=[TextResourceContents(uri=uri, text=text, mime_type="text/plain")]) - - -async def _handle_list_tools( # pragma: no cover - ctx: ServerRequestContext, params: PaginatedRequestParams | None -) -> ListToolsResult: - return ListToolsResult( - tools=[ - Tool( - name="test_tool", - description="A test tool", - input_schema={"type": "object", "properties": {}}, - ) - ] - ) - - -async def _handle_call_tool( # pragma: no cover - ctx: ServerRequestContext, params: CallToolRequestParams -) -> CallToolResult: - return CallToolResult(content=[TextContent(type="text", text=f"Called {params.name}")]) - - -def _create_server() -> Server: # pragma: no cover - return Server( - SERVER_NAME, - on_read_resource=_handle_read_resource, - on_list_tools=_handle_list_tools, - on_call_tool=_handle_call_tool, - ) + return ReadResourceResult( + contents=[TextResourceContents(uri=uri, text=f"Read {parsed.netloc}", mime_type="text/plain")] + ) + raise MCPError(code=404, message="OOPS! no resource with that URI was found") -# Test fixtures -def make_server_app() -> Starlette: # pragma: no cover - """Create test Starlette app with SSE transport""" - # Configure security with allowed hosts/origins for testing - security_settings = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] +def make_app(server: Server) -> Starlette: + """Mount `server` on a Starlette app exposing the SSE transport at /sse and /messages/.""" + # DNS-rebinding protection validates Host/Origin headers against a network attack that cannot + # exist for an in-process app; the transport security behaviour itself is pinned by + # tests/server/test_sse_security.py. + sse = SseServerTransport( + "/messages/", security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False) ) - sse = SseServerTransport("/messages/", security_settings=security_settings) - server = _create_server() async def handle_sse(request: Request) -> Response: - async with sse.connect_sse(request.scope, request.receive, request._send) as streams: - await server.run(streams[0], streams[1], server.create_initialization_options()) + async with sse.connect_sse(request.scope, request.receive, request._send) as (read_stream, write_stream): + await server.run(read_stream, write_stream, server.create_initialization_options()) return Response() - app = Starlette( + return Starlette( routes=[ Route("/sse", endpoint=handle_sse), Mount("/messages/", app=sse.handle_post_message), ] ) - return app - - -def run_server(server_port: int) -> None: # pragma: no cover - app = make_server_app() - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting server on {server_port}") - server.run() - - -@pytest.fixture() -def server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True) - print("starting process") - proc.start() - # Wait for server to be running - print("waiting for server to start") - wait_for_server(server_port) +def make_server_app() -> Starlette: + return make_app(Server(SERVER_NAME, on_read_resource=_handle_read_resource)) - yield - print("killing server") - # Signal the server to stop - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("server process failed to terminate") - - -@pytest.fixture() -async def http_client(server: None, server_url: str) -> AsyncGenerator[httpx.AsyncClient, None]: - """Create test client""" - async with httpx.AsyncClient(base_url=server_url) as client: - yield client +@pytest.mark.anyio +async def test_raw_sse_connection() -> None: + """The SSE GET responds 200 with an event-stream content type, announcing the session + endpoint as its first event.""" + http_client = httpx.AsyncClient( + transport=StreamingASGITransport(make_server_app(), cancel_on_close=False), base_url=BASE_URL + ) + with anyio.fail_after(5): + async with http_client, http_client.stream("GET", "/sse") as response: + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" -# Tests -@pytest.mark.anyio -async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None: - """Test the SSE connection establishment simply with an HTTP client.""" - async with anyio.create_task_group(): - - async def connection_test() -> None: - async with http_client.stream("GET", "/sse") as response: - assert response.status_code == 200 - assert response.headers["content-type"] == "text/event-stream; charset=utf-8" - - line_number = 0 - async for line in response.aiter_lines(): # pragma: no branch - if line_number == 0: - assert line == "event: endpoint" - elif line_number == 1: - assert line.startswith("data: /messages/?session_id=") - else: - return - line_number += 1 - - # Add timeout to prevent test from hanging if it fails - with anyio.fail_after(3): - await connection_test() + lines = response.aiter_lines() + assert await anext(lines) == "event: endpoint" + assert (await anext(lines)).startswith("data: /messages/?session_id=") @pytest.mark.anyio -async def test_sse_client_basic_connection(server: None, server_url: str) -> None: - async with sse_client(server_url + "/sse") as streams: +async def test_sse_client_basic_connection() -> None: + """A client initializes against, and pings, a server over the SSE transport.""" + factory = in_process_client_factory(make_server_app()) + async with sse_client(f"{BASE_URL}/sse", httpx_client_factory=factory) as streams: async with ClientSession(*streams) as session: - # Test initialization result = await session.initialize() assert isinstance(result, InitializeResult) assert result.server_info.name == SERVER_NAME - # Test ping ping_result = await session.send_ping() assert isinstance(ping_result, EmptyResult) @pytest.mark.anyio -async def test_sse_client_on_session_created(server: None, server_url: str) -> None: +async def test_sse_client_on_session_created() -> None: + """The session-created callback receives the new session ID before sse_client yields.""" + factory = in_process_client_factory(make_server_app()) captured: list[str] = [] - async with sse_client(server_url + "/sse", on_session_created=captured.append) as streams: + async with sse_client( + f"{BASE_URL}/sse", httpx_client_factory=factory, on_session_created=captured.append + ) as streams: async with ClientSession(*streams) as session: result = await session.initialize() assert isinstance(result, InitializeResult) @@ -226,13 +169,14 @@ async def test_sse_client_on_session_created(server: None, server_url: str) -> N ], ) def test_extract_session_id_from_endpoint(endpoint_url: str, expected: str | None) -> None: + """The session ID is read from the endpoint URL's sessionId/session_id query parameters.""" assert _extract_session_id_from_endpoint(endpoint_url) == expected @pytest.mark.anyio -async def test_sse_client_on_session_created_not_called_when_no_session_id( - server: None, server_url: str, monkeypatch: pytest.MonkeyPatch -) -> None: +async def test_sse_client_on_session_created_not_called_when_no_session_id(monkeypatch: pytest.MonkeyPatch) -> None: + """No session-created callback fires when the endpoint URL carries no session ID.""" + factory = in_process_client_factory(make_server_app()) callback_mock = Mock() def mock_extract(url: str) -> None: @@ -240,7 +184,7 @@ def mock_extract(url: str) -> None: monkeypatch.setattr(mcp.client.sse, "_extract_session_id_from_endpoint", mock_extract) - async with sse_client(server_url + "/sse", on_session_created=callback_mock) as streams: + async with sse_client(f"{BASE_URL}/sse", httpx_client_factory=factory, on_session_created=callback_mock) as streams: async with ClientSession(*streams) as session: result = await session.initialize() assert isinstance(result, InitializeResult) @@ -250,8 +194,9 @@ def mock_extract(url: str) -> None: @pytest.fixture -async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: - async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: +async def initialized_sse_client_session() -> AsyncGenerator[ClientSession, None]: + factory = in_process_client_factory(make_server_app()) + async with sse_client(f"{BASE_URL}/sse", httpx_client_factory=factory) as streams: async with ClientSession(*streams) as session: await session.initialize() yield session @@ -261,6 +206,7 @@ async def initialized_sse_client_session(server: None, server_url: str) -> Async async def test_sse_client_happy_request_and_response( initialized_sse_client_session: ClientSession, ) -> None: + """A resource read round-trips its arguments and the handler's content over SSE.""" session = initialized_sse_client_session response = await session.read_resource(uri="foobar://should-work") assert len(response.contents) == 1 @@ -272,93 +218,45 @@ async def test_sse_client_happy_request_and_response( async def test_sse_client_exception_handling( initialized_sse_client_session: ClientSession, ) -> None: + """A server-side MCPError reaches the client with its message intact.""" session = initialized_sse_client_session with pytest.raises(MCPError, match="OOPS! no resource with that URI was found"): await session.read_resource(uri="xxx://will-not-work") @pytest.mark.anyio -@pytest.mark.skip("this test highlights a possible bug in SSE read timeout exception handling") -async def test_sse_client_timeout( # pragma: no cover - initialized_sse_client_session: ClientSession, -) -> None: - session = initialized_sse_client_session - - # sanity check that normal, fast responses are working - response = await session.read_resource(uri="foobar://1") - assert isinstance(response, ReadResourceResult) - - with anyio.move_on_after(3): - with pytest.raises(MCPError, match="Read timed out"): - response = await session.read_resource(uri="slow://2") - # we should receive an error here - return - - pytest.fail("the client should have timed out and returned an error already") - - -def run_mounted_server(server_port: int) -> None: # pragma: no cover - app = make_server_app() - main_app = Starlette(routes=[Mount("/mounted_app", app=app)]) - server = uvicorn.Server(config=uvicorn.Config(app=main_app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting server on {server_port}") - server.run() - - -@pytest.fixture() -def mounted_server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process(target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True) - print("starting process") - proc.start() - - # Wait for server to be running - print("waiting for server to start") - wait_for_server(server_port) +async def test_sse_client_basic_connection_mounted_app() -> None: + """The SSE transport works unchanged when its app is mounted under a sub-path.""" + main_app = Starlette(routes=[Mount("/mounted_app", app=make_server_app())]) + factory = in_process_client_factory(main_app) - yield - - print("killing server") - # Signal the server to stop - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("server process failed to terminate") - - -@pytest.mark.anyio -async def test_sse_client_basic_connection_mounted_app(mounted_server: None, server_url: str) -> None: - async with sse_client(server_url + "/mounted_app/sse") as streams: + async with sse_client(f"{BASE_URL}/mounted_app/sse", httpx_client_factory=factory) as streams: async with ClientSession(*streams) as session: - # Test initialization result = await session.initialize() assert isinstance(result, InitializeResult) assert result.server_info.name == SERVER_NAME - # Test ping ping_result = await session.send_ping() assert isinstance(ping_result, EmptyResult) -async def _handle_context_call_tool( # pragma: no cover - ctx: ServerRequestContext, params: CallToolRequestParams -) -> CallToolResult: - headers_info: dict[str, Any] = {} - if ctx.request: - headers_info = dict(ctx.request.headers) +async def _handle_context_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + assert params.name in ("echo_headers", "echo_context") + assert ctx.request is not None + headers_info = dict(ctx.request.headers) if params.name == "echo_headers": return CallToolResult(content=[TextContent(type="text", text=json.dumps(headers_info))]) - elif params.name == "echo_context": - context_data = { - "request_id": (params.arguments or {}).get("request_id"), - "headers": headers_info, - } - return CallToolResult(content=[TextContent(type="text", text=json.dumps(context_data))]) - return CallToolResult(content=[TextContent(type="text", text=f"Called {params.name}")]) + assert params.arguments is not None + context_data = { + "request_id": params.arguments.get("request_id"), + "headers": headers_info, + } + return CallToolResult(content=[TextContent(type="text", text=json.dumps(context_data))]) -async def _handle_context_list_tools( # pragma: no cover +async def _handle_context_list_tools( ctx: ServerRequestContext, params: PaginatedRequestParams | None ) -> ListToolsResult: return ListToolsResult( @@ -381,115 +279,65 @@ async def _handle_context_list_tools( # pragma: no cover ) -def run_context_server(server_port: int) -> None: # pragma: no cover - """Run a server that captures request context""" - # Configure security with allowed hosts/origins for testing - security_settings = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] - ) - sse = SseServerTransport("/messages/", security_settings=security_settings) - context_server = Server( - "request_context_server", - on_call_tool=_handle_context_call_tool, - on_list_tools=_handle_context_list_tools, - ) - - async def handle_sse(request: Request) -> Response: - async with sse.connect_sse(request.scope, request.receive, request._send) as streams: - await context_server.run(streams[0], streams[1], context_server.create_initialization_options()) - return Response() - - app = Starlette( - routes=[ - Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse.handle_post_message), - ] +def make_context_server_app() -> Starlette: + return make_app( + Server( + "request_context_server", + on_call_tool=_handle_context_call_tool, + on_list_tools=_handle_context_list_tools, + ) ) - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting context server on {server_port}") - server.run() - - -@pytest.fixture() -def context_server(server_port: int) -> Generator[None, None, None]: - """Fixture that provides a server with request context capture""" - proc = multiprocessing.Process(target=run_context_server, kwargs={"server_port": server_port}, daemon=True) - print("starting context server process") - proc.start() - - # Wait for server to be running - print("waiting for context server to start") - wait_for_server(server_port) - - yield - - print("killing context server") - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("context server process failed to terminate") - @pytest.mark.anyio -async def test_request_context_propagation(context_server: None, server_url: str) -> None: - """Test that request context is properly propagated through SSE transport.""" - # Test with custom headers +async def test_request_context_propagation() -> None: + """Custom HTTP headers on the SSE connection are visible to server handlers via ctx.request.""" + factory = in_process_client_factory(make_context_server_app()) + custom_headers = { "Authorization": "Bearer test-token", "X-Custom-Header": "test-value", "X-Trace-Id": "trace-123", } - async with sse_client(server_url + "/sse", headers=custom_headers) as ( - read_stream, - write_stream, - ): - async with ClientSession(read_stream, write_stream) as session: - # Initialize the session + async with sse_client(f"{BASE_URL}/sse", httpx_client_factory=factory, headers=custom_headers) as streams: + async with ClientSession(*streams) as session: result = await session.initialize() assert isinstance(result, InitializeResult) - # Call the tool that echoes headers back tool_result = await session.call_tool("echo_headers", {}) - # Parse the JSON response - assert len(tool_result.content) == 1 - headers_data = json.loads(tool_result.content[0].text if tool_result.content[0].type == "text" else "{}") + content = tool_result.content[0] + assert isinstance(content, TextContent) + headers_data = json.loads(content.text) - # Verify headers were propagated assert headers_data.get("authorization") == "Bearer test-token" assert headers_data.get("x-custom-header") == "test-value" assert headers_data.get("x-trace-id") == "trace-123" @pytest.mark.anyio -async def test_request_context_isolation(context_server: None, server_url: str) -> None: - """Test that request contexts are isolated between different SSE clients.""" +async def test_request_context_isolation() -> None: + """Each SSE connection's handlers see only that connection's request headers.""" + factory = in_process_client_factory(make_context_server_app()) contexts: list[dict[str, Any]] = [] - # Create multiple clients with different headers + # Connect three clients in turn, each with its own headers. for i in range(3): headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"} - async with sse_client(server_url + "/sse", headers=headers) as ( - read_stream, - write_stream, - ): - async with ClientSession(read_stream, write_stream) as session: + async with sse_client(f"{BASE_URL}/sse", httpx_client_factory=factory, headers=headers) as streams: + async with ClientSession(*streams) as session: await session.initialize() - # Call the tool that echoes context tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"}) assert len(tool_result.content) == 1 - context_data = json.loads( - tool_result.content[0].text if tool_result.content[0].type == "text" else "{}" - ) - contexts.append(context_data) + content = tool_result.content[0] + assert isinstance(content, TextContent) + contexts.append(json.loads(content.text)) - # Verify each request had its own context assert len(contexts) == 3 for i, ctx in enumerate(contexts): assert ctx["request_id"] == f"request-{i}" @@ -497,7 +345,7 @@ async def test_request_context_isolation(context_server: None, server_url: str) assert ctx["headers"].get("x-custom-value") == f"value-{i}" -def test_sse_message_id_coercion(): +def test_sse_message_id_coercion() -> None: """Previously, the `RequestId` would coerce a string that looked like an integer into an integer. See for more details. @@ -531,7 +379,7 @@ def test_sse_message_id_coercion(): ("/messages/#fragment", ValueError), ], ) -def test_sse_server_transport_endpoint_validation(endpoint: str, expected_result: str | type[Exception]): +def test_sse_server_transport_endpoint_validation(endpoint: str, expected_result: str | type[Exception]) -> None: """Test that SseServerTransport properly validates and normalizes endpoints.""" if isinstance(expected_result, type): # Test invalid endpoints that should raise an exception @@ -605,7 +453,7 @@ async def mock_aiter_sse() -> AsyncGenerator[ServerSentEvent, None]: @pytest.mark.anyio -async def test_sse_session_cleanup_on_disconnect(server: None, server_url: str) -> None: +async def test_sse_session_cleanup_on_disconnect() -> None: """Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/1227 When a client disconnects, the server should remove the session from @@ -613,18 +461,21 @@ async def test_sse_session_cleanup_on_disconnect(server: None, server_url: str) POST requests to disconnected sessions return 202 Accepted followed by a ClosedResourceError when the server tries to write to the dead stream. """ + factory = in_process_client_factory(make_server_app()) captured: list[str] = [] # Connect a client session, then disconnect - async with sse_client(server_url + "/sse", on_session_created=captured.append) as streams: + async with sse_client( + f"{BASE_URL}/sse", httpx_client_factory=factory, on_session_created=captured.append + ) as streams: async with ClientSession(*streams) as session: await session.initialize() # After disconnect, POST to the stale session should return 404 # (not 202 as it did before the fix) - async with httpx.AsyncClient() as client: + async with factory() as client: response = await client.post( - f"{server_url}/messages/?session_id={captured[0]}", + f"/messages/?session_id={captured[0]}", json={"jsonrpc": "2.0", "method": "ping", "id": 99}, headers={"Content-Type": "application/json"}, )