diff --git a/.github/workflows/claude-code-review.yml b/.github/workflows/claude-code-review.yml deleted file mode 100644 index 36c88040ea..0000000000 --- a/.github/workflows/claude-code-review.yml +++ /dev/null @@ -1,33 +0,0 @@ -# Source: https://github.com/anthropics/claude-code-action/blob/main/docs/code-review.md -name: Claude Code Review - -on: - pull_request: - types: [opened, synchronize, ready_for_review, reopened] - -jobs: - claude-review: - # Fork PRs don't have access to secrets or OIDC tokens, so the action - # cannot authenticate. See https://github.com/anthropics/claude-code-action/issues/339 - if: github.event.pull_request.head.repo.fork == false && github.actor != 'dependabot[bot]' - runs-on: ubuntu-latest - permissions: - contents: read - pull-requests: read - issues: read - id-token: write - - steps: - - name: Checkout repository - uses: actions/checkout@v6 - with: - fetch-depth: 1 - - - name: Run Claude Code Review - id: claude-review - uses: anthropics/claude-code-action@v1 - with: - anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} - plugin_marketplaces: "https://github.com/anthropics/claude-code.git" - plugins: "code-review@claude-code-plugins" - prompt: "/code-review:code-review ${{ github.repository }}/pull/${{ github.event.pull_request.number }}" diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml index 490e9ae2cc..18f5377cb6 100644 --- a/.github/workflows/claude.yml +++ b/.github/workflows/claude.yml @@ -14,7 +14,7 @@ on: jobs: claude: if: | - (github.event_name == 'issue_comment' && contains(github.event.comment.body, '@claude')) || + (github.event_name == 'issue_comment' && contains(github.event.comment.body, '@claude') && !startsWith(github.event.comment.body, '@claude review')) || (github.event_name == 'pull_request_review_comment' && contains(github.event.comment.body, '@claude')) || (github.event_name == 'pull_request_review' && contains(github.event.review.body, '@claude')) || (github.event_name == 'issues' && (contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude'))) @@ -27,15 +27,16 @@ jobs: actions: read # Required for Claude to read CI results on PRs steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: fetch-depth: 1 + persist-credentials: false - name: Run Claude Code id: claude - uses: anthropics/claude-code-action@v1 + uses: anthropics/claude-code-action@2f8ba26a219c06cfb0f468eef8d97055fa814f97 # v1.0.53 with: - anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} + anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} # zizmor: ignore[secrets-outside-env] use_commit_signing: true additional_permissions: | actions: read diff --git a/.github/workflows/comment-on-release.yml b/.github/workflows/comment-on-release.yml index 15d6a1d26a..66f1fcc32a 100644 --- a/.github/workflows/comment-on-release.yml +++ b/.github/workflows/comment-on-release.yml @@ -16,13 +16,16 @@ jobs: uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 with: fetch-depth: 0 + persist-credentials: false - name: Get previous release id: previous_release uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + env: + CURRENT_TAG: ${{ github.event.release.tag_name }} with: script: | - const currentTag = '${{ github.event.release.tag_name }}'; + const currentTag = process.env.CURRENT_TAG; // Get all releases const { data: releases } = await github.rest.repos.listReleases({ @@ -54,10 +57,13 @@ jobs: - name: Get merged PRs between releases id: get_prs uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + env: + CURRENT_TAG: ${{ github.event.release.tag_name }} + PREVIOUS_TAG_JSON: ${{ steps.previous_release.outputs.result }} with: script: | - const currentTag = '${{ github.event.release.tag_name }}'; - const previousTag = ${{ steps.previous_release.outputs.result }}; + const currentTag = process.env.CURRENT_TAG; + const previousTag = JSON.parse(process.env.PREVIOUS_TAG_JSON); if (!previousTag) { console.log('No previous release found, skipping'); @@ -104,11 +110,15 @@ jobs: - name: Comment on PRs uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + env: + PR_NUMBERS_JSON: ${{ steps.get_prs.outputs.result }} + RELEASE_TAG: ${{ github.event.release.tag_name }} + RELEASE_URL: ${{ github.event.release.html_url }} with: script: | - const prNumbers = ${{ steps.get_prs.outputs.result }}; - const releaseTag = '${{ github.event.release.tag_name }}'; - const releaseUrl = '${{ github.event.release.html_url }}'; + const prNumbers = JSON.parse(process.env.PR_NUMBERS_JSON); + const releaseTag = process.env.RELEASE_TAG; + const releaseUrl = process.env.RELEASE_URL; const comment = `This pull request is included in [${releaseTag}](${releaseUrl})`; diff --git a/.github/workflows/conformance.yml b/.github/workflows/conformance.yml index d876da00b0..9c33d2936b 100644 --- a/.github/workflows/conformance.yml +++ b/.github/workflows/conformance.yml @@ -19,6 +19,8 @@ jobs: continue-on-error: true steps: - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 + with: + persist-credentials: false - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 with: enable-cache: true @@ -34,6 +36,8 @@ jobs: continue-on-error: true steps: - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 + with: + persist-credentials: false - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 with: enable-cache: true diff --git a/.github/workflows/deploy-docs.yml b/.github/workflows/deploy-docs.yml new file mode 100644 index 0000000000..fcf7a6c1a7 --- /dev/null +++ b/.github/workflows/deploy-docs.yml @@ -0,0 +1,59 @@ +name: Deploy Docs + +on: + push: + branches: + - main + - v1.x + paths: + - docs/** + - mkdocs.yml + - src/mcp/** + - scripts/build-docs.sh + - pyproject.toml + - uv.lock + - .github/workflows/deploy-docs.yml + workflow_dispatch: + +concurrency: + group: deploy-docs + cancel-in-progress: false + +jobs: + deploy-docs: + runs-on: ubuntu-latest + + permissions: + contents: read + pages: write + id-token: write + + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + + steps: + - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 + with: + persist-credentials: false + + - name: Install uv + uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 + with: + enable-cache: true + version: 0.9.5 + + - name: Build combined docs (v1.x at /, main at /v2/) + run: bash scripts/build-docs.sh site + + - name: Configure Pages + uses: actions/configure-pages@45bfe0192ca1faeb007ade9deae92b16b8254a0d # v6.0.0 + + - name: Upload Pages artifact + uses: actions/upload-pages-artifact@fc324d3547104276b827a68afc52ff2a11cc49c9 # v5.0.0 + with: + path: site + + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@cd2ce8fcbc39b97be8ca5fce6e763baed58fa128 # v5.0.0 diff --git a/.github/workflows/publish-docs-manually.yml b/.github/workflows/publish-docs-manually.yml deleted file mode 100644 index f058174ab1..0000000000 --- a/.github/workflows/publish-docs-manually.yml +++ /dev/null @@ -1,33 +0,0 @@ -name: Publish Docs manually - -on: - workflow_dispatch: - -jobs: - docs-publish: - runs-on: ubuntu-latest - permissions: - contents: write - steps: - - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 - - name: Configure Git Credentials - run: | - git config user.name github-actions[bot] - git config user.email 41898282+github-actions[bot]@users.noreply.github.com - - - name: Install uv - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 - with: - enable-cache: true - version: 0.9.5 - - - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV - - uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3 - with: - key: mkdocs-material-${{ env.cache_id }} - path: .cache - restore-keys: | - mkdocs-material- - - - run: uv sync --frozen --group docs - - run: uv run --frozen --no-sync mkdocs gh-deploy --force diff --git a/.github/workflows/publish-pypi.yml b/.github/workflows/publish-pypi.yml index bfc3cc64e1..c72061d12b 100644 --- a/.github/workflows/publish-pypi.yml +++ b/.github/workflows/publish-pypi.yml @@ -4,6 +4,9 @@ on: release: types: [published] +permissions: + contents: read + jobs: release-build: name: Build distribution @@ -11,11 +14,13 @@ jobs: needs: [checks] steps: - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 + with: + persist-credentials: false - name: Install uv uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 with: - enable-cache: true + enable-cache: false version: 0.9.5 - name: Set up Python 3.12 @@ -51,32 +56,3 @@ jobs: - name: Publish package distributions to PyPI uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # release/v1 - - docs-publish: - runs-on: ubuntu-latest - needs: ["pypi-publish"] - permissions: - contents: write - steps: - - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 - - name: Configure Git Credentials - run: | - git config user.name github-actions[bot] - git config user.email 41898282+github-actions[bot]@users.noreply.github.com - - - name: Install uv - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 - with: - enable-cache: true - version: 0.9.5 - - - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV - - uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3 - with: - key: mkdocs-material-${{ env.cache_id }} - path: .cache - restore-keys: | - mkdocs-material- - - - run: uv sync --frozen --group docs - - run: uv run --frozen --no-sync mkdocs gh-deploy --force diff --git a/.github/workflows/shared.yml b/.github/workflows/shared.yml index fe97895f63..5f115aef89 100644 --- a/.github/workflows/shared.yml +++ b/.github/workflows/shared.yml @@ -14,6 +14,8 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 + with: + persist-credentials: false - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 with: @@ -26,7 +28,19 @@ jobs: with: extra_args: --all-files --verbose env: - SKIP: no-commit-to-branch + SKIP: no-commit-to-branch,readme-v1-frozen + + # TODO(Max): Drop this in v2. + - name: Check README.md is not modified + if: github.event_name == 'pull_request' + run: | + git fetch --no-tags --depth=1 origin "$BASE_SHA" + if git diff --name-only "$BASE_SHA" -- README.md | grep -q .; then + echo "::error::README.md is frozen at v1. Edit README.v2.md instead." + exit 1 + fi + env: + BASE_SHA: ${{ github.event.pull_request.base.sha }} test: name: test (${{ matrix.python-version }}, ${{ matrix.dep-resolution.name }}, ${{ matrix.os }}) @@ -45,6 +59,8 @@ jobs: steps: - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 + with: + persist-credentials: false - name: Install uv uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 @@ -58,6 +74,7 @@ jobs: - name: Run pytest with coverage shell: bash run: | + uv run --frozen --no-sync coverage erase uv run --frozen --no-sync coverage run -m pytest -n auto uv run --frozen --no-sync coverage combine uv run --frozen --no-sync coverage report @@ -70,6 +87,8 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 + with: + persist-credentials: false - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 with: diff --git a/.github/workflows/weekly-lockfile-update.yml b/.github/workflows/weekly-lockfile-update.yml index 8808822476..c30c72991a 100644 --- a/.github/workflows/weekly-lockfile-update.yml +++ b/.github/workflows/weekly-lockfile-update.yml @@ -14,9 +14,11 @@ jobs: update-lockfile: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false - - uses: astral-sh/setup-uv@v7.2.1 + - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 with: version: 0.9.5 @@ -32,6 +34,7 @@ jobs: uses: peter-evans/create-pull-request@c0f553fe549906ede9cf27b5156039d195d2ece0 # v7 with: commit-message: "chore: update uv.lock with latest dependencies" + sign-commits: true title: "chore: weekly dependency update" body-path: pr_body.md branch: weekly-lockfile-update diff --git a/.github/workflows/zizmor.yml b/.github/workflows/zizmor.yml new file mode 100644 index 0000000000..83d3eb3817 --- /dev/null +++ b/.github/workflows/zizmor.yml @@ -0,0 +1,25 @@ +name: GitHub Actions Security Analysis + +on: + push: + branches: ["main"] + pull_request: + branches: ["**"] + +permissions: {} + +jobs: + zizmor: + runs-on: ubuntu-latest + + permissions: + security-events: write # Required for upload-sarif (used by zizmor-action) to upload SARIF files. + + steps: + - name: Checkout repository + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + + - name: Run zizmor 🌈 + uses: zizmorcore/zizmor-action@5f14fd08f7cf1cb1609c1e344975f152c7ee938d # v0.5.6 diff --git a/.gitignore b/.gitignore index 5ff4ce9771..3443adf7c8 100644 --- a/.gitignore +++ b/.gitignore @@ -143,6 +143,7 @@ venv.bak/ # mkdocs documentation /site +/.worktrees/ # mypy .mypy_cache/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 03a8ae0389..42c12fdedd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -55,6 +55,12 @@ repos: language: system files: ^(pyproject\.toml|uv\.lock)$ pass_filenames: false + # TODO(Max): Drop this in v2. + - id: readme-v1-frozen + name: README.md is frozen (v1 docs) + entry: README.md is frozen at v1. Edit README.v2.md instead. + language: fail + files: ^README\.md$ - id: readme-snippets name: Check README snippets are up to date entry: uv run --frozen python scripts/update_readme_snippets.py --check diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000000..307bd81b3e --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,140 @@ +# Development Guidelines + +## Branching Model + + + +- `main` is currently the V2 rework. Breaking changes are expected here — when removing or + replacing an API, delete it outright and document the change in + `docs/migration.md`. Do not add `@deprecated` shims or backward-compat layers + on `main`. +- `v1.x` is the release branch for the current stable line. Backport PRs target + this branch and use a `[v1.x]` title prefix. +- `README.md` is frozen at v1 (a pre-commit hook rejects edits). Edit + `README.v2.md` instead. + +## Package Management + +- ONLY use uv, NEVER pip +- Installation: `uv add ` +- Running tools: `uv run --frozen `. Always pass `--frozen` so uv doesn't + rewrite `uv.lock` as a side effect. +- Cross-version testing: `uv run --frozen --python 3.10 pytest ...` to run + against a specific interpreter (CI covers 3.10–3.14). +- Upgrading: `uv lock --upgrade-package ` +- FORBIDDEN: `uv pip install`, `@latest` syntax +- Don't raise dependency floors for CVEs alone. The `>=` constraint already + lets users upgrade. Only raise a floor when the SDK needs functionality from + the newer version, and don't add SDK code to work around a dependency's + vulnerability. See Kludex/uvicorn#2643 and python-sdk #1552 for reasoning. + +## Code Quality + +- Type hints required for all code +- Public APIs must have docstrings. When a public API raises exceptions a + caller would reasonably catch, document them in a `Raises:` section. Don't + list exceptions from argument validation or programmer error. +- `src/mcp/__init__.py` defines the public API surface via `__all__`. Adding a + symbol there is a deliberate API decision, not a convenience re-export. +- IMPORTANT: All imports go at the top of the file — inline imports hide + dependencies and obscure circular-import bugs. Only exception: when a + top-level import genuinely can't work (lazy-loading optional deps, or + tests that re-import a module). + +## Testing + +- Framework: `uv run --frozen pytest` +- Async testing: use anyio, not asyncio +- Do not use `Test` prefixed classes — write plain top-level `test_*` functions. + Legacy files still contain `Test*` classes; do NOT follow that pattern for new + tests even when adding to such a file. +- IMPORTANT: Tests should be fast and deterministic. Prefer in-memory async execution; + reach for threads only when necessary, and subprocesses only as a last resort. +- For end-to-end behavior, an in-memory `Client(server)` is usually the + cleanest approach (see `tests/client/test_client.py` for the canonical + pattern). For narrower changes, testing the function directly is fine. Use + judgment. +- Test files mirror the source tree: `src/mcp/client/stdio.py` → + `tests/client/test_stdio.py`. Add tests to the existing file for that module. +- Avoid `anyio.sleep()` with a fixed duration to wait for async operations. Instead: + - Use `anyio.Event` — set it in the callback/handler, `await event.wait()` in the test + - For stream messages, use `await stream.receive()` instead of `sleep()` + `receive_nowait()` + - Exception: `sleep()` is appropriate when testing time-based features (e.g., timeouts) +- Wrap indefinite waits (`event.wait()`, `stream.receive()`) in `anyio.fail_after(5)` to prevent hangs +- Pytest is configured with `filterwarnings = ["error"]`, so warnings fail + tests. Don't silence warnings from your own code; fix the underlying cause. + Scoped `ignore::` entries for upstream libraries are acceptable in + `pyproject.toml` with a comment explaining why. + +### Coverage + +CI requires 100% (`fail_under = 100`, `branch = true`). + +- Full check: `./scripts/test` (~23s). Runs coverage + `strict-no-cover` on the + default Python. Not identical to CI: CI runs 3.10–3.14 × {ubuntu, windows} + × {locked, lowest-direct}, and some branch-coverage quirks only surface on + specific matrix entries. +- Targeted check while iterating (~4s, deterministic): + + ```bash + uv run --frozen coverage erase + uv run --frozen coverage run -m pytest tests/path/test_foo.py + uv run --frozen coverage combine + uv run --frozen coverage report --include='src/mcp/path/foo.py' --fail-under=0 + # UV_FROZEN=1 propagates --frozen to the uv subprocess strict-no-cover spawns + UV_FROZEN=1 uv run --frozen strict-no-cover + ``` + + Partial runs can't hit 100% (coverage tracks `tests/` too), so `--fail-under=0` + and `--include` scope the report. `strict-no-cover` has no false positives on + partial runs — if your new test executes a line marked `# pragma: no cover`, + even a single-file run catches it. + +Avoid adding new `# pragma: no cover`, `# type: ignore`, or `# noqa` comments. +In tests, use `assert isinstance(x, T)` to narrow types instead of +`# type: ignore`. In library code (`src/`), a `# pragma: no cover` needs very +good reasoning — it usually means a test is missing. Audit before pushing: + +```bash +git diff origin/main... | grep -E '^\+.*(pragma|type: ignore|noqa)' +``` + +What the existing pragmas mean: + +- `# pragma: no cover` — line is never executed. CI's `strict-no-cover` (skipped + on Windows runners) fails if it IS executed. When your test starts covering + such a line, remove the pragma. +- `# pragma: lax no cover` — excluded from coverage but not checked by + `strict-no-cover`. Use for lines covered on some platforms/versions but not + others. +- `# pragma: no branch` — excludes branch arcs only. coverage.py misreports the + `->exit` arc for nested `async with` on Python 3.11+ (worse on 3.14/Windows). + +## Breaking Changes + +When making breaking changes, document them in `docs/migration.md`. Include: + +- What changed +- Why it changed +- How to migrate existing code + +Search for related sections in the migration guide and group related changes together +rather than adding new standalone sections. + +## Formatting & Type Checking + +- Format: `uv run --frozen ruff format .` +- Lint: `uv run --frozen ruff check . --fix` +- Type check: `uv run --frozen pyright` +- Pre-commit runs all of the above plus markdownlint, a `uv.lock` consistency + check, and README checks — see `.pre-commit-config.yaml` + +## Exception Handling + +- **Always use `logger.exception()` instead of `logger.error()` when catching exceptions** + - Don't include the exception in the message: `logger.exception("Failed")` not `logger.exception(f"Failed: {e}")` +- **Catch specific exceptions** where possible: + - File ops: `except (OSError, PermissionError):` + - JSON: `except json.JSONDecodeError:` + - Network: `except (ConnectionError, TimeoutError):` +- **FORBIDDEN** `except Exception:` - unless in top-level handlers diff --git a/CLAUDE.md b/CLAUDE.md index d7b175636b..43c994c2d3 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,143 +1 @@ -# Development Guidelines - -This document contains critical information about working with this codebase. Follow these guidelines precisely. - -## Core Development Rules - -1. Package Management - - ONLY use uv, NEVER pip - - Installation: `uv add ` - - Running tools: `uv run ` - - Upgrading: `uv lock --upgrade-package ` - - FORBIDDEN: `uv pip install`, `@latest` syntax - -2. Code Quality - - Type hints required for all code - - Public APIs must have docstrings - - Functions must be focused and small - - Follow existing patterns exactly - - Line length: 120 chars maximum - - FORBIDDEN: imports inside functions. THEY SHOULD BE AT THE TOP OF THE FILE. - -3. Testing Requirements - - Framework: `uv run --frozen pytest` - - Async testing: use anyio, not asyncio - - Do not use `Test` prefixed classes, use functions - - Coverage: test edge cases and errors - - New features require tests - - Bug fixes require regression tests - - IMPORTANT: The `tests/client/test_client.py` is the most well designed test file. Follow its patterns. - - IMPORTANT: Be minimal, and focus on E2E tests: Use the `mcp.client.Client` whenever possible. - -Test files mirror the source tree: `src/mcp/client/streamable_http.py` → `tests/client/test_streamable_http.py` -Add tests to the existing file for that module. - -- For commits fixing bugs or adding features based on user reports add: - - ```bash - git commit --trailer "Reported-by:" - ``` - - Where `` is the name of the user. - -- For commits related to a Github issue, add - - ```bash - git commit --trailer "Github-Issue:#" - ``` - -- NEVER ever mention a `co-authored-by` or similar aspects. In particular, never - mention the tool used to create the commit message or PR. - -## Pull Requests - -- Create a detailed message of what changed. Focus on the high level description of - the problem it tries to solve, and how it is solved. Don't go into the specifics of the - code unless it adds clarity. - -- NEVER ever mention a `co-authored-by` or similar aspects. In particular, never - mention the tool used to create the commit message or PR. - -## Breaking Changes - -When making breaking changes, document them in `docs/migration.md`. Include: - -- What changed -- Why it changed -- How to migrate existing code - -Search for related sections in the migration guide and group related changes together -rather than adding new standalone sections. - -## Python Tools - -## Code Formatting - -1. Ruff - - Format: `uv run --frozen ruff format .` - - Check: `uv run --frozen ruff check .` - - Fix: `uv run --frozen ruff check . --fix` - - Critical issues: - - Line length (88 chars) - - Import sorting (I001) - - Unused imports - - Line wrapping: - - Strings: use parentheses - - Function calls: multi-line with proper indent - - Imports: try to use a single line - -2. Type Checking - - Tool: `uv run --frozen pyright` - - Requirements: - - Type narrowing for strings - - Version warnings can be ignored if checks pass - -3. Pre-commit - - Config: `.pre-commit-config.yaml` - - Runs: on git commit - - Tools: Prettier (YAML/JSON), Ruff (Python) - - Ruff updates: - - Check PyPI versions - - Update config rev - - Commit config first - -## Error Resolution - -1. CI Failures - - Fix order: - 1. Formatting - 2. Type errors - 3. Linting - - Type errors: - - Get full line context - - Check Optional types - - Add type narrowing - - Verify function signatures - -2. Common Issues - - Line length: - - Break strings with parentheses - - Multi-line function calls - - Split imports - - Types: - - Add None checks - - Narrow string types - - Match existing patterns - -3. Best Practices - - Check git status before commits - - Run formatters before type checks - - Keep changes minimal - - Follow existing patterns - - Document public APIs - - Test thoroughly - -## Exception Handling - -- **Always use `logger.exception()` instead of `logger.error()` when catching exceptions** - - Don't include the exception in the message: `logger.exception("Failed")` not `logger.exception(f"Failed: {e}")` -- **Catch specific exceptions** where possible: - - File ops: `except (OSError, PermissionError):` - - JSON: `except json.JSONDecodeError:` - - Network: `except (ConnectionError, TimeoutError):` -- **FORBIDDEN** `except Exception:` - unless in top-level handlers +@AGENTS.md diff --git a/README.md b/README.md index dc23d0d1d7..487d48bee4 100644 --- a/README.md +++ b/README.md @@ -13,12 +13,13 @@ -> [!IMPORTANT] -> **This is the `main` branch which contains v2 of the SDK (currently in development, pre-alpha).** -> -> We anticipate a stable v2 release in Q1 2026. Until then, **v1.x remains the recommended version** for production use. v1.x will continue to receive bug fixes and security updates for at least 6 months after v2 ships to give people time to upgrade. + + +> [!NOTE] +> **This README documents v1.x of the MCP Python SDK (the current stable release).** > -> For v1 documentation and code, see the [`v1.x` branch](https://github.com/modelcontextprotocol/python-sdk/tree/v1.x). +> For v1.x code and documentation, see the [`v1.x` branch](https://github.com/modelcontextprotocol/python-sdk/tree/v1.x). +> For the upcoming v2 documentation (pre-alpha, in development on `main`), see [`README.v2.md`](README.v2.md). ## Table of Contents @@ -45,7 +46,7 @@ - [Sampling](#sampling) - [Logging and Notifications](#logging-and-notifications) - [Authentication](#authentication) - - [MCPServer Properties](#mcpserver-properties) + - [FastMCP Properties](#fastmcp-properties) - [Session Properties and Methods](#session-properties-and-methods) - [Request Context Properties](#request-context-properties) - [Running Your Server](#running-your-server) @@ -134,18 +135,19 @@ uv run mcp Let's create a simple MCP server that exposes a calculator tool and some data: - + ```python -"""MCPServer quickstart example. +""" +FastMCP quickstart example. Run from the repository root: - uv run examples/snippets/servers/mcpserver_quickstart.py + uv run examples/snippets/servers/fastmcp_quickstart.py """ -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP # Create an MCP server -mcp = MCPServer("Demo") +mcp = FastMCP("Demo", json_response=True) # Add an addition tool @@ -177,16 +179,16 @@ def greet_user(name: str, style: str = "friendly") -> str: # Run with streamable HTTP transport if __name__ == "__main__": - mcp.run(transport="streamable-http", json_response=True) + mcp.run(transport="streamable-http") ``` -_Full example: [examples/snippets/servers/mcpserver_quickstart.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/mcpserver_quickstart.py)_ +_Full example: [examples/snippets/servers/fastmcp_quickstart.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/fastmcp_quickstart.py)_ You can install this server in [Claude Code](https://docs.claude.com/en/docs/claude-code/mcp) and interact with it right away. First, run the server: ```bash -uv run --with mcp examples/snippets/servers/mcpserver_quickstart.py +uv run --with mcp examples/snippets/servers/fastmcp_quickstart.py ``` Then add it to Claude Code: @@ -216,7 +218,7 @@ The [Model Context Protocol (MCP)](https://modelcontextprotocol.io) lets you bui ### Server -The MCPServer server is your core interface to the MCP protocol. It handles connection management, protocol compliance, and message routing: +The FastMCP server is your core interface to the MCP protocol. It handles connection management, protocol compliance, and message routing: ```python @@ -226,7 +228,7 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass -from mcp.server.mcpserver import Context, MCPServer +from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession @@ -256,7 +258,7 @@ class AppContext: @asynccontextmanager -async def app_lifespan(server: MCPServer) -> AsyncIterator[AppContext]: +async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]: """Manage application lifecycle with type-safe context.""" # Initialize on startup db = await Database.connect() @@ -268,7 +270,7 @@ async def app_lifespan(server: MCPServer) -> AsyncIterator[AppContext]: # Pass lifespan to server -mcp = MCPServer("My App", lifespan=app_lifespan) +mcp = FastMCP("My App", lifespan=app_lifespan) # Access type-safe lifespan context in tools @@ -288,9 +290,9 @@ Resources are how you expose data to LLMs. They're similar to GET endpoints in a ```python -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP -mcp = MCPServer(name="Resource Example") +mcp = FastMCP(name="Resource Example") @mcp.resource("file://documents/{name}") @@ -319,9 +321,9 @@ Tools let LLMs take actions through your server. Unlike resources, tools are exp ```python -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP -mcp = MCPServer(name="Tool Example") +mcp = FastMCP(name="Tool Example") @mcp.tool() @@ -340,14 +342,14 @@ def get_weather(city: str, unit: str = "celsius") -> str: _Full example: [examples/snippets/servers/basic_tool.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/basic_tool.py)_ -Tools can optionally receive a Context object by including a parameter with the `Context` type annotation. This context is automatically injected by the MCPServer framework and provides access to MCP capabilities: +Tools can optionally receive a Context object by including a parameter with the `Context` type annotation. This context is automatically injected by the FastMCP framework and provides access to MCP capabilities: ```python -from mcp.server.mcpserver import Context, MCPServer +from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession -mcp = MCPServer(name="Progress Example") +mcp = FastMCP(name="Progress Example") @mcp.tool() @@ -395,7 +397,7 @@ validated data that clients can easily process. **Note:** For backward compatibility, unstructured results are also returned. Unstructured results are provided for backward compatibility with previous versions of the MCP specification, and are quirks-compatible -with previous versions of MCPServer in the current version of the SDK. +with previous versions of FastMCP in the current version of the SDK. **Note:** In cases where a tool function's return type annotation causes the tool to be classified as structured _and this is undesirable_, @@ -414,10 +416,10 @@ from typing import Annotated from pydantic import BaseModel -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP from mcp.types import CallToolResult, TextContent -mcp = MCPServer("CallToolResult Example") +mcp = FastMCP("CallToolResult Example") class ValidationModel(BaseModel): @@ -441,7 +443,7 @@ def validated_tool() -> Annotated[CallToolResult, ValidationModel]: """Return CallToolResult with structured output validation.""" return CallToolResult( content=[TextContent(type="text", text="Validated response")], - structured_content={"status": "success", "data": {"result": 42}}, + structuredContent={"status": "success", "data": {"result": 42}}, _meta={"internal": "metadata"}, ) @@ -465,9 +467,9 @@ from typing import TypedDict from pydantic import BaseModel, Field -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP -mcp = MCPServer("Structured Output Example") +mcp = FastMCP("Structured Output Example") # Using Pydantic models for rich structured data @@ -567,10 +569,10 @@ Prompts are reusable templates that help LLMs interact with your server effectiv ```python -from mcp.server.mcpserver import MCPServer -from mcp.server.mcpserver.prompts import base +from mcp.server.fastmcp import FastMCP +from mcp.server.fastmcp.prompts import base -mcp = MCPServer(name="Prompt Example") +mcp = FastMCP(name="Prompt Example") @mcp.prompt(title="Code Review") @@ -595,7 +597,7 @@ _Full example: [examples/snippets/servers/basic_prompt.py](https://github.com/mo MCP servers can provide icons for UI display. Icons can be added to the server implementation, tools, resources, and prompts: ```python -from mcp.server.mcpserver import MCPServer, Icon +from mcp.server.fastmcp import FastMCP, Icon # Create an icon from a file path or URL icon = Icon( @@ -605,7 +607,7 @@ icon = Icon( ) # Add icons to server -mcp = MCPServer( +mcp = FastMCP( "My Server", website_url="https://example.com", icons=[icon] @@ -623,21 +625,21 @@ def my_resource(): return "content" ``` -_Full example: [examples/mcpserver/icons_demo.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/mcpserver/icons_demo.py)_ +_Full example: [examples/fastmcp/icons_demo.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/fastmcp/icons_demo.py)_ ### Images -MCPServer provides an `Image` class that automatically handles image data: +FastMCP provides an `Image` class that automatically handles image data: ```python -"""Example showing image handling with MCPServer.""" +"""Example showing image handling with FastMCP.""" from PIL import Image as PILImage -from mcp.server.mcpserver import Image, MCPServer +from mcp.server.fastmcp import FastMCP, Image -mcp = MCPServer("Image Example") +mcp = FastMCP("Image Example") @mcp.tool() @@ -660,9 +662,9 @@ The Context object is automatically injected into tool and resource functions th To use context in a tool or resource function, add a parameter with the `Context` type annotation: ```python -from mcp.server.mcpserver import Context, MCPServer +from mcp.server.fastmcp import Context, FastMCP -mcp = MCPServer(name="Context Example") +mcp = FastMCP(name="Context Example") @mcp.tool() @@ -678,11 +680,11 @@ The Context object provides the following capabilities: - `ctx.request_id` - Unique ID for the current request - `ctx.client_id` - Client ID if available -- `ctx.mcp_server` - Access to the MCPServer server instance (see [MCPServer Properties](#mcpserver-properties)) +- `ctx.fastmcp` - Access to the FastMCP server instance (see [FastMCP Properties](#fastmcp-properties)) - `ctx.session` - Access to the underlying session for advanced communication (see [Session Properties and Methods](#session-properties-and-methods)) - `ctx.request_context` - Access to request-specific data and lifespan resources (see [Request Context Properties](#request-context-properties)) - `await ctx.debug(message)` - Send debug log message -- `await ctx.info(message)` - Send info log message +- `await ctx.info(message)` - Send info log message - `await ctx.warning(message)` - Send warning log message - `await ctx.error(message)` - Send error log message - `await ctx.log(level, message, logger_name=None)` - Send log with custom level @@ -692,10 +694,10 @@ The Context object provides the following capabilities: ```python -from mcp.server.mcpserver import Context, MCPServer +from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession -mcp = MCPServer(name="Progress Example") +mcp = FastMCP(name="Progress Example") @mcp.tool() @@ -726,8 +728,9 @@ Client usage: ```python -"""cd to the `examples/snippets` directory and run: -uv run completion-client +""" +cd to the `examples/snippets` directory and run: + uv run completion-client """ import asyncio @@ -755,8 +758,8 @@ async def run(): # List available resource templates templates = await session.list_resource_templates() print("Available resource templates:") - for template in templates.resource_templates: - print(f" - {template.uri_template}") + for template in templates.resourceTemplates: + print(f" - {template.uriTemplate}") # List available prompts prompts = await session.list_prompts() @@ -765,20 +768,20 @@ async def run(): print(f" - {prompt.name}") # Complete resource template arguments - if templates.resource_templates: - template = templates.resource_templates[0] - print(f"\nCompleting arguments for resource template: {template.uri_template}") + if templates.resourceTemplates: + template = templates.resourceTemplates[0] + print(f"\nCompleting arguments for resource template: {template.uriTemplate}") # Complete without context result = await session.complete( - ref=ResourceTemplateReference(type="ref/resource", uri=template.uri_template), + ref=ResourceTemplateReference(type="ref/resource", uri=template.uriTemplate), argument={"name": "owner", "value": "model"}, ) print(f"Completions for 'owner' starting with 'model': {result.completion.values}") # Complete with context - repo suggestions based on owner result = await session.complete( - ref=ResourceTemplateReference(type="ref/resource", uri=template.uri_template), + ref=ResourceTemplateReference(type="ref/resource", uri=template.uriTemplate), argument={"name": "repo", "value": ""}, context_arguments={"owner": "modelcontextprotocol"}, ) @@ -824,12 +827,12 @@ import uuid from pydantic import BaseModel, Field -from mcp.server.mcpserver import Context, MCPServer +from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession from mcp.shared.exceptions import UrlElicitationRequiredError from mcp.types import ElicitRequestURLParams -mcp = MCPServer(name="Elicitation Example") +mcp = FastMCP(name="Elicitation Example") class BookingPreferences(BaseModel): @@ -908,7 +911,7 @@ async def connect_service(service_name: str, ctx: Context[ServerSession, None]) mode="url", message=f"Authorization required to connect to {service_name}", url=f"https://{service_name}.example.com/oauth/authorize?elicit={elicitation_id}", - elicitation_id=elicitation_id, + elicitationId=elicitation_id, ) ] ) @@ -931,11 +934,11 @@ Tools can interact with LLMs through sampling (generating text): ```python -from mcp.server.mcpserver import Context, MCPServer +from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession from mcp.types import SamplingMessage, TextContent -mcp = MCPServer(name="Sampling Example") +mcp = FastMCP(name="Sampling Example") @mcp.tool() @@ -968,10 +971,10 @@ Tools can send logs and notifications through the context: ```python -from mcp.server.mcpserver import Context, MCPServer +from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession -mcp = MCPServer(name="Notifications Example") +mcp = FastMCP(name="Notifications Example") @mcp.tool() @@ -1002,15 +1005,16 @@ MCP servers can use authentication by providing an implementation of the `TokenV ```python -"""Run from the repository root: -uv run examples/snippets/servers/oauth_server.py +""" +Run from the repository root: + uv run examples/snippets/servers/oauth_server.py """ from pydantic import AnyHttpUrl from mcp.server.auth.provider import AccessToken, TokenVerifier from mcp.server.auth.settings import AuthSettings -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP class SimpleTokenVerifier(TokenVerifier): @@ -1020,9 +1024,10 @@ class SimpleTokenVerifier(TokenVerifier): pass # This is where you would implement actual token validation -# Create MCPServer instance as a Resource Server -mcp = MCPServer( +# Create FastMCP instance as a Resource Server +mcp = FastMCP( "Weather Service", + json_response=True, # Token verifier for authentication token_verifier=SimpleTokenVerifier(), # Auth settings for RFC 9728 Protected Resource Metadata @@ -1046,7 +1051,7 @@ async def get_weather(city: str = "London") -> dict[str, str]: if __name__ == "__main__": - mcp.run(transport="streamable-http", json_response=True) + mcp.run(transport="streamable-http") ``` _Full example: [examples/snippets/servers/oauth_server.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/oauth_server.py)_ @@ -1062,19 +1067,19 @@ For a complete example with separate Authorization Server and Resource Server im See [TokenVerifier](src/mcp/server/auth/provider.py) for more details on implementing token validation. -### MCPServer Properties +### FastMCP Properties -The MCPServer server instance accessible via `ctx.mcp_server` provides access to server configuration and metadata: +The FastMCP server instance accessible via `ctx.fastmcp` provides access to server configuration and metadata: -- `ctx.mcp_server.name` - The server's name as defined during initialization -- `ctx.mcp_server.instructions` - Server instructions/description provided to clients -- `ctx.mcp_server.website_url` - Optional website URL for the server -- `ctx.mcp_server.icons` - Optional list of icons for UI display -- `ctx.mcp_server.settings` - Complete server configuration object containing: +- `ctx.fastmcp.name` - The server's name as defined during initialization +- `ctx.fastmcp.instructions` - Server instructions/description provided to clients +- `ctx.fastmcp.website_url` - Optional website URL for the server +- `ctx.fastmcp.icons` - Optional list of icons for UI display +- `ctx.fastmcp.settings` - Complete server configuration object containing: - `debug` - Debug mode flag - `log_level` - Current logging level - `host` and `port` - Server network configuration - - `sse_path`, `streamable_http_path` - Transport paths + - `mount_path`, `sse_path`, `streamable_http_path` - Transport paths - `stateless_http` - Whether the server operates in stateless mode - And other configuration options @@ -1083,12 +1088,12 @@ The MCPServer server instance accessible via `ctx.mcp_server` provides access to def server_info(ctx: Context) -> dict: """Get information about the current server.""" return { - "name": ctx.mcp_server.name, - "instructions": ctx.mcp_server.instructions, - "debug_mode": ctx.mcp_server.settings.debug, - "log_level": ctx.mcp_server.settings.log_level, - "host": ctx.mcp_server.settings.host, - "port": ctx.mcp_server.settings.port, + "name": ctx.fastmcp.name, + "instructions": ctx.fastmcp.instructions, + "debug_mode": ctx.fastmcp.settings.debug, + "log_level": ctx.fastmcp.settings.log_level, + "host": ctx.fastmcp.settings.host, + "port": ctx.fastmcp.settings.port, } ``` @@ -1110,13 +1115,13 @@ The session object accessible via `ctx.session` provides advanced control over c async def notify_data_update(resource_uri: str, ctx: Context) -> str: """Update data and notify clients of the change.""" # Perform data update logic here - + # Notify clients that this specific resource changed await ctx.session.send_resource_updated(AnyUrl(resource_uri)) - + # If this affects the overall resource list, notify about that too await ctx.session.send_resource_list_changed() - + return f"Updated {resource_uri} and notified clients" ``` @@ -1145,11 +1150,11 @@ def query_with_config(query: str, ctx: Context) -> str: """Execute a query using shared database and configuration.""" # Access typed lifespan context app_ctx: AppContext = ctx.request_context.lifespan_context - + # Use shared resources connection = app_ctx.db settings = app_ctx.config - + # Execute query with configuration result = connection.execute(query, timeout=settings.query_timeout) return str(result) @@ -1203,9 +1208,9 @@ cd to the `examples/snippets` directory and run: python servers/direct_execution.py """ -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP -mcp = MCPServer("My App") +mcp = FastMCP("My App") @mcp.tool() @@ -1234,7 +1239,7 @@ python servers/direct_execution.py uv run mcp run servers/direct_execution.py ``` -Note that `uv run mcp run` or `uv run mcp dev` only supports server using MCPServer and not the low-level server variant. +Note that `uv run mcp run` or `uv run mcp dev` only supports server using FastMCP and not the low-level server variant. ### Streamable HTTP Transport @@ -1242,13 +1247,22 @@ Note that `uv run mcp run` or `uv run mcp dev` only supports server using MCPSer ```python -"""Run from the repository root: -uv run examples/snippets/servers/streamable_config.py +""" +Run from the repository root: + uv run examples/snippets/servers/streamable_config.py """ -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP -mcp = MCPServer("StatelessServer") +# Stateless server with JSON responses (recommended) +mcp = FastMCP("StatelessServer", stateless_http=True, json_response=True) + +# Other configuration options: +# Stateless server with SSE streaming responses +# mcp = FastMCP("StatelessServer", stateless_http=True) + +# Stateful server with session persistence +# mcp = FastMCP("StatefulServer") # Add a simple tool to demonstrate the server @@ -1259,28 +1273,20 @@ def greet(name: str = "World") -> str: # Run server with streamable_http transport -# Transport-specific options (stateless_http, json_response) are passed to run() if __name__ == "__main__": - # Stateless server with JSON responses (recommended) - mcp.run(transport="streamable-http", stateless_http=True, json_response=True) - - # Other configuration options: - # Stateless server with SSE streaming responses - # mcp.run(transport="streamable-http", stateless_http=True) - - # Stateful server with session persistence - # mcp.run(transport="streamable-http") + mcp.run(transport="streamable-http") ``` _Full example: [examples/snippets/servers/streamable_config.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/streamable_config.py)_ -You can mount multiple MCPServer servers in a Starlette application: +You can mount multiple FastMCP servers in a Starlette application: ```python -"""Run from the repository root: -uvicorn examples.snippets.servers.streamable_starlette_mount:app --reload +""" +Run from the repository root: + uvicorn examples.snippets.servers.streamable_starlette_mount:app --reload """ import contextlib @@ -1288,10 +1294,10 @@ import contextlib from starlette.applications import Starlette from starlette.routing import Mount -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP # Create the Echo server -echo_mcp = MCPServer(name="EchoServer") +echo_mcp = FastMCP(name="EchoServer", stateless_http=True, json_response=True) @echo_mcp.tool() @@ -1301,7 +1307,7 @@ def echo(message: str) -> str: # Create the Math server -math_mcp = MCPServer(name="MathServer") +math_mcp = FastMCP(name="MathServer", stateless_http=True, json_response=True) @math_mcp.tool() @@ -1322,16 +1328,16 @@ async def lifespan(app: Starlette): # Create the Starlette app and mount the MCP servers app = Starlette( routes=[ - Mount("/echo", echo_mcp.streamable_http_app(stateless_http=True, json_response=True)), - Mount("/math", math_mcp.streamable_http_app(stateless_http=True, json_response=True)), + Mount("/echo", echo_mcp.streamable_http_app()), + Mount("/math", math_mcp.streamable_http_app()), ], lifespan=lifespan, ) # Note: Clients connect to http://localhost:8000/echo/mcp and http://localhost:8000/math/mcp # To mount at the root of each path (e.g., /echo instead of /echo/mcp): -# echo_mcp.streamable_http_app(streamable_http_path="/", stateless_http=True, json_response=True) -# math_mcp.streamable_http_app(streamable_http_path="/", stateless_http=True, json_response=True) +# echo_mcp.settings.streamable_http_path = "/" +# math_mcp.settings.streamable_http_path = "/" ``` _Full example: [examples/snippets/servers/streamable_starlette_mount.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/streamable_starlette_mount.py)_ @@ -1389,7 +1395,8 @@ You can mount the StreamableHTTP server to an existing ASGI server using the `st ```python -"""Basic example showing how to mount StreamableHTTP server in Starlette. +""" +Basic example showing how to mount StreamableHTTP server in Starlette. Run from the repository root: uvicorn examples.snippets.servers.streamable_http_basic_mounting:app --reload @@ -1400,10 +1407,10 @@ import contextlib from starlette.applications import Starlette from starlette.routing import Mount -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP # Create MCP server -mcp = MCPServer("My App") +mcp = FastMCP("My App", json_response=True) @mcp.tool() @@ -1420,10 +1427,9 @@ async def lifespan(app: Starlette): # Mount the StreamableHTTP server to the existing ASGI server -# Transport-specific options are passed to streamable_http_app() app = Starlette( routes=[ - Mount("/", app=mcp.streamable_http_app(json_response=True)), + Mount("/", app=mcp.streamable_http_app()), ], lifespan=lifespan, ) @@ -1436,7 +1442,8 @@ _Full example: [examples/snippets/servers/streamable_http_basic_mounting.py](htt ```python -"""Example showing how to mount StreamableHTTP server using Host-based routing. +""" +Example showing how to mount StreamableHTTP server using Host-based routing. Run from the repository root: uvicorn examples.snippets.servers.streamable_http_host_mounting:app --reload @@ -1447,10 +1454,10 @@ import contextlib from starlette.applications import Starlette from starlette.routing import Host -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP # Create MCP server -mcp = MCPServer("MCP Host App") +mcp = FastMCP("MCP Host App", json_response=True) @mcp.tool() @@ -1467,10 +1474,9 @@ async def lifespan(app: Starlette): # Mount using Host-based routing -# Transport-specific options are passed to streamable_http_app() app = Starlette( routes=[ - Host("mcp.acme.corp", app=mcp.streamable_http_app(json_response=True)), + Host("mcp.acme.corp", app=mcp.streamable_http_app()), ], lifespan=lifespan, ) @@ -1483,7 +1489,8 @@ _Full example: [examples/snippets/servers/streamable_http_host_mounting.py](http ```python -"""Example showing how to mount multiple StreamableHTTP servers with path configuration. +""" +Example showing how to mount multiple StreamableHTTP servers with path configuration. Run from the repository root: uvicorn examples.snippets.servers.streamable_http_multiple_servers:app --reload @@ -1494,11 +1501,11 @@ import contextlib from starlette.applications import Starlette from starlette.routing import Mount -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP # Create multiple MCP servers -api_mcp = MCPServer("API Server") -chat_mcp = MCPServer("Chat Server") +api_mcp = FastMCP("API Server", json_response=True) +chat_mcp = FastMCP("Chat Server", json_response=True) @api_mcp.tool() @@ -1513,6 +1520,12 @@ def send_message(message: str) -> str: return f"Message sent: {message}" +# Configure servers to mount at the root of each path +# This means endpoints will be at /api and /chat instead of /api/mcp and /chat/mcp +api_mcp.settings.streamable_http_path = "/" +chat_mcp.settings.streamable_http_path = "/" + + # Create a combined lifespan to manage both session managers @contextlib.asynccontextmanager async def lifespan(app: Starlette): @@ -1522,12 +1535,11 @@ async def lifespan(app: Starlette): yield -# Mount the servers with transport-specific options passed to streamable_http_app() -# streamable_http_path="/" means endpoints will be at /api and /chat instead of /api/mcp and /chat/mcp +# Mount the servers app = Starlette( routes=[ - Mount("/api", app=api_mcp.streamable_http_app(json_response=True, streamable_http_path="/")), - Mount("/chat", app=chat_mcp.streamable_http_app(json_response=True, streamable_http_path="/")), + Mount("/api", app=api_mcp.streamable_http_app()), + Mount("/chat", app=chat_mcp.streamable_http_app()), ], lifespan=lifespan, ) @@ -1540,7 +1552,8 @@ _Full example: [examples/snippets/servers/streamable_http_multiple_servers.py](h ```python -"""Example showing path configuration when mounting MCPServer. +""" +Example showing path configuration during FastMCP initialization. Run from the repository root: uvicorn examples.snippets.servers.streamable_http_path_config:app --reload @@ -1549,10 +1562,15 @@ Run from the repository root: from starlette.applications import Starlette from starlette.routing import Mount -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP -# Create a simple MCPServer server -mcp_at_root = MCPServer("My Server") +# Configure streamable_http_path during initialization +# This server will mount at the root of wherever it's mounted +mcp_at_root = FastMCP( + "My Server", + json_response=True, + streamable_http_path="/", +) @mcp_at_root.tool() @@ -1561,14 +1579,10 @@ def process_data(data: str) -> str: return f"Processed: {data}" -# Mount at /process with streamable_http_path="/" so the endpoint is /process (not /process/mcp) -# Transport-specific options like json_response are passed to streamable_http_app() +# Mount at /process - endpoints will be at /process instead of /process/mcp app = Starlette( routes=[ - Mount( - "/process", - app=mcp_at_root.streamable_http_app(json_response=True, streamable_http_path="/"), - ), + Mount("/process", app=mcp_at_root.streamable_http_app()), ] ) ``` @@ -1585,10 +1599,10 @@ You can mount the SSE server to an existing ASGI server using the `sse_app` meth ```python from starlette.applications import Starlette from starlette.routing import Mount, Host -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP -mcp = MCPServer("My App") +mcp = FastMCP("My App") # Mount the SSE server to the existing ASGI server app = Starlette( @@ -1601,28 +1615,41 @@ app = Starlette( app.router.routes.append(Host('mcp.acme.corp', app=mcp.sse_app())) ``` -You can also mount multiple MCP servers at different sub-paths. The SSE transport automatically detects the mount path via ASGI's `root_path` mechanism, so message endpoints are correctly routed: +When mounting multiple MCP servers under different paths, you can configure the mount path in several ways: ```python from starlette.applications import Starlette from starlette.routing import Mount -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP # Create multiple MCP servers -github_mcp = MCPServer("GitHub API") -browser_mcp = MCPServer("Browser") -search_mcp = MCPServer("Search") +github_mcp = FastMCP("GitHub API") +browser_mcp = FastMCP("Browser") +curl_mcp = FastMCP("Curl") +search_mcp = FastMCP("Search") + +# Method 1: Configure mount paths via settings (recommended for persistent configuration) +github_mcp.settings.mount_path = "/github" +browser_mcp.settings.mount_path = "/browser" -# Mount each server at its own sub-path -# The SSE transport automatically uses ASGI's root_path to construct -# the correct message endpoint (e.g., /github/messages/, /browser/messages/) +# Method 2: Pass mount path directly to sse_app (preferred for ad-hoc mounting) +# This approach doesn't modify the server's settings permanently + +# Create Starlette app with multiple mounted servers app = Starlette( routes=[ + # Using settings-based configuration Mount("/github", app=github_mcp.sse_app()), Mount("/browser", app=browser_mcp.sse_app()), - Mount("/search", app=search_mcp.sse_app()), + # Using direct mount path parameter + Mount("/curl", app=curl_mcp.sse_app("/curl")), + Mount("/search", app=search_mcp.sse_app("/search")), ] ) + +# Method 3: For direct execution, you can also pass the mount path to run() +if __name__ == "__main__": + search_mcp.run(transport="sse", mount_path="/search") ``` For more information on mounting applications in Starlette, see the [Starlette documentation](https://www.starlette.io/routing/#submounting-routes). @@ -1635,8 +1662,9 @@ For more control, you can use the low-level server implementation directly. This ```python -"""Run from the repository root: -uv run examples/snippets/servers/lowlevel/lifespan.py +""" +Run from the repository root: + uv run examples/snippets/servers/lowlevel/lifespan.py """ from collections.abc import AsyncIterator @@ -1644,7 +1672,7 @@ from contextlib import asynccontextmanager from typing import Any import mcp.server.stdio -from mcp import types +import mcp.types as types from mcp.server.lowlevel import NotificationOptions, Server from mcp.server.models import InitializationOptions @@ -1692,7 +1720,7 @@ async def handle_list_tools() -> list[types.Tool]: types.Tool( name="query_db", description="Query the database", - input_schema={ + inputSchema={ "type": "object", "properties": {"query": {"type": "string", "description": "SQL query to execute"}}, "required": ["query"], @@ -1751,14 +1779,15 @@ The lifespan API provides: ```python -"""Run from the repository root: +""" +Run from the repository root: uv run examples/snippets/servers/lowlevel/basic.py """ import asyncio import mcp.server.stdio -from mcp import types +import mcp.types as types from mcp.server.lowlevel import NotificationOptions, Server from mcp.server.models import InitializationOptions @@ -1829,15 +1858,16 @@ The low-level server supports structured output for tools, allowing you to retur ```python -"""Run from the repository root: -uv run examples/snippets/servers/lowlevel/structured_output.py +""" +Run from the repository root: + uv run examples/snippets/servers/lowlevel/structured_output.py """ import asyncio from typing import Any import mcp.server.stdio -from mcp import types +import mcp.types as types from mcp.server.lowlevel import NotificationOptions, Server from mcp.server.models import InitializationOptions @@ -1851,12 +1881,12 @@ async def list_tools() -> list[types.Tool]: types.Tool( name="get_weather", description="Get current weather for a city", - input_schema={ + inputSchema={ "type": "object", "properties": {"city": {"type": "string", "description": "City name"}}, "required": ["city"], }, - output_schema={ + outputSchema={ "type": "object", "properties": { "temperature": {"type": "number", "description": "Temperature in Celsius"}, @@ -1931,15 +1961,16 @@ For full control over the response including the `_meta` field (for passing data ```python -"""Run from the repository root: -uv run examples/snippets/servers/lowlevel/direct_call_tool_result.py +""" +Run from the repository root: + uv run examples/snippets/servers/lowlevel/direct_call_tool_result.py """ import asyncio from typing import Any import mcp.server.stdio -from mcp import types +import mcp.types as types from mcp.server.lowlevel import NotificationOptions, Server from mcp.server.models import InitializationOptions @@ -1953,7 +1984,7 @@ async def list_tools() -> list[types.Tool]: types.Tool( name="advanced_tool", description="Tool with full control including _meta field", - input_schema={ + inputSchema={ "type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"], @@ -1969,7 +2000,7 @@ async def handle_call_tool(name: str, arguments: dict[str, Any]) -> types.CallTo message = str(arguments.get("message", "")) return types.CallToolResult( content=[types.TextContent(type="text", text=f"Processed: {message}")], - structured_content={"result": "success", "message": message}, + structuredContent={"result": "success", "message": message}, _meta={"hidden": "data for client applications only"}, ) @@ -2010,9 +2041,13 @@ For servers that need to handle large datasets, the low-level server provides pa ```python -"""Example of implementing pagination with MCP server decorators.""" +""" +Example of implementing pagination with MCP server decorators. +""" + +from pydantic import AnyUrl -from mcp import types +import mcp.types as types from mcp.server.lowlevel import Server # Initialize the server @@ -2036,14 +2071,14 @@ async def list_resources_paginated(request: types.ListResourcesRequest) -> types # Get page of resources page_items = [ - types.Resource(uri=f"resource://items/{item}", name=item, description=f"Description for {item}") + types.Resource(uri=AnyUrl(f"resource://items/{item}"), name=item, description=f"Description for {item}") for item in ITEMS[start:end] ] # Determine next cursor next_cursor = str(end) if end < len(ITEMS) else None - return types.ListResourcesResult(resources=page_items, next_cursor=next_cursor) + return types.ListResourcesResult(resources=page_items, nextCursor=next_cursor) ``` _Full example: [examples/snippets/servers/pagination_example.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/pagination_example.py)_ @@ -2053,7 +2088,9 @@ _Full example: [examples/snippets/servers/pagination_example.py](https://github. ```python -"""Example of consuming paginated MCP endpoints from a client.""" +""" +Example of consuming paginated MCP endpoints from a client. +""" import asyncio @@ -2082,8 +2119,8 @@ async def list_all_resources() -> None: print(f"Fetched {len(result.resources)} resources") # Check if there are more pages - if result.next_cursor: - cursor = result.next_cursor + if result.nextCursor: + cursor = result.nextCursor else: break @@ -2112,28 +2149,31 @@ The SDK provides a high-level client interface for connecting to MCP servers usi ```python -"""cd to the `examples/snippets/clients` directory and run: -uv run client +""" +cd to the `examples/snippets/clients` directory and run: + uv run client """ import asyncio import os +from pydantic import AnyUrl + from mcp import ClientSession, StdioServerParameters, types -from mcp.client.context import ClientRequestContext from mcp.client.stdio import stdio_client +from mcp.shared.context import RequestContext # Create server parameters for stdio connection server_params = StdioServerParameters( command="uv", # Using uv to run the server - args=["run", "server", "mcpserver_quickstart", "stdio"], # We're already in snippets dir + args=["run", "server", "fastmcp_quickstart", "stdio"], # We're already in snippets dir env={"UV_INDEX": os.environ.get("UV_INDEX", "")}, ) # Optional: create a sampling callback async def handle_sampling_message( - context: ClientRequestContext, params: types.CreateMessageRequestParams + context: RequestContext[ClientSession, None], params: types.CreateMessageRequestParams ) -> types.CreateMessageResult: print(f"Sampling request: {params.messages}") return types.CreateMessageResult( @@ -2143,7 +2183,7 @@ async def handle_sampling_message( text="Hello, world! from model", ), model="gpt-3.5-turbo", - stop_reason="endTurn", + stopReason="endTurn", ) @@ -2157,7 +2197,7 @@ async def run(): prompts = await session.list_prompts() print(f"Available prompts: {[p.name for p in prompts.prompts]}") - # Get a prompt (greet_user prompt from mcpserver_quickstart) + # Get a prompt (greet_user prompt from fastmcp_quickstart) if prompts.prompts: prompt = await session.get_prompt("greet_user", arguments={"name": "Alice", "style": "friendly"}) print(f"Prompt result: {prompt.messages[0].content}") @@ -2170,18 +2210,18 @@ async def run(): tools = await session.list_tools() print(f"Available tools: {[t.name for t in tools.tools]}") - # Read a resource (greeting resource from mcpserver_quickstart) - resource_content = await session.read_resource("greeting://World") + # Read a resource (greeting resource from fastmcp_quickstart) + resource_content = await session.read_resource(AnyUrl("greeting://World")) content_block = resource_content.contents[0] if isinstance(content_block, types.TextContent): print(f"Resource content: {content_block.text}") - # Call a tool (add tool from mcpserver_quickstart) + # Call a tool (add tool from fastmcp_quickstart) result = await session.call_tool("add", arguments={"a": 5, "b": 3}) result_unstructured = result.content[0] if isinstance(result_unstructured, types.TextContent): print(f"Tool result: {result_unstructured.text}") - result_structured = result.structured_content + result_structured = result.structuredContent print(f"Structured tool result: {result_structured}") @@ -2201,8 +2241,9 @@ Clients can also connect using [Streamable HTTP transport](https://modelcontextp ```python -"""Run from the repository root: -uv run examples/snippets/clients/streamable_basic.py +""" +Run from the repository root: + uv run examples/snippets/clients/streamable_basic.py """ import asyncio @@ -2240,8 +2281,9 @@ When building MCP clients, the SDK provides utilities to help display human-read ```python -"""cd to the `examples/snippets` directory and run: -uv run display-utilities-client +""" +cd to the `examples/snippets` directory and run: + uv run display-utilities-client """ import asyncio @@ -2254,7 +2296,7 @@ from mcp.shared.metadata_utils import get_display_name # Create server parameters for stdio connection server_params = StdioServerParameters( command="uv", # Using uv to run the server - args=["run", "server", "mcpserver_quickstart", "stdio"], + args=["run", "server", "fastmcp_quickstart", "stdio"], env={"UV_INDEX": os.environ.get("UV_INDEX", "")}, ) @@ -2280,7 +2322,7 @@ async def display_resources(session: ClientSession): print(f"Resource: {display_name} ({resource.uri})") templates_response = await session.list_resource_templates() - for template in templates_response.resource_templates: + for template in templates_response.resourceTemplates: display_name = get_display_name(template) print(f"Resource Template: {display_name}") @@ -2324,7 +2366,8 @@ The SDK includes [authorization support](https://modelcontextprotocol.io/specifi ```python -"""Before running, specify running MCP RS server URL. +""" +Before running, specify running MCP RS server URL. To spin up RS server locally, see examples/servers/simple-auth/README.md diff --git a/README.v2.md b/README.v2.md index 67f181811f..d0851c04e5 100644 --- a/README.v2.md +++ b/README.v2.md @@ -346,13 +346,12 @@ Tools can optionally receive a Context object by including a parameter with the ```python from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession mcp = MCPServer(name="Progress Example") @mcp.tool() -async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str: +async def long_running_task(task_name: str, ctx: Context, steps: int = 5) -> str: """Execute a task with progress updates.""" await ctx.info(f"Starting: {task_name}") @@ -682,11 +681,11 @@ The Context object provides the following capabilities: - `ctx.mcp_server` - Access to the MCPServer server instance (see [MCPServer Properties](#mcpserver-properties)) - `ctx.session` - Access to the underlying session for advanced communication (see [Session Properties and Methods](#session-properties-and-methods)) - `ctx.request_context` - Access to request-specific data and lifespan resources (see [Request Context Properties](#request-context-properties)) -- `await ctx.debug(message)` - Send debug log message -- `await ctx.info(message)` - Send info log message -- `await ctx.warning(message)` - Send warning log message -- `await ctx.error(message)` - Send error log message -- `await ctx.log(level, message, logger_name=None)` - Send log with custom level +- `await ctx.debug(data)` - Send debug log message +- `await ctx.info(data)` - Send info log message +- `await ctx.warning(data)` - Send warning log message +- `await ctx.error(data)` - Send error log message +- `await ctx.log(level, data, logger_name=None)` - Send log with custom level - `await ctx.report_progress(progress, total=None, message=None)` - Report operation progress - `await ctx.read_resource(uri)` - Read a resource by URI - `await ctx.elicit(message, schema)` - Request additional information from user with validation @@ -694,13 +693,12 @@ The Context object provides the following capabilities: ```python from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession mcp = MCPServer(name="Progress Example") @mcp.tool() -async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str: +async def long_running_task(task_name: str, ctx: Context, steps: int = 5) -> str: """Execute a task with progress updates.""" await ctx.info(f"Starting: {task_name}") @@ -826,7 +824,6 @@ import uuid from pydantic import BaseModel, Field from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession from mcp.shared.exceptions import UrlElicitationRequiredError from mcp.types import ElicitRequestURLParams @@ -844,7 +841,7 @@ class BookingPreferences(BaseModel): @mcp.tool() -async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerSession, None]) -> str: +async def book_table(date: str, time: str, party_size: int, ctx: Context) -> str: """Book a table with date availability check. This demonstrates form mode elicitation for collecting non-sensitive user input. @@ -868,7 +865,7 @@ async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerS @mcp.tool() -async def secure_payment(amount: float, ctx: Context[ServerSession, None]) -> str: +async def secure_payment(amount: float, ctx: Context) -> str: """Process a secure payment requiring URL confirmation. This demonstrates URL mode elicitation using ctx.elicit_url() for @@ -892,7 +889,7 @@ async def secure_payment(amount: float, ctx: Context[ServerSession, None]) -> st @mcp.tool() -async def connect_service(service_name: str, ctx: Context[ServerSession, None]) -> str: +async def connect_service(service_name: str, ctx: Context) -> str: """Connect to a third-party service requiring OAuth authorization. This demonstrates the "throw error" pattern using UrlElicitationRequiredError. @@ -933,14 +930,13 @@ Tools can interact with LLMs through sampling (generating text): ```python from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession from mcp.types import SamplingMessage, TextContent mcp = MCPServer(name="Sampling Example") @mcp.tool() -async def generate_poem(topic: str, ctx: Context[ServerSession, None]) -> str: +async def generate_poem(topic: str, ctx: Context) -> str: """Generate a poem using LLM sampling.""" prompt = f"Write a short poem about {topic}" @@ -970,13 +966,12 @@ Tools can send logs and notifications through the context: ```python from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession mcp = MCPServer(name="Notifications Example") @mcp.tool() -async def process_data(data: str, ctx: Context[ServerSession, None]) -> str: +async def process_data(data: str, ctx: Context) -> str: """Process data with logging.""" # Different log levels await ctx.debug(f"Debug: Processing '{data}'") @@ -1642,12 +1637,11 @@ uv run examples/snippets/servers/lowlevel/lifespan.py from collections.abc import AsyncIterator from contextlib import asynccontextmanager -from typing import Any +from typing import TypedDict import mcp.server.stdio from mcp import types -from mcp.server.lowlevel import NotificationOptions, Server -from mcp.server.models import InitializationOptions +from mcp.server import Server, ServerRequestContext # Mock database class for example @@ -1670,52 +1664,58 @@ class Database: return [{"id": "1", "name": "Example", "query": query_str}] +class AppContext(TypedDict): + db: Database + + @asynccontextmanager -async def server_lifespan(_server: Server) -> AsyncIterator[dict[str, Any]]: +async def server_lifespan(_server: Server[AppContext]) -> AsyncIterator[AppContext]: """Manage server startup and shutdown lifecycle.""" - # Initialize resources on startup db = await Database.connect() try: yield {"db": db} finally: - # Clean up on shutdown await db.disconnect() -# Pass lifespan to server -server = Server("example-server", lifespan=server_lifespan) - - -@server.list_tools() -async def handle_list_tools() -> list[types.Tool]: +async def handle_list_tools( + ctx: ServerRequestContext[AppContext], params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: """List available tools.""" - return [ - types.Tool( - name="query_db", - description="Query the database", - input_schema={ - "type": "object", - "properties": {"query": {"type": "string", "description": "SQL query to execute"}}, - "required": ["query"], - }, - ) - ] + return types.ListToolsResult( + tools=[ + types.Tool( + name="query_db", + description="Query the database", + input_schema={ + "type": "object", + "properties": {"query": {"type": "string", "description": "SQL query to execute"}}, + "required": ["query"], + }, + ) + ] + ) -@server.call_tool() -async def query_db(name: str, arguments: dict[str, Any]) -> list[types.TextContent]: +async def handle_call_tool( + ctx: ServerRequestContext[AppContext], params: types.CallToolRequestParams +) -> types.CallToolResult: """Handle database query tool call.""" - if name != "query_db": - raise ValueError(f"Unknown tool: {name}") + if params.name != "query_db": + raise ValueError(f"Unknown tool: {params.name}") - # Access lifespan context - ctx = server.request_context db = ctx.lifespan_context["db"] + results = await db.query((params.arguments or {})["query"]) - # Execute query - results = await db.query(arguments["query"]) + return types.CallToolResult(content=[types.TextContent(type="text", text=f"Query results: {results}")]) - return [types.TextContent(type="text", text=f"Query results: {results}")] + +server = Server( + "example-server", + lifespan=server_lifespan, + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) async def run(): @@ -1724,14 +1724,7 @@ async def run(): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="example-server", - server_version="0.1.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) @@ -1760,32 +1753,30 @@ import asyncio import mcp.server.stdio from mcp import types -from mcp.server.lowlevel import NotificationOptions, Server -from mcp.server.models import InitializationOptions - -# Create a server instance -server = Server("example-server") +from mcp.server import Server, ServerRequestContext -@server.list_prompts() -async def handle_list_prompts() -> list[types.Prompt]: +async def handle_list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListPromptsResult: """List available prompts.""" - return [ - types.Prompt( - name="example-prompt", - description="An example prompt template", - arguments=[types.PromptArgument(name="arg1", description="Example argument", required=True)], - ) - ] + return types.ListPromptsResult( + prompts=[ + types.Prompt( + name="example-prompt", + description="An example prompt template", + arguments=[types.PromptArgument(name="arg1", description="Example argument", required=True)], + ) + ] + ) -@server.get_prompt() -async def handle_get_prompt(name: str, arguments: dict[str, str] | None) -> types.GetPromptResult: +async def handle_get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> types.GetPromptResult: """Get a specific prompt by name.""" - if name != "example-prompt": - raise ValueError(f"Unknown prompt: {name}") + if params.name != "example-prompt": + raise ValueError(f"Unknown prompt: {params.name}") - arg1_value = (arguments or {}).get("arg1", "default") + arg1_value = (params.arguments or {}).get("arg1", "default") return types.GetPromptResult( description="Example prompt", @@ -1798,20 +1789,20 @@ async def handle_get_prompt(name: str, arguments: dict[str, str] | None) -> type ) +server = Server( + "example-server", + on_list_prompts=handle_list_prompts, + on_get_prompt=handle_get_prompt, +) + + async def run(): """Run the basic low-level server.""" async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="example", - server_version="0.1.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) @@ -1835,62 +1826,67 @@ uv run examples/snippets/servers/lowlevel/structured_output.py """ import asyncio -from typing import Any +import json import mcp.server.stdio from mcp import types -from mcp.server.lowlevel import NotificationOptions, Server -from mcp.server.models import InitializationOptions +from mcp.server import Server, ServerRequestContext -server = Server("example-server") - -@server.list_tools() -async def list_tools() -> list[types.Tool]: +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: """List available tools with structured output schemas.""" - return [ - types.Tool( - name="get_weather", - description="Get current weather for a city", - input_schema={ - "type": "object", - "properties": {"city": {"type": "string", "description": "City name"}}, - "required": ["city"], - }, - output_schema={ - "type": "object", - "properties": { - "temperature": {"type": "number", "description": "Temperature in Celsius"}, - "condition": {"type": "string", "description": "Weather condition"}, - "humidity": {"type": "number", "description": "Humidity percentage"}, - "city": {"type": "string", "description": "City name"}, + return types.ListToolsResult( + tools=[ + types.Tool( + name="get_weather", + description="Get current weather for a city", + input_schema={ + "type": "object", + "properties": {"city": {"type": "string", "description": "City name"}}, + "required": ["city"], }, - "required": ["temperature", "condition", "humidity", "city"], - }, - ) - ] + output_schema={ + "type": "object", + "properties": { + "temperature": {"type": "number", "description": "Temperature in Celsius"}, + "condition": {"type": "string", "description": "Weather condition"}, + "humidity": {"type": "number", "description": "Humidity percentage"}, + "city": {"type": "string", "description": "City name"}, + }, + "required": ["temperature", "condition", "humidity", "city"], + }, + ) + ] + ) -@server.call_tool() -async def call_tool(name: str, arguments: dict[str, Any]) -> dict[str, Any]: +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: """Handle tool calls with structured output.""" - if name == "get_weather": - city = arguments["city"] + if params.name == "get_weather": + city = (params.arguments or {})["city"] - # Simulated weather data - in production, call a weather API weather_data = { "temperature": 22.5, "condition": "partly cloudy", "humidity": 65, - "city": city, # Include the requested city + "city": city, } - # low-level server will validate structured output against the tool's - # output schema, and additionally serialize it into a TextContent block - # for backwards compatibility with pre-2025-06-18 clients. - return weather_data - else: - raise ValueError(f"Unknown tool: {name}") + return types.CallToolResult( + content=[types.TextContent(type="text", text=json.dumps(weather_data, indent=2))], + structured_content=weather_data, + ) + + raise ValueError(f"Unknown tool: {params.name}") + + +server = Server( + "example-server", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) async def run(): @@ -1899,14 +1895,7 @@ async def run(): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="structured-output-example", - server_version="0.1.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) @@ -1917,18 +1906,11 @@ if __name__ == "__main__": _Full example: [examples/snippets/servers/lowlevel/structured_output.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/lowlevel/structured_output.py)_ -Tools can return data in four ways: +With the low-level server, handlers always return `CallToolResult` directly. You construct both the human-readable `content` and the machine-readable `structured_content` yourself, giving you full control over the response. -1. **Content only**: Return a list of content blocks (default behavior before spec revision 2025-06-18) -2. **Structured data only**: Return a dictionary that will be serialized to JSON (Introduced in spec revision 2025-06-18) -3. **Both**: Return a tuple of (content, structured_data) preferred option to use for backwards compatibility -4. **Direct CallToolResult**: Return `CallToolResult` directly for full control (including `_meta` field) +##### Returning CallToolResult with `_meta` -When an `outputSchema` is defined, the server automatically validates the structured output against the schema. This ensures type safety and helps catch errors early. - -##### Returning CallToolResult Directly - -For full control over the response including the `_meta` field (for passing data to client applications without exposing it to the model), return `CallToolResult` directly: +For passing data to client applications without exposing it to the model, use the `_meta` field on `CallToolResult`: ```python @@ -1937,44 +1919,49 @@ uv run examples/snippets/servers/lowlevel/direct_call_tool_result.py """ import asyncio -from typing import Any import mcp.server.stdio from mcp import types -from mcp.server.lowlevel import NotificationOptions, Server -from mcp.server.models import InitializationOptions - -server = Server("example-server") +from mcp.server import Server, ServerRequestContext -@server.list_tools() -async def list_tools() -> list[types.Tool]: +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: """List available tools.""" - return [ - types.Tool( - name="advanced_tool", - description="Tool with full control including _meta field", - input_schema={ - "type": "object", - "properties": {"message": {"type": "string"}}, - "required": ["message"], - }, - ) - ] + return types.ListToolsResult( + tools=[ + types.Tool( + name="advanced_tool", + description="Tool with full control including _meta field", + input_schema={ + "type": "object", + "properties": {"message": {"type": "string"}}, + "required": ["message"], + }, + ) + ] + ) -@server.call_tool() -async def handle_call_tool(name: str, arguments: dict[str, Any]) -> types.CallToolResult: +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: """Handle tool calls by returning CallToolResult directly.""" - if name == "advanced_tool": - message = str(arguments.get("message", "")) + if params.name == "advanced_tool": + message = (params.arguments or {}).get("message", "") return types.CallToolResult( content=[types.TextContent(type="text", text=f"Processed: {message}")], structured_content={"result": "success", "message": message}, _meta={"hidden": "data for client applications only"}, ) - raise ValueError(f"Unknown tool: {name}") + raise ValueError(f"Unknown tool: {params.name}") + + +server = Server( + "example-server", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) async def run(): @@ -1983,14 +1970,7 @@ async def run(): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="example", - server_version="0.1.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) @@ -2001,8 +1981,6 @@ if __name__ == "__main__": _Full example: [examples/snippets/servers/lowlevel/direct_call_tool_result.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/lowlevel/direct_call_tool_result.py)_ -**Note:** When returning `CallToolResult`, you bypass the automatic content/structured conversion. You must construct the complete response yourself. - ### Pagination (Advanced) For servers that need to handle large datasets, the low-level server provides paginated versions of list operations. This is an optional optimization - most servers won't need pagination unless they're dealing with hundreds or thousands of items. @@ -2011,25 +1989,23 @@ For servers that need to handle large datasets, the low-level server provides pa ```python -"""Example of implementing pagination with MCP server decorators.""" +"""Example of implementing pagination with the low-level MCP server.""" from mcp import types -from mcp.server.lowlevel import Server - -# Initialize the server -server = Server("paginated-server") +from mcp.server import Server, ServerRequestContext # Sample data to paginate ITEMS = [f"Item {i}" for i in range(1, 101)] # 100 items -@server.list_resources() -async def list_resources_paginated(request: types.ListResourcesRequest) -> types.ListResourcesResult: +async def handle_list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListResourcesResult: """List resources with pagination support.""" page_size = 10 # Extract cursor from request params - cursor = request.params.cursor if request.params is not None else None + cursor = params.cursor if params is not None else None # Parse cursor to get offset start = 0 if cursor is None else int(cursor) @@ -2045,6 +2021,9 @@ async def list_resources_paginated(request: types.ListResourcesRequest) -> types next_cursor = str(end) if end < len(ITEMS) else None return types.ListResourcesResult(resources=page_items, next_cursor=next_cursor) + + +server = Server("paginated-server", on_list_resources=handle_list_resources) ``` _Full example: [examples/snippets/servers/pagination_example.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/pagination_example.py)_ diff --git a/SECURITY.md b/SECURITY.md index 6545156105..5029242009 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,15 +1,21 @@ # Security Policy -Thank you for helping us keep the SDKs and systems they interact with secure. +Thank you for helping keep the Model Context Protocol and its ecosystem secure. ## Reporting Security Issues -This SDK is maintained by [Anthropic](https://www.anthropic.com/) as part of the Model Context Protocol project. +If you discover a security vulnerability in this repository, please report it through +the [GitHub Security Advisory process](https://docs.github.com/en/code-security/security-advisories/guidance-on-reporting-and-writing-information-about-vulnerabilities/privately-reporting-a-security-vulnerability) +for this repository. -The security of our systems and user data is Anthropic’s top priority. We appreciate the work of security researchers acting in good faith in identifying and reporting potential vulnerabilities. +Please **do not** report security vulnerabilities through public GitHub issues, discussions, +or pull requests. -Our security program is managed on HackerOne and we ask that any validated vulnerability in this functionality be reported through their [submission form](https://hackerone.com/anthropic-vdp/reports/new?type=team&report_type=vulnerability). +## What to Include -## Vulnerability Disclosure Program +To help us triage and respond quickly, please include: -Our Vulnerability Program Guidelines are defined on our [HackerOne program page](https://hackerone.com/anthropic-vdp). +- A description of the vulnerability +- Steps to reproduce the issue +- The potential impact +- Any suggested fixes (optional) diff --git a/docs/api.md b/docs/api.md deleted file mode 100644 index 3f696af543..0000000000 --- a/docs/api.md +++ /dev/null @@ -1 +0,0 @@ -::: mcp diff --git a/docs/experimental/index.md b/docs/experimental/index.md index 1d496b3f10..c97fe2a3d6 100644 --- a/docs/experimental/index.md +++ b/docs/experimental/index.md @@ -27,10 +27,9 @@ Tasks are useful for: Experimental features are accessed via the `.experimental` property: ```python -# Server-side -@server.experimental.get_task() -async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: - ... +# Server-side: enable task support (auto-registers default handlers) +server = Server(name="my-server") +server.experimental.enable_tasks() # Client-side result = await session.experimental.call_tool_as_task("tool_name", {"arg": "value"}) diff --git a/docs/experimental/tasks-server.md b/docs/experimental/tasks-server.md index 761dc5de5c..b350ee3bb6 100644 --- a/docs/experimental/tasks-server.md +++ b/docs/experimental/tasks-server.md @@ -408,16 +408,10 @@ For custom error messages, call `task.fail()` before raising. For web applications, use the Streamable HTTP transport: ```python -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager - import uvicorn -from starlette.applications import Starlette -from starlette.routing import Mount from mcp.server import Server from mcp.server.experimental.task_context import ServerTaskContext -from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.types import ( CallToolResult, CreateTaskResult, TextContent, Tool, ToolExecution, TASK_REQUIRED, ) @@ -462,22 +456,8 @@ async def handle_tool(name: str, arguments: dict) -> CallToolResult | CreateTask return CallToolResult(content=[TextContent(type="text", text=f"Unknown: {name}")], isError=True) -def create_app(): - session_manager = StreamableHTTPSessionManager(app=server) - - @asynccontextmanager - async def lifespan(app: Starlette) -> AsyncIterator[None]: - async with session_manager.run(): - yield - - return Starlette( - routes=[Mount("/mcp", app=session_manager.handle_request)], - lifespan=lifespan, - ) - - if __name__ == "__main__": - uvicorn.run(create_app(), host="127.0.0.1", port=8000) + uvicorn.run(server.streamable_http_app(), host="127.0.0.1", port=8000) ``` ## Testing Task Servers diff --git a/docs/hooks/gen_ref_pages.py b/docs/hooks/gen_ref_pages.py new file mode 100644 index 0000000000..ad8c19b45f --- /dev/null +++ b/docs/hooks/gen_ref_pages.py @@ -0,0 +1,35 @@ +"""Generate the code reference pages and navigation.""" + +from pathlib import Path + +import mkdocs_gen_files + +nav = mkdocs_gen_files.Nav() + +root = Path(__file__).parent.parent.parent +src = root / "src" + +for path in sorted(src.rglob("*.py")): + module_path = path.relative_to(src).with_suffix("") + doc_path = path.relative_to(src).with_suffix(".md") + full_doc_path = Path("api", doc_path) + + parts = tuple(module_path.parts) + + if parts[-1] == "__init__": + parts = parts[:-1] + doc_path = doc_path.with_name("index.md") + full_doc_path = full_doc_path.with_name("index.md") + elif parts[-1].startswith("_"): + continue + + nav[parts] = doc_path.as_posix() + + with mkdocs_gen_files.open(full_doc_path, "w") as fd: + ident = ".".join(parts) + fd.write(f"::: {ident}") + + mkdocs_gen_files.set_edit_path(full_doc_path, path.relative_to(root)) + +with mkdocs_gen_files.open("api/SUMMARY.md", "w") as nav_file: + nav_file.writelines(nav.build_literate_nav()) diff --git a/docs/index.md b/docs/index.md index e096d910b4..6a937da67f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,5 +1,8 @@ # MCP Python SDK +!!! info "You are viewing the in-development v2 documentation" + For the current stable release, see the [v1.x documentation](https://py.sdk.modelcontextprotocol.io/). + The **Model Context Protocol (MCP)** allows applications to provide context for LLMs in a standardized way, separating the concerns of providing context from the actual LLM interaction. This Python SDK implements the full MCP specification, making it easy to: @@ -64,4 +67,4 @@ npx -y @modelcontextprotocol/inspector ## API Reference -Full API documentation is available in the [API Reference](api.md). +Full API documentation is available in the [API Reference](api/mcp/index.md). diff --git a/docs/migration.md b/docs/migration.md index 7d30f0ac92..8b70885e8d 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -38,6 +38,7 @@ http_client = httpx.AsyncClient( headers={"Authorization": "Bearer token"}, timeout=httpx.Timeout(30, read=300), auth=my_auth, + follow_redirects=True, ) async with http_client: @@ -48,6 +49,8 @@ async with http_client: ... ``` +v1's internal client set `follow_redirects=True`; set it explicitly when supplying your own `httpx.AsyncClient` to preserve that behavior. + ### `get_session_id` callback removed from `streamable_http_client` The `get_session_id` callback (third element of the returned tuple) has been removed from `streamable_http_client`. The function now returns a 2-tuple `(read_stream, write_stream)` instead of a 3-tuple. @@ -100,6 +103,8 @@ async with http_client: The `headers`, `timeout`, `sse_read_timeout`, and `auth` parameters have been removed from `StreamableHTTPTransport`. Configure these on the `httpx.AsyncClient` instead (see example above). +Note: `sse_client` retains its `headers`, `timeout`, `sse_read_timeout`, and `auth` parameters — only the streamable HTTP transport changed. + ### Removed type aliases and classes The following deprecated type aliases and classes have been removed from `mcp.types`: @@ -126,6 +131,52 @@ from mcp.types import ContentBlock, ResourceTemplateReference # Use `str` instead of `Cursor` for pagination cursors ``` +### Field names changed from camelCase to snake_case + +All Pydantic model fields in `mcp.types` now use snake_case names for Python attribute access. The JSON wire format is unchanged — serialization still uses camelCase via Pydantic aliases. + +**Before (v1):** + +```python +result = await session.call_tool("my_tool", {"x": 1}) +if result.isError: + ... + +tools = await session.list_tools() +cursor = tools.nextCursor +schema = tools.tools[0].inputSchema +``` + +**After (v2):** + +```python +result = await session.call_tool("my_tool", {"x": 1}) +if result.is_error: + ... + +tools = await session.list_tools() +cursor = tools.next_cursor +schema = tools.tools[0].input_schema +``` + +Common renames: + +| v1 (camelCase) | v2 (snake_case) | +|----------------|-----------------| +| `inputSchema` | `input_schema` | +| `outputSchema` | `output_schema` | +| `isError` | `is_error` | +| `nextCursor` | `next_cursor` | +| `mimeType` | `mime_type` | +| `structuredContent` | `structured_content` | +| `serverInfo` | `server_info` | +| `protocolVersion` | `protocol_version` | +| `uriTemplate` | `uri_template` | +| `listChanged` | `list_changed` | +| `progressToken` | `progress_token` | + +Because `populate_by_name=True` is set, the old camelCase names still work as constructor kwargs (e.g., `Tool(inputSchema={...})` is accepted), but attribute access must use snake_case (`tool.input_schema`). + ### `args` parameter removed from `ClientSessionGroup.call_tool()` The deprecated `args` parameter has been removed from `ClientSessionGroup.call_tool()`. Use `arguments` instead. @@ -169,6 +220,30 @@ result = await session.list_resources(params=PaginatedRequestParams(cursor="next result = await session.list_tools(params=PaginatedRequestParams(cursor="next_page_token")) ``` +### `ClientSession.get_server_capabilities()` replaced by `initialize_result` property + +`ClientSession` now stores the full `InitializeResult` via an `initialize_result` property. This provides access to `server_info`, `capabilities`, `instructions`, and the negotiated `protocol_version` through a single property. The `get_server_capabilities()` method has been removed. + +**Before (v1):** + +```python +capabilities = session.get_server_capabilities() +# server_info, instructions, protocol_version were not stored — had to capture initialize() return value +``` + +**After (v2):** + +```python +result = session.initialize_result +if result is not None: + capabilities = result.capabilities + server_info = result.server_info + instructions = result.instructions + version = result.protocol_version +``` + +The high-level `Client.initialize_result` returns the same `InitializeResult` but is non-nullable — initialization is guaranteed inside the context manager, so no `None` check is needed. This replaces v1's `Client.server_capabilities`; use `client.initialize_result.capabilities` instead. + ### `McpError` renamed to `MCPError` The `McpError` exception class has been renamed to `MCPError` for consistent naming with the MCP acronym style used throughout the SDK. @@ -201,6 +276,28 @@ except MCPError as e: from mcp import MCPError ``` +The constructor signature also changed — it now takes `code`, `message`, and optional `data` directly instead of wrapping an `ErrorData`: + +**Before (v1):** + +```python +from mcp.shared.exceptions import McpError +from mcp.types import ErrorData, INVALID_REQUEST + +raise McpError(ErrorData(code=INVALID_REQUEST, message="bad input")) +``` + +**After (v2):** + +```python +from mcp.shared.exceptions import MCPError +from mcp.types import INVALID_REQUEST + +raise MCPError(INVALID_REQUEST, "bad input") +# or, if you already have an ErrorData: +raise MCPError.from_error_data(error_data) +``` + ### `FastMCP` renamed to `MCPServer` The `FastMCP` class has been renamed to `MCPServer` to better reflect its role as the main server class in the SDK. This is a simple rename with no functional changes to the class itself. @@ -216,11 +313,19 @@ mcp = FastMCP("Demo") **After (v2):** ```python -from mcp.server.mcpserver import MCPServer +from mcp.server.mcpserver import MCPServer, Context mcp = MCPServer("Demo") ``` +`Context` is the type annotation for the `ctx` parameter injected into tools, resources, and prompts (see [`get_context()` removed](#mcpserverget_context-removed) below). + +All submodules under `mcp.server.fastmcp.*` are now under `mcp.server.mcpserver.*` with the same structure. Common imports: + +- `Image`, `Audio` — from `mcp.server.mcpserver` (or `.utilities.types`) +- `UserMessage`, `AssistantMessage` — from `mcp.server.mcpserver.prompts.base` +- `ToolError`, `ResourceError` — from `mcp.server.mcpserver.exceptions` + ### `mount_path` parameter removed from MCPServer The `mount_path` parameter has been removed from `MCPServer.__init__()`, `MCPServer.run()`, `MCPServer.run_sse_async()`, and `MCPServer.sse_app()`. It was also removed from the `Settings` class. @@ -288,6 +393,100 @@ app = Starlette(routes=[Mount("/", app=mcp.streamable_http_app(json_response=Tru **Note:** DNS rebinding protection is automatically enabled when `host` is `127.0.0.1`, `localhost`, or `::1`. This now happens in `sse_app()` and `streamable_http_app()` instead of the constructor. +If you were mutating these via `mcp.settings` after construction (e.g., `mcp.settings.port = 9000`), pass them to `run()` / `sse_app()` / `streamable_http_app()` instead — these fields no longer exist on `Settings`. The `debug` and `log_level` parameters remain on the constructor. + +### `MCPServer.get_context()` removed + +`MCPServer.get_context()` has been removed. Context is now injected by the framework and passed explicitly — there is no ambient ContextVar to read from. + +**If you were calling `get_context()` from inside a tool/resource/prompt:** use the `ctx: Context` parameter injection instead. + +**Before (v1):** + +```python +@mcp.tool() +async def my_tool(x: int) -> str: + ctx = mcp.get_context() + await ctx.info("Processing...") + return str(x) +``` + +**After (v2):** + +```python +from mcp.server.mcpserver import Context + +@mcp.tool() +async def my_tool(x: int, ctx: Context) -> str: + await ctx.info("Processing...") + return str(x) +``` + +### `MCPServer.call_tool()`, `read_resource()`, `get_prompt()` now accept a `context` parameter + +`MCPServer.call_tool()`, `MCPServer.read_resource()`, and `MCPServer.get_prompt()` now accept an optional `context: Context | None = None` parameter. The framework passes this automatically during normal request handling. If you call these methods directly and omit `context`, a Context with no active request is constructed for you — tools that don't use `ctx` work normally, but any attempt to use `ctx.session`, `ctx.request_id`, etc. will raise. + +The internal layers (`ToolManager.call_tool`, `Tool.run`, `Prompt.render`, `ResourceTemplate.create_resource`, etc.) now require `context` as a positional argument. + +### Registering lowlevel handlers on `MCPServer` (workaround) + +`MCPServer` does not expose public APIs for `subscribe_resource`, `unsubscribe_resource`, or `set_logging_level` handlers. In v1, the workaround was to reach into the private lowlevel server and use its decorator methods: + +**Before (v1):** + +```python +@mcp._mcp_server.set_logging_level() # pyright: ignore[reportPrivateUsage] +async def handle_set_logging_level(level: str) -> None: + ... + +mcp._mcp_server.subscribe_resource()(handle_subscribe) # pyright: ignore[reportPrivateUsage] +``` + +In v2, the lowlevel `Server` no longer has decorator methods (handlers are constructor-only), so the equivalent workaround is `_add_request_handler`: + +**After (v2):** + +```python +from mcp.server import ServerRequestContext +from mcp.types import EmptyResult, SetLevelRequestParams, SubscribeRequestParams + + +async def handle_set_logging_level(ctx: ServerRequestContext, params: SetLevelRequestParams) -> EmptyResult: + ... + return EmptyResult() + + +async def handle_subscribe(ctx: ServerRequestContext, params: SubscribeRequestParams) -> EmptyResult: + ... + return EmptyResult() + + +mcp._lowlevel_server._add_request_handler("logging/setLevel", handle_set_logging_level) # pyright: ignore[reportPrivateUsage] +mcp._lowlevel_server._add_request_handler("resources/subscribe", handle_subscribe) # pyright: ignore[reportPrivateUsage] +``` + +This is a private API and may change. A public way to register these handlers on `MCPServer` is planned; until then, use this workaround or use the lowlevel `Server` directly. + +### `MCPServer`'s `Context` logging: `message` renamed to `data`, `extra` removed + +On the high-level `Context` object (`mcp.server.mcpserver.Context`), `log()`, `.debug()`, `.info()`, `.warning()`, and `.error()` now take `data: Any` instead of `message: str`, matching the MCP spec's `LoggingMessageNotificationParams.data` field which allows any JSON-serializable value. The `extra` parameter has been removed — pass structured data directly as `data`. + +The lowlevel `ServerSession.send_log_message(data: Any)` already accepted arbitrary data and is unchanged. + +`Context.log()` also now accepts all eight RFC-5424 log levels (`debug`, `info`, `notice`, `warning`, `error`, `critical`, `alert`, `emergency`) via the `LoggingLevel` type, not just the four it previously allowed. + +```python +# Before +await ctx.info("Connection failed", extra={"host": "localhost", "port": 5432}) +await ctx.log(level="info", message="hello") + +# After +await ctx.info({"message": "Connection failed", "host": "localhost", "port": 5432}) +await ctx.log(level="info", data="hello") +``` + +Positional calls (`await ctx.info("hello")`) are unaffected. + ### Replace `RootModel` by union types with `TypeAdapter` validation The following union types are no longer `RootModel` subclasses: @@ -328,6 +527,22 @@ notification = server_notification_adapter.validate_python(data) # No .root access needed - notification is the actual type ``` +The same applies when constructing values — the wrapper call is no longer needed: + +**Before (v1):** + +```python +await session.send_notification(ClientNotification(InitializedNotification())) +await session.send_request(ClientRequest(PingRequest()), EmptyResult) +``` + +**After (v2):** + +```python +await session.send_notification(InitializedNotification()) +await session.send_request(PingRequest(), EmptyResult) +``` + **Available adapters:** | Union Type | Adapter | @@ -351,7 +566,6 @@ The nested `RequestParams.Meta` Pydantic model class has been replaced with a to - `RequestParams.Meta` (Pydantic model) → `RequestParamsMeta` (TypedDict) - Attribute access (`meta.progress_token`) → Dictionary access (`meta.get("progress_token")`) - `progress_token` field changed from `ProgressToken | None = None` to `NotRequired[ProgressToken]` -` **In request context handlers:** @@ -364,14 +578,17 @@ async def handle_tool(name: str, arguments: dict) -> list[TextContent]: await ctx.session.send_progress_notification(ctx.meta.progress_token, 0.5, 100) # After (v2) -@server.call_tool() -async def handle_tool(name: str, arguments: dict) -> list[TextContent]: - ctx = server.request_context +async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: if ctx.meta and "progress_token" in ctx.meta: await ctx.session.send_progress_notification(ctx.meta["progress_token"], 0.5, 100) + ... + +server = Server("my-server", on_call_tool=handle_call_tool) ``` -### `RequestContext` and `ProgressContext` type parameters simplified +### `RequestContext` type parameters simplified + +The `mcp.shared.context` module has been removed. `RequestContext` is now split into `ClientRequestContext` (in `mcp.client.context`) and `ServerRequestContext` (in `mcp.server.context`). The `RequestContext` class has been split to separate shared fields from server-specific fields. The shared `RequestContext` now only takes 1 type parameter (the session type) instead of 3. @@ -380,40 +597,115 @@ The `RequestContext` class has been split to separate shared fields from server- - Type parameters reduced from `RequestContext[SessionT, LifespanContextT, RequestT]` to `RequestContext[SessionT]` - Server-specific fields (`lifespan_context`, `experimental`, `request`, `close_sse_stream`, `close_standalone_sse_stream`) moved to new `ServerRequestContext` class in `mcp.server.context` -**`ProgressContext` changes:** - -- Type parameters reduced from `ProgressContext[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]` to `ProgressContext[SessionT]` - **Before (v1):** ```python from mcp.client.session import ClientSession from mcp.shared.context import RequestContext, LifespanContextT, RequestT -from mcp.shared.progress import ProgressContext # RequestContext with 3 type parameters ctx: RequestContext[ClientSession, LifespanContextT, RequestT] - -# ProgressContext with 5 type parameters -progress_ctx: ProgressContext[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT] ``` **After (v2):** ```python from mcp.client.context import ClientRequestContext -from mcp.client.session import ClientSession from mcp.server.context import ServerRequestContext, LifespanContextT, RequestT -from mcp.shared.progress import ProgressContext # For client-side context (sampling, elicitation, list_roots callbacks) ctx: ClientRequestContext # For server-specific context with lifespan and request types server_ctx: ServerRequestContext[LifespanContextT, RequestT] +``` + +The high-level `Context` class (injected into `@mcp.tool()` etc.) similarly dropped its `ServerSessionT` parameter: `Context[ServerSessionT, LifespanContextT, RequestT]` → `Context[LifespanContextT, RequestT]`. Both remaining parameters have defaults, so bare `Context` is usually sufficient: + +**Before (v1):** -# ProgressContext with 1 type parameter -progress_ctx: ProgressContext[ClientSession] +```python +async def my_tool(ctx: Context[ServerSession, None]) -> str: ... +``` + +**After (v2):** + +```python +async def my_tool(ctx: Context) -> str: ... +# or, with an explicit lifespan type: +async def my_tool(ctx: Context[MyLifespanState]) -> str: ... +``` + +### `ProgressContext` and `progress()` context manager removed + +The `mcp.shared.progress` module (`ProgressContext`, `Progress`, and the `progress()` context manager) has been removed. This module had no real-world adoption — all users send progress notifications via `Context.report_progress()` or `session.send_progress_notification()` directly. + +**Before (v1):** + +```python +from mcp.shared.progress import progress + +with progress(ctx, total=100) as p: + await p.progress(25) +``` + +**After — use `Context.report_progress()` (recommended):** + +```python +@server.tool() +async def my_tool(x: int, ctx: Context) -> str: + await ctx.report_progress(25, 100) + return "done" +``` + +**After — use `session.send_progress_notification()` (low-level):** + +```python +await session.send_progress_notification( + progress_token=progress_token, + progress=25, + total=100, +) +``` + +### `create_connected_server_and_client_session` removed + +The `create_connected_server_and_client_session` helper in `mcp.shared.memory` has been removed. Use `mcp.client.Client` instead — it accepts a `Server` or `MCPServer` instance directly and handles the in-memory transport and session setup for you. + +**Before (v1):** + +```python +from mcp.shared.memory import create_connected_server_and_client_session + +async with create_connected_server_and_client_session(server) as session: + result = await session.call_tool("my_tool", {"x": 1}) +``` + +**After (v2):** + +```python +from mcp.client import Client + +async with Client(server) as client: + result = await client.call_tool("my_tool", {"x": 1}) +``` + +`Client` accepts the same callback parameters the old helper did (`sampling_callback`, `list_roots_callback`, `logging_callback`, `message_handler`, `elicitation_callback`, `client_info`) plus `raise_exceptions` to surface server-side errors. + +If you need direct access to the underlying `ClientSession` and memory streams (e.g., for low-level transport testing), `create_client_server_memory_streams` is still available in `mcp.shared.memory`: + +```python +import anyio +from mcp.client.session import ClientSession +from mcp.shared.memory import create_client_server_memory_streams + +async with create_client_server_memory_streams() as (client_streams, server_streams): + async with anyio.create_task_group() as tg: + tg.start_soon(lambda: server.run(*server_streams, server.create_initialization_options())) + async with ClientSession(*client_streams) as session: + await session.initialize() + ... + tg.cancel_scope.cancel() ``` ### Resource URI type changed from `AnyUrl` to `str` @@ -471,12 +763,324 @@ await client.read_resource("test://resource") await client.read_resource(str(my_any_url)) ``` +### Lowlevel `Server`: constructor parameters are now keyword-only + +All parameters after `name` are now keyword-only. If you were passing `version` or other parameters positionally, use keyword arguments instead: + +```python +# Before (v1) +server = Server("my-server", "1.0") + +# After (v2) +server = Server("my-server", version="1.0") +``` + +### Lowlevel `Server`: type parameter reduced from 2 to 1 + +The `Server` class previously had two type parameters: `Server[LifespanResultT, RequestT]`. The `RequestT` parameter has been removed — handlers now receive typed params directly rather than a generic request type. + +```python +# Before (v1) +from typing import Any + +from mcp.server.lowlevel.server import Server + +server: Server[dict[str, Any], Any] = Server(...) + +# After (v2) +from typing import Any + +from mcp.server import Server + +server: Server[dict[str, Any]] = Server(...) +``` + +### Lowlevel `Server`: `request_handlers` and `notification_handlers` attributes removed + +The public `server.request_handlers` and `server.notification_handlers` dictionaries have been removed. Handler registration is now done exclusively through constructor `on_*` keyword arguments. There is no public API to register handlers after construction. + +```python +# Before (v1) — direct dict access +from mcp.types import ListToolsRequest + +if ListToolsRequest in server.request_handlers: + ... + +# After (v2) — no public access to handler dicts +# Use the on_* constructor params to register handlers +server = Server("my-server", on_list_tools=handle_list_tools) +``` + +If you need to check whether a handler is registered, track this yourself — there is currently no public introspection API. + +### Lowlevel `Server`: decorator-based handlers replaced with constructor `on_*` params + +The lowlevel `Server` class no longer uses decorator methods for handler registration. Instead, handlers are passed as `on_*` keyword arguments to the constructor. + +**Before (v1):** + +```python +from mcp.server.lowlevel.server import Server + +server = Server("my-server") + +@server.list_tools() +async def handle_list_tools(): + return [types.Tool(name="my_tool", description="A tool", inputSchema={})] + +@server.call_tool() +async def handle_call_tool(name: str, arguments: dict): + return [types.TextContent(type="text", text=f"Called {name}")] +``` + +**After (v2):** + +```python +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + CallToolRequestParams, + CallToolResult, + ListToolsResult, + PaginatedRequestParams, + TextContent, + Tool, +) + +async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="my_tool", description="A tool", input_schema={})]) + + +async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + return CallToolResult( + content=[TextContent(type="text", text=f"Called {params.name}")], + is_error=False, + ) + +server = Server("my-server", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) +``` + +**Key differences:** + +- Handlers receive `(ctx, params)` instead of the full request object or unpacked arguments. `ctx` is a `ServerRequestContext` with `session`, `lifespan_context`, and `experimental` fields (plus `request_id`, `meta`, etc. for request handlers). `params` is the typed request params object. +- Handlers return the full result type (e.g. `ListToolsResult`) rather than unwrapped values (e.g. `list[Tool]`). +- The automatic `jsonschema` input/output validation that the old `call_tool()` decorator performed has been removed. There is no built-in replacement — if you relied on schema validation in the lowlevel server, you will need to validate inputs yourself in your handler. + +**Complete handler reference:** + +All handlers receive `ctx: ServerRequestContext` as the first argument. The second argument and return type are: + +| v1 decorator | v2 constructor kwarg | `params` type | return type | +|---|---|---|---| +| `@server.list_tools()` | `on_list_tools` | `PaginatedRequestParams \| None` | `ListToolsResult` | +| `@server.call_tool()` | `on_call_tool` | `CallToolRequestParams` | `CallToolResult \| CreateTaskResult` | +| `@server.list_resources()` | `on_list_resources` | `PaginatedRequestParams \| None` | `ListResourcesResult` | +| `@server.list_resource_templates()` | `on_list_resource_templates` | `PaginatedRequestParams \| None` | `ListResourceTemplatesResult` | +| `@server.read_resource()` | `on_read_resource` | `ReadResourceRequestParams` | `ReadResourceResult` | +| `@server.subscribe_resource()` | `on_subscribe_resource` | `SubscribeRequestParams` | `EmptyResult` | +| `@server.unsubscribe_resource()` | `on_unsubscribe_resource` | `UnsubscribeRequestParams` | `EmptyResult` | +| `@server.list_prompts()` | `on_list_prompts` | `PaginatedRequestParams \| None` | `ListPromptsResult` | +| `@server.get_prompt()` | `on_get_prompt` | `GetPromptRequestParams` | `GetPromptResult` | +| `@server.completion()` | `on_completion` | `CompleteRequestParams` | `CompleteResult` | +| `@server.set_logging_level()` | `on_set_logging_level` | `SetLevelRequestParams` | `EmptyResult` | +| — | `on_ping` | `RequestParams \| None` | `EmptyResult` | +| `@server.progress_notification()` | `on_progress` | `ProgressNotificationParams` | `None` | +| — | `on_roots_list_changed` | `NotificationParams \| None` | `None` | + +All `params` and return types are importable from `mcp.types`. + +**Notification handlers:** + +```python +from mcp.server import Server, ServerRequestContext +from mcp.types import ProgressNotificationParams + + +async def handle_progress(ctx: ServerRequestContext, params: ProgressNotificationParams) -> None: + print(f"Progress: {params.progress}/{params.total}") + +server = Server("my-server", on_progress=handle_progress) +``` + +### Lowlevel `Server`: automatic return value wrapping removed + +The old decorator-based handlers performed significant automatic wrapping of return values. This magic has been removed — handlers now return fully constructed result types. If you want these conveniences, use `MCPServer` (previously `FastMCP`) instead of the lowlevel `Server`. + +**`call_tool()` — structured output wrapping removed:** + +The old decorator accepted several return types and auto-wrapped them into `CallToolResult`: + +```python +# Before (v1) — returning a dict auto-wrapped into structured_content + JSON TextContent +@server.call_tool() +async def handle(name: str, arguments: dict) -> dict: + return {"temperature": 22.5, "city": "London"} + +# Before (v1) — returning a list auto-wrapped into CallToolResult.content +@server.call_tool() +async def handle(name: str, arguments: dict) -> list[TextContent]: + return [TextContent(type="text", text="Done")] +``` + +```python +# After (v2) — construct the full result yourself +import json + +async def handle(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + data = {"temperature": 22.5, "city": "London"} + return CallToolResult( + content=[TextContent(type="text", text=json.dumps(data, indent=2))], + structured_content=data, + ) +``` + +Note: `params.arguments` can be `None` (the old decorator defaulted it to `{}`). Use `params.arguments or {}` to preserve the old behavior. + +**`read_resource()` — content type wrapping removed:** + +The old decorator auto-wrapped `Iterable[ReadResourceContents]` (and the deprecated `str`/`bytes` shorthand) into `TextResourceContents`/`BlobResourceContents`, handling base64 encoding and mime-type defaulting: + +```python +# Before (v1) — Iterable[ReadResourceContents] auto-wrapped +from mcp.server.lowlevel.helper_types import ReadResourceContents + +@server.read_resource() +async def handle(uri: AnyUrl) -> Iterable[ReadResourceContents]: + return [ReadResourceContents(content="file contents", mime_type="text/plain")] + +# Before (v1) — str/bytes shorthand (already deprecated in v1) +@server.read_resource() +async def handle(uri: str) -> str: + return "file contents" + +@server.read_resource() +async def handle(uri: str) -> bytes: + return b"\x89PNG..." +``` + +```python +# After (v2) — construct TextResourceContents or BlobResourceContents yourself +import base64 + +async def handle_read(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: + # Text content + return ReadResourceResult( + contents=[TextResourceContents(uri=str(params.uri), text="file contents", mime_type="text/plain")] + ) + +async def handle_read(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: + # Binary content — you must base64-encode it yourself + return ReadResourceResult( + contents=[BlobResourceContents( + uri=str(params.uri), + blob=base64.b64encode(b"\x89PNG...").decode("utf-8"), + mime_type="image/png", + )] + ) +``` + +**`list_tools()`, `list_resources()`, `list_prompts()` — list wrapping removed:** + +The old decorators accepted bare lists and wrapped them into the result type: + +```python +# Before (v1) +@server.list_tools() +async def handle() -> list[Tool]: + return [Tool(name="my_tool", ...)] + +# After (v2) +async def handle(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="my_tool", ...)]) +``` + +**Using `MCPServer` instead:** + +If you prefer the convenience of automatic wrapping, use `MCPServer` which still provides these features through its `@mcp.tool()`, `@mcp.resource()`, and `@mcp.prompt()` decorators. The lowlevel `Server` is intentionally minimal — it provides no magic and gives you full control over the MCP protocol types. + +### Lowlevel `Server`: `request_context` property removed + +The `server.request_context` property has been removed. Request context is now passed directly to handlers as the first argument (`ctx`). The `request_ctx` module-level contextvar has been removed entirely. + +**Before (v1):** + +```python +from mcp.server.lowlevel.server import request_ctx + +@server.call_tool() +async def handle_call_tool(name: str, arguments: dict): + ctx = server.request_context # or request_ctx.get() + await ctx.session.send_log_message(level="info", data="Processing...") + return [types.TextContent(type="text", text="Done")] +``` + +**After (v2):** + +```python +from mcp.server import ServerRequestContext +from mcp.types import CallToolRequestParams, CallToolResult, TextContent + + +async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + await ctx.session.send_log_message(level="info", data="Processing...") + return CallToolResult( + content=[TextContent(type="text", text="Done")], + is_error=False, + ) +``` + +### `RequestContext`: request-specific fields are now optional + +The `RequestContext` class now uses optional fields for request-specific data (`request_id`, `meta`, etc.) so it can be used for both request and notification handlers. In notification handlers, these fields are `None`. + +```python +from mcp.server import ServerRequestContext + +# request_id, meta, etc. are available in request handlers +# but None in notification handlers +``` + +### Experimental: task handler decorators removed + +The experimental decorator methods on `ExperimentalHandlers` (`@server.experimental.list_tasks()`, `@server.experimental.get_task()`, etc.) have been removed. + +Default task handlers are still registered automatically via `server.experimental.enable_tasks()`. Custom handlers can be passed as `on_*` kwargs to override specific defaults. + +**Before (v1):** + +```python +server = Server("my-server") +server.experimental.enable_tasks() + +@server.experimental.get_task() +async def custom_get_task(request: GetTaskRequest) -> GetTaskResult: + ... +``` + +**After (v2):** + +```python +from mcp.server import Server, ServerRequestContext +from mcp.types import GetTaskRequestParams, GetTaskResult + + +async def custom_get_task(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: + ... + + +server = Server("my-server") +server.experimental.enable_tasks(on_get_task=custom_get_task) +``` + ## Deprecations ## Bug Fixes +### Lowlevel `Server`: `subscribe` capability now correctly reported + +Previously, the lowlevel `Server` hardcoded `subscribe=False` in resource capabilities even when a `subscribe_resource()` handler was registered. The `subscribe` capability is now dynamically set to `True` when an `on_subscribe_resource` handler is provided. Clients that previously didn't see `subscribe: true` in capabilities will now see it when a handler is registered, which may change client behavior. + ### Extra fields no longer allowed on top-level MCP types MCP protocol types no longer accept arbitrary extra fields at the top level. This matches the MCP specification which only allows extra fields within `_meta` objects, not on the types themselves. @@ -495,7 +1099,7 @@ params = CallToolRequestParams( params = CallToolRequestParams( name="my_tool", arguments={}, - _meta={"progressToken": "tok", "customField": "value"}, # OK + _meta={"my_custom_key": "value", "another": 123}, # OK ) ``` @@ -506,16 +1110,16 @@ params = CallToolRequestParams( The `streamable_http_app()` method is now available directly on the lowlevel `Server` class, not just `MCPServer`. This allows using the streamable HTTP transport without the MCPServer wrapper. ```python -from mcp.server.lowlevel.server import Server +from mcp.server import Server, ServerRequestContext +from mcp.types import ListToolsResult, PaginatedRequestParams -server = Server("my-server") -# Register handlers... -@server.list_tools() -async def list_tools(): - return [...] +async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[...]) + + +server = Server("my-server", on_list_tools=handle_list_tools) -# Create a Starlette app for streamable HTTP app = server.streamable_http_app( streamable_http_path="/mcp", json_response=False, @@ -529,6 +1133,6 @@ The lowlevel `Server` also now exposes a `session_manager` property to access th If you encounter issues during migration: -1. Check the [API Reference](api.md) for updated method signatures +1. Check the [API Reference](api/mcp/index.md) for updated method signatures 2. Review the [examples](https://github.com/modelcontextprotocol/python-sdk/tree/main/examples) for updated usage patterns 3. Open an issue on [GitHub](https://github.com/modelcontextprotocol/python-sdk/issues) if you find a bug or need further assistance diff --git a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py index 5fac56be5d..6ef2f0b113 100644 --- a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py +++ b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py @@ -18,7 +18,7 @@ from urllib.parse import parse_qs, urlparse import httpx -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from mcp.client._transport import ReadStream, WriteStream from mcp.client.auth import OAuthClientProvider, TokenStorage from mcp.client.session import ClientSession from mcp.client.sse import sse_client @@ -241,8 +241,8 @@ async def _default_redirect_handler(authorization_url: str) -> None: async def _run_session( self, - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], ): """Run the MCP session with the given streams.""" print("🤝 Initializing MCP session...") diff --git a/examples/servers/everything-server/mcp_everything_server/server.py b/examples/servers/everything-server/mcp_everything_server/server.py index 4fb7d9a1d3..a0620b9c1d 100644 --- a/examples/servers/everything-server/mcp_everything_server/server.py +++ b/examples/servers/everything-server/mcp_everything_server/server.py @@ -10,9 +10,9 @@ import logging import click +from mcp.server import ServerRequestContext from mcp.server.mcpserver import Context, MCPServer from mcp.server.mcpserver.prompts.base import UserMessage -from mcp.server.session import ServerSession from mcp.server.streamable_http import EventCallback, EventMessage, EventStore from mcp.types import ( AudioContent, @@ -20,13 +20,17 @@ CompletionArgument, CompletionContext, EmbeddedResource, + EmptyResult, ImageContent, JSONRPCMessage, PromptReference, ResourceTemplateReference, SamplingMessage, + SetLevelRequestParams, + SubscribeRequestParams, TextContent, TextResourceContents, + UnsubscribeRequestParams, ) from pydantic import BaseModel, Field @@ -137,7 +141,7 @@ def test_multiple_content_types() -> list[TextContent | ImageContent | EmbeddedR @mcp.tool() -async def test_tool_with_logging(ctx: Context[ServerSession, None]) -> str: +async def test_tool_with_logging(ctx: Context) -> str: """Tests tool that emits log messages during execution""" await ctx.info("Tool execution started") await asyncio.sleep(0.05) @@ -150,7 +154,7 @@ async def test_tool_with_logging(ctx: Context[ServerSession, None]) -> str: @mcp.tool() -async def test_tool_with_progress(ctx: Context[ServerSession, None]) -> str: +async def test_tool_with_progress(ctx: Context) -> str: """Tests tool that reports progress notifications""" await ctx.report_progress(progress=0, total=100, message="Completed step 0 of 100") await asyncio.sleep(0.05) @@ -168,7 +172,7 @@ async def test_tool_with_progress(ctx: Context[ServerSession, None]) -> str: @mcp.tool() -async def test_sampling(prompt: str, ctx: Context[ServerSession, None]) -> str: +async def test_sampling(prompt: str, ctx: Context) -> str: """Tests server-initiated sampling (LLM completion request)""" try: # Request sampling from client @@ -193,7 +197,7 @@ class UserResponse(BaseModel): @mcp.tool() -async def test_elicitation(message: str, ctx: Context[ServerSession, None]) -> str: +async def test_elicitation(message: str, ctx: Context) -> str: """Tests server-initiated elicitation (user input request)""" try: # Request user input from client @@ -225,7 +229,7 @@ class SEP1034DefaultsSchema(BaseModel): @mcp.tool() -async def test_elicitation_sep1034_defaults(ctx: Context[ServerSession, None]) -> str: +async def test_elicitation_sep1034_defaults(ctx: Context) -> str: """Tests elicitation with default values for all primitive types (SEP-1034)""" try: # Request user input with defaults for all primitive types @@ -284,7 +288,7 @@ class EnumSchemasTestSchema(BaseModel): @mcp.tool() -async def test_elicitation_sep1330_enums(ctx: Context[ServerSession, None]) -> str: +async def test_elicitation_sep1330_enums(ctx: Context) -> str: """Tests elicitation with enum schema variations per SEP-1330""" try: result = await ctx.elicit( @@ -308,7 +312,7 @@ def test_error_handling() -> str: @mcp.tool() -async def test_reconnection(ctx: Context[ServerSession, None]) -> str: +async def test_reconnection(ctx: Context) -> str: """Tests SSE polling by closing stream mid-call (SEP-1699)""" await ctx.info("Before disconnect") @@ -393,28 +397,29 @@ def test_prompt_with_image() -> list[UserMessage]: # Custom request handlers # TODO(felix): Add public APIs to MCPServer for subscribe_resource, unsubscribe_resource, # and set_logging_level to avoid accessing protected _lowlevel_server attribute. -@mcp._lowlevel_server.set_logging_level() # pyright: ignore[reportPrivateUsage] -async def handle_set_logging_level(level: str) -> None: +async def handle_set_logging_level(ctx: ServerRequestContext, params: SetLevelRequestParams) -> EmptyResult: """Handle logging level changes""" - logger.info(f"Log level set to: {level}") - # In a real implementation, you would adjust the logging level here - # For conformance testing, we just acknowledge the request + logger.info(f"Log level set to: {params.level}") + return EmptyResult() -async def handle_subscribe(uri: str) -> None: +async def handle_subscribe(ctx: ServerRequestContext, params: SubscribeRequestParams) -> EmptyResult: """Handle resource subscription""" - resource_subscriptions.add(str(uri)) - logger.info(f"Subscribed to resource: {uri}") + resource_subscriptions.add(str(params.uri)) + logger.info(f"Subscribed to resource: {params.uri}") + return EmptyResult() -async def handle_unsubscribe(uri: str) -> None: +async def handle_unsubscribe(ctx: ServerRequestContext, params: UnsubscribeRequestParams) -> EmptyResult: """Handle resource unsubscription""" - resource_subscriptions.discard(str(uri)) - logger.info(f"Unsubscribed from resource: {uri}") + resource_subscriptions.discard(str(params.uri)) + logger.info(f"Unsubscribed from resource: {params.uri}") + return EmptyResult() -mcp._lowlevel_server.subscribe_resource()(handle_subscribe) # pyright: ignore[reportPrivateUsage] -mcp._lowlevel_server.unsubscribe_resource()(handle_unsubscribe) # pyright: ignore[reportPrivateUsage] +mcp._lowlevel_server._add_request_handler("logging/setLevel", handle_set_logging_level) # pyright: ignore[reportPrivateUsage] +mcp._lowlevel_server._add_request_handler("resources/subscribe", handle_subscribe) # pyright: ignore[reportPrivateUsage] +mcp._lowlevel_server._add_request_handler("resources/unsubscribe", handle_unsubscribe) # pyright: ignore[reportPrivateUsage] @mcp.completion() diff --git a/examples/servers/simple-auth/mcp_simple_auth/auth_server.py b/examples/servers/simple-auth/mcp_simple_auth/auth_server.py index 9d13fffe42..26c87c5ef2 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/auth_server.py +++ b/examples/servers/simple-auth/mcp_simple_auth/auth_server.py @@ -120,6 +120,8 @@ async def introspect_handler(request: Request) -> Response: "iat": int(time.time()), "token_type": "Bearer", "aud": access_token.resource, # RFC 8707 audience claim + "sub": access_token.subject, # RFC 7662 subject + "iss": str(server_settings.server_url), } ) diff --git a/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py b/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py index 3a3895cc57..48eb9a8414 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py +++ b/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py @@ -181,6 +181,7 @@ async def handle_simple_callback(self, username: str, password: str, state: str) scopes=[self.settings.mcp_scope], code_challenge=code_challenge, resource=resource, # RFC 8707 + subject=username, ) self.auth_codes[new_code] = auth_code @@ -219,6 +220,7 @@ async def exchange_authorization_code( scopes=authorization_code.scopes, expires_at=int(time.time()) + 3600, resource=authorization_code.resource, # RFC 8707 + subject=authorization_code.subject, ) # Store user data mapping for this token diff --git a/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py b/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py index 5228d034e4..641095a125 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py +++ b/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py @@ -75,6 +75,8 @@ async def verify_token(self, token: str) -> AccessToken | None: scopes=data.get("scope", "").split() if data.get("scope") else [], expires_at=data.get("exp"), resource=data.get("aud"), # Include resource in token + subject=data.get("sub"), # RFC 7662 subject (resource owner) + claims=data, ) except Exception as e: logger.warning(f"Token introspection failed: {e}") diff --git a/examples/servers/simple-pagination/README.md b/examples/servers/simple-pagination/README.md index e732b8efbe..4cab40fd34 100644 --- a/examples/servers/simple-pagination/README.md +++ b/examples/servers/simple-pagination/README.md @@ -4,14 +4,14 @@ A simple MCP server demonstrating pagination for tools, resources, and prompts u ## Usage -Start the server using either stdio (default) or SSE transport: +Start the server using either stdio (default) or Streamable HTTP transport: ```bash # Using stdio transport (default) uv run mcp-simple-pagination -# Using SSE transport on custom port -uv run mcp-simple-pagination --transport sse --port 8000 +# Using Streamable HTTP transport on custom port +uv run mcp-simple-pagination --transport streamable-http --port 8000 ``` The server exposes: diff --git a/examples/servers/simple-pagination/mcp_simple_pagination/server.py b/examples/servers/simple-pagination/mcp_simple_pagination/server.py index ff45ae2245..c94f2ac3d1 100644 --- a/examples/servers/simple-pagination/mcp_simple_pagination/server.py +++ b/examples/servers/simple-pagination/mcp_simple_pagination/server.py @@ -1,16 +1,17 @@ """Simple MCP server demonstrating pagination for tools, resources, and prompts. -This example shows how to use the paginated decorators to handle large lists -of items that need to be split across multiple pages. +This example shows how to implement pagination with the low-level server API +to handle large lists of items that need to be split across multiple pages. """ -from typing import Any +from typing import TypeVar import anyio import click from mcp import types -from mcp.server.lowlevel import Server -from starlette.requests import Request +from mcp.server import Server, ServerRequestContext + +T = TypeVar("T") # Sample data - in real scenarios, this might come from a database SAMPLE_TOOLS = [ @@ -44,176 +45,125 @@ ] +def _paginate(cursor: str | None, items: list[T], page_size: int) -> tuple[list[T], str | None]: + """Helper to paginate a list of items given a cursor.""" + if cursor is not None: + try: + start_idx = int(cursor) + except (ValueError, TypeError): + return [], None + else: + start_idx = 0 + + page = items[start_idx : start_idx + page_size] + next_cursor = str(start_idx + page_size) if start_idx + page_size < len(items) else None + return page, next_cursor + + +# Paginated list_tools - returns 5 tools per page +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + cursor = params.cursor if params is not None else None + page, next_cursor = _paginate(cursor, SAMPLE_TOOLS, page_size=5) + return types.ListToolsResult(tools=page, next_cursor=next_cursor) + + +# Paginated list_resources - returns 10 resources per page +async def handle_list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListResourcesResult: + cursor = params.cursor if params is not None else None + page, next_cursor = _paginate(cursor, SAMPLE_RESOURCES, page_size=10) + return types.ListResourcesResult(resources=page, next_cursor=next_cursor) + + +# Paginated list_prompts - returns 7 prompts per page +async def handle_list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListPromptsResult: + cursor = params.cursor if params is not None else None + page, next_cursor = _paginate(cursor, SAMPLE_PROMPTS, page_size=7) + return types.ListPromptsResult(prompts=page, next_cursor=next_cursor) + + +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + # Find the tool in our sample data + tool = next((t for t in SAMPLE_TOOLS if t.name == params.name), None) + if not tool: + raise ValueError(f"Unknown tool: {params.name}") + + return types.CallToolResult( + content=[ + types.TextContent( + type="text", + text=f"Called tool '{params.name}' with arguments: {params.arguments}", + ) + ] + ) + + +async def handle_read_resource( + ctx: ServerRequestContext, params: types.ReadResourceRequestParams +) -> types.ReadResourceResult: + resource = next((r for r in SAMPLE_RESOURCES if r.uri == str(params.uri)), None) + if not resource: + raise ValueError(f"Unknown resource: {params.uri}") + + return types.ReadResourceResult( + contents=[ + types.TextResourceContents( + uri=str(params.uri), + text=f"Content of {resource.name}: This is sample content for the resource.", + mime_type="text/plain", + ) + ] + ) + + +async def handle_get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> types.GetPromptResult: + prompt = next((p for p in SAMPLE_PROMPTS if p.name == params.name), None) + if not prompt: + raise ValueError(f"Unknown prompt: {params.name}") + + message_text = f"This is the prompt '{params.name}'" + if params.arguments: + message_text += f" with arguments: {params.arguments}" + + return types.GetPromptResult( + description=prompt.description, + messages=[ + types.PromptMessage( + role="user", + content=types.TextContent(type="text", text=message_text), + ) + ], + ) + + @click.command() -@click.option("--port", default=8000, help="Port to listen on for SSE") +@click.option("--port", default=8000, help="Port to listen on for HTTP") @click.option( "--transport", - type=click.Choice(["stdio", "sse"]), + type=click.Choice(["stdio", "streamable-http"]), default="stdio", help="Transport type", ) def main(port: int, transport: str) -> int: - app = Server("mcp-simple-pagination") - - # Paginated list_tools - returns 5 tools per page - @app.list_tools() - async def list_tools_paginated(request: types.ListToolsRequest) -> types.ListToolsResult: - page_size = 5 - - cursor = request.params.cursor if request.params is not None else None - if cursor is None: - # First page - start_idx = 0 - else: - # Parse cursor to get the start index - try: - start_idx = int(cursor) - except (ValueError, TypeError): - # Invalid cursor, return empty - return types.ListToolsResult(tools=[], next_cursor=None) - - # Get the page of tools - page_tools = SAMPLE_TOOLS[start_idx : start_idx + page_size] - - # Determine if there are more pages - next_cursor = None - if start_idx + page_size < len(SAMPLE_TOOLS): - next_cursor = str(start_idx + page_size) - - return types.ListToolsResult(tools=page_tools, next_cursor=next_cursor) - - # Paginated list_resources - returns 10 resources per page - @app.list_resources() - async def list_resources_paginated( - request: types.ListResourcesRequest, - ) -> types.ListResourcesResult: - page_size = 10 - - cursor = request.params.cursor if request.params is not None else None - if cursor is None: - # First page - start_idx = 0 - else: - # Parse cursor to get the start index - try: - start_idx = int(cursor) - except (ValueError, TypeError): - # Invalid cursor, return empty - return types.ListResourcesResult(resources=[], next_cursor=None) - - # Get the page of resources - page_resources = SAMPLE_RESOURCES[start_idx : start_idx + page_size] - - # Determine if there are more pages - next_cursor = None - if start_idx + page_size < len(SAMPLE_RESOURCES): - next_cursor = str(start_idx + page_size) - - return types.ListResourcesResult(resources=page_resources, next_cursor=next_cursor) - - # Paginated list_prompts - returns 7 prompts per page - @app.list_prompts() - async def list_prompts_paginated( - request: types.ListPromptsRequest, - ) -> types.ListPromptsResult: - page_size = 7 - - cursor = request.params.cursor if request.params is not None else None - if cursor is None: - # First page - start_idx = 0 - else: - # Parse cursor to get the start index - try: - start_idx = int(cursor) - except (ValueError, TypeError): - # Invalid cursor, return empty - return types.ListPromptsResult(prompts=[], next_cursor=None) - - # Get the page of prompts - page_prompts = SAMPLE_PROMPTS[start_idx : start_idx + page_size] - - # Determine if there are more pages - next_cursor = None - if start_idx + page_size < len(SAMPLE_PROMPTS): - next_cursor = str(start_idx + page_size) - - return types.ListPromptsResult(prompts=page_prompts, next_cursor=next_cursor) - - # Implement call_tool handler - @app.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: - # Find the tool in our sample data - tool = next((t for t in SAMPLE_TOOLS if t.name == name), None) - if not tool: - raise ValueError(f"Unknown tool: {name}") - - # Simple mock response - return [ - types.TextContent( - type="text", - text=f"Called tool '{name}' with arguments: {arguments}", - ) - ] - - # Implement read_resource handler - @app.read_resource() - async def read_resource(uri: str) -> str: - # Find the resource in our sample data - resource = next((r for r in SAMPLE_RESOURCES if r.uri == uri), None) - if not resource: - raise ValueError(f"Unknown resource: {uri}") - - # Return a simple string - the decorator will convert it to TextResourceContents - return f"Content of {resource.name}: This is sample content for the resource." - - # Implement get_prompt handler - @app.get_prompt() - async def get_prompt(name: str, arguments: dict[str, str] | None) -> types.GetPromptResult: - # Find the prompt in our sample data - prompt = next((p for p in SAMPLE_PROMPTS if p.name == name), None) - if not prompt: - raise ValueError(f"Unknown prompt: {name}") - - # Simple mock response - message_text = f"This is the prompt '{name}'" - if arguments: - message_text += f" with arguments: {arguments}" - - return types.GetPromptResult( - description=prompt.description, - messages=[ - types.PromptMessage( - role="user", - content=types.TextContent(type="text", text=message_text), - ) - ], - ) - - if transport == "sse": - from mcp.server.sse import SseServerTransport - from starlette.applications import Starlette - from starlette.responses import Response - from starlette.routing import Mount, Route - - sse = SseServerTransport("/messages/") - - async def handle_sse(request: Request): - async with sse.connect_sse(request.scope, request.receive, request._send) as streams: # type: ignore[reportPrivateUsage] - await app.run(streams[0], streams[1], app.create_initialization_options()) - return Response() - - starlette_app = Starlette( - debug=True, - routes=[ - Route("/sse", endpoint=handle_sse, methods=["GET"]), - Mount("/messages/", app=sse.handle_post_message), - ], - ) + app = Server( + "mcp-simple-pagination", + on_list_tools=handle_list_tools, + on_list_resources=handle_list_resources, + on_list_prompts=handle_list_prompts, + on_call_tool=handle_call_tool, + on_read_resource=handle_read_resource, + on_get_prompt=handle_get_prompt, + ) + if transport == "streamable-http": import uvicorn - uvicorn.run(starlette_app, host="127.0.0.1", port=port) + uvicorn.run(app.streamable_http_app(), host="127.0.0.1", port=port) else: from mcp.server.stdio import stdio_server diff --git a/examples/servers/simple-prompt/README.md b/examples/servers/simple-prompt/README.md index 48e796e198..c837da876e 100644 --- a/examples/servers/simple-prompt/README.md +++ b/examples/servers/simple-prompt/README.md @@ -4,14 +4,14 @@ A simple MCP server that exposes a customizable prompt template with optional co ## Usage -Start the server using either stdio (default) or SSE transport: +Start the server using either stdio (default) or Streamable HTTP transport: ```bash # Using stdio transport (default) uv run mcp-simple-prompt -# Using SSE transport on custom port -uv run mcp-simple-prompt --transport sse --port 8000 +# Using Streamable HTTP transport on custom port +uv run mcp-simple-prompt --transport streamable-http --port 8000 ``` The server exposes a prompt named "simple" that accepts two optional arguments: diff --git a/examples/servers/simple-prompt/mcp_simple_prompt/server.py b/examples/servers/simple-prompt/mcp_simple_prompt/server.py index cbc5a9d68f..74b71b3f38 100644 --- a/examples/servers/simple-prompt/mcp_simple_prompt/server.py +++ b/examples/servers/simple-prompt/mcp_simple_prompt/server.py @@ -1,8 +1,7 @@ import anyio import click from mcp import types -from mcp.server.lowlevel import Server -from starlette.requests import Request +from mcp.server import Server, ServerRequestContext def create_messages(context: str | None = None, topic: str | None = None) -> list[types.PromptMessage]: @@ -30,20 +29,11 @@ def create_messages(context: str | None = None, topic: str | None = None) -> lis return messages -@click.command() -@click.option("--port", default=8000, help="Port to listen on for SSE") -@click.option( - "--transport", - type=click.Choice(["stdio", "sse"]), - default="stdio", - help="Transport type", -) -def main(port: int, transport: str) -> int: - app = Server("mcp-simple-prompt") - - @app.list_prompts() - async def list_prompts() -> list[types.Prompt]: - return [ +async def handle_list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListPromptsResult: + return types.ListPromptsResult( + prompts=[ types.Prompt( name="simple", title="Simple Assistant Prompt", @@ -62,44 +52,40 @@ async def list_prompts() -> list[types.Prompt]: ], ) ] + ) - @app.get_prompt() - async def get_prompt(name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult: - if name != "simple": - raise ValueError(f"Unknown prompt: {name}") - if arguments is None: - arguments = {} +async def handle_get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> types.GetPromptResult: + if params.name != "simple": + raise ValueError(f"Unknown prompt: {params.name}") - return types.GetPromptResult( - messages=create_messages(context=arguments.get("context"), topic=arguments.get("topic")), - description="A simple prompt with optional context and topic arguments", - ) + arguments = params.arguments or {} - if transport == "sse": - from mcp.server.sse import SseServerTransport - from starlette.applications import Starlette - from starlette.responses import Response - from starlette.routing import Mount, Route + return types.GetPromptResult( + messages=create_messages(context=arguments.get("context"), topic=arguments.get("topic")), + description="A simple prompt with optional context and topic arguments", + ) - sse = SseServerTransport("/messages/") - async def handle_sse(request: Request): - async with sse.connect_sse(request.scope, request.receive, request._send) as streams: # type: ignore[reportPrivateUsage] - await app.run(streams[0], streams[1], app.create_initialization_options()) - return Response() - - starlette_app = Starlette( - debug=True, - routes=[ - Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse.handle_post_message), - ], - ) +@click.command() +@click.option("--port", default=8000, help="Port to listen on for HTTP") +@click.option( + "--transport", + type=click.Choice(["stdio", "streamable-http"]), + default="stdio", + help="Transport type", +) +def main(port: int, transport: str) -> int: + app = Server( + "mcp-simple-prompt", + on_list_prompts=handle_list_prompts, + on_get_prompt=handle_get_prompt, + ) + if transport == "streamable-http": import uvicorn - uvicorn.run(starlette_app, host="127.0.0.1", port=port) + uvicorn.run(app.streamable_http_app(), host="127.0.0.1", port=port) else: from mcp.server.stdio import stdio_server diff --git a/examples/servers/simple-resource/README.md b/examples/servers/simple-resource/README.md index df674e91e4..7fb2ab7cdc 100644 --- a/examples/servers/simple-resource/README.md +++ b/examples/servers/simple-resource/README.md @@ -4,14 +4,14 @@ A simple MCP server that exposes sample text files as resources. ## Usage -Start the server using either stdio (default) or SSE transport: +Start the server using either stdio (default) or Streamable HTTP transport: ```bash # Using stdio transport (default) uv run mcp-simple-resource -# Using SSE transport on custom port -uv run mcp-simple-resource --transport sse --port 8000 +# Using Streamable HTTP transport on custom port +uv run mcp-simple-resource --transport streamable-http --port 8000 ``` The server exposes some basic text file resources that can be read by clients. diff --git a/examples/servers/simple-resource/mcp_simple_resource/server.py b/examples/servers/simple-resource/mcp_simple_resource/server.py index 588d1044a8..8d11054145 100644 --- a/examples/servers/simple-resource/mcp_simple_resource/server.py +++ b/examples/servers/simple-resource/mcp_simple_resource/server.py @@ -1,9 +1,9 @@ +from urllib.parse import urlparse + import anyio import click from mcp import types -from mcp.server.lowlevel import Server -from mcp.server.lowlevel.helper_types import ReadResourceContents -from starlette.requests import Request +from mcp.server import Server, ServerRequestContext SAMPLE_RESOURCES = { "greeting": { @@ -21,20 +21,11 @@ } -@click.command() -@click.option("--port", default=8000, help="Port to listen on for SSE") -@click.option( - "--transport", - type=click.Choice(["stdio", "sse"]), - default="stdio", - help="Transport type", -) -def main(port: int, transport: str) -> int: - app = Server("mcp-simple-resource") - - @app.list_resources() - async def list_resources() -> list[types.Resource]: - return [ +async def handle_list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListResourcesResult: + return types.ListResourcesResult( + resources=[ types.Resource( uri=f"file:///{name}.txt", name=name, @@ -44,45 +35,50 @@ async def list_resources() -> list[types.Resource]: ) for name in SAMPLE_RESOURCES.keys() ] + ) - @app.read_resource() - async def read_resource(uri: str): - from urllib.parse import urlparse - - parsed = urlparse(uri) - if not parsed.path: - raise ValueError(f"Invalid resource path: {uri}") - name = parsed.path.replace(".txt", "").lstrip("/") - - if name not in SAMPLE_RESOURCES: - raise ValueError(f"Unknown resource: {uri}") - return [ReadResourceContents(content=SAMPLE_RESOURCES[name]["content"], mime_type="text/plain")] +async def handle_read_resource( + ctx: ServerRequestContext, params: types.ReadResourceRequestParams +) -> types.ReadResourceResult: + parsed = urlparse(str(params.uri)) + if not parsed.path: + raise ValueError(f"Invalid resource path: {params.uri}") + name = parsed.path.replace(".txt", "").lstrip("/") - if transport == "sse": - from mcp.server.sse import SseServerTransport - from starlette.applications import Starlette - from starlette.responses import Response - from starlette.routing import Mount, Route + if name not in SAMPLE_RESOURCES: + raise ValueError(f"Unknown resource: {params.uri}") - sse = SseServerTransport("/messages/") + return types.ReadResourceResult( + contents=[ + types.TextResourceContents( + uri=str(params.uri), + text=SAMPLE_RESOURCES[name]["content"], + mime_type="text/plain", + ) + ] + ) - async def handle_sse(request: Request): - async with sse.connect_sse(request.scope, request.receive, request._send) as streams: # type: ignore[reportPrivateUsage] - await app.run(streams[0], streams[1], app.create_initialization_options()) - return Response() - starlette_app = Starlette( - debug=True, - routes=[ - Route("/sse", endpoint=handle_sse, methods=["GET"]), - Mount("/messages/", app=sse.handle_post_message), - ], - ) +@click.command() +@click.option("--port", default=8000, help="Port to listen on for HTTP") +@click.option( + "--transport", + type=click.Choice(["stdio", "streamable-http"]), + default="stdio", + help="Transport type", +) +def main(port: int, transport: str) -> int: + app = Server( + "mcp-simple-resource", + on_list_resources=handle_list_resources, + on_read_resource=handle_read_resource, + ) + if transport == "streamable-http": import uvicorn - uvicorn.run(starlette_app, host="127.0.0.1", port=port) + uvicorn.run(app.streamable_http_app(), host="127.0.0.1", port=port) else: from mcp.server.stdio import stdio_server diff --git a/examples/servers/simple-streamablehttp-stateless/README.md b/examples/servers/simple-streamablehttp-stateless/README.md index b87250b353..a254f88d14 100644 --- a/examples/servers/simple-streamablehttp-stateless/README.md +++ b/examples/servers/simple-streamablehttp-stateless/README.md @@ -7,7 +7,6 @@ A stateless MCP server example demonstrating the StreamableHttp transport withou - Uses the StreamableHTTP transport in stateless mode (mcp_session_id=None) - Each request creates a new ephemeral connection - No session state maintained between requests -- Task lifecycle scoped to individual requests - Suitable for deployment in multi-node environments ## Usage diff --git a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py index 9fed2f0aa6..e2b8d2ef2f 100644 --- a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py +++ b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py @@ -1,76 +1,20 @@ -import contextlib import logging -from collections.abc import AsyncIterator -from typing import Any import anyio import click import uvicorn from mcp import types -from mcp.server.lowlevel import Server -from mcp.server.streamable_http_manager import StreamableHTTPSessionManager -from starlette.applications import Starlette +from mcp.server import Server, ServerRequestContext from starlette.middleware.cors import CORSMiddleware -from starlette.routing import Mount -from starlette.types import Receive, Scope, Send logger = logging.getLogger(__name__) -@click.command() -@click.option("--port", default=3000, help="Port to listen on for HTTP") -@click.option( - "--log-level", - default="INFO", - help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", -) -@click.option( - "--json-response", - is_flag=True, - default=False, - help="Enable JSON responses instead of SSE streams", -) -def main( - port: int, - log_level: str, - json_response: bool, -) -> None: - # Configure logging - logging.basicConfig( - level=getattr(logging, log_level.upper()), - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - ) - - app = Server("mcp-streamable-http-stateless-demo") - - @app.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: - ctx = app.request_context - interval = arguments.get("interval", 1.0) - count = arguments.get("count", 5) - caller = arguments.get("caller", "unknown") - - # Send the specified number of notifications with the given interval - for i in range(count): - await ctx.session.send_log_message( - level="info", - data=f"Notification {i + 1}/{count} from caller: {caller}", - logger="notification_stream", - related_request_id=ctx.request_id, - ) - if i < count - 1: # Don't wait after the last notification - await anyio.sleep(interval) - - return [ - types.TextContent( - type="text", - text=(f"Sent {count} notifications with {interval}s interval for caller: {caller}"), - ) - ] - - @app.list_tools() - async def list_tools() -> list[types.Tool]: - return [ +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ types.Tool( name="start-notification-stream", description=("Sends a stream of notifications with configurable count and interval"), @@ -94,40 +38,77 @@ async def list_tools() -> list[types.Tool]: }, ) ] + ) - # Create the session manager with true stateless mode - session_manager = StreamableHTTPSessionManager( - app=app, - event_store=None, - json_response=json_response, - stateless=True, + +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + arguments = params.arguments or {} + interval = arguments.get("interval", 1.0) + count = arguments.get("count", 5) + caller = arguments.get("caller", "unknown") + + # Send the specified number of notifications with the given interval + for i in range(count): + await ctx.session.send_log_message( + level="info", + data=f"Notification {i + 1}/{count} from caller: {caller}", + logger="notification_stream", + related_request_id=ctx.request_id, + ) + if i < count - 1: # Don't wait after the last notification + await anyio.sleep(interval) + + return types.CallToolResult( + content=[ + types.TextContent( + type="text", + text=(f"Sent {count} notifications with {interval}s interval for caller: {caller}"), + ) + ] ) - async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None: - await session_manager.handle_request(scope, receive, send) - @contextlib.asynccontextmanager - async def lifespan(app: Starlette) -> AsyncIterator[None]: - """Context manager for session manager.""" - async with session_manager.run(): - logger.info("Application started with StreamableHTTP session manager!") - try: - yield - finally: - logger.info("Application shutting down...") +@click.command() +@click.option("--port", default=3000, help="Port to listen on for HTTP") +@click.option( + "--log-level", + default="INFO", + help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", +) +@click.option( + "--json-response", + is_flag=True, + default=False, + help="Enable JSON responses instead of SSE streams", +) +def main( + port: int, + log_level: str, + json_response: bool, +) -> None: + # Configure logging + logging.basicConfig( + level=getattr(logging, log_level.upper()), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + app = Server( + "mcp-streamable-http-stateless-demo", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) - # Create an ASGI application using the transport - starlette_app = Starlette( + starlette_app = app.streamable_http_app( + stateless_http=True, + json_response=json_response, debug=True, - routes=[Mount("/mcp", app=handle_streamable_http)], - lifespan=lifespan, ) # Wrap ASGI application with CORS middleware to expose Mcp-Session-Id header # for browser-based clients (ensures 500 errors get proper CORS headers) starlette_app = CORSMiddleware( starlette_app, - allow_origins=["*"], # Allow all origins - adjust as needed for production + allow_origins=["*"], # Note: streamable_http_app() enforces localhost-only Origin by default allow_methods=["GET", "POST", "DELETE"], # MCP streamable HTTP methods expose_headers=["Mcp-Session-Id"], ) diff --git a/examples/servers/simple-streamablehttp/README.md b/examples/servers/simple-streamablehttp/README.md index 9836367170..3eed3320e7 100644 --- a/examples/servers/simple-streamablehttp/README.md +++ b/examples/servers/simple-streamablehttp/README.md @@ -6,9 +6,7 @@ A simple MCP server example demonstrating the StreamableHttp transport, which en - Uses the StreamableHTTP transport for server-client communication - Supports REST API operations (POST, GET, DELETE) for `/mcp` endpoint -- Task management with anyio task groups - Ability to send multiple notifications over time to the client -- Proper resource cleanup and lifespan management - Resumability support via InMemoryEventStore ## Usage diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index ef03d9b08f..ec9761d1b0 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -1,17 +1,11 @@ -import contextlib import logging -from collections.abc import AsyncIterator -from typing import Any import anyio import click +import uvicorn from mcp import types -from mcp.server.lowlevel import Server -from mcp.server.streamable_http_manager import StreamableHTTPSessionManager -from starlette.applications import Starlette +from mcp.server import Server, ServerRequestContext from starlette.middleware.cors import CORSMiddleware -from starlette.routing import Mount -from starlette.types import Receive, Scope, Send from .event_store import InMemoryEventStore @@ -19,6 +13,75 @@ logger = logging.getLogger(__name__) +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool( + name="start-notification-stream", + description="Sends a stream of notifications with configurable count and interval", + input_schema={ + "type": "object", + "required": ["interval", "count", "caller"], + "properties": { + "interval": { + "type": "number", + "description": "Interval between notifications in seconds", + }, + "count": { + "type": "number", + "description": "Number of notifications to send", + }, + "caller": { + "type": "string", + "description": "Identifier of the caller to include in notifications", + }, + }, + }, + ) + ] + ) + + +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + arguments = params.arguments or {} + interval = arguments.get("interval", 1.0) + count = arguments.get("count", 5) + caller = arguments.get("caller", "unknown") + + # Send the specified number of notifications with the given interval + for i in range(count): + # Include more detailed message for resumability demonstration + notification_msg = f"[{i + 1}/{count}] Event from '{caller}' - Use Last-Event-ID to resume if disconnected" + await ctx.session.send_log_message( + level="info", + data=notification_msg, + logger="notification_stream", + # Associates this notification with the original request + # Ensures notifications are sent to the correct response stream + # Without this, notifications will either go to: + # - a standalone SSE stream (if GET request is supported) + # - nowhere (if GET request isn't supported) + related_request_id=ctx.request_id, + ) + logger.debug(f"Sent notification {i + 1}/{count} for caller: {caller}") + if i < count - 1: # Don't wait after the last notification + await anyio.sleep(interval) + + # This will send a resource notification through standalone SSE + # established by GET request + await ctx.session.send_resource_updated(uri="http:///test_resource") + return types.CallToolResult( + content=[ + types.TextContent( + type="text", + text=(f"Sent {count} notifications with {interval}s interval for caller: {caller}"), + ) + ] + ) + + @click.command() @click.option("--port", default=3000, help="Port to listen on for HTTP") @click.option( @@ -43,70 +106,11 @@ def main( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) - app = Server("mcp-streamable-http-demo") - - @app.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: - ctx = app.request_context - interval = arguments.get("interval", 1.0) - count = arguments.get("count", 5) - caller = arguments.get("caller", "unknown") - - # Send the specified number of notifications with the given interval - for i in range(count): - # Include more detailed message for resumability demonstration - notification_msg = f"[{i + 1}/{count}] Event from '{caller}' - Use Last-Event-ID to resume if disconnected" - await ctx.session.send_log_message( - level="info", - data=notification_msg, - logger="notification_stream", - # Associates this notification with the original request - # Ensures notifications are sent to the correct response stream - # Without this, notifications will either go to: - # - a standalone SSE stream (if GET request is supported) - # - nowhere (if GET request isn't supported) - related_request_id=ctx.request_id, - ) - logger.debug(f"Sent notification {i + 1}/{count} for caller: {caller}") - if i < count - 1: # Don't wait after the last notification - await anyio.sleep(interval) - - # This will send a resource notificaiton though standalone SSE - # established by GET request - await ctx.session.send_resource_updated(uri="http:///test_resource") - return [ - types.TextContent( - type="text", - text=(f"Sent {count} notifications with {interval}s interval for caller: {caller}"), - ) - ] - - @app.list_tools() - async def list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="start-notification-stream", - description=("Sends a stream of notifications with configurable count and interval"), - input_schema={ - "type": "object", - "required": ["interval", "count", "caller"], - "properties": { - "interval": { - "type": "number", - "description": "Interval between notifications in seconds", - }, - "count": { - "type": "number", - "description": "Number of notifications to send", - }, - "caller": { - "type": "string", - "description": ("Identifier of the caller to include in notifications"), - }, - }, - }, - ) - ] + app = Server( + "mcp-streamable-http-demo", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) # Create event store for resumability # The InMemoryEventStore enables resumability support for StreamableHTTP transport. @@ -118,47 +122,21 @@ async def list_tools() -> list[types.Tool]: # For production, use a persistent storage solution. event_store = InMemoryEventStore() - # Create the session manager with our app and event store - session_manager = StreamableHTTPSessionManager( - app=app, - event_store=event_store, # Enable resumability + starlette_app = app.streamable_http_app( + event_store=event_store, json_response=json_response, - ) - - # ASGI handler for streamable HTTP connections - async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None: - await session_manager.handle_request(scope, receive, send) - - @contextlib.asynccontextmanager - async def lifespan(app: Starlette) -> AsyncIterator[None]: - """Context manager for managing session manager lifecycle.""" - async with session_manager.run(): - logger.info("Application started with StreamableHTTP session manager!") - try: - yield - finally: - logger.info("Application shutting down...") - - # Create an ASGI application using the transport - starlette_app = Starlette( debug=True, - routes=[ - Mount("/mcp", app=handle_streamable_http), - ], - lifespan=lifespan, ) # Wrap ASGI application with CORS middleware to expose Mcp-Session-Id header # for browser-based clients (ensures 500 errors get proper CORS headers) starlette_app = CORSMiddleware( starlette_app, - allow_origins=["*"], # Allow all origins - adjust as needed for production + allow_origins=["*"], # Note: streamable_http_app() enforces localhost-only Origin by default allow_methods=["GET", "POST", "DELETE"], # MCP streamable HTTP methods expose_headers=["Mcp-Session-Id"], ) - import uvicorn - uvicorn.run(starlette_app, host="127.0.0.1", port=port) return 0 diff --git a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py b/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py index dc689ed942..bc06e12088 100644 --- a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py +++ b/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py @@ -6,49 +6,41 @@ - ServerTaskContext.elicit() and ServerTaskContext.create_message() queue requests properly """ -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager from typing import Any import click import uvicorn from mcp import types +from mcp.server import Server, ServerRequestContext from mcp.server.experimental.task_context import ServerTaskContext -from mcp.server.lowlevel import Server -from mcp.server.streamable_http_manager import StreamableHTTPSessionManager -from starlette.applications import Starlette -from starlette.routing import Mount -server = Server("simple-task-interactive") -# Enable task support - this auto-registers all handlers -server.experimental.enable_tasks() +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool( + name="confirm_delete", + description="Asks for confirmation before deleting (demonstrates elicitation)", + input_schema={ + "type": "object", + "properties": {"filename": {"type": "string"}}, + }, + execution=types.ToolExecution(task_support=types.TASK_REQUIRED), + ), + types.Tool( + name="write_haiku", + description="Asks LLM to write a haiku (demonstrates sampling)", + input_schema={"type": "object", "properties": {"topic": {"type": "string"}}}, + execution=types.ToolExecution(task_support=types.TASK_REQUIRED), + ), + ] + ) -@server.list_tools() -async def list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="confirm_delete", - description="Asks for confirmation before deleting (demonstrates elicitation)", - input_schema={ - "type": "object", - "properties": {"filename": {"type": "string"}}, - }, - execution=types.ToolExecution(task_support=types.TASK_REQUIRED), - ), - types.Tool( - name="write_haiku", - description="Asks LLM to write a haiku (demonstrates sampling)", - input_schema={"type": "object", "properties": {"topic": {"type": "string"}}}, - execution=types.ToolExecution(task_support=types.TASK_REQUIRED), - ), - ] - - -async def handle_confirm_delete(arguments: dict[str, Any]) -> types.CreateTaskResult: +async def handle_confirm_delete(ctx: ServerRequestContext, arguments: dict[str, Any]) -> types.CreateTaskResult: """Handle the confirm_delete tool - demonstrates elicitation.""" - ctx = server.request_context ctx.experimental.validate_task_mode(types.TASK_REQUIRED) filename = arguments.get("filename", "unknown.txt") @@ -80,9 +72,8 @@ async def work(task: ServerTaskContext) -> types.CallToolResult: return await ctx.experimental.run_task(work) -async def handle_write_haiku(arguments: dict[str, Any]) -> types.CreateTaskResult: +async def handle_write_haiku(ctx: ServerRequestContext, arguments: dict[str, Any]) -> types.CreateTaskResult: """Handle the write_haiku tool - demonstrates sampling.""" - ctx = server.request_context ctx.experimental.validate_task_mode(types.TASK_REQUIRED) topic = arguments.get("topic", "nature") @@ -111,37 +102,37 @@ async def work(task: ServerTaskContext) -> types.CallToolResult: return await ctx.experimental.run_task(work) -@server.call_tool() -async def handle_call_tool(name: str, arguments: dict[str, Any]) -> types.CallToolResult | types.CreateTaskResult: +async def handle_call_tool( + ctx: ServerRequestContext, params: types.CallToolRequestParams +) -> types.CallToolResult | types.CreateTaskResult: """Dispatch tool calls to their handlers.""" - if name == "confirm_delete": - return await handle_confirm_delete(arguments) - elif name == "write_haiku": - return await handle_write_haiku(arguments) - else: - return types.CallToolResult( - content=[types.TextContent(type="text", text=f"Unknown tool: {name}")], - is_error=True, - ) + arguments = params.arguments or {} + if params.name == "confirm_delete": + return await handle_confirm_delete(ctx, arguments) + elif params.name == "write_haiku": + return await handle_write_haiku(ctx, arguments) -def create_app(session_manager: StreamableHTTPSessionManager) -> Starlette: - @asynccontextmanager - async def app_lifespan(app: Starlette) -> AsyncIterator[None]: - async with session_manager.run(): - yield - - return Starlette( - routes=[Mount("/mcp", app=session_manager.handle_request)], - lifespan=app_lifespan, + return types.CallToolResult( + content=[types.TextContent(type="text", text=f"Unknown tool: {params.name}")], + is_error=True, ) +server = Server( + "simple-task-interactive", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) + +# Enable task support - this auto-registers all handlers +server.experimental.enable_tasks() + + @click.command() @click.option("--port", default=8000, help="Port to listen on") def main(port: int) -> int: - session_manager = StreamableHTTPSessionManager(app=server) - starlette_app = create_app(session_manager) + starlette_app = server.streamable_http_app() print(f"Starting server on http://localhost:{port}/mcp") uvicorn.run(starlette_app, host="127.0.0.1", port=port) return 0 diff --git a/examples/servers/simple-task/mcp_simple_task/server.py b/examples/servers/simple-task/mcp_simple_task/server.py index ec16b15ae8..7583cd8f0e 100644 --- a/examples/servers/simple-task/mcp_simple_task/server.py +++ b/examples/servers/simple-task/mcp_simple_task/server.py @@ -1,83 +1,69 @@ """Simple task server demonstrating MCP tasks over streamable HTTP.""" -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager -from typing import Any - import anyio import click import uvicorn from mcp import types +from mcp.server import Server, ServerRequestContext from mcp.server.experimental.task_context import ServerTaskContext -from mcp.server.lowlevel import Server -from mcp.server.streamable_http_manager import StreamableHTTPSessionManager -from starlette.applications import Starlette -from starlette.routing import Mount -server = Server("simple-task-server") -# One-line setup: auto-registers get_task, get_task_result, list_tasks, cancel_task -server.experimental.enable_tasks() +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool( + name="long_running_task", + description="A task that takes a few seconds to complete with status updates", + input_schema={"type": "object", "properties": {}}, + execution=types.ToolExecution(task_support=types.TASK_REQUIRED), + ) + ] + ) -@server.list_tools() -async def list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="long_running_task", - description="A task that takes a few seconds to complete with status updates", - input_schema={"type": "object", "properties": {}}, - execution=types.ToolExecution(task_support=types.TASK_REQUIRED), - ) - ] +async def handle_call_tool( + ctx: ServerRequestContext, params: types.CallToolRequestParams +) -> types.CallToolResult | types.CreateTaskResult: + """Dispatch tool calls to their handlers.""" + if params.name == "long_running_task": + ctx.experimental.validate_task_mode(types.TASK_REQUIRED) + async def work(task: ServerTaskContext) -> types.CallToolResult: + await task.update_status("Starting work...") + await anyio.sleep(1) -async def handle_long_running_task(arguments: dict[str, Any]) -> types.CreateTaskResult: - """Handle the long_running_task tool - demonstrates status updates.""" - ctx = server.request_context - ctx.experimental.validate_task_mode(types.TASK_REQUIRED) + await task.update_status("Processing step 1...") + await anyio.sleep(1) - async def work(task: ServerTaskContext) -> types.CallToolResult: - await task.update_status("Starting work...") - await anyio.sleep(1) + await task.update_status("Processing step 2...") + await anyio.sleep(1) - await task.update_status("Processing step 1...") - await anyio.sleep(1) + return types.CallToolResult(content=[types.TextContent(type="text", text="Task completed!")]) - await task.update_status("Processing step 2...") - await anyio.sleep(1) + return await ctx.experimental.run_task(work) - return types.CallToolResult(content=[types.TextContent(type="text", text="Task completed!")]) + return types.CallToolResult( + content=[types.TextContent(type="text", text=f"Unknown tool: {params.name}")], + is_error=True, + ) - return await ctx.experimental.run_task(work) +server = Server( + "simple-task-server", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) -@server.call_tool() -async def handle_call_tool(name: str, arguments: dict[str, Any]) -> types.CallToolResult | types.CreateTaskResult: - """Dispatch tool calls to their handlers.""" - if name == "long_running_task": - return await handle_long_running_task(arguments) - else: - return types.CallToolResult( - content=[types.TextContent(type="text", text=f"Unknown tool: {name}")], - is_error=True, - ) +# One-line setup: auto-registers get_task, get_task_result, list_tasks, cancel_task +server.experimental.enable_tasks() @click.command() @click.option("--port", default=8000, help="Port to listen on") def main(port: int) -> int: - session_manager = StreamableHTTPSessionManager(app=server) - - @asynccontextmanager - async def app_lifespan(app: Starlette) -> AsyncIterator[None]: - async with session_manager.run(): - yield - - starlette_app = Starlette( - routes=[Mount("/mcp", app=session_manager.handle_request)], - lifespan=app_lifespan, - ) + starlette_app = server.streamable_http_app() print(f"Starting server on http://localhost:{port}/mcp") uvicorn.run(starlette_app, host="127.0.0.1", port=port) diff --git a/examples/servers/simple-tool/README.md b/examples/servers/simple-tool/README.md index 06020b4b0e..7d3759f9de 100644 --- a/examples/servers/simple-tool/README.md +++ b/examples/servers/simple-tool/README.md @@ -3,14 +3,14 @@ A simple MCP server that exposes a website fetching tool. ## Usage -Start the server using either stdio (default) or SSE transport: +Start the server using either stdio (default) or Streamable HTTP transport: ```bash # Using stdio transport (default) uv run mcp-simple-tool -# Using SSE transport on custom port -uv run mcp-simple-tool --transport sse --port 8000 +# Using Streamable HTTP transport on custom port +uv run mcp-simple-tool --transport streamable-http --port 8000 ``` The server exposes a tool named "fetch" that accepts one required argument: diff --git a/examples/servers/simple-tool/mcp_simple_tool/server.py b/examples/servers/simple-tool/mcp_simple_tool/server.py index 1c253a22ec..226058b955 100644 --- a/examples/servers/simple-tool/mcp_simple_tool/server.py +++ b/examples/servers/simple-tool/mcp_simple_tool/server.py @@ -1,11 +1,8 @@ -from typing import Any - import anyio import click from mcp import types -from mcp.server.lowlevel import Server +from mcp.server import Server, ServerRequestContext from mcp.shared._httpx_utils import create_mcp_http_client -from starlette.requests import Request async def fetch_website( @@ -18,28 +15,11 @@ async def fetch_website( return [types.TextContent(type="text", text=response.text)] -@click.command() -@click.option("--port", default=8000, help="Port to listen on for SSE") -@click.option( - "--transport", - type=click.Choice(["stdio", "sse"]), - default="stdio", - help="Transport type", -) -def main(port: int, transport: str) -> int: - app = Server("mcp-website-fetcher") - - @app.call_tool() - async def fetch_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: - if name != "fetch": - raise ValueError(f"Unknown tool: {name}") - if "url" not in arguments: - raise ValueError("Missing required argument 'url'") - return await fetch_website(arguments["url"]) - - @app.list_tools() - async def list_tools() -> list[types.Tool]: - return [ +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ types.Tool( name="fetch", title="Website Fetcher", @@ -56,31 +36,38 @@ async def list_tools() -> list[types.Tool]: }, ) ] + ) - if transport == "sse": - from mcp.server.sse import SseServerTransport - from starlette.applications import Starlette - from starlette.responses import Response - from starlette.routing import Mount, Route - sse = SseServerTransport("/messages/") +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + if params.name != "fetch": + raise ValueError(f"Unknown tool: {params.name}") + arguments = params.arguments or {} + if "url" not in arguments: + raise ValueError("Missing required argument 'url'") + content = await fetch_website(arguments["url"]) + return types.CallToolResult(content=content) - async def handle_sse(request: Request): - async with sse.connect_sse(request.scope, request.receive, request._send) as streams: # type: ignore[reportPrivateUsage] - await app.run(streams[0], streams[1], app.create_initialization_options()) - return Response() - starlette_app = Starlette( - debug=True, - routes=[ - Route("/sse", endpoint=handle_sse, methods=["GET"]), - Mount("/messages/", app=sse.handle_post_message), - ], - ) +@click.command() +@click.option("--port", default=8000, help="Port to listen on for HTTP") +@click.option( + "--transport", + type=click.Choice(["stdio", "streamable-http"]), + default="stdio", + help="Transport type", +) +def main(port: int, transport: str) -> int: + app = Server( + "mcp-website-fetcher", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) + if transport == "streamable-http": import uvicorn - uvicorn.run(starlette_app, host="127.0.0.1", port=port) + uvicorn.run(app.streamable_http_app(), host="127.0.0.1", port=port) else: from mcp.server.stdio import stdio_server diff --git a/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py index 9d7071ca70..14bc174c47 100644 --- a/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py +++ b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py @@ -12,107 +12,25 @@ uv run mcp-sse-polling-demo --port 3000 """ -import contextlib import logging -from collections.abc import AsyncIterator -from typing import Any import anyio import click +import uvicorn from mcp import types -from mcp.server.lowlevel import Server -from mcp.server.streamable_http_manager import StreamableHTTPSessionManager -from starlette.applications import Starlette -from starlette.routing import Mount -from starlette.types import Receive, Scope, Send +from mcp.server import Server, ServerRequestContext from .event_store import InMemoryEventStore logger = logging.getLogger(__name__) -@click.command() -@click.option("--port", default=3000, help="Port to listen on") -@click.option( - "--log-level", - default="INFO", - help="Logging level (DEBUG, INFO, WARNING, ERROR)", -) -@click.option( - "--retry-interval", - default=100, - help="SSE retry interval in milliseconds (sent to client)", -) -def main(port: int, log_level: str, retry_interval: int) -> int: - """Run the SSE Polling Demo server.""" - logging.basicConfig( - level=getattr(logging, log_level.upper()), - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - ) - - # Create the lowlevel server - app = Server("sse-polling-demo") - - @app.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: - """Handle tool calls.""" - ctx = app.request_context - - if name == "process_batch": - items = arguments.get("items", 10) - checkpoint_every = arguments.get("checkpoint_every", 3) - - if items < 1 or items > 100: - return [types.TextContent(type="text", text="Error: items must be between 1 and 100")] - if checkpoint_every < 1 or checkpoint_every > 20: - return [types.TextContent(type="text", text="Error: checkpoint_every must be between 1 and 20")] - - await ctx.session.send_log_message( - level="info", - data=f"Starting batch processing of {items} items...", - logger="process_batch", - related_request_id=ctx.request_id, - ) - - for i in range(1, items + 1): - # Simulate work - await anyio.sleep(0.5) - - # Report progress - await ctx.session.send_log_message( - level="info", - data=f"[{i}/{items}] Processing item {i}", - logger="process_batch", - related_request_id=ctx.request_id, - ) - - # Checkpoint: close stream to trigger client reconnect - if i % checkpoint_every == 0 and i < items: - await ctx.session.send_log_message( - level="info", - data=f"Checkpoint at item {i} - closing SSE stream for polling", - logger="process_batch", - related_request_id=ctx.request_id, - ) - if ctx.close_sse_stream: - logger.info(f"Closing SSE stream at checkpoint {i}") - await ctx.close_sse_stream() - # Wait for client to reconnect (must be > retry_interval of 100ms) - await anyio.sleep(0.2) - - return [ - types.TextContent( - type="text", - text=f"Successfully processed {items} items with checkpoints every {checkpoint_every} items", - ) - ] - - return [types.TextContent(type="text", text=f"Unknown tool: {name}")] - - @app.list_tools() - async def list_tools() -> list[types.Tool]: - """List available tools.""" - return [ +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + """List available tools.""" + return types.ListToolsResult( + tools=[ types.Tool( name="process_batch", description=( @@ -136,38 +54,104 @@ async def list_tools() -> list[types.Tool]: }, ) ] + ) - # Create event store for resumability - event_store = InMemoryEventStore() - # Create session manager with event store and retry interval - session_manager = StreamableHTTPSessionManager( - app=app, - event_store=event_store, - retry_interval=retry_interval, - ) +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + """Handle tool calls.""" + arguments = params.arguments or {} - async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None: - await session_manager.handle_request(scope, receive, send) + if params.name == "process_batch": + items = arguments.get("items", 10) + checkpoint_every = arguments.get("checkpoint_every", 3) - @contextlib.asynccontextmanager - async def lifespan(starlette_app: Starlette) -> AsyncIterator[None]: - async with session_manager.run(): - logger.info(f"SSE Polling Demo server started on port {port}") - logger.info("Try: POST /mcp with tools/call for 'process_batch'") - yield - logger.info("Server shutting down...") + if items < 1 or items > 100: + return types.CallToolResult( + content=[types.TextContent(type="text", text="Error: items must be between 1 and 100")] + ) + if checkpoint_every < 1 or checkpoint_every > 20: + return types.CallToolResult( + content=[types.TextContent(type="text", text="Error: checkpoint_every must be between 1 and 20")] + ) - starlette_app = Starlette( - debug=True, - routes=[ - Mount("/mcp", app=handle_streamable_http), - ], - lifespan=lifespan, + await ctx.session.send_log_message( + level="info", + data=f"Starting batch processing of {items} items...", + logger="process_batch", + related_request_id=ctx.request_id, + ) + + for i in range(1, items + 1): + # Simulate work + await anyio.sleep(0.5) + + # Report progress + await ctx.session.send_log_message( + level="info", + data=f"[{i}/{items}] Processing item {i}", + logger="process_batch", + related_request_id=ctx.request_id, + ) + + # Checkpoint: close stream to trigger client reconnect + if i % checkpoint_every == 0 and i < items: + await ctx.session.send_log_message( + level="info", + data=f"Checkpoint at item {i} - closing SSE stream for polling", + logger="process_batch", + related_request_id=ctx.request_id, + ) + if ctx.close_sse_stream: + logger.info(f"Closing SSE stream at checkpoint {i}") + await ctx.close_sse_stream() + # Wait for client to reconnect (must be > retry_interval of 100ms) + await anyio.sleep(0.2) + + return types.CallToolResult( + content=[ + types.TextContent( + type="text", + text=f"Successfully processed {items} items with checkpoints every {checkpoint_every} items", + ) + ] + ) + + return types.CallToolResult(content=[types.TextContent(type="text", text=f"Unknown tool: {params.name}")]) + + +@click.command() +@click.option("--port", default=3000, help="Port to listen on") +@click.option( + "--log-level", + default="INFO", + help="Logging level (DEBUG, INFO, WARNING, ERROR)", +) +@click.option( + "--retry-interval", + default=100, + help="SSE retry interval in milliseconds (sent to client)", +) +def main(port: int, log_level: str, retry_interval: int) -> int: + """Run the SSE Polling Demo server.""" + logging.basicConfig( + level=getattr(logging, log_level.upper()), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) - import uvicorn + app = Server( + "sse-polling-demo", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) + + starlette_app = app.streamable_http_app( + event_store=InMemoryEventStore(), + retry_interval=retry_interval, + debug=True, + ) + logger.info(f"SSE Polling Demo server starting on port {port}") + logger.info("Try: POST /mcp with tools/call for 'process_batch'") uvicorn.run(starlette_app, host="127.0.0.1", port=port) return 0 diff --git a/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__main__.py b/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__main__.py index fd73a54cdb..95fb908540 100644 --- a/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__main__.py +++ b/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__main__.py @@ -2,60 +2,54 @@ """Example low-level MCP server demonstrating structured output support. This example shows how to use the low-level server API to return -structured data from tools, with automatic validation against output -schemas. +structured data from tools. """ import asyncio +import json +import random from datetime import datetime -from typing import Any import mcp.server.stdio from mcp import types -from mcp.server.lowlevel import NotificationOptions, Server -from mcp.server.models import InitializationOptions +from mcp.server import Server, ServerRequestContext -# Create low-level server instance -server = Server("structured-output-lowlevel-example") - -@server.list_tools() -async def list_tools() -> list[types.Tool]: +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: """List available tools with their schemas.""" - return [ - types.Tool( - name="get_weather", - description="Get weather information (simulated)", - input_schema={ - "type": "object", - "properties": {"city": {"type": "string", "description": "City name"}}, - "required": ["city"], - }, - output_schema={ - "type": "object", - "properties": { - "temperature": {"type": "number"}, - "conditions": {"type": "string"}, - "humidity": {"type": "integer", "minimum": 0, "maximum": 100}, - "wind_speed": {"type": "number"}, - "timestamp": {"type": "string", "format": "date-time"}, + return types.ListToolsResult( + tools=[ + types.Tool( + name="get_weather", + description="Get weather information (simulated)", + input_schema={ + "type": "object", + "properties": {"city": {"type": "string", "description": "City name"}}, + "required": ["city"], + }, + output_schema={ + "type": "object", + "properties": { + "temperature": {"type": "number"}, + "conditions": {"type": "string"}, + "humidity": {"type": "integer", "minimum": 0, "maximum": 100}, + "wind_speed": {"type": "number"}, + "timestamp": {"type": "string", "format": "date-time"}, + }, + "required": ["temperature", "conditions", "humidity", "wind_speed", "timestamp"], }, - "required": ["temperature", "conditions", "humidity", "wind_speed", "timestamp"], - }, - ), - ] + ), + ] + ) -@server.call_tool() -async def call_tool(name: str, arguments: dict[str, Any]) -> Any: +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: """Handle tool call with structured output.""" - if name == "get_weather": - # city = arguments["city"] # Would be used with real weather API - + if params.name == "get_weather": # Simulate weather data (in production, call a real weather API) - import random - weather_conditions = ["sunny", "cloudy", "rainy", "partly cloudy", "foggy"] weather_data = { @@ -66,12 +60,19 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> Any: "timestamp": datetime.now().isoformat(), } - # Return structured data only - # The low-level server will serialize this to JSON content automatically - return weather_data + return types.CallToolResult( + content=[types.TextContent(type="text", text=json.dumps(weather_data, indent=2))], + structured_content=weather_data, + ) + + raise ValueError(f"Unknown tool: {params.name}") - else: - raise ValueError(f"Unknown tool: {name}") + +server = Server( + "structured-output-lowlevel-example", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) async def run(): @@ -80,14 +81,7 @@ async def run(): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="structured-output-lowlevel-example", - server_version="0.1.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) diff --git a/examples/snippets/clients/url_elicitation_client.py b/examples/snippets/clients/url_elicitation_client.py index 9888c588e6..2aecbeeee6 100644 --- a/examples/snippets/clients/url_elicitation_client.py +++ b/examples/snippets/clients/url_elicitation_client.py @@ -24,8 +24,6 @@ import asyncio import json -import subprocess -import sys import webbrowser from typing import Any from urllib.parse import urlparse @@ -56,15 +54,19 @@ async def handle_elicitation( ) +ALLOWED_SCHEMES = {"http", "https"} + + async def handle_url_elicitation( params: types.ElicitRequestParams, ) -> types.ElicitResult: """Handle URL mode elicitation - show security warning and optionally open browser. This function demonstrates the security-conscious approach to URL elicitation: - 1. Display the full URL and domain for user inspection - 2. Show the server's reason for requesting this interaction - 3. Require explicit user consent before opening any URL + 1. Validate the URL scheme before prompting the user + 2. Display the full URL and domain for user inspection + 3. Show the server's reason for requesting this interaction + 4. Require explicit user consent before opening any URL """ # Extract URL parameters - these are available on URL mode requests url = getattr(params, "url", None) @@ -75,6 +77,12 @@ async def handle_url_elicitation( print("Error: No URL provided in elicitation request") return types.ElicitResult(action="cancel") + # Reject dangerous URL schemes before prompting the user + parsed = urlparse(str(url)) + if parsed.scheme.lower() not in ALLOWED_SCHEMES: + print(f"\nRejecting URL with disallowed scheme '{parsed.scheme}': {url}") + return types.ElicitResult(action="decline") + # Extract domain for security display domain = extract_domain(url) @@ -105,7 +113,11 @@ async def handle_url_elicitation( # Open the browser print(f"\nOpening browser to: {url}") - open_browser(url) + try: + webbrowser.open(url) + except Exception as e: + print(f"Failed to open browser: {e}") + print(f"Please manually open: {url}") print("Waiting for you to complete the interaction in your browser...") print("(The server will continue once you've finished)") @@ -121,20 +133,6 @@ def extract_domain(url: str) -> str: return "unknown" -def open_browser(url: str) -> None: - """Open URL in the default browser.""" - try: - if sys.platform == "darwin": - subprocess.run(["open", url], check=False) - elif sys.platform == "win32": - subprocess.run(["start", url], shell=True, check=False) - else: - webbrowser.open(url) - except Exception as e: - print(f"Failed to open browser: {e}") - print(f"Please manually open: {url}") - - async def call_tool_with_error_handling( session: ClientSession, tool_name: str, diff --git a/examples/snippets/servers/elicitation.py b/examples/snippets/servers/elicitation.py index 70e515c75f..79453f543e 100644 --- a/examples/snippets/servers/elicitation.py +++ b/examples/snippets/servers/elicitation.py @@ -10,7 +10,6 @@ from pydantic import BaseModel, Field from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession from mcp.shared.exceptions import UrlElicitationRequiredError from mcp.types import ElicitRequestURLParams @@ -28,7 +27,7 @@ class BookingPreferences(BaseModel): @mcp.tool() -async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerSession, None]) -> str: +async def book_table(date: str, time: str, party_size: int, ctx: Context) -> str: """Book a table with date availability check. This demonstrates form mode elicitation for collecting non-sensitive user input. @@ -52,7 +51,7 @@ async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerS @mcp.tool() -async def secure_payment(amount: float, ctx: Context[ServerSession, None]) -> str: +async def secure_payment(amount: float, ctx: Context) -> str: """Process a secure payment requiring URL confirmation. This demonstrates URL mode elicitation using ctx.elicit_url() for @@ -76,7 +75,7 @@ async def secure_payment(amount: float, ctx: Context[ServerSession, None]) -> st @mcp.tool() -async def connect_service(service_name: str, ctx: Context[ServerSession, None]) -> str: +async def connect_service(service_name: str, ctx: Context) -> str: """Connect to a third-party service requiring OAuth authorization. This demonstrates the "throw error" pattern using UrlElicitationRequiredError. diff --git a/examples/snippets/servers/lowlevel/basic.py b/examples/snippets/servers/lowlevel/basic.py index 0d44325048..81f40e9945 100644 --- a/examples/snippets/servers/lowlevel/basic.py +++ b/examples/snippets/servers/lowlevel/basic.py @@ -6,32 +6,30 @@ import mcp.server.stdio from mcp import types -from mcp.server.lowlevel import NotificationOptions, Server -from mcp.server.models import InitializationOptions +from mcp.server import Server, ServerRequestContext -# Create a server instance -server = Server("example-server") - -@server.list_prompts() -async def handle_list_prompts() -> list[types.Prompt]: +async def handle_list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListPromptsResult: """List available prompts.""" - return [ - types.Prompt( - name="example-prompt", - description="An example prompt template", - arguments=[types.PromptArgument(name="arg1", description="Example argument", required=True)], - ) - ] + return types.ListPromptsResult( + prompts=[ + types.Prompt( + name="example-prompt", + description="An example prompt template", + arguments=[types.PromptArgument(name="arg1", description="Example argument", required=True)], + ) + ] + ) -@server.get_prompt() -async def handle_get_prompt(name: str, arguments: dict[str, str] | None) -> types.GetPromptResult: +async def handle_get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> types.GetPromptResult: """Get a specific prompt by name.""" - if name != "example-prompt": - raise ValueError(f"Unknown prompt: {name}") + if params.name != "example-prompt": + raise ValueError(f"Unknown prompt: {params.name}") - arg1_value = (arguments or {}).get("arg1", "default") + arg1_value = (params.arguments or {}).get("arg1", "default") return types.GetPromptResult( description="Example prompt", @@ -44,20 +42,20 @@ async def handle_get_prompt(name: str, arguments: dict[str, str] | None) -> type ) +server = Server( + "example-server", + on_list_prompts=handle_list_prompts, + on_get_prompt=handle_get_prompt, +) + + async def run(): """Run the basic low-level server.""" async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="example", - server_version="0.1.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) diff --git a/examples/snippets/servers/lowlevel/direct_call_tool_result.py b/examples/snippets/servers/lowlevel/direct_call_tool_result.py index 725f5711af..7e8fc4dcb3 100644 --- a/examples/snippets/servers/lowlevel/direct_call_tool_result.py +++ b/examples/snippets/servers/lowlevel/direct_call_tool_result.py @@ -3,44 +3,49 @@ """ import asyncio -from typing import Any import mcp.server.stdio from mcp import types -from mcp.server.lowlevel import NotificationOptions, Server -from mcp.server.models import InitializationOptions +from mcp.server import Server, ServerRequestContext -server = Server("example-server") - -@server.list_tools() -async def list_tools() -> list[types.Tool]: +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: """List available tools.""" - return [ - types.Tool( - name="advanced_tool", - description="Tool with full control including _meta field", - input_schema={ - "type": "object", - "properties": {"message": {"type": "string"}}, - "required": ["message"], - }, - ) - ] - - -@server.call_tool() -async def handle_call_tool(name: str, arguments: dict[str, Any]) -> types.CallToolResult: + return types.ListToolsResult( + tools=[ + types.Tool( + name="advanced_tool", + description="Tool with full control including _meta field", + input_schema={ + "type": "object", + "properties": {"message": {"type": "string"}}, + "required": ["message"], + }, + ) + ] + ) + + +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: """Handle tool calls by returning CallToolResult directly.""" - if name == "advanced_tool": - message = str(arguments.get("message", "")) + if params.name == "advanced_tool": + message = (params.arguments or {}).get("message", "") return types.CallToolResult( content=[types.TextContent(type="text", text=f"Processed: {message}")], structured_content={"result": "success", "message": message}, _meta={"hidden": "data for client applications only"}, ) - raise ValueError(f"Unknown tool: {name}") + raise ValueError(f"Unknown tool: {params.name}") + + +server = Server( + "example-server", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) async def run(): @@ -49,14 +54,7 @@ async def run(): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="example", - server_version="0.1.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) diff --git a/examples/snippets/servers/lowlevel/lifespan.py b/examples/snippets/servers/lowlevel/lifespan.py index da8ff7bdfd..bcd96c8935 100644 --- a/examples/snippets/servers/lowlevel/lifespan.py +++ b/examples/snippets/servers/lowlevel/lifespan.py @@ -4,12 +4,11 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager -from typing import Any +from typing import TypedDict import mcp.server.stdio from mcp import types -from mcp.server.lowlevel import NotificationOptions, Server -from mcp.server.models import InitializationOptions +from mcp.server import Server, ServerRequestContext # Mock database class for example @@ -32,52 +31,58 @@ async def query(self, query_str: str) -> list[dict[str, str]]: return [{"id": "1", "name": "Example", "query": query_str}] +class AppContext(TypedDict): + db: Database + + @asynccontextmanager -async def server_lifespan(_server: Server) -> AsyncIterator[dict[str, Any]]: +async def server_lifespan(_server: Server[AppContext]) -> AsyncIterator[AppContext]: """Manage server startup and shutdown lifecycle.""" - # Initialize resources on startup db = await Database.connect() try: yield {"db": db} finally: - # Clean up on shutdown await db.disconnect() -# Pass lifespan to server -server = Server("example-server", lifespan=server_lifespan) - - -@server.list_tools() -async def handle_list_tools() -> list[types.Tool]: +async def handle_list_tools( + ctx: ServerRequestContext[AppContext], params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: """List available tools.""" - return [ - types.Tool( - name="query_db", - description="Query the database", - input_schema={ - "type": "object", - "properties": {"query": {"type": "string", "description": "SQL query to execute"}}, - "required": ["query"], - }, - ) - ] - - -@server.call_tool() -async def query_db(name: str, arguments: dict[str, Any]) -> list[types.TextContent]: + return types.ListToolsResult( + tools=[ + types.Tool( + name="query_db", + description="Query the database", + input_schema={ + "type": "object", + "properties": {"query": {"type": "string", "description": "SQL query to execute"}}, + "required": ["query"], + }, + ) + ] + ) + + +async def handle_call_tool( + ctx: ServerRequestContext[AppContext], params: types.CallToolRequestParams +) -> types.CallToolResult: """Handle database query tool call.""" - if name != "query_db": - raise ValueError(f"Unknown tool: {name}") + if params.name != "query_db": + raise ValueError(f"Unknown tool: {params.name}") - # Access lifespan context - ctx = server.request_context db = ctx.lifespan_context["db"] + results = await db.query((params.arguments or {})["query"]) + + return types.CallToolResult(content=[types.TextContent(type="text", text=f"Query results: {results}")]) - # Execute query - results = await db.query(arguments["query"]) - return [types.TextContent(type="text", text=f"Query results: {results}")] +server = Server( + "example-server", + lifespan=server_lifespan, + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) async def run(): @@ -86,14 +91,7 @@ async def run(): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="example-server", - server_version="0.1.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) diff --git a/examples/snippets/servers/lowlevel/structured_output.py b/examples/snippets/servers/lowlevel/structured_output.py index cad8f67da6..f93c8875fd 100644 --- a/examples/snippets/servers/lowlevel/structured_output.py +++ b/examples/snippets/servers/lowlevel/structured_output.py @@ -3,62 +3,67 @@ """ import asyncio -from typing import Any +import json import mcp.server.stdio from mcp import types -from mcp.server.lowlevel import NotificationOptions, Server -from mcp.server.models import InitializationOptions +from mcp.server import Server, ServerRequestContext -server = Server("example-server") - -@server.list_tools() -async def list_tools() -> list[types.Tool]: +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: """List available tools with structured output schemas.""" - return [ - types.Tool( - name="get_weather", - description="Get current weather for a city", - input_schema={ - "type": "object", - "properties": {"city": {"type": "string", "description": "City name"}}, - "required": ["city"], - }, - output_schema={ - "type": "object", - "properties": { - "temperature": {"type": "number", "description": "Temperature in Celsius"}, - "condition": {"type": "string", "description": "Weather condition"}, - "humidity": {"type": "number", "description": "Humidity percentage"}, - "city": {"type": "string", "description": "City name"}, + return types.ListToolsResult( + tools=[ + types.Tool( + name="get_weather", + description="Get current weather for a city", + input_schema={ + "type": "object", + "properties": {"city": {"type": "string", "description": "City name"}}, + "required": ["city"], }, - "required": ["temperature", "condition", "humidity", "city"], - }, - ) - ] + output_schema={ + "type": "object", + "properties": { + "temperature": {"type": "number", "description": "Temperature in Celsius"}, + "condition": {"type": "string", "description": "Weather condition"}, + "humidity": {"type": "number", "description": "Humidity percentage"}, + "city": {"type": "string", "description": "City name"}, + }, + "required": ["temperature", "condition", "humidity", "city"], + }, + ) + ] + ) -@server.call_tool() -async def call_tool(name: str, arguments: dict[str, Any]) -> dict[str, Any]: +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: """Handle tool calls with structured output.""" - if name == "get_weather": - city = arguments["city"] + if params.name == "get_weather": + city = (params.arguments or {})["city"] - # Simulated weather data - in production, call a weather API weather_data = { "temperature": 22.5, "condition": "partly cloudy", "humidity": 65, - "city": city, # Include the requested city + "city": city, } - # low-level server will validate structured output against the tool's - # output schema, and additionally serialize it into a TextContent block - # for backwards compatibility with pre-2025-06-18 clients. - return weather_data - else: - raise ValueError(f"Unknown tool: {name}") + return types.CallToolResult( + content=[types.TextContent(type="text", text=json.dumps(weather_data, indent=2))], + structured_content=weather_data, + ) + + raise ValueError(f"Unknown tool: {params.name}") + + +server = Server( + "example-server", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) async def run(): @@ -67,14 +72,7 @@ async def run(): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="structured-output-example", - server_version="0.1.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) diff --git a/examples/snippets/servers/notifications.py b/examples/snippets/servers/notifications.py index 5579af510f..d6d903cc7f 100644 --- a/examples/snippets/servers/notifications.py +++ b/examples/snippets/servers/notifications.py @@ -1,11 +1,10 @@ from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession mcp = MCPServer(name="Notifications Example") @mcp.tool() -async def process_data(data: str, ctx: Context[ServerSession, None]) -> str: +async def process_data(data: str, ctx: Context) -> str: """Process data with logging.""" # Different log levels await ctx.debug(f"Debug: Processing '{data}'") diff --git a/examples/snippets/servers/pagination_example.py b/examples/snippets/servers/pagination_example.py index bb406653e5..bcd0ffb106 100644 --- a/examples/snippets/servers/pagination_example.py +++ b/examples/snippets/servers/pagination_example.py @@ -1,22 +1,20 @@ -"""Example of implementing pagination with MCP server decorators.""" +"""Example of implementing pagination with the low-level MCP server.""" from mcp import types -from mcp.server.lowlevel import Server - -# Initialize the server -server = Server("paginated-server") +from mcp.server import Server, ServerRequestContext # Sample data to paginate ITEMS = [f"Item {i}" for i in range(1, 101)] # 100 items -@server.list_resources() -async def list_resources_paginated(request: types.ListResourcesRequest) -> types.ListResourcesResult: +async def handle_list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListResourcesResult: """List resources with pagination support.""" page_size = 10 # Extract cursor from request params - cursor = request.params.cursor if request.params is not None else None + cursor = params.cursor if params is not None else None # Parse cursor to get offset start = 0 if cursor is None else int(cursor) @@ -32,3 +30,6 @@ async def list_resources_paginated(request: types.ListResourcesRequest) -> types next_cursor = str(end) if end < len(ITEMS) else None return types.ListResourcesResult(resources=page_items, next_cursor=next_cursor) + + +server = Server("paginated-server", on_list_resources=handle_list_resources) diff --git a/examples/snippets/servers/sampling.py b/examples/snippets/servers/sampling.py index 4ffeeda726..43259589a4 100644 --- a/examples/snippets/servers/sampling.py +++ b/examples/snippets/servers/sampling.py @@ -1,12 +1,11 @@ from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession from mcp.types import SamplingMessage, TextContent mcp = MCPServer(name="Sampling Example") @mcp.tool() -async def generate_poem(topic: str, ctx: Context[ServerSession, None]) -> str: +async def generate_poem(topic: str, ctx: Context) -> str: """Generate a poem using LLM sampling.""" prompt = f"Write a short poem about {topic}" diff --git a/examples/snippets/servers/tool_progress.py b/examples/snippets/servers/tool_progress.py index 0b283cb1f5..376dbc5db8 100644 --- a/examples/snippets/servers/tool_progress.py +++ b/examples/snippets/servers/tool_progress.py @@ -1,11 +1,10 @@ from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession mcp = MCPServer(name="Progress Example") @mcp.tool() -async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str: +async def long_running_task(task_name: str, ctx: Context, steps: int = 5) -> str: """Execute a task with progress updates.""" await ctx.info(f"Starting: {task_name}") diff --git a/mkdocs.yml b/mkdocs.yml index 3019f5214b..e48c64242d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -5,7 +5,7 @@ strict: true repo_name: modelcontextprotocol/python-sdk repo_url: https://github.com/modelcontextprotocol/python-sdk edit_uri: edit/main/docs/ -site_url: https://modelcontextprotocol.github.io/python-sdk +site_url: https://py.sdk.modelcontextprotocol.io/v2/ # TODO(Marcelo): Add Anthropic copyright? # copyright: © Model Context Protocol 2025 to present @@ -25,7 +25,7 @@ nav: - Introduction: experimental/tasks.md - Server Implementation: experimental/tasks-server.md - Client Usage: experimental/tasks-client.md - - API Reference: api.md + - API Reference: api/ theme: name: "material" @@ -112,12 +112,18 @@ watch: plugins: - search - - social + - social: + enabled: !ENV [ENABLE_SOCIAL_CARDS, false] - glightbox + - gen-files: + scripts: + - docs/hooks/gen_ref_pages.py + - literate-nav: + nav_file: SUMMARY.md - mkdocstrings: handlers: python: - paths: [src/mcp] + paths: [src] options: relative_crossrefs: true members_order: source @@ -125,8 +131,6 @@ plugins: show_signature_annotations: true signature_crossrefs: true group_by_category: false - # 3 because docs are in pages with an H2 just above them - heading_level: 3 inventories: - url: https://docs.python.org/3/objects.inv - url: https://docs.pydantic.dev/latest/objects.inv diff --git a/pyproject.toml b/pyproject.toml index bfc3067137..6d2319621a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,8 +25,8 @@ classifiers = [ "Programming Language :: Python :: 3.14", ] dependencies = [ - "anyio>=4.5", - "httpx>=0.27.1", + "anyio>=4.9", + "httpx>=0.27.1,<1.0.0", "httpx-sse>=0.4", "pydantic>=2.12.0", "starlette>=0.48.0; python_version >= '3.14'", @@ -40,6 +40,7 @@ dependencies = [ "pyjwt[crypto]>=2.10.1", "typing-extensions>=4.13.0", "typing-inspection>=0.4.1", + "opentelemetry-api>=1.28.0", ] [project.optional-dependencies] @@ -53,13 +54,31 @@ mcp = "mcp.cli:app [cli]" [tool.uv] default-groups = ["dev", "docs"] required-version = ">=0.9.5" +# PEP 517 build isolation fetches [build-system].requires (and transitives) at +# floating-latest with no hash check on every fresh sync; uv does not lock them +# (astral-sh/uv#5190). Pinning here narrows that to known-good versions. Covers +# the workspace builds (hatchling + uv-dynamic-versioning) and the legacy +# setuptools fallback used by the strict-no-cover git dep. +build-constraint-dependencies = [ + "hatchling==1.29.0", + "uv-dynamic-versioning==0.14.0", + "dunamai==1.26.1", + "jinja2==3.1.6", + "markupsafe==3.0.3", + "packaging==26.1", + "pathspec==1.0.4", + "pluggy==1.6.0", + "tomlkit==0.14.0", + "trove-classifiers==2026.1.14.14", + "setuptools==82.0.1", +] [dependency-groups] dev = [ # We add mcp[cli,ws] so `uv sync` considers the extras. "mcp[cli,ws]", "pyright>=1.1.400", - "pytest>=8.3.4", + "pytest>=8.4.0", "ruff>=0.8.5", "trio>=0.26.2", "pytest-flakefinder>=1.1.0", @@ -71,11 +90,14 @@ dev = [ "coverage[toml]>=7.10.7,<=7.13", "pillow>=12.0", "strict-no-cover", + "logfire>=3.0.0", ] docs = [ "mkdocs>=1.6.1", + "mkdocs-gen-files>=0.5.0", "mkdocs-glightbox>=0.4.0", - "mkdocs-material>=9.5.45", + "mkdocs-literate-nav>=0.6.1", + "mkdocs-material[imaging]>=9.5.45", "mkdocstrings-python>=2.0.1", ] @@ -93,6 +115,7 @@ bump = true [project.urls] Homepage = "https://modelcontextprotocol.io" +Documentation = "https://py.sdk.modelcontextprotocol.io/v2/" Repository = "https://github.com/modelcontextprotocol/python-sdk" Issues = "https://github.com/modelcontextprotocol/python-sdk/issues" @@ -170,16 +193,20 @@ strict-no-cover = { git = "https://github.com/pydantic/strict-no-cover" } [tool.pytest.ini_options] log_cli = true xfail_strict = true +markers = [ + "requirement(id): links a test to the entry in tests/interaction/_requirements.py it exercises", +] addopts = """ --color=yes --capture=fd + -p anyio + -p examples """ filterwarnings = [ "error", # This should be fixed on Uvicorn's side. "ignore::DeprecationWarning:websockets", "ignore:websockets.server.WebSocketServerProtocol is deprecated:DeprecationWarning", - "ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel", # pywin32 internal deprecation warning "ignore:getargs.*The 'u' format is deprecated:DeprecationWarning", ] @@ -215,13 +242,10 @@ skip_covered = true show_missing = true ignore_errors = true precision = 2 -exclude_lines = [ - "pragma: no cover", +exclude_also = [ "pragma: lax no cover", - "if TYPE_CHECKING:", "@overload", "raise NotImplementedError", - "^\\s*\\.\\.\\.\\s*$", ] # https://coverage.readthedocs.io/en/latest/config.html#paths diff --git a/scripts/build-docs.sh b/scripts/build-docs.sh new file mode 100755 index 0000000000..5a61309acf --- /dev/null +++ b/scripts/build-docs.sh @@ -0,0 +1,54 @@ +#!/usr/bin/env bash +# +# Build combined v1 + v2 MkDocs documentation for GitHub Pages. +# +# v1 docs (from the v1.x branch) are placed at the site root. +# v2 docs (from main) are placed under /v2/. +# +# Both branches are fetched fresh from origin, so the output is identical +# regardless of which branch triggered the workflow. This script is intended +# to run in CI; for local single-branch preview use `uv run mkdocs serve`. +# +# Usage: +# scripts/build-docs.sh [output-dir] +# +# Default output directory: site +# +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" +OUTPUT_DIR="$(cd "$REPO_ROOT" && mkdir -p "${1:-site}" && cd "${1:-site}" && pwd)" +V1_WORKTREE="$REPO_ROOT/.worktrees/v1-docs" +V2_WORKTREE="$REPO_ROOT/.worktrees/v2-docs" + +cleanup() { + cd "$REPO_ROOT" + git worktree remove --force "$V1_WORKTREE" 2>/dev/null || true + git worktree remove --force "$V2_WORKTREE" 2>/dev/null || true + rmdir "$REPO_ROOT/.worktrees" 2>/dev/null || true +} +trap cleanup EXIT + +rm -rf "${OUTPUT_DIR:?}"/* + +build_branch() { + local branch="$1" worktree="$2" dest="$3" + + echo "=== Building docs for ${branch} ===" + git fetch origin "$branch" + git worktree remove --force "$worktree" 2>/dev/null || true + rm -rf "$worktree" + git worktree add --detach "$worktree" "origin/${branch}" + + ( + cd "$worktree" + uv sync --frozen --group docs + uv run --frozen --no-sync mkdocs build --site-dir "$dest" + ) +} + +build_branch v1.x "$V1_WORKTREE" "$OUTPUT_DIR" +build_branch main "$V2_WORKTREE" "$OUTPUT_DIR/v2" + +echo "=== Combined docs built at $OUTPUT_DIR ===" diff --git a/scripts/test b/scripts/test index 0d08e47b1b..dc43f351dd 100755 --- a/scripts/test +++ b/scripts/test @@ -2,6 +2,10 @@ set -ex +uv run --frozen coverage erase uv run --frozen coverage run -m pytest -n auto $@ uv run --frozen coverage combine uv run --frozen coverage report +# strict-no-cover spawns `uv run coverage json` internally without --frozen; +# UV_FROZEN=1 propagates to that subprocess so it doesn't touch uv.lock. +UV_FROZEN=1 uv run --frozen strict-no-cover diff --git a/src/mcp/cli/claude.py b/src/mcp/cli/claude.py index 071b4b6fbb..93bf218fbb 100644 --- a/src/mcp/cli/claude.py +++ b/src/mcp/cli/claude.py @@ -33,7 +33,7 @@ def get_claude_config_path() -> Path | None: # pragma: no cover def get_uv_path() -> str: """Get the full path to the uv executable.""" uv_path = shutil.which("uv") - if not uv_path: # pragma: no cover + if not uv_path: logger.error( "uv executable not found in PATH, falling back to 'uv'. Please ensure uv is installed and in your PATH" ) @@ -65,7 +65,7 @@ def update_claude_config( """ config_dir = get_claude_config_path() uv_path = get_uv_path() - if not config_dir: # pragma: no cover + if not config_dir: raise RuntimeError( "Claude Desktop config directory not found. Please ensure Claude Desktop" " is installed and has been run at least once to initialize its config." @@ -90,7 +90,7 @@ def update_claude_config( config["mcpServers"] = {} # Always preserve existing env vars and merge with new ones - if server_name in config["mcpServers"] and "env" in config["mcpServers"][server_name]: # pragma: no cover + if server_name in config["mcpServers"] and "env" in config["mcpServers"][server_name]: existing_env = config["mcpServers"][server_name]["env"] if env_vars: # New vars take precedence over existing ones @@ -103,22 +103,26 @@ def update_claude_config( # Collect all packages in a set to deduplicate packages = {MCP_PACKAGE} - if with_packages: # pragma: no cover + if with_packages: packages.update(pkg for pkg in with_packages if pkg) # Add all packages with --with for pkg in sorted(packages): args.extend(["--with", pkg]) - if with_editable: # pragma: no cover + if with_editable: args.extend(["--with-editable", str(with_editable)]) # Convert file path to absolute before adding to command # Split off any :object suffix first - if ":" in file_spec: + # First check if we have a Windows path (e.g., C:\...) + has_windows_drive = len(file_spec) > 1 and file_spec[1] == ":" + + # Split on the last colon, but only if it's not part of the Windows drive letter + if ":" in (file_spec[2:] if has_windows_drive else file_spec): file_path, server_object = file_spec.rsplit(":", 1) file_spec = f"{Path(file_path).resolve()}:{server_object}" - else: # pragma: no cover + else: file_spec = str(Path(file_spec).resolve()) # Add mcp run command @@ -127,7 +131,7 @@ def update_claude_config( server_config: dict[str, Any] = {"command": uv_path, "args": args} # Add environment variables if specified - if env_vars: # pragma: no cover + if env_vars: server_config["env"] = env_vars config["mcpServers"][server_name] = server_config diff --git a/src/mcp/cli/cli.py b/src/mcp/cli/cli.py index 858ab7db29..62334a4a2c 100644 --- a/src/mcp/cli/cli.py +++ b/src/mcp/cli/cli.py @@ -317,12 +317,12 @@ def run( ) -> None: # pragma: no cover """Run an MCP server. - The server can be specified in two ways:\n - 1. Module approach: server.py - runs the module directly, expecting a server.run() call.\n - 2. Import approach: server.py:app - imports and runs the specified server object.\n\n + The server can be specified in two ways: + 1. Module approach: server.py - runs the module directly, expecting a server.run() call. + 2. Import approach: server.py:app - imports and runs the specified server object. Note: This command runs the server directly. You are responsible for ensuring - all dependencies are available.\n + all dependencies are available. For dependency management, use `mcp install` or `mcp dev` instead. """ # noqa: E501 file, server_object = _parse_file_path(file_spec) diff --git a/src/mcp/client/__main__.py b/src/mcp/client/__main__.py index f3db17906d..b9ec344226 100644 --- a/src/mcp/client/__main__.py +++ b/src/mcp/client/__main__.py @@ -6,9 +6,9 @@ from urllib.parse import urlparse import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp import types +from mcp.client._transport import ReadStream, WriteStream from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.stdio import StdioServerParameters, stdio_client @@ -33,8 +33,8 @@ async def message_handler( async def run_session( - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], client_info: types.Implementation | None = None, ): async with ClientSession( diff --git a/src/mcp/client/_transport.py b/src/mcp/client/_transport.py index a863629005..0163fef950 100644 --- a/src/mcp/client/_transport.py +++ b/src/mcp/client/_transport.py @@ -5,11 +5,12 @@ from contextlib import AbstractAsyncContextManager from typing import Protocol -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream - +from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.message import SessionMessage -TransportStreams = tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]] +__all__ = ["ReadStream", "WriteStream", "Transport", "TransportStreams"] + +TransportStreams = tuple[ReadStream[SessionMessage | Exception], WriteStream[SessionMessage]] class Transport(AbstractAsyncContextManager[TransportStreams], Protocol): diff --git a/src/mcp/client/auth/extensions/client_credentials.py b/src/mcp/client/auth/extensions/client_credentials.py index 07f6180bf1..cb6dafb407 100644 --- a/src/mcp/client/auth/extensions/client_credentials.py +++ b/src/mcp/client/auth/extensions/client_credentials.py @@ -450,7 +450,7 @@ def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]): # prag # When using private_key_jwt, in a client_credentials flow, we use RFC 7523 Section 2.2 token_data["client_assertion"] = assertion token_data["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" - # We need to set the audience to the resource server, the audience is difference from the one in claims + # We need to set the audience to the resource server, the audience is different from the one in claims # it represents the resource server that will validate the token token_data["audience"] = self.context.get_resource_url() diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 41aecc6f28..3c546fda2b 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -205,8 +205,9 @@ def prepare_token_auth( headers["Authorization"] = f"Basic {encoded_credentials}" # Don't include client_secret in body for basic auth data = {k: v for k, v in data.items() if k != "client_secret"} - elif auth_method == "client_secret_post" and self.client_info.client_secret: - # Include client_secret in request body + elif auth_method == "client_secret_post" and self.client_info.client_id and self.client_info.client_secret: + # Include client_id and client_secret in request body (RFC 6749 §2.3.1) + data["client_id"] = self.client_info.client_id data["client_secret"] = self.client_info.client_secret # For auth_method == "none", don't add any client_secret @@ -215,6 +216,7 @@ def prepare_token_auth( class OAuthClientProvider(httpx.Auth): """OAuth2 authentication for httpx. + Handles OAuth flow with automatic client registration and token storage. """ @@ -241,7 +243,7 @@ def __init__( callback_handler: Handler for authorization callbacks. timeout: Timeout for the OAuth flow. client_metadata_url: URL-based client ID. When provided and the server - advertises client_id_metadata_document_supported=true, this URL will be + advertises client_id_metadata_document_supported=True, this URL will be used as the client_id instead of performing dynamic client registration. Must be a valid HTTPS URL with a non-root pathname. validate_resource_url: Optional callback to override resource URL validation. @@ -318,7 +320,7 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: raise OAuthFlowError("No callback handler provided for authorization code grant") # pragma: no cover if self.context.oauth_metadata and self.context.oauth_metadata.authorization_endpoint: - auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint) # pragma: no cover + auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint) else: auth_base_url = self.context.get_authorization_base_url(self.context.server_url) auth_endpoint = urljoin(auth_base_url, "/authorize") @@ -341,11 +343,16 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: # Only include resource param if conditions are met if self.context.should_include_resource_param(self.context.protocol_version): - auth_params["resource"] = self.context.get_resource_url() # RFC 8707 # pragma: no cover + auth_params["resource"] = self.context.get_resource_url() # RFC 8707 if self.context.client_metadata.scope: # pragma: no branch auth_params["scope"] = self.context.client_metadata.scope + # OIDC requires prompt=consent when offline_access is requested + # https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess + if "offline_access" in self.context.client_metadata.scope.split(): + auth_params["prompt"] = "consent" + authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}" await self.context.redirect_handler(authorization_url) @@ -353,10 +360,10 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: auth_code, returned_state = await self.context.callback_handler() if returned_state is None or not secrets.compare_digest(returned_state, state): - raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {state}") # pragma: no cover + raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {state}") if not auth_code: - raise OAuthFlowError("No authorization code received") # pragma: no cover + raise OAuthFlowError("No authorization code received") # Return auth code and code verifier for token exchange return auth_code, pkce_params.code_verifier @@ -445,7 +452,7 @@ async def _refresh_token(self) -> httpx.Request: return httpx.Request("POST", token_url, data=refresh_data, headers=headers) - async def _handle_refresh_response(self, response: httpx.Response) -> bool: # pragma: no cover + async def _handle_refresh_response(self, response: httpx.Response) -> bool: """Handle token refresh response. Returns True if successful.""" if response.status_code != 200: logger.warning(f"Token refresh failed: {response.status_code}") @@ -461,12 +468,12 @@ async def _handle_refresh_response(self, response: httpx.Response) -> bool: # p await self.context.storage.set_tokens(token_response) return True - except ValidationError: + except ValidationError: # pragma: no cover logger.exception("Invalid refresh response") self.context.clear_tokens() return False - async def _initialize(self) -> None: # pragma: no cover + async def _initialize(self) -> None: """Load stored tokens and client info.""" self.context.current_tokens = await self.context.storage.get_tokens() self.context.client_info = await self.context.storage.get_client_info() @@ -493,12 +500,6 @@ async def _validate_resource_match(self, prm: ProtectedResourceMetadata) -> None if not prm_resource: return # pragma: no cover default_resource = resource_url_from_server_url(self.context.server_url) - # Normalize: Pydantic AnyHttpUrl adds trailing slash to root URLs - # (e.g. "https://example.com/") while resource_url_from_server_url may not. - if not default_resource.endswith("/"): - default_resource += "/" - if not prm_resource.endswith("/"): - prm_resource += "/" if not check_resource_allowed(requested_resource=default_resource, configured_resource=prm_resource): raise OAuthFlowError(f"Protected resource {prm_resource} does not match expected {default_resource}") @@ -506,17 +507,17 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. """HTTPX auth flow integration.""" async with self.context.lock: if not self._initialized: - await self._initialize() # pragma: no cover + await self._initialize() # Capture protocol version from request headers self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION) if not self.context.is_token_valid() and self.context.can_refresh_token(): # Try to refresh token - refresh_request = await self._refresh_token() # pragma: no cover - refresh_response = yield refresh_request # pragma: no cover + refresh_request = await self._refresh_token() + refresh_response = yield refresh_request - if not await self._handle_refresh_response(refresh_response): # pragma: no cover + if not await self._handle_refresh_response(refresh_response): # Refresh failed, need full re-authentication self._initialized = False @@ -580,6 +581,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. extract_scope_from_www_auth(response), self.context.protected_resource_metadata, self.context.oauth_metadata, + self.context.client_metadata.grant_types, ) # Step 4: Register client or use URL-based client ID (CIMD) @@ -610,7 +612,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # Step 5: Perform authorization and complete token exchange token_response = yield await self._perform_authorization() await self._handle_token_response(token_response) - except Exception: # pragma: no cover + except Exception: logger.exception("OAuth flow error") raise @@ -626,7 +628,10 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. try: # Step 2a: Update the required scopes self.context.client_metadata.scope = get_client_metadata_scopes( - extract_scope_from_www_auth(response), self.context.protected_resource_metadata + extract_scope_from_www_auth(response), + self.context.protected_resource_metadata, + self.context.oauth_metadata, + self.context.client_metadata.grant_types, ) # Step 2b: Perform (re-)authorization and token exchange diff --git a/src/mcp/client/auth/utils.py b/src/mcp/client/auth/utils.py index 1aa960b9ce..d75324f2f0 100644 --- a/src/mcp/client/auth/utils.py +++ b/src/mcp/client/auth/utils.py @@ -38,7 +38,7 @@ def extract_field_from_www_auth(response: Response, field_name: str) -> str | No def extract_scope_from_www_auth(response: Response) -> str | None: - """Extract scope parameter from WWW-Authenticate header as per RFC6750. + """Extract scope parameter from WWW-Authenticate header as per RFC 6750. Returns: Scope string if found in WWW-Authenticate header, None otherwise @@ -47,7 +47,7 @@ def extract_scope_from_www_auth(response: Response) -> str | None: def extract_resource_metadata_from_www_auth(response: Response) -> str | None: - """Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728. + """Extract protected resource metadata URL from WWW-Authenticate header as per RFC 9728. Returns: Resource metadata URL if found in WWW-Authenticate header, None otherwise @@ -67,8 +67,8 @@ def build_protected_resource_metadata_discovery_urls(www_auth_url: str | None, s 3. Fall back to root-based well-known URI: /.well-known/oauth-protected-resource Args: - www_auth_url: optional resource_metadata url extracted from the WWW-Authenticate header - server_url: server url + www_auth_url: Optional resource_metadata URL extracted from the WWW-Authenticate header + server_url: Server URL Returns: Ordered list of URLs to try for discovery @@ -99,31 +99,43 @@ def get_client_metadata_scopes( www_authenticate_scope: str | None, protected_resource_metadata: ProtectedResourceMetadata | None, authorization_server_metadata: OAuthMetadata | None = None, + client_grant_types: list[str] | None = None, ) -> str | None: - """Select scopes as outlined in the 'Scope Selection Strategy' in the MCP spec.""" - # Per MCP spec, scope selection priority order: - # 1. Use scope from WWW-Authenticate header (if provided) - # 2. Use all scopes from PRM scopes_supported (if available) - # 3. Omit scope parameter if neither is available - + """Select effective scopes and augment for refresh token support.""" + selected_scope: str | None = None + + # MCP spec scope selection priority: + # 1. WWW-Authenticate header scope + # 2. PRM scopes_supported + # 3. AS scopes_supported (SDK fallback) + # 4. Omit scope parameter if www_authenticate_scope is not None: - # Priority 1: WWW-Authenticate header scope - return www_authenticate_scope + selected_scope = www_authenticate_scope elif protected_resource_metadata is not None and protected_resource_metadata.scopes_supported is not None: - # Priority 2: PRM scopes_supported - return " ".join(protected_resource_metadata.scopes_supported) + selected_scope = " ".join(protected_resource_metadata.scopes_supported) elif authorization_server_metadata is not None and authorization_server_metadata.scopes_supported is not None: - return " ".join(authorization_server_metadata.scopes_supported) # pragma: no cover - else: - # Priority 3: Omit scope parameter - return None + selected_scope = " ".join(authorization_server_metadata.scopes_supported) + + # SEP-2207: append offline_access when the AS supports it and the client can use refresh tokens + if ( + selected_scope is not None + and authorization_server_metadata is not None + and authorization_server_metadata.scopes_supported is not None + and "offline_access" in authorization_server_metadata.scopes_supported + and client_grant_types is not None + and "refresh_token" in client_grant_types + and "offline_access" not in selected_scope.split() + ): + selected_scope = f"{selected_scope} offline_access" + + return selected_scope def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: str | None, server_url: str) -> list[str]: - """Generate ordered list of (url, type) tuples for discovery attempts. + """Generate an ordered list of URLs for authorization server metadata discovery. Args: - auth_server_url: URL for the OAuth Authorization Metadata URL if found, otherwise None + auth_server_url: OAuth Authorization Server Metadata URL if found, otherwise None server_url: URL for the MCP server, used as a fallback if auth_server_url is None """ @@ -170,7 +182,7 @@ async def handle_protected_resource_response( Per SEP-985, supports fallback when discovery fails at one URL. Returns: - True if metadata was successfully discovered, False if we should try next URL + ProtectedResourceMetadata if successfully discovered, None if we should try next URL """ if response.status_code == 200: try: @@ -206,7 +218,7 @@ def create_oauth_metadata_request(url: str) -> Request: def create_client_registration_request( auth_server_metadata: OAuthMetadata | None, client_metadata: OAuthClientMetadata, auth_base_url: str ) -> Request: - """Build registration request or skip if already registered.""" + """Build a client registration request.""" if auth_server_metadata and auth_server_metadata.registration_endpoint: registration_url = str(auth_server_metadata.registration_endpoint) @@ -261,7 +273,7 @@ def should_use_client_metadata_url( """Determine if URL-based client ID (CIMD) should be used instead of DCR. URL-based client IDs should be used when: - 1. The server advertises client_id_metadata_document_supported=true + 1. The server advertises client_id_metadata_document_supported=True 2. The client has a valid client_metadata_url configured Args: @@ -306,7 +318,7 @@ def create_client_info_from_metadata_url( async def handle_token_response_scopes( response: Response, ) -> OAuthToken: - """Parse and validate token response with optional scope validation. + """Parse and validate a token response. Parses token response JSON. Callers should check response.status_code before calling. diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 29d4a7035d..b33fea4052 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -19,6 +19,7 @@ EmptyResult, GetPromptResult, Implementation, + InitializeResult, ListPromptsResult, ListResourcesResult, ListResourceTemplatesResult, @@ -29,7 +30,6 @@ ReadResourceResult, RequestParamsMeta, ResourceTemplateReference, - ServerCapabilities, ) @@ -37,8 +37,8 @@ class Client: """A high-level MCP client for connecting to MCP servers. - Currently supports in-memory transport for testing. Pass a Server or - MCPServer instance directly to the constructor. + Supports in-memory transport for testing (pass a Server or MCPServer instance), + Streamable HTTP transport (pass a URL string), or a custom Transport instance. Example: ```python @@ -155,9 +155,16 @@ def session(self) -> ClientSession: return self._session @property - def server_capabilities(self) -> ServerCapabilities | None: - """The server capabilities received during initialization, or None if not yet initialized.""" - return self.session.get_server_capabilities() + def initialize_result(self) -> InitializeResult: + """The server's InitializeResult. + + Contains server_info, capabilities, instructions, and the negotiated protocol_version. + Raises RuntimeError if accessed outside the context manager. + """ + result = self.session.initialize_result + if result is None: # pragma: no cover + raise RuntimeError("Client must be used within an async context manager") + return result async def send_ping(self, *, meta: RequestParamsMeta | None = None) -> EmptyResult: """Send a ping request to the server.""" @@ -205,7 +212,7 @@ async def read_resource(self, uri: str, *, meta: RequestParamsMeta | None = None Args: uri: The URI of the resource to read. - meta: Additional metadata for the request + meta: Additional metadata for the request. Returns: The resource content. @@ -239,7 +246,7 @@ async def call_tool( meta: Additional metadata for the request Returns: - The tool result + The tool result. """ return await self.session.call_tool( name=name, @@ -298,4 +305,4 @@ async def list_tools(self, *, cursor: str | None = None, meta: RequestParamsMeta async def send_roots_list_changed(self) -> None: """Send a notification that the roots list has changed.""" # TODO(Marcelo): Currently, there is no way for the server to handle this. We should add support. - await self.session.send_roots_list_changed() # pragma: no cover + await self.session.send_roots_list_changed() diff --git a/src/mcp/client/experimental/task_handlers.py b/src/mcp/client/experimental/task_handlers.py index 28ff2b1f29..0ab513236a 100644 --- a/src/mcp/client/experimental/task_handlers.py +++ b/src/mcp/client/experimental/task_handlers.py @@ -187,11 +187,13 @@ class ExperimentalTaskHandlers: WARNING: These APIs are experimental and may change without notice. Example: + ```python handlers = ExperimentalTaskHandlers( get_task=my_get_task_handler, list_tasks=my_list_tasks_handler, ) session = ClientSession(..., experimental_task_handlers=handlers) + ``` """ # Pure task request handlers diff --git a/src/mcp/client/experimental/tasks.py b/src/mcp/client/experimental/tasks.py index 8ddc4faceb..a566df766b 100644 --- a/src/mcp/client/experimental/tasks.py +++ b/src/mcp/client/experimental/tasks.py @@ -5,6 +5,7 @@ WARNING: These APIs are experimental and may change without notice. Example: + ```python # Call a tool as a task result = await session.experimental.call_tool_as_task("tool_name", {"arg": "value"}) task_id = result.task.task_id @@ -21,6 +22,7 @@ # Cancel a task await session.experimental.cancel_task(task_id) + ``` """ from collections.abc import AsyncIterator @@ -72,6 +74,7 @@ async def call_tool_as_task( CreateTaskResult containing the task reference Example: + ```python # Create task result = await session.experimental.call_tool_as_task( "long_running_tool", {"input": "data"} @@ -83,10 +86,11 @@ async def call_tool_as_task( status = await session.experimental.get_task(task_id) if status.status == "completed": break - await asyncio.sleep(0.5) + await anyio.sleep(0.5) # Get result final = await session.experimental.get_task_result(task_id, CallToolResult) + ``` """ return await self._session.send_request( types.CallToolRequest( @@ -177,7 +181,7 @@ async def poll_task(self, task_id: str) -> AsyncIterator[types.GetTaskResult]: """Poll a task until it reaches a terminal status. Yields GetTaskResult for each poll, allowing the caller to react to - status changes (e.g., handle input_required). Exits when task reaches + status changes (e.g., handle input_required). Exits when the task reaches a terminal status (completed, failed, cancelled). Respects the pollInterval hint from the server. @@ -189,6 +193,7 @@ async def poll_task(self, task_id: str) -> AsyncIterator[types.GetTaskResult]: GetTaskResult for each poll Example: + ```python async for status in session.experimental.poll_task(task_id): print(f"Status: {status.status}") if status.status == "input_required": @@ -197,6 +202,7 @@ async def poll_task(self, task_id: str) -> AsyncIterator[types.GetTaskResult]: # Task is now terminal, get the result result = await session.experimental.get_task_result(task_id, CallToolResult) + ``` """ async for status in poll_until_terminal(self.get_task, task_id): yield status diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 0687f98c3a..86113874be 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -4,10 +4,10 @@ from typing import Any, Protocol import anyio.lowlevel -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import TypeAdapter from mcp import types +from mcp.client._transport import ReadStream, WriteStream from mcp.client.experimental import ExperimentalClientFeatures from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers from mcp.shared._context import RequestContext @@ -74,7 +74,7 @@ async def _default_elicitation_callback( context: RequestContext[ClientSession], params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: - return types.ErrorData( # pragma: no cover + return types.ErrorData( code=types.INVALID_REQUEST, message="Elicitation not supported", ) @@ -109,8 +109,8 @@ class ClientSession( ): def __init__( self, - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], read_timeout_seconds: float | None = None, sampling_callback: SamplingFnT | None = None, elicitation_callback: ElicitationFnT | None = None, @@ -131,7 +131,7 @@ def __init__( self._logging_callback = logging_callback or _default_logging_callback self._message_handler = message_handler or _default_message_handler self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} - self._server_capabilities: types.ServerCapabilities | None = None + self._initialize_result: types.InitializeResult | None = None self._experimental_features: ExperimentalClientFeatures | None = None # Experimental: Task handlers (use defaults if not provided) @@ -185,18 +185,19 @@ async def initialize(self) -> types.InitializeResult: if result.protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: raise RuntimeError(f"Unsupported protocol version from the server: {result.protocol_version}") - self._server_capabilities = result.capabilities + self._initialize_result = result await self.send_notification(types.InitializedNotification()) return result - def get_server_capabilities(self) -> types.ServerCapabilities | None: - """Return the server capabilities received during initialization. + @property + def initialize_result(self) -> types.InitializeResult | None: + """The server's InitializeResult. None until initialize() has been called. - Returns None if the session has not been initialized yet. + Contains server_info, capabilities, instructions, and the negotiated protocol_version. """ - return self._server_capabilities + return self._initialize_result @property def experimental(self) -> ExperimentalClientFeatures: @@ -206,8 +207,10 @@ def experimental(self) -> ExperimentalClientFeatures: These APIs are experimental and may change without notice. Example: + ```python status = await session.experimental.get_task(task_id) result = await session.experimental.get_task_result(task_id, CallToolResult) + ``` """ if self._experimental_features is None: self._experimental_features = ExperimentalClientFeatures(self) @@ -334,9 +337,7 @@ async def _validate_tool_result(self, name: str, result: types.CallToolResult) - from jsonschema import SchemaError, ValidationError, validate if result.structured_content is None: - raise RuntimeError( - f"Tool {name} has an output schema but did not return structured content" - ) # pragma: no cover + raise RuntimeError(f"Tool {name} has an output schema but did not return structured content") try: validate(result.structured_content, output_schema) except ValidationError as e: @@ -405,7 +406,7 @@ async def list_tools(self, *, params: types.PaginatedRequestParams | None = None return result - async def send_roots_list_changed(self) -> None: # pragma: no cover + async def send_roots_list_changed(self) -> None: """Send a roots/list_changed notification.""" await self.send_notification(types.RootsListChangedNotification()) @@ -446,7 +447,7 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques client_response = ClientResponse.validate_python(response) await responder.respond(client_response) - case types.PingRequest(): # pragma: no cover + case types.PingRequest(): with responder: return await responder.respond(types.EmptyResult()) diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index f4e6293b71..9610212642 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -3,7 +3,7 @@ Tools, resources, and prompts are aggregated across servers. Servers may be connected to or disconnected from at any point after initialization. -This abstractions can handle naming collisions using a custom user-provided hook. +This abstraction can handle naming collisions using a custom user-provided hook. """ import contextlib @@ -30,7 +30,7 @@ class SseServerParameters(BaseModel): - """Parameters for initializing a sse_client.""" + """Parameters for initializing an sse_client.""" # The endpoint URL. url: str @@ -67,8 +67,8 @@ class StreamableHttpParameters(BaseModel): ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters -# Use dataclass instead of pydantic BaseModel -# because pydantic BaseModel cannot handle Protocol fields. +# Use dataclass instead of Pydantic BaseModel +# because Pydantic BaseModel cannot handle Protocol fields. @dataclass class ClientSessionParameters: """Parameters for establishing a client session to an MCP server.""" @@ -91,13 +91,14 @@ class ClientSessionGroup: For auxiliary handlers, such as resource subscription, this is delegated to the client and can be accessed via the session. - Example Usage: + Example: + ```python name_fn = lambda name, server_info: f"{(server_info.name)}_{name}" async with ClientSessionGroup(component_name_hook=name_fn) as group: for server_param in server_params: await group.connect_to_server(server_param) ... - + ``` """ class _ComponentNames(BaseModel): @@ -119,7 +120,7 @@ class _ComponentNames(BaseModel): _session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack] # Optional fn consuming (component_name, server_info) for custom names. - # This is provide a means to mitigate naming conflicts across servers. + # This is to provide a means to mitigate naming conflicts across servers. # Example: (tool_name, server_info) => "{result.server_info.name}.{tool_name}" _ComponentNameHook: TypeAlias = Callable[[str, types.Implementation], str] _component_name_hook: _ComponentNameHook | None diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 8f8e4dadc7..74e5ba8062 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -7,11 +7,10 @@ import anyio import httpx from anyio.abc import TaskStatus -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from httpx_sse import aconnect_sse -from httpx_sse._exceptions import SSEError +from httpx_sse import SSEError, aconnect_sse from mcp import types +from mcp.shared._context_streams import create_context_streams from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client from mcp.shared.message import SessionMessage @@ -47,115 +46,114 @@ async def sse_client( headers: Optional headers to include in requests. timeout: HTTP timeout for regular operations (in seconds). sse_read_timeout: Timeout for SSE read operations (in seconds). + httpx_client_factory: Factory function for creating the HTTPX client. auth: Optional HTTPX authentication handler. on_session_created: Optional callback invoked with the session ID when received. """ - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] - read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - - write_stream: MemoryObjectSendStream[SessionMessage] - write_stream_reader: MemoryObjectReceiveStream[SessionMessage] - - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - - async with anyio.create_task_group() as tg: - try: - logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}") - async with httpx_client_factory( - headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout) - ) as client: - async with aconnect_sse( - client, - "GET", - url, - ) as event_source: - event_source.response.raise_for_status() - logger.debug("SSE connection established") - - async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED): - try: - async for sse in event_source.aiter_sse(): # pragma: no branch - logger.debug(f"Received SSE event: {sse.event}") - match sse.event: - case "endpoint": - endpoint_url = urljoin(url, sse.data) - logger.debug(f"Received endpoint URL: {endpoint_url}") - - url_parsed = urlparse(url) - endpoint_parsed = urlparse(endpoint_url) - if ( # pragma: no cover - url_parsed.netloc != endpoint_parsed.netloc - or url_parsed.scheme != endpoint_parsed.scheme - ): - error_msg = ( # pragma: no cover - f"Endpoint origin does not match connection origin: {endpoint_url}" - ) - logger.error(error_msg) # pragma: no cover - raise ValueError(error_msg) # pragma: no cover - - if on_session_created: - session_id = _extract_session_id_from_endpoint(endpoint_url) - if session_id: - on_session_created(session_id) - - task_status.started(endpoint_url) - - case "message": - # Skip empty data (keep-alive pings) - if not sse.data: - continue - try: - message = types.jsonrpc_message_adapter.validate_json( - sse.data, by_name=False - ) - logger.debug(f"Received server message: {message}") - except Exception as exc: # pragma: no cover - logger.exception("Error parsing server message") # pragma: no cover - await read_stream_writer.send(exc) # pragma: no cover - continue # pragma: no cover - - session_message = SessionMessage(message) - await read_stream_writer.send(session_message) - case _: # pragma: no cover - logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover - except SSEError as sse_exc: # pragma: lax no cover - logger.exception("Encountered SSE exception") - raise sse_exc - except Exception as exc: # pragma: lax no cover - logger.exception("Error in sse_reader") - await read_stream_writer.send(exc) - finally: - await read_stream_writer.aclose() - - async def post_writer(endpoint_url: str): - try: - async with write_stream_reader: - async for session_message in write_stream_reader: - logger.debug(f"Sending client message: {session_message}") - response = await client.post( - endpoint_url, - json=session_message.message.model_dump( - by_alias=True, - mode="json", - exclude_none=True, - ), + logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}") + async with httpx_client_factory( + headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout) + ) as client: + async with aconnect_sse(client, "GET", url) as event_source: + event_source.response.raise_for_status() + logger.debug("SSE connection established") + + read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) + write_stream, write_stream_reader = create_context_streams[SessionMessage](0) + + async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED): + try: + async for sse in event_source.aiter_sse(): # pragma: no branch + logger.debug(f"Received SSE event: {sse.event}") + match sse.event: + case "endpoint": + endpoint_url = urljoin(url, sse.data) + logger.debug(f"Received endpoint URL: {endpoint_url}") + + url_parsed = urlparse(url) + endpoint_parsed = urlparse(endpoint_url) + if ( # pragma: no cover + url_parsed.netloc != endpoint_parsed.netloc + or url_parsed.scheme != endpoint_parsed.scheme + ): + error_msg = ( # pragma: no cover + f"Endpoint origin does not match connection origin: {endpoint_url}" ) - response.raise_for_status() - logger.debug(f"Client message sent successfully: {response.status_code}") - except Exception: # pragma: lax no cover - logger.exception("Error in post_writer") - finally: - await write_stream.aclose() - - endpoint_url = await tg.start(sse_reader) - logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}") - tg.start_soon(post_writer, endpoint_url) - - try: - yield read_stream, write_stream - finally: - tg.cancel_scope.cancel() - finally: - await read_stream_writer.aclose() - await write_stream.aclose() + logger.error(error_msg) # pragma: no cover + raise ValueError(error_msg) # pragma: no cover + + if on_session_created: + session_id = _extract_session_id_from_endpoint(endpoint_url) + if session_id: + on_session_created(session_id) + + task_status.started(endpoint_url) + + case "message": + # Skip empty data (keep-alive pings) + if not sse.data: + continue + try: + message = types.jsonrpc_message_adapter.validate_json(sse.data, by_name=False) + logger.debug(f"Received server message: {message}") + except Exception as exc: # pragma: no cover + logger.exception("Error parsing server message") # pragma: no cover + await read_stream_writer.send(exc) # pragma: no cover + continue # pragma: no cover + + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) + case _: # pragma: no cover + logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover + except SSEError as sse_exc: # pragma: lax no cover + logger.exception("Encountered SSE exception") + raise sse_exc + except Exception as exc: # pragma: lax no cover + logger.exception("Error in sse_reader") + await read_stream_writer.send(exc) + finally: + await read_stream_writer.aclose() + + async def post_writer(endpoint_url: str): + try: + async with write_stream_reader, write_stream: + + async def _send_message(session_message: SessionMessage) -> None: + logger.debug(f"Sending client message: {session_message}") + response = await client.post( + endpoint_url, + json=session_message.message.model_dump( + by_alias=True, + mode="json", + exclude_unset=True, + ), + ) + response.raise_for_status() + logger.debug(f"Client message sent successfully: {response.status_code}") + + async for session_message in write_stream_reader: + sender_ctx = write_stream_reader.last_context + if sender_ctx is not None: + async with anyio.create_task_group() as tg: + sender_ctx.run(tg.start_soon, _send_message, session_message) + else: + await _send_message(session_message) # pragma: no cover + except Exception: # pragma: lax no cover + logger.exception("Error in post_writer") + + # On Python 3.14, coverage.py reports a phantom branch arc on this + # line (->yield) when nested two async-with levels deep. The branch + # is the unreachable "did __aexit__ suppress?" arm for memory streams. + async with ( # pragma: no branch + read_stream_writer, + read_stream, + write_stream, + write_stream_reader, + anyio.create_task_group() as tg, + ): + endpoint_url = await tg.start(sse_reader) + logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}") + tg.start_soon(post_writer, endpoint_url) + + yield read_stream, write_stream + tg.cancel_scope.cancel() diff --git a/src/mcp/client/stdio.py b/src/mcp/client/stdio.py index 605c5ea24e..902dc8576c 100644 --- a/src/mcp/client/stdio.py +++ b/src/mcp/client/stdio.py @@ -87,9 +87,9 @@ class StdioServerParameters(BaseModel): encoding: str = "utf-8" """ - The text encoding used when sending/receiving messages to the server + The text encoding used when sending/receiving messages to the server. - defaults to utf-8 + Defaults to utf-8. """ encoding_error_handler: Literal["strict", "ignore", "replace"] = "strict" @@ -97,7 +97,7 @@ class StdioServerParameters(BaseModel): The text encoding error handler. See https://docs.python.org/3/library/codecs.html#codec-base-classes for - explanations of possible values + explanations of possible values. """ @@ -167,7 +167,7 @@ async def stdin_writer(): try: async with write_stream_reader: async for session_message in write_stream_reader: - json = session_message.message.model_dump_json(by_alias=True, exclude_none=True) + json = session_message.message.model_dump_json(by_alias=True, exclude_unset=True) await process.stdin.send( (json + "\n").encode( encoding=server.encoding, diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 9d45bec6ee..aa3e50e07e 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -11,14 +11,15 @@ import anyio import httpx from anyio.abc import TaskGroup -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from httpx_sse import EventSource, ServerSentEvent, aconnect_sse from pydantic import ValidationError from mcp.client._transport import TransportStreams +from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams from mcp.shared._httpx_utils import create_mcp_http_client from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.types import ( + INTERNAL_ERROR, INVALID_REQUEST, PARSE_ERROR, ErrorData, @@ -37,8 +38,8 @@ # TODO(Marcelo): Put the TransportStreams in a module under shared, so we can import here. SessionMessageOrError = SessionMessage | Exception -StreamWriter = MemoryObjectSendStream[SessionMessageOrError] -StreamReader = MemoryObjectReceiveStream[SessionMessage] +StreamWriter = ContextSendStream[SessionMessageOrError] +StreamReader = ContextReceiveStream[SessionMessage] MCP_SESSION_ID = "mcp-session-id" MCP_PROTOCOL_VERSION = "mcp-protocol-version" @@ -209,7 +210,7 @@ async def handle_get_stream(self, client: httpx.AsyncClient, read_stream_writer: # Stream ended normally (server closed) - reset attempt counter attempt = 0 - except Exception: # pragma: lax no cover + except Exception: logger.debug("GET stream error", exc_info=True) attempt += 1 @@ -259,21 +260,27 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: async with ctx.client.stream( "POST", self.url, - json=message.model_dump(by_alias=True, mode="json", exclude_none=True), + json=message.model_dump(by_alias=True, mode="json", exclude_unset=True), headers=headers, ) as response: if response.status_code == 202: logger.debug("Received 202 Accepted") return - if response.status_code == 404: # pragma: no branch - if isinstance(message, JSONRPCRequest): # pragma: no branch + if response.status_code == 404: + if isinstance(message, JSONRPCRequest): error_data = ErrorData(code=INVALID_REQUEST, message="Session terminated") session_message = SessionMessage(JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data)) await ctx.read_stream_writer.send(session_message) return - response.raise_for_status() + if response.status_code >= 400: + if isinstance(message, JSONRPCRequest): + error_data = ErrorData(code=INTERNAL_ERROR, message="Server returned an error response") + session_message = SessionMessage(JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data)) + await ctx.read_stream_writer.send(session_message) + return + if is_initialization: self._maybe_extract_session_id_from_response(response) @@ -351,7 +358,7 @@ async def _handle_sse_response( resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None), is_initialization=is_initialization, ) - # If the SSE event indicates completion, like returning respose/error + # If the SSE event indicates completion, like returning response/error # break the loop if is_complete: await response.aclose() @@ -427,14 +434,15 @@ async def post_writer( client: httpx.AsyncClient, write_stream_reader: StreamReader, read_stream_writer: StreamWriter, - write_stream: MemoryObjectSendStream[SessionMessage], + write_stream: ContextSendStream[SessionMessage], start_get_stream: Callable[[], None], tg: TaskGroup, ) -> None: """Handle writing requests to the server.""" try: - async with write_stream_reader: - async for session_message in write_stream_reader: + async with write_stream_reader, read_stream_writer, write_stream: + + async def _handle_message(session_message: SessionMessage) -> None: message = session_message.message metadata = ( session_message.metadata @@ -471,25 +479,30 @@ async def handle_request_async(): else: await handle_request_async() + async for session_message in write_stream_reader: + sender_ctx = write_stream_reader.last_context + if sender_ctx is not None: + async with anyio.create_task_group() as tg_local: + sender_ctx.run(tg_local.start_soon, _handle_message, session_message) + else: + await _handle_message(session_message) # pragma: no cover + except Exception: # pragma: lax no cover logger.exception("Error in post_writer") - finally: - await read_stream_writer.aclose() - await write_stream.aclose() async def terminate_session(self, client: httpx.AsyncClient) -> None: """Terminate the session by sending a DELETE request.""" - if not self.session_id: # pragma: lax no cover - return + if not self.session_id: + return # pragma: no cover try: headers = self._prepare_headers() response = await client.delete(self.url, headers=headers) - if response.status_code == 405: # pragma: lax no cover + if response.status_code == 405: logger.debug("Server does not allow session termination") - elif response.status_code not in (200, 204): # pragma: lax no cover - logger.warning(f"Session termination failed: {response.status_code}") + elif response.status_code not in (200, 204): + logger.warning(f"Session termination failed: {response.status_code}") # pragma: no cover except Exception as exc: # pragma: no cover logger.warning(f"Session termination failed: {exc}") @@ -526,9 +539,6 @@ async def streamable_http_client( Example: See examples/snippets/clients/ for usage patterns. """ - read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) - write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) - # Determine if we need to create and manage the client client_provided = http_client is not None client = http_client @@ -539,34 +549,40 @@ async def streamable_http_client( transport = StreamableHTTPTransport(url) - async with anyio.create_task_group() as tg: - try: - logger.debug(f"Connecting to StreamableHTTP endpoint: {url}") - - async with contextlib.AsyncExitStack() as stack: - # Only manage client lifecycle if we created it - if not client_provided: - await stack.enter_async_context(client) - - def start_get_stream() -> None: - tg.start_soon(transport.handle_get_stream, client, read_stream_writer) - - tg.start_soon( - transport.post_writer, - client, - write_stream_reader, - read_stream_writer, - write_stream, - start_get_stream, - tg, - ) + logger.debug(f"Connecting to StreamableHTTP endpoint: {url}") + + async with contextlib.AsyncExitStack() as stack: + # Only manage client lifecycle if we created it + if not client_provided: + await stack.enter_async_context(client) + + read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) + write_stream, write_stream_reader = create_context_streams[SessionMessage](0) + + async with ( + read_stream_writer, + read_stream, + write_stream, + write_stream_reader, + anyio.create_task_group() as tg, + ): + + def start_get_stream() -> None: + tg.start_soon(transport.handle_get_stream, client, read_stream_writer) + + tg.start_soon( + transport.post_writer, + client, + write_stream_reader, + read_stream_writer, + write_stream, + start_get_stream, + tg, + ) - try: - yield read_stream, write_stream - finally: - if transport.session_id and terminate_on_close: - await transport.terminate_session(client) - tg.cancel_scope.cancel() - finally: - await read_stream_writer.aclose() - await write_stream.aclose() + try: + yield read_stream, write_stream + finally: + if transport.session_id and terminate_on_close: + await transport.terminate_session(client) + tg.cancel_scope.cancel() diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index cf4b86e992..de473f36d3 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -25,8 +25,8 @@ async def websocket_client( (read_stream, write_stream) - read_stream: As you read from this stream, you'll receive either valid - JSONRPCMessage objects or Exception objects (when validation fails). - - write_stream: Write JSONRPCMessage objects to this stream to send them + SessionMessage objects or Exception objects (when validation fails). + - write_stream: Write SessionMessage objects to this stream to send them over the WebSocket to the server. """ @@ -38,11 +38,10 @@ async def websocket_client( write_stream: MemoryObjectSendStream[SessionMessage] write_stream_reader: MemoryObjectReceiveStream[SessionMessage] - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - # Connect using websockets, requesting the "mcp" subprotocol async with ws_connect(url, subprotocols=[Subprotocol("mcp")]) as ws: + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) async def ws_reader(): """Reads text messages from the WebSocket, parses them as JSON-RPC messages, @@ -65,10 +64,16 @@ async def ws_writer(): async with write_stream_reader: async for session_message in write_stream_reader: # Convert to a dict, then to JSON - msg_dict = session_message.message.model_dump(by_alias=True, mode="json", exclude_none=True) + msg_dict = session_message.message.model_dump(by_alias=True, mode="json", exclude_unset=True) await ws.send(json.dumps(msg_dict)) - async with anyio.create_task_group() as tg: + async with ( + read_stream_writer, + read_stream, + write_stream, + write_stream_reader, + anyio.create_task_group() as tg, + ): # Start reader and writer tasks tg.start_soon(ws_reader) tg.start_soon(ws_writer) diff --git a/src/mcp/os/win32/utilities.py b/src/mcp/os/win32/utilities.py index fa4e4b399b..6f68405f78 100644 --- a/src/mcp/os/win32/utilities.py +++ b/src/mcp/os/win32/utilities.py @@ -123,6 +123,11 @@ def pid(self) -> int: """Return the process ID.""" return self.popen.pid + @property + def returncode(self) -> int | None: + """Return the exit code, or ``None`` if the process has not yet terminated.""" + return self.popen.returncode + # ------------------------ # Updated function @@ -138,9 +143,9 @@ async def create_windows_process( ) -> Process | FallbackProcess: """Creates a subprocess in a Windows-compatible way with Job Object support. - Attempt to use anyio's open_process for async subprocess creation. - In some cases this will throw NotImplementedError on Windows, e.g. - when using the SelectorEventLoop which does not support async subprocesses. + Attempts to use anyio's open_process for async subprocess creation. + In some cases this will throw NotImplementedError on Windows, e.g., + when using the SelectorEventLoop, which does not support async subprocesses. In that case, we fall back to using subprocess.Popen. The process is automatically added to a Job Object to ensure all child @@ -242,8 +247,9 @@ def _create_job_object() -> int | None: def _maybe_assign_process_to_job(process: Process | FallbackProcess, job: JobHandle | None) -> None: - """Try to assign a process to a job object. If assignment fails - for any reason, the job handle is closed. + """Try to assign a process to a job object. + + If assignment fails for any reason, the job handle is closed. """ if not job: return @@ -312,8 +318,8 @@ async def terminate_windows_process(process: Process | FallbackProcess): Note: On Windows, terminating a process with process.terminate() doesn't always guarantee immediate process termination. - So we give it 2s to exit, or we call process.kill() - which sends a SIGKILL equivalent signal. + If the process does not exit within 2 seconds, process.kill() is called + to send a SIGKILL-equivalent signal. Args: process: The process to terminate diff --git a/src/mcp/server/__init__.py b/src/mcp/server/__init__.py index a2dada3af7..aab5c33f7d 100644 --- a/src/mcp/server/__init__.py +++ b/src/mcp/server/__init__.py @@ -1,5 +1,6 @@ +from .context import ServerRequestContext from .lowlevel import NotificationOptions, Server from .mcpserver import MCPServer from .models import InitializationOptions -__all__ = ["Server", "MCPServer", "NotificationOptions", "InitializationOptions"] +__all__ = ["Server", "ServerRequestContext", "MCPServer", "NotificationOptions", "InitializationOptions"] diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index dec6713b13..5cf93cf8c2 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -117,7 +117,7 @@ async def error_response( pass # the error response MUST contain the state specified by the client, if any - if state is None: # pragma: no cover + if state is None: # make last-ditch effort to load state state = best_effort_extract_string("state", params) diff --git a/src/mcp/server/auth/handlers/revoke.py b/src/mcp/server/auth/handlers/revoke.py index 68a3392b4f..4efd154001 100644 --- a/src/mcp/server/auth/handlers/revoke.py +++ b/src/mcp/server/auth/handlers/revoke.py @@ -15,7 +15,7 @@ class RevocationRequest(BaseModel): - """# See https://datatracker.ietf.org/doc/html/rfc7009#section-2.1""" + """See https://datatracker.ietf.org/doc/html/rfc7009#section-2.1""" token: str token_type_hint: Literal["access_token", "refresh_token"] | None = None diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index 6825c00b9e..ba66e94226 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -1,6 +1,6 @@ import json import time -from typing import Any +from typing import Any, TypedDict from pydantic import AnyHttpUrl from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser @@ -19,6 +19,30 @@ def __init__(self, auth_info: AccessToken): self.scopes = auth_info.scopes +class AuthorizationContext(TypedDict): + client_id: str + issuer: str | None + subject: str | None + + +def authorization_context(user: AuthenticatedUser) -> AuthorizationContext: + """Identify the principal `user` represents, for transports to compare + against the principal that created a session. Components the token + verifier does not supply are `None`, so the comparison degrades to the + remaining components. + + See `examples/servers/simple-auth/mcp_simple_auth/token_verifier.py` for + a verifier that populates `subject` and `claims` from an introspection + response.""" + token = user.access_token + issuer = (token.claims or {}).get("iss") + return AuthorizationContext( + client_id=token.client_id, + issuer=str(issuer) if issuer is not None else None, + subject=token.subject, + ) + + class BearerAuthBackend(AuthenticationBackend): """Authentication backend that validates Bearer tokens using a TokenVerifier.""" @@ -95,7 +119,7 @@ async def _send_auth_error(self, send: Send, status_code: int, error: str, descr """Send an authentication error response with WWW-Authenticate header.""" # Build WWW-Authenticate header value www_auth_parts = [f'error="{error}"', f'error_description="{description}"'] - if self.resource_metadata_url: # pragma: no cover + if self.resource_metadata_url: www_auth_parts.append(f'resource_metadata="{self.resource_metadata_url}"') www_authenticate = f"Bearer {', '.join(www_auth_parts)}" diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index 8a6a1b5188..2832f83523 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -19,15 +19,17 @@ def __init__(self, message: str): class ClientAuthenticator: """ClientAuthenticator is a callable which validates requests from a client application, used to verify /token calls. + If, during registration, the client requested to be issued a secret, the authenticator asserts that /token calls must be authenticated with - that same token. + that same secret. + NOTE: clients can opt for no authentication during registration, in which case this logic is skipped. """ def __init__(self, provider: OAuthAuthorizationServerProvider[Any, Any, Any]): - """Initialize the dependency. + """Initialize the authenticator. Args: provider: Provider to look up client information @@ -83,7 +85,7 @@ async def authenticate_request(self, request: Request) -> OAuthClientInformation elif client.token_endpoint_auth_method == "client_secret_post": raw_form_data = form_data.get("client_secret") - # form_data.get() can return a UploadFile or None, so we need to check if it's a string + # form_data.get() can return an UploadFile or None, so we need to check if it's a string if isinstance(raw_form_data, str): request_client_secret = str(raw_form_data) diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index 5eb577fd43..4ce1137575 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Generic, Literal, Protocol, TypeVar +from typing import Any, Generic, Literal, Protocol, TypeVar from urllib.parse import parse_qs, urlencode, urlparse, urlunparse from pydantic import AnyUrl, BaseModel @@ -25,6 +25,7 @@ class AuthorizationCode(BaseModel): redirect_uri: AnyUrl redirect_uri_provided_explicitly: bool resource: str | None = None # RFC 8707 resource indicator + subject: str | None = None # resource owner; propagate to the issued AccessToken class RefreshToken(BaseModel): @@ -32,6 +33,7 @@ class RefreshToken(BaseModel): client_id: str scopes: list[str] expires_at: int | None = None + subject: str | None = None # resource owner; propagate to refreshed AccessTokens class AccessToken(BaseModel): @@ -40,6 +42,8 @@ class AccessToken(BaseModel): scopes: list[str] expires_at: int | None = None resource: str | None = None # RFC 8707 resource indicator + subject: str | None = None # RFC 7662/9068 `sub`: resource owner; unique only per issuer + claims: dict[str, Any] | None = None # additional claims (e.g. `iss`, `act`) RegistrationErrorCode = Literal[ @@ -131,8 +135,9 @@ async def register_client(self, client_info: OAuthClientInformationFull) -> None """ async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: - """Called as part of the /authorize endpoint, and returns a URL that the client + """Handle the /authorize endpoint and return a URL that the client will be redirected to. + Many MCP implementations will redirect to a third-party provider to perform a second OAuth exchange with that provider. In this sort of setup, the client has an OAuth connection with the MCP server, and the MCP server has an OAuth @@ -151,7 +156,7 @@ async def authorize(self, client: OAuthClientInformationFull, params: Authorizat | | +------------+ - Implementations will need to define another handler on the MCP server return + Implementations will need to define another handler on the MCP server's return flow to perform the second redirect, and generate and store an authorization code as part of completing the OAuth authorization step. @@ -182,7 +187,7 @@ async def load_authorization_code( authorization_code: The authorization code to get the challenge for. Returns: - The AuthorizationCode, or None if not found + The AuthorizationCode, or None if not found. """ ... @@ -199,7 +204,7 @@ async def exchange_authorization_code( The OAuth token, containing access and refresh tokens. Raises: - TokenError: If the request is invalid + TokenError: If the request is invalid. """ ... @@ -234,18 +239,18 @@ async def exchange_refresh_token( The OAuth token, containing access and refresh tokens. Raises: - TokenError: If the request is invalid + TokenError: If the request is invalid. """ ... async def load_access_token(self, token: str) -> AccessTokenT | None: - """Loads an access token by its token. + """Loads an access token by its token string. Args: token: The access token to verify. Returns: - The AuthInfo, or None if the token is invalid. + The access token, or None if the token is invalid. """ async def revoke_token( @@ -261,7 +266,7 @@ async def revoke_token( provided. Args: - token: the token to revoke + token: The token to revoke. """ diff --git a/src/mcp/server/auth/routes.py b/src/mcp/server/auth/routes.py index 08f735f362..a72e819477 100644 --- a/src/mcp/server/auth/routes.py +++ b/src/mcp/server/auth/routes.py @@ -25,25 +25,21 @@ def validate_issuer_url(url: AnyHttpUrl): """Validate that the issuer URL meets OAuth 2.0 requirements. Args: - url: The issuer URL to validate + url: The issuer URL to validate. Raises: - ValueError: If the issuer URL is invalid + ValueError: If the issuer URL is invalid. """ - # RFC 8414 requires HTTPS, but we allow localhost HTTP for testing - if ( - url.scheme != "https" - and url.host != "localhost" - and (url.host is not None and not url.host.startswith("127.0.0.1")) - ): - raise ValueError("Issuer URL must be HTTPS") # pragma: no cover + # RFC 8414 requires HTTPS, but we allow loopback/localhost HTTP for testing + if url.scheme != "https" and url.host not in ("localhost", "127.0.0.1", "[::1]"): + raise ValueError("Issuer URL must be HTTPS") # No fragments or query parameters allowed if url.fragment: - raise ValueError("Issuer URL must not have a fragment") # pragma: no cover + raise ValueError("Issuer URL must not have a fragment") if url.query: - raise ValueError("Issuer URL must not have a query string") # pragma: no cover + raise ValueError("Issuer URL must not have a query string") AUTHORIZATION_PATH = "/authorize" @@ -217,6 +213,8 @@ def create_protected_resource_routes( resource_url: The URL of this resource server authorization_servers: List of authorization servers that can issue tokens scopes_supported: Optional list of scopes supported by this resource + resource_name: Optional human-readable name for this resource + resource_documentation: Optional URL to documentation for this resource Returns: List of Starlette routes for protected resource metadata diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index 43b9d3800c..d8e11d78b2 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -10,7 +10,7 @@ from mcp.shared._context import RequestContext from mcp.shared.message import CloseSSEStreamCallback -LifespanContextT = TypeVar("LifespanContextT") +LifespanContextT = TypeVar("LifespanContextT", default=dict[str, Any]) RequestT = TypeVar("RequestT", default=Any) diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index 58e9fe4485..731c914edc 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -112,8 +112,8 @@ async def elicit_with_validation( This method can be used to interactively ask for additional information from the client within a tool's execution. The client might display the message to the - user and collect a response according to the provided schema. Or in case a - client is an agent, it might decide how to handle the elicitation -- either by asking + user and collect a response according to the provided schema. If the client + is an agent, it might decide how to handle the elicitation -- either by asking the user or automatically generating a response. For sensitive data like credentials or OAuth flows, use elicit_url() instead. diff --git a/src/mcp/server/experimental/request_context.py b/src/mcp/server/experimental/request_context.py index 80ae5912b0..3eba65822a 100644 --- a/src/mcp/server/experimental/request_context.py +++ b/src/mcp/server/experimental/request_context.py @@ -62,8 +62,8 @@ def validate_task_mode( """Validate that the request is compatible with the tool's task execution mode. Per MCP spec: - - "required": Clients MUST invoke as task. Server returns -32601 if not. - - "forbidden" (or None): Clients MUST NOT invoke as task. Server returns -32601 if they do. + - "required": Clients MUST invoke as a task. Server returns -32601 if not. + - "forbidden" (or None): Clients MUST NOT invoke as a task. Server returns -32601 if they do. - "optional": Either is acceptable. Args: @@ -111,7 +111,7 @@ def can_use_tool(self, tool_task_mode: TaskExecutionMode | None) -> bool: """Check if this client can use a tool with the given task mode. Useful for filtering tool lists or providing warnings. - Returns False if tool requires "required" but client doesn't support tasks. + Returns False if the tool's task mode is "required" but the client doesn't support tasks. Args: tool_task_mode: The tool's execution.taskSupport value @@ -160,19 +160,18 @@ async def run_task( RuntimeError: If task support is not enabled or task_metadata is missing Example: - @server.call_tool() - async def handle_tool(name: str, args: dict): - ctx = server.request_context - + ```python + async def handle_tool(ctx: RequestContext, params: CallToolRequestParams) -> CallToolResult: async def work(task: ServerTaskContext) -> CallToolResult: result = await task.elicit( message="Are you sure?", - requestedSchema={"type": "object", ...} + requested_schema={"type": "object", ...} ) confirmed = result.content.get("confirm", False) return CallToolResult(content=[TextContent(text="Done" if confirmed else "Cancelled")]) return await ctx.experimental.run_task(work) + ``` WARNING: This API is experimental and may change without notice. """ diff --git a/src/mcp/server/experimental/task_context.py b/src/mcp/server/experimental/task_context.py index 9b626c9862..1fc45badfd 100644 --- a/src/mcp/server/experimental/task_context.py +++ b/src/mcp/server/experimental/task_context.py @@ -56,6 +56,7 @@ class ServerTaskContext: - Status notifications via the session Example: + ```python async def my_task_work(task: ServerTaskContext) -> CallToolResult: await task.update_status("Starting...") @@ -68,6 +69,7 @@ async def my_task_work(task: ServerTaskContext) -> CallToolResult: return CallToolResult(content=[TextContent(text="Done!")]) else: return CallToolResult(content=[TextContent(text="Cancelled")]) + ``` """ def __init__( diff --git a/src/mcp/server/experimental/task_result_handler.py b/src/mcp/server/experimental/task_result_handler.py index 991221bd0b..b2268bc1c8 100644 --- a/src/mcp/server/experimental/task_result_handler.py +++ b/src/mcp/server/experimental/task_result_handler.py @@ -44,17 +44,14 @@ class TaskResultHandler: 5. Returns the final result Usage: - # Create handler with store and queue - handler = TaskResultHandler(task_store, message_queue) - - # Register it with the server - @server.experimental.get_task_result() - async def handle_task_result(req: GetTaskPayloadRequest) -> GetTaskPayloadResult: - ctx = server.request_context - return await handler.handle(req, ctx.session, ctx.request_id) - - # Or use the convenience method - handler.register(server) + async def handle_task_result( + ctx: ServerRequestContext, params: GetTaskPayloadRequestParams + ) -> GetTaskPayloadResult: + ... + + server.experimental.enable_tasks( + on_task_result=handle_task_result, + ) """ def __init__( diff --git a/src/mcp/server/experimental/task_support.py b/src/mcp/server/experimental/task_support.py index 23b5d9cc89..b542195048 100644 --- a/src/mcp/server/experimental/task_support.py +++ b/src/mcp/server/experimental/task_support.py @@ -31,14 +31,20 @@ class TaskSupport: - Manages a task group for background task execution Example: - # Simple in-memory setup + Simple in-memory setup: + + ```python server.experimental.enable_tasks() + ``` + + Custom store/queue for distributed systems: - # Custom store/queue for distributed systems + ```python server.experimental.enable_tasks( store=RedisTaskStore(redis_url), queue=RedisTaskMessageQueue(redis_url), ) + ``` """ store: TaskStore diff --git a/src/mcp/server/lowlevel/__init__.py b/src/mcp/server/lowlevel/__init__.py index 66df389916..37191ba1a0 100644 --- a/src/mcp/server/lowlevel/__init__.py +++ b/src/mcp/server/lowlevel/__init__.py @@ -1,3 +1,3 @@ from .server import NotificationOptions, Server -__all__ = ["Server", "NotificationOptions"] +__all__ = ["NotificationOptions", "Server"] diff --git a/src/mcp/server/lowlevel/experimental.py b/src/mcp/server/lowlevel/experimental.py index 9b472c0232..5a907b6407 100644 --- a/src/mcp/server/lowlevel/experimental.py +++ b/src/mcp/server/lowlevel/experimental.py @@ -7,10 +7,12 @@ import logging from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING +from typing import Any, Generic +from typing_extensions import TypeVar + +from mcp.server.context import ServerRequestContext from mcp.server.experimental.task_support import TaskSupport -from mcp.server.lowlevel.func_inspection import create_call_wrapper from mcp.shared.exceptions import MCPError from mcp.shared.experimental.tasks.helpers import cancel_task from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore @@ -18,16 +20,16 @@ from mcp.shared.experimental.tasks.store import TaskStore from mcp.types import ( INVALID_PARAMS, - CancelTaskRequest, + CancelTaskRequestParams, CancelTaskResult, GetTaskPayloadRequest, + GetTaskPayloadRequestParams, GetTaskPayloadResult, - GetTaskRequest, + GetTaskRequestParams, GetTaskResult, - ListTasksRequest, ListTasksResult, + PaginatedRequestParams, ServerCapabilities, - ServerResult, ServerTasksCapability, ServerTasksRequestsCapability, TasksCallCapability, @@ -36,13 +38,12 @@ TasksToolsCapability, ) -if TYPE_CHECKING: - from mcp.server.lowlevel.server import Server - logger = logging.getLogger(__name__) +LifespanResultT = TypeVar("LifespanResultT", default=Any) -class ExperimentalHandlers: + +class ExperimentalHandlers(Generic[LifespanResultT]): """Experimental request/notification handlers. WARNING: These APIs are experimental and may change without notice. @@ -50,13 +51,13 @@ class ExperimentalHandlers: def __init__( self, - server: Server, - request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]], - notification_handlers: dict[type, Callable[..., Awaitable[None]]], - ): - self._server = server - self._request_handlers = request_handlers - self._notification_handlers = notification_handlers + add_request_handler: Callable[ + [str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]]], None + ], + has_handler: Callable[[str], bool], + ) -> None: + self._add_request_handler = add_request_handler + self._has_handler = has_handler self._task_support: TaskSupport | None = None @property @@ -66,16 +67,13 @@ def task_support(self) -> TaskSupport | None: def update_capabilities(self, capabilities: ServerCapabilities) -> None: # Only add tasks capability if handlers are registered - if not any( - req_type in self._request_handlers - for req_type in [GetTaskRequest, ListTasksRequest, CancelTaskRequest, GetTaskPayloadRequest] - ): + if not any(self._has_handler(method) for method in ["tasks/get", "tasks/list", "tasks/cancel", "tasks/result"]): return capabilities.tasks = ServerTasksCapability() - if ListTasksRequest in self._request_handlers: + if self._has_handler("tasks/list"): capabilities.tasks.list = TasksListCapability() - if CancelTaskRequest in self._request_handlers: + if self._has_handler("tasks/cancel"): capabilities.tasks.cancel = TasksCancelCapability() capabilities.tasks.requests = ServerTasksRequestsCapability( @@ -86,28 +84,54 @@ def enable_tasks( self, store: TaskStore | None = None, queue: TaskMessageQueue | None = None, + *, + on_get_task: Callable[[ServerRequestContext[LifespanResultT], GetTaskRequestParams], Awaitable[GetTaskResult]] + | None = None, + on_task_result: Callable[ + [ServerRequestContext[LifespanResultT], GetTaskPayloadRequestParams], Awaitable[GetTaskPayloadResult] + ] + | None = None, + on_list_tasks: Callable[ + [ServerRequestContext[LifespanResultT], PaginatedRequestParams | None], Awaitable[ListTasksResult] + ] + | None = None, + on_cancel_task: Callable[ + [ServerRequestContext[LifespanResultT], CancelTaskRequestParams], Awaitable[CancelTaskResult] + ] + | None = None, ) -> TaskSupport: """Enable experimental task support. - This sets up the task infrastructure and auto-registers default handlers - for tasks/get, tasks/result, tasks/list, and tasks/cancel. + This sets up the task infrastructure and registers handlers for + tasks/get, tasks/result, tasks/list, and tasks/cancel. Custom handlers + can be provided via the on_* kwargs; any not provided will use defaults. Args: store: Custom TaskStore implementation (defaults to InMemoryTaskStore) queue: Custom TaskMessageQueue implementation (defaults to InMemoryTaskMessageQueue) + on_get_task: Custom handler for tasks/get + on_task_result: Custom handler for tasks/result + on_list_tasks: Custom handler for tasks/list + on_cancel_task: Custom handler for tasks/cancel Returns: The TaskSupport configuration object Example: - # Simple in-memory setup + Simple in-memory setup: + + ```python server.experimental.enable_tasks() + ``` - # Custom store/queue for distributed systems + Custom store/queue for distributed systems: + + ```python server.experimental.enable_tasks( store=RedisTaskStore(redis_url), queue=RedisTaskMessageQueue(redis_url), ) + ``` WARNING: This API is experimental and may change without notice. """ @@ -117,24 +141,27 @@ def enable_tasks( queue = InMemoryTaskMessageQueue() self._task_support = TaskSupport(store=store, queue=queue) - - # Auto-register default handlers - self._register_default_task_handlers() - - return self._task_support - - def _register_default_task_handlers(self) -> None: - """Register default handlers for task operations.""" - assert self._task_support is not None - support = self._task_support - - # Register get_task handler if not already registered - if GetTaskRequest not in self._request_handlers: - - async def _default_get_task(req: GetTaskRequest) -> ServerResult: - task = await support.store.get_task(req.params.task_id) + task_support = self._task_support + + # Register user-provided handlers + if on_get_task is not None: + self._add_request_handler("tasks/get", on_get_task) + if on_task_result is not None: + self._add_request_handler("tasks/result", on_task_result) + if on_list_tasks is not None: + self._add_request_handler("tasks/list", on_list_tasks) + if on_cancel_task is not None: + self._add_request_handler("tasks/cancel", on_cancel_task) + + # Fill in defaults for any not provided + if not self._has_handler("tasks/get"): + + async def _default_get_task( + ctx: ServerRequestContext[LifespanResultT], params: GetTaskRequestParams + ) -> GetTaskResult: + task = await task_support.store.get_task(params.task_id) if task is None: - raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {req.params.task_id}") + raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {params.task_id}") return GetTaskResult( task_id=task.task_id, status=task.status, @@ -145,136 +172,39 @@ async def _default_get_task(req: GetTaskRequest) -> ServerResult: poll_interval=task.poll_interval, ) - self._request_handlers[GetTaskRequest] = _default_get_task + self._add_request_handler("tasks/get", _default_get_task) - # Register get_task_result handler if not already registered - if GetTaskPayloadRequest not in self._request_handlers: + if not self._has_handler("tasks/result"): - async def _default_get_task_result(req: GetTaskPayloadRequest) -> GetTaskPayloadResult: - ctx = self._server.request_context - result = await support.handler.handle(req, ctx.session, ctx.request_id) + async def _default_get_task_result( + ctx: ServerRequestContext[LifespanResultT], params: GetTaskPayloadRequestParams + ) -> GetTaskPayloadResult: + assert ctx.request_id is not None + req = GetTaskPayloadRequest(params=params) + result = await task_support.handler.handle(req, ctx.session, ctx.request_id) return result - self._request_handlers[GetTaskPayloadRequest] = _default_get_task_result + self._add_request_handler("tasks/result", _default_get_task_result) - # Register list_tasks handler if not already registered - if ListTasksRequest not in self._request_handlers: + if not self._has_handler("tasks/list"): - async def _default_list_tasks(req: ListTasksRequest) -> ListTasksResult: - cursor = req.params.cursor if req.params else None - tasks, next_cursor = await support.store.list_tasks(cursor) + async def _default_list_tasks( + ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None + ) -> ListTasksResult: + cursor = params.cursor if params else None + tasks, next_cursor = await task_support.store.list_tasks(cursor) return ListTasksResult(tasks=tasks, next_cursor=next_cursor) - self._request_handlers[ListTasksRequest] = _default_list_tasks - - # Register cancel_task handler if not already registered - if CancelTaskRequest not in self._request_handlers: - - async def _default_cancel_task(req: CancelTaskRequest) -> CancelTaskResult: - result = await cancel_task(support.store, req.params.task_id) - return result - - self._request_handlers[CancelTaskRequest] = _default_cancel_task - - def list_tasks( - self, - ) -> Callable[ - [Callable[[ListTasksRequest], Awaitable[ListTasksResult]]], - Callable[[ListTasksRequest], Awaitable[ListTasksResult]], - ]: - """Register a handler for listing tasks. - - WARNING: This API is experimental and may change without notice. - """ - - def decorator( - func: Callable[[ListTasksRequest], Awaitable[ListTasksResult]], - ) -> Callable[[ListTasksRequest], Awaitable[ListTasksResult]]: - logger.debug("Registering handler for ListTasksRequest") - wrapper = create_call_wrapper(func, ListTasksRequest) - - async def handler(req: ListTasksRequest) -> ListTasksResult: - result = await wrapper(req) - return result - - self._request_handlers[ListTasksRequest] = handler - return func - - return decorator - - def get_task( - self, - ) -> Callable[ - [Callable[[GetTaskRequest], Awaitable[GetTaskResult]]], Callable[[GetTaskRequest], Awaitable[GetTaskResult]] - ]: - """Register a handler for getting task status. - - WARNING: This API is experimental and may change without notice. - """ - - def decorator( - func: Callable[[GetTaskRequest], Awaitable[GetTaskResult]], - ) -> Callable[[GetTaskRequest], Awaitable[GetTaskResult]]: - logger.debug("Registering handler for GetTaskRequest") - wrapper = create_call_wrapper(func, GetTaskRequest) - - async def handler(req: GetTaskRequest) -> GetTaskResult: - result = await wrapper(req) - return result - - self._request_handlers[GetTaskRequest] = handler - return func - - return decorator - - def get_task_result( - self, - ) -> Callable[ - [Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]]], - Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]], - ]: - """Register a handler for getting task results/payload. - - WARNING: This API is experimental and may change without notice. - """ - - def decorator( - func: Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]], - ) -> Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]]: - logger.debug("Registering handler for GetTaskPayloadRequest") - wrapper = create_call_wrapper(func, GetTaskPayloadRequest) - - async def handler(req: GetTaskPayloadRequest) -> GetTaskPayloadResult: - result = await wrapper(req) - return result - - self._request_handlers[GetTaskPayloadRequest] = handler - return func - - return decorator - - def cancel_task( - self, - ) -> Callable[ - [Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]]], - Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]], - ]: - """Register a handler for cancelling tasks. - - WARNING: This API is experimental and may change without notice. - """ + self._add_request_handler("tasks/list", _default_list_tasks) - def decorator( - func: Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]], - ) -> Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]]: - logger.debug("Registering handler for CancelTaskRequest") - wrapper = create_call_wrapper(func, CancelTaskRequest) + if not self._has_handler("tasks/cancel"): - async def handler(req: CancelTaskRequest) -> CancelTaskResult: - result = await wrapper(req) + async def _default_cancel_task( + ctx: ServerRequestContext[LifespanResultT], params: CancelTaskRequestParams + ) -> CancelTaskResult: + result = await cancel_task(task_support.store, params.task_id) return result - self._request_handlers[CancelTaskRequest] = handler - return func + self._add_request_handler("tasks/cancel", _default_cancel_task) - return decorator + return task_support diff --git a/src/mcp/server/lowlevel/func_inspection.py b/src/mcp/server/lowlevel/func_inspection.py deleted file mode 100644 index d176970902..0000000000 --- a/src/mcp/server/lowlevel/func_inspection.py +++ /dev/null @@ -1,53 +0,0 @@ -import inspect -from collections.abc import Callable -from typing import Any, TypeVar, get_type_hints - -T = TypeVar("T") -R = TypeVar("R") - - -def create_call_wrapper(func: Callable[..., R], request_type: type[T]) -> Callable[[T], R]: - """Create a wrapper function that knows how to call func with the request object. - - Returns a wrapper function that takes the request and calls func appropriately. - - The wrapper handles three calling patterns: - 1. Positional-only parameter typed as request_type (no default): func(req) - 2. Positional/keyword parameter typed as request_type (no default): func(**{param_name: req}) - 3. No request parameter or parameter with default: func() - """ - try: - sig = inspect.signature(func) - type_hints = get_type_hints(func) - except (ValueError, TypeError, NameError): # pragma: no cover - return lambda _: func() - - # Check for positional-only parameter typed as request_type - for param_name, param in sig.parameters.items(): - if param.kind == inspect.Parameter.POSITIONAL_ONLY: - param_type = type_hints.get(param_name) - if param_type == request_type: # pragma: no branch - # Check if it has a default - if so, treat as old style - if param.default is not inspect.Parameter.empty: # pragma: no cover - return lambda _: func() - # Found positional-only parameter with correct type and no default - return lambda req: func(req) - - # Check for any positional/keyword parameter typed as request_type - for param_name, param in sig.parameters.items(): - if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY): # pragma: no branch - param_type = type_hints.get(param_name) - if param_type == request_type: - # Check if it has a default - if so, treat as old style - if param.default is not inspect.Parameter.empty: # pragma: no cover - return lambda _: func() - - # Found keyword parameter with correct type and no default - # Need to capture param_name in closure properly - def make_keyword_wrapper(name: str) -> Callable[[Any], Any]: - return lambda req: func(**{name: req}) - - return make_keyword_wrapper(param_name) - - # No request parameter found - use old style - return lambda _: func() diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 7bd79bb37c..5e4e2e6f5b 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -2,83 +2,50 @@ This module provides a framework for creating an MCP (Model Context Protocol) server. It allows you to easily define and handle various types of requests and notifications -in an asynchronous manner. +using constructor-based handler registration. Usage: -1. Create a Server instance: - server = Server("your_server_name") - -2. Define request handlers using decorators: - @server.list_prompts() - async def handle_list_prompts(request: types.ListPromptsRequest) -> types.ListPromptsResult: - # Implementation - - @server.get_prompt() - async def handle_get_prompt( - name: str, arguments: dict[str, str] | None - ) -> types.GetPromptResult: - # Implementation - - @server.list_tools() - async def handle_list_tools(request: types.ListToolsRequest) -> types.ListToolsResult: - # Implementation - - @server.call_tool() - async def handle_call_tool( - name: str, arguments: dict | None - ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: - # Implementation - - @server.list_resource_templates() - async def handle_list_resource_templates() -> list[types.ResourceTemplate]: - # Implementation - -3. Define notification handlers if needed: - @server.progress_notification() - async def handle_progress( - progress_token: str | int, progress: float, total: float | None, - message: str | None - ) -> None: - # Implementation - -4. Run the server: +1. Define handler functions: + async def my_list_tools(ctx, params): + return types.ListToolsResult(tools=[...]) + + async def my_call_tool(ctx, params): + return types.CallToolResult(content=[...]) + +2. Create a Server instance with on_* handlers: + server = Server( + "your_server_name", + on_list_tools=my_list_tools, + on_call_tool=my_call_tool, + ) + +3. Run the server: async def main(): async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="your_server_name", - server_version="your_version", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) asyncio.run(main()) -The Server class provides methods to register handlers for various MCP requests and -notifications. It automatically manages the request context and handles incoming -messages from the client. +The Server class dispatches incoming requests and notifications to registered +handler callables by method string. """ from __future__ import annotations -import base64 import contextvars -import json import logging import warnings -from collections.abc import AsyncIterator, Awaitable, Callable, Iterable +from collections.abc import AsyncIterator, Awaitable, Callable from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager from importlib.metadata import version as importlib_version -from typing import Any, Generic, TypeAlias, cast +from typing import Any, Generic, cast import anyio -import jsonschema -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from opentelemetry.trace import SpanKind, StatusCode from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware @@ -94,30 +61,20 @@ async def main(): from mcp.server.context import ServerRequestContext from mcp.server.experimental.request_context import Experimental from mcp.server.lowlevel.experimental import ExperimentalHandlers -from mcp.server.lowlevel.func_inspection import create_call_wrapper -from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings -from mcp.shared.exceptions import MCPError, UrlElicitationRequiredError +from mcp.shared._otel import extract_trace_context, otel_span +from mcp.shared._stream_protocols import ReadStream, WriteStream +from mcp.shared.exceptions import MCPError from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder -from mcp.shared.tool_name_validation import validate_and_warn_tool_name logger = logging.getLogger(__name__) LifespanResultT = TypeVar("LifespanResultT", default=Any) -RequestT = TypeVar("RequestT", default=Any) - -# type aliases for tool call results -StructuredContent: TypeAlias = dict[str, Any] -UnstructuredContent: TypeAlias = Iterable[types.ContentBlock] -CombinationContent: TypeAlias = tuple[UnstructuredContent, StructuredContent] - -# This will be properly typed in each Server instance's context -request_ctx: contextvars.ContextVar[ServerRequestContext[Any, Any]] = contextvars.ContextVar("request_ctx") class NotificationOptions: @@ -128,22 +85,24 @@ def __init__(self, prompts_changed: bool = False, resources_changed: bool = Fals @asynccontextmanager -async def lifespan(_: Server[LifespanResultT, RequestT]) -> AsyncIterator[dict[str, Any]]: +async def lifespan(_: Server[LifespanResultT]) -> AsyncIterator[dict[str, Any]]: """Default lifespan context manager that does nothing. - Args: - server: The server instance this lifespan is managing - Returns: An empty context object """ yield {} -class Server(Generic[LifespanResultT, RequestT]): +async def _ping_handler(ctx: ServerRequestContext[Any], params: types.RequestParams | None) -> types.EmptyResult: + return types.EmptyResult() + + +class Server(Generic[LifespanResultT]): def __init__( self, name: str, + *, version: str | None = None, title: str | None = None, description: str | None = None, @@ -151,9 +110,80 @@ def __init__( website_url: str | None = None, icons: list[types.Icon] | None = None, lifespan: Callable[ - [Server[LifespanResultT, RequestT]], + [Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT], ] = lifespan, + # Request handlers + on_list_tools: Callable[ + [ServerRequestContext[LifespanResultT], types.PaginatedRequestParams | None], + Awaitable[types.ListToolsResult], + ] + | None = None, + on_call_tool: Callable[ + [ServerRequestContext[LifespanResultT], types.CallToolRequestParams], + Awaitable[types.CallToolResult | types.CreateTaskResult], + ] + | None = None, + on_list_resources: Callable[ + [ServerRequestContext[LifespanResultT], types.PaginatedRequestParams | None], + Awaitable[types.ListResourcesResult], + ] + | None = None, + on_list_resource_templates: Callable[ + [ServerRequestContext[LifespanResultT], types.PaginatedRequestParams | None], + Awaitable[types.ListResourceTemplatesResult], + ] + | None = None, + on_read_resource: Callable[ + [ServerRequestContext[LifespanResultT], types.ReadResourceRequestParams], + Awaitable[types.ReadResourceResult], + ] + | None = None, + on_subscribe_resource: Callable[ + [ServerRequestContext[LifespanResultT], types.SubscribeRequestParams], + Awaitable[types.EmptyResult], + ] + | None = None, + on_unsubscribe_resource: Callable[ + [ServerRequestContext[LifespanResultT], types.UnsubscribeRequestParams], + Awaitable[types.EmptyResult], + ] + | None = None, + on_list_prompts: Callable[ + [ServerRequestContext[LifespanResultT], types.PaginatedRequestParams | None], + Awaitable[types.ListPromptsResult], + ] + | None = None, + on_get_prompt: Callable[ + [ServerRequestContext[LifespanResultT], types.GetPromptRequestParams], + Awaitable[types.GetPromptResult], + ] + | None = None, + on_completion: Callable[ + [ServerRequestContext[LifespanResultT], types.CompleteRequestParams], + Awaitable[types.CompleteResult], + ] + | None = None, + on_set_logging_level: Callable[ + [ServerRequestContext[LifespanResultT], types.SetLevelRequestParams], + Awaitable[types.EmptyResult], + ] + | None = None, + on_ping: Callable[ + [ServerRequestContext[LifespanResultT], types.RequestParams | None], + Awaitable[types.EmptyResult], + ] = _ping_handler, + # Notification handlers + on_roots_list_changed: Callable[ + [ServerRequestContext[LifespanResultT], types.NotificationParams | None], + Awaitable[None], + ] + | None = None, + on_progress: Callable[ + [ServerRequestContext[LifespanResultT], types.ProgressNotificationParams], + Awaitable[None], + ] + | None = None, ): self.name = name self.version = version @@ -163,15 +193,64 @@ def __init__( self.website_url = website_url self.icons = icons self.lifespan = lifespan - self.request_handlers: dict[type, Callable[..., Awaitable[types.ServerResult]]] = { - types.PingRequest: _ping_handler, - } - self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} - self._tool_cache: dict[str, types.Tool] = {} - self._experimental_handlers: ExperimentalHandlers | None = None + self._request_handlers: dict[str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]]] = {} + self._notification_handlers: dict[ + str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[None]] + ] = {} + self._experimental_handlers: ExperimentalHandlers[LifespanResultT] | None = None self._session_manager: StreamableHTTPSessionManager | None = None logger.debug("Initializing server %r", name) + # Populate internal handler dicts from on_* kwargs + self._request_handlers.update( + { + method: handler + for method, handler in { + "ping": on_ping, + "prompts/list": on_list_prompts, + "prompts/get": on_get_prompt, + "resources/list": on_list_resources, + "resources/templates/list": on_list_resource_templates, + "resources/read": on_read_resource, + "resources/subscribe": on_subscribe_resource, + "resources/unsubscribe": on_unsubscribe_resource, + "tools/list": on_list_tools, + "tools/call": on_call_tool, + "logging/setLevel": on_set_logging_level, + "completion/complete": on_completion, + }.items() + if handler is not None + } + ) + + self._notification_handlers.update( + { + method: handler + for method, handler in { + "notifications/roots/list_changed": on_roots_list_changed, + "notifications/progress": on_progress, + }.items() + if handler is not None + } + ) + + def _add_request_handler( + self, + method: str, + handler: Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]], + ) -> None: + """Add a request handler, silently replacing any existing handler for the same method.""" + self._request_handlers[method] = handler + + def _has_handler(self, method: str) -> bool: + """Check if a handler is registered for the given method.""" + return method in self._request_handlers or method in self._notification_handlers + + # TODO: Rethink capabilities API. Currently capabilities are derived from registered + # handlers but require NotificationOptions to be passed externally for list_changed + # flags, and experimental_capabilities as a separate dict. Consider deriving capabilities + # entirely from server state (e.g. constructor params for list_changed) instead of + # requiring callers to assemble them at create_initialization_options() time. def create_initialization_options( self, notification_options: NotificationOptions | None = None, @@ -214,25 +293,26 @@ def get_capabilities( completions_capability = None # Set prompt capabilities if handler exists - if types.ListPromptsRequest in self.request_handlers: + if "prompts/list" in self._request_handlers: prompts_capability = types.PromptsCapability(list_changed=notification_options.prompts_changed) # Set resource capabilities if handler exists - if types.ListResourcesRequest in self.request_handlers: + if "resources/list" in self._request_handlers: resources_capability = types.ResourcesCapability( - subscribe=False, list_changed=notification_options.resources_changed + subscribe="resources/subscribe" in self._request_handlers, + list_changed=notification_options.resources_changed, ) # Set tool capabilities if handler exists - if types.ListToolsRequest in self.request_handlers: + if "tools/list" in self._request_handlers: tools_capability = types.ToolsCapability(list_changed=notification_options.tools_changed) # Set logging capabilities if handler exists - if types.SetLevelRequest in self.request_handlers: + if "logging/setLevel" in self._request_handlers: logging_capability = types.LoggingCapability() # Set completions capabilities if handler exists - if types.CompleteRequest in self.request_handlers: + if "completion/complete" in self._request_handlers: completions_capability = types.CompletionsCapability() capabilities = types.ServerCapabilities( @@ -248,12 +328,7 @@ def get_capabilities( return capabilities @property - def request_context(self) -> ServerRequestContext[LifespanResultT, RequestT]: - """If called outside of a request context, this will raise a LookupError.""" - return request_ctx.get() - - @property - def experimental(self) -> ExperimentalHandlers: + def experimental(self) -> ExperimentalHandlers[LifespanResultT]: """Experimental APIs for tasks and other features. WARNING: These APIs are experimental and may change without notice. @@ -261,7 +336,10 @@ def experimental(self) -> ExperimentalHandlers: # We create this inline so we only add these capabilities _if_ they're actually used if self._experimental_handlers is None: - self._experimental_handlers = ExperimentalHandlers(self, self.request_handlers, self.notification_handlers) + self._experimental_handlers = ExperimentalHandlers( + add_request_handler=self._add_request_handler, + has_handler=self._has_handler, + ) return self._experimental_handlers @property @@ -271,385 +349,17 @@ def session_manager(self) -> StreamableHTTPSessionManager: Raises: RuntimeError: If called before streamable_http_app() has been called. """ - if self._session_manager is None: # pragma: no cover - raise RuntimeError( + if self._session_manager is None: + raise RuntimeError( # pragma: no cover "Session manager can only be accessed after calling streamable_http_app(). " "The session manager is created lazily to avoid unnecessary initialization." ) - return self._session_manager # pragma: no cover - - def list_prompts(self): - def decorator( - func: Callable[[], Awaitable[list[types.Prompt]]] - | Callable[[types.ListPromptsRequest], Awaitable[types.ListPromptsResult]], - ): - logger.debug("Registering handler for PromptListRequest") - - wrapper = create_call_wrapper(func, types.ListPromptsRequest) - - async def handler(req: types.ListPromptsRequest): - result = await wrapper(req) - # Handle both old style (list[Prompt]) and new style (ListPromptsResult) - if isinstance(result, types.ListPromptsResult): - return result - else: - # Old style returns list[Prompt] - return types.ListPromptsResult(prompts=result) - - self.request_handlers[types.ListPromptsRequest] = handler - return func - - return decorator - - def get_prompt(self): - def decorator( - func: Callable[[str, dict[str, str] | None], Awaitable[types.GetPromptResult]], - ): - logger.debug("Registering handler for GetPromptRequest") - - async def handler(req: types.GetPromptRequest): - prompt_get = await func(req.params.name, req.params.arguments) - return prompt_get - - self.request_handlers[types.GetPromptRequest] = handler - return func - - return decorator - - def list_resources(self): - def decorator( - func: Callable[[], Awaitable[list[types.Resource]]] - | Callable[[types.ListResourcesRequest], Awaitable[types.ListResourcesResult]], - ): - logger.debug("Registering handler for ListResourcesRequest") - - wrapper = create_call_wrapper(func, types.ListResourcesRequest) - - async def handler(req: types.ListResourcesRequest): - result = await wrapper(req) - # Handle both old style (list[Resource]) and new style (ListResourcesResult) - if isinstance(result, types.ListResourcesResult): - return result - else: - # Old style returns list[Resource] - return types.ListResourcesResult(resources=result) - - self.request_handlers[types.ListResourcesRequest] = handler - return func - - return decorator - - def list_resource_templates(self): - def decorator(func: Callable[[], Awaitable[list[types.ResourceTemplate]]]): - logger.debug("Registering handler for ListResourceTemplatesRequest") - - async def handler(_: Any): - templates = await func() - return types.ListResourceTemplatesResult(resource_templates=templates) - - self.request_handlers[types.ListResourceTemplatesRequest] = handler - return func - - return decorator - - def read_resource(self): - def decorator( - func: Callable[[str], Awaitable[str | bytes | Iterable[ReadResourceContents]]], - ): - logger.debug("Registering handler for ReadResourceRequest") - - async def handler(req: types.ReadResourceRequest): - result = await func(req.params.uri) - - def create_content(data: str | bytes, mime_type: str | None, meta: dict[str, Any] | None = None): - # Note: ResourceContents uses Field(alias="_meta"), so we must use the alias key - meta_kwargs: dict[str, Any] = {"_meta": meta} if meta is not None else {} - match data: - case str() as data: - return types.TextResourceContents( - uri=req.params.uri, - text=data, - mime_type=mime_type or "text/plain", - **meta_kwargs, - ) - case bytes() as data: # pragma: no branch - return types.BlobResourceContents( - uri=req.params.uri, - blob=base64.b64encode(data).decode(), - mime_type=mime_type or "application/octet-stream", - **meta_kwargs, - ) - - match result: - case str() | bytes() as data: # pragma: lax no cover - warnings.warn( - "Returning str or bytes from read_resource is deprecated. " - "Use Iterable[ReadResourceContents] instead.", - DeprecationWarning, - stacklevel=2, - ) - content = create_content(data, None) - case Iterable() as contents: - contents_list = [ - create_content( - content_item.content, content_item.mime_type, getattr(content_item, "meta", None) - ) - for content_item in contents - ] - return types.ReadResourceResult(contents=contents_list) - case _: # pragma: no cover - raise ValueError(f"Unexpected return type from read_resource: {type(result)}") - - return types.ReadResourceResult(contents=[content]) # pragma: no cover - - self.request_handlers[types.ReadResourceRequest] = handler - return func - - return decorator - - def set_logging_level(self): - def decorator(func: Callable[[types.LoggingLevel], Awaitable[None]]): - logger.debug("Registering handler for SetLevelRequest") - - async def handler(req: types.SetLevelRequest): - await func(req.params.level) - return types.EmptyResult() - - self.request_handlers[types.SetLevelRequest] = handler - return func - - return decorator - - def subscribe_resource(self): - def decorator(func: Callable[[str], Awaitable[None]]): - logger.debug("Registering handler for SubscribeRequest") - - async def handler(req: types.SubscribeRequest): - await func(req.params.uri) - return types.EmptyResult() - - self.request_handlers[types.SubscribeRequest] = handler - return func - - return decorator - - def unsubscribe_resource(self): - def decorator(func: Callable[[str], Awaitable[None]]): - logger.debug("Registering handler for UnsubscribeRequest") - - async def handler(req: types.UnsubscribeRequest): - await func(req.params.uri) - return types.EmptyResult() - - self.request_handlers[types.UnsubscribeRequest] = handler - return func - - return decorator - - def list_tools(self): - def decorator( - func: Callable[[], Awaitable[list[types.Tool]]] - | Callable[[types.ListToolsRequest], Awaitable[types.ListToolsResult]], - ): - logger.debug("Registering handler for ListToolsRequest") - - wrapper = create_call_wrapper(func, types.ListToolsRequest) - - async def handler(req: types.ListToolsRequest): - result = await wrapper(req) - - # Handle both old style (list[Tool]) and new style (ListToolsResult) - if isinstance(result, types.ListToolsResult): - # Refresh the tool cache with returned tools - for tool in result.tools: - validate_and_warn_tool_name(tool.name) - self._tool_cache[tool.name] = tool - return result - else: - # Old style returns list[Tool] - # Clear and refresh the entire tool cache - self._tool_cache.clear() - for tool in result: - validate_and_warn_tool_name(tool.name) - self._tool_cache[tool.name] = tool - return types.ListToolsResult(tools=result) - - self.request_handlers[types.ListToolsRequest] = handler - return func - - return decorator - - def _make_error_result(self, error_message: str) -> types.CallToolResult: - """Create a CallToolResult with an error.""" - return types.CallToolResult( - content=[types.TextContent(type="text", text=error_message)], - is_error=True, - ) - - async def _get_cached_tool_definition(self, tool_name: str) -> types.Tool | None: - """Get tool definition from cache, refreshing if necessary. - - Returns the Tool object if found, None otherwise. - """ - if tool_name not in self._tool_cache: - if types.ListToolsRequest in self.request_handlers: - logger.debug("Tool cache miss for %s, refreshing cache", tool_name) - await self.request_handlers[types.ListToolsRequest](None) - - tool = self._tool_cache.get(tool_name) - if tool is None: - logger.warning("Tool '%s' not listed, no validation will be performed", tool_name) - - return tool - - def call_tool(self, *, validate_input: bool = True): - """Register a tool call handler. - - Args: - validate_input: If True, validates input against inputSchema. Default is True. - - The handler validates input against inputSchema (if validate_input=True), calls the tool function, - and builds a CallToolResult with the results: - - Unstructured content (iterable of ContentBlock): returned in content - - Structured content (dict): returned in structuredContent, serialized JSON text returned in content - - Both: returned in content and structuredContent - - If outputSchema is defined, validates structuredContent or errors if missing. - """ - - def decorator( - func: Callable[ - [str, dict[str, Any]], - Awaitable[ - UnstructuredContent - | StructuredContent - | CombinationContent - | types.CallToolResult - | types.CreateTaskResult - ], - ], - ): - logger.debug("Registering handler for CallToolRequest") - - async def handler(req: types.CallToolRequest): - try: - tool_name = req.params.name - arguments = req.params.arguments or {} - tool = await self._get_cached_tool_definition(tool_name) - - # input validation - if validate_input and tool: - try: - jsonschema.validate(instance=arguments, schema=tool.input_schema) - except jsonschema.ValidationError as e: - return self._make_error_result(f"Input validation error: {e.message}") - - # tool call - results = await func(tool_name, arguments) - - # output normalization - unstructured_content: UnstructuredContent - maybe_structured_content: StructuredContent | None - if isinstance(results, types.CallToolResult): - return results - elif isinstance(results, types.CreateTaskResult): - # Task-augmented execution returns task info instead of result - return results - elif isinstance(results, tuple) and len(results) == 2: - # tool returned both structured and unstructured content - unstructured_content, maybe_structured_content = cast(CombinationContent, results) - elif isinstance(results, dict): - # tool returned structured content only - maybe_structured_content = cast(StructuredContent, results) - unstructured_content = [types.TextContent(type="text", text=json.dumps(results, indent=2))] - elif hasattr(results, "__iter__"): - # tool returned unstructured content only - unstructured_content = cast(UnstructuredContent, results) - maybe_structured_content = None - else: # pragma: no cover - return self._make_error_result(f"Unexpected return type from tool: {type(results).__name__}") - - # output validation - if tool and tool.output_schema is not None: - if maybe_structured_content is None: - return self._make_error_result( - "Output validation error: outputSchema defined but no structured output returned" - ) - else: - try: - jsonschema.validate(instance=maybe_structured_content, schema=tool.output_schema) - except jsonschema.ValidationError as e: - return self._make_error_result(f"Output validation error: {e.message}") - - # result - return types.CallToolResult( - content=list(unstructured_content), - structured_content=maybe_structured_content, - is_error=False, - ) - except UrlElicitationRequiredError: - # Re-raise UrlElicitationRequiredError so it can be properly handled - # by _handle_request, which converts it to an error response with code -32042 - raise - except Exception as e: - return self._make_error_result(str(e)) - - self.request_handlers[types.CallToolRequest] = handler - return func - - return decorator - - def progress_notification(self): - def decorator( - func: Callable[[str | int, float, float | None, str | None], Awaitable[None]], - ): - logger.debug("Registering handler for ProgressNotification") - - async def handler(req: types.ProgressNotification): - await func( - req.params.progress_token, - req.params.progress, - req.params.total, - req.params.message, - ) - - self.notification_handlers[types.ProgressNotification] = handler - return func - - return decorator - - def completion(self): - """Provides completions for prompts and resource templates""" - - def decorator( - func: Callable[ - [ - types.PromptReference | types.ResourceTemplateReference, - types.CompletionArgument, - types.CompletionContext | None, - ], - Awaitable[types.Completion | None], - ], - ): - logger.debug("Registering handler for CompleteRequest") - - async def handler(req: types.CompleteRequest): - completion = await func(req.params.ref, req.params.argument, req.params.context) - return types.CompleteResult( - completion=completion - if completion is not None - else types.Completion(values=[], total=None, has_more=None), - ) - - self.request_handlers[types.CompleteRequest] = handler - return func - - return decorator + return self._session_manager async def run( self, - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], initialization_options: InitializationOptions, # When False, exceptions are returned as messages to the client. # When True, exceptions are raised, which will cause the server to shut down @@ -680,16 +390,29 @@ async def run( await stack.enter_async_context(task_support.run()) async with anyio.create_task_group() as tg: - async for message in session.incoming_messages: - logger.debug("Received message: %s", message) - - tg.start_soon( - self._handle_message, - message, - session, - lifespan_context, - raise_exceptions, - ) + try: + async for message in session.incoming_messages: + logger.debug("Received message: %s", message) + + if isinstance(message, RequestResponder) and message.context is not None: + context = message.context + else: + context = contextvars.copy_context() + + context.run( + tg.start_soon, + self._handle_message, + message, + session, + lifespan_context, + raise_exceptions, + ) + finally: + # Transport closed: cancel in-flight handlers. Without this the + # TG join waits for them, and when they eventually try to + # respond they hit a closed write stream (the session's + # _receive_loop closed it when the read stream ended). + tg.cancel_scope.cancel() async def _handle_message( self, @@ -707,15 +430,10 @@ async def _handle_message( ) case Exception(): logger.error(f"Received exception from stream: {message}") - await session.send_log_message( - level="error", - data="Internal Server Error", - logger="mcp.server.exception_handler", - ) if raise_exceptions: raise message case _: - await self._handle_notification(message) + await self._handle_notification(message, session, lifespan_context) for warning in w: # pragma: lax no cover logger.info("Warning: %s: %s", warning.category.__name__, warning.message) @@ -730,30 +448,41 @@ async def _handle_request( ): logger.info("Processing request of type %s", type(req).__name__) - if handler := self.request_handlers.get(type(req)): - logger.debug("Dispatching request of type %s", type(req).__name__) + target = getattr(req.params, "name", None) if req.params else None + span_name = f"MCP handle {req.method} {target}" if target else f"MCP handle {req.method}" - token = None - try: - # Extract request context and close_sse_stream from message metadata - request_data = None - close_sse_stream_cb = None - close_standalone_sse_stream_cb = None - if message.message_metadata is not None and isinstance(message.message_metadata, ServerMessageMetadata): - request_data = message.message_metadata.request_context - close_sse_stream_cb = message.message_metadata.close_sse_stream - close_standalone_sse_stream_cb = message.message_metadata.close_standalone_sse_stream - - # Set our global state that can be retrieved via - # app.get_request_context() - client_capabilities = session.client_params.capabilities if session.client_params else None - task_support = self._experimental_handlers.task_support if self._experimental_handlers else None - # Get task metadata from request params if present - task_metadata = None - if hasattr(req, "params") and req.params is not None: - task_metadata = getattr(req.params, "task", None) - token = request_ctx.set( - ServerRequestContext( + # Extract W3C trace context from _meta (SEP-414). + meta = cast(dict[str, Any] | None, getattr(req.params, "meta", None)) if req.params else None + parent_context = extract_trace_context(meta) if meta is not None else None + + with otel_span( + span_name, + kind=SpanKind.SERVER, + attributes={"mcp.method.name": req.method, "jsonrpc.request.id": message.request_id}, + context=parent_context, + ) as span: + if handler := self._request_handlers.get(req.method): + logger.debug("Dispatching request of type %s", type(req).__name__) + + try: + # Extract request context and close_sse_stream from message metadata + request_data = None + close_sse_stream_cb = None + close_standalone_sse_stream_cb = None + if message.message_metadata is not None and isinstance( + message.message_metadata, ServerMessageMetadata + ): + request_data = message.message_metadata.request_context + close_sse_stream_cb = message.message_metadata.close_sse_stream + close_standalone_sse_stream_cb = message.message_metadata.close_standalone_sse_stream + + client_capabilities = session.client_params.capabilities if session.client_params else None + task_support = self._experimental_handlers.task_support if self._experimental_handlers else None + # Get task metadata from request params if present + task_metadata = None + if hasattr(req, "params") and req.params is not None: # pragma: no branch + task_metadata = getattr(req.params, "task", None) + ctx = ServerRequestContext( request_id=message.request_id, meta=message.request_meta, session=session, @@ -768,34 +497,65 @@ async def _handle_request( close_sse_stream=close_sse_stream_cb, close_standalone_sse_stream=close_standalone_sse_stream_cb, ) - ) - response = await handler(req) - except MCPError as err: - response = err.error - except anyio.get_cancelled_exc_class(): - logger.info("Request %s cancelled - duplicate response suppressed", message.request_id) + response = await handler(ctx, req.params) + except MCPError as err: + response = err.error + except anyio.get_cancelled_exc_class(): + if message.cancelled: + # Client sent CancelledNotification; responder.cancel() already + # sent an error response, so skip the duplicate. + logger.info("Request %s cancelled - duplicate response suppressed", message.request_id) + return + # Transport-close cancellation from the TG in run(); re-raise so the + # TG swallows its own cancellation. + raise + except Exception as err: + if raise_exceptions: # pragma: no cover + raise err + response = types.ErrorData(code=0, message=str(err)) + else: + response = types.ErrorData(code=types.METHOD_NOT_FOUND, message="Method not found") + + if isinstance(response, types.ErrorData) and span is not None: + span.set_status(StatusCode.ERROR, response.message) + + try: + await message.respond(response) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + # Transport closed between handler unblocking and respond. Happens + # when _receive_loop's finally wakes a handler blocked on + # send_request: the handler runs to respond() before run()'s TG + # cancel fires, but after the write stream closed. Closed if our + # end closed (_receive_loop's async-with exit); Broken if the peer + # end closed first (streamable_http terminate()). + logger.debug("Response for %s dropped - transport closed", message.request_id) return - except Exception as err: - if raise_exceptions: # pragma: no cover - raise err - response = types.ErrorData(code=0, message=str(err), data=None) - finally: - # Reset the global state after we are done - if token is not None: # pragma: no branch - request_ctx.reset(token) - - await message.respond(response) - else: # pragma: no cover - await message.respond(types.ErrorData(code=types.METHOD_NOT_FOUND, message="Method not found")) - - logger.debug("Response sent") - - async def _handle_notification(self, notify: Any): - if handler := self.notification_handlers.get(type(notify)): # type: ignore + + logger.debug("Response sent") + + async def _handle_notification( + self, + notify: types.ClientNotification, + session: ServerSession, + lifespan_context: LifespanResultT, + ) -> None: + if handler := self._notification_handlers.get(notify.method): logger.debug("Dispatching notification of type %s", type(notify).__name__) try: - await handler(notify) + client_capabilities = session.client_params.capabilities if session.client_params else None + task_support = self._experimental_handlers.task_support if self._experimental_handlers else None + ctx = ServerRequestContext( + session=session, + lifespan_context=lifespan_context, + experimental=Experimental( + task_metadata=None, + _client_capabilities=client_capabilities, + _session=session, + _task_support=task_support, + ), + ) + await handler(ctx, notify.params) except Exception: # pragma: no cover logger.exception("Uncaught exception in notification handler") @@ -843,7 +603,7 @@ def streamable_http_app( required_scopes: list[str] = [] # Set up auth if configured - if auth: # pragma: no cover + if auth: required_scopes = auth.required_scopes or [] # Add auth middleware if token verifier is available @@ -869,10 +629,10 @@ def streamable_http_app( ) # Set up routes with or without auth - if token_verifier: # pragma: no cover + if token_verifier: # Determine resource metadata URL resource_metadata_url = None - if auth and auth.resource_server_url: + if auth and auth.resource_server_url: # pragma: no branch # Build compliant metadata URL for WWW-Authenticate header resource_metadata_url = build_resource_metadata_url(auth.resource_server_url) @@ -892,7 +652,7 @@ def streamable_http_app( ) # Add protected resource metadata endpoint if configured as RS - if auth and auth.resource_server_url: # pragma: no cover + if auth and auth.resource_server_url: routes.extend( create_protected_resource_routes( resource_url=auth.resource_server_url, @@ -910,7 +670,3 @@ def streamable_http_app( middleware=middleware, lifespan=lambda app: session_manager.run(), ) - - -async def _ping_handler(request: types.PingRequest) -> types.ServerResult: - return types.EmptyResult() diff --git a/src/mcp/server/mcpserver/__init__.py b/src/mcp/server/mcpserver/__init__.py index f51c0b0ed8..0857e38bd4 100644 --- a/src/mcp/server/mcpserver/__init__.py +++ b/src/mcp/server/mcpserver/__init__.py @@ -2,7 +2,8 @@ from mcp.types import Icon -from .server import Context, MCPServer +from .context import Context +from .server import MCPServer from .utilities.types import Audio, Image __all__ = ["MCPServer", "Context", "Image", "Audio", "Icon"] diff --git a/src/mcp/server/mcpserver/context.py b/src/mcp/server/mcpserver/context.py new file mode 100644 index 0000000000..92de074d34 --- /dev/null +++ b/src/mcp/server/mcpserver/context.py @@ -0,0 +1,278 @@ +from __future__ import annotations + +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, Generic + +from pydantic import AnyUrl, BaseModel + +from mcp.server.context import LifespanContextT, RequestT, ServerRequestContext +from mcp.server.elicitation import ( + ElicitationResult, + ElicitSchemaModelT, + UrlElicitationResult, + elicit_url, + elicit_with_validation, +) +from mcp.server.lowlevel.helper_types import ReadResourceContents +from mcp.types import LoggingLevel + +if TYPE_CHECKING: + from mcp.server.mcpserver.server import MCPServer + + +class Context(BaseModel, Generic[LifespanContextT, RequestT]): + """Context object providing access to MCP capabilities. + + This provides a cleaner interface to MCP's RequestContext functionality. + It gets injected into tool and resource functions that request it via type hints. + + To use context in a tool function, add a parameter with the Context type annotation: + + ```python + @server.tool() + async def my_tool(x: int, ctx: Context) -> str: + # Log messages to the client + await ctx.info(f"Processing {x}") + await ctx.debug("Debug info") + await ctx.warning("Warning message") + await ctx.error("Error message") + + # Report progress + await ctx.report_progress(50, 100) + + # Access resources + data = await ctx.read_resource("resource://data") + + # Get request info + request_id = ctx.request_id + client_id = ctx.client_id + + return str(x) + ``` + + The context parameter name can be anything as long as it's annotated with Context. + The context is optional - tools that don't need it can omit the parameter. + """ + + _request_context: ServerRequestContext[LifespanContextT, RequestT] | None + _mcp_server: MCPServer | None + + # TODO(maxisbey): Consider making request_context/mcp_server required, or refactor Context entirely. + def __init__( + self, + *, + request_context: ServerRequestContext[LifespanContextT, RequestT] | None = None, + mcp_server: MCPServer | None = None, + # TODO(Marcelo): We should drop this kwargs parameter. + **kwargs: Any, + ): + super().__init__(**kwargs) + self._request_context = request_context + self._mcp_server = mcp_server + + @property + def mcp_server(self) -> MCPServer: + """Access to the MCPServer instance.""" + if self._mcp_server is None: # pragma: no cover + raise ValueError("Context is not available outside of a request") + return self._mcp_server # pragma: no cover + + @property + def request_context(self) -> ServerRequestContext[LifespanContextT, RequestT]: + """Access to the underlying request context.""" + if self._request_context is None: # pragma: no cover + raise ValueError("Context is not available outside of a request") + return self._request_context + + async def report_progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: + """Report progress for the current operation. + + Args: + progress: Current progress value (e.g., 24) + total: Optional total value (e.g., 100) + message: Optional message (e.g., "Starting render...") + """ + progress_token = self.request_context.meta.get("progress_token") if self.request_context.meta else None + + if progress_token is None: + return + + await self.request_context.session.send_progress_notification( + progress_token=progress_token, + progress=progress, + total=total, + message=message, + related_request_id=self.request_id, + ) + + async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]: + """Read a resource by URI. + + Args: + uri: Resource URI to read + + Returns: + The resource content as either text or bytes + """ + assert self._mcp_server is not None, "Context is not available outside of a request" + return await self._mcp_server.read_resource(uri, self) + + async def elicit( + self, + message: str, + schema: type[ElicitSchemaModelT], + ) -> ElicitationResult[ElicitSchemaModelT]: + """Elicit information from the client/user. + + This method can be used to interactively ask for additional information from the + client within a tool's execution. The client might display the message to the + user and collect a response according to the provided schema. If the client + is an agent, it might decide how to handle the elicitation -- either by asking + the user or automatically generating a response. + + Args: + message: Message to present to the user + schema: A Pydantic model class defining the expected response structure. + According to the specification, only primitive types are allowed. + + Returns: + An ElicitationResult containing the action taken and the data if accepted + + Note: + Check the result.action to determine if the user accepted, declined, or cancelled. + The result.data will only be populated if action is "accept" and validation succeeded. + """ + + return await elicit_with_validation( + session=self.request_context.session, + message=message, + schema=schema, + related_request_id=self.request_id, + ) + + async def elicit_url( + self, + message: str, + url: str, + elicitation_id: str, + ) -> UrlElicitationResult: + """Request URL mode elicitation from the client. + + This directs the user to an external URL for out-of-band interactions + that must not pass through the MCP client. Use this for: + - Collecting sensitive credentials (API keys, passwords) + - OAuth authorization flows with third-party services + - Payment and subscription flows + - Any interaction where data should not pass through the LLM context + + The response indicates whether the user consented to navigate to the URL. + The actual interaction happens out-of-band. When the elicitation completes, + call `ctx.session.send_elicit_complete(elicitation_id)` to notify the client. + + Args: + message: Human-readable explanation of why the interaction is needed + url: The URL the user should navigate to + elicitation_id: Unique identifier for tracking this elicitation + + Returns: + UrlElicitationResult indicating accept, decline, or cancel + """ + return await elicit_url( + session=self.request_context.session, + message=message, + url=url, + elicitation_id=elicitation_id, + related_request_id=self.request_id, + ) + + async def log( + self, + level: LoggingLevel, + data: Any, + *, + logger_name: str | None = None, + ) -> None: + """Send a log message to the client. + + Args: + level: Log level (debug, info, notice, warning, error, critical, + alert, emergency) + data: The data to be logged. Any JSON serializable type is allowed + (string, dict, list, number, bool, etc.) per the MCP specification. + logger_name: Optional logger name + """ + await self.request_context.session.send_log_message( + level=level, + data=data, + logger=logger_name, + related_request_id=self.request_id, + ) + + # TODO(maxisbey): see if this is needed otherwise remove + @property + def client_id(self) -> str | None: + """Get the client ID if available. + + Note: this reads from the MCP request's `_meta` params, not the OAuth + bearer token. For that, use `get_access_token().client_id`. + """ + return self.request_context.meta.get("client_id") if self.request_context.meta else None # pragma: no cover + + @property + def request_id(self) -> str: + """Get the unique ID for this request.""" + return str(self.request_context.request_id) + + @property + def session(self): + """Access to the underlying session for advanced usage.""" + return self.request_context.session + + async def close_sse_stream(self) -> None: + """Close the SSE stream to trigger client reconnection. + + This method closes the HTTP connection for the current request, triggering + client reconnection. Events continue to be stored in the event store and will + be replayed when the client reconnects with Last-Event-ID. + + Use this to implement polling behavior during long-running operations - + the client will reconnect after the retry interval specified in the priming event. + + Note: + This is a no-op if not using StreamableHTTP transport with event_store. + The callback is only available when event_store is configured. + """ + if self._request_context and self._request_context.close_sse_stream: # pragma: no branch + await self._request_context.close_sse_stream() + + async def close_standalone_sse_stream(self) -> None: + """Close the standalone GET SSE stream to trigger client reconnection. + + This method closes the HTTP connection for the standalone GET stream used + for unsolicited server-to-client notifications. The client SHOULD reconnect + with Last-Event-ID to resume receiving notifications. + + Note: + This is a no-op if not using StreamableHTTP transport with event_store. + Currently, client reconnection for standalone GET streams is NOT + implemented - this is a known gap. + """ + if self._request_context and self._request_context.close_standalone_sse_stream: # pragma: no cover + await self._request_context.close_standalone_sse_stream() + + # Convenience methods for common log levels + async def debug(self, data: Any, *, logger_name: str | None = None) -> None: + """Send a debug log message.""" + await self.log("debug", data, logger_name=logger_name) + + async def info(self, data: Any, *, logger_name: str | None = None) -> None: + """Send an info log message.""" + await self.log("info", data, logger_name=logger_name) + + async def warning(self, data: Any, *, logger_name: str | None = None) -> None: + """Send a warning log message.""" + await self.log("warning", data, logger_name=logger_name) + + async def error(self, data: Any, *, logger_name: str | None = None) -> None: + """Send an error log message.""" + await self.log("error", data, logger_name=logger_name) diff --git a/src/mcp/server/mcpserver/prompts/base.py b/src/mcp/server/mcpserver/prompts/base.py index 17744a6707..2f778eb514 100644 --- a/src/mcp/server/mcpserver/prompts/base.py +++ b/src/mcp/server/mcpserver/prompts/base.py @@ -2,20 +2,22 @@ from __future__ import annotations -import inspect +import functools from collections.abc import Awaitable, Callable, Sequence from typing import TYPE_CHECKING, Any, Literal +import anyio.to_thread import pydantic_core from pydantic import BaseModel, Field, TypeAdapter, validate_call from mcp.server.mcpserver.utilities.context_injection import find_context_parameter, inject_context from mcp.server.mcpserver.utilities.func_metadata import func_metadata +from mcp.shared._callable_inspection import is_async_callable from mcp.types import ContentBlock, Icon, TextContent if TYPE_CHECKING: from mcp.server.context import LifespanContextT, RequestT - from mcp.server.mcpserver.server import Context + from mcp.server.mcpserver.context import Context class Message(BaseModel): @@ -135,10 +137,14 @@ def from_function( async def render( self, - arguments: dict[str, Any] | None = None, - context: Context[LifespanContextT, RequestT] | None = None, + arguments: dict[str, Any] | None, + context: Context[LifespanContextT, RequestT], ) -> list[Message]: - """Render the prompt with arguments.""" + """Render the prompt with arguments. + + Raises: + ValueError: If required arguments are missing, or if rendering fails. + """ # Validate required arguments if self.arguments: required = {arg.name for arg in self.arguments if arg.required} @@ -151,10 +157,11 @@ async def render( # Add context to arguments if needed call_args = inject_context(self.fn, arguments or {}, context, self.context_kwarg) - # Call function and check if result is a coroutine - result = self.fn(**call_args) - if inspect.iscoroutine(result): - result = await result + fn = self.fn + if is_async_callable(fn): + result = await fn(**call_args) + else: + result = await anyio.to_thread.run_sync(functools.partial(self.fn, **call_args)) # Validate messages if not isinstance(result, list | tuple): @@ -178,5 +185,5 @@ async def render( raise ValueError(f"Could not convert prompt result to message: {msg}") return messages - except Exception as e: # pragma: no cover + except Exception as e: raise ValueError(f"Error rendering prompt {self.name}: {e}") diff --git a/src/mcp/server/mcpserver/prompts/manager.py b/src/mcp/server/mcpserver/prompts/manager.py index 21b9741318..28a7a6e98c 100644 --- a/src/mcp/server/mcpserver/prompts/manager.py +++ b/src/mcp/server/mcpserver/prompts/manager.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: from mcp.server.context import LifespanContextT, RequestT - from mcp.server.mcpserver.server import Context + from mcp.server.mcpserver.context import Context logger = get_logger(__name__) @@ -48,12 +48,12 @@ def add_prompt( async def render_prompt( self, name: str, - arguments: dict[str, Any] | None = None, - context: Context[LifespanContextT, RequestT] | None = None, + arguments: dict[str, Any] | None, + context: Context[LifespanContextT, RequestT], ) -> list[Message]: """Render a prompt by name with arguments.""" prompt = self.get_prompt(name) if not prompt: raise ValueError(f"Unknown prompt: {name}") - return await prompt.render(arguments, context=context) + return await prompt.render(arguments, context) diff --git a/src/mcp/server/mcpserver/resources/base.py b/src/mcp/server/mcpserver/resources/base.py index d3ccc425ef..d48e0695cb 100644 --- a/src/mcp/server/mcpserver/resources/base.py +++ b/src/mcp/server/mcpserver/resources/base.py @@ -23,11 +23,7 @@ class Resource(BaseModel, abc.ABC): name: str | None = Field(description="Name of the resource", default=None) title: str | None = Field(description="Human-readable title of the resource", default=None) description: str | None = Field(description="Description of the resource", default=None) - mime_type: str = Field( - default="text/plain", - description="MIME type of the resource content", - pattern=r"^[a-zA-Z0-9]+/[a-zA-Z0-9\-+.]+(;\s*[a-zA-Z0-9\-_.]+=[a-zA-Z0-9\-_.]+)*$", - ) + mime_type: str = Field(default="text/plain", description="MIME type of the resource content") icons: list[Icon] | None = Field(default=None, description="Optional list of icons for this resource") annotations: Annotations | None = Field(default=None, description="Optional annotations for the resource") meta: dict[str, Any] | None = Field(default=None, description="Optional metadata for this resource") diff --git a/src/mcp/server/mcpserver/resources/resource_manager.py b/src/mcp/server/mcpserver/resources/resource_manager.py index ed5b741239..766cf51aea 100644 --- a/src/mcp/server/mcpserver/resources/resource_manager.py +++ b/src/mcp/server/mcpserver/resources/resource_manager.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: from mcp.server.context import LifespanContextT, RequestT - from mcp.server.mcpserver.server import Context + from mcp.server.mcpserver.context import Context logger = get_logger(__name__) @@ -22,28 +22,26 @@ class ResourceManager: """Manages MCPServer resources.""" - def __init__(self, warn_on_duplicate_resources: bool = True): + def __init__(self, warn_on_duplicate_resources: bool = True, *, resources: list[Resource] | None = None): self._resources: dict[str, Resource] = {} self._templates: dict[str, ResourceTemplate] = {} self.warn_on_duplicate_resources = warn_on_duplicate_resources + for resource in resources or (): + self.add_resource(resource) + def add_resource(self, resource: Resource) -> Resource: """Add a resource to the manager. Args: - resource: A Resource instance to add + resource: A Resource instance to add. Returns: - The added resource. If a resource with the same URI already exists, - returns the existing resource. + The added resource. If a resource with the same URI already exists, returns the existing resource. """ logger.debug( "Adding resource", - extra={ - "uri": resource.uri, - "type": type(resource).__name__, - "resource_name": resource.name, - }, + extra={"uri": resource.uri, "type": type(resource).__name__, "resource_name": resource.name}, ) existing = self._resources.get(str(resource.uri)) if existing: @@ -80,9 +78,7 @@ def add_template( self._templates[template.uri_template] = template return template - async def get_resource( - self, uri: AnyUrl | str, context: Context[LifespanContextT, RequestT] | None = None - ) -> Resource: + async def get_resource(self, uri: AnyUrl | str, context: Context[LifespanContextT, RequestT]) -> Resource: """Get resource by URI, checking concrete resources first, then templates.""" uri_str = str(uri) logger.debug("Getting resource", extra={"uri": uri_str}) diff --git a/src/mcp/server/mcpserver/resources/templates.py b/src/mcp/server/mcpserver/resources/templates.py index e796823d9d..f1ee29a37f 100644 --- a/src/mcp/server/mcpserver/resources/templates.py +++ b/src/mcp/server/mcpserver/resources/templates.py @@ -2,22 +2,24 @@ from __future__ import annotations -import inspect +import functools import re from collections.abc import Callable from typing import TYPE_CHECKING, Any from urllib.parse import unquote +import anyio.to_thread from pydantic import BaseModel, Field, validate_call from mcp.server.mcpserver.resources.types import FunctionResource, Resource from mcp.server.mcpserver.utilities.context_injection import find_context_parameter, inject_context from mcp.server.mcpserver.utilities.func_metadata import func_metadata +from mcp.shared._callable_inspection import is_async_callable from mcp.types import Annotations, Icon if TYPE_CHECKING: from mcp.server.context import LifespanContextT, RequestT - from mcp.server.mcpserver.server import Context + from mcp.server.mcpserver.context import Context class ResourceTemplate(BaseModel): @@ -99,17 +101,22 @@ async def create_resource( self, uri: str, params: dict[str, Any], - context: Context[LifespanContextT, RequestT] | None = None, + context: Context[LifespanContextT, RequestT], ) -> Resource: - """Create a resource from the template with the given parameters.""" + """Create a resource from the template with the given parameters. + + Raises: + ValueError: If creating the resource fails. + """ try: # Add context to params if needed params = inject_context(self.fn, params, context, self.context_kwarg) - # Call function and check if result is a coroutine - result = self.fn(**params) - if inspect.iscoroutine(result): - result = await result + fn = self.fn + if is_async_callable(fn): + result = await fn(**params) + else: + result = await anyio.to_thread.run_sync(functools.partial(self.fn, **params)) return FunctionResource( uri=uri, # type: ignore diff --git a/src/mcp/server/mcpserver/resources/types.py b/src/mcp/server/mcpserver/resources/types.py index 64e6338060..d9e472e362 100644 --- a/src/mcp/server/mcpserver/resources/types.py +++ b/src/mcp/server/mcpserver/resources/types.py @@ -1,6 +1,7 @@ """Concrete resource implementations.""" -import inspect +from __future__ import annotations + import json from collections.abc import Callable from pathlib import Path @@ -14,6 +15,7 @@ from pydantic import Field, ValidationInfo, validate_call from mcp.server.mcpserver.resources.base import Resource +from mcp.shared._callable_inspection import is_async_callable from mcp.types import Annotations, Icon @@ -55,11 +57,11 @@ class FunctionResource(Resource): async def read(self) -> str | bytes: """Read the resource by calling the wrapped function.""" try: - # Call the function first to see if it returns a coroutine - result = self.fn() - # If it's a coroutine, await it - if inspect.iscoroutine(result): - result = await result + fn = self.fn + if is_async_callable(fn): + result = await fn() + else: + result = await anyio.to_thread.run_sync(self.fn) if isinstance(result, Resource): # pragma: no cover return await result.read() @@ -84,7 +86,7 @@ def from_function( icons: list[Icon] | None = None, annotations: Annotations | None = None, meta: dict[str, Any] | None = None, - ) -> "FunctionResource": + ) -> FunctionResource: """Create a FunctionResource from a function.""" func_name = name or fn.__name__ if func_name == "": # pragma: no cover @@ -109,7 +111,7 @@ def from_function( class FileResource(Resource): """A resource that reads from a file. - Set is_binary=True to read file as binary data instead of text. + Set is_binary=True to read the file as binary data instead of text. """ path: Path = Field(description="Path to the file") diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index 8c1fc342be..ec2365810e 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -2,7 +2,9 @@ from __future__ import annotations +import base64 import inspect +import json import re from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence from contextlib import AbstractAsyncContextManager, asynccontextmanager @@ -10,7 +12,6 @@ import anyio import pydantic_core -from pydantic import BaseModel from pydantic.networks import AnyUrl from pydantic_settings import BaseSettings, SettingsConfigDict from starlette.applications import Starlette @@ -25,12 +26,11 @@ from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware from mcp.server.auth.provider import OAuthAuthorizationServerProvider, ProviderTokenVerifier, TokenVerifier from mcp.server.auth.settings import AuthSettings -from mcp.server.context import LifespanContextT, RequestT, ServerRequestContext -from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, UrlElicitationResult, elicit_with_validation -from mcp.server.elicitation import elicit_url as _elicit_url +from mcp.server.context import ServerRequestContext from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.lowlevel.server import LifespanResultT, Server from mcp.server.lowlevel.server import lifespan as default_lifespan +from mcp.server.mcpserver.context import Context from mcp.server.mcpserver.exceptions import ResourceError from mcp.server.mcpserver.prompts import Prompt, PromptManager from mcp.server.mcpserver.resources import FunctionResource, Resource, ResourceManager @@ -42,7 +42,30 @@ from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings -from mcp.types import Annotations, ContentBlock, GetPromptResult, Icon, ToolAnnotations +from mcp.shared.exceptions import MCPError +from mcp.types import ( + Annotations, + BlobResourceContents, + CallToolRequestParams, + CallToolResult, + CompleteRequestParams, + CompleteResult, + Completion, + ContentBlock, + GetPromptRequestParams, + GetPromptResult, + Icon, + ListPromptsResult, + ListResourcesResult, + ListResourceTemplatesResult, + ListToolsResult, + PaginatedRequestParams, + ReadResourceRequestParams, + ReadResourceResult, + TextContent, + TextResourceContents, + ToolAnnotations, +) from mcp.types import Prompt as MCPPrompt from mcp.types import PromptArgument as MCPPromptArgument from mcp.types import Resource as MCPResource @@ -82,8 +105,11 @@ class Settings(BaseSettings, Generic[LifespanResultT]): # prompt settings warn_on_duplicate_prompts: bool + dependencies: list[str] + """List of dependencies to install in the server environment. Used by the `mcp install` and `mcp dev` CLI.""" + lifespan: Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None - """A async context manager that will be called when the server is started.""" + """An async context manager that will be called when the server is started.""" auth: AuthSettings | None @@ -91,9 +117,9 @@ class Settings(BaseSettings, Generic[LifespanResultT]): def lifespan_wrapper( app: MCPServer[LifespanResultT], lifespan: Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]], -) -> Callable[[Server[LifespanResultT, Request]], AbstractAsyncContextManager[LifespanResultT]]: +) -> Callable[[Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]]: @asynccontextmanager - async def wrap(_: Server[LifespanResultT, Request]) -> AsyncIterator[LifespanResultT]: + async def wrap(_: Server[LifespanResultT]) -> AsyncIterator[LifespanResultT]: async with lifespan(app) as context: yield context @@ -114,11 +140,13 @@ def __init__( token_verifier: TokenVerifier | None = None, *, tools: list[Tool] | None = None, + resources: list[Resource] | None = None, debug: bool = False, log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO", warn_on_duplicate_resources: bool = True, warn_on_duplicate_tools: bool = True, warn_on_duplicate_prompts: bool = True, + dependencies: list[str] | None = None, lifespan: Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None = None, auth: AuthSettings | None = None, ): @@ -128,10 +156,17 @@ def __init__( warn_on_duplicate_resources=warn_on_duplicate_resources, warn_on_duplicate_tools=warn_on_duplicate_tools, warn_on_duplicate_prompts=warn_on_duplicate_prompts, + dependencies=dependencies or [], lifespan=lifespan, auth=auth, ) + self.dependencies = self.settings.dependencies + self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools) + self._resource_manager = ResourceManager( + resources=resources, warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources + ) + self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts) self._lowlevel_server = Server( name=name or "mcp-server", title=title, @@ -140,13 +175,17 @@ def __init__( website_url=website_url, icons=icons, version=version, + on_list_tools=self._handle_list_tools, + on_call_tool=self._handle_call_tool, + on_list_resources=self._handle_list_resources, + on_read_resource=self._handle_read_resource, + on_list_resource_templates=self._handle_list_resource_templates, + on_list_prompts=self._handle_list_prompts, + on_get_prompt=self._handle_get_prompt, # TODO(Marcelo): It seems there's a type mismatch between the lifespan type from an MCPServer and Server. # We need to create a Lifespan type that is a generic on the server type, like Starlette does. lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan), # type: ignore ) - self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools) - self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources) - self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts) # Validate auth configuration if self.settings.auth is not None: if auth_server_provider and token_verifier: # pragma: no cover @@ -164,9 +203,6 @@ def __init__( self._token_verifier = ProviderTokenVerifier(auth_server_provider) self._custom_starlette_routes: list[Route] = [] - # Set up MCP protocol handlers - self._setup_handlers() - # Configure logging configure_logging(self.settings.log_level) @@ -208,7 +244,7 @@ def session_manager(self) -> StreamableHTTPSessionManager: Raises: RuntimeError: If called before streamable_http_app() has been called. """ - return self._lowlevel_server.session_manager # pragma: no cover + return self._lowlevel_server.session_manager @overload def run(self, transport: Literal["stdio"] = ...) -> None: ... @@ -263,18 +299,86 @@ def run( case "streamable-http": # pragma: no cover anyio.run(lambda: self.run_streamable_http_async(**kwargs)) - def _setup_handlers(self) -> None: - """Set up core MCP protocol handlers.""" - self._lowlevel_server.list_tools()(self.list_tools) - # Note: we disable the lowlevel server's input validation. - # MCPServer does ad hoc conversion of incoming data before validating - - # for now we preserve this for backwards compatibility. - self._lowlevel_server.call_tool(validate_input=False)(self.call_tool) - self._lowlevel_server.list_resources()(self.list_resources) - self._lowlevel_server.read_resource()(self.read_resource) - self._lowlevel_server.list_prompts()(self.list_prompts) - self._lowlevel_server.get_prompt()(self.get_prompt) - self._lowlevel_server.list_resource_templates()(self.list_resource_templates) + async def _handle_list_tools( + self, ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None + ) -> ListToolsResult: + return ListToolsResult(tools=await self.list_tools()) + + async def _handle_call_tool( + self, ctx: ServerRequestContext[LifespanResultT], params: CallToolRequestParams + ) -> CallToolResult: + context = Context(request_context=ctx, mcp_server=self) + try: + result = await self.call_tool(params.name, params.arguments or {}, context) + except MCPError: + raise + except Exception as e: + return CallToolResult(content=[TextContent(type="text", text=str(e))], is_error=True) + if isinstance(result, CallToolResult): + return result + if isinstance(result, tuple) and len(result) == 2: + unstructured_content, structured_content = result + return CallToolResult( + content=list(unstructured_content), # type: ignore[arg-type] + structured_content=structured_content, # type: ignore[arg-type] + ) + if isinstance(result, dict): # pragma: no cover + # TODO: this code path is unreachable — convert_result never returns a raw dict. + # The call_tool return type (Sequence[ContentBlock] | dict[str, Any]) is wrong + # and needs to be cleaned up. + return CallToolResult( + content=[TextContent(type="text", text=json.dumps(result, indent=2))], + structured_content=result, + ) + return CallToolResult(content=list(result)) + + async def _handle_list_resources( + self, ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult(resources=await self.list_resources()) + + async def _handle_read_resource( + self, ctx: ServerRequestContext[LifespanResultT], params: ReadResourceRequestParams + ) -> ReadResourceResult: + context = Context(request_context=ctx, mcp_server=self) + results = await self.read_resource(params.uri, context) + contents: list[TextResourceContents | BlobResourceContents] = [] + for item in results: + if isinstance(item.content, bytes): + contents.append( + BlobResourceContents( + uri=params.uri, + blob=base64.b64encode(item.content).decode(), + mime_type=item.mime_type or "application/octet-stream", + _meta=item.meta, + ) + ) + else: + contents.append( + TextResourceContents( + uri=params.uri, + text=item.content, + mime_type=item.mime_type or "text/plain", + _meta=item.meta, + ) + ) + return ReadResourceResult(contents=contents) + + async def _handle_list_resource_templates( + self, ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None + ) -> ListResourceTemplatesResult: + return ListResourceTemplatesResult(resource_templates=await self.list_resource_templates()) + + async def _handle_list_prompts( + self, ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None + ) -> ListPromptsResult: + return ListPromptsResult(prompts=await self.list_prompts()) + + async def _handle_get_prompt( + self, ctx: ServerRequestContext[LifespanResultT], params: GetPromptRequestParams + ) -> GetPromptResult: + context = Context(request_context=ctx, mcp_server=self) + return await self.get_prompt(params.name, params.arguments, context) async def list_tools(self) -> list[MCPTool]: """List all available tools.""" @@ -293,20 +397,13 @@ async def list_tools(self) -> list[MCPTool]: for info in tools ] - def get_context(self) -> Context[LifespanResultT, Request]: - """Returns a Context object. Note that the context will only be valid - during a request; outside a request, most methods will error. - """ - try: - request_context = self._lowlevel_server.request_context - except LookupError: - request_context = None - return Context(request_context=request_context, mcp_server=self) - - async def call_tool(self, name: str, arguments: dict[str, Any]) -> Sequence[ContentBlock] | dict[str, Any]: + async def call_tool( + self, name: str, arguments: dict[str, Any], context: Context[LifespanResultT, Any] | None = None + ) -> Sequence[ContentBlock] | dict[str, Any]: """Call a tool by name with arguments.""" - context = self.get_context() - return await self._tool_manager.call_tool(name, arguments, context=context, convert_result=True) + if context is None: + context = Context(mcp_server=self) + return await self._tool_manager.call_tool(name, arguments, context, convert_result=True) async def list_resources(self) -> list[MCPResource]: """List all available resources.""" @@ -342,14 +439,16 @@ async def list_resource_templates(self) -> list[MCPResourceTemplate]: for template in templates ] - async def read_resource(self, uri: AnyUrl | str) -> Iterable[ReadResourceContents]: + async def read_resource( + self, uri: AnyUrl | str, context: Context[LifespanResultT, Any] | None = None + ) -> Iterable[ReadResourceContents]: """Read a resource by URI.""" - - context = self.get_context() + if context is None: + context = Context(mcp_server=self) try: - resource = await self._resource_manager.get_resource(uri, context=context) - except ValueError: - raise ResourceError(f"Unknown resource: {uri}") + resource = await self._resource_manager.get_resource(uri, context) + except ValueError as exc: + raise ResourceError(f"Unknown resource: {uri}") from exc try: content = await resource.read() @@ -381,6 +480,8 @@ def add_tool( title: Optional human-readable title for the tool description: Optional description of what the tool does annotations: Optional ToolAnnotations providing additional tool information + icons: Optional list of icons for the tool + meta: Optional metadata dictionary for the tool structured_output: Controls whether the tool's output is structured or unstructured - If None, auto-detects based on the function's return type annotation - If True, creates a structured tool (return type annotation permitting) @@ -429,25 +530,33 @@ def tool( title: Optional human-readable title for the tool description: Optional description of what the tool does annotations: Optional ToolAnnotations providing additional tool information + icons: Optional list of icons for the tool + meta: Optional metadata dictionary for the tool structured_output: Controls whether the tool's output is structured or unstructured - If None, auto-detects based on the function's return type annotation - If True, creates a structured tool (return type annotation permitting) - If False, unconditionally creates an unstructured tool Example: + ```python @server.tool() def my_tool(x: int) -> str: return str(x) + ``` + ```python @server.tool() - def tool_with_context(x: int, ctx: Context) -> str: - ctx.info(f"Processing {x}") + async def tool_with_context(x: int, ctx: Context) -> str: + await ctx.info(f"Processing {x}") return str(x) + ``` + ```python @server.tool() async def async_tool(x: int, context: Context) -> str: await context.report_progress(50, 100) return str(x) + ``` """ # Check if user passed function directly instead of calling decorator if callable(name): @@ -479,14 +588,33 @@ def completion(self): - context: Optional CompletionContext with previously resolved arguments Example: + ```python @mcp.completion() async def handle_completion(ref, argument, context): if isinstance(ref, ResourceTemplateReference): # Return completions based on ref, argument, and context return Completion(values=["option1", "option2"]) return None + ``` """ - return self._lowlevel_server.completion() + + def decorator(func: _CallableT) -> _CallableT: + async def handler( + ctx: ServerRequestContext[LifespanResultT], params: CompleteRequestParams + ) -> CompleteResult: + result = await func(params.ref, params.argument, params.context) + return CompleteResult( + completion=result if result is not None else Completion(values=[], total=None, has_more=None), + ) + + # TODO(maxisbey): remove private access — completion needs post-construction + # handler registration, find a better pattern for this + self._lowlevel_server._add_request_handler( # pyright: ignore[reportPrivateUsage] + "completion/complete", handler + ) + return func + + return decorator def add_resource(self, resource: Resource) -> None: """Add a resource to the server. @@ -525,15 +653,18 @@ def resource( title: Optional human-readable title for the resource description: Optional description of the resource mime_type: Optional MIME type for the resource + icons: Optional list of icons for the resource + annotations: Optional annotations for the resource meta: Optional metadata dictionary for the resource Example: + ```python @server.resource("resource://my-resource") def get_data() -> str: return "Hello, world!" @server.resource("resource://my-resource") - async get_data() -> str: + async def get_data() -> str: data = await fetch_data() return f"Hello, world! {data}" @@ -545,6 +676,7 @@ def get_weather(city: str) -> str: async def get_weather(city: str) -> str: data = await fetch_weather(city) return f"Weather for {city}: {data}" + ``` """ # Check if user passed function directly instead of calling decorator if callable(uri): @@ -625,8 +757,10 @@ def prompt( name: Optional name for the prompt (defaults to function name) title: Optional human-readable title for the prompt description: Optional description of what the prompt does + icons: Optional list of icons for the prompt Example: + ```python @server.prompt() def analyze_table(table_name: str) -> list[Message]: schema = read_table_schema(table_name) @@ -652,6 +786,7 @@ async def analyze_file(path: str) -> list[Message]: } } ] + ``` """ # Check if user passed function directly instead of calling decorator if callable(name): @@ -693,9 +828,11 @@ def custom_route( include_in_schema: Whether to include in OpenAPI schema, defaults to True Example: + ```python @server.custom_route("/health", methods=["GET"]) async def health_check(request: Request) -> Response: return JSONResponse({"status": "ok"}) + ``` """ def decorator( # pragma: no cover @@ -953,14 +1090,18 @@ async def list_prompts(self) -> list[MCPPrompt]: for prompt in prompts ] - async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> GetPromptResult: + async def get_prompt( + self, name: str, arguments: dict[str, Any] | None = None, context: Context[LifespanResultT, Any] | None = None + ) -> GetPromptResult: """Get a prompt by name with arguments.""" + if context is None: + context = Context(mcp_server=self) try: prompt = self._prompt_manager.get_prompt(name) if not prompt: raise ValueError(f"Unknown prompt: {name}") - messages = await prompt.render(arguments, context=self.get_context()) + messages = await prompt.render(arguments, context) return GetPromptResult( description=prompt.description, @@ -968,264 +1109,4 @@ async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None) - ) except Exception as e: logger.exception(f"Error getting prompt {name}") - raise ValueError(str(e)) - - -class Context(BaseModel, Generic[LifespanContextT, RequestT]): - """Context object providing access to MCP capabilities. - - This provides a cleaner interface to MCP's RequestContext functionality. - It gets injected into tool and resource functions that request it via type hints. - - To use context in a tool function, add a parameter with the Context type annotation: - - ```python - @server.tool() - def my_tool(x: int, ctx: Context) -> str: - # Log messages to the client - ctx.info(f"Processing {x}") - ctx.debug("Debug info") - ctx.warning("Warning message") - ctx.error("Error message") - - # Report progress - ctx.report_progress(50, 100) - - # Access resources - data = ctx.read_resource("resource://data") - - # Get request info - request_id = ctx.request_id - client_id = ctx.client_id - - return str(x) - ``` - - The context parameter name can be anything as long as it's annotated with Context. - The context is optional - tools that don't need it can omit the parameter. - """ - - _request_context: ServerRequestContext[LifespanContextT, RequestT] | None - _mcp_server: MCPServer | None - - def __init__( - self, - *, - request_context: ServerRequestContext[LifespanContextT, RequestT] | None = None, - mcp_server: MCPServer | None = None, - # TODO(Marcelo): We should drop this kwargs parameter. - **kwargs: Any, - ): - super().__init__(**kwargs) - self._request_context = request_context - self._mcp_server = mcp_server - - @property - def mcp_server(self) -> MCPServer: - """Access to the MCPServer instance.""" - if self._mcp_server is None: # pragma: no cover - raise ValueError("Context is not available outside of a request") - return self._mcp_server # pragma: no cover - - @property - def request_context(self) -> ServerRequestContext[LifespanContextT, RequestT]: - """Access to the underlying request context.""" - if self._request_context is None: # pragma: no cover - raise ValueError("Context is not available outside of a request") - return self._request_context - - async def report_progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: - """Report progress for the current operation. - - Args: - progress: Current progress value e.g. 24 - total: Optional total value e.g. 100 - message: Optional message e.g. Starting render... - """ - progress_token = self.request_context.meta.get("progress_token") if self.request_context.meta else None - - if progress_token is None: # pragma: no cover - return - - await self.request_context.session.send_progress_notification( - progress_token=progress_token, - progress=progress, - total=total, - message=message, - ) - - async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]: - """Read a resource by URI. - - Args: - uri: Resource URI to read - - Returns: - The resource content as either text or bytes - """ - assert self._mcp_server is not None, "Context is not available outside of a request" - return await self._mcp_server.read_resource(uri) - - async def elicit( - self, - message: str, - schema: type[ElicitSchemaModelT], - ) -> ElicitationResult[ElicitSchemaModelT]: - """Elicit information from the client/user. - - This method can be used to interactively ask for additional information from the - client within a tool's execution. The client might display the message to the - user and collect a response according to the provided schema. Or in case a - client is an agent, it might decide how to handle the elicitation -- either by asking - the user or automatically generating a response. - - Args: - schema: A Pydantic model class defining the expected response structure, according to the specification, - only primitive types are allowed. - message: Optional message to present to the user. If not provided, will use - a default message based on the schema - - Returns: - An ElicitationResult containing the action taken and the data if accepted - - Note: - Check the result.action to determine if the user accepted, declined, or cancelled. - The result.data will only be populated if action is "accept" and validation succeeded. - """ - - return await elicit_with_validation( - session=self.request_context.session, - message=message, - schema=schema, - related_request_id=self.request_id, - ) - - async def elicit_url( - self, - message: str, - url: str, - elicitation_id: str, - ) -> UrlElicitationResult: - """Request URL mode elicitation from the client. - - This directs the user to an external URL for out-of-band interactions - that must not pass through the MCP client. Use this for: - - Collecting sensitive credentials (API keys, passwords) - - OAuth authorization flows with third-party services - - Payment and subscription flows - - Any interaction where data should not pass through the LLM context - - The response indicates whether the user consented to navigate to the URL. - The actual interaction happens out-of-band. When the elicitation completes, - call `self.session.send_elicit_complete(elicitation_id)` to notify the client. - - Args: - message: Human-readable explanation of why the interaction is needed - url: The URL the user should navigate to - elicitation_id: Unique identifier for tracking this elicitation - - Returns: - UrlElicitationResult indicating accept, decline, or cancel - """ - return await _elicit_url( - session=self.request_context.session, - message=message, - url=url, - elicitation_id=elicitation_id, - related_request_id=self.request_id, - ) - - async def log( - self, - level: Literal["debug", "info", "warning", "error"], - message: str, - *, - logger_name: str | None = None, - extra: dict[str, Any] | None = None, - ) -> None: - """Send a log message to the client. - - Args: - level: Log level (debug, info, warning, error) - message: Log message - logger_name: Optional logger name - extra: Optional dictionary with additional structured data to include - """ - - if extra: - log_data = {"message": message, **extra} - else: - log_data = message - - await self.request_context.session.send_log_message( - level=level, - data=log_data, - logger=logger_name, - related_request_id=self.request_id, - ) - - @property - def client_id(self) -> str | None: - """Get the client ID if available.""" - return self.request_context.meta.get("client_id") if self.request_context.meta else None # pragma: no cover - - @property - def request_id(self) -> str: - """Get the unique ID for this request.""" - return str(self.request_context.request_id) - - @property - def session(self): - """Access to the underlying session for advanced usage.""" - return self.request_context.session - - async def close_sse_stream(self) -> None: - """Close the SSE stream to trigger client reconnection. - - This method closes the HTTP connection for the current request, triggering - client reconnection. Events continue to be stored in the event store and will - be replayed when the client reconnects with Last-Event-ID. - - Use this to implement polling behavior during long-running operations - - client will reconnect after the retry interval specified in the priming event. - - Note: - This is a no-op if not using StreamableHTTP transport with event_store. - The callback is only available when event_store is configured. - """ - if self._request_context and self._request_context.close_sse_stream: # pragma: no cover - await self._request_context.close_sse_stream() - - async def close_standalone_sse_stream(self) -> None: - """Close the standalone GET SSE stream to trigger client reconnection. - - This method closes the HTTP connection for the standalone GET stream used - for unsolicited server-to-client notifications. The client SHOULD reconnect - with Last-Event-ID to resume receiving notifications. - - Note: - This is a no-op if not using StreamableHTTP transport with event_store. - Currently, client reconnection for standalone GET streams is NOT - implemented - this is a known gap. - """ - if self._request_context and self._request_context.close_standalone_sse_stream: # pragma: no cover - await self._request_context.close_standalone_sse_stream() - - # Convenience methods for common log levels - async def debug(self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None) -> None: - """Send a debug log message.""" - await self.log("debug", message, logger_name=logger_name, extra=extra) - - async def info(self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None) -> None: - """Send an info log message.""" - await self.log("info", message, logger_name=logger_name, extra=extra) - - async def warning( - self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None - ) -> None: - """Send a warning log message.""" - await self.log("warning", message, logger_name=logger_name, extra=extra) - - async def error(self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None) -> None: - """Send an error log message.""" - await self.log("error", message, logger_name=logger_name, extra=extra) + raise ValueError(str(e)) from e diff --git a/src/mcp/server/mcpserver/tools/base.py b/src/mcp/server/mcpserver/tools/base.py index f6bfadbc4d..754313eb8a 100644 --- a/src/mcp/server/mcpserver/tools/base.py +++ b/src/mcp/server/mcpserver/tools/base.py @@ -1,7 +1,5 @@ from __future__ import annotations -import functools -import inspect from collections.abc import Callable from functools import cached_property from typing import TYPE_CHECKING, Any @@ -11,13 +9,14 @@ from mcp.server.mcpserver.exceptions import ToolError from mcp.server.mcpserver.utilities.context_injection import find_context_parameter from mcp.server.mcpserver.utilities.func_metadata import FuncMetadata, func_metadata +from mcp.shared._callable_inspection import is_async_callable from mcp.shared.exceptions import UrlElicitationRequiredError from mcp.shared.tool_name_validation import validate_and_warn_tool_name from mcp.types import Icon, ToolAnnotations if TYPE_CHECKING: from mcp.server.context import LifespanContextT, RequestT - from mcp.server.mcpserver.server import Context + from mcp.server.mcpserver.context import Context class Tool(BaseModel): @@ -63,7 +62,7 @@ def from_function( raise ValueError("You must provide a name for lambda functions") func_doc = description or fn.__doc__ or "" - is_async = _is_async_callable(fn) + is_async = is_async_callable(fn) if context_kwarg is None: # pragma: no branch context_kwarg = find_context_parameter(fn) @@ -92,10 +91,14 @@ def from_function( async def run( self, arguments: dict[str, Any], - context: Context[LifespanContextT, RequestT] | None = None, + context: Context[LifespanContextT, RequestT], convert_result: bool = False, ) -> Any: - """Run the tool with arguments.""" + """Run the tool with arguments. + + Raises: + ToolError: If the tool function raises during execution. + """ try: result = await self.fn_metadata.call_fn_with_arg_validation( self.fn, @@ -114,12 +117,3 @@ async def run( raise except Exception as e: raise ToolError(f"Error executing tool {self.name}: {e}") from e - - -def _is_async_callable(obj: Any) -> bool: - while isinstance(obj, functools.partial): # pragma: lax no cover - obj = obj.func - - return inspect.iscoroutinefunction(obj) or ( - callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None)) - ) diff --git a/src/mcp/server/mcpserver/tools/tool_manager.py b/src/mcp/server/mcpserver/tools/tool_manager.py index c6f8384bdb..eef4911f9e 100644 --- a/src/mcp/server/mcpserver/tools/tool_manager.py +++ b/src/mcp/server/mcpserver/tools/tool_manager.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: from mcp.server.context import LifespanContextT, RequestT - from mcp.server.mcpserver.server import Context + from mcp.server.mcpserver.context import Context logger = get_logger(__name__) @@ -18,18 +18,12 @@ class ToolManager: """Manages MCPServer tools.""" - def __init__( - self, - warn_on_duplicate_tools: bool = True, - *, - tools: list[Tool] | None = None, - ): + def __init__(self, warn_on_duplicate_tools: bool = True, *, tools: list[Tool] | None = None): self._tools: dict[str, Tool] = {} - if tools is not None: - for tool in tools: - if warn_on_duplicate_tools and tool.name in self._tools: - logger.warning(f"Tool already exists: {tool.name}") - self._tools[tool.name] = tool + for tool in tools or (): + if warn_on_duplicate_tools and tool.name in self._tools: + logger.warning(f"Tool already exists: {tool.name}") + self._tools[tool.name] = tool self.warn_on_duplicate_tools = warn_on_duplicate_tools @@ -81,7 +75,7 @@ async def call_tool( self, name: str, arguments: dict[str, Any], - context: Context[LifespanContextT, RequestT] | None = None, + context: Context[LifespanContextT, RequestT], convert_result: bool = False, ) -> Any: """Call a tool by name with arguments.""" @@ -89,4 +83,4 @@ async def call_tool( if not tool: raise ToolError(f"Unknown tool: {name}") - return await tool.run(arguments, context=context, convert_result=convert_result) + return await tool.run(arguments, context, convert_result=convert_result) diff --git a/src/mcp/server/mcpserver/utilities/context_injection.py b/src/mcp/server/mcpserver/utilities/context_injection.py index 9cba83e860..ac7ab82d05 100644 --- a/src/mcp/server/mcpserver/utilities/context_injection.py +++ b/src/mcp/server/mcpserver/utilities/context_injection.py @@ -7,6 +7,8 @@ from collections.abc import Callable from typing import Any +from mcp.server.mcpserver.context import Context + def find_context_parameter(fn: Callable[..., Any]) -> str | None: """Find the parameter that should receive the Context object. @@ -20,8 +22,6 @@ def find_context_parameter(fn: Callable[..., Any]) -> str | None: Returns: The name of the context parameter, or None if not found """ - from mcp.server.mcpserver.server import Context - # Get type hints to properly resolve string annotations try: hints = typing.get_type_hints(fn) diff --git a/src/mcp/server/mcpserver/utilities/func_metadata.py b/src/mcp/server/mcpserver/utilities/func_metadata.py index 4b539ce1f2..4a76106371 100644 --- a/src/mcp/server/mcpserver/utilities/func_metadata.py +++ b/src/mcp/server/mcpserver/utilities/func_metadata.py @@ -9,7 +9,7 @@ import anyio import anyio.to_thread import pydantic_core -from pydantic import BaseModel, ConfigDict, Field, WithJsonSchema, create_model +from pydantic import BaseModel, ConfigDict, Field, PydanticUserError, WithJsonSchema, create_model from pydantic.fields import FieldInfo from pydantic.json_schema import GenerateJsonSchema, JsonSchemaWarningKind from typing_extensions import is_typeddict @@ -46,7 +46,7 @@ class ArgModelBase(BaseModel): def model_dump_one_level(self) -> dict[str, Any]: """Return a dict of the model's fields, one level deep. - That is, sub-models etc are not dumped - they are kept as pydantic models. + That is, sub-models etc are not dumped - they are kept as Pydantic models. """ kwargs: dict[str, Any] = {} for field_name, field_info in self.__class__.model_fields.items(): @@ -89,8 +89,7 @@ async def call_fn_with_arg_validation( return await anyio.to_thread.run_sync(functools.partial(fn, **arguments_parsed_dict)) def convert_result(self, result: Any) -> Any: - """Convert the result of a function call to the appropriate format for - the lowlevel server tool call handler: + """Convert a function call result to the format for the lowlevel tool call handler. - If output_model is None, return the unstructured content directly. - If output_model is not None, convert the result to structured output format @@ -126,11 +125,11 @@ def convert_result(self, result: Any) -> Any: def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]: """Pre-parse data from JSON. - Return a dict with same keys as input but with values parsed from JSON + Return a dict with the same keys as input but with values parsed from JSON if appropriate. This is to handle cases like `["a", "b", "c"]` being passed in as JSON inside - a string rather than an actual list. Claude desktop is prone to this - in fact + a string rather than an actual list. Claude Desktop is prone to this - in fact it seems incapable of NOT doing this. For sub-models, it tends to pass dicts (JSON objects) as JSON strings, which can be pre-parsed here. """ @@ -173,8 +172,7 @@ def func_metadata( skip_names: Sequence[str] = (), structured_output: bool | None = None, ) -> FuncMetadata: - """Given a function, return metadata including a pydantic model representing its - signature. + """Given a function, return metadata including a Pydantic model representing its signature. The use case for this is ``` @@ -183,11 +181,11 @@ def func_metadata( return func(**validated_args.model_dump_one_level()) ``` - **critically** it also provides pre-parse helper to attempt to parse things from + **critically** it also provides a pre-parse helper to attempt to parse things from JSON. Args: - func: The function to convert to a pydantic model + func: The function to convert to a Pydantic model skip_names: A list of parameter names to skip. These will not be included in the model. structured_output: Controls whether the tool's output is structured or unstructured @@ -195,8 +193,8 @@ def func_metadata( - If True, creates a structured tool (return type annotation permitting) - If False, unconditionally creates an unstructured tool - If structured, creates a Pydantic model for the function's result based on its annotation. - Supports various return types: + If structured, creates a Pydantic model for the function's result based on its annotation. + Supports various return types: - BaseModel subclasses (used directly) - Primitive types (str, int, float, bool, bytes, None) - wrapped in a model with a 'result' field @@ -206,9 +204,9 @@ def func_metadata( Returns: A FuncMetadata object containing: - - arg_model: A pydantic model representing the function's arguments - - output_model: A pydantic model for the return type if output is structured - - output_conversion: Records how function output should be converted before returning. + - arg_model: A Pydantic model representing the function's arguments + - output_model: A Pydantic model for the return type if the output is structured + - wrap_output: Whether the function result needs to be wrapped in `{"result": ...}` for structured output. """ try: sig = inspect.signature(func, eval_str=True) @@ -296,7 +294,7 @@ def func_metadata( ] # pragma: no cover else: # We only had `Annotated[CallToolResult, ReturnType]`, treat the original annotation - # as beging `ReturnType`: + # as being `ReturnType`: original_annotation = return_type_expr else: return FuncMetadata(arg_model=arguments_model) @@ -355,7 +353,7 @@ def _try_create_model_and_schema( if origin is dict: args = get_args(type_expr) if len(args) == 2 and args[0] is str: - # TODO: should we use the original annotation? We are loosing any potential `Annotated` + # TODO: should we use the original annotation? We are losing any potential `Annotated` # metadata for Pydantic here: model = _create_dict_model(func_name, type_expr) else: @@ -404,9 +402,16 @@ def _try_create_model_and_schema( # Use StrictJsonSchema to raise exceptions instead of warnings try: schema = model.model_json_schema(schema_generator=StrictJsonSchema) - except (TypeError, ValueError, pydantic_core.SchemaError, pydantic_core.ValidationError) as e: + except ( + PydanticUserError, + TypeError, + ValueError, + pydantic_core.SchemaError, + pydantic_core.ValidationError, + ) as e: # These are expected errors when a type can't be converted to a Pydantic schema - # TypeError: When Pydantic can't handle the type + # PydanticUserError: When Pydantic can't handle the type (e.g. PydanticInvalidForJsonSchema); + # subclasses TypeError on pydantic <2.13 and RuntimeError on pydantic >=2.13 # ValueError: When there are issues with the type definition (including our custom warnings) # SchemaError: When Pydantic can't build a schema # ValidationError: When validation fails diff --git a/src/mcp/server/mcpserver/utilities/logging.py b/src/mcp/server/mcpserver/utilities/logging.py index c394f2bfaf..04ca38853b 100644 --- a/src/mcp/server/mcpserver/utilities/logging.py +++ b/src/mcp/server/mcpserver/utilities/logging.py @@ -8,10 +8,10 @@ def get_logger(name: str) -> logging.Logger: """Get a logger nested under MCP namespace. Args: - name: the name of the logger + name: The name of the logger. Returns: - a configured logger instance + A configured logger instance. """ return logging.getLogger(name) @@ -22,7 +22,7 @@ def configure_logging( """Configure logging for MCP. Args: - level: the log level to use + level: The log level to use. """ handlers: list[logging.Handler] = [] try: diff --git a/src/mcp/server/models.py b/src/mcp/server/models.py index 41b9224c1d..3861f42a7e 100644 --- a/src/mcp/server/models.py +++ b/src/mcp/server/models.py @@ -1,4 +1,4 @@ -"""This module provides simpler types to use with the server for managing prompts +"""This module provides simplified types to use with the server for managing prompts and tools. """ diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index f496121a37..fc2f97a9cb 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -6,30 +6,22 @@ Common usage pattern: ``` - server = Server(name) - - @server.call_tool() - async def handle_tool_call(ctx: RequestContext, arguments: dict[str, Any]) -> Any: + async def handle_call_tool(ctx: RequestContext, params: CallToolRequestParams) -> CallToolResult: # Check client capabilities before proceeding if ctx.session.check_client_capability( types.ClientCapabilities(experimental={"advanced_tools": dict()}) ): - # Perform advanced tool operations - result = await perform_advanced_tool_operation(arguments) + result = await perform_advanced_tool_operation(params.arguments) else: - # Fall back to basic tool operations - result = await perform_basic_tool_operation(arguments) - + result = await perform_basic_tool_operation(params.arguments) return result - @server.list_prompts() - async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: - # Access session for any necessary checks or operations + async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult: if ctx.session.client_params: - # Customize prompts based on client initialization parameters - return generate_custom_prompts(ctx.session.client_params) - else: - return default_prompts + return ListPromptsResult(prompts=generate_custom_prompts(ctx.session.client_params)) + return ListPromptsResult(prompts=default_prompts) + + server = Server(name, on_call_tool=handle_call_tool, on_list_prompts=handle_list_prompts) ``` The ServerSession class is typically used internally by the Server class and should not @@ -41,13 +33,14 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: import anyio import anyio.lowlevel -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from anyio.streams.memory import MemoryObjectReceiveStream from pydantic import AnyUrl, TypeAdapter from mcp import types from mcp.server.experimental.session_features import ExperimentalServerSessionFeatures from mcp.server.models import InitializationOptions from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages +from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.exceptions import StatelessModeNotSupported from mcp.shared.experimental.tasks.capabilities import check_tasks_capability from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY @@ -87,8 +80,8 @@ class ServerSession( def __init__( self, - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], init_options: InitializationOptions, stateless: bool = False, ) -> None: @@ -230,7 +223,7 @@ async def send_log_message( related_request_id, ) - async def send_resource_updated(self, uri: str | AnyUrl) -> None: # pragma: no cover + async def send_resource_updated(self, uri: str | AnyUrl) -> None: """Send a resource updated notification.""" await self.send_notification( types.ResourceUpdatedNotification( @@ -371,12 +364,12 @@ async def elicit( """Send a form mode elicitation/create request. Args: - message: The message to present to the user - requested_schema: Schema defining the expected response structure - related_request_id: Optional ID of the request that triggered this elicitation + message: The message to present to the user. + requested_schema: Schema defining the expected response structure. + related_request_id: Optional ID of the request that triggered this elicitation. Returns: - The client's response + The client's response. Note: This method is deprecated in favor of elicit_form(). It remains for @@ -393,12 +386,12 @@ async def elicit_form( """Send a form mode elicitation/create request. Args: - message: The message to present to the user - requested_schema: Schema defining the expected response structure - related_request_id: Optional ID of the request that triggered this elicitation + message: The message to present to the user. + requested_schema: Schema defining the expected response structure. + related_request_id: Optional ID of the request that triggered this elicitation. Returns: - The client's response with form data + The client's response with form data. Raises: StatelessModeNotSupported: If called in stateless HTTP mode. @@ -429,13 +422,13 @@ async def elicit_url( like OAuth flows, credential collection, or payment processing. Args: - message: Human-readable explanation of why the interaction is needed - url: The URL the user should navigate to - elicitation_id: Unique identifier for tracking this elicitation - related_request_id: Optional ID of the request that triggered this elicitation + message: Human-readable explanation of why the interaction is needed. + url: The URL the user should navigate to. + elicitation_id: Unique identifier for tracking this elicitation. + related_request_id: Optional ID of the request that triggered this elicitation. Returns: - The client's response indicating acceptance, decline, or cancellation + The client's response indicating acceptance, decline, or cancellation. Raises: StatelessModeNotSupported: If called in stateless HTTP mode. @@ -454,7 +447,7 @@ async def elicit_url( metadata=ServerMessageMetadata(related_request_id=related_request_id), ) - async def send_ping(self) -> types.EmptyResult: # pragma: no cover + async def send_ping(self) -> types.EmptyResult: """Send a ping request.""" return await self.send_request( types.PingRequest(), @@ -482,15 +475,15 @@ async def send_progress_notification( related_request_id, ) - async def send_resource_list_changed(self) -> None: # pragma: no cover + async def send_resource_list_changed(self) -> None: """Send a resource list changed notification.""" await self.send_notification(types.ResourceListChangedNotification()) - async def send_tool_list_changed(self) -> None: # pragma: no cover + async def send_tool_list_changed(self) -> None: """Send a tool list changed notification.""" await self.send_notification(types.ToolListChangedNotification()) - async def send_prompt_list_changed(self) -> None: # pragma: no cover + async def send_prompt_list_changed(self) -> None: """Send a prompt list changed notification.""" await self.send_notification(types.PromptListChangedNotification()) @@ -507,7 +500,7 @@ async def send_elicit_complete( Args: elicitation_id: The unique identifier of the completed elicitation - related_request_id: Optional ID of the request that triggered this + related_request_id: Optional ID of the request that triggered this notification """ await self.send_notification( types.ElicitCompleteNotification( diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 5be6b78ca9..05e948332b 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -2,8 +2,8 @@ This module implements a Server-Sent Events (SSE) transport layer for MCP servers. -Example usage: -``` +Example: + ```python # Create an SSE transport at an endpoint sse = SseServerTransport("/messages/") @@ -27,10 +27,10 @@ async def handle_sse(request): # Create and run Starlette app starlette_app = Starlette(routes=routes) uvicorn.run(starlette_app, host="127.0.0.1", port=port) -``` + ``` -Note: The handle_sse function must return a Response to avoid a "TypeError: 'NoneType' -object is not callable" error when client disconnects. The example above returns +Note: The handle_sse function must return a Response to avoid a +"TypeError: 'NoneType' object is not callable" error when client disconnects. The example above returns an empty Response() after the SSE connection ends to fix this. See SseServerTransport class documentation for more details. @@ -43,7 +43,6 @@ async def handle_sse(request): from uuid import UUID, uuid4 import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import ValidationError from sse_starlette import EventSourceResponse from starlette.requests import Request @@ -51,18 +50,20 @@ async def handle_sse(request): from starlette.types import Receive, Scope, Send from mcp import types +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, AuthorizationContext, authorization_context from mcp.server.transport_security import ( TransportSecurityMiddleware, TransportSecuritySettings, ) +from mcp.shared._context_streams import ContextSendStream, create_context_streams from mcp.shared.message import ServerMessageMetadata, SessionMessage logger = logging.getLogger(__name__) class SseServerTransport: - """SSE server transport for MCP. This class provides _two_ ASGI applications, - suitable to be used with a framework like Starlette and a server like Hypercorn: + """SSE server transport for MCP. This class provides two ASGI applications, + suitable for use with a framework like Starlette and a server like Hypercorn: 1. connect_sse() is an ASGI application which receives incoming GET requests, and sets up a new SSE stream to send server messages to the client. @@ -72,7 +73,10 @@ class SseServerTransport: """ _endpoint: str - _read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]] + _read_stream_writers: dict[UUID, ContextSendStream[SessionMessage | Exception]] + # Identity of the credential that created each session; requests for a + # session must present the same credential. + _session_owners: dict[UUID, AuthorizationContext] _security: TransportSecurityMiddleware def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | None = None) -> None: @@ -112,11 +116,12 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | self._endpoint = endpoint self._read_stream_writers = {} + self._session_owners = {} self._security = TransportSecurityMiddleware(security_settings) logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}") @asynccontextmanager - async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # pragma: no cover + async def connect_sse(self, scope: Scope, receive: Receive, send: Send): if scope["type"] != "http": logger.error("connect_sse received non-HTTP request") raise ValueError("connect_sse can only handle HTTP requests") @@ -129,16 +134,14 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # prag raise ValueError("Request validation failed") logger.debug("Setting up SSE connection") - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] - read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - write_stream: MemoryObjectSendStream[SessionMessage] - write_stream_reader: MemoryObjectReceiveStream[SessionMessage] - - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) + write_stream, write_stream_reader = create_context_streams[SessionMessage](0) session_id = uuid4() + user = scope.get("user") + if isinstance(user, AuthenticatedUser): + self._session_owners[session_id] = authorization_context(user) self._read_stream_writers[session_id] = read_stream_writer logger.debug(f"Created new session with ID: {session_id}") @@ -170,31 +173,36 @@ async def sse_writer(): await sse_stream_writer.send( { "event": "message", - "data": session_message.message.model_dump_json(by_alias=True, exclude_none=True), + "data": session_message.message.model_dump_json(by_alias=True, exclude_unset=True), } ) - async with anyio.create_task_group() as tg: - - async def response_wrapper(scope: Scope, receive: Receive, send: Send): - """The EventSourceResponse returning signals a client close / disconnect. - In this case we close our side of the streams to signal the client that - the connection has been closed. - """ - await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)( - scope, receive, send - ) - await read_stream_writer.aclose() - await write_stream_reader.aclose() - logging.debug(f"Client session disconnected {session_id}") + try: + async with anyio.create_task_group() as tg: + + async def response_wrapper(scope: Scope, receive: Receive, send: Send): + """The EventSourceResponse returning signals a client close / disconnect. + In this case we close our side of the streams to signal the client that + the connection has been closed. + """ + await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)( + scope, receive, send + ) + await read_stream_writer.aclose() + await write_stream_reader.aclose() + await sse_stream_reader.aclose() + logging.debug(f"Client session disconnected {session_id}") - logger.debug("Starting SSE response task") - tg.start_soon(response_wrapper, scope, receive, send) + logger.debug("Starting SSE response task") + tg.start_soon(response_wrapper, scope, receive, send) - logger.debug("Yielding read and write streams") - yield (read_stream, write_stream) + logger.debug("Yielding read and write streams") + yield (read_stream, write_stream) + finally: + self._read_stream_writers.pop(session_id, None) + self._session_owners.pop(session_id, None) - async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: # pragma: no cover + async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: logger.debug("Handling POST message") request = Request(scope, receive) @@ -223,6 +231,15 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) response = Response("Could not find session", status_code=404) return await response(scope, receive, send) + user = scope.get("user") + requestor = authorization_context(user) if isinstance(user, AuthenticatedUser) else None + if requestor != self._session_owners.get(session_id): + # A session can only be used with the credential that created it. + # Respond exactly as if the session did not exist. + logger.warning("Rejecting message for session %s: credential does not match", session_id) + response = Response("Could not find session", status_code=404) + return await response(scope, receive, send) + body = await request.body() logger.debug(f"Received JSON: {body}") diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index 7f3aa2ac2f..5c1459dff6 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -4,8 +4,8 @@ that can be used to communicate with an MCP client through standard input/output streams. -Example usage: -``` +Example: + ```python async def run_server(): async with stdio_server() as (read_stream, write_stream): # read_stream contains incoming JSONRPCMessages from stdin @@ -14,7 +14,7 @@ async def run_server(): await server.run(read_stream, write_stream, init_options) anyio.run(run_server) -``` + ``` """ import sys @@ -23,9 +23,9 @@ async def run_server(): import anyio import anyio.lowlevel -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp import types +from mcp.shared._context_streams import create_context_streams from mcp.shared.message import SessionMessage @@ -39,18 +39,12 @@ async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio. # python is platform-dependent (Windows is particularly problematic), so we # re-wrap the underlying binary stream to ensure UTF-8. if not stdin: - stdin = anyio.wrap_file(TextIOWrapper(sys.stdin.buffer, encoding="utf-8")) + stdin = anyio.wrap_file(TextIOWrapper(sys.stdin.buffer, encoding="utf-8", errors="replace")) if not stdout: stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8")) - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] - read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - - write_stream: MemoryObjectSendStream[SessionMessage] - write_stream_reader: MemoryObjectReceiveStream[SessionMessage] - - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) + write_stream, write_stream_reader = create_context_streams[SessionMessage](0) async def stdin_reader(): try: @@ -58,7 +52,7 @@ async def stdin_reader(): async for line in stdin: try: message = types.jsonrpc_message_adapter.validate_json(line, by_name=False) - except Exception as exc: # pragma: no cover + except Exception as exc: await read_stream_writer.send(exc) continue @@ -71,7 +65,7 @@ async def stdout_writer(): try: async with write_stream_reader: async for session_message in write_stream_reader: - json = session_message.message.model_dump_json(by_alias=True, exclude_none=True) + json = session_message.message.model_dump_json(by_alias=True, exclude_unset=True) await stdout.write(json + "\n") await stdout.flush() except anyio.ClosedResourceError: # pragma: no cover diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 54ac7374a1..f2f4407cea 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -25,6 +25,8 @@ from starlette.types import Receive, Scope, Send from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings +from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams +from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp.types import ( @@ -89,7 +91,7 @@ async def store_event(self, stream_id: StreamId, message: JSONRPCMessage | None) message: The JSON-RPC message to store, or None for priming events Returns: - The generated event ID for the stored event + The generated event ID for the stored event. """ pass # pragma: no cover @@ -106,7 +108,7 @@ async def replay_events_after( send_callback: A callback function to send events to the client Returns: - The stream ID of the replayed events + The stream ID of the replayed events, or None if no events were found. """ pass # pragma: no cover @@ -119,10 +121,10 @@ class StreamableHTTPServerTransport: """ # Server notification streams for POST requests as well as standalone SSE stream - _read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = None - _read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None - _write_stream: MemoryObjectSendStream[SessionMessage] | None = None - _write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None + _read_stream_writer: ContextSendStream[SessionMessage | Exception] | None = None + _read_stream: ContextReceiveStream[SessionMessage | Exception] | None = None + _write_stream: ContextSendStream[SessionMessage] | None = None + _write_stream_reader: ContextReceiveStream[SessionMessage] | None = None _security: TransportSecurityMiddleware def __init__( @@ -169,13 +171,15 @@ def __init__( ] = {} self._sse_stream_writers: dict[RequestId, MemoryObjectSendStream[dict[str, str]]] = {} self._terminated = False + # Idle timeout cancel scope; managed by the session manager. + self.idle_scope: anyio.CancelScope | None = None @property def is_terminated(self) -> bool: """Check if this transport has been explicitly terminated.""" return self._terminated - def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover + def close_sse_stream(self, request_id: RequestId) -> None: """Close SSE connection for a specific request without terminating the stream. This method closes the HTTP connection for the specified request, triggering @@ -183,7 +187,7 @@ def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover be replayed when the client reconnects with Last-Event-ID. Use this to implement polling behavior during long-running operations - - client will reconnect after the retry interval specified in the priming event. + the client will reconnect after the retry interval specified in the priming event. Args: request_id: The request ID whose SSE stream should be closed. @@ -194,11 +198,11 @@ def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover the disconnect. """ writer = self._sse_stream_writers.pop(request_id, None) - if writer: + if writer: # pragma: no branch writer.close() # Also close and remove request streams - if request_id in self._request_streams: + if request_id in self._request_streams: # pragma: no branch send_stream, receive_stream = self._request_streams.pop(request_id) send_stream.close() receive_stream.close() @@ -211,7 +215,7 @@ def close_standalone_sse_stream(self) -> None: # pragma: no cover with Last-Event-ID to resume receiving notifications. Use this to implement polling behavior for the notification stream - - client will reconnect after the retry interval specified in the priming event. + the client will reconnect after the retry interval specified in the priming event. Note: This is a no-op if there is no active standalone SSE stream. @@ -238,7 +242,7 @@ def _create_session_message( # Only provide close callbacks when client supports resumability if self._event_store and protocol_version >= "2025-11-25": - async def close_stream_callback() -> None: # pragma: no cover + async def close_stream_callback() -> None: self.close_sse_stream(request_id) async def close_standalone_stream_callback() -> None: # pragma: no cover @@ -289,7 +293,7 @@ def _create_error_response( ) -> Response: """Create an error response with a simple string message.""" response_headers = {"Content-Type": CONTENT_TYPE_JSON} - if headers: # pragma: no cover + if headers: response_headers.update(headers) if self.mcp_session_id: @@ -298,12 +302,12 @@ def _create_error_response( # Return a properly formatted JSON error response error_response = JSONRPCError( jsonrpc="2.0", - id="server-error", # We don't have a request ID for general errors + id=None, error=ErrorData(code=error_code, message=error_message), ) return Response( - error_response.model_dump_json(by_alias=True, exclude_none=True), + error_response.model_dump_json(by_alias=True, exclude_unset=True), status_code=status_code, headers=response_headers, ) @@ -314,16 +318,16 @@ def _create_json_response( status_code: HTTPStatus = HTTPStatus.OK, headers: dict[str, str] | None = None, ) -> Response: - """Create a JSON response from a JSONRPCMessage""" + """Create a JSON response from a JSONRPCMessage.""" response_headers = {"Content-Type": CONTENT_TYPE_JSON} - if headers: # pragma: lax no cover - response_headers.update(headers) + if headers: + response_headers.update(headers) # pragma: no cover - if self.mcp_session_id: # pragma: lax no cover + if self.mcp_session_id: response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id return Response( - response_message.model_dump_json(by_alias=True, exclude_none=True) if response_message else None, + response_message.model_dump_json(by_alias=True, exclude_unset=True) if response_message else None, status_code=status_code, headers=response_headers, ) @@ -336,11 +340,11 @@ def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: """Create event data dictionary from an EventMessage.""" event_data = { "event": "message", - "data": event_message.message.model_dump_json(by_alias=True, exclude_none=True), + "data": event_message.message.model_dump_json(by_alias=True, exclude_unset=True), } # If an event ID was provided, include it - if event_message.event_id: # pragma: no cover + if event_message.event_id: event_data["id"] = event_message.event_id return event_data @@ -360,7 +364,7 @@ async def _clean_up_memory_streams(self, request_id: RequestId) -> None: self._request_streams.pop(request_id, None) async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: - """Application entry point that handles all HTTP requests""" + """Application entry point that handles all HTTP requests.""" request = Request(scope, receive) # Validate request headers for DNS rebinding protection @@ -370,7 +374,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No await error_response(scope, receive, send) return - if self._terminated: # pragma: no cover + if self._terminated: # If the session has been terminated, return 404 Not Found response = self._create_error_response( "Not Found: Session has been terminated", @@ -385,16 +389,23 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No await self._handle_get_request(request, send) elif request.method == "DELETE": await self._handle_delete_request(request, send) - else: # pragma: no cover + else: await self._handle_unsupported_request(request, send) def _check_accept_headers(self, request: Request) -> tuple[bool, bool]: - """Check if the request accepts the required media types.""" + """Check if the request accepts the required media types. + + Supports wildcard media types per RFC 7231, section 5.3.2: + - */* matches any media type + - application/* matches any application/ subtype + - text/* matches any text/ subtype + """ accept_header = request.headers.get("accept", "") - accept_types = [media_type.strip() for media_type in accept_header.split(",")] + accept_types = [media_type.strip().split(";")[0].strip().lower() for media_type in accept_header.split(",")] - has_json = any(media_type.startswith(CONTENT_TYPE_JSON) for media_type in accept_types) - has_sse = any(media_type.startswith(CONTENT_TYPE_SSE) for media_type in accept_types) + has_wildcard = "*/*" in accept_types + has_json = has_wildcard or any(t in (CONTENT_TYPE_JSON, "application/*") for t in accept_types) + has_sse = has_wildcard or any(t in (CONTENT_TYPE_SSE, "text/*") for t in accept_types) return has_json, has_sse @@ -410,7 +421,7 @@ async def _validate_accept_header(self, request: Request, scope: Scope, send: Se has_json, has_sse = self._check_accept_headers(request) if self.is_json_response_enabled: # For JSON-only responses, only require application/json - if not has_json: # pragma: lax no cover + if not has_json: # pragma: no cover response = self._create_error_response( "Not Acceptable: Client must accept application/json", HTTPStatus.NOT_ACCEPTABLE, @@ -458,7 +469,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re try: message = jsonrpc_message_adapter.validate_python(raw_message, by_name=False) - except ValidationError as e: # pragma: no cover + except ValidationError as e: response = self._create_error_response( f"Validation error: {str(e)}", HTTPStatus.BAD_REQUEST, @@ -484,7 +495,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re ) await response(scope, receive, send) return - elif not await self._validate_request_headers(request, send): # pragma: no cover + elif not await self._validate_request_headers(request, send): return # For notifications and responses only, return 202 Accepted @@ -534,7 +545,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re if isinstance(event_message.message, JSONRPCResponse | JSONRPCError): response_message = event_message.message break - # For notifications and request, keep waiting + # For notifications and requests, keep waiting else: # pragma: no cover logger.debug(f"received: {event_message.message.method}") @@ -568,7 +579,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re # Store writer reference so close_sse_stream() can close it self._sse_stream_writers[request_id] = sse_stream_writer - async def sse_writer(): # pragma: lax no cover + async def sse_writer(): # Get the request ID from the incoming request message try: async with sse_stream_writer, request_stream_reader: @@ -584,10 +595,10 @@ async def sse_writer(): # pragma: lax no cover # If response, remove from pending streams and close if isinstance(event_message.message, JSONRPCResponse | JSONRPCError): break - except anyio.ClosedResourceError: + except anyio.ClosedResourceError: # pragma: lax no cover # Expected when close_sse_stream() is called logger.debug("SSE stream closed by close_sse_stream()") - except Exception: + except Exception: # pragma: lax no cover logger.exception("Error in SSE writer") finally: logger.debug("Closing SSE writer") @@ -617,14 +628,14 @@ async def sse_writer(): # pragma: lax no cover # Then send the message to be processed by the server session_message = self._create_session_message(message, request, request_id, protocol_version) await writer.send(session_message) - except Exception: # pragma: no cover + except Exception: # pragma: lax no cover logger.exception("SSE response error") await sse_stream_writer.aclose() await self._clean_up_memory_streams(request_id) finally: await sse_stream_reader.aclose() - except Exception as err: # pragma: no cover + except Exception as err: logger.exception("Error handling POST request") response = self._create_error_response( f"Error handling POST request: {err}", @@ -632,9 +643,9 @@ async def sse_writer(): # pragma: lax no cover INTERNAL_ERROR, ) await response(scope, receive, send) - if writer: + if writer: # pragma: no cover await writer.send(Exception(err)) - return + return # pragma: no cover async def _handle_get_request(self, request: Request, send: Send) -> None: """Handle GET request to establish SSE. @@ -650,7 +661,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: # Validate Accept header - must include text/event-stream _, has_sse = self._check_accept_headers(request) - if not has_sse: # pragma: no cover + if not has_sse: response = self._create_error_response( "Not Acceptable: Client must accept text/event-stream", HTTPStatus.NOT_ACCEPTABLE, @@ -662,7 +673,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: return # Handle resumability: check for Last-Event-ID header - if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): # pragma: no cover + if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): await self._replay_events(last_event_id, request, send) return @@ -672,11 +683,11 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: "Content-Type": CONTENT_TYPE_SSE, } - if self.mcp_session_id: + if self.mcp_session_id: # pragma: no branch headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id # Check if we already have an active GET stream - if GET_STREAM_KEY in self._request_streams: # pragma: no cover + if GET_STREAM_KEY in self._request_streams: response = self._create_error_response( "Conflict: Only one SSE stream is allowed per session", HTTPStatus.CONFLICT, @@ -696,7 +707,7 @@ async def standalone_sse_writer(): async with sse_stream_writer, standalone_stream_reader: # Process messages from the standalone stream - async for event_message in standalone_stream_reader: # pragma: lax no cover + async for event_message in standalone_stream_reader: # For the standalone stream, we handle: # - JSONRPCNotification (server sends notifications to client) # - JSONRPCRequest (server sends requests to client) @@ -705,8 +716,8 @@ async def standalone_sse_writer(): # Send the message via SSE event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) - except Exception: # pragma: no cover - logger.exception("Error in standalone SSE writer") + except Exception: + logger.exception("Error in standalone SSE writer") # pragma: no cover finally: logger.debug("Closing standalone SSE writer") await self._clean_up_memory_streams(GET_STREAM_KEY) @@ -764,7 +775,7 @@ async def terminate(self) -> None: request_stream_keys = list(self._request_streams.keys()) # Close all request streams asynchronously - for key in request_stream_keys: # pragma: lax no cover + for key in request_stream_keys: await self._clean_up_memory_streams(key) # Clear the request streams dictionary immediately @@ -782,13 +793,13 @@ async def terminate(self) -> None: # During cleanup, we catch all exceptions since streams might be in various states logger.debug(f"Error closing streams: {e}") - async def _handle_unsupported_request(self, request: Request, send: Send) -> None: # pragma: no cover + async def _handle_unsupported_request(self, request: Request, send: Send) -> None: """Handle unsupported HTTP methods.""" headers = { "Content-Type": CONTENT_TYPE_JSON, "Allow": "GET, POST, DELETE", } - if self.mcp_session_id: + if self.mcp_session_id: # pragma: no branch headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id response = self._create_error_response( @@ -798,7 +809,7 @@ async def _handle_unsupported_request(self, request: Request, send: Send) -> Non ) await response(request.scope, request.receive, send) - async def _validate_request_headers(self, request: Request, send: Send) -> bool: # pragma: lax no cover + async def _validate_request_headers(self, request: Request, send: Send) -> bool: if not await self._validate_session(request, send): return False if not await self._validate_protocol_version(request, send): @@ -807,7 +818,7 @@ async def _validate_request_headers(self, request: Request, send: Send) -> bool: async def _validate_session(self, request: Request, send: Send) -> bool: """Validate the session ID in the request.""" - if not self.mcp_session_id: # pragma: no cover + if not self.mcp_session_id: # If we're not using session IDs, return True return True @@ -815,7 +826,7 @@ async def _validate_session(self, request: Request, send: Send) -> bool: request_session_id = self._get_session_id(request) # If no session ID provided but required, return error - if not request_session_id: # pragma: no cover + if not request_session_id: response = self._create_error_response( "Bad Request: Missing session ID", HTTPStatus.BAD_REQUEST, @@ -840,11 +851,11 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER) # If no protocol version provided, assume default version - if protocol_version is None: # pragma: no cover + if protocol_version is None: protocol_version = DEFAULT_NEGOTIATED_VERSION # Check if the protocol version is supported - if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: # pragma: no cover + if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: supported_versions = ", ".join(SUPPORTED_PROTOCOL_VERSIONS) response = self._create_error_response( f"Bad Request: Unsupported protocol version: {protocol_version}. " @@ -856,13 +867,14 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool return True - async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: # pragma: no cover + async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: """Replays events that would have been sent after the specified event ID. + Only used when resumability is enabled. """ event_store = self._event_store if not event_store: - return + return # pragma: no cover try: headers = { @@ -871,7 +883,7 @@ async def _replay_events(self, last_event_id: str, request: Request, send: Send) "Content-Type": CONTENT_TYPE_SSE, } - if self.mcp_session_id: + if self.mcp_session_id: # pragma: no branch headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id # Get protocol version from header (already validated in _validate_protocol_version) @@ -892,7 +904,7 @@ async def send_event(event_message: EventMessage) -> None: stream_id = await event_store.replay_events_after(last_event_id, send_event) # If stream ID not in mapping, create it - if stream_id and stream_id not in self._request_streams: + if stream_id and stream_id not in self._request_streams: # pragma: no branch # Register SSE writer so close_sse_stream() can close it self._sse_stream_writers[stream_id] = sse_stream_writer @@ -909,10 +921,10 @@ async def send_event(event_message: EventMessage) -> None: event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) - except anyio.ClosedResourceError: + except anyio.ClosedResourceError: # pragma: lax no cover # Expected when close_sse_stream() is called logger.debug("Replay SSE stream closed by close_sse_stream()") - except Exception: + except Exception: # pragma: lax no cover logger.exception("Error in replay sender") # Create and start EventSourceResponse @@ -924,13 +936,13 @@ async def send_event(event_message: EventMessage) -> None: try: await response(request.scope, request.receive, send) - except Exception: + except Exception: # pragma: lax no cover logger.exception("Error in replay response") finally: await sse_stream_writer.aclose() await sse_stream_reader.aclose() - except Exception: + except Exception: # pragma: lax no cover logger.exception("Error replaying events") response = self._create_error_response( "Error replaying events", @@ -944,8 +956,8 @@ async def connect( self, ) -> AsyncGenerator[ tuple[ - MemoryObjectReceiveStream[SessionMessage | Exception], - MemoryObjectSendStream[SessionMessage], + ReadStream[SessionMessage | Exception], + WriteStream[SessionMessage], ], None, ]: @@ -957,8 +969,8 @@ async def connect( # Create the memory streams for this connection - read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) - write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) + read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) + write_stream, write_stream_reader = create_context_streams[SessionMessage](0) # Store the streams self._read_stream_writer = read_stream_writer @@ -975,14 +987,13 @@ async def message_router(): # Determine which request stream(s) should receive this message message = session_message.message target_request_id = None - # Check if this is a response - if isinstance(message, JSONRPCResponse | JSONRPCError): - response_id = str(message.id) - # If this response is for an existing request stream, - # send it there - target_request_id = response_id + # Check if this is a response with a known request id. + # Null-id errors (e.g., parse errors) fall through to + # the GET stream since they can't be correlated. + if isinstance(message, JSONRPCResponse | JSONRPCError) and message.id is not None: + target_request_id = str(message.id) # Extract related_request_id from meta if it exists - elif ( # pragma: no cover + elif ( session_message.metadata is not None and isinstance( session_message.metadata, @@ -998,7 +1009,7 @@ async def message_router(): # regardless of whether a client is connected # messages will be replayed on the re-connect event_id = None - if self._event_store: # pragma: lax no cover + if self._event_store: event_id = await self._event_store.store_event(request_stream_id, message) logger.debug(f"Stored {event_id} from {request_stream_id}") @@ -1009,14 +1020,14 @@ async def message_router(): except (anyio.BrokenResourceError, anyio.ClosedResourceError): # pragma: no cover # Stream might be closed, remove from registry self._request_streams.pop(request_stream_id, None) - else: # pragma: no cover + else: logger.debug( f"""Request stream {request_stream_id} not found for message. Still processing message as the client might reconnect and replay.""" ) except anyio.ClosedResourceError: - if self._terminated: + if self._terminated: # pragma: lax no cover logger.debug("Read stream closed by client") else: logger.exception("Unexpected closure of read stream in message router") @@ -1030,7 +1041,7 @@ async def message_router(): # Yield the streams for the caller to use yield read_stream, write_stream finally: - for stream_id in list(self._request_streams.keys()): # pragma: lax no cover + for stream_id in list(self._request_streams.keys()): await self._clean_up_memory_streams(stream_id) self._request_streams.clear() diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index ddc6e5014f..81350a8f24 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -5,7 +5,6 @@ import contextlib import logging from collections.abc import AsyncIterator -from http import HTTPStatus from typing import TYPE_CHECKING, Any from uuid import uuid4 @@ -15,6 +14,7 @@ from starlette.responses import Response from starlette.types import Receive, Scope, Send +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, AuthorizationContext, authorization_context from mcp.server.streamable_http import ( MCP_SESSION_ID_HEADER, EventStore, @@ -39,6 +39,7 @@ class StreamableHTTPSessionManager: 2. Resumability via an optional event store 3. Connection management and lifecycle 4. Request handling and transport setup + 5. Idle session cleanup via optional timeout Important: Only one StreamableHTTPSessionManager instance should be created per application. The instance cannot be reused after its run() context has @@ -46,37 +47,51 @@ class StreamableHTTPSessionManager: Args: app: The MCP server instance - event_store: Optional event store for resumability support. - If provided, enables resumable connections where clients - can reconnect and receive missed events. - If None, sessions are still tracked but not resumable. + event_store: Optional event store for resumability support. If provided, enables resumable connections + where clients can reconnect and receive missed events. If None, sessions are still tracked but not + resumable. json_response: Whether to use JSON responses instead of SSE streams - stateless: If True, creates a completely fresh transport for each request - with no session tracking or state persistence between requests. + stateless: If True, creates a completely fresh transport for each request with no session tracking or + state persistence between requests. security_settings: Optional transport security settings. - retry_interval: Retry interval in milliseconds to suggest to clients in SSE - retry field. Used for SSE polling behavior. + retry_interval: Retry interval in milliseconds to suggest to clients in SSE retry field. Used for SSE + polling behavior. + session_idle_timeout: Optional idle timeout in seconds for stateful sessions. If set, sessions that + receive no HTTP requests for this duration will be automatically terminated and removed. When + retry_interval is also configured, ensure the idle timeout comfortably exceeds the retry interval to + avoid reaping sessions during normal SSE polling gaps. Default is None (no timeout). A value of 1800 + (30 minutes) is recommended for most deployments. """ def __init__( self, - app: Server[Any, Any], + app: Server[Any], event_store: EventStore | None = None, json_response: bool = False, stateless: bool = False, security_settings: TransportSecuritySettings | None = None, retry_interval: int | None = None, + session_idle_timeout: float | None = None, ): + if session_idle_timeout is not None and session_idle_timeout <= 0: + raise ValueError("session_idle_timeout must be a positive number of seconds") + if stateless and session_idle_timeout is not None: + raise RuntimeError("session_idle_timeout is not supported in stateless mode") + self.app = app self.event_store = event_store self.json_response = json_response self.stateless = stateless self.security_settings = security_settings self.retry_interval = retry_interval + self.session_idle_timeout = session_idle_timeout # Session tracking (only used if not stateless) self._session_creation_lock = anyio.Lock() self._server_instances: dict[str, StreamableHTTPServerTransport] = {} + # Identity of the credential that created each session; requests for a + # session must present the same credential. + self._session_owners: dict[str, AuthorizationContext] = {} # The task group will be set during lifespan self._task_group = None @@ -123,6 +138,7 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]: self._task_group = None # Clear any remaining server instances self._server_instances.clear() + self._session_owners.clear() async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: """Process ASGI request with proper session handling and transport setup. @@ -161,7 +177,7 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA self.app.create_initialization_options(), stateless=True, ) - except Exception: # pragma: no cover + except Exception: # pragma: lax no cover logger.exception("Stateless session crashed") # Assert task group is not None for type checking @@ -180,10 +196,33 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S request = Request(scope, receive) request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) + user = scope.get("user") + requestor = authorization_context(user) if isinstance(user, AuthenticatedUser) else None + # Existing session case if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: transport = self._server_instances[request_mcp_session_id] + if requestor != self._session_owners.get(request_mcp_session_id): + # A session can only be used with the credential that created + # it. Respond exactly as if the session did not exist. + logger.warning( + "Rejecting request for session %s: credential does not match the one that created the session", + request_mcp_session_id[:64], + ) + body = JSONRPCError( + jsonrpc="2.0", id=None, error=ErrorData(code=INVALID_REQUEST, message="Session not found") + ) + response = Response( + body.model_dump_json(by_alias=True, exclude_unset=True), + status_code=404, + media_type="application/json", + ) + await response(scope, receive, send) + return logger.debug("Session already exists, handling request directly") + # Push back idle deadline on activity + if transport.idle_scope is not None and self.session_idle_timeout is not None: + transport.idle_scope.deadline = anyio.current_time() + self.session_idle_timeout # pragma: no cover await transport.handle_request(scope, receive, send) return @@ -201,6 +240,8 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S ) assert http_transport.mcp_session_id is not None + if requestor is not None: + self._session_owners[http_transport.mcp_session_id] = requestor self._server_instances[http_transport.mcp_session_id] = http_transport logger.info(f"Created new transport with session ID: {new_session_id}") @@ -210,16 +251,32 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE read_stream, write_stream = streams task_status.started() try: - await self.app.run( - read_stream, - write_stream, - self.app.create_initialization_options(), - stateless=False, # Stateful mode - ) + # Use a cancel scope for idle timeout — when the + # deadline passes the scope cancels app.run() and + # execution continues after the ``with`` block. + # Incoming requests push the deadline forward. + idle_scope = anyio.CancelScope() + if self.session_idle_timeout is not None: + idle_scope.deadline = anyio.current_time() + self.session_idle_timeout + http_transport.idle_scope = idle_scope + + with idle_scope: + await self.app.run( + read_stream, + write_stream, + self.app.create_initialization_options(), + stateless=False, + ) + + if idle_scope.cancelled_caught: + assert http_transport.mcp_session_id is not None + logger.info(f"Session {http_transport.mcp_session_id} idle timeout") + self._server_instances.pop(http_transport.mcp_session_id, None) + self._session_owners.pop(http_transport.mcp_session_id, None) + await http_transport.terminate() except Exception: logger.exception(f"Session {http_transport.mcp_session_id} crashed") finally: - # Only remove from instances if not terminated if ( # pragma: no branch http_transport.mcp_session_id and http_transport.mcp_session_id in self._server_instances @@ -230,6 +287,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE f"{http_transport.mcp_session_id} from active instances." ) del self._server_instances[http_transport.mcp_session_id] + self._session_owners.pop(http_transport.mcp_session_id, None) # Assert task group is not None for type checking assert self._task_group is not None @@ -242,15 +300,12 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE # Unknown or expired session ID - return 404 per MCP spec # TODO: Align error code once spec clarifies # See: https://github.com/modelcontextprotocol/python-sdk/issues/1821 - error_response = JSONRPCError( - jsonrpc="2.0", - id="server-error", - error=ErrorData(code=INVALID_REQUEST, message="Session not found"), + logger.info(f"Rejected request with unknown or expired session ID: {request_mcp_session_id[:64]}") + body = JSONRPCError( + jsonrpc="2.0", id=None, error=ErrorData(code=INVALID_REQUEST, message="Session not found") ) response = Response( - content=error_response.model_dump_json(by_alias=True, exclude_none=True), - status_code=HTTPStatus.NOT_FOUND, - media_type="application/json", + body.model_dump_json(by_alias=True, exclude_unset=True), status_code=404, media_type="application/json" ) await response(scope, receive, send) diff --git a/src/mcp/server/transport_security.py b/src/mcp/server/transport_security.py index 1ed9842c0e..d9e9f965b3 100644 --- a/src/mcp/server/transport_security.py +++ b/src/mcp/server/transport_security.py @@ -40,7 +40,7 @@ def __init__(self, settings: TransportSecuritySettings | None = None): # If not specified, disable DNS rebinding protection by default for backwards compatibility self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False) - def _validate_host(self, host: str | None) -> bool: # pragma: no cover + def _validate_host(self, host: str | None) -> bool: """Validate the Host header against allowed values.""" if not host: logger.warning("Missing Host header in request") @@ -62,7 +62,7 @@ def _validate_host(self, host: str | None) -> bool: # pragma: no cover logger.warning(f"Invalid Host header: {host}") return False - def _validate_origin(self, origin: str | None) -> bool: # pragma: no cover + def _validate_origin(self, origin: str | None) -> bool: """Validate the Origin header against allowed values.""" # Origin can be absent for same-origin requests if not origin: @@ -94,7 +94,7 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res Returns None if validation passes, or an error Response if validation fails. """ # Always validate Content-Type for POST requests - if is_post: # pragma: no branch + if is_post: content_type = request.headers.get("content-type") if not self._validate_content_type(content_type): return Response("Invalid Content-Type header", status_code=400) @@ -103,14 +103,14 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res if not self.settings.enable_dns_rebinding_protection: return None - # Validate Host header # pragma: no cover - host = request.headers.get("host") # pragma: no cover - if not self._validate_host(host): # pragma: no cover - return Response("Invalid Host header", status_code=421) # pragma: no cover + # Validate Host header + host = request.headers.get("host") + if not self._validate_host(host): + return Response("Invalid Host header", status_code=421) - # Validate Origin header # pragma: no cover - origin = request.headers.get("origin") # pragma: no cover - if not self._validate_origin(origin): # pragma: no cover - return Response("Invalid Origin header", status_code=403) # pragma: no cover + # Validate Origin header + origin = request.headers.get("origin") + if not self._validate_origin(origin): + return Response("Invalid Origin header", status_code=403) - return None # pragma: no cover + return None diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index a4c8448112..277f9b5af2 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -1,32 +1,26 @@ from contextlib import asynccontextmanager import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic_core import ValidationError from starlette.types import Receive, Scope, Send from starlette.websockets import WebSocket from mcp import types +from mcp.shared._context_streams import create_context_streams from mcp.shared.message import SessionMessage -@asynccontextmanager # pragma: no cover +@asynccontextmanager async def websocket_server(scope: Scope, receive: Receive, send: Send): - """WebSocket server transport for MCP. This is an ASGI application, suitable to be - used with a framework like Starlette and a server like Hypercorn. + """WebSocket server transport for MCP. This is an ASGI application, suitable for use + with a framework like Starlette and a server like Hypercorn. """ websocket = WebSocket(scope, receive, send) await websocket.accept(subprotocol="mcp") - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] - read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - - write_stream: MemoryObjectSendStream[SessionMessage] - write_stream_reader: MemoryObjectReceiveStream[SessionMessage] - - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) + write_stream, write_stream_reader = create_context_streams[SessionMessage](0) async def ws_reader(): try: @@ -34,22 +28,22 @@ async def ws_reader(): async for msg in websocket.iter_text(): try: client_message = types.jsonrpc_message_adapter.validate_json(msg, by_name=False) - except ValidationError as exc: + except ValidationError as exc: # pragma: no cover await read_stream_writer.send(exc) continue session_message = SessionMessage(client_message) await read_stream_writer.send(session_message) - except anyio.ClosedResourceError: + except anyio.ClosedResourceError: # pragma: no cover await websocket.close() async def ws_writer(): try: async with write_stream_reader: async for session_message in write_stream_reader: - obj = session_message.message.model_dump_json(by_alias=True, exclude_none=True) + obj = session_message.message.model_dump_json(by_alias=True, exclude_unset=True) await websocket.send_text(obj) - except anyio.ClosedResourceError: + except anyio.ClosedResourceError: # pragma: no cover await websocket.close() async with anyio.create_task_group() as tg: diff --git a/src/mcp/shared/_callable_inspection.py b/src/mcp/shared/_callable_inspection.py new file mode 100644 index 0000000000..0e89e446f8 --- /dev/null +++ b/src/mcp/shared/_callable_inspection.py @@ -0,0 +1,33 @@ +"""Callable inspection utilities. + +Adapted from Starlette's `is_async_callable` implementation. +https://github.com/encode/starlette/blob/main/starlette/_utils.py +""" + +from __future__ import annotations + +import functools +import inspect +from collections.abc import Awaitable, Callable +from typing import Any, TypeGuard, TypeVar, overload + +T = TypeVar("T") + +AwaitableCallable = Callable[..., Awaitable[T]] + + +@overload +def is_async_callable(obj: AwaitableCallable[T]) -> TypeGuard[AwaitableCallable[T]]: ... + + +@overload +def is_async_callable(obj: Any) -> TypeGuard[AwaitableCallable[Any]]: ... + + +def is_async_callable(obj: Any) -> Any: + while isinstance(obj, functools.partial): # pragma: lax no cover + obj = obj.func + + return inspect.iscoroutinefunction(obj) or ( + callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None)) + ) diff --git a/src/mcp/shared/_context.py b/src/mcp/shared/_context.py index 2facc2a493..bbcee2d02c 100644 --- a/src/mcp/shared/_context.py +++ b/src/mcp/shared/_context.py @@ -13,8 +13,12 @@ @dataclass(kw_only=True) class RequestContext(Generic[SessionT]): - """Common context for handling incoming requests.""" + """Common context for handling incoming requests. + + For request handlers, request_id is always populated. + For notification handlers, request_id is None. + """ - request_id: RequestId - meta: RequestParamsMeta | None session: SessionT + request_id: RequestId | None = None + meta: RequestParamsMeta | None = None diff --git a/src/mcp/shared/_context_streams.py b/src/mcp/shared/_context_streams.py new file mode 100644 index 0000000000..04c33306d9 --- /dev/null +++ b/src/mcp/shared/_context_streams.py @@ -0,0 +1,119 @@ +"""Context-aware memory stream wrappers. + +anyio memory streams do not propagate ``contextvars.Context`` across task +boundaries. These thin wrappers capture the sender's context at ``send()`` +time and expose it on the receive side via ``last_context``, so consumers +can restore it with ``ctx.run(handler, item)``. + +The iteration interface is unchanged (yields ``T``, not tuples), keeping +these wrappers duck-type compatible with plain ``MemoryObjectSendStream`` +and ``MemoryObjectReceiveStream``. +""" + +from __future__ import annotations + +import contextvars +from types import TracebackType +from typing import Any, Generic, TypeVar + +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream + +T = TypeVar("T") + +# Internal payload carried through the underlying raw stream. +_Envelope = tuple[contextvars.Context, T] + + +class ContextSendStream(Generic[T]): + """Send-side wrapper that snapshots ``contextvars.copy_context()`` on every ``send()``.""" + + __slots__ = ("_inner",) + + def __init__(self, inner: MemoryObjectSendStream[_Envelope[T]]) -> None: + self._inner = inner + + async def send(self, item: T) -> None: + await self._inner.send((contextvars.copy_context(), item)) + + def close(self) -> None: + self._inner.close() + + async def aclose(self) -> None: + await self._inner.aclose() + + def clone(self) -> ContextSendStream[T]: # pragma: no cover + return ContextSendStream(self._inner.clone()) + + async def __aenter__(self) -> ContextSendStream[T]: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + await self.aclose() + return None + + +class ContextReceiveStream(Generic[T]): + """Receive-side wrapper that yields ``T`` and stores the sender's context in ``last_context``.""" + + __slots__ = ("_inner", "last_context") + + def __init__(self, inner: MemoryObjectReceiveStream[_Envelope[T]]) -> None: + self._inner = inner + self.last_context: contextvars.Context | None = None + + async def receive(self) -> T: + ctx, item = await self._inner.receive() + self.last_context = ctx + return item + + def close(self) -> None: + self._inner.close() + + async def aclose(self) -> None: + await self._inner.aclose() + + def clone(self) -> ContextReceiveStream[T]: # pragma: no cover + return ContextReceiveStream(self._inner.clone()) + + def __aiter__(self) -> ContextReceiveStream[T]: + return self + + async def __anext__(self) -> T: + try: + return await self.receive() + except anyio.EndOfStream: + raise StopAsyncIteration + + async def __aenter__(self) -> ContextReceiveStream[T]: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + await self.aclose() + return None + + +class create_context_streams( + tuple[ContextSendStream[T], ContextReceiveStream[T]], +): + """Create context-aware memory object streams. + + Supports ``create_context_streams[T](n)`` bracket syntax, + matching anyio's ``create_memory_object_stream`` API style. + """ + + def __new__(cls, max_buffer_size: float = 0) -> tuple[ContextSendStream[T], ContextReceiveStream[T]]: # type: ignore[type-var] + raw_send: MemoryObjectSendStream[Any] + raw_receive: MemoryObjectReceiveStream[Any] + raw_send, raw_receive = anyio.create_memory_object_stream(max_buffer_size) + return (ContextSendStream(raw_send), ContextReceiveStream(raw_receive)) diff --git a/src/mcp/shared/_httpx_utils.py b/src/mcp/shared/_httpx_utils.py index 8cf7bda2ad..251469eaa1 100644 --- a/src/mcp/shared/_httpx_utils.py +++ b/src/mcp/shared/_httpx_utils.py @@ -44,26 +44,38 @@ def create_mcp_http_client( The returned AsyncClient must be used as a context manager to ensure proper cleanup of connections. - Examples: - # Basic usage with MCP defaults + Example: + Basic usage with MCP defaults: + + ```python async with create_mcp_http_client() as client: response = await client.get("https://api.example.com") + ``` + + With custom headers: - # With custom headers + ```python headers = {"Authorization": "Bearer token"} async with create_mcp_http_client(headers) as client: response = await client.get("/endpoint") + ``` - # With both custom headers and timeout + With both custom headers and timeout: + + ```python timeout = httpx.Timeout(60.0, read=300.0) async with create_mcp_http_client(headers, timeout) as client: response = await client.get("/long-request") + ``` + + With authentication: - # With authentication + ```python from httpx import BasicAuth auth = BasicAuth(username="user", password="pass") async with create_mcp_http_client(headers, timeout, auth) as client: response = await client.get("/protected-endpoint") + ``` """ # Set MCP defaults kwargs: dict[str, Any] = {"follow_redirects": True} diff --git a/src/mcp/shared/_otel.py b/src/mcp/shared/_otel.py new file mode 100644 index 0000000000..170e873a0f --- /dev/null +++ b/src/mcp/shared/_otel.py @@ -0,0 +1,36 @@ +"""OpenTelemetry helpers for MCP.""" + +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +from typing import Any + +from opentelemetry.context import Context +from opentelemetry.propagate import extract, inject +from opentelemetry.trace import SpanKind, get_tracer + +_tracer = get_tracer("mcp-python-sdk") + + +@contextmanager +def otel_span( + name: str, + *, + kind: SpanKind, + attributes: dict[str, Any] | None = None, + context: Context | None = None, +) -> Iterator[Any]: + """Create an OTel span.""" + with _tracer.start_as_current_span(name, kind=kind, attributes=attributes, context=context) as span: + yield span + + +def inject_trace_context(meta: dict[str, Any]) -> None: + """Inject W3C trace context (traceparent/tracestate) into a `_meta` dict.""" + inject(meta) + + +def extract_trace_context(meta: dict[str, Any]) -> Context: + """Extract W3C trace context from a `_meta` dict.""" + return extract(meta) diff --git a/src/mcp/shared/_stream_protocols.py b/src/mcp/shared/_stream_protocols.py new file mode 100644 index 0000000000..b799751329 --- /dev/null +++ b/src/mcp/shared/_stream_protocols.py @@ -0,0 +1,49 @@ +"""Stream protocols for MCP transports. + +These are general-purpose protocols satisfied by both ``MemoryObjectSendStream``/ +``MemoryObjectReceiveStream`` and the context-aware wrappers in ``_context_streams``. +""" + +from __future__ import annotations + +from types import TracebackType +from typing import Protocol, TypeVar + +from typing_extensions import Self + +T_co = TypeVar("T_co", covariant=True) +T_contra = TypeVar("T_contra", contravariant=True) + + +class ReadStream(Protocol[T_co]): + """Protocol for reading items from a stream. + + Consumers that need the sender's context should use + ``getattr(stream, 'last_context', None)``. + """ + + async def receive(self) -> T_co: ... + async def aclose(self) -> None: ... + def __aiter__(self) -> ReadStream[T_co]: ... + async def __anext__(self) -> T_co: ... + async def __aenter__(self) -> Self: ... + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: ... + + +class WriteStream(Protocol[T_contra]): + """Protocol for writing items to a stream.""" + + async def send(self, item: T_contra, /) -> None: ... + async def aclose(self) -> None: ... + async def __aenter__(self) -> Self: ... + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: ... diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index bf03a8b8dd..3b48152d5b 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -33,9 +33,8 @@ def __init__(self, message: str): class OAuthClientMetadata(BaseModel): - """RFC 7591 OAuth 2.0 Dynamic Client Registration metadata. + """RFC 7591 OAuth 2.0 Dynamic Client Registration Metadata. See https://datatracker.ietf.org/doc/html/rfc7591#section-2 - for the full specification. """ redirect_uris: list[AnyUrl] | None = Field(..., min_length=1) @@ -68,15 +67,33 @@ class OAuthClientMetadata(BaseModel): software_id: str | None = None software_version: str | None = None + @field_validator( + "client_uri", + "logo_uri", + "tos_uri", + "policy_uri", + "jwks_uri", + mode="before", + ) + @classmethod + def _empty_string_optional_url_to_none(cls, v: object) -> object: + # RFC 7591 §2 marks these URL fields OPTIONAL. Some authorization servers + # echo omitted metadata back as "" instead of dropping the keys, which + # AnyHttpUrl would otherwise reject — throwing away an otherwise valid + # registration response. Treat "" as absent. + if v == "": + return None + return v + def validate_scope(self, requested_scope: str | None) -> list[str] | None: if requested_scope is None: return None requested_scopes = requested_scope.split(" ") allowed_scopes = [] if self.scope is None else self.scope.split(" ") for scope in requested_scopes: - if scope not in allowed_scopes: # pragma: no branch + if scope not in allowed_scopes: raise InvalidScopeError(f"Client was not registered with scope {scope}") - return requested_scopes # pragma: no cover + return requested_scopes def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl: if redirect_uri is not None: @@ -145,9 +162,9 @@ class ProtectedResourceMetadata(BaseModel): resource_documentation: AnyHttpUrl | None = None resource_policy_uri: AnyHttpUrl | None = None resource_tos_uri: AnyHttpUrl | None = None - # tls_client_certificate_bound_access_tokens default is False, but ommited here for clarity + # tls_client_certificate_bound_access_tokens default is False, but omitted here for clarity tls_client_certificate_bound_access_tokens: bool | None = None authorization_details_types_supported: list[str] | None = None dpop_signing_alg_values_supported: list[str] | None = None - # dpop_bound_access_tokens_required default is False, but ommited here for clarity + # dpop_bound_access_tokens_required default is False, but omitted here for clarity dpop_bound_access_tokens_required: bool | None = None diff --git a/src/mcp/shared/auth_utils.py b/src/mcp/shared/auth_utils.py index 8f3c542f22..3ba880f40d 100644 --- a/src/mcp/shared/auth_utils.py +++ b/src/mcp/shared/auth_utils.py @@ -51,22 +51,17 @@ def check_resource_allowed(requested_resource: str, configured_resource: str) -> if requested.scheme.lower() != configured.scheme.lower() or requested.netloc.lower() != configured.netloc.lower(): return False - # Handle cases like requested=/foo and configured=/foo/ + # Normalize trailing slashes before comparison so that + # "/foo" and "/foo/" are treated as equivalent. requested_path = requested.path configured_path = configured.path - - # If requested path is shorter, it cannot be a child - if len(requested_path) < len(configured_path): - return False - - # Check if the requested path starts with the configured path - # Ensure both paths end with / for proper comparison - # This ensures that paths like "/api123" don't incorrectly match "/api" if not requested_path.endswith("/"): requested_path += "/" if not configured_path.endswith("/"): configured_path += "/" + # Check hierarchical match: requested must start with configured path. + # The trailing-slash normalization ensures "/api123/" won't match "/api/". return requested_path.startswith(configured_path) diff --git a/src/mcp/shared/exceptions.py b/src/mcp/shared/exceptions.py index 7a2b2ded4d..f153ea319d 100644 --- a/src/mcp/shared/exceptions.py +++ b/src/mcp/shared/exceptions.py @@ -12,7 +12,10 @@ class MCPError(Exception): def __init__(self, code: int, message: str, data: Any = None): super().__init__(code, message, data) - self.error = ErrorData(code=code, message=message, data=data) + if data is not None: + self.error = ErrorData(code=code, message=message, data=data) + else: + self.error = ErrorData(code=code, message=message) @property def code(self) -> int: @@ -62,6 +65,7 @@ class UrlElicitationRequiredError(MCPError): must complete one or more URL elicitations before the request can be processed. Example: + ```python raise UrlElicitationRequiredError([ ElicitRequestURLParams( message="Authorization required for your files", @@ -69,6 +73,7 @@ class UrlElicitationRequiredError(MCPError): elicitation_id="auth-001" ) ]) + ``` """ def __init__(self, elicitations: list[ElicitRequestURLParams], message: str | None = None): diff --git a/src/mcp/shared/experimental/tasks/helpers.py b/src/mcp/shared/experimental/tasks/helpers.py index 38ca802daf..3f91cd0d06 100644 --- a/src/mcp/shared/experimental/tasks/helpers.py +++ b/src/mcp/shared/experimental/tasks/helpers.py @@ -72,9 +72,10 @@ async def cancel_task( - Task is already in a terminal state (completed, failed, cancelled) Example: - @server.experimental.cancel_task() - async def handle_cancel(request: CancelTaskRequest) -> CancelTaskResult: - return await cancel_task(store, request.params.taskId) + ```python + async def handle_cancel(ctx, params: CancelTaskRequestParams) -> CancelTaskResult: + return await cancel_task(store, params.task_id) + ``` """ task = await store.get_task(task_id) if task is None: diff --git a/src/mcp/shared/experimental/tasks/message_queue.py b/src/mcp/shared/experimental/tasks/message_queue.py index 018c2b7b26..e17c4a8650 100644 --- a/src/mcp/shared/experimental/tasks/message_queue.py +++ b/src/mcp/shared/experimental/tasks/message_queue.py @@ -12,6 +12,7 @@ """ from abc import ABC, abstractmethod +from collections import deque from dataclasses import dataclass, field from datetime import datetime, timezone from typing import Any, Literal @@ -151,13 +152,13 @@ class InMemoryTaskMessageQueue(TaskMessageQueue): """ def __init__(self) -> None: - self._queues: dict[str, list[QueuedMessage]] = {} + self._queues: dict[str, deque[QueuedMessage]] = {} self._events: dict[str, anyio.Event] = {} - def _get_queue(self, task_id: str) -> list[QueuedMessage]: + def _get_queue(self, task_id: str) -> deque[QueuedMessage]: """Get or create the queue for a task.""" if task_id not in self._queues: - self._queues[task_id] = [] + self._queues[task_id] = deque() return self._queues[task_id] async def enqueue(self, task_id: str, message: QueuedMessage) -> None: @@ -172,7 +173,7 @@ async def dequeue(self, task_id: str) -> QueuedMessage | None: queue = self._get_queue(task_id) if not queue: return None - return queue.pop(0) + return queue.popleft() async def peek(self, task_id: str) -> QueuedMessage | None: """Return the next message without removing it.""" diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index d01d28b808..468590d095 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -5,25 +5,23 @@ from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream - +from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams from mcp.shared.message import SessionMessage -MessageStream = tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]] +MessageStream = tuple[ContextReceiveStream[SessionMessage | Exception], ContextSendStream[SessionMessage | Exception]] @asynccontextmanager async def create_client_server_memory_streams() -> AsyncGenerator[tuple[MessageStream, MessageStream], None]: """Creates a pair of bidirectional memory streams for client-server communication. - Returns: + Yields: A tuple of (client_streams, server_streams) where each is a tuple of (read_stream, write_stream) """ # Create streams for both directions - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) + server_to_client_send, server_to_client_receive = create_context_streams[SessionMessage | Exception](1) + client_to_server_send, client_to_server_receive = create_context_streams[SessionMessage | Exception](1) client_streams = (server_to_client_receive, client_to_server_send) server_streams = (client_to_server_receive, server_to_client_send) diff --git a/src/mcp/shared/message.py b/src/mcp/shared/message.py index 9dedd2e5d3..1858eeac31 100644 --- a/src/mcp/shared/message.py +++ b/src/mcp/shared/message.py @@ -6,6 +6,7 @@ from collections.abc import Awaitable, Callable from dataclasses import dataclass +from typing import Any from mcp.types import JSONRPCMessage, RequestId @@ -30,8 +31,10 @@ class ServerMessageMetadata: """Metadata specific to server messages.""" related_request_id: RequestId | None = None - # Request-specific context (e.g., headers, auth info) - request_context: object | None = None + # Transport-specific request context (e.g. starlette Request for HTTP + # transports, None for stdio). Typed as Any because the server layer is + # transport-agnostic. + request_context: Any = None # Callback to close SSE stream for the current request without terminating close_sse_stream: CloseSSEStreamCallback | None = None # Callback to close the standalone GET SSE stream (for unsolicited notifications) diff --git a/src/mcp/shared/metadata_utils.py b/src/mcp/shared/metadata_utils.py index 2b66996bde..6e4d33da0f 100644 --- a/src/mcp/shared/metadata_utils.py +++ b/src/mcp/shared/metadata_utils.py @@ -1,7 +1,7 @@ """Utility functions for working with metadata in MCP types. These utilities are primarily intended for client-side usage to properly display -human-readable names in user interfaces in a spec compliant way. +human-readable names in user interfaces in a spec-compliant way. """ from mcp.types import Implementation, Prompt, Resource, ResourceTemplate, Tool @@ -18,11 +18,13 @@ def get_display_name(obj: Tool | Resource | Prompt | ResourceTemplate | Implemen For other objects: title > name Example: + ```python # In a client displaying available tools tools = await session.list_tools() for tool in tools.tools: display_name = get_display_name(tool) print(f"Available tool: {display_name}") + ``` Args: obj: An MCP object with name and optional title fields diff --git a/src/mcp/shared/progress.py b/src/mcp/shared/progress.py deleted file mode 100644 index 510bd81632..0000000000 --- a/src/mcp/shared/progress.py +++ /dev/null @@ -1,45 +0,0 @@ -from collections.abc import Generator -from contextlib import contextmanager -from dataclasses import dataclass, field -from typing import Generic - -from pydantic import BaseModel - -from mcp.shared._context import RequestContext, SessionT -from mcp.types import ProgressToken - - -class Progress(BaseModel): - progress: float - total: float | None - - -@dataclass -class ProgressContext(Generic[SessionT]): - session: SessionT - progress_token: ProgressToken - total: float | None - current: float = field(default=0.0, init=False) - - async def progress(self, amount: float, message: str | None = None) -> None: - self.current += amount - - await self.session.send_progress_notification( - self.progress_token, self.current, total=self.total, message=message - ) - - -@contextmanager -def progress( - ctx: RequestContext[SessionT], - total: float | None = None, -) -> Generator[ProgressContext[SessionT], None]: - progress_token = ctx.meta.get("progress_token") if ctx.meta else None - if progress_token is None: # pragma: no cover - raise ValueError("No progress token provided") - - progress_ctx = ProgressContext(ctx.session, progress_token, total) - try: - yield progress_ctx - finally: - pass diff --git a/src/mcp/shared/response_router.py b/src/mcp/shared/response_router.py index 7ec4a443c1..fe24b016f1 100644 --- a/src/mcp/shared/response_router.py +++ b/src/mcp/shared/response_router.py @@ -25,6 +25,7 @@ class ResponseRouter(Protocol): and deliver the response/error to the appropriate handler. Example: + ```python class TaskResultHandler(ResponseRouter): def route_response(self, request_id, response): resolver = self._pending_requests.pop(request_id, None) @@ -32,6 +33,7 @@ def route_response(self, request_id, response): resolver.set_result(response) return True return False + ``` """ def route_response(self, request_id: RequestId, response: dict[str, Any]) -> bool: diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 453e36274e..9c72a23844 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextvars import logging from collections.abc import Callable from contextlib import AsyncExitStack @@ -7,10 +8,13 @@ from typing import Any, Generic, Protocol, TypeVar import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from anyio.streams.memory import MemoryObjectSendStream +from opentelemetry.trace import SpanKind from pydantic import BaseModel, TypeAdapter from typing_extensions import Self +from mcp.shared._otel import inject_trace_context, otel_span +from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.exceptions import MCPError from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage from mcp.shared.response_router import ResponseRouter @@ -60,8 +64,10 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): cancellation handling: Example: + ```python with request_responder as resp: await resp.respond(result) + ``` The context manager ensures: 1. Proper cancellation scope setup and cleanup @@ -77,11 +83,13 @@ def __init__( session: BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT], on_complete: Callable[[RequestResponder[ReceiveRequestT, SendResultT]], Any], message_metadata: MessageMetadata = None, + context: contextvars.Context | None = None, ) -> None: self.request_id = request_id self.request_meta = request_meta self.request = request self.message_metadata = message_metadata + self.context = context self._session = session self._completed = False self._cancel_scope = anyio.CancelScope() @@ -103,7 +111,7 @@ def __exit__( ) -> None: """Exit the context manager, performing cleanup and notifying completion.""" try: - if self._completed: # pragma: no branch + if self._completed: self._on_complete(self) finally: self._entered = False @@ -115,6 +123,7 @@ async def respond(self, response: SendResultT | ErrorData) -> None: """Send a response for this request. Must be called within a context manager block. + Raises: RuntimeError: If not used within a context manager AssertionError: If request was already responded to @@ -142,7 +151,7 @@ async def cancel(self) -> None: # Send an error response to indicate cancellation await self._session._send_response( # type: ignore[reportPrivateUsage] request_id=self.request_id, - response=ErrorData(code=0, message="Request cancelled", data=None), + response=ErrorData(code=0, message="Request cancelled"), ) @property @@ -178,8 +187,8 @@ class BaseSession( def __init__( self, - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], # If none, reading will never time out read_timeout_seconds: float | None = None, ) -> None: @@ -235,7 +244,7 @@ async def send_request( metadata: MessageMetadata = None, progress_callback: ProgressFnT | None = None, ) -> ReceiveResultT: - """Sends a request and wait for a response. + """Sends a request and waits for a response. Raises an MCPError if the response contains an error. If a request read timeout is provided, it will take precedence over the session read timeout. @@ -261,24 +270,36 @@ async def send_request( self._progress_callbacks[request_id] = progress_callback try: - jsonrpc_request = JSONRPCRequest(jsonrpc="2.0", id=request_id, **request_data) - await self._write_stream.send(SessionMessage(message=jsonrpc_request, metadata=metadata)) - - # request read timeout takes precedence over session read timeout - timeout = request_read_timeout_seconds or self._session_read_timeout_seconds - - try: - with anyio.fail_after(timeout): - response_or_error = await response_stream_reader.receive() - except TimeoutError: - class_name = request.__class__.__name__ - message = f"Timed out while waiting for response to {class_name}. Waited {timeout} seconds." - raise MCPError(code=REQUEST_TIMEOUT, message=message) - - if isinstance(response_or_error, JSONRPCError): - raise MCPError.from_jsonrpc_error(response_or_error) - else: - return result_type.model_validate(response_or_error.result, by_name=False) + target = request_data.get("params", {}).get("name") + span_name = f"MCP send {request.method} {target}" if target else f"MCP send {request.method}" + + with otel_span( + span_name, + kind=SpanKind.CLIENT, + attributes={"mcp.method.name": request.method, "jsonrpc.request.id": request_id}, + ): + # Inject W3C trace context into _meta (SEP-414). + meta: dict[str, Any] = request_data.setdefault("params", {}).setdefault("_meta", {}) + inject_trace_context(meta) + + jsonrpc_request = JSONRPCRequest(jsonrpc="2.0", id=request_id, **request_data) + await self._write_stream.send(SessionMessage(message=jsonrpc_request, metadata=metadata)) + + # request read timeout takes precedence over session read timeout + timeout = request_read_timeout_seconds or self._session_read_timeout_seconds + + try: + with anyio.fail_after(timeout): + response_or_error = await response_stream_reader.receive() + except TimeoutError: + class_name = request.__class__.__name__ + message = f"Timed out while waiting for response to {class_name}. Waited {timeout} seconds." + raise MCPError(code=REQUEST_TIMEOUT, message=message) + + if isinstance(response_or_error, JSONRPCError): + raise MCPError.from_jsonrpc_error(response_or_error) + else: + return result_type.model_validate(response_or_error.result, by_name=False) finally: self._response_streams.pop(request_id, None) @@ -330,10 +351,10 @@ def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]: async def _receive_loop(self) -> None: async with self._read_stream, self._write_stream: try: - async for message in self._read_stream: - if isinstance(message, Exception): # pragma: no cover - await self._handle_incoming(message) - elif isinstance(message.message, JSONRPCRequest): + + async def _handle_session_message(message: SessionMessage) -> None: + sender_context: contextvars.Context | None = getattr(self._read_stream, "last_context", None) + if isinstance(message.message, JSONRPCRequest): try: validated_request = self._receive_request_adapter.validate_python( message.message.model_dump(by_alias=True, mode="json", exclude_none=True), @@ -346,6 +367,7 @@ async def _receive_loop(self) -> None: session=self, on_complete=lambda r: self._in_flight.pop(r.request_id, None), message_metadata=message.metadata, + context=sender_context, ) self._in_flight[responder.request_id] = responder await self._received_request(responder) @@ -403,6 +425,13 @@ async def _receive_loop(self) -> None: else: # Response or error await self._handle_response(message) + async for message in self._read_stream: + if isinstance(message, Exception): + await self._handle_incoming(message) + continue + + await _handle_session_message(message) + except anyio.ClosedResourceError: # This is expected when the client disconnects abruptly. # Without this handler, the exception would propagate up and @@ -415,12 +444,14 @@ async def _receive_loop(self) -> None: finally: # after the read stream is closed, we need to send errors # to any pending requests - for id, stream in self._response_streams.items(): + # Snapshot: stream.send() wakes the waiter, whose finally pops + # from _response_streams before the next __next__() call. + for id, stream in list(self._response_streams.items()): error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed") try: await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) await stream.aclose() - except Exception: # pragma: no cover + except Exception: # pragma: lax no cover # Stream might already be closed pass self._response_streams.clear() @@ -458,6 +489,12 @@ async def _handle_response(self, message: SessionMessage) -> None: if not isinstance(message.message, JSONRPCResponse | JSONRPCError): return # pragma: no cover + if message.message.id is None: + # Narrows to JSONRPCError since JSONRPCResponse.id is always RequestId + error = message.message.error + logging.warning(f"Received error with null ID: {error.message}") + await self._handle_incoming(MCPError(error.code, error.message, error.data)) + return # Normalize response ID to handle type mismatches (e.g., "0" vs 0) response_id = self._normalize_request_id(message.message.id) @@ -506,4 +543,4 @@ async def send_progress_notification( async def _handle_incoming( self, req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception ) -> None: - """A generic handler for incoming messages. Overwritten by subclasses.""" + """A generic handler for incoming messages. Overridden by subclasses.""" diff --git a/src/mcp/types/_types.py b/src/mcp/types/_types.py index 320422636a..9005d253af 100644 --- a/src/mcp/types/_types.py +++ b/src/mcp/types/_types.py @@ -22,7 +22,7 @@ provided by the client. See the "Protocol Version Header" at -https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#protocol-version-header). +https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#protocol-version-header. """ ProgressToken = str | int @@ -108,8 +108,7 @@ class Request(MCPModel, Generic[RequestParamsT, MethodT]): class PaginatedRequest(Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]): - """Base class for paginated requests, - matching the schema's PaginatedRequest interface.""" + """Base class for paginated requests, matching the schema's PaginatedRequest interface.""" params: PaginatedRequestParams | None = None @@ -174,10 +173,10 @@ class Icon(MCPModel): theme: IconTheme | None = None """Optional theme specifier. - - `"light"` indicates the icon is designed for a light background, `"dark"` indicates the icon + + `"light"` indicates the icon is designed for a light background, `"dark"` indicates the icon is designed for a dark background. - + See https://modelcontextprotocol.io/specification/2025-11-25/schema#icon for more details. """ @@ -536,7 +535,7 @@ class TaskStatusNotificationParams(NotificationParams, Task): class TaskStatusNotification(Notification[TaskStatusNotificationParams, Literal["notifications/tasks/status"]]): """An optional notification from the receiver to the requestor, informing them that a task's status has changed. - Receivers are not required to send these notifications + Receivers are not required to send these notifications. """ method: Literal["notifications/tasks/status"] = "notifications/tasks/status" @@ -608,7 +607,7 @@ class ProgressNotificationParams(NotificationParams): message: str | None = None """Message related to progress. - This should provide relevant human readable progress information. + This should provide relevant human-readable progress information. """ @@ -999,7 +998,9 @@ class ToolResultContent(MCPModel): SamplingContent: TypeAlias = TextContent | ImageContent | AudioContent """Basic content types for sampling responses (without tool use). -Used for backwards-compatible CreateMessageResult when tools are not used.""" + +Used for backwards-compatible CreateMessageResult when tools are not used. +""" class SamplingMessage(MCPModel): @@ -1117,7 +1118,7 @@ class ToolAnnotations(MCPModel): idempotent_hint: bool | None = None """ If true, calling the tool repeatedly with the same arguments - will have no additional effect on the its environment. + will have no additional effect on its environment. (This property is meaningful only when `read_only_hint == false`) Default: false """ @@ -1265,7 +1266,7 @@ class ModelPreferences(MCPModel): sampling. Because LLMs can vary along multiple dimensions, choosing the "best" model is - rarely straightforward. Different models excel in different areas—some are + rarely straightforward. Different models excel in different areas—some are faster but less capable, others are more capable but more expensive, and so on. This interface allows servers to express their priorities across multiple dimensions to help clients make an appropriate selection for their use case. @@ -1369,7 +1370,7 @@ class CreateMessageRequest(Request[CreateMessageRequestParams, Literal["sampling class CreateMessageResult(Result): - """The client's response to a sampling/create_message request from the server. + """The client's response to a sampling/createMessage request from the server. This is the backwards-compatible version that returns single content (no arrays). Used when the request does not include tools. @@ -1386,7 +1387,7 @@ class CreateMessageResult(Result): class CreateMessageResultWithTools(Result): - """The client's response to a sampling/create_message request when tools were provided. + """The client's response to a sampling/createMessage request when tools were provided. This version supports array content for tool use flows. """ @@ -1426,14 +1427,14 @@ class PromptReference(MCPModel): type: Literal["ref/prompt"] = "ref/prompt" name: str - """The name of the prompt or prompt template""" + """The name of the prompt or prompt template.""" class CompletionArgument(MCPModel): """The argument's information for completion requests.""" name: str - """The name of the argument""" + """The name of the argument.""" value: str """The value of the argument to use for completion matching.""" @@ -1451,7 +1452,7 @@ class CompleteRequestParams(RequestParams): ref: ResourceTemplateReference | PromptReference argument: CompletionArgument context: CompletionContext | None = None - """Additional, optional context for completions""" + """Additional, optional context for completions.""" class CompleteRequest(Request[CompleteRequestParams, Literal["completion/complete"]]): @@ -1479,7 +1480,7 @@ class Completion(MCPModel): class CompleteResult(Result): - """The server's response to a completion/complete request""" + """The server's response to a completion/complete request.""" completion: Completion @@ -1522,6 +1523,7 @@ class Root(MCPModel): class ListRootsResult(Result): """The client's response to a roots/list request from the server. + This result contains an array of Root objects, each representing a root directory or file that the server can operate on. """ @@ -1643,7 +1645,7 @@ class ElicitRequestFormParams(RequestParams): requested_schema: ElicitRequestedSchema """ - A restricted subset of JSON Schema defining the structure of expected response. + A restricted subset of JSON Schema defining the structure of the expected response. Only top-level properties are allowed, without nesting. """ @@ -1697,8 +1699,8 @@ class ElicitResult(Result): content: dict[str, str | int | float | bool | list[str] | None] | None = None """ The submitted form data, only present when action is "accept" in form mode. - Contains values matching the requested schema. Values can be strings, integers, - booleans, or arrays of strings. + Contains values matching the requested schema. Values can be strings, integers, floats, + booleans, arrays of strings, or null. For URL mode, this field is omitted. """ diff --git a/src/mcp/types/jsonrpc.py b/src/mcp/types/jsonrpc.py index 0cfdc993a5..84304a37c1 100644 --- a/src/mcp/types/jsonrpc.py +++ b/src/mcp/types/jsonrpc.py @@ -75,7 +75,7 @@ class JSONRPCError(BaseModel): """A response to a request that indicates an error occurred.""" jsonrpc: Literal["2.0"] - id: RequestId + id: RequestId | None error: ErrorData diff --git a/tests/cli/test_claude.py b/tests/cli/test_claude.py new file mode 100644 index 0000000000..73d4f0eb52 --- /dev/null +++ b/tests/cli/test_claude.py @@ -0,0 +1,146 @@ +"""Tests for mcp.cli.claude — Claude Desktop config file generation.""" + +import json +from pathlib import Path +from typing import Any + +import pytest + +from mcp.cli.claude import get_uv_path, update_claude_config + + +@pytest.fixture +def config_dir(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path: + """Temp Claude config dir with get_claude_config_path and get_uv_path mocked.""" + claude_dir = tmp_path / "Claude" + claude_dir.mkdir() + monkeypatch.setattr("mcp.cli.claude.get_claude_config_path", lambda: claude_dir) + monkeypatch.setattr("mcp.cli.claude.get_uv_path", lambda: "/fake/bin/uv") + return claude_dir + + +def _read_server(config_dir: Path, name: str) -> dict[str, Any]: + config = json.loads((config_dir / "claude_desktop_config.json").read_text()) + return config["mcpServers"][name] + + +def test_generates_uv_run_command(config_dir: Path): + """Should write a uv run command that invokes mcp run on the resolved file spec.""" + assert update_claude_config(file_spec="server.py:app", server_name="my_server") + + resolved = Path("server.py").resolve() + assert _read_server(config_dir, "my_server") == { + "command": "/fake/bin/uv", + "args": ["run", "--frozen", "--with", "mcp[cli]", "mcp", "run", f"{resolved}:app"], + } + + +def test_file_spec_without_object_suffix(config_dir: Path): + """File specs without :object should still resolve to an absolute path.""" + assert update_claude_config(file_spec="server.py", server_name="s") + + assert _read_server(config_dir, "s")["args"][-1] == str(Path("server.py").resolve()) + + +def test_with_packages_sorted_and_deduplicated(config_dir: Path): + """Extra packages should appear as --with flags, sorted and deduplicated with mcp[cli].""" + assert update_claude_config(file_spec="s.py:app", server_name="s", with_packages=["zebra", "aardvark", "zebra"]) + + args = _read_server(config_dir, "s")["args"] + assert args[:8] == ["run", "--frozen", "--with", "aardvark", "--with", "mcp[cli]", "--with", "zebra"] + + +def test_with_editable_adds_flag(config_dir: Path, tmp_path: Path): + """with_editable should add --with-editable after the --with flags.""" + editable = tmp_path / "project" + assert update_claude_config(file_spec="s.py:app", server_name="s", with_editable=editable) + + args = _read_server(config_dir, "s")["args"] + assert args[4:6] == ["--with-editable", str(editable)] + + +def test_env_vars_written(config_dir: Path): + """env_vars should be written under the server's env key.""" + assert update_claude_config(file_spec="s.py:app", server_name="s", env_vars={"KEY": "val"}) + + assert _read_server(config_dir, "s")["env"] == {"KEY": "val"} + + +def test_existing_env_vars_merged_new_wins(config_dir: Path): + """Re-installing should merge env vars, with new values overriding existing ones.""" + (config_dir / "claude_desktop_config.json").write_text( + json.dumps({"mcpServers": {"s": {"env": {"OLD": "keep", "KEY": "old"}}}}) + ) + + assert update_claude_config(file_spec="s.py:app", server_name="s", env_vars={"KEY": "new"}) + + assert _read_server(config_dir, "s")["env"] == {"OLD": "keep", "KEY": "new"} + + +def test_existing_env_vars_preserved_without_new(config_dir: Path): + """Re-installing without env_vars should keep the existing env block intact.""" + (config_dir / "claude_desktop_config.json").write_text(json.dumps({"mcpServers": {"s": {"env": {"KEEP": "me"}}}})) + + assert update_claude_config(file_spec="s.py:app", server_name="s") + + assert _read_server(config_dir, "s")["env"] == {"KEEP": "me"} + + +def test_other_servers_preserved(config_dir: Path): + """Installing a new server should not clobber existing mcpServers entries.""" + (config_dir / "claude_desktop_config.json").write_text(json.dumps({"mcpServers": {"other": {"command": "x"}}})) + + assert update_claude_config(file_spec="s.py:app", server_name="s") + + config = json.loads((config_dir / "claude_desktop_config.json").read_text()) + assert set(config["mcpServers"]) == {"other", "s"} + assert config["mcpServers"]["other"] == {"command": "x"} + + +def test_raises_when_config_dir_missing(monkeypatch: pytest.MonkeyPatch): + """Should raise RuntimeError when Claude Desktop config dir can't be found.""" + monkeypatch.setattr("mcp.cli.claude.get_claude_config_path", lambda: None) + monkeypatch.setattr("mcp.cli.claude.get_uv_path", lambda: "/fake/bin/uv") + + with pytest.raises(RuntimeError, match="Claude Desktop config directory not found"): + update_claude_config(file_spec="s.py:app", server_name="s") + + +@pytest.mark.parametrize("which_result, expected", [("/usr/local/bin/uv", "/usr/local/bin/uv"), (None, "uv")]) +def test_get_uv_path(monkeypatch: pytest.MonkeyPatch, which_result: str | None, expected: str): + """Should return shutil.which's result, or fall back to bare 'uv' when not on PATH.""" + + def fake_which(cmd: str) -> str | None: + return which_result + + monkeypatch.setattr("shutil.which", fake_which) + assert get_uv_path() == expected + + +@pytest.mark.parametrize( + "file_spec, expected_last_arg", + [ + ("C:\\Users\\server.py", "C:\\Users\\server.py"), + ("C:\\Users\\server.py:app", "C:\\Users\\server.py:app"), + ], +) +def test_windows_drive_letter_not_split( + config_dir: Path, monkeypatch: pytest.MonkeyPatch, file_spec: str, expected_last_arg: str +): + """Drive-letter paths like 'C:\\server.py' must not be split on the drive colon. + + Before the fix, a bare 'C:\\path\\server.py' would hit rsplit(":", 1) and yield + ("C", "\\path\\server.py"), calling resolve() on Path("C") instead of the full path. + """ + seen: list[str] = [] + + def fake_resolve(self: Path) -> Path: + seen.append(str(self)) + return self + + monkeypatch.setattr(Path, "resolve", fake_resolve) + + assert update_claude_config(file_spec=file_spec, server_name="s") + + assert seen == ["C:\\Users\\server.py"] + assert _read_server(config_dir, "s")["args"][-1] == expected_last_arg diff --git a/tests/client/auth/extensions/test_client_credentials.py b/tests/client/auth/extensions/test_client_credentials.py index 0003b16797..09760f4530 100644 --- a/tests/client/auth/extensions/test_client_credentials.py +++ b/tests/client/auth/extensions/test_client_credentials.py @@ -252,6 +252,72 @@ async def test_exchange_token_client_credentials(self, mock_storage: MockTokenSt assert "scope=read write" in content assert "resource=https://api.example.com/v1/mcp" in content + @pytest.mark.anyio + async def test_exchange_token_client_secret_post_includes_client_id(self, mock_storage: MockTokenStorage): + """Test that client_secret_post includes both client_id and client_secret in body (RFC 6749 §2.3.1).""" + provider = ClientCredentialsOAuthProvider( + server_url="https://api.example.com/v1/mcp", + storage=mock_storage, + client_id="test-client-id", + client_secret="test-client-secret", + token_endpoint_auth_method="client_secret_post", + scopes="read write", + ) + await provider._initialize() + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://api.example.com"), + authorization_endpoint=AnyHttpUrl("https://api.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://api.example.com/token"), + ) + provider.context.protocol_version = "2025-06-18" + + request = await provider._perform_authorization() + + content = urllib.parse.unquote_plus(request.content.decode()) + assert "grant_type=client_credentials" in content + assert "client_id=test-client-id" in content + assert "client_secret=test-client-secret" in content + # Should NOT have Basic auth header + assert "Authorization" not in request.headers + + @pytest.mark.anyio + async def test_exchange_token_client_secret_post_without_client_id(self, mock_storage: MockTokenStorage): + """Test client_secret_post skips body credentials when client_id is None.""" + provider = ClientCredentialsOAuthProvider( + server_url="https://api.example.com/v1/mcp", + storage=mock_storage, + client_id="placeholder", + client_secret="test-client-secret", + token_endpoint_auth_method="client_secret_post", + scopes="read write", + ) + await provider._initialize() + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://api.example.com"), + authorization_endpoint=AnyHttpUrl("https://api.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://api.example.com/token"), + ) + provider.context.protocol_version = "2025-06-18" + # Override client_info to have client_id=None (edge case) + provider.context.client_info = OAuthClientInformationFull( + redirect_uris=None, + client_id=None, + client_secret="test-client-secret", + grant_types=["client_credentials"], + token_endpoint_auth_method="client_secret_post", + scope="read write", + ) + + request = await provider._perform_authorization() + + content = urllib.parse.unquote_plus(request.content.decode()) + assert "grant_type=client_credentials" in content + # Neither client_id nor client_secret should be in body since client_id is None + # (RFC 6749 §2.3.1 requires both for client_secret_post) + assert "client_id=" not in content + assert "client_secret=" not in content + assert "Authorization" not in request.headers + @pytest.mark.anyio async def test_exchange_token_without_scopes(self, mock_storage: MockTokenStorage): """Test token exchange without scopes.""" diff --git a/tests/client/conftest.py b/tests/client/conftest.py index 268e968aa4..081e1d68ea 100644 --- a/tests/client/conftest.py +++ b/tests/client/conftest.py @@ -4,15 +4,15 @@ from unittest.mock import patch import pytest -from anyio.streams.memory import MemoryObjectSendStream import mcp.shared.memory +from mcp.client._transport import WriteStream from mcp.shared.message import SessionMessage from mcp.types import JSONRPCNotification, JSONRPCRequest class SpyMemoryObjectSendStream: - def __init__(self, original_stream: MemoryObjectSendStream[SessionMessage]): + def __init__(self, original_stream: WriteStream[SessionMessage]): self.original_stream = original_stream self.sent_messages: list[SessionMessage] = [] @@ -77,7 +77,8 @@ def get_server_notifications(self, method: str | None = None) -> list[JSONRPCNot def stream_spy() -> Generator[Callable[[], StreamSpyCollection], None, None]: """Fixture that provides spies for both client and server write streams. - Example usage: + Example: + ```python async def test_something(stream_spy): # ... set up server and client ... @@ -92,6 +93,7 @@ async def test_something(stream_spy): # Clear for the next operation spies.clear() + ``` """ client_spy = None server_spy = None diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 5aa985e360..bb0bce4c92 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -2264,3 +2264,357 @@ async def callback_handler() -> tuple[str, str | None]: await auth_flow.asend(final_response) except StopAsyncIteration: pass + + +class TestSEP2207OfflineAccessScope: + """Test SEP-2207: offline_access scope augmentation for OIDC-flavored refresh tokens.""" + + def _make_as_metadata(self, scopes_supported: list[str] | None = None) -> OAuthMetadata: + return OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + scopes_supported=scopes_supported, + ) + + def _make_prm(self, scopes_supported: list[str] | None = None) -> ProtectedResourceMetadata: + return ProtectedResourceMetadata( + resource=AnyHttpUrl("https://api.example.com/v1/mcp"), + authorization_servers=[AnyHttpUrl("https://auth.example.com")], + scopes_supported=scopes_supported, + ) + + def test_offline_access_added_when_as_supports_and_client_has_refresh_token(self): + """offline_access is appended when AS advertises it and client supports refresh_token grant.""" + prm = self._make_prm(scopes_supported=["read", "write"]) + asm = self._make_as_metadata(scopes_supported=["read", "write", "offline_access"]) + + scopes = get_client_metadata_scopes( + www_authenticate_scope=None, + protected_resource_metadata=prm, + authorization_server_metadata=asm, + client_grant_types=["authorization_code", "refresh_token"], + ) + assert scopes == "read write offline_access" + + def test_offline_access_added_with_www_authenticate_scope(self): + """offline_access is appended even when scopes come from WWW-Authenticate header.""" + asm = self._make_as_metadata(scopes_supported=["read", "write", "offline_access"]) + + scopes = get_client_metadata_scopes( + www_authenticate_scope="read write", + protected_resource_metadata=None, + authorization_server_metadata=asm, + client_grant_types=["authorization_code", "refresh_token"], + ) + assert scopes == "read write offline_access" + + def test_offline_access_not_added_when_as_does_not_support(self): + """offline_access is not added when AS does not advertise it in scopes_supported.""" + prm = self._make_prm(scopes_supported=["read", "write"]) + asm = self._make_as_metadata(scopes_supported=["read", "write"]) + + scopes = get_client_metadata_scopes( + www_authenticate_scope=None, + protected_resource_metadata=prm, + authorization_server_metadata=asm, + client_grant_types=["authorization_code", "refresh_token"], + ) + assert scopes == "read write" + + def test_offline_access_not_added_when_client_has_no_refresh_token_grant(self): + """offline_access is not added when client does not support refresh_token grant.""" + prm = self._make_prm(scopes_supported=["read", "write"]) + asm = self._make_as_metadata(scopes_supported=["read", "write", "offline_access"]) + + scopes = get_client_metadata_scopes( + www_authenticate_scope=None, + protected_resource_metadata=prm, + authorization_server_metadata=asm, + client_grant_types=["authorization_code"], + ) + assert scopes == "read write" + + def test_offline_access_not_duplicated_when_already_present(self): + """offline_access is not added again if it already appears in the selected scopes.""" + prm = self._make_prm(scopes_supported=["read", "offline_access", "write"]) + asm = self._make_as_metadata(scopes_supported=["read", "write", "offline_access"]) + + scopes = get_client_metadata_scopes( + www_authenticate_scope=None, + protected_resource_metadata=prm, + authorization_server_metadata=asm, + client_grant_types=["authorization_code", "refresh_token"], + ) + assert scopes == "read offline_access write" + + def test_offline_access_not_added_when_no_scopes_selected(self): + """offline_access is not added when no base scopes are available (None).""" + asm = self._make_as_metadata(scopes_supported=["offline_access"]) + + scopes = get_client_metadata_scopes( + www_authenticate_scope=None, + protected_resource_metadata=None, + authorization_server_metadata=asm, + client_grant_types=["authorization_code", "refresh_token"], + ) + # When AS scopes are the only source and include offline_access, + # the base scope is "offline_access" and no duplication happens + assert scopes == "offline_access" + + def test_offline_access_not_added_when_as_scopes_supported_is_none(self): + """offline_access is not added when AS scopes_supported is None.""" + prm = self._make_prm(scopes_supported=["read", "write"]) + asm = self._make_as_metadata(scopes_supported=None) + + scopes = get_client_metadata_scopes( + www_authenticate_scope=None, + protected_resource_metadata=prm, + authorization_server_metadata=asm, + client_grant_types=["authorization_code", "refresh_token"], + ) + assert scopes == "read write" + + def test_offline_access_not_added_when_no_as_metadata(self): + """offline_access is not added when AS metadata is not available.""" + prm = self._make_prm(scopes_supported=["read", "write"]) + + scopes = get_client_metadata_scopes( + www_authenticate_scope=None, + protected_resource_metadata=prm, + authorization_server_metadata=None, + client_grant_types=["authorization_code", "refresh_token"], + ) + assert scopes == "read write" + + def test_offline_access_not_added_when_no_grant_types_provided(self): + """offline_access is not added when client_grant_types is None.""" + prm = self._make_prm(scopes_supported=["read", "write"]) + asm = self._make_as_metadata(scopes_supported=["read", "write", "offline_access"]) + + scopes = get_client_metadata_scopes( + www_authenticate_scope=None, + protected_resource_metadata=prm, + authorization_server_metadata=asm, + client_grant_types=None, + ) + assert scopes == "read write" + + def test_default_client_metadata_includes_refresh_token_grant(self): + """Default OAuthClientMetadata includes refresh_token in grant_types (SEP-2207 guidance).""" + metadata = OAuthClientMetadata(redirect_uris=[AnyUrl("http://localhost:3030/callback")]) + assert "refresh_token" in metadata.grant_types + + @pytest.mark.anyio + async def test_auth_flow_adds_offline_access_when_as_advertises( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ): + """E2E: auth flow includes offline_access in authorization request when AS advertises it.""" + + captured_auth_url: str | None = None + captured_state: str | None = None + + async def redirect_handler(url: str) -> None: + nonlocal captured_auth_url, captured_state + captured_auth_url = url + parsed = urlparse(url) + params = parse_qs(parsed.query) + captured_state = params.get("state", [None])[0] + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", captured_state + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + provider.context.current_tokens = None + provider.context.token_expiry_time = None + provider._initialized = True + + # Pre-set client info to skip DCR + provider.context.client_info = OAuthClientInformationFull( + client_id="test_client", + client_secret="test_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + + test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") + auth_flow = provider.async_auth_flow(test_request) + + # First request + request = await auth_flow.__anext__() + assert "Authorization" not in request.headers + + # Send 401 + response = httpx.Response(401, headers={}, request=test_request) + + # PRM discovery + prm_request = await auth_flow.asend(response) + prm_response = httpx.Response( + 200, + content=( + b'{"resource": "https://api.example.com/v1/mcp",' + b' "authorization_servers": ["https://auth.example.com"],' + b' "scopes_supported": ["read", "write"]}' + ), + request=prm_request, + ) + + # OAuth metadata discovery - AS advertises offline_access + oauth_request = await auth_flow.asend(prm_response) + oauth_response = httpx.Response( + 200, + content=( + b'{"issuer": "https://auth.example.com",' + b' "authorization_endpoint": "https://auth.example.com/authorize",' + b' "token_endpoint": "https://auth.example.com/token",' + b' "scopes_supported": ["read", "write", "offline_access"]}' + ), + request=oauth_request, + ) + + # This triggers authorization, which calls redirect_handler + token_request = await auth_flow.asend(oauth_response) + + # Verify the authorization URL included offline_access in the scope + assert captured_auth_url is not None + parsed = urlparse(captured_auth_url) + params = parse_qs(parsed.query) + scope_value = params["scope"][0] + scope_parts = scope_value.split() + assert "offline_access" in scope_parts + assert "read" in scope_parts + assert "write" in scope_parts + + # OIDC requires prompt=consent when offline_access is requested + assert params["prompt"][0] == "consent" + + # Complete the token exchange + token_response = httpx.Response( + 200, + content=( + b'{"access_token": "new_access_token", "token_type": "Bearer",' + b' "expires_in": 3600, "refresh_token": "new_refresh_token"}' + ), + request=token_request, + ) + + final_request = await auth_flow.asend(token_response) + assert final_request.headers["Authorization"] == "Bearer new_access_token" + + # Close the generator + final_response = httpx.Response(200, request=final_request) + try: + await auth_flow.asend(final_response) + except StopAsyncIteration: + pass + + @pytest.mark.anyio + async def test_auth_flow_no_offline_access_when_as_does_not_advertise( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ): + """E2E: auth flow does NOT include offline_access when AS doesn't advertise it.""" + + captured_auth_url: str | None = None + captured_state: str | None = None + + async def redirect_handler(url: str) -> None: + nonlocal captured_auth_url, captured_state + captured_auth_url = url + parsed = urlparse(url) + params = parse_qs(parsed.query) + captured_state = params.get("state", [None])[0] + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", captured_state + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + provider.context.current_tokens = None + provider.context.token_expiry_time = None + provider._initialized = True + + # Pre-set client info to skip DCR + provider.context.client_info = OAuthClientInformationFull( + client_id="test_client", + client_secret="test_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + + test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") + auth_flow = provider.async_auth_flow(test_request) + + # First request + await auth_flow.__anext__() + + # Send 401 + response = httpx.Response(401, headers={}, request=test_request) + + # PRM discovery + prm_request = await auth_flow.asend(response) + prm_response = httpx.Response( + 200, + content=( + b'{"resource": "https://api.example.com/v1/mcp",' + b' "authorization_servers": ["https://auth.example.com"],' + b' "scopes_supported": ["read", "write"]}' + ), + request=prm_request, + ) + + # OAuth metadata discovery - AS does NOT advertise offline_access + oauth_request = await auth_flow.asend(prm_response) + oauth_response = httpx.Response( + 200, + content=( + b'{"issuer": "https://auth.example.com",' + b' "authorization_endpoint": "https://auth.example.com/authorize",' + b' "token_endpoint": "https://auth.example.com/token",' + b' "scopes_supported": ["read", "write"]}' + ), + request=oauth_request, + ) + + # This triggers authorization, which calls redirect_handler + token_request = await auth_flow.asend(oauth_response) + + # Verify the authorization URL does NOT include offline_access + assert captured_auth_url is not None + parsed = urlparse(captured_auth_url) + params = parse_qs(parsed.query) + scope_value = params["scope"][0] + scope_parts = scope_value.split() + assert "offline_access" not in scope_parts + assert "read" in scope_parts + assert "write" in scope_parts + + # prompt=consent should NOT be present without offline_access + assert "prompt" not in params + + # Complete the token exchange + token_response = httpx.Response( + 200, + content=b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600}', + request=token_request, + ) + + final_request = await auth_flow.asend(token_response) + assert final_request.headers["Authorization"] == "Bearer new_access_token" + + # Close the generator + final_response = httpx.Response(200, request=final_request) + try: + await auth_flow.asend(final_response) + except StopAsyncIteration: + pass diff --git a/tests/client/test_client.py b/tests/client/test_client.py index d483ae54b6..ac52a9024a 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -2,16 +2,19 @@ from __future__ import annotations +import contextvars +from collections.abc import Iterator +from contextlib import contextmanager from unittest.mock import patch import anyio import pytest from inline_snapshot import snapshot -from mcp import types +from mcp import MCPError, types from mcp.client._memory import InMemoryTransport from mcp.client.client import Client -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.mcpserver import MCPServer from mcp.types import ( CallToolResult, @@ -41,33 +44,36 @@ @pytest.fixture def simple_server() -> Server: """Create a simple MCP server for testing.""" - server = Server(name="test_server") - @server.list_resources() - async def handle_list_resources(): - return [Resource(uri="memory://test", name="Test Resource", description="A test resource")] + async def handle_list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult( + resources=[Resource(uri="memory://test", name="Test Resource", description="A test resource")] + ) - @server.subscribe_resource() - async def handle_subscribe_resource(uri: str): - pass + async def handle_subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeRequestParams) -> EmptyResult: + return EmptyResult() - @server.unsubscribe_resource() - async def handle_unsubscribe_resource(uri: str): - pass + async def handle_unsubscribe_resource( + ctx: ServerRequestContext, params: types.UnsubscribeRequestParams + ) -> EmptyResult: + return EmptyResult() - @server.set_logging_level() - async def handle_set_logging_level(level: str): - pass + async def handle_set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + return EmptyResult() - @server.completion() - async def handle_completion( - ref: types.PromptReference | types.ResourceTemplateReference, - argument: types.CompletionArgument, - context: types.CompletionContext | None, - ) -> types.Completion | None: - return types.Completion(values=[]) + async def handle_completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> types.CompleteResult: + return types.CompleteResult(completion=types.Completion(values=[])) - return server + return Server( + name="test_server", + on_list_resources=handle_list_resources, + on_subscribe_resource=handle_subscribe_resource, + on_unsubscribe_resource=handle_unsubscribe_resource, + on_set_logging_level=handle_set_logging_level, + on_completion=handle_completion, + ) @pytest.fixture @@ -96,7 +102,7 @@ def greeting_prompt(name: str) -> str: async def test_client_is_initialized(app: MCPServer): """Test that the client is initialized after entering context.""" async with Client(app) as client: - assert client.server_capabilities == snapshot( + assert client.initialize_result.capabilities == snapshot( ServerCapabilities( experimental={}, prompts=PromptsCapability(list_changed=False), @@ -104,6 +110,7 @@ async def test_client_is_initialized(app: MCPServer): tools=ToolsCapability(list_changed=False), ) ) + assert client.initialize_result.server_info.name == "test" async def test_client_with_simple_server(simple_server: Server): @@ -172,6 +179,21 @@ async def test_read_resource(app: MCPServer): ) +async def test_read_resource_error_propagates(): + """MCPError raised by a server handler propagates to the client with its code intact.""" + + async def handle_read_resource( + ctx: ServerRequestContext, params: types.ReadResourceRequestParams + ) -> ReadResourceResult: + raise MCPError(code=404, message="no resource with that URI was found") + + server = Server("test", on_read_resource=handle_read_resource) + async with Client(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.read_resource("unknown://example") + assert exc_info.value.error.code == 404 + + async def test_get_prompt(app: MCPServer): """Test getting a prompt.""" async with Client(app) as client: @@ -202,19 +224,14 @@ async def test_client_send_progress_notification(): """Test sending progress notification.""" received_from_client = None event = anyio.Event() - server = Server(name="test_server") - - @server.progress_notification() - async def handle_progress_notification( - progress_token: str | int, - progress: float = 0.0, - total: float | None = None, - message: str | None = None, - ) -> None: + + async def handle_progress(ctx: ServerRequestContext, params: types.ProgressNotificationParams) -> None: nonlocal received_from_client - received_from_client = {"progress_token": progress_token, "progress": progress} + received_from_client = {"progress_token": params.progress_token, "progress": params.progress} event.set() + server = Server(name="test_server", on_progress=handle_progress) + async with Client(server) as client: await client.send_progress_notification(progress_token="token123", progress=50.0) await event.wait() @@ -306,3 +323,33 @@ async def test_client_uses_transport_directly(app: MCPServer): structured_content={"result": "Hello, Transport!"}, ) ) + + +_TEST_CONTEXTVAR = contextvars.ContextVar("test_var", default="initial") + + +@contextmanager +def _set_test_contextvar(value: str) -> Iterator[None]: + token = _TEST_CONTEXTVAR.set(value) + try: + yield + finally: + _TEST_CONTEXTVAR.reset(token) + + +async def test_context_propagation(): + """Sender's contextvars.Context is propagated to the server handler.""" + server = MCPServer("test") + + @server.tool() + async def check_context() -> str: + """Return the contextvar value visible to the handler.""" + return _TEST_CONTEXTVAR.get() + + async with Client(server) as client: + with _set_test_contextvar("client_value"): + result = await client.call_tool("check_context", {}) + + assert result.content[0].text == "client_value", ( # type: ignore[union-attr] + "Server handler did not see the sender's contextvars.Context" + ) diff --git a/tests/client/test_config.py b/tests/client/test_config.py deleted file mode 100644 index d1a0576ff3..0000000000 --- a/tests/client/test_config.py +++ /dev/null @@ -1,75 +0,0 @@ -import json -import subprocess -from pathlib import Path -from unittest.mock import patch - -import pytest - -from mcp.cli.claude import update_claude_config - - -@pytest.fixture -def temp_config_dir(tmp_path: Path): - """Create a temporary Claude config directory.""" - config_dir = tmp_path / "Claude" - config_dir.mkdir() - return config_dir - - -@pytest.fixture -def mock_config_path(temp_config_dir: Path): - """Mock get_claude_config_path to return our temporary directory.""" - with patch("mcp.cli.claude.get_claude_config_path", return_value=temp_config_dir): - yield temp_config_dir - - -def test_command_execution(mock_config_path: Path): - """Test that the generated command can actually be executed.""" - # Setup - server_name = "test_server" - file_spec = "test_server.py:app" - - # Update config - success = update_claude_config(file_spec=file_spec, server_name=server_name) - assert success - - # Read the generated config - config_file = mock_config_path / "claude_desktop_config.json" - config = json.loads(config_file.read_text()) - - # Get the command and args - server_config = config["mcpServers"][server_name] - command = server_config["command"] - args = server_config["args"] - - test_args = [command] + args + ["--help"] - - result = subprocess.run(test_args, capture_output=True, text=True, timeout=20, check=False) - - assert result.returncode == 0 - assert "usage" in result.stdout.lower() - - -def test_absolute_uv_path(mock_config_path: Path): - """Test that the absolute path to uv is used when available.""" - # Mock the shutil.which function to return a fake path - mock_uv_path = "/usr/local/bin/uv" - - with patch("mcp.cli.claude.get_uv_path", return_value=mock_uv_path): - # Setup - server_name = "test_server" - file_spec = "test_server.py:app" - - # Update config - success = update_claude_config(file_spec=file_spec, server_name=server_name) - assert success - - # Read the generated config - config_file = mock_config_path / "claude_desktop_config.json" - config = json.loads(config_file.read_text()) - - # Verify the command is the absolute path - server_config = config["mcpServers"][server_name] - command = server_config["command"] - - assert command == mock_uv_path diff --git a/tests/client/test_http_unicode.py b/tests/client/test_http_unicode.py index 5cca8c1943..cc2e14e469 100644 --- a/tests/client/test_http_unicode.py +++ b/tests/client/test_http_unicode.py @@ -8,7 +8,6 @@ import socket from collections.abc import AsyncGenerator, Generator from contextlib import asynccontextmanager -from typing import Any import pytest from starlette.applications import Starlette @@ -17,7 +16,7 @@ from mcp import types from mcp.client.session import ClientSession from mcp.client.streamable_http import streamable_http_client -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.types import TextContent, Tool from tests.test_helpers import wait_for_server @@ -47,54 +46,56 @@ def run_unicode_server(port: int) -> None: # pragma: no cover import uvicorn # Need to recreate the server setup in this process - server = Server(name="unicode_test_server") - - @server.list_tools() - async def list_tools() -> list[Tool]: - """List tools with Unicode descriptions.""" - return [ - Tool( - name="echo_unicode", - description="🔤 Echo Unicode text - Hello 👋 World 🌍 - Testing 🧪 Unicode ✨", - input_schema={ - "type": "object", - "properties": { - "text": {"type": "string", "description": "Text to echo back"}, + async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + Tool( + name="echo_unicode", + description="🔤 Echo Unicode text - Hello 👋 World 🌍 - Testing 🧪 Unicode ✨", + input_schema={ + "type": "object", + "properties": { + "text": {"type": "string", "description": "Text to echo back"}, + }, + "required": ["text"], }, - "required": ["text"], - }, - ), - ] - - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any] | None) -> list[TextContent]: - """Handle tool calls with Unicode content.""" - if name == "echo_unicode": - text = arguments.get("text", "") if arguments else "" - return [ - TextContent( - type="text", - text=f"Echo: {text}", - ) + ), ] - else: - raise ValueError(f"Unknown tool: {name}") - - @server.list_prompts() - async def list_prompts() -> list[types.Prompt]: - """List prompts with Unicode names and descriptions.""" - return [ - types.Prompt( - name="unicode_prompt", - description="Unicode prompt - Слой хранилища, где располагаются", - arguments=[], + ) + + async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + if params.name == "echo_unicode": + text = params.arguments.get("text", "") if params.arguments else "" + return types.CallToolResult( + content=[ + TextContent( + type="text", + text=f"Echo: {text}", + ) + ] ) - ] + else: + raise ValueError(f"Unknown tool: {params.name}") + + async def handle_list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListPromptsResult: + return types.ListPromptsResult( + prompts=[ + types.Prompt( + name="unicode_prompt", + description="Unicode prompt - Слой хранилища, где располагаются", + arguments=[], + ) + ] + ) - @server.get_prompt() - async def get_prompt(name: str, arguments: dict[str, Any] | None) -> types.GetPromptResult: - """Get a prompt with Unicode content.""" - if name == "unicode_prompt": + async def handle_get_prompt( + ctx: ServerRequestContext, params: types.GetPromptRequestParams + ) -> types.GetPromptResult: + if params.name == "unicode_prompt": return types.GetPromptResult( messages=[ types.PromptMessage( @@ -106,7 +107,15 @@ async def get_prompt(name: str, arguments: dict[str, Any] | None) -> types.GetPr ) ] ) - raise ValueError(f"Unknown prompt: {name}") + raise ValueError(f"Unknown prompt: {params.name}") + + server = Server( + name="unicode_test_server", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + on_list_prompts=handle_list_prompts, + on_get_prompt=handle_get_prompt, + ) # Create the session manager session_manager = StreamableHTTPSessionManager( diff --git a/tests/client/test_list_methods_cursor.py b/tests/client/test_list_methods_cursor.py index 4d7c53db25..f70fb9277d 100644 --- a/tests/client/test_list_methods_cursor.py +++ b/tests/client/test_list_methods_cursor.py @@ -3,9 +3,9 @@ import pytest from mcp import Client, types -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.mcpserver import MCPServer -from mcp.types import ListToolsRequest, ListToolsResult +from mcp.types import ListToolsResult from .conftest import StreamSpyCollection @@ -105,14 +105,16 @@ async def test_list_tools_with_strict_server_validation( async def test_list_tools_with_lowlevel_server(): """Test that list_tools works with a lowlevel Server using params.""" - server = Server("test-lowlevel") - @server.list_tools() - async def handle_list_tools(request: ListToolsRequest) -> ListToolsResult: + async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListToolsResult: # Echo back what cursor we received in the tool description - cursor = request.params.cursor if request.params else None + cursor = params.cursor if params else None return ListToolsResult(tools=[types.Tool(name="test_tool", description=f"cursor={cursor}", input_schema={})]) + server = Server("test-lowlevel", on_list_tools=handle_list_tools) + async with Client(server) as client: result = await client.list_tools() assert result.tools[0].description == "cursor=None" diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index 6a2f49f390..be4b9a97b9 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -3,8 +3,7 @@ from mcp import Client from mcp.client.session import ClientSession -from mcp.server.mcpserver import MCPServer -from mcp.server.mcpserver.server import Context +from mcp.server.mcpserver import Context, MCPServer from mcp.shared._context import RequestContext from mcp.types import ListRootsResult, Root, TextContent @@ -26,7 +25,7 @@ async def list_roots_callback( return callback_return @server.tool("test_list_roots") - async def test_list_roots(context: Context[None], message: str): + async def test_list_roots(context: Context, message: str): roots = await context.session.list_roots() assert roots == callback_return return True diff --git a/tests/client/test_logging_callback.py b/tests/client/test_logging_callback.py index 31cdeece73..454c1d3382 100644 --- a/tests/client/test_logging_callback.py +++ b/tests/client/test_logging_callback.py @@ -1,9 +1,9 @@ -from typing import Any, Literal +from typing import Literal import pytest from mcp import Client, types -from mcp.server.mcpserver import MCPServer +from mcp.server.mcpserver import Context, MCPServer from mcp.shared.session import RequestResponder from mcp.types import ( LoggingMessageNotificationParams, @@ -33,30 +33,23 @@ async def test_tool() -> bool: # Create a function that can send a log notification @server.tool("test_tool_with_log") async def test_tool_with_log( - message: str, level: Literal["debug", "info", "warning", "error"], logger: str + message: str, level: Literal["debug", "info", "warning", "error"], logger: str, ctx: Context ) -> bool: """Send a log notification to the client.""" - await server.get_context().log( - level=level, - message=message, - logger_name=logger, - ) + await ctx.log(level=level, data=message, logger_name=logger) return True - @server.tool("test_tool_with_log_extra") - async def test_tool_with_log_extra( - message: str, + @server.tool("test_tool_with_log_dict") + async def test_tool_with_log_dict( level: Literal["debug", "info", "warning", "error"], logger: str, - extra_string: str, - extra_dict: dict[str, Any], + ctx: Context, ) -> bool: - """Send a log notification to the client with extra fields.""" - await server.get_context().log( + """Send a log notification with a dict payload.""" + await ctx.log( level=level, - message=message, + data={"message": "Test log message", "extra_string": "example", "extra_dict": {"a": 1, "b": 2, "c": 3}}, logger_name=logger, - extra={"extra_string": extra_string, "extra_dict": extra_dict}, ) return True @@ -87,18 +80,15 @@ async def message_handler( "logger": "test_logger", }, ) - log_result_with_extra = await client.call_tool( - "test_tool_with_log_extra", + log_result_with_dict = await client.call_tool( + "test_tool_with_log_dict", { - "message": "Test log message", "level": "info", "logger": "test_logger", - "extra_string": "example", - "extra_dict": {"a": 1, "b": 2, "c": 3}, }, ) assert log_result.is_error is False - assert log_result_with_extra.is_error is False + assert log_result_with_dict.is_error is False assert len(logging_collector.log_messages) == 2 # Create meta object with related_request_id added dynamically log = logging_collector.log_messages[0] @@ -106,10 +96,10 @@ async def message_handler( assert log.logger == "test_logger" assert log.data == "Test log message" - log_with_extra = logging_collector.log_messages[1] - assert log_with_extra.level == "info" - assert log_with_extra.logger == "test_logger" - assert log_with_extra.data == { + log_with_dict = logging_collector.log_messages[1] + assert log_with_dict.level == "info" + assert log_with_dict.logger == "test_logger" + assert log_with_dict.data == { "message": "Test log message", "extra_string": "example", "extra_dict": {"a": 1, "b": 2, "c": 3}, diff --git a/tests/client/test_notification_response.py b/tests/client/test_notification_response.py index 9e233acc32..69c8afeb84 100644 --- a/tests/client/test_notification_response.py +++ b/tests/client/test_notification_response.py @@ -116,6 +116,58 @@ async def test_unexpected_content_type_sends_jsonrpc_error() -> None: await session.list_tools() +def _create_http_error_app(error_status: int, *, error_on_notifications: bool = False) -> Starlette: + """Create a server that returns an HTTP error for non-init requests.""" + + async def handle_mcp_request(request: Request) -> Response: + body = await request.body() + data = json.loads(body) + + if data.get("method") == "initialize": + return _init_json_response(data) + + if "id" not in data: + if error_on_notifications: + return Response(status_code=error_status) + return Response(status_code=202) + + return Response(status_code=error_status) + + return Starlette(debug=True, routes=[Route("/mcp", handle_mcp_request, methods=["POST"])]) + + +async def test_http_error_status_sends_jsonrpc_error() -> None: + """Verify HTTP 5xx errors unblock the pending request with an MCPError. + + When a server returns a non-2xx status code (e.g. 500), the client should + send a JSONRPCError so the pending request resolves immediately instead of + raising an unhandled httpx.HTTPStatusError that causes the caller to hang. + """ + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=_create_http_error_app(500))) as client: + async with streamable_http_client("http://localhost/mcp", http_client=client) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + await session.initialize() + + with pytest.raises(MCPError, match="Server returned an error response"): # pragma: no branch + await session.list_tools() + + +async def test_http_error_on_notification_does_not_hang() -> None: + """Verify HTTP errors on notifications are silently ignored. + + When a notification gets an HTTP error, there is no pending request to + unblock, so the client should just return without sending a JSONRPCError. + """ + app = _create_http_error_app(500, error_on_notifications=True) + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) as client: + async with streamable_http_client("http://localhost/mcp", http_client=client) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + await session.initialize() + + # Should not raise or hang — the error is silently ignored for notifications + await session.send_notification(RootsListChangedNotification(method="notifications/roots/list_changed")) + + def _create_invalid_json_response_app() -> Starlette: """Create a server that returns invalid JSON for requests.""" diff --git a/tests/client/test_output_schema_validation.py b/tests/client/test_output_schema_validation.py index cc93d303b1..d78197b5c3 100644 --- a/tests/client/test_output_schema_validation.py +++ b/tests/client/test_output_schema_validation.py @@ -1,50 +1,41 @@ -import inspect import logging -from contextlib import contextmanager from typing import Any -from unittest.mock import patch -import jsonschema import pytest from mcp import Client -from mcp.server.lowlevel import Server -from mcp.types import Tool - - -@contextmanager -def bypass_server_output_validation(): - """Context manager that bypasses server-side output validation. - This simulates a malicious or non-compliant server that doesn't validate - its outputs, allowing us to test client-side validation. - """ - # Save the original validate function - original_validate = jsonschema.validate - - # Create a mock that tracks which module is calling it - def selective_mock(instance: Any = None, schema: Any = None, *args: Any, **kwargs: Any) -> None: - # Check the call stack to see where this is being called from - for frame_info in inspect.stack(): - # If called from the server module, skip validation - # TODO: fix this as it's a rather gross workaround and will eventually break - # Normalize path separators for cross-platform compatibility - normalized_path = frame_info.filename.replace("\\", "/") - if "mcp/server/lowlevel/server.py" in normalized_path: - return None - # Otherwise, use the real validation (for client-side) - return original_validate(instance=instance, schema=schema, *args, **kwargs) - - with patch("jsonschema.validate", selective_mock): - yield +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + CallToolRequestParams, + CallToolResult, + ListToolsResult, + PaginatedRequestParams, + TextContent, + Tool, +) + + +def _make_server( + tools: list[Tool], + structured_content: dict[str, Any], +) -> Server: + """Create a low-level server that returns the given structured_content for any tool call.""" + + async def on_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=tools) + + async def on_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + return CallToolResult( + content=[TextContent(type="text", text="result")], + structured_content=structured_content, + ) + + return Server("test-server", on_list_tools=on_list_tools, on_call_tool=on_call_tool) @pytest.mark.anyio async def test_tool_structured_output_client_side_validation_basemodel(): """Test that client validates structured content against schema for BaseModel outputs""" - # Create a malicious low-level server that returns invalid structured content - server = Server("test-server") - - # Define the expected schema for our tool output_schema = { "type": "object", "properties": {"name": {"type": "string", "title": "Name"}, "age": {"type": "integer", "title": "Age"}}, @@ -52,39 +43,27 @@ async def test_tool_structured_output_client_side_validation_basemodel(): "title": "UserOutput", } - @server.list_tools() - async def list_tools(): - return [ + server = _make_server( + tools=[ Tool( name="get_user", description="Get user data", input_schema={"type": "object"}, output_schema=output_schema, ) - ] - - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]): - # Return invalid structured content - age is string instead of integer - # The low-level server will wrap this in CallToolResult - return {"name": "John", "age": "invalid"} # Invalid: age should be int + ], + structured_content={"name": "John", "age": "invalid"}, # Invalid: age should be int + ) - # Test that client validates the structured content - with bypass_server_output_validation(): - async with Client(server) as client: - # The client validates structured content and should raise an error - with pytest.raises(RuntimeError) as exc_info: - await client.call_tool("get_user", {}) - # Verify it's a validation error - assert "Invalid structured content returned by tool get_user" in str(exc_info.value) + async with Client(server) as client: + with pytest.raises(RuntimeError) as exc_info: + await client.call_tool("get_user", {}) + assert "Invalid structured content returned by tool get_user" in str(exc_info.value) @pytest.mark.anyio async def test_tool_structured_output_client_side_validation_primitive(): """Test that client validates structured content for primitive outputs""" - server = Server("test-server") - - # Primitive types are wrapped in {"result": value} output_schema = { "type": "object", "properties": {"result": {"type": "integer", "title": "Result"}}, @@ -92,122 +71,95 @@ async def test_tool_structured_output_client_side_validation_primitive(): "title": "calculate_Output", } - @server.list_tools() - async def list_tools(): - return [ + server = _make_server( + tools=[ Tool( name="calculate", description="Calculate something", input_schema={"type": "object"}, output_schema=output_schema, ) - ] - - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]): - # Return invalid structured content - result is string instead of integer - return {"result": "not_a_number"} # Invalid: should be int + ], + structured_content={"result": "not_a_number"}, # Invalid: should be int + ) - with bypass_server_output_validation(): - async with Client(server) as client: - # The client validates structured content and should raise an error - with pytest.raises(RuntimeError) as exc_info: - await client.call_tool("calculate", {}) - assert "Invalid structured content returned by tool calculate" in str(exc_info.value) + async with Client(server) as client: + with pytest.raises(RuntimeError) as exc_info: + await client.call_tool("calculate", {}) + assert "Invalid structured content returned by tool calculate" in str(exc_info.value) @pytest.mark.anyio async def test_tool_structured_output_client_side_validation_dict_typed(): """Test that client validates dict[str, T] structured content""" - server = Server("test-server") - - # dict[str, int] schema output_schema = {"type": "object", "additionalProperties": {"type": "integer"}, "title": "get_scores_Output"} - @server.list_tools() - async def list_tools(): - return [ + server = _make_server( + tools=[ Tool( name="get_scores", description="Get scores", input_schema={"type": "object"}, output_schema=output_schema, ) - ] - - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]): - # Return invalid structured content - values should be integers - return {"alice": "100", "bob": "85"} # Invalid: values should be int + ], + structured_content={"alice": "100", "bob": "85"}, # Invalid: values should be int + ) - with bypass_server_output_validation(): - async with Client(server) as client: - # The client validates structured content and should raise an error - with pytest.raises(RuntimeError) as exc_info: - await client.call_tool("get_scores", {}) - assert "Invalid structured content returned by tool get_scores" in str(exc_info.value) + async with Client(server) as client: + with pytest.raises(RuntimeError) as exc_info: + await client.call_tool("get_scores", {}) + assert "Invalid structured content returned by tool get_scores" in str(exc_info.value) @pytest.mark.anyio async def test_tool_structured_output_client_side_validation_missing_required(): """Test that client validates missing required fields""" - server = Server("test-server") - output_schema = { "type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "integer"}, "email": {"type": "string"}}, - "required": ["name", "age", "email"], # All fields required + "required": ["name", "age", "email"], "title": "PersonOutput", } - @server.list_tools() - async def list_tools(): - return [ + server = _make_server( + tools=[ Tool( name="get_person", description="Get person data", input_schema={"type": "object"}, output_schema=output_schema, ) - ] - - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]): - # Return structured content missing required field 'email' - return {"name": "John", "age": 30} # Missing required 'email' + ], + structured_content={"name": "John", "age": 30}, # Missing required 'email' + ) - with bypass_server_output_validation(): - async with Client(server) as client: - # The client validates structured content and should raise an error - with pytest.raises(RuntimeError) as exc_info: - await client.call_tool("get_person", {}) - assert "Invalid structured content returned by tool get_person" in str(exc_info.value) + async with Client(server) as client: + with pytest.raises(RuntimeError) as exc_info: + await client.call_tool("get_person", {}) + assert "Invalid structured content returned by tool get_person" in str(exc_info.value) @pytest.mark.anyio async def test_tool_not_listed_warning(caplog: pytest.LogCaptureFixture): """Test that client logs warning when tool is not in list_tools but has output_schema""" - server = Server("test-server") - @server.list_tools() - async def list_tools() -> list[Tool]: - # Return empty list - tool is not listed - return [] + async def on_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[]) + + async def on_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + return CallToolResult( + content=[TextContent(type="text", text="result")], + structured_content={"result": 42}, + ) - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]) -> dict[str, Any]: - # Server still responds to the tool call with structured content - return {"result": 42} + server = Server("test-server", on_list_tools=on_list_tools, on_call_tool=on_call_tool) - # Set logging level to capture warnings caplog.set_level(logging.WARNING) - with bypass_server_output_validation(): - async with Client(server) as client: - # Call a tool that wasn't listed - result = await client.call_tool("mystery_tool", {}) - assert result.structured_content == {"result": 42} - assert result.is_error is False + async with Client(server) as client: + result = await client.call_tool("mystery_tool", {}) + assert result.structured_content == {"result": 42} + assert result.is_error is False - # Check that warning was logged - assert "Tool mystery_tool not listed" in caplog.text + assert "Tool mystery_tool not listed" in caplog.text diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index 3357bc921d..6efcac0a52 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -2,7 +2,7 @@ from mcp import Client from mcp.client.session import ClientSession -from mcp.server.mcpserver import MCPServer +from mcp.server.mcpserver import Context, MCPServer from mcp.shared._context import RequestContext from mcp.types import ( CreateMessageRequestParams, @@ -32,8 +32,8 @@ async def sampling_callback( return callback_return @server.tool("test_sampling") - async def test_sampling_tool(message: str): - value = await server.get_context().session.create_message( + async def test_sampling_tool(message: str, ctx: Context) -> bool: + value = await ctx.session.create_message( messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))], max_tokens=100, ) @@ -77,9 +77,9 @@ async def sampling_callback( return callback_return @server.tool("test_backwards_compat") - async def test_tool(message: str): + async def test_tool(message: str, ctx: Context) -> bool: # Call create_message WITHOUT tools - result = await server.get_context().session.create_message( + result = await ctx.session.create_message( messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))], max_tokens=100, ) diff --git a/tests/client/test_session.py b/tests/client/test_session.py index d6d13e273c..f25c964f03 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -540,8 +540,8 @@ async def mock_server(): @pytest.mark.anyio -async def test_get_server_capabilities(): - """Test that get_server_capabilities returns None before init and capabilities after""" +async def test_initialize_result(): + """Test that initialize_result is None before init and contains the full result after.""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) @@ -551,6 +551,8 @@ async def test_get_server_capabilities(): resources=types.ResourcesCapability(subscribe=True, list_changed=True), tools=types.ToolsCapability(list_changed=False), ) + expected_server_info = Implementation(name="mock-server", version="0.1.0") + expected_instructions = "Use the tools wisely." async def mock_server(): session_message = await client_to_server_receive.receive() @@ -564,7 +566,8 @@ async def mock_server(): result = InitializeResult( protocol_version=LATEST_PROTOCOL_VERSION, capabilities=expected_capabilities, - server_info=Implementation(name="mock-server", version="0.1.0"), + server_info=expected_server_info, + instructions=expected_instructions, ) async with server_to_client_send: @@ -590,21 +593,17 @@ async def mock_server(): server_to_client_send, server_to_client_receive, ): - assert session.get_server_capabilities() is None + assert session.initialize_result is None tg.start_soon(mock_server) await session.initialize() - capabilities = session.get_server_capabilities() - assert capabilities is not None - assert capabilities == expected_capabilities - assert capabilities.logging is not None - assert capabilities.prompts is not None - assert capabilities.prompts.list_changed is True - assert capabilities.resources is not None - assert capabilities.resources.subscribe is True - assert capabilities.tools is not None - assert capabilities.tools.list_changed is False + result = session.initialize_result + assert result is not None + assert result.server_info == expected_server_info + assert result.capabilities == expected_capabilities + assert result.instructions == expected_instructions + assert result.protocol_version == LATEST_PROTOCOL_VERSION @pytest.mark.anyio diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index f70c24eee7..06e2cba4b1 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -1,12 +1,12 @@ import errno -import os import shutil import sys -import tempfile import textwrap import time +from contextlib import AsyncExitStack, suppress import anyio +import anyio.abc import pytest from mcp.client.session import ClientSession @@ -16,12 +16,11 @@ _terminate_process_tree, stdio_client, ) +from mcp.os.win32.utilities import FallbackProcess from mcp.shared.exceptions import MCPError from mcp.shared.message import SessionMessage from mcp.types import CONNECTION_CLOSED, JSONRPCMessage, JSONRPCRequest, JSONRPCResponse -from ..shared.test_win32_utils import escape_path_for_python - # Timeout for cleanup of processes that ignore SIGTERM # This timeout ensures the test fails quickly if the cleanup logic doesn't have # proper fallback mechanisms (SIGINT/SIGKILL) for processes that ignore SIGTERM @@ -221,291 +220,230 @@ def sigint_handler(signum, frame): raise -class TestChildProcessCleanup: - """Tests for child process cleanup functionality using _terminate_process_tree. - - These tests verify that child processes are properly terminated when the parent - is killed, addressing the issue where processes like npx spawn child processes - that need to be cleaned up. The tests cover various process tree scenarios: - - - Basic parent-child relationship (single child process) - - Multi-level process trees (parent → child → grandchild) - - Race conditions where parent exits during cleanup - - Note on Windows ResourceWarning: - On Windows, we may see ResourceWarning about subprocess still running. This is - expected behavior due to how Windows process termination works: - - anyio's process.terminate() calls Windows TerminateProcess() API - - TerminateProcess() immediately kills the process without allowing cleanup - - subprocess.Popen objects in the killed process can't run their cleanup code - - Python detects this during garbage collection and issues a ResourceWarning - - This warning does NOT indicate a process leak - the processes are properly - terminated. It only means the Popen objects couldn't clean up gracefully. - This is a fundamental difference between Windows and Unix process termination. - """ +# --------------------------------------------------------------------------- +# TestChildProcessCleanup — socket-based deterministic child liveness probe +# --------------------------------------------------------------------------- +# +# These tests verify that `_terminate_process_tree()` kills the *entire* process +# tree (not just the immediate child), which is critical for cleaning up tools +# like `npx` that spawn their own subprocesses. +# +# Mechanism: each subprocess in the tree connects a TCP socket back to a +# listener owned by the test. We then use two kernel-guaranteed blocking-I/O +# signals — neither requires any `sleep()` or polling loop: +# +# 1. `await listener.accept()` blocks until the subprocess connects, +# proving it is running. +# 2. After `_terminate_process_tree()`, `await stream.receive(1)` raises +# `EndOfStream` (clean close / FIN) or `BrokenResourceError` (abrupt +# close / RST — typical on Windows after TerminateJobObject) because the +# kernel closes all file descriptors when a process terminates. Either +# is the direct, OS-level proof that the child is dead. +# +# This replaces an older file-growth-watching approach whose fixed `sleep()` +# durations raced against slow Python interpreter startup on loaded CI runners. + + +def _connect_back_script(port: int) -> str: + """Return a ``python -c`` script body that connects to the given port, + sends ``b'alive'``, then blocks forever. Used by TestChildProcessCleanup + subprocesses as a liveness probe.""" + return ( + f"import socket, time\n" + f"s = socket.create_connection(('127.0.0.1', {port}))\n" + f"s.sendall(b'alive')\n" + f"time.sleep(3600)\n" + ) - @pytest.mark.anyio - @pytest.mark.filterwarnings("ignore::ResourceWarning" if sys.platform == "win32" else "default") - async def test_basic_child_process_cleanup(self): - """Test basic parent-child process cleanup. - Parent spawns a single child process that writes continuously to a file. - """ - # Create a marker file for the child process to write to - with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: - marker_file = f.name - # Also create a file to verify parent started - with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: - parent_marker = f.name +def _spawn_then_block(child_script: str) -> str: + """Return a ``python -c`` script body that spawns ``child_script`` as a + subprocess, then blocks forever. The ``!r`` injection avoids nested-quote + escaping for arbitrary child script content.""" + return ( + f"import subprocess, sys, time\nsubprocess.Popen([sys.executable, '-c', {child_script!r}])\ntime.sleep(3600)\n" + ) - try: - # Parent script that spawns a child process - parent_script = textwrap.dedent( - f""" - import subprocess - import sys - import time - import os - - # Mark that parent started - with open({escape_path_for_python(parent_marker)}, 'w') as f: - f.write('parent started\\n') - - # Child script that writes continuously - child_script = f''' - import time - with open({escape_path_for_python(marker_file)}, 'a') as f: - while True: - f.write(f"{time.time()}") - f.flush() - time.sleep(0.1) - ''' - - # Start the child process - child = subprocess.Popen([sys.executable, '-c', child_script]) - - # Parent just sleeps - while True: - time.sleep(0.1) - """ - ) - print("\nStarting child process termination test...") +async def _open_liveness_listener() -> tuple[anyio.abc.SocketListener, int]: + """Open a TCP listener on localhost and return it along with its port.""" + multi = await anyio.create_tcp_listener(local_host="127.0.0.1") + sock = multi.listeners[0] + assert isinstance(sock, anyio.abc.SocketListener) + addr = sock.extra(anyio.abc.SocketAttribute.local_address) + # IPv4 local_address is (host: str, port: int) + assert isinstance(addr, tuple) and len(addr) >= 2 and isinstance(addr[1], int) + return sock, addr[1] + + +async def _accept_alive(sock: anyio.abc.SocketListener) -> anyio.abc.SocketStream: + """Accept one connection and assert the peer sent ``b'alive'``. + + Blocks deterministically until a subprocess connects (no polling). The + outer test bounds this with ``anyio.fail_after`` to catch the case where + the subprocess chain failed to start. + """ + stream = await sock.accept() + msg = await stream.receive(5) + assert msg == b"alive", f"expected b'alive', got {msg!r}" + return stream - # Start the parent process - proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script]) - # Wait for processes to start - await anyio.sleep(0.5) +async def _assert_stream_closed(stream: anyio.abc.SocketStream) -> None: + """Assert the peer holding the other end of ``stream`` has terminated. - # Verify parent started - assert os.path.exists(parent_marker), "Parent process didn't start" + When a process dies, the kernel closes its file descriptors including + sockets. The next ``receive()`` on the peer socket unblocks with one of: - # Verify child is writing - if os.path.exists(marker_file): # pragma: no branch - initial_size = os.path.getsize(marker_file) - await anyio.sleep(0.3) - size_after_wait = os.path.getsize(marker_file) - assert size_after_wait > initial_size, "Child process should be writing" - print(f"Child is writing (file grew from {initial_size} to {size_after_wait} bytes)") + - ``anyio.EndOfStream`` — clean close (FIN), typical after graceful exit + or POSIX ``SIGTERM``. + - ``anyio.BrokenResourceError`` — abrupt close (RST), typical after + Windows ``TerminateJobObject`` or POSIX ``SIGKILL``. - # Terminate using our function - print("Terminating process and children...") + Either is a deterministic, kernel-level signal that the process is dead — + no sleeps or polling required. + """ + with anyio.fail_after(5.0), pytest.raises((anyio.EndOfStream, anyio.BrokenResourceError)): + await stream.receive(1) + + +async def _terminate_and_reap(proc: anyio.abc.Process | FallbackProcess) -> None: + """Terminate the process tree, reap, and tear down pipe transports. + + ``_terminate_process_tree`` kills the OS process group / Job Object but does + not call ``process.wait()`` or clean up the asyncio pipe transports. On + Windows those transports leak and emit ``ResourceWarning`` when GC'd in a + later test, causing ``PytestUnraisableExceptionWarning`` knock-on failures. + + Production ``stdio.py`` avoids this via its ``stdout_reader`` task which + reads stdout to EOF (triggering ``_ProactorReadPipeTransport._eof_received`` + → ``close()``) plus ``async with process:`` which waits and closes stdin. + These tests call ``_terminate_process_tree`` directly, so they replicate + both parts here: ``wait()`` + close stdin + drain stdout to EOF. + + The stdout drain is the non-obvious part: anyio's ``StreamReaderWrapper.aclose()`` + only marks the Python-level reader closed — it never touches the underlying + ``_ProactorReadPipeTransport``. That transport starts paused and only detects + pipe EOF when someone reads, so without a drain it lives until ``__del__``. + + Idempotent: the ``returncode`` guard skips termination if already reaped + (avoids spurious WARNING/ERROR logs from ``terminate_posix_process_tree``'s + fallback path, visible because ``log_cli = true``); ``wait()`` and stream + ``aclose()`` no-op on subsequent calls; the drain raises ``ClosedResourceError`` + on the second call, caught by the suppress. The tests call this explicitly + as the action under test and ``AsyncExitStack`` calls it again on exit as a + safety net. Bounded by ``move_on_after`` to prevent hangs. + """ + with anyio.move_on_after(5.0): + if proc.returncode is None: await _terminate_process_tree(proc) + await proc.wait() + assert proc.stdin is not None + assert proc.stdout is not None + await proc.stdin.aclose() + with suppress(anyio.EndOfStream, anyio.BrokenResourceError, anyio.ClosedResourceError): + await proc.stdout.receive(65536) + await proc.stdout.aclose() - # Verify processes stopped - await anyio.sleep(0.5) - if os.path.exists(marker_file): # pragma: no branch - size_after_cleanup = os.path.getsize(marker_file) - await anyio.sleep(0.5) - final_size = os.path.getsize(marker_file) - print(f"After cleanup: file size {size_after_cleanup} -> {final_size}") - assert final_size == size_after_cleanup, ( - f"Child process still running! File grew by {final_size - size_after_cleanup} bytes" - ) +class TestChildProcessCleanup: + """Integration tests for ``_terminate_process_tree`` covering basic, + nested, and early-parent-exit process tree scenarios. See module-level + comment above for the socket-based liveness probe mechanism. + """ - print("SUCCESS: Child process was properly terminated") + @pytest.mark.anyio + async def test_basic_child_process_cleanup(self): + """Parent spawns one child; terminating the tree kills both.""" + async with AsyncExitStack() as stack: + sock, port = await _open_liveness_listener() + stack.push_async_callback(sock.aclose) + + # Parent spawns a child; the child connects back to us. + parent_script = _spawn_then_block(_connect_back_script(port)) + proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script]) + stack.push_async_callback(_terminate_and_reap, proc) + + # Deterministic: accept() blocks until the child connects. No sleep. + with anyio.fail_after(10.0): + stream = await _accept_alive(sock) + stack.push_async_callback(stream.aclose) - finally: - # Clean up files - for f in [marker_file, parent_marker]: - try: - os.unlink(f) - except OSError: # pragma: no cover - pass + # Terminate, reap and close transports (wraps _terminate_process_tree, + # the behavior under test). + await _terminate_and_reap(proc) + + # Deterministic: kernel closed child's socket when it died. + await _assert_stream_closed(stream) @pytest.mark.anyio - @pytest.mark.filterwarnings("ignore::ResourceWarning" if sys.platform == "win32" else "default") async def test_nested_process_tree(self): - """Test nested process tree cleanup (parent → child → grandchild). - Each level writes to a different file to verify all processes are terminated. - """ - # Create temporary files for each process level - with tempfile.NamedTemporaryFile(mode="w", delete=False) as f1: - parent_file = f1.name - with tempfile.NamedTemporaryFile(mode="w", delete=False) as f2: - child_file = f2.name - with tempfile.NamedTemporaryFile(mode="w", delete=False) as f3: - grandchild_file = f3.name - - try: - # Simple nested process tree test - # We create parent -> child -> grandchild, each writing to a file - parent_script = textwrap.dedent( - f""" - import subprocess - import sys - import time - import os - - # Child will spawn grandchild and write to child file - child_script = f'''import subprocess - import sys - import time - - # Grandchild just writes to file - grandchild_script = \"\"\"import time - with open({escape_path_for_python(grandchild_file)}, 'a') as f: - while True: - f.write(f"gc {{time.time()}}") - f.flush() - time.sleep(0.1)\"\"\" - - # Spawn grandchild - subprocess.Popen([sys.executable, '-c', grandchild_script]) - - # Child writes to its file - with open({escape_path_for_python(child_file)}, 'a') as f: - while True: - f.write(f"c {time.time()}") - f.flush() - time.sleep(0.1)''' - - # Spawn child process - subprocess.Popen([sys.executable, '-c', child_script]) - - # Parent writes to its file - with open({escape_path_for_python(parent_file)}, 'a') as f: - while True: - f.write(f"p {time.time()}") - f.flush() - time.sleep(0.1) - """ + """Parent → child → grandchild; terminating the tree kills all three.""" + async with AsyncExitStack() as stack: + sock, port = await _open_liveness_listener() + stack.push_async_callback(sock.aclose) + + # Build a three-level chain: parent spawns child, child spawns + # grandchild. Every level connects back to our socket. + grandchild = _connect_back_script(port) + child = ( + f"import subprocess, sys\n" + f"subprocess.Popen([sys.executable, '-c', {grandchild!r}])\n" + _connect_back_script(port) + ) + parent_script = ( + f"import subprocess, sys\n" + f"subprocess.Popen([sys.executable, '-c', {child!r}])\n" + _connect_back_script(port) ) - - # Start the parent process proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script]) + stack.push_async_callback(_terminate_and_reap, proc) - # Let all processes start - await anyio.sleep(1.0) - - # Verify all are writing - for file_path, name in [(parent_file, "parent"), (child_file, "child"), (grandchild_file, "grandchild")]: - if os.path.exists(file_path): # pragma: no branch - initial_size = os.path.getsize(file_path) - await anyio.sleep(0.3) - new_size = os.path.getsize(file_path) - assert new_size > initial_size, f"{name} process should be writing" + # Deterministic: three blocking accepts, one per tree level. + streams: list[anyio.abc.SocketStream] = [] + with anyio.fail_after(10.0): + for _ in range(3): + stream = await _accept_alive(sock) + stack.push_async_callback(stream.aclose) + streams.append(stream) - # Terminate the whole tree - await _terminate_process_tree(proc) + # Terminate the entire tree (wraps _terminate_process_tree). + await _terminate_and_reap(proc) - # Verify all stopped - await anyio.sleep(0.5) - for file_path, name in [(parent_file, "parent"), (child_file, "child"), (grandchild_file, "grandchild")]: - if os.path.exists(file_path): # pragma: no branch - size1 = os.path.getsize(file_path) - await anyio.sleep(0.3) - size2 = os.path.getsize(file_path) - assert size1 == size2, f"{name} still writing after cleanup!" - - print("SUCCESS: All processes in tree terminated") - - finally: - # Clean up all marker files - for f in [parent_file, child_file, grandchild_file]: - try: - os.unlink(f) - except OSError: # pragma: no cover - pass + # Every level of the tree must be dead: three kernel-level EOFs. + for stream in streams: + await _assert_stream_closed(stream) @pytest.mark.anyio - @pytest.mark.filterwarnings("ignore::ResourceWarning" if sys.platform == "win32" else "default") async def test_early_parent_exit(self): - """Test cleanup when parent exits during termination sequence. - Tests the race condition where parent might die during our termination - sequence but we can still clean up the children via the process group. + """Parent exits immediately on SIGTERM; process-group termination still + catches the child (exercises the race where the parent dies mid-cleanup). """ - # Create a temporary file for the child - with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: - marker_file = f.name - - try: - # Parent that spawns child and waits briefly - parent_script = textwrap.dedent( - f""" - import subprocess - import sys - import time - import signal - - # Child that continues running - child_script = f'''import time - with open({escape_path_for_python(marker_file)}, 'a') as f: - while True: - f.write(f"child {time.time()}") - f.flush() - time.sleep(0.1)''' - - # Start child in same process group - subprocess.Popen([sys.executable, '-c', child_script]) - - # Parent waits a bit then exits on SIGTERM - def handle_term(sig, frame): - sys.exit(0) - - signal.signal(signal.SIGTERM, handle_term) - - # Wait - while True: - time.sleep(0.1) - """ + async with AsyncExitStack() as stack: + sock, port = await _open_liveness_listener() + stack.push_async_callback(sock.aclose) + + # Parent installs a SIGTERM handler that exits immediately, spawns a + # child that connects back to us, then blocks. + child = _connect_back_script(port) + parent_script = ( + f"import signal, subprocess, sys, time\n" + f"signal.signal(signal.SIGTERM, lambda *_: sys.exit(0))\n" + f"subprocess.Popen([sys.executable, '-c', {child!r}])\n" + f"time.sleep(3600)\n" ) - - # Start the parent process proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script]) + stack.push_async_callback(_terminate_and_reap, proc) - # Let child start writing - await anyio.sleep(0.5) - - # Verify child is writing - if os.path.exists(marker_file): # pragma: no branch - size1 = os.path.getsize(marker_file) - await anyio.sleep(0.3) - size2 = os.path.getsize(marker_file) - assert size2 > size1, "Child should be writing" + # Deterministic: child connected means both parent and child are up. + with anyio.fail_after(10.0): + stream = await _accept_alive(sock) + stack.push_async_callback(stream.aclose) - # Terminate - this will kill the process group even if parent exits first - await _terminate_process_tree(proc) + # Parent will sys.exit(0) on SIGTERM, but the process-group kill + # (POSIX killpg / Windows Job Object) must still terminate the child. + await _terminate_and_reap(proc) - # Verify child stopped - await anyio.sleep(0.5) - if os.path.exists(marker_file): # pragma: no branch - size3 = os.path.getsize(marker_file) - await anyio.sleep(0.3) - size4 = os.path.getsize(marker_file) - assert size3 == size4, "Child should be terminated" - - print("SUCCESS: Child terminated even with parent exit during cleanup") - - finally: - # Clean up marker file - try: - os.unlink(marker_file) - except OSError: # pragma: no cover - pass + # Child must be dead despite parent's early exit. + await _assert_stream_closed(stream) @pytest.mark.anyio diff --git a/tests/client/test_transport_stream_cleanup.py b/tests/client/test_transport_stream_cleanup.py new file mode 100644 index 0000000000..1e6be3c725 --- /dev/null +++ b/tests/client/test_transport_stream_cleanup.py @@ -0,0 +1,119 @@ +"""Regression tests for memory stream leaks in client transports. + +When a connection error occurs (404, 403, ConnectError), transport context +managers must close ALL 4 memory stream ends they created. anyio memory streams +are paired but independent — closing the writer does NOT close the reader. +Unclosed stream ends emit ResourceWarning on GC, which pytest promotes to a +test failure in whatever test happens to be running when GC triggers. + +These tests force GC after the transport context exits, so any leaked stream +triggers a ResourceWarning immediately and deterministically here, rather than +nondeterministically in an unrelated later test. +""" + +import gc +import sys +from collections.abc import Iterator +from contextlib import contextmanager + +import httpx +import pytest + +from mcp.client.sse import sse_client +from mcp.client.streamable_http import streamable_http_client +from mcp.client.websocket import websocket_client + + +@contextmanager +def _assert_no_memory_stream_leak() -> Iterator[None]: + """Fail if any anyio MemoryObject stream emits ResourceWarning during the block. + + Uses a custom sys.unraisablehook to capture ONLY MemoryObject stream leaks, + ignoring unrelated resources (e.g. PipeHandle from flaky stdio tests on the + same xdist worker). gc.collect() is forced after the block to make leaks + deterministic. + """ + leaked: list[str] = [] + old_hook = sys.unraisablehook + + def hook(args: "sys.UnraisableHookArgs") -> None: # pragma: no cover + # Only executes if a leak occurs (i.e. the bug is present). + # args.object is the __del__ function (not the stream instance) when + # unraisablehook fires from a finalizer, so check exc_value — the + # actual ResourceWarning("Unclosed "). + # Non-MemoryObject unraisables (e.g. PipeHandle leaked by an earlier + # flaky test on the same xdist worker) are deliberately ignored — + # this test should not fail for another test's resource leak. + if "MemoryObject" in str(args.exc_value): + leaked.append(str(args.exc_value)) + + sys.unraisablehook = hook + try: + yield + gc.collect() + assert not leaked, f"Memory streams leaked: {leaked}" + finally: + sys.unraisablehook = old_hook + + +@pytest.mark.anyio +async def test_sse_client_closes_all_streams_on_connection_error(free_tcp_port: int) -> None: + """sse_client creates streams only after the SSE connection succeeds, so a + ConnectError propagates directly with nothing to leak. + + Before the fix, streams were created before connecting and only 2 of 4 were + closed in the finally block. + """ + with _assert_no_memory_stream_leak(): + with pytest.raises(httpx.ConnectError): + async with sse_client(f"http://127.0.0.1:{free_tcp_port}/sse"): + pytest.fail("should not reach here") # pragma: no cover + + +@pytest.mark.anyio +async def test_sse_client_closes_all_streams_on_http_error() -> None: + """sse_client creates streams only after raise_for_status() passes, so an + HTTPStatusError from a 4xx/5xx response propagates bare (not wrapped in an + ExceptionGroup) with nothing to leak — the task group is never entered. + """ + + def return_403(request: httpx.Request) -> httpx.Response: + return httpx.Response(403) + + def mock_factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + return httpx.AsyncClient(transport=httpx.MockTransport(return_403)) + + with _assert_no_memory_stream_leak(): + with pytest.raises(httpx.HTTPStatusError): + async with sse_client("http://test/sse", httpx_client_factory=mock_factory): + pytest.fail("should not reach here") # pragma: no cover + + +@pytest.mark.anyio +async def test_streamable_http_client_closes_all_streams_on_exit() -> None: + """streamable_http_client must close all 4 stream ends on exit. + + Before the fix, read_stream was never closed — not even on the happy path. + This test enters and exits the context without sending any messages, so no + network connection is ever attempted (streamable_http connects lazily). + """ + with _assert_no_memory_stream_leak(): + async with streamable_http_client("http://127.0.0.1:1/mcp"): + pass + + +@pytest.mark.anyio +async def test_websocket_client_closes_all_streams_on_connection_error(free_tcp_port: int) -> None: + """websocket_client must close all 4 stream ends when ws_connect fails. + + Before the fix, there was no try/finally at all — if ws_connect raised, + all 4 streams were leaked. + """ + with _assert_no_memory_stream_leak(): + with pytest.raises(OSError): + async with websocket_client(f"ws://127.0.0.1:{free_tcp_port}/ws"): + pytest.fail("should not reach here") # pragma: no cover diff --git a/tests/client/transports/test_memory.py b/tests/client/transports/test_memory.py index 30ecb0ac33..c8fc41fd5d 100644 --- a/tests/client/transports/test_memory.py +++ b/tests/client/transports/test_memory.py @@ -2,31 +2,31 @@ import pytest -from mcp import Client +from mcp import Client, types from mcp.client._memory import InMemoryTransport -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.mcpserver import MCPServer -from mcp.types import Resource +from mcp.types import ListResourcesResult, Resource @pytest.fixture def simple_server() -> Server: """Create a simple MCP server for testing.""" - server = Server(name="test_server") - - # pragma: no cover - handler exists only to register a resource capability. - # Transport tests verify stream creation, not handler invocation. - @server.list_resources() - async def handle_list_resources(): # pragma: no cover - return [ - Resource( - uri="memory://test", - name="Test Resource", - description="A test resource", - ) - ] - return server + async def handle_list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: # pragma: no cover + return ListResourcesResult( + resources=[ + Resource( + uri="memory://test", + name="Test Resource", + description="A test resource", + ) + ] + ) + + return Server(name="test_server", on_list_resources=handle_list_resources) @pytest.fixture @@ -69,7 +69,7 @@ async def test_with_mcpserver(mcpserver_server: MCPServer): async def test_server_is_running(mcpserver_server: MCPServer): """Test that the server is running and responding to requests.""" async with Client(mcpserver_server) as client: - assert client.server_capabilities is not None + assert client.initialize_result.capabilities.tools is not None async def test_list_tools(mcpserver_server: MCPServer): diff --git a/tests/experimental/tasks/client/test_tasks.py b/tests/experimental/tasks/client/test_tasks.py index f21abf4d0f..613c794ebf 100644 --- a/tests/experimental/tasks/client/test_tasks.py +++ b/tests/experimental/tasks/client/test_tasks.py @@ -1,43 +1,38 @@ """Tests for the experimental client task methods (session.experimental).""" +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager from dataclasses import dataclass, field -from typing import Any import anyio import pytest from anyio import Event from anyio.abc import TaskGroup -from mcp.client.session import ClientSession -from mcp.server import Server -from mcp.server.lowlevel import NotificationOptions -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession +from mcp import Client +from mcp.server import Server, ServerRequestContext from mcp.shared.experimental.tasks.helpers import task_execution from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder from mcp.types import ( CallToolRequest, CallToolRequestParams, CallToolResult, - CancelTaskRequest, + CancelTaskRequestParams, CancelTaskResult, - ClientResult, CreateTaskResult, - GetTaskPayloadRequest, + GetTaskPayloadRequestParams, GetTaskPayloadResult, - GetTaskRequest, + GetTaskRequestParams, GetTaskResult, - ListTasksRequest, ListTasksResult, - ServerNotification, - ServerRequest, + ListToolsResult, + PaginatedRequestParams, TaskMetadata, TextContent, - Tool, ) +pytestmark = pytest.mark.anyio + @dataclass class AppContext: @@ -48,44 +43,53 @@ class AppContext: task_done_events: dict[str, Event] = field(default_factory=lambda: {}) -@pytest.mark.anyio -async def test_session_experimental_get_task() -> None: - """Test session.experimental.get_task() method.""" - # Note: We bypass the normal lifespan mechanism - server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] - store = InMemoryTaskStore() +async def _handle_list_tools( + ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None +) -> ListToolsResult: + raise NotImplementedError - @server.list_tools() - async def list_tools(): - return [Tool(name="test_tool", description="Test", input_schema={"type": "object"})] - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: - ctx = server.request_context - app = ctx.lifespan_context - if ctx.experimental.is_task: - task_metadata = ctx.experimental.task_metadata - assert task_metadata is not None - task = await app.store.create_task(task_metadata) +async def _handle_call_tool_with_done_event( + ctx: ServerRequestContext[AppContext], params: CallToolRequestParams, *, result_text: str = "Done" +) -> CallToolResult | CreateTaskResult: + app = ctx.lifespan_context + if ctx.experimental.is_task: + task_metadata = ctx.experimental.task_metadata + assert task_metadata is not None + task = await app.store.create_task(task_metadata) - done_event = Event() - app.task_done_events[task.task_id] = done_event + done_event = Event() + app.task_done_events[task.task_id] = done_event - async def do_work(): - async with task_execution(task.task_id, app.store) as task_ctx: - await task_ctx.complete(CallToolResult(content=[TextContent(type="text", text="Done")])) - done_event.set() + async def do_work() -> None: + async with task_execution(task.task_id, app.store) as task_ctx: + await task_ctx.complete(CallToolResult(content=[TextContent(type="text", text=result_text)])) + done_event.set() - app.task_group.start_soon(do_work) - return CreateTaskResult(task=task) + app.task_group.start_soon(do_work) + return CreateTaskResult(task=task) - raise NotImplementedError + raise NotImplementedError + + +def _make_lifespan(store: InMemoryTaskStore, task_done_events: dict[str, Event]): + @asynccontextmanager + async def app_lifespan(server: Server[AppContext]) -> AsyncIterator[AppContext]: + async with anyio.create_task_group() as tg: + yield AppContext(task_group=tg, store=store, task_done_events=task_done_events) - @server.experimental.get_task() - async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: - app = server.request_context.lifespan_context - task = await app.store.get_task(request.params.task_id) - assert task is not None, f"Test setup error: task {request.params.task_id} should exist" + return app_lifespan + + +async def test_session_experimental_get_task() -> None: + """Test session.experimental.get_task() method.""" + store = InMemoryTaskStore() + task_done_events: dict[str, Event] = {} + + async def handle_get_task(ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams) -> GetTaskResult: + app = ctx.lifespan_context + task = await app.store.get_task(params.task_id) + assert task is not None, f"Test setup error: task {params.task_id} should exist" return GetTaskResult( task_id=task.task_id, status=task.status, @@ -96,280 +100,141 @@ async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: poll_interval=task.poll_interval, ) - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: ... # pragma: no branch - - async def run_server(app_context: AppContext): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), + server: Server[AppContext] = Server( + "test-server", + lifespan=_make_lifespan(store, task_done_events), + on_list_tools=_handle_list_tools, + on_call_tool=_handle_call_tool_with_done_event, + ) + server.experimental.enable_tasks(on_get_task=handle_get_task) + + async with Client(server) as client: + # Create a task + create_result = await client.session.send_request( + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), + ) ), - ) as server_session: - async for message in server_session.incoming_messages: - await server._handle_message(message, server_session, app_context, raise_exceptions=False) - - async with anyio.create_task_group() as tg: - app_context = AppContext(task_group=tg, store=store) - tg.start_soon(run_server, app_context) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - await client_session.initialize() - - # Create a task - create_result = await client_session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) - ), - CreateTaskResult, - ) - task_id = create_result.task.task_id - - # Wait for task to complete - await app_context.task_done_events[task_id].wait() + CreateTaskResult, + ) + task_id = create_result.task.task_id - # Use session.experimental to get task status - task_status = await client_session.experimental.get_task(task_id) + # Wait for task to complete + await task_done_events[task_id].wait() - assert task_status.task_id == task_id - assert task_status.status == "completed" + # Use session.experimental to get task status + task_status = await client.session.experimental.get_task(task_id) - tg.cancel_scope.cancel() + assert task_status.task_id == task_id + assert task_status.status == "completed" -@pytest.mark.anyio async def test_session_experimental_get_task_result() -> None: """Test session.experimental.get_task_result() method.""" - server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] store = InMemoryTaskStore() + task_done_events: dict[str, Event] = {} - @server.list_tools() - async def list_tools(): - return [Tool(name="test_tool", description="Test", input_schema={"type": "object"})] + async def handle_call_tool( + ctx: ServerRequestContext[AppContext], params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: + return await _handle_call_tool_with_done_event(ctx, params, result_text="Task result content") - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: - ctx = server.request_context - app = ctx.lifespan_context - if ctx.experimental.is_task: - task_metadata = ctx.experimental.task_metadata - assert task_metadata is not None - task = await app.store.create_task(task_metadata) - - done_event = Event() - app.task_done_events[task.task_id] = done_event - - async def do_work(): - async with task_execution(task.task_id, app.store) as task_ctx: - await task_ctx.complete( - CallToolResult(content=[TextContent(type="text", text="Task result content")]) - ) - done_event.set() - - app.task_group.start_soon(do_work) - return CreateTaskResult(task=task) - - raise NotImplementedError - - @server.experimental.get_task_result() async def handle_get_task_result( - request: GetTaskPayloadRequest, + ctx: ServerRequestContext[AppContext], params: GetTaskPayloadRequestParams ) -> GetTaskPayloadResult: - app = server.request_context.lifespan_context - result = await app.store.get_result(request.params.task_id) - assert result is not None, f"Test setup error: result for {request.params.task_id} should exist" + app = ctx.lifespan_context + result = await app.store.get_result(params.task_id) + assert result is not None, f"Test setup error: result for {params.task_id} should exist" assert isinstance(result, CallToolResult) return GetTaskPayloadResult(**result.model_dump()) - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: ... # pragma: no branch - - async def run_server(app_context: AppContext): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), + server: Server[AppContext] = Server( + "test-server", + lifespan=_make_lifespan(store, task_done_events), + on_list_tools=_handle_list_tools, + on_call_tool=handle_call_tool, + ) + server.experimental.enable_tasks(on_task_result=handle_get_task_result) + + async with Client(server) as client: + # Create a task + create_result = await client.session.send_request( + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), + ) ), - ) as server_session: - async for message in server_session.incoming_messages: - await server._handle_message(message, server_session, app_context, raise_exceptions=False) - - async with anyio.create_task_group() as tg: - app_context = AppContext(task_group=tg, store=store) - tg.start_soon(run_server, app_context) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - await client_session.initialize() - - # Create a task - create_result = await client_session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) - ), - CreateTaskResult, - ) - task_id = create_result.task.task_id - - # Wait for task to complete - await app_context.task_done_events[task_id].wait() + CreateTaskResult, + ) + task_id = create_result.task.task_id - # Use TaskClient to get task result - task_result = await client_session.experimental.get_task_result(task_id, CallToolResult) + # Wait for task to complete + await task_done_events[task_id].wait() - assert len(task_result.content) == 1 - content = task_result.content[0] - assert isinstance(content, TextContent) - assert content.text == "Task result content" + # Use TaskClient to get task result + task_result = await client.session.experimental.get_task_result(task_id, CallToolResult) - tg.cancel_scope.cancel() + assert len(task_result.content) == 1 + content = task_result.content[0] + assert isinstance(content, TextContent) + assert content.text == "Task result content" -@pytest.mark.anyio async def test_session_experimental_list_tasks() -> None: """Test TaskClient.list_tasks() method.""" - server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] store = InMemoryTaskStore() + task_done_events: dict[str, Event] = {} - @server.list_tools() - async def list_tools(): - return [Tool(name="test_tool", description="Test", input_schema={"type": "object"})] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: - ctx = server.request_context + async def handle_list_tasks( + ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None + ) -> ListTasksResult: app = ctx.lifespan_context - if ctx.experimental.is_task: - task_metadata = ctx.experimental.task_metadata - assert task_metadata is not None - task = await app.store.create_task(task_metadata) - - done_event = Event() - app.task_done_events[task.task_id] = done_event - - async def do_work(): - async with task_execution(task.task_id, app.store) as task_ctx: - await task_ctx.complete(CallToolResult(content=[TextContent(type="text", text="Done")])) - done_event.set() - - app.task_group.start_soon(do_work) - return CreateTaskResult(task=task) - - raise NotImplementedError - - @server.experimental.list_tasks() - async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: - app = server.request_context.lifespan_context - tasks_list, next_cursor = await app.store.list_tasks(cursor=request.params.cursor if request.params else None) + cursor = params.cursor if params else None + tasks_list, next_cursor = await app.store.list_tasks(cursor=cursor) return ListTasksResult(tasks=tasks_list, next_cursor=next_cursor) - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: ... # pragma: no branch - - async def run_server(app_context: AppContext): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, + server: Server[AppContext] = Server( + "test-server", + lifespan=_make_lifespan(store, task_done_events), + on_list_tools=_handle_list_tools, + on_call_tool=_handle_call_tool_with_done_event, + ) + server.experimental.enable_tasks(on_list_tasks=handle_list_tasks) + + async with Client(server) as client: + # Create two tasks + for _ in range(2): + create_result = await client.session.send_request( + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), + ) ), - ), - ) as server_session: - async for message in server_session.incoming_messages: - await server._handle_message(message, server_session, app_context, raise_exceptions=False) - - async with anyio.create_task_group() as tg: - app_context = AppContext(task_group=tg, store=store) - tg.start_soon(run_server, app_context) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - await client_session.initialize() - - # Create two tasks - for _ in range(2): - create_result = await client_session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) - ), - CreateTaskResult, - ) - await app_context.task_done_events[create_result.task.task_id].wait() - - # Use TaskClient to list tasks - list_result = await client_session.experimental.list_tasks() + CreateTaskResult, + ) + await task_done_events[create_result.task.task_id].wait() - assert len(list_result.tasks) == 2 + # Use TaskClient to list tasks + list_result = await client.session.experimental.list_tasks() - tg.cancel_scope.cancel() + assert len(list_result.tasks) == 2 -@pytest.mark.anyio async def test_session_experimental_cancel_task() -> None: """Test TaskClient.cancel_task() method.""" - server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] store = InMemoryTaskStore() + task_done_events: dict[str, Event] = {} - @server.list_tools() - async def list_tools(): - return [Tool(name="test_tool", description="Test", input_schema={"type": "object"})] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: - ctx = server.request_context + async def handle_call_tool_no_work( + ctx: ServerRequestContext[AppContext], params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: app = ctx.lifespan_context if ctx.experimental.is_task: task_metadata = ctx.experimental.task_metadata @@ -377,14 +242,12 @@ async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextCon task = await app.store.create_task(task_metadata) # Don't start any work - task stays in "working" status return CreateTaskResult(task=task) - raise NotImplementedError - @server.experimental.get_task() - async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: - app = server.request_context.lifespan_context - task = await app.store.get_task(request.params.task_id) - assert task is not None, f"Test setup error: task {request.params.task_id} should exist" + async def handle_get_task(ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams) -> GetTaskResult: + app = ctx.lifespan_context + task = await app.store.get_task(params.task_id) + assert task is not None, f"Test setup error: task {params.task_id} should exist" return GetTaskResult( task_id=task.task_id, status=task.status, @@ -395,14 +258,14 @@ async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: poll_interval=task.poll_interval, ) - @server.experimental.cancel_task() - async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: - app = server.request_context.lifespan_context - task = await app.store.get_task(request.params.task_id) - assert task is not None, f"Test setup error: task {request.params.task_id} should exist" - await app.store.update_task(request.params.task_id, status="cancelled") - # CancelTaskResult extends Task, so we need to return the updated task info - updated_task = await app.store.get_task(request.params.task_id) + async def handle_cancel_task( + ctx: ServerRequestContext[AppContext], params: CancelTaskRequestParams + ) -> CancelTaskResult: + app = ctx.lifespan_context + task = await app.store.get_task(params.task_id) + assert task is not None, f"Test setup error: task {params.task_id} should exist" + await app.store.update_task(params.task_id, status="cancelled") + updated_task = await app.store.get_task(params.task_id) assert updated_task is not None return CancelTaskResult( task_id=updated_task.task_id, @@ -412,63 +275,35 @@ async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: ttl=updated_task.ttl, ) - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: ... # pragma: no branch - - async def run_server(app_context: AppContext): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), + server: Server[AppContext] = Server( + "test-server", + lifespan=_make_lifespan(store, task_done_events), + on_list_tools=_handle_list_tools, + on_call_tool=handle_call_tool_no_work, + ) + server.experimental.enable_tasks(on_get_task=handle_get_task, on_cancel_task=handle_cancel_task) + + async with Client(server) as client: + # Create a task (but don't complete it) + create_result = await client.session.send_request( + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), + ) ), - ) as server_session: - async for message in server_session.incoming_messages: - await server._handle_message(message, server_session, app_context, raise_exceptions=False) - - async with anyio.create_task_group() as tg: - app_context = AppContext(task_group=tg, store=store) - tg.start_soon(run_server, app_context) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - await client_session.initialize() - - # Create a task (but don't complete it) - create_result = await client_session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) - ), - CreateTaskResult, - ) - task_id = create_result.task.task_id - - # Verify task is working - status_before = await client_session.experimental.get_task(task_id) - assert status_before.status == "working" + CreateTaskResult, + ) + task_id = create_result.task.task_id - # Cancel the task - await client_session.experimental.cancel_task(task_id) + # Verify task is working + status_before = await client.session.experimental.get_task(task_id) + assert status_before.status == "working" - # Verify task is cancelled - status_after = await client_session.experimental.get_task(task_id) - assert status_after.status == "cancelled" + # Cancel the task + await client.session.experimental.cancel_task(task_id) - tg.cancel_scope.cancel() + # Verify task is cancelled + status_after = await client.session.experimental.get_task(task_id) + assert status_after.status == "cancelled" diff --git a/tests/experimental/tasks/server/test_integration.py b/tests/experimental/tasks/server/test_integration.py index 41cecc1295..b5b79033d0 100644 --- a/tests/experimental/tasks/server/test_integration.py +++ b/tests/experimental/tasks/server/test_integration.py @@ -8,46 +8,37 @@ 5. Client retrieves result with tasks/result """ +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager from dataclasses import dataclass, field -from typing import Any import anyio import pytest from anyio import Event from anyio.abc import TaskGroup -from mcp.client.session import ClientSession -from mcp.server import Server -from mcp.server.lowlevel import NotificationOptions -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession +from mcp import Client +from mcp.server import Server, ServerRequestContext from mcp.shared.experimental.tasks.helpers import task_execution from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder from mcp.types import ( - TASK_REQUIRED, CallToolRequest, CallToolRequestParams, CallToolResult, - ClientResult, CreateTaskResult, - GetTaskPayloadRequest, GetTaskPayloadRequestParams, GetTaskPayloadResult, - GetTaskRequest, GetTaskRequestParams, GetTaskResult, - ListTasksRequest, ListTasksResult, - ServerNotification, - ServerRequest, + ListToolsResult, + PaginatedRequestParams, TaskMetadata, TextContent, - Tool, - ToolExecution, ) +pytestmark = pytest.mark.anyio + @dataclass class AppContext: @@ -55,77 +46,57 @@ class AppContext: task_group: TaskGroup store: InMemoryTaskStore - # Events to signal when tasks complete (for testing without sleeps) task_done_events: dict[str, Event] = field(default_factory=lambda: {}) -@pytest.mark.anyio +def _make_lifespan(store: InMemoryTaskStore, task_done_events: dict[str, Event]): + @asynccontextmanager + async def app_lifespan(server: Server[AppContext]) -> AsyncIterator[AppContext]: + async with anyio.create_task_group() as tg: + yield AppContext(task_group=tg, store=store, task_done_events=task_done_events) + + return app_lifespan + + async def test_task_lifecycle_with_task_execution() -> None: - """Test the complete task lifecycle using the task_execution pattern. - - This demonstrates the recommended way to implement task-augmented tools: - 1. Create task in store - 2. Spawn work using task_execution() context manager - 3. Return CreateTaskResult immediately - 4. Work executes in background, auto-fails on exception - """ - # Note: We bypass the normal lifespan mechanism and pass context directly to _handle_message - server: Server[AppContext, Any] = Server("test-tasks") # type: ignore[assignment] + """Test the complete task lifecycle using the task_execution pattern.""" store = InMemoryTaskStore() + task_done_events: dict[str, Event] = {} + + async def handle_list_tools( + ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None + ) -> ListToolsResult: + raise NotImplementedError - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="process_data", - description="Process data asynchronously", - input_schema={ - "type": "object", - "properties": {"input": {"type": "string"}}, - }, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: - ctx = server.request_context + async def handle_call_tool( + ctx: ServerRequestContext[AppContext], params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: app = ctx.lifespan_context - if name == "process_data" and ctx.experimental.is_task: - # 1. Create task in store + if params.name == "process_data" and ctx.experimental.is_task: task_metadata = ctx.experimental.task_metadata assert task_metadata is not None task = await app.store.create_task(task_metadata) - # 2. Create event to signal completion (for testing) done_event = Event() app.task_done_events[task.task_id] = done_event - # 3. Define work function using task_execution for safety - async def do_work(): + async def do_work() -> None: async with task_execution(task.task_id, app.store) as task_ctx: await task_ctx.update_status("Processing input...") - # Simulate work - input_value = arguments.get("input", "") + input_value = (params.arguments or {}).get("input", "") result_text = f"Processed: {input_value.upper()}" await task_ctx.complete(CallToolResult(content=[TextContent(type="text", text=result_text)])) - # Signal completion done_event.set() - # 4. Spawn work in task group (from lifespan_context) app.task_group.start_soon(do_work) - - # 5. Return CreateTaskResult immediately return CreateTaskResult(task=task) raise NotImplementedError - # Register task query handlers (delegate to store) - @server.experimental.get_task() - async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: - app = server.request_context.lifespan_context - task = await app.store.get_task(request.params.task_id) - assert task is not None, f"Test setup error: task {request.params.task_id} should exist" + async def handle_get_task(ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams) -> GetTaskResult: + app = ctx.lifespan_context + task = await app.store.get_task(params.task_id) + assert task is not None, f"Test setup error: task {params.task_id} should exist" return GetTaskResult( task_id=task.task_id, status=task.status, @@ -136,134 +107,91 @@ async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: poll_interval=task.poll_interval, ) - @server.experimental.get_task_result() async def handle_get_task_result( - request: GetTaskPayloadRequest, + ctx: ServerRequestContext[AppContext], params: GetTaskPayloadRequestParams ) -> GetTaskPayloadResult: - app = server.request_context.lifespan_context - result = await app.store.get_result(request.params.task_id) - assert result is not None, f"Test setup error: result for {request.params.task_id} should exist" + app = ctx.lifespan_context + result = await app.store.get_result(params.task_id) + assert result is not None, f"Test setup error: result for {params.task_id} should exist" assert isinstance(result, CallToolResult) - # Return as GetTaskPayloadResult (which accepts extra fields) return GetTaskPayloadResult(**result.model_dump()) - @server.experimental.list_tasks() - async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + async def handle_list_tasks( + ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None + ) -> ListTasksResult: raise NotImplementedError - # Set up client-server communication - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: ... # pragma: no cover - - async def run_server(app_context: AppContext): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, + server: Server[AppContext] = Server( + "test-tasks", + lifespan=_make_lifespan(store, task_done_events), + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) + server.experimental.enable_tasks( + on_get_task=handle_get_task, + on_task_result=handle_get_task_result, + on_list_tasks=handle_list_tasks, + ) + + async with Client(server) as client: + # Step 1: Send task-augmented tool call + create_result = await client.session.send_request( + CallToolRequest( + params=CallToolRequestParams( + name="process_data", + arguments={"input": "hello world"}, + task=TaskMetadata(ttl=60000), ), ), - ) as server_session: - async for message in server_session.incoming_messages: - await server._handle_message(message, server_session, app_context, raise_exceptions=False) - - async with anyio.create_task_group() as tg: - # Create app context with task group and store - app_context = AppContext(task_group=tg, store=store) - tg.start_soon(run_server, app_context) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - await client_session.initialize() - - # === Step 1: Send task-augmented tool call === - create_result = await client_session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="process_data", - arguments={"input": "hello world"}, - task=TaskMetadata(ttl=60000), - ), - ), - CreateTaskResult, - ) - - assert isinstance(create_result, CreateTaskResult) - assert create_result.task.status == "working" - task_id = create_result.task.task_id - - # === Step 2: Wait for task to complete === - await app_context.task_done_events[task_id].wait() + CreateTaskResult, + ) - task_status = await client_session.send_request( - GetTaskRequest(params=GetTaskRequestParams(task_id=task_id)), - GetTaskResult, - ) + assert isinstance(create_result, CreateTaskResult) + assert create_result.task.status == "working" + task_id = create_result.task.task_id - assert task_status.task_id == task_id - assert task_status.status == "completed" + # Step 2: Wait for task to complete + await task_done_events[task_id].wait() - # === Step 3: Retrieve the actual result === - task_result = await client_session.send_request( - GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=task_id)), - CallToolResult, - ) + task_status = await client.session.experimental.get_task(task_id) + assert task_status.task_id == task_id + assert task_status.status == "completed" - assert len(task_result.content) == 1 - content = task_result.content[0] - assert isinstance(content, TextContent) - assert content.text == "Processed: HELLO WORLD" + # Step 3: Retrieve the actual result + task_result = await client.session.experimental.get_task_result(task_id, CallToolResult) - tg.cancel_scope.cancel() + assert len(task_result.content) == 1 + content = task_result.content[0] + assert isinstance(content, TextContent) + assert content.text == "Processed: HELLO WORLD" -@pytest.mark.anyio async def test_task_auto_fails_on_exception() -> None: """Test that task_execution automatically fails the task on unhandled exception.""" - # Note: We bypass the normal lifespan mechanism and pass context directly to _handle_message - server: Server[AppContext, Any] = Server("test-tasks-failure") # type: ignore[assignment] store = InMemoryTaskStore() + task_done_events: dict[str, Event] = {} + + async def handle_list_tools( + ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None + ) -> ListToolsResult: + raise NotImplementedError - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="failing_task", - description="A task that fails", - input_schema={"type": "object", "properties": {}}, - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: - ctx = server.request_context + async def handle_call_tool( + ctx: ServerRequestContext[AppContext], params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: app = ctx.lifespan_context - if name == "failing_task" and ctx.experimental.is_task: + if params.name == "failing_task" and ctx.experimental.is_task: task_metadata = ctx.experimental.task_metadata assert task_metadata is not None task = await app.store.create_task(task_metadata) - # Create event to signal completion (for testing) done_event = Event() app.task_done_events[task.task_id] = done_event - async def do_failing_work(): + async def do_failing_work() -> None: async with task_execution(task.task_id, app.store) as task_ctx: await task_ctx.update_status("About to fail...") raise RuntimeError("Something went wrong!") - # Note: complete() is never called, but task_execution - # will automatically call fail() due to the exception # This line is reached because task_execution suppresses the exception done_event.set() @@ -272,11 +200,10 @@ async def do_failing_work(): raise NotImplementedError - @server.experimental.get_task() - async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: - app = server.request_context.lifespan_context - task = await app.store.get_task(request.params.task_id) - assert task is not None, f"Test setup error: task {request.params.task_id} should exist" + async def handle_get_task(ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams) -> GetTaskResult: + app = ctx.lifespan_context + task = await app.store.get_task(params.task_id) + assert task is not None, f"Test setup error: task {params.task_id} should exist" return GetTaskResult( task_id=task.task_id, status=task.status, @@ -287,64 +214,34 @@ async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: poll_interval=task.poll_interval, ) - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: ... # pragma: no cover - - async def run_server(app_context: AppContext): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, + server: Server[AppContext] = Server( + "test-tasks-failure", + lifespan=_make_lifespan(store, task_done_events), + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) + server.experimental.enable_tasks(on_get_task=handle_get_task) + + async with Client(server) as client: + # Send task request + create_result = await client.session.send_request( + CallToolRequest( + params=CallToolRequestParams( + name="failing_task", + arguments={}, + task=TaskMetadata(ttl=60000), ), ), - ) as server_session: - async for message in server_session.incoming_messages: - await server._handle_message(message, server_session, app_context, raise_exceptions=False) - - async with anyio.create_task_group() as tg: - app_context = AppContext(task_group=tg, store=store) - tg.start_soon(run_server, app_context) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - await client_session.initialize() - - # Send task request - create_result = await client_session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="failing_task", - arguments={}, - task=TaskMetadata(ttl=60000), - ), - ), - CreateTaskResult, - ) - - task_id = create_result.task.task_id + CreateTaskResult, + ) - # Wait for task to complete (even though it fails) - await app_context.task_done_events[task_id].wait() + task_id = create_result.task.task_id - # Check that task was auto-failed - task_status = await client_session.send_request( - GetTaskRequest(params=GetTaskRequestParams(task_id=task_id)), GetTaskResult - ) + # Wait for task to complete (even though it fails) + await task_done_events[task_id].wait() - assert task_status.status == "failed" - assert task_status.status_message == "Something went wrong!" + # Check that task was auto-failed + task_status = await client.session.experimental.get_task(task_id) - tg.cancel_scope.cancel() + assert task_status.status == "failed" + assert task_status.status_message == "Something went wrong!" diff --git a/tests/experimental/tasks/server/test_run_task_flow.py b/tests/experimental/tasks/server/test_run_task_flow.py index 0d5d1df77a..027382e69e 100644 --- a/tests/experimental/tasks/server/test_run_task_flow.py +++ b/tests/experimental/tasks/server/test_run_task_flow.py @@ -8,159 +8,102 @@ These are integration tests that verify the complete flow works end-to-end. """ -from typing import Any from unittest.mock import Mock import anyio import pytest from anyio import Event -from mcp.client.session import ClientSession -from mcp.server import Server +from mcp import Client +from mcp.server import Server, ServerRequestContext from mcp.server.experimental.request_context import Experimental from mcp.server.experimental.task_context import ServerTaskContext from mcp.server.experimental.task_support import TaskSupport from mcp.server.lowlevel import NotificationOptions from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue -from mcp.shared.message import SessionMessage from mcp.types import ( TASK_REQUIRED, + CallToolRequestParams, CallToolResult, - CancelTaskRequest, - CancelTaskResult, CreateTaskResult, - GetTaskPayloadRequest, - GetTaskPayloadResult, - GetTaskRequest, + GetTaskRequestParams, GetTaskResult, - ListTasksRequest, - ListTasksResult, + ListToolsResult, + PaginatedRequestParams, TextContent, - Tool, - ToolExecution, ) +pytestmark = pytest.mark.anyio -@pytest.mark.anyio -async def test_run_task_basic_flow() -> None: - """Test the basic run_task flow without elicitation. - 1. enable_tasks() sets up handlers - 2. Client calls tool with task field - 3. run_task() spawns work, returns CreateTaskResult - 4. Work completes in background - 5. Client polls and sees completed status - """ - server = Server("test-run-task") +async def _handle_list_tools_simple_task( + ctx: ServerRequestContext, params: PaginatedRequestParams | None +) -> ListToolsResult: + raise NotImplementedError - # One-line setup - server.experimental.enable_tasks() - # Track when work completes and capture received meta +async def test_run_task_basic_flow() -> None: + """Test the basic run_task flow without elicitation.""" work_completed = Event() received_meta: list[str | None] = [None] - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="simple_task", - description="A simple task", - input_schema={"type": "object", "properties": {"input": {"type": "string"}}}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult | CreateTaskResult: - ctx = server.request_context + async def handle_call_tool( + ctx: ServerRequestContext, params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) - # Capture the meta from the request (if present) if ctx.meta is not None: # pragma: no branch received_meta[0] = ctx.meta.get("custom_field") async def work(task: ServerTaskContext) -> CallToolResult: await task.update_status("Working...") - input_val = arguments.get("input", "default") + input_val = (params.arguments or {}).get("input", "default") result = CallToolResult(content=[TextContent(type="text", text=f"Processed: {input_val}")]) work_completed.set() return result return await ctx.experimental.run_task(work) - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), + server = Server( + "test-run-task", + on_list_tools=_handle_list_tools_simple_task, + on_call_tool=handle_call_tool, + ) + server.experimental.enable_tasks() + + async with Client(server) as client: + result = await client.session.experimental.call_tool_as_task( + "simple_task", + {"input": "hello"}, + meta={"custom_field": "test_value"}, ) - async def run_client() -> None: - async with ClientSession(server_to_client_receive, client_to_server_send) as client_session: - # Initialize - await client_session.initialize() - - # Call tool as task (with meta to test that code path) - result = await client_session.experimental.call_tool_as_task( - "simple_task", - {"input": "hello"}, - meta={"custom_field": "test_value"}, - ) - - # Should get CreateTaskResult - task_id = result.task.task_id - assert result.task.status == "working" - - # Wait for work to complete - with anyio.fail_after(5): - await work_completed.wait() - - # Poll until task status is completed - with anyio.fail_after(5): - while True: - task_status = await client_session.experimental.get_task(task_id) - if task_status.status == "completed": # pragma: no branch - break - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) - - # Verify the meta was passed through correctly + task_id = result.task.task_id + assert result.task.status == "working" + + with anyio.fail_after(5): + await work_completed.wait() + + with anyio.fail_after(5): + while True: + task_status = await client.session.experimental.get_task(task_id) + if task_status.status == "completed": # pragma: no branch + break + assert received_meta[0] == "test_value" -@pytest.mark.anyio async def test_run_task_auto_fails_on_exception() -> None: """Test that run_task automatically fails the task when work raises.""" - server = Server("test-run-task-fail") - server.experimental.enable_tasks() - work_failed = Event() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="failing_task", - description="A task that fails", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult | CreateTaskResult: - ctx = server.request_context + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError + + async def handle_call_tool( + ctx: ServerRequestContext, params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) async def work(task: ServerTaskContext) -> CallToolResult: @@ -169,42 +112,29 @@ async def work(task: ServerTaskContext) -> CallToolResult: return await ctx.experimental.run_task(work) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options(), - ) - - async def run_client() -> None: - async with ClientSession(server_to_client_receive, client_to_server_send) as client_session: - await client_session.initialize() - - result = await client_session.experimental.call_tool_as_task("failing_task", {}) - task_id = result.task.task_id + server = Server( + "test-run-task-fail", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) + server.experimental.enable_tasks() - # Wait for work to fail - with anyio.fail_after(5): - await work_failed.wait() + async with Client(server) as client: + result = await client.session.experimental.call_tool_as_task("failing_task", {}) + task_id = result.task.task_id - # Poll until task status is failed - with anyio.fail_after(5): - while True: - task_status = await client_session.experimental.get_task(task_id) - if task_status.status == "failed": # pragma: no branch - break + with anyio.fail_after(5): + await work_failed.wait() - assert "Something went wrong" in (task_status.status_message or "") + with anyio.fail_after(5): + while True: + task_status = await client.session.experimental.get_task(task_id) + if task_status.status == "failed": # pragma: no branch + break - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) + assert "Something went wrong" in (task_status.status_message or "") -@pytest.mark.anyio async def test_enable_tasks_auto_registers_handlers() -> None: """Test that enable_tasks() auto-registers get_task, list_tasks, cancel_task handlers.""" server = Server("test-enable-tasks") @@ -221,63 +151,41 @@ async def test_enable_tasks_auto_registers_handlers() -> None: assert caps_after.tasks is not None assert caps_after.tasks.list is not None assert caps_after.tasks.cancel is not None - # Verify nested call capability is present assert caps_after.tasks.requests is not None assert caps_after.tasks.requests.tools is not None assert caps_after.tasks.requests.tools.call is not None -@pytest.mark.anyio async def test_enable_tasks_with_custom_store_and_queue() -> None: """Test that enable_tasks() uses provided store and queue instead of defaults.""" server = Server("test-custom-store-queue") - # Create custom store and queue custom_store = InMemoryTaskStore() custom_queue = InMemoryTaskMessageQueue() - # Enable tasks with custom implementations task_support = server.experimental.enable_tasks(store=custom_store, queue=custom_queue) - # Verify our custom implementations are used assert task_support.store is custom_store assert task_support.queue is custom_queue -@pytest.mark.anyio async def test_enable_tasks_skips_default_handlers_when_custom_registered() -> None: """Test that enable_tasks() doesn't override already-registered handlers.""" server = Server("test-custom-handlers") - # Register custom handlers BEFORE enable_tasks (never called, just for registration) - @server.experimental.get_task() - async def custom_get_task(req: GetTaskRequest) -> GetTaskResult: - raise NotImplementedError - - @server.experimental.get_task_result() - async def custom_get_task_result(req: GetTaskPayloadRequest) -> GetTaskPayloadResult: + # Register custom handlers via enable_tasks kwargs + async def custom_get_task(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: raise NotImplementedError - @server.experimental.list_tasks() - async def custom_list_tasks(req: ListTasksRequest) -> ListTasksResult: - raise NotImplementedError - - @server.experimental.cancel_task() - async def custom_cancel_task(req: CancelTaskRequest) -> CancelTaskResult: - raise NotImplementedError + server.experimental.enable_tasks(on_get_task=custom_get_task) - # Now enable tasks - should NOT override our custom handlers - server.experimental.enable_tasks() + # Verify handler is registered + assert server._has_handler("tasks/get") + assert server._has_handler("tasks/list") + assert server._has_handler("tasks/cancel") + assert server._has_handler("tasks/result") - # Verify our custom handlers are still registered (not replaced by defaults) - # The handlers dict should contain our custom handlers - assert GetTaskRequest in server.request_handlers - assert GetTaskPayloadRequest in server.request_handlers - assert ListTasksRequest in server.request_handlers - assert CancelTaskRequest in server.request_handlers - -@pytest.mark.anyio async def test_run_task_without_enable_tasks_raises() -> None: """Test that run_task raises when enable_tasks() wasn't called.""" experimental = Experimental( @@ -294,7 +202,6 @@ async def work(task: ServerTaskContext) -> CallToolResult: await experimental.run_task(work) -@pytest.mark.anyio async def test_task_support_task_group_before_run_raises() -> None: """Test that accessing task_group before run() raises RuntimeError.""" task_support = TaskSupport.in_memory() @@ -303,7 +210,6 @@ async def test_task_support_task_group_before_run_raises() -> None: _ = task_support.task_group -@pytest.mark.anyio async def test_run_task_without_session_raises() -> None: """Test that run_task raises when session is not available.""" task_support = TaskSupport.in_memory() @@ -322,7 +228,6 @@ async def work(task: ServerTaskContext) -> CallToolResult: await experimental.run_task(work) -@pytest.mark.anyio async def test_run_task_without_task_metadata_raises() -> None: """Test that run_task raises when request is not task-augmented.""" task_support = TaskSupport.in_memory() @@ -342,29 +247,17 @@ async def work(task: ServerTaskContext) -> CallToolResult: await experimental.run_task(work) -@pytest.mark.anyio async def test_run_task_with_model_immediate_response() -> None: """Test that run_task includes model_immediate_response in CreateTaskResult._meta.""" - server = Server("test-run-task-immediate") - server.experimental.enable_tasks() - work_completed = Event() immediate_response_text = "Processing your request..." - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="task_with_immediate", - description="A task with immediate response", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult | CreateTaskResult: - ctx = server.request_context + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError + + async def handle_call_tool( + ctx: ServerRequestContext, params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) async def work(task: ServerTaskContext) -> CallToolResult: @@ -373,164 +266,102 @@ async def work(task: ServerTaskContext) -> CallToolResult: return await ctx.experimental.run_task(work, model_immediate_response=immediate_response_text) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options(), - ) - - async def run_client() -> None: - async with ClientSession(server_to_client_receive, client_to_server_send) as client_session: - await client_session.initialize() - - result = await client_session.experimental.call_tool_as_task("task_with_immediate", {}) + server = Server( + "test-run-task-immediate", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) + server.experimental.enable_tasks() - # Verify the immediate response is in _meta - assert result.meta is not None - assert "io.modelcontextprotocol/model-immediate-response" in result.meta - assert result.meta["io.modelcontextprotocol/model-immediate-response"] == immediate_response_text + async with Client(server) as client: + result = await client.session.experimental.call_tool_as_task("task_with_immediate", {}) - with anyio.fail_after(5): - await work_completed.wait() + assert result.meta is not None + assert "io.modelcontextprotocol/model-immediate-response" in result.meta + assert result.meta["io.modelcontextprotocol/model-immediate-response"] == immediate_response_text - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) + with anyio.fail_after(5): + await work_completed.wait() -@pytest.mark.anyio async def test_run_task_doesnt_complete_if_already_terminal() -> None: """Test that run_task doesn't auto-complete if work manually completed the task.""" - server = Server("test-already-complete") - server.experimental.enable_tasks() - work_completed = Event() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="manual_complete_task", - description="A task that manually completes", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult | CreateTaskResult: - ctx = server.request_context + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError + + async def handle_call_tool( + ctx: ServerRequestContext, params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) async def work(task: ServerTaskContext) -> CallToolResult: - # Manually complete the task before returning manual_result = CallToolResult(content=[TextContent(type="text", text="Manually completed")]) await task.complete(manual_result, notify=False) work_completed.set() - # Return a different result - but it should be ignored since task is already terminal return CallToolResult(content=[TextContent(type="text", text="This should be ignored")]) return await ctx.experimental.run_task(work) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options(), - ) - - async def run_client() -> None: - async with ClientSession(server_to_client_receive, client_to_server_send) as client_session: - await client_session.initialize() - - result = await client_session.experimental.call_tool_as_task("manual_complete_task", {}) - task_id = result.task.task_id + server = Server( + "test-already-complete", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) + server.experimental.enable_tasks() - with anyio.fail_after(5): - await work_completed.wait() + async with Client(server) as client: + result = await client.session.experimental.call_tool_as_task("manual_complete_task", {}) + task_id = result.task.task_id - # Poll until task status is completed - with anyio.fail_after(5): - while True: - status = await client_session.experimental.get_task(task_id) - if status.status == "completed": # pragma: no branch - break + with anyio.fail_after(5): + await work_completed.wait() - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) + with anyio.fail_after(5): + while True: + status = await client.session.experimental.get_task(task_id) + if status.status == "completed": # pragma: no branch + break -@pytest.mark.anyio async def test_run_task_doesnt_fail_if_already_terminal() -> None: """Test that run_task doesn't auto-fail if work manually failed/cancelled the task.""" - server = Server("test-already-failed") - server.experimental.enable_tasks() - work_completed = Event() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="manual_cancel_task", - description="A task that manually cancels then raises", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult | CreateTaskResult: - ctx = server.request_context + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError + + async def handle_call_tool( + ctx: ServerRequestContext, params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) async def work(task: ServerTaskContext) -> CallToolResult: - # Manually fail the task first await task.fail("Manually failed", notify=False) work_completed.set() - # Then raise - but the auto-fail should be skipped since task is already terminal raise RuntimeError("This error should not change status") return await ctx.experimental.run_task(work) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options(), - ) - - async def run_client() -> None: - async with ClientSession(server_to_client_receive, client_to_server_send) as client_session: - await client_session.initialize() - - result = await client_session.experimental.call_tool_as_task("manual_cancel_task", {}) - task_id = result.task.task_id + server = Server( + "test-already-failed", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) + server.experimental.enable_tasks() - with anyio.fail_after(5): - await work_completed.wait() + async with Client(server) as client: + result = await client.session.experimental.call_tool_as_task("manual_cancel_task", {}) + task_id = result.task.task_id - # Poll until task status is failed - with anyio.fail_after(5): - while True: - status = await client_session.experimental.get_task(task_id) - if status.status == "failed": # pragma: no branch - break + with anyio.fail_after(5): + await work_completed.wait() - # Task should still be failed (from manual fail, not auto-fail from exception) - assert status.status_message == "Manually failed" # Not "This error should not change status" + with anyio.fail_after(5): + while True: + status = await client.session.experimental.get_task(task_id) + if status.status == "failed": # pragma: no branch + break - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) + assert status.status_message == "Manually failed" diff --git a/tests/experimental/tasks/server/test_server.py b/tests/experimental/tasks/server/test_server.py index 8005380d28..6a28b274ea 100644 --- a/tests/experimental/tasks/server/test_server.py +++ b/tests/experimental/tasks/server/test_server.py @@ -6,8 +6,9 @@ import anyio import pytest +from mcp import Client from mcp.client.session import ClientSession -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession @@ -23,7 +24,6 @@ CallToolRequest, CallToolRequestParams, CallToolResult, - CancelTaskRequest, CancelTaskRequestParams, CancelTaskResult, ClientResult, @@ -31,21 +31,18 @@ GetTaskPayloadRequest, GetTaskPayloadRequestParams, GetTaskPayloadResult, - GetTaskRequest, GetTaskRequestParams, GetTaskResult, JSONRPCError, JSONRPCNotification, JSONRPCResponse, - ListTasksRequest, ListTasksResult, - ListToolsRequest, ListToolsResult, + PaginatedRequestParams, SamplingMessage, ServerCapabilities, ServerNotification, ServerRequest, - ServerResult, Task, TaskMetadata, TextContent, @@ -53,57 +50,37 @@ ToolExecution, ) +pytestmark = pytest.mark.anyio -@pytest.mark.anyio -async def test_list_tasks_handler() -> None: - """Test that experimental list_tasks handler works.""" - server = Server("test") +async def test_list_tasks_handler() -> None: + """Test that experimental list_tasks handler works via Client.""" now = datetime.now(timezone.utc) test_tasks = [ - Task( - task_id="task-1", - status="working", - created_at=now, - last_updated_at=now, - ttl=60000, - poll_interval=1000, - ), - Task( - task_id="task-2", - status="completed", - created_at=now, - last_updated_at=now, - ttl=60000, - poll_interval=1000, - ), + Task(task_id="task-1", status="working", created_at=now, last_updated_at=now, ttl=60000, poll_interval=1000), + Task(task_id="task-2", status="completed", created_at=now, last_updated_at=now, ttl=60000, poll_interval=1000), ] - @server.experimental.list_tasks() - async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + async def handle_list_tasks(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListTasksResult: return ListTasksResult(tasks=test_tasks) - handler = server.request_handlers[ListTasksRequest] - request = ListTasksRequest(method="tasks/list") - result = await handler(request) + server = Server("test") + server.experimental.enable_tasks(on_list_tasks=handle_list_tasks) - assert isinstance(result, ServerResult) - assert isinstance(result, ListTasksResult) - assert len(result.tasks) == 2 - assert result.tasks[0].task_id == "task-1" - assert result.tasks[1].task_id == "task-2" + async with Client(server) as client: + result = await client.session.experimental.list_tasks() + assert len(result.tasks) == 2 + assert result.tasks[0].task_id == "task-1" + assert result.tasks[1].task_id == "task-2" -@pytest.mark.anyio async def test_get_task_handler() -> None: - """Test that experimental get_task handler works.""" - server = Server("test") + """Test that experimental get_task handler works via Client.""" - @server.experimental.get_task() - async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: + async def handle_get_task(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: now = datetime.now(timezone.utc) return GetTaskResult( - task_id=request.params.task_id, + task_id=params.task_id, status="working", created_at=now, last_updated_at=now, @@ -111,85 +88,69 @@ async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: poll_interval=1000, ) - handler = server.request_handlers[GetTaskRequest] - request = GetTaskRequest( - method="tasks/get", - params=GetTaskRequestParams(task_id="test-task-123"), - ) - result = await handler(request) + server = Server("test") + server.experimental.enable_tasks(on_get_task=handle_get_task) - assert isinstance(result, ServerResult) - assert isinstance(result, GetTaskResult) - assert result.task_id == "test-task-123" - assert result.status == "working" + async with Client(server) as client: + result = await client.session.experimental.get_task("test-task-123") + assert result.task_id == "test-task-123" + assert result.status == "working" -@pytest.mark.anyio async def test_get_task_result_handler() -> None: - """Test that experimental get_task_result handler works.""" - server = Server("test") + """Test that experimental get_task_result handler works via Client.""" - @server.experimental.get_task_result() - async def handle_get_task_result(request: GetTaskPayloadRequest) -> GetTaskPayloadResult: + async def handle_get_task_result( + ctx: ServerRequestContext, params: GetTaskPayloadRequestParams + ) -> GetTaskPayloadResult: return GetTaskPayloadResult() - handler = server.request_handlers[GetTaskPayloadRequest] - request = GetTaskPayloadRequest( - method="tasks/result", - params=GetTaskPayloadRequestParams(task_id="test-task-123"), - ) - result = await handler(request) + server = Server("test") + server.experimental.enable_tasks(on_task_result=handle_get_task_result) - assert isinstance(result, ServerResult) - assert isinstance(result, GetTaskPayloadResult) + async with Client(server) as client: + result = await client.session.send_request( + GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id="test-task-123")), + GetTaskPayloadResult, + ) + assert isinstance(result, GetTaskPayloadResult) -@pytest.mark.anyio async def test_cancel_task_handler() -> None: - """Test that experimental cancel_task handler works.""" - server = Server("test") + """Test that experimental cancel_task handler works via Client.""" - @server.experimental.cancel_task() - async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: + async def handle_cancel_task(ctx: ServerRequestContext, params: CancelTaskRequestParams) -> CancelTaskResult: now = datetime.now(timezone.utc) return CancelTaskResult( - task_id=request.params.task_id, + task_id=params.task_id, status="cancelled", created_at=now, last_updated_at=now, ttl=60000, ) - handler = server.request_handlers[CancelTaskRequest] - request = CancelTaskRequest( - method="tasks/cancel", - params=CancelTaskRequestParams(task_id="test-task-123"), - ) - result = await handler(request) + server = Server("test") + server.experimental.enable_tasks(on_cancel_task=handle_cancel_task) - assert isinstance(result, ServerResult) - assert isinstance(result, CancelTaskResult) - assert result.task_id == "test-task-123" - assert result.status == "cancelled" + async with Client(server) as client: + result = await client.session.experimental.cancel_task("test-task-123") + assert result.task_id == "test-task-123" + assert result.status == "cancelled" -@pytest.mark.anyio async def test_server_capabilities_include_tasks() -> None: """Test that server capabilities include tasks when handlers are registered.""" server = Server("test") - @server.experimental.list_tasks() - async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + async def noop_list_tasks(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListTasksResult: raise NotImplementedError - @server.experimental.cancel_task() - async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: + async def noop_cancel_task(ctx: ServerRequestContext, params: CancelTaskRequestParams) -> CancelTaskResult: raise NotImplementedError - capabilities = server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ) + server.experimental.enable_tasks(on_list_tasks=noop_list_tasks, on_cancel_task=noop_cancel_task) + + capabilities = server.get_capabilities(notification_options=NotificationOptions(), experimental_capabilities={}) assert capabilities.tasks is not None assert capabilities.tasks.list is not None @@ -198,259 +159,164 @@ async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: assert capabilities.tasks.requests.tools is not None -@pytest.mark.anyio -async def test_server_capabilities_partial_tasks() -> None: +@pytest.mark.skip( + reason="TODO(maxisbey): enable_tasks registers default handlers for all task methods, " + "so partial capabilities aren't possible yet. Low-level API should support " + "selectively enabling/disabling task capabilities." +) +async def test_server_capabilities_partial_tasks() -> None: # pragma: no cover """Test capabilities with only some task handlers registered.""" server = Server("test") - @server.experimental.list_tasks() - async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + async def noop_list_tasks(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListTasksResult: raise NotImplementedError # Only list_tasks registered, not cancel_task + server.experimental.enable_tasks(on_list_tasks=noop_list_tasks) - capabilities = server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ) + capabilities = server.get_capabilities(notification_options=NotificationOptions(), experimental_capabilities={}) assert capabilities.tasks is not None assert capabilities.tasks.list is not None assert capabilities.tasks.cancel is None # Not registered -@pytest.mark.anyio async def test_tool_with_task_execution_metadata() -> None: """Test that tools can declare task execution mode.""" - server = Server("test") - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="quick_tool", - description="Fast tool", - input_schema={"type": "object", "properties": {}}, - execution=ToolExecution(task_support=TASK_FORBIDDEN), - ), - Tool( - name="long_tool", - description="Long running tool", - input_schema={"type": "object", "properties": {}}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ), - Tool( - name="flexible_tool", - description="Can be either", - input_schema={"type": "object", "properties": {}}, - execution=ToolExecution(task_support=TASK_OPTIONAL), - ), - ] + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="quick_tool", + description="Fast tool", + input_schema={"type": "object", "properties": {}}, + execution=ToolExecution(task_support=TASK_FORBIDDEN), + ), + Tool( + name="long_tool", + description="Long running tool", + input_schema={"type": "object", "properties": {}}, + execution=ToolExecution(task_support=TASK_REQUIRED), + ), + Tool( + name="flexible_tool", + description="Can be either", + input_schema={"type": "object", "properties": {}}, + execution=ToolExecution(task_support=TASK_OPTIONAL), + ), + ] + ) - tools_handler = server.request_handlers[ListToolsRequest] - request = ListToolsRequest(method="tools/list") - result = await tools_handler(request) + server = Server("test", on_list_tools=handle_list_tools) - assert isinstance(result, ServerResult) - assert isinstance(result, ListToolsResult) - tools = result.tools + async with Client(server) as client: + result = await client.list_tools() + tools = result.tools - assert tools[0].execution is not None - assert tools[0].execution.task_support == TASK_FORBIDDEN - assert tools[1].execution is not None - assert tools[1].execution.task_support == TASK_REQUIRED - assert tools[2].execution is not None - assert tools[2].execution.task_support == TASK_OPTIONAL + assert tools[0].execution is not None + assert tools[0].execution.task_support == TASK_FORBIDDEN + assert tools[1].execution is not None + assert tools[1].execution.task_support == TASK_REQUIRED + assert tools[2].execution is not None + assert tools[2].execution.task_support == TASK_OPTIONAL -@pytest.mark.anyio async def test_task_metadata_in_call_tool_request() -> None: - """Test that task metadata is accessible via RequestContext when calling a tool.""" - server = Server("test") + """Test that task metadata is accessible via ctx when calling a tool.""" captured_task_metadata: TaskMetadata | None = None - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="long_task", - description="A long running task", - input_schema={"type": "object", "properties": {}}, - execution=ToolExecution(task_support="optional"), - ) - ] + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: nonlocal captured_task_metadata - ctx = server.request_context captured_task_metadata = ctx.experimental.task_metadata - return [TextContent(type="text", text="done")] - - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: ... # pragma: no branch - - async def run_server(): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, + return CallToolResult(content=[TextContent(type="text", text="done")]) + + server = Server("test", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) + + async with Client(server) as client: + # Call tool with task metadata + await client.session.send_request( + CallToolRequest( + params=CallToolRequestParams( + name="long_task", + arguments={}, + task=TaskMetadata(ttl=60000), ), ), - ) as server_session: - async with anyio.create_task_group() as tg: - - async def handle_messages(): - async for message in server_session.incoming_messages: # pragma: no branch - await server._handle_message(message, server_session, {}, False) - - tg.start_soon(handle_messages) - await anyio.sleep_forever() - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - await client_session.initialize() - - # Call tool with task metadata - await client_session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="long_task", - arguments={}, - task=TaskMetadata(ttl=60000), - ), - ), - CallToolResult, - ) - - tg.cancel_scope.cancel() + CallToolResult, + ) assert captured_task_metadata is not None assert captured_task_metadata.ttl == 60000 -@pytest.mark.anyio async def test_task_metadata_is_task_property() -> None: - """Test that RequestContext.experimental.is_task works correctly.""" - server = Server("test") + """Test that ctx.experimental.is_task works correctly.""" is_task_values: list[bool] = [] - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="test_tool", - description="Test tool", - input_schema={"type": "object", "properties": {}}, - ) - ] + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: - ctx = server.request_context + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: is_task_values.append(ctx.experimental.is_task) - return [TextContent(type="text", text="done")] + return CallToolResult(content=[TextContent(type="text", text="done")]) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + server = Server("test", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: ... # pragma: no branch + async with Client(server) as client: + # Call without task metadata + await client.session.send_request( + CallToolRequest(params=CallToolRequestParams(name="test_tool", arguments={})), + CallToolResult, + ) - async def run_server(): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), + # Call with task metadata + await client.session.send_request( + CallToolRequest( + params=CallToolRequestParams(name="test_tool", arguments={}, task=TaskMetadata(ttl=60000)), ), - ) as server_session: - async with anyio.create_task_group() as tg: - - async def handle_messages(): - async for message in server_session.incoming_messages: # pragma: no branch - await server._handle_message(message, server_session, {}, False) - - tg.start_soon(handle_messages) - await anyio.sleep_forever() - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - await client_session.initialize() - - # Call without task metadata - await client_session.send_request( - CallToolRequest(params=CallToolRequestParams(name="test_tool", arguments={})), - CallToolResult, - ) - - # Call with task metadata - await client_session.send_request( - CallToolRequest( - params=CallToolRequestParams(name="test_tool", arguments={}, task=TaskMetadata(ttl=60000)), - ), - CallToolResult, - ) - - tg.cancel_scope.cancel() + CallToolResult, + ) assert len(is_task_values) == 2 assert is_task_values[0] is False # First call without task assert is_task_values[1] is True # Second call with task -@pytest.mark.anyio async def test_update_capabilities_no_handlers() -> None: """Test that update_capabilities returns early when no task handlers are registered.""" server = Server("test-no-handlers") - # Access experimental to initialize it, but don't register any task handlers _ = server.experimental caps = server.get_capabilities(NotificationOptions(), {}) - - # Without any task handlers registered, tasks capability should be None assert caps.tasks is None -@pytest.mark.anyio +async def test_update_capabilities_partial_handlers() -> None: + """Test that update_capabilities skips list/cancel when only tasks/get is registered.""" + server = Server("test-partial") + # Access .experimental to create the ExperimentalHandlers instance + exp = server.experimental + # Second access returns the same cached instance + assert server.experimental is exp + + async def noop_get(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: + raise NotImplementedError + + server._add_request_handler("tasks/get", noop_get) + + caps = server.get_capabilities(NotificationOptions(), {}) + assert caps.tasks is not None + assert caps.tasks.list is None + assert caps.tasks.cancel is None + + async def test_default_task_handlers_via_enable_tasks() -> None: - """Test that enable_tasks() auto-registers working default handlers. - - This exercises the default handlers in lowlevel/experimental.py: - - _default_get_task (task not found) - - _default_get_task_result - - _default_list_tasks - - _default_cancel_task - """ + """Test that enable_tasks() auto-registers working default handlers.""" server = Server("test-default-handlers") - # Enable tasks with default handlers (no custom handlers registered) task_support = server.experimental.enable_tasks() store = task_support.store @@ -493,24 +359,18 @@ async def run_server() -> None: task = await store.create_task(TaskMetadata(ttl=60000)) # Test list_tasks (default handler) - list_result = await client_session.send_request(ListTasksRequest(), ListTasksResult) + list_result = await client_session.experimental.list_tasks() assert len(list_result.tasks) == 1 assert list_result.tasks[0].task_id == task.task_id # Test get_task (default handler - found) - get_result = await client_session.send_request( - GetTaskRequest(params=GetTaskRequestParams(task_id=task.task_id)), - GetTaskResult, - ) + get_result = await client_session.experimental.get_task(task.task_id) assert get_result.task_id == task.task_id assert get_result.status == "working" # Test get_task (default handler - not found path) with pytest.raises(MCPError, match="not found"): - await client_session.send_request( - GetTaskRequest(params=GetTaskRequestParams(task_id="nonexistent-task")), - GetTaskResult, - ) + await client_session.experimental.get_task("nonexistent-task") # Create a completed task to test get_task_result completed_task = await store.create_task(TaskMetadata(ttl=60000)) @@ -529,9 +389,7 @@ async def run_server() -> None: assert "io.modelcontextprotocol/related-task" in payload_result.meta # Test cancel_task (default handler) - cancel_result = await client_session.send_request( - CancelTaskRequest(params=CancelTaskRequestParams(task_id=task.task_id)), CancelTaskResult - ) + cancel_result = await client_session.experimental.cancel_task(task.task_id) assert cancel_result.task_id == task.task_id assert cancel_result.status == "cancelled" diff --git a/tests/experimental/tasks/test_elicitation_scenarios.py b/tests/experimental/tasks/test_elicitation_scenarios.py index 57122da7b9..2d0378a9ce 100644 --- a/tests/experimental/tasks/test_elicitation_scenarios.py +++ b/tests/experimental/tasks/test_elicitation_scenarios.py @@ -17,7 +17,7 @@ from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers from mcp.client.session import ClientSession -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.experimental.task_context import ServerTaskContext from mcp.server.lowlevel import NotificationOptions from mcp.shared._context import RequestContext @@ -26,6 +26,7 @@ from mcp.shared.message import SessionMessage from mcp.types import ( TASK_REQUIRED, + CallToolRequestParams, CallToolResult, CreateMessageRequestParams, CreateMessageResult, @@ -35,11 +36,12 @@ ErrorData, GetTaskPayloadResult, GetTaskResult, + ListToolsResult, + PaginatedRequestParams, SamplingMessage, TaskMetadata, TextContent, Tool, - ToolExecution, ) @@ -181,24 +183,21 @@ async def test_scenario1_normal_tool_normal_elicitation() -> None: Server calls session.elicit() directly, client responds immediately. """ - server = Server("test-scenario1") elicit_received = Event() tool_result: list[str] = [] - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="confirm_action", - description="Confirm an action", - input_schema={"type": "object"}, - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: - ctx = server.request_context + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="confirm_action", + description="Confirm an action", + input_schema={"type": "object"}, + ) + ] + ) + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: # Normal elicitation - expects immediate response result = await ctx.session.elicit( message="Please confirm the action", @@ -209,6 +208,8 @@ async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResu tool_result.append("confirmed" if confirmed else "cancelled") return CallToolResult(content=[TextContent(type="text", text="confirmed" if confirmed else "cancelled")]) + server = Server("test-scenario1", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) + # Elicitation callback for client async def elicitation_callback( context: RequestContext[ClientSession], @@ -262,27 +263,24 @@ async def test_scenario2_normal_tool_task_augmented_elicitation() -> None: Server calls session.experimental.elicit_as_task(), client creates a task for the elicitation and returns CreateTaskResult. Server polls client. """ - server = Server("test-scenario2") elicit_received = Event() tool_result: list[str] = [] # Client-side task store for handling task-augmented elicitation client_task_store = InMemoryTaskStore() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="confirm_action", - description="Confirm an action", - input_schema={"type": "object"}, - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: - ctx = server.request_context + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="confirm_action", + description="Confirm an action", + input_schema={"type": "object"}, + ) + ] + ) + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: # Task-augmented elicitation - server polls client result = await ctx.session.experimental.elicit_as_task( message="Please confirm the action", @@ -294,6 +292,7 @@ async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResu tool_result.append("confirmed" if confirmed else "cancelled") return CallToolResult(content=[TextContent(type="text", text="confirmed" if confirmed else "cancelled")]) + server = Server("test-scenario2", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) task_handlers = create_client_task_handlers(client_task_store, elicit_received) # Set up streams @@ -342,26 +341,13 @@ async def test_scenario3_task_augmented_tool_normal_elicitation() -> None: Client calls tool as task. Inside the task, server uses task.elicit() which queues the request and delivers via tasks/result. """ - server = Server("test-scenario3") - server.experimental.enable_tasks() - elicit_received = Event() work_completed = Event() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="confirm_action", - description="Confirm an action", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CreateTaskResult: - ctx = server.request_context + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) async def work(task: ServerTaskContext) -> CallToolResult: @@ -377,6 +363,9 @@ async def work(task: ServerTaskContext) -> CallToolResult: return await ctx.experimental.run_task(work) + server = Server("test-scenario3", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) + server.experimental.enable_tasks() + # Elicitation callback for client async def elicitation_callback( context: RequestContext[ClientSession], @@ -452,29 +441,16 @@ async def test_scenario4_task_augmented_tool_task_augmented_elicitation() -> Non 5. Server gets the ElicitResult and completes the tool task 6. Client's tasks/result returns with the CallToolResult """ - server = Server("test-scenario4") - server.experimental.enable_tasks() - elicit_received = Event() work_completed = Event() # Client-side task store for handling task-augmented elicitation client_task_store = InMemoryTaskStore() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="confirm_action", - description="Confirm an action", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CreateTaskResult: - ctx = server.request_context + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) async def work(task: ServerTaskContext) -> CallToolResult: @@ -491,6 +467,8 @@ async def work(task: ServerTaskContext) -> CallToolResult: return await ctx.experimental.run_task(work) + server = Server("test-scenario4", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) + server.experimental.enable_tasks() task_handlers = create_client_task_handlers(client_task_store, elicit_received) # Set up streams @@ -553,27 +531,24 @@ async def test_scenario2_sampling_normal_tool_task_augmented_sampling() -> None: Server calls session.experimental.create_message_as_task(), client creates a task for the sampling and returns CreateTaskResult. Server polls client. """ - server = Server("test-scenario2-sampling") sampling_received = Event() tool_result: list[str] = [] # Client-side task store for handling task-augmented sampling client_task_store = InMemoryTaskStore() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="generate_text", - description="Generate text using sampling", - input_schema={"type": "object"}, - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: - ctx = server.request_context + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="generate_text", + description="Generate text using sampling", + input_schema={"type": "object"}, + ) + ] + ) + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: # Task-augmented sampling - server polls client result = await ctx.session.experimental.create_message_as_task( messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], @@ -587,6 +562,7 @@ async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResu tool_result.append(response_text) return CallToolResult(content=[TextContent(type="text", text=response_text)]) + server = Server("test-scenario2-sampling", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) task_handlers = create_sampling_task_handlers(client_task_store, sampling_received) # Set up streams @@ -636,29 +612,16 @@ async def test_scenario4_sampling_task_augmented_tool_task_augmented_sampling() which sends task-augmented sampling. Client creates its own task for the sampling, and server polls the client. """ - server = Server("test-scenario4-sampling") - server.experimental.enable_tasks() - sampling_received = Event() work_completed = Event() # Client-side task store for handling task-augmented sampling client_task_store = InMemoryTaskStore() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="generate_text", - description="Generate text using sampling", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CreateTaskResult: - ctx = server.request_context + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) async def work(task: ServerTaskContext) -> CallToolResult: @@ -677,6 +640,8 @@ async def work(task: ServerTaskContext) -> CallToolResult: return await ctx.experimental.run_task(work) + server = Server("test-scenario4-sampling", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) + server.experimental.enable_tasks() task_handlers = create_sampling_task_handlers(client_task_store, sampling_received) # Set up streams diff --git a/tests/experimental/tasks/test_message_queue.py b/tests/experimental/tasks/test_message_queue.py index a8517e535c..eca113d5b4 100644 --- a/tests/experimental/tasks/test_message_queue.py +++ b/tests/experimental/tasks/test_message_queue.py @@ -1,5 +1,6 @@ """Tests for TaskMessageQueue and InMemoryTaskMessageQueue.""" +from collections import deque from datetime import datetime, timezone import anyio @@ -270,7 +271,7 @@ async def is_empty_with_injection(tid: str) -> bool: if call_count == 2 and tid == task_id: # Before second check, inject a message - this simulates a message # arriving between event creation and the double-check - queue._queues[task_id] = [QueuedMessage(type="request", message=make_request())] + queue._queues[task_id] = deque([QueuedMessage(type="request", message=make_request())]) return await original_is_empty(tid) queue.is_empty = is_empty_with_injection # type: ignore[method-assign] diff --git a/tests/experimental/tasks/test_spec_compliance.py b/tests/experimental/tasks/test_spec_compliance.py index d00ce40a45..38d7d0a664 100644 --- a/tests/experimental/tasks/test_spec_compliance.py +++ b/tests/experimental/tasks/test_spec_compliance.py @@ -10,17 +10,17 @@ import pytest -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.lowlevel import NotificationOptions from mcp.shared.experimental.tasks.helpers import MODEL_IMMEDIATE_RESPONSE_KEY from mcp.types import ( - CancelTaskRequest, + CancelTaskRequestParams, CancelTaskResult, CreateTaskResult, - GetTaskRequest, + GetTaskRequestParams, GetTaskResult, - ListTasksRequest, ListTasksResult, + PaginatedRequestParams, ServerCapabilities, Task, ) @@ -44,13 +44,22 @@ def test_server_without_task_handlers_has_no_tasks_capability() -> None: assert caps.tasks is None +async def _noop_get_task(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: + raise NotImplementedError + + +async def _noop_list_tasks(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListTasksResult: + raise NotImplementedError + + +async def _noop_cancel_task(ctx: ServerRequestContext, params: CancelTaskRequestParams) -> CancelTaskResult: + raise NotImplementedError + + def test_server_with_list_tasks_handler_declares_list_capability() -> None: """Server with list_tasks handler declares tasks.list capability.""" server: Server = Server("test") - - @server.experimental.list_tasks() - async def handle_list(req: ListTasksRequest) -> ListTasksResult: - raise NotImplementedError + server.experimental.enable_tasks(on_list_tasks=_noop_list_tasks) caps = _get_capabilities(server) assert caps.tasks is not None @@ -60,10 +69,7 @@ async def handle_list(req: ListTasksRequest) -> ListTasksResult: def test_server_with_cancel_task_handler_declares_cancel_capability() -> None: """Server with cancel_task handler declares tasks.cancel capability.""" server: Server = Server("test") - - @server.experimental.cancel_task() - async def handle_cancel(req: CancelTaskRequest) -> CancelTaskResult: - raise NotImplementedError + server.experimental.enable_tasks(on_cancel_task=_noop_cancel_task) caps = _get_capabilities(server) assert caps.tasks is not None @@ -75,10 +81,7 @@ def test_server_with_get_task_handler_declares_requests_tools_call_capability() (get_task is required for task-augmented tools/call support) """ server: Server = Server("test") - - @server.experimental.get_task() - async def handle_get(req: GetTaskRequest) -> GetTaskResult: - raise NotImplementedError + server.experimental.enable_tasks(on_get_task=_noop_get_task) caps = _get_capabilities(server) assert caps.tasks is not None @@ -86,28 +89,30 @@ async def handle_get(req: GetTaskRequest) -> GetTaskResult: assert caps.tasks.requests.tools is not None -def test_server_without_list_handler_has_no_list_capability() -> None: +@pytest.mark.skip( + reason="TODO(maxisbey): enable_tasks registers default handlers for all task methods, " + "so partial capabilities aren't possible yet. Low-level API should support " + "selectively enabling/disabling task capabilities." +) +def test_server_without_list_handler_has_no_list_capability() -> None: # pragma: no cover """Server without list_tasks handler has no tasks.list capability.""" server: Server = Server("test") - - # Register only get_task (not list_tasks) - @server.experimental.get_task() - async def handle_get(req: GetTaskRequest) -> GetTaskResult: - raise NotImplementedError + server.experimental.enable_tasks(on_get_task=_noop_get_task) caps = _get_capabilities(server) assert caps.tasks is not None assert caps.tasks.list is None -def test_server_without_cancel_handler_has_no_cancel_capability() -> None: +@pytest.mark.skip( + reason="TODO(maxisbey): enable_tasks registers default handlers for all task methods, " + "so partial capabilities aren't possible yet. Low-level API should support " + "selectively enabling/disabling task capabilities." +) +def test_server_without_cancel_handler_has_no_cancel_capability() -> None: # pragma: no cover """Server without cancel_task handler has no tasks.cancel capability.""" server: Server = Server("test") - - # Register only get_task (not cancel_task) - @server.experimental.get_task() - async def handle_get(req: GetTaskRequest) -> GetTaskResult: - raise NotImplementedError + server.experimental.enable_tasks(on_get_task=_noop_get_task) caps = _get_capabilities(server) assert caps.tasks is not None @@ -117,18 +122,11 @@ async def handle_get(req: GetTaskRequest) -> GetTaskResult: def test_server_with_all_task_handlers_has_full_capability() -> None: """Server with all task handlers declares complete tasks capability.""" server: Server = Server("test") - - @server.experimental.list_tasks() - async def handle_list(req: ListTasksRequest) -> ListTasksResult: - raise NotImplementedError - - @server.experimental.cancel_task() - async def handle_cancel(req: CancelTaskRequest) -> CancelTaskResult: - raise NotImplementedError - - @server.experimental.get_task() - async def handle_get(req: GetTaskRequest) -> GetTaskResult: - raise NotImplementedError + server.experimental.enable_tasks( + on_list_tasks=_noop_list_tasks, + on_cancel_task=_noop_cancel_task, + on_get_task=_noop_get_task, + ) caps = _get_capabilities(server) assert caps.tasks is not None diff --git a/tests/interaction/README.md b/tests/interaction/README.md new file mode 100644 index 0000000000..be68c3b0f1 --- /dev/null +++ b/tests/interaction/README.md @@ -0,0 +1,228 @@ +# Interaction-model test suite + +This suite enumerates the MCP interaction model as end-to-end tests: one test per piece of +functionality, asserting the full client↔server round trip through the public API. It exists to +pin the SDK's observable behaviour — every request type, every notification direction, every +error plane — so that internal rewrites of the send/receive path can be proven equivalent by +running the suite before and after. + +```bash +uv run --frozen pytest tests/interaction/ +``` + +The whole suite is in-process and event-driven — including the streamable HTTP, SSE, and OAuth +flows — with a single subprocess test for stdio. + +## Ground rules + +- **Public API only.** Tests drive a `Client` connected to a `Server` or `MCPServer`. Nothing + reaches into session internals, so the suite keeps working when those internals change. + `ClientSession` is used directly only for behaviours `Client` cannot express (skipping + initialization, requesting a non-default protocol version). +- **Pin current behaviour.** Every test passes against the current `main`, including behaviours + that diverge from the specification. A failing or xfailed test proves nothing about whether a + rewrite preserved behaviour; a passing test that pins the wrong output exactly does. Known + divergences are recorded as data on the requirement (see below), not worked around in the test. +- **Spec-mandated assertions, not implementation quirks.** Error *codes* are asserted against + the constants in `mcp.types`; error *message strings* are pinned only where they are the + SDK's own deliberate output. +- **No sleeps, no real I/O.** Concurrency is coordinated with `anyio.Event`; every wait that + could hang is bounded by `anyio.fail_after(5)`. The HTTP and OAuth tests drive the Starlette + app in-process through the suite's streaming ASGI bridge (`transports/_bridge.py`), which + delivers each response chunk as the server produces it — full duplex, but still no sockets, + threads, or subprocesses anywhere outside the one stdio test. + +## Layout + +```text +tests/interaction/ + _requirements.py the requirements manifest (see below) + _helpers.py shared type aliases + the wire-recording transport + _connect.py the transport-parametrized connection factories + conftest.py the connect fixture (the transport matrix) + test_coverage.py enforces the manifest ↔ test contract + lowlevel/ one file per feature area, against the low-level Server + mcpserver/ the same feature areas in MCPServer's natural idiom + transports/ behaviour specific to one transport (sessions, resumability, framing) + auth/ OAuth flows against an in-process authorization server +``` + +The two server APIs produce genuinely different wire output for the same conceptual feature +(`MCPServer` generates schemas, converts exceptions to `isError` results, attaches structured +content), so they get parallel directories with mirrored file names rather than one parametrized +test body — each directory pins its flavour's true output exactly. + +### The transport matrix + +Transport-agnostic tests take the `connect` fixture instead of constructing `Client(server)` +directly, and therefore run once per transport: over the in-memory transport, over the server's +real streamable HTTP app driven in-process through the streaming bridge, and over the legacy SSE +transport the same way. A test connects with `async with connect(server, ...) as client:` and +asserts the same output on every leg, because the transport is not supposed to change observable +behaviour. Tests that are tied to one transport do not use the fixture: the wire-recording tests +(their seam is the in-memory stream pair), the bare-`ClientSession` lifecycle tests, the +real-clock timeout tests (the timeout machinery is transport-independent and must not race +transport latency), and everything under `transports/`, which pins behaviour only observable on +that transport. + +A transport conformance test in `transports/` speaks raw `httpx` against the mounted ASGI app +**only** when its assertion is about HTTP semantics that `Client` cannot observe — status codes, +response headers, SSE event fields, which stream a message travels on. Any other behaviour is +asserted through a `Client`, connected to the mounted app via `client_via_http(http)` so several +clients can share one session manager. + +## The requirements manifest + +`_requirements.py` maps every behaviour the suite covers to the reason it must hold: + +```python +"tools:call:content:text": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#text-content", + behavior="tools/call delivers arguments to the tool handler and returns its text content.", +), +``` + +- **`source`** is a deep link into the MCP specification for externally mandated behaviour, + the literal string `"sdk"` for behaviour the SDK chose where the spec is silent, or + `"issue:#n"` for a regression lock. +- **`behavior`** describes the *required* behaviour — what the specification (or the SDK's own + contract) says should happen. Tests always pin the SDK's current behaviour; where that falls + short of `behavior`, the gap is recorded as data rather than hidden in the test. +- **`divergence`** records that gap for entries whose tests pin the divergent current behaviour. +- **`deferred`** marks a behaviour that is tracked but has no test in this suite, with a precise + reason: the SDK does not implement it, the negative cannot be observed, the assertion is + schema-level rather than interaction-level, the feature is experimental (tasks), or the test + would require real-time waits the suite refuses. +- **`transports`** names the transports a behaviour applies to; omitted means transport-independent. +- **`issue`** carries the tracking link for a recorded gap once one is filed. + +Tests link themselves to the manifest with a decorator: + +```python +@requirement("tools:call:content:text") +async def test_call_tool_returns_text_content() -> None: ... +``` + +`test_coverage.py` enforces the contract in both directions: every non-deferred requirement must +be exercised by at least one test, every deferred requirement by none, and an unknown ID fails at +import time. A behaviour without a manifest entry cannot be silently half-tested, and a manifest +entry without a test cannot be silently aspirational. + +### The divergence lifecycle + +1. A test reveals that the SDK does not do what the spec says. The test pins what the SDK + *actually does* and a `Divergence(note=..., issue=...)` goes on the requirement. +2. When the behaviour is eventually fixed, the pinned test fails. Whoever makes the change finds + the divergence note explaining that the old behaviour was a known gap, re-pins the test to the + spec-correct output, and deletes the `Divergence`. +3. An empty divergence list means the SDK is spec-conformant on every behaviour the suite covers. + +A requirement may carry both `divergence` and `deferred`: the divergence records that the SDK falls +short of the spec, and the deferral records why no test pins it (typically because the divergent +behaviour cannot be driven through the public API). Divergence alone implies a test pins the +divergent behaviour; divergence plus deferred means the gap is known but unpinned. + +This is also the triage key for any rewrite: a test that fails on the new code path either has a +divergence note (the rewrite accidentally fixed a known gap — decide whether to keep the fix) or +it does not (the rewrite broke something that was correct — fix the rewrite). + +### When a new spec revision is released + +1. Update `SPEC_REVISION` and walk the new revision's changelog. +2. For each changed interaction, find its requirements (the IDs use the wire method strings the + changelog speaks in), re-audit the tests against the new text, and update `source` links and + assertions where behaviour legitimately changed. +3. New interactions get new requirements and new tests; removed interactions get their + requirements deleted along with their tests. +4. A behaviour that is correct under both revisions needs no change beyond the `source` link. + +## Writing a test + +The shortest complete example of the conventions: + +```python +@requirement("tools:call:content:text") +async def test_call_tool_returns_text_content() -> None: + """Arguments reach the tool handler; its content comes back as the call result.""" + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "add" + assert params.arguments is not None + return CallToolResult(content=[TextContent(text=str(params.arguments["a"] + params.arguments["b"]))]) + + server = Server("adder", on_call_tool=call_tool) + + async with Client(server) as client: + result = await client.call_tool("add", {"a": 2, "b": 3}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="5")])) +``` + +- **The server is defined inside the test** (or in a small fixture at the top of the file when + several tests genuinely share it). The whole observable behaviour fits on one screen. +- **Test names are behaviour sentences** — they state the observable outcome, not the feature + being poked. Docstrings add the one or two sentences of context a reviewer needs, including + whether the assertion is spec-mandated, SDK-defined, or a known divergence. +- **Handlers assert their dispatch identity first** (`assert params.name == "add"`), proving the + request that arrived is the request the test sent. +- **The result proves the round trip.** Server-side observations travel back to the test through + the protocol itself (a tool returns what it saw) or through a closure-captured list; the test + asserts after the call returns. +- **Order within a test**: server handlers → server construction → client callbacks → connect → + act → assert. The test reads in the order the conversation happens. +- A registered handler or tool that a test never invokes gets a `raise NotImplementedError` body + so it cannot silently become load-bearing. +- A test that needs a peer no real `Server` or `Client` can play (a server that answers initialize + with an unsupported version, a client that sends malformed params) plays that side of the wire by + hand over `create_client_server_memory_streams()`. This scripted-peer pattern is the suite's only + way to drive behaviour the typed API cannot produce, and the docstring of every such test says so. + +Stack a second `@requirement` decorator only when a test's natural assertions incidentally prove +another behaviour — one capabilities snapshot proving four `*:capability:declared` entries, one +input-schema identity check proving each preserved keyword. Do not build a test around covering +many requirements at once; if the assertions would be separate, write separate tests. + +### Choosing an assertion + +| The property under test is… | Assert with | +|---|---| +| the result of a transformation (arguments → output, exception → error result) | `result == snapshot(...)` of the full object, so any field the implementation adds or drops fails the test | +| pass-through of an opaque value (`_meta`, cursors) | identity against the same variable that was sent — a snapshot of a pass-through value only matches the input because a human checked two literals correspond | +| an error | `pytest.raises(MCPError)` and a snapshot of `exc.value.error` when the message is the SDK's own; a plain `==` on `.code` against the `mcp.types` constant when it is not | +| third-party output embedded in a result (validation messages) | the stable prefix only — never pin text that changes with a dependency upgrade | + +### Notifications and concurrency + +The client's receive loop dispatches each incoming message to completion before reading the next, +and the in-memory transport delivers everything on one ordered stream. Together these guarantee +that every notification a server handler emits before its response reaches the client callback +before the originating request returns — so tests collect notifications into a plain list and +assert after the call, with no synchronisation. The exceptions: + +- a notification not triggered by a request the test is awaiting needs an `anyio.Event` set in + the receiving handler and awaited under `anyio.fail_after(5)`; +- the ordering guarantee does not survive transports that split messages across streams (the + streamable HTTP standalone GET stream) — see `transports/test_streamable_http.py`. + +### Coverage + +CI requires 100% line and branch coverage, including `tests/`, and `strict-no-cover` fails the +build if a line marked `# pragma: no cover` is ever executed. When a new test starts covering a +pragma'd line in `src/`, delete the pragma in the same change. Do not add new `# type: ignore` or +`# noqa` comments; restructure instead. Two pragmas are sanctioned in this suite's test code, both +for known-upstream tracer bugs and only after restructuring has been tried: `# pragma: no branch` +on a `with`/`async with` line whose only fault is coverage.py mis-tracing the exit arc of a nested +async context (reserve it for shapes that cannot collapse — a sync `with` adjacent to an +`async with`); and `# pragma: lax no cover` on a single statement that 3.11's tracer drops because +the preceding `async with` unwinds via `coro.throw()` (python/cpython#106749, wontfix on 3.11) — +this hits any test that must run statements after a `ClientSession`/`streamable_http_client` exits +but still inside an outer `async with`, and no restructure can avoid it. + +A handful of `# pragma: lax no cover` markers in `src/` cover teardown exception handlers whose +execution is timing-dependent under the in-process HTTP bridge — the POST-stream and +stateless-session `except Exception` handlers in `server/streamable_http*.py`, the `_terminated` +check in `message_router`, and the response-stream double-close guard in +`BaseSession._receive_loop`. `strict-no-cover` does not check `lax` lines; do not promote them to +strict `no cover` without first making the teardown ordering deterministic. The suite also relies +on a one-line `src/mcp/server/sse.py` fix (`sse_stream_reader.aclose()`) that closes a stream the +SSE leg would otherwise leak. diff --git a/tests/interaction/__init__.py b/tests/interaction/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/interaction/_connect.py b/tests/interaction/_connect.py new file mode 100644 index 0000000000..1faf4aa8d6 --- /dev/null +++ b/tests/interaction/_connect.py @@ -0,0 +1,360 @@ +"""Transport-parametrized connection factories for the interaction suite. + +The `connect` fixture (see conftest.py) hands tests one of these factories so the same test body +runs over each transport without naming any of them: the factory is a drop-in replacement for +constructing `Client(server, ...)` and yields the connected client. The HTTP factories drive the +server's real Starlette app through the in-process streaming bridge, so the full transport layer +(session ids, SSE encoding, session management) runs with no sockets, threads, or subprocesses. +""" + +from collections.abc import AsyncIterator, Awaitable, Callable, Iterable +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from typing import Any, Protocol + +import httpx +from httpx_sse import ServerSentEvent, aconnect_sse +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import Mount, Route + +from mcp.client.client import Client +from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT +from mcp.client.sse import sse_client +from mcp.client.streamable_http import streamable_http_client +from mcp.server import Server +from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier +from mcp.server.auth.settings import AuthSettings +from mcp.server.mcpserver import MCPServer +from mcp.server.sse import SseServerTransport +from mcp.server.streamable_http import EventStore +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.server.transport_security import TransportSecuritySettings +from mcp.types import ( + LATEST_PROTOCOL_VERSION, + ClientCapabilities, + Implementation, + InitializeRequestParams, + JSONRPCMessage, + JSONRPCRequest, + JSONRPCResponse, + jsonrpc_message_adapter, +) +from tests.interaction.transports._bridge import StreamingASGITransport + +# The in-process app is mounted at this origin purely so URLs are well-formed; nothing listens here. +BASE_URL = "http://127.0.0.1:8000" + +# DNS-rebinding protection validates Host/Origin headers against a real network attack that cannot +# exist for an in-process ASGI app, so the in-process factories disable it; tests that exercise the +# protection itself pass explicit settings (or transport_security=None to get the localhost +# auto-enable behaviour). +NO_DNS_REBINDING_PROTECTION = TransportSecuritySettings(enable_dns_rebinding_protection=False) + + +class Connect(Protocol): + """Connect a Client to a server over the transport selected by the `connect` fixture. + + Accepts the same keyword arguments as `Client` and yields the connected client. + """ + + def __call__( + self, + server: Server | MCPServer, + *, + read_timeout_seconds: float | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + client_info: Implementation | None = None, + elicitation_callback: ElicitationFnT | None = None, + ) -> AbstractAsyncContextManager[Client]: ... + + +@asynccontextmanager +async def connect_in_memory( + server: Server | MCPServer, + *, + read_timeout_seconds: float | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + client_info: Implementation | None = None, + elicitation_callback: ElicitationFnT | None = None, +) -> AsyncIterator[Client]: + """Yield a Client connected to the server over the in-memory transport.""" + async with Client( + server, + read_timeout_seconds=read_timeout_seconds, + sampling_callback=sampling_callback, + list_roots_callback=list_roots_callback, + logging_callback=logging_callback, + message_handler=message_handler, + client_info=client_info, + elicitation_callback=elicitation_callback, + ) as client: + yield client + + +@asynccontextmanager +async def connect_over_streamable_http( + server: Server | MCPServer, + *, + stateless_http: bool = False, + json_response: bool = False, + event_store: EventStore | None = None, + retry_interval: int | None = None, + read_timeout_seconds: float | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + client_info: Implementation | None = None, + elicitation_callback: ElicitationFnT | None = None, +) -> AsyncIterator[Client]: + """Yield a Client connected to the server's streamable HTTP app, entirely in process. + + With the defaults this is the matrix leg (stateful sessions, SSE responses); the + transport-specific tests pass `stateless_http` or `json_response` to select the other + server modes, and the resumability tests pass an `event_store` (with `retry_interval=0` so + the client's reconnection wait is a no-op). + """ + app = server.streamable_http_app( + stateless_http=stateless_http, + json_response=json_response, + event_store=event_store, + retry_interval=retry_interval, + transport_security=NO_DNS_REBINDING_PROTECTION, + ) + async with ( + server.session_manager.run(), + httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http_client, + Client( + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client), + read_timeout_seconds=read_timeout_seconds, + sampling_callback=sampling_callback, + list_roots_callback=list_roots_callback, + logging_callback=logging_callback, + message_handler=message_handler, + client_info=client_info, + elicitation_callback=elicitation_callback, + ) as client, + ): + yield client + + +@asynccontextmanager +async def mounted_app( + server: Server | MCPServer, + *, + stateless_http: bool = False, + json_response: bool = False, + event_store: EventStore | None = None, + retry_interval: int | None = None, + transport_security: TransportSecuritySettings | None = NO_DNS_REBINDING_PROTECTION, + on_request: Callable[[httpx.Request], Awaitable[None]] | None = None, + headers: dict[str, str] | None = None, + auth: AuthSettings | None = None, + token_verifier: TokenVerifier | None = None, + auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any] | None = None, +) -> AsyncIterator[tuple[httpx.AsyncClient, StreamableHTTPSessionManager]]: + """Mount the server's streamable HTTP app on the in-process bridge and yield an httpx client. + + Yields the httpx client (rooted at the in-process origin) and the live session manager. Tests + use this in two ways: for raw-httpx assertions (status codes, headers, SSE bytes) the test + speaks HTTP through the yielded client directly; for client-driven assertions the test wraps + that client in `client_via_http(http)`, which lets several `Client`s share the one mounted + session manager. `on_request` records every outgoing HTTP request before it leaves the + yielded client. + + DNS-rebinding protection is disabled by default; pass explicit settings (or `None` for the + localhost auto-enable behaviour) to test the protection itself. + """ + lowlevel = server._lowlevel_server if isinstance(server, MCPServer) else server + app = lowlevel.streamable_http_app( + stateless_http=stateless_http, + json_response=json_response, + event_store=event_store, + retry_interval=retry_interval, + transport_security=transport_security, + auth=auth, + token_verifier=token_verifier, + auth_server_provider=auth_server_provider, + ) + event_hooks = {"request": [on_request]} if on_request is not None else None + async with ( + server.session_manager.run(), + httpx.AsyncClient( + transport=StreamingASGITransport(app), base_url=BASE_URL, event_hooks=event_hooks, headers=headers + ) as http_client, + ): + yield http_client, server.session_manager + + +@asynccontextmanager +async def client_via_http( + http_client: httpx.AsyncClient, + *, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + elicitation_callback: ElicitationFnT | None = None, +) -> AsyncIterator[Client]: + """Connect a `Client` over an already-mounted streamable HTTP app. + + Use with `mounted_app(...)` so several `Client`s share the one session manager, or so a + client-driven assertion can sit alongside raw-httpx assertions in the same test. The + underlying `httpx.AsyncClient` is left open when the `Client` exits. + """ + transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) + async with Client( + transport, + logging_callback=logging_callback, + message_handler=message_handler, + elicitation_callback=elicitation_callback, + ) as client: + yield client + + +def parse_sse_messages(events: Iterable[ServerSentEvent]) -> list[JSONRPCMessage]: + """Decode SSE events into JSON-RPC messages, skipping priming events that carry no data.""" + return [jsonrpc_message_adapter.validate_json(event.data) for event in events if event.data] + + +async def post_jsonrpc( + http: httpx.AsyncClient, body: dict[str, object], *, session_id: str | None = None +) -> tuple[httpx.Response, list[JSONRPCMessage]]: + """POST a JSON-RPC body and read its SSE response stream to completion. + + Returns the HTTP response (for header/status assertions) and the parsed JSON-RPC messages + that arrived on the response's SSE stream. Only meaningful for requests the server answers + with `text/event-stream`; for error responses or 202 notification acknowledgements, use + `httpx.AsyncClient.post` directly and assert on the response. + """ + async with aconnect_sse(http, "POST", "/mcp", json=body, headers=base_headers(session_id=session_id)) as source: + events = [event async for event in source.aiter_sse()] + return source.response, parse_sse_messages(events) + + +def base_headers(*, session_id: str | None = None) -> dict[str, str]: + """Standard request headers for raw-httpx streamable-HTTP tests. + + Every well-formed request carries these (Accept covering both response representations, + Content-Type for POST bodies, MCP-Protocol-Version at the latest revision, and the session + ID once one exists), so a test that wants to assert a specific rejection only varies the one + header under test. + """ + headers = { + "accept": "application/json, text/event-stream", + "content-type": "application/json", + "mcp-protocol-version": LATEST_PROTOCOL_VERSION, + } + if session_id is not None: + headers["mcp-session-id"] = session_id + return headers + + +def initialize_body(request_id: int = 1) -> dict[str, object]: + """A wire-level initialize JSON-RPC request body, exactly as an SDK client would send it.""" + params = InitializeRequestParams( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ClientCapabilities(), + client_info=Implementation(name="raw", version="0.0.0"), + ) + return JSONRPCRequest( + jsonrpc="2.0", id=request_id, method="initialize", params=params.model_dump(by_alias=True, exclude_none=True) + ).model_dump(by_alias=True, exclude_none=True) + + +async def initialize_via_http(http: httpx.AsyncClient) -> str: + """Perform the initialize handshake over a raw `httpx.AsyncClient` and return the session ID. + + Validates the SSE response and sends the `notifications/initialized` follow-up, so the server + is fully ready for subsequent feature requests when this returns. + """ + async with aconnect_sse(http, "POST", "/mcp", json=initialize_body(), headers=base_headers()) as source: + assert source.response.status_code == 200 + # An event-store-backed server opens the stream with a priming event (empty data); skip it. + events = [event async for event in source.aiter_sse() if event.data] + assert len(events) == 1 + assert JSONRPCResponse.model_validate_json(events[0].data).id == 1 + session_id = source.response.headers["mcp-session-id"] + initialized = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "method": "notifications/initialized"}, + headers=base_headers(session_id=session_id), + ) + assert initialized.status_code == 202 + return session_id + + +def build_sse_app(server: Server | MCPServer) -> tuple[Starlette, SseServerTransport]: + """Mount a server on a Starlette app exposing the legacy SSE transport at /sse and /messages/. + + `MCPServer.sse_app()` exists but does not expose the underlying `SseServerTransport`, which + the SSE-specific tests need; building the app explicitly here gives both server flavours the + same routing while keeping that handle. + """ + sse = SseServerTransport( + "/messages/", security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False) + ) + lowlevel = server._lowlevel_server if isinstance(server, MCPServer) else server + + async def handle_sse(request: Request) -> Response: + async with sse.connect_sse(request.scope, request.receive, request._send) as (read, write): + await lowlevel.run(read, write, lowlevel.create_initialization_options()) + return Response() + + app = Starlette( + routes=[ + Route("/sse", endpoint=handle_sse, methods=["GET"]), + Mount("/messages/", app=sse.handle_post_message), + ], + ) + return app, sse + + +@asynccontextmanager +async def connect_over_sse( + server: Server | MCPServer, + *, + read_timeout_seconds: float | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + client_info: Implementation | None = None, + elicitation_callback: ElicitationFnT | None = None, +) -> AsyncIterator[Client]: + """Yield a Client connected to the server's legacy SSE transport, entirely in process.""" + app, _ = build_sse_app(server) + + def httpx_client_factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + # The SSE server transport's connect_sse runs the entire MCP session inside the GET + # request and only releases its streams after that request observes a disconnect, so the + # bridge must let the application drain rather than cancelling at close. + return httpx.AsyncClient( + transport=StreamingASGITransport(app, cancel_on_close=False), + base_url=BASE_URL, + headers=headers, + timeout=timeout, + auth=auth, + ) + + transport = sse_client(f"{BASE_URL}/sse", httpx_client_factory=httpx_client_factory) + async with Client( + transport, + read_timeout_seconds=read_timeout_seconds, + sampling_callback=sampling_callback, + list_roots_callback=list_roots_callback, + logging_callback=logging_callback, + message_handler=message_handler, + client_info=client_info, + elicitation_callback=elicitation_callback, + ) as client: + yield client diff --git a/tests/interaction/_helpers.py b/tests/interaction/_helpers.py new file mode 100644 index 0000000000..25833b0ca5 --- /dev/null +++ b/tests/interaction/_helpers.py @@ -0,0 +1,107 @@ +"""Shared helpers for the interaction suite. + +Keep this module small: it exists only for (a) types that every test would otherwise have to +assemble from the SDK's internals to annotate a client callback, and (b) the recording transport +used by the wire-level tests. Server fixtures and assertion helpers belong in the test that uses +them. +""" + +from types import TracebackType + +import anyio +from typing_extensions import Self + +from mcp.client._transport import ReadStream, Transport, TransportStreams, WriteStream +from mcp.shared.message import SessionMessage +from mcp.shared.session import RequestResponder +from mcp.types import ClientResult, ServerNotification, ServerRequest + +# TODO: this union is the parameter type of every client message handler (MessageHandlerFnT), +# but the SDK does not export a name for it -- writing a correctly-typed handler requires +# importing RequestResponder from mcp.shared.session and assembling the union by hand. It +# should be a named, exported alias next to MessageHandlerFnT (like ClientRequestContext is +# for the request callbacks), at which point this alias can be deleted. +IncomingMessage = RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception +"""Everything a client message handler can receive.""" + + +class _RecordingReadStream: + """Delegates to a read stream, appending every received message to a log.""" + + def __init__(self, inner: ReadStream[SessionMessage | Exception], log: list[SessionMessage | Exception]) -> None: + self._inner = inner + self._log = log + + async def receive(self) -> SessionMessage | Exception: + item = await self._inner.receive() + self._log.append(item) + return item + + async def aclose(self) -> None: + await self._inner.aclose() + + def __aiter__(self) -> Self: + return self + + async def __anext__(self) -> SessionMessage | Exception: + try: + return await self.receive() + except anyio.EndOfStream: + raise StopAsyncIteration from None + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None + ) -> bool | None: + await self.aclose() + return None + + +class _RecordingWriteStream: + """Delegates to a write stream, appending every sent message to a log.""" + + def __init__(self, inner: WriteStream[SessionMessage], log: list[SessionMessage]) -> None: + self._inner = inner + self._log = log + + async def send(self, item: SessionMessage, /) -> None: + self._log.append(item) + await self._inner.send(item) + + async def aclose(self) -> None: + await self._inner.aclose() + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None + ) -> bool | None: + await self.aclose() + return None + + +class RecordingTransport: + """Wraps a Transport and records every message crossing the client's transport boundary. + + `sent` holds everything the client wrote towards the server; `received` holds everything the + server delivered to the client. The recording sits at the transport seam -- the exact payloads + a real transport would serialise -- and never touches the session, so wire-level assertions + written against it survive changes to the receive path. + """ + + def __init__(self, inner: Transport) -> None: + self.inner = inner + self.sent: list[SessionMessage] = [] + self.received: list[SessionMessage | Exception] = [] + + async def __aenter__(self) -> TransportStreams: + read_stream, write_stream = await self.inner.__aenter__() + return _RecordingReadStream(read_stream, self.received), _RecordingWriteStream(write_stream, self.sent) + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None + ) -> bool | None: + return await self.inner.__aexit__(exc_type, exc_val, exc_tb) diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py new file mode 100644 index 0000000000..109b30fc77 --- /dev/null +++ b/tests/interaction/_requirements.py @@ -0,0 +1,2816 @@ +"""Requirements manifest for the interaction-model test suite. + +Every user-facing behaviour the SDK must satisfy, keyed by a stable `:[:]` +ID. Each entry owns the tests that exercise it: tests declare `@requirement("")` (a test that +proves several behaviours stacks several decorators) and `test_coverage.py` enforces the contract +in both directions: every non-deferred requirement has at least one test, and every test carries +at least one requirement. + +Sources: + spec URL -- externally mandated by the MCP specification (deep link to the section) + `sdk` -- a behavioural guarantee the SDK chose; not spec-mandated + `issue:#n` -- regression lock-in for a previously fixed bug + +The `behavior` sentence describes the REQUIRED behaviour -- what the specification (or the SDK's +own contract) says should happen. Tests always pin the SDK's current behaviour. Where current +behaviour falls short of `behavior`, the gap is recorded as data: `divergence` on entries whose +tests pin the divergent behaviour, or `deferred` on entries that are tracked but not yet covered +by a test in this suite. An entry may carry both: `divergence` records the spec-compliance gap +(issue-able) and `deferred` records why no test exists; `divergence` alone implies a test pins +the divergent behaviour. `issue` carries the tracking link for a recorded gap once one is filed. + +`deferred` reasons take one of three shapes: where the behaviour is exercised elsewhere in this +repo the reason names the covering test path; where the SDK does not implement the behaviour at +all the reason starts with "Not implemented in the SDK"; and where an interaction-level test is +planned but not yet written the reason starts with "Not yet covered here". + +`transports` records which transports a behaviour applies to (or is observable on); None means +the behaviour is transport-independent. + +The ID vocabulary and entry granularity are aligned with the TypeScript SDK's end-to-end +requirements suite, so coverage and recorded divergences can be compared across the two SDKs +entry by entry; IDs that exist in only one SDK reflect genuinely different API surface. +""" + +import re +from collections.abc import Callable +from dataclasses import dataclass +from typing import Literal, TypeVar + +import pytest + +SPEC_REVISION = "2025-11-25" +SPEC_BASE_URL = f"https://modelcontextprotocol.io/specification/{SPEC_REVISION}" + +Transport = Literal["in-memory", "stdio", "streamable-http", "sse"] + +_TestFn = TypeVar("_TestFn", bound=Callable[..., object]) + +_SOURCE_PATTERN = re.compile(r"https://modelcontextprotocol\.io/specification/.+|sdk|issue:#\d+") + +_TASKS_DEFERRAL = ( + "Tasks are experimental and the spec is being substantially revised; python task behaviour is " + "covered by tests/experimental/tasks/ until the next spec revision settles." +) + + +@dataclass(frozen=True, kw_only=True) +class Divergence: + """A documented gap between the SDK behaviour this suite pins and what `source` mandates.""" + + note: str + issue: str | None = None + + +@dataclass(frozen=True, kw_only=True) +class Requirement: + """A single testable behaviour and the provenance of why it must hold.""" + + source: str + behavior: str + transports: tuple[Transport, ...] | None = None + divergence: Divergence | None = None + deferred: str | None = None + issue: str | None = None + + def __post_init__(self) -> None: + if not _SOURCE_PATTERN.fullmatch(self.source): + raise ValueError(f"source must be a specification URL, 'sdk', or 'issue:#n', got {self.source!r}") + + +REQUIREMENTS: dict[str, Requirement] = { + # ═══════════════════════════════════════════════════════════════════════════ + # Lifecycle & version negotiation + # ═══════════════════════════════════════════════════════════════════════════ + "lifecycle:capability:client-not-declared": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#operation", + behavior=( + "The client rejects sending notifications or registering handlers for capabilities it did not declare." + ), + divergence=Divergence( + note=( + "The client does not check its own declared capabilities before sending notifications or " + "serving callbacks; nothing prevents a caller from violating the spec's MUST." + ), + ), + deferred=( + "Not implemented in the SDK: the client does not check its own declared capabilities before " + "sending notifications or serving callbacks." + ), + ), + "lifecycle:capability:server-not-advertised": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#operation", + behavior=( + "The client rejects calls to methods (e.g. resources/list) for capabilities the server did not advertise." + ), + divergence=Divergence( + note=( + "The client sends any request regardless of the server's advertised capabilities and " + "surfaces whatever the server answers; the spec's MUST is not enforced." + ), + ), + deferred=( + "Not implemented in the SDK: the client sends any request regardless of the server's " + "advertised capabilities and surfaces whatever the server answers." + ), + ), + "lifecycle:initialize:basic": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior=( + "Connecting sends initialize with the protocol version, client capabilities, and client " + "info; the server responds with its own and the connection is established." + ), + ), + "lifecycle:initialize:server-info": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior="The initialize result identifies the server: name and version, plus title when declared.", + ), + "lifecycle:initialize:instructions": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior="A server may include an instructions string in the initialize result; the client exposes it.", + ), + "lifecycle:initialize:capabilities:from-handlers": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#capability-negotiation", + behavior=( + "The server advertises a capability for each feature area it has a registered handler for, " + "and omits the capability for areas it does not." + ), + ), + "lifecycle:initialize:capabilities:minimal": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#capability-negotiation", + behavior="A server with no feature handlers advertises no feature capabilities.", + ), + "lifecycle:initialize:client-info": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior="The client's name, version, and title are visible to server handlers after initialization.", + ), + "lifecycle:initialize:client-capabilities": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#capability-negotiation", + behavior=( + "The client capabilities visible to the server reflect which client callbacks are configured " + "(sampling, elicitation, roots)." + ), + ), + "lifecycle:initialized-notification": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior=( + "After successful initialization, the client sends exactly one initialized notification, " + "before any non-ping request." + ), + ), + "lifecycle:ping": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/ping#behavior-requirements", + behavior="ping in either direction returns an empty result.", + ), + "ping:client-to-server": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/ping#behavior-requirements", + behavior="A client-initiated ping receives an empty result from the server.", + ), + "ping:server-to-client": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/ping#behavior-requirements", + behavior="A server-initiated ping receives an empty result from the client.", + ), + "lifecycle:requests-before-initialized": Requirement( + source="sdk", + behavior=( + "A request other than ping sent before the initialization handshake completes is rejected with an error." + ), + ), + "lifecycle:pre-initialization-ordering": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior=( + "Before initialization completes, the client sends no requests other than pings, and the " + "server sends no requests other than pings and logging." + ), + divergence=Divergence( + note=( + "The server's send methods (create_message / elicit_form / list_roots) do not check " + "initialization state before sending; on the client side, Client always completes the " + "handshake before any caller code runs." + ), + ), + deferred=( + "Not implemented in the SDK: neither side enforces sender-side restraint. The server's send " + "methods (create_message / elicit_form / list_roots) do not check initialization state before " + "sending, and there is no natural hook to issue a server-to-client request between the " + "initialize response and the initialized notification through the public API; on the client " + "side, Client always completes the handshake before any caller code runs." + ), + ), + "lifecycle:version:downgrade": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#version-negotiation", + behavior=( + "When the server returns an older supported protocol version, the client downgrades to it " + "and the connection succeeds at that version." + ), + ), + "lifecycle:version:match": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#version-negotiation", + behavior=( + "When the server supports the requested protocol version it echoes that version in the " + "initialize result, and the connection proceeds at that version." + ), + ), + "lifecycle:version:server-fallback-latest": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#version-negotiation", + behavior=( + "An initialize request carrying a protocol version the server does not support is answered " + "with another version the server supports — the latest one — rather than an error." + ), + ), + "lifecycle:version:reject-unsupported": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#version-negotiation", + behavior=( + "A client that receives an initialize response carrying a protocol version it does not " + "support fails initialization with an error rather than proceeding with the session." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Protocol primitives: cancellation, timeout, progress, errors, _meta + # ═══════════════════════════════════════════════════════════════════════════ + "protocol:request-id:unique": Requirement( + source=f"{SPEC_BASE_URL}/basic#requests", + behavior=( + "Every request sent on a session carries a unique, non-null string or integer id; ids are " + "never reused within the session." + ), + ), + "protocol:notifications:no-response": Requirement( + source=f"{SPEC_BASE_URL}/basic#notifications", + behavior=( + "Notifications are never answered: every message the server delivers is either the response " + "to a request the client sent or a notification carrying no id." + ), + ), + "protocol:cancel:abort-signal": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#cancellation-flow", + behavior=( + "Cancelling an in-flight request through the client API sends notifications/cancelled with " + "the request id and fails the local call." + ), + deferred=( + "Not implemented in the SDK: there is no public client-side API to cancel an in-flight " + "request; cancellation requires hand-constructing the notification (which is how " + "protocol:cancel:in-flight exercises the receiving side)." + ), + ), + "protocol:cancel:handler-abort-propagates": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior="On the receiving side, a cancellation notification stops the running request handler.", + ), + "protocol:cancel:in-flight": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior=( + "A cancellation notification for an in-flight request stops the server-side handler, and the " + "receiver does not send a response for the cancelled request." + ), + divergence=Divergence( + note=( + "The spec says receivers of a cancellation SHOULD NOT send a response for the cancelled " + "request; the server sends an error response (code 0, 'Request cancelled'), which is what " + "unblocks the SDK client's pending call." + ), + ), + ), + "protocol:cancel:initialize-not-cancellable": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior="The client never sends notifications/cancelled for the initialize request.", + deferred=( + "Not implemented in the SDK: the client has no public cancellation API at all, so no pathway " + "exists that could cancel initialize; there is no distinct behaviour to pin beyond that absence." + ), + ), + "protocol:cancel:late-response-ignored": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior=( + "A response that arrives after the sender issued notifications/cancelled is ignored; the " + "request stays failed and no error is raised." + ), + divergence=Divergence( + note=( + "A response whose id matches no in-flight request is delivered to the message handler " + "as a RuntimeError rather than being silently ignored. The post-cancellation case is the " + "same code path; tested in its unknown-id form because that is deterministic without the " + "client-side cancellation API the SDK does not yet provide." + ), + ), + ), + "protocol:cancel:server-survives": Requirement( + source="sdk", + behavior="The session continues to serve new requests after an earlier request was cancelled.", + ), + "protocol:cancel:server-to-client": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior=( + "A server that abandons an in-flight server-initiated request (sampling, elicitation, roots) " + "cancels it, and the client stops processing the cancelled request." + ), + divergence=Divergence( + note=( + "Abandoning a server-side send_request emits no cancellation notification, and the client " + "could not act on one anyway: client callbacks run inline in the receive loop, so a " + "cancellation is not even read until the callback has finished." + ), + ), + deferred=( + "Not implemented in the SDK: abandoning a server-side send_request emits no cancellation " + "notification (the same sender-side gap recorded on protocol:timeout:sends-cancellation), and " + "the client could not act on one anyway because client callbacks run inline in the receive " + "loop, so a cancellation would not even be read until the callback had already finished." + ), + ), + "protocol:cancel:unknown-id-ignored": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#error-handling", + behavior=( + "The receiver silently ignores a cancellation notification referencing an unknown or " + "already-completed request id; no error response is sent and no exception is raised." + ), + ), + "protocol:cancel:sender-targeting": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior=( + "Cancellation notifications reference only requests that were previously issued in the same " + "direction and are believed to still be in flight." + ), + deferred=( + "Not implemented in the SDK: there is no public client-side cancel API to drive (see " + "protocol:cancel:abort-signal), so the sender-side targeting rule has nothing to pin." + ), + ), + "protocol:error:connection-closed": Requirement( + source="sdk", + behavior="Closing the transport fails all in-flight requests with a connection-closed error.", + ), + "protocol:error:internal-error": Requirement( + source=f"{SPEC_BASE_URL}/basic#responses", + behavior=( + "An unhandled exception in a request handler is returned to the caller as JSON-RPC error " + "-32603 Internal error." + ), + divergence=Divergence( + note=( + "The low-level Server returns code 0 (not a defined JSON-RPC code) instead of -32603 and " + "leaks str(exc) as the error message." + ), + ), + ), + "protocol:error:invalid-params": Requirement( + source=f"{SPEC_BASE_URL}/basic#responses", + behavior="A request with malformed params is answered with JSON-RPC error -32602 Invalid params.", + ), + "protocol:error:method-not-found": Requirement( + source=f"{SPEC_BASE_URL}/basic#responses", + behavior="A request whose method has no registered handler is answered with a METHOD_NOT_FOUND error.", + ), + "protocol:meta:related-task": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#related-task-metadata", + behavior="Messages may carry related-task _meta associating them with a task.", + deferred=_TASKS_DEFERRAL, + ), + "meta:request-to-handler": Requirement( + source=f"{SPEC_BASE_URL}/basic#_meta", + behavior="The _meta object the client attaches to a request is visible to the server handler.", + ), + "meta:result-to-client": Requirement( + source=f"{SPEC_BASE_URL}/basic#_meta", + behavior="The _meta object a handler attaches to its result is delivered to the client.", + ), + "protocol:progress:callback": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=( + "Progress notifications emitted by a handler during a request are delivered to the caller's " + "progress callback, in order, with their progress, total, and message." + ), + ), + "protocol:progress:token-injected": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=( + "Supplying a progress callback attaches a progress token to the outgoing request, which the " + "server-side handler can observe in its request metadata." + ), + ), + "protocol:progress:token-unique": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=("Concurrent in-flight requests that each supply a progress callback carry distinct progress tokens."), + ), + "protocol:progress:monotonic": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=( + "The progress value increases with each notification for a given token, even when the total is unknown." + ), + divergence=Divergence( + note=( + "The spec MUST is not enforced: progress values are not validated on either side, so a " + "handler that emits non-increasing values has them forwarded to the callback unchanged." + ), + ), + ), + "protocol:progress:stops-after-completion": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#behavior-requirements", + behavior="Progress notifications for a token stop once the associated request completes.", + divergence=Divergence( + note=( + "send_progress_notification does not check whether the token's request has already " + "completed; the late notification is sent and reaches the client." + ), + ), + ), + "protocol:progress:late-dropped-by-client": Requirement( + source="sdk", + behavior=( + "A progress notification that arrives after its request has completed is not delivered to the " + "original progress callback." + ), + ), + "protocol:progress:no-token": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior="Without a progress callback the request carries no progress token.", + ), + "protocol:progress:client-to-server": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior="A progress notification sent by the client is delivered to the server's progress handler.", + ), + "protocol:timeout:basic": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior=( + "A request that exceeds its read timeout fails with a request-timeout error instead of " + "waiting forever for the response." + ), + ), + "protocol:timeout:max-total": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior="A maximum total timeout is enforced even when progress notifications keep arriving.", + divergence=Divergence( + note=( + "There is no maximum-total-timeout option; only the per-request read timeout exists, so the " + "spec's SHOULD that an overall maximum is always enforced cannot be satisfied." + ), + ), + deferred=( + "Not implemented in the SDK: there is no maximum-total-timeout option; only the per-request " + "read timeout exists." + ), + ), + "protocol:timeout:reset-on-progress": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior="When configured to do so, each progress notification resets the request's read timeout.", + deferred=( + "Not implemented in the SDK: progress notifications do not reset the request read timeout and " + "no option exists to enable that." + ), + ), + "protocol:timeout:sends-cancellation": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior=( + "When a request times out, the sender issues notifications/cancelled for that request before " + "failing the local call." + ), + divergence=Divergence( + note=( + "The client only raises locally and sends nothing on timeout, so the server keeps running the handler." + ), + ), + ), + "protocol:timeout:session-survives": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior="The session continues to serve new requests after an earlier request timed out.", + ), + "protocol:timeout:session-default": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior="A session-level read timeout applies to every request that does not override it.", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Tools + # ═══════════════════════════════════════════════════════════════════════════ + "tools:call:content:audio": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#audio-content", + behavior="A tool result can carry audio content: base64 data with a mimeType.", + ), + "tools:call:content:embedded-resource": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#embedded-resources", + behavior="A tool result can carry an embedded resource with full text or blob contents.", + ), + "tools:call:content:image": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#image-content", + behavior="A tool result can carry image content: base64 data with a mimeType.", + ), + "tools:call:content:mixed": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool-result", + behavior="A tool result can carry multiple content blocks of different types; order is preserved.", + ), + "tools:call:content:resource-link": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#resource-links", + behavior="A tool result can carry a resource_link content block referencing a resource by URI.", + ), + "tools:call:content:text": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#text-content", + behavior="tools/call delivers arguments to the tool handler and returns its text content to the caller.", + ), + "tools:call:concurrent": Requirement( + source="sdk", + behavior=( + "Multiple tool calls in flight on one session are dispatched concurrently, and each caller " + "receives the response to its own request." + ), + ), + "tools:call:elicitation-roundtrip": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#user-interaction-model", + behavior=( + "A tool handler that issues an elicitation receives the client's result and can embed it in " + "the tool call result." + ), + ), + "tools:call:is-error": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#error-handling", + behavior=( + "A tool execution failure is returned as a result with isError true and the failure described " + "in content, not as a JSON-RPC error." + ), + ), + "tools:call:logging-mid-execution": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#log-message-notifications", + behavior=( + "Log notifications emitted by a tool handler during execution reach the client's logging " + "callback before the tool result returns." + ), + ), + "tools:call:progress": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=( + "Progress notifications emitted by a tool handler reach the caller's progress callback before " + "the tool result returns." + ), + ), + "tools:call:sampling-roundtrip": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", + behavior=( + "A tool handler that issues a sampling request receives the client's completion and can embed " + "it in the tool call result." + ), + ), + "tools:call:structured-content": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#structured-content", + behavior="A tool result can carry structuredContent alongside content; the client receives both.", + ), + "tools:call:structured-content:text-mirror": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#structured-content", + behavior="A tool returning structured content also returns the serialized JSON as a text content block.", + ), + "tools:call:unknown-name": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#error-handling", + behavior="tools/call for a name the server does not recognise returns a JSON-RPC error.", + ), + "tools:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#capabilities", + behavior="A server with a list_tools handler advertises the tools capability in its initialize result.", + ), + "tools:input-schema:json-schema-2020-12": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool", + behavior=( + "A tool registered with a JSON Schema 2020-12 inputSchema (nested objects, $defs references) " + "is discoverable and callable." + ), + ), + "tools:input-schema:preserve-additional-properties": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool", + behavior="tools/list preserves inputSchema additionalProperties as registered.", + ), + "tools:input-schema:preserve-defs": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool", + behavior="tools/list preserves inputSchema $defs as registered.", + ), + "tools:input-schema:preserve-schema-dialect": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool", + behavior="tools/list preserves the inputSchema $schema dialect URI as registered.", + ), + "tools:list-changed": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#list-changed-notification", + behavior=( + "When the tool set changes, the server sends notifications/tools/list_changed and it reaches " + "the client's handler." + ), + ), + "tools:list:basic": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#listing-tools", + behavior="tools/list returns the registered tools with name, description, and inputSchema.", + ), + "tools:list:metadata": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool", + behavior=( + "Optional Tool fields supplied by the server (title, annotations, outputSchema, icons, _meta) " + "are delivered to the client unchanged." + ), + ), + "tools:list:pagination": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#response-format", + behavior=( + "tools/list supports cursor pagination: the nextCursor returned by a list handler round-trips " + "back to the handler as an opaque cursor until the listing is exhausted." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Tools: SDK guarantees + # ═══════════════════════════════════════════════════════════════════════════ + "client:output-schema:skip-on-error": Requirement( + source="sdk", + behavior="The client skips structured-content validation when the tool result has isError true.", + ), + "client:output-schema:validate": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#output-schema", + behavior=( + "A tool result whose structuredContent does not conform to the tool's declared outputSchema " + "is rejected by the client: the call raises instead of returning the invalid result." + ), + ), + "client:output-schema:missing-structured": Requirement( + source="sdk", + behavior="A tool that declares an output schema but returns no structuredContent fails client-side validation.", + ), + "client:output-schema:auto-list": Requirement( + source="sdk", + behavior=( + "Calling a tool whose output schema is not yet cached issues an implicit tools/list to " + "populate the cache; subsequent calls of the same tool do not." + ), + divergence=Divergence( + note=( + "Design concern rather than spec violation: the implicit request is invisible to the " + "caller, and against a server that registers only on_call_tool a successful call surfaces " + "as METHOD_NOT_FOUND from a tools/list the caller never asked for." + ), + ), + ), + "mcpserver:output-schema:missing-structured": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#output-schema", + behavior="A tool with an output schema whose function returns no structured content produces a server error.", + ), + "mcpserver:output-schema:server-validate": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#output-schema", + behavior=( + "MCPServer validates structured content against the tool's output schema before returning; a " + "mismatch produces a server error." + ), + ), + "mcpserver:output-schema:skip-on-error": Requirement( + source="sdk", + behavior="Server-side output schema validation is skipped when the tool returns an isError result.", + ), + "mcpserver:tool:duplicate-name": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool-names", + behavior="Registering a tool with a name already in use is rejected at registration time.", + divergence=Divergence( + note=( + "MCPServer logs a warning and keeps the first registration instead of rejecting; " + "warn_on_duplicate_tools defaults to True and warning is the only effect -- there is " + "no rejection mode." + ), + ), + ), + "mcpserver:tool:extra": Requirement( + source="sdk", + behavior=( + "Tool functions can access request metadata (request id, client params, session) through the " + "Context parameter." + ), + ), + "mcpserver:tool:handler-throws": Requirement( + source="sdk", + behavior=( + "An exception raised by a tool function (ToolError or otherwise) is caught and returned as a " + "tool result with isError true and the failure text in content; it does not become a JSON-RPC error." + ), + ), + "mcpserver:tool:input-validation": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#error-handling", + behavior=( + "Arguments that fail the tool's input validation produce a tool execution error (isError true " + "with the validation failure described in content) without invoking the function." + ), + ), + "mcpserver:tool:naming-validation": Requirement( + source="sdk", + behavior=( + "Registering a tool whose name violates the spec's tool-naming conventions emits a warning; " + "registration still succeeds." + ), + ), + "mcpserver:tool:output-schema:model": Requirement( + source="sdk", + behavior=( + "A tool returning a typed model advertises a matching generated outputSchema and returns the " + "model's fields as structuredContent alongside a serialised text block." + ), + ), + "mcpserver:tool:output-schema:wrapped": Requirement( + source="sdk", + behavior=( + "A tool returning a non-object type (primitive or list) wraps the value as {'result': ...} in " + "structuredContent, with a matching generated outputSchema." + ), + ), + "mcpserver:tool:schema-variants": Requirement( + source="sdk", + behavior=( + "Tool input schemas generated from complex parameter types (unions, nested models, " + "constrained types) validate and coerce arguments before the function runs." + ), + ), + "mcpserver:tool:unknown-name": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#error-handling", + behavior="tools/call for a name that was never registered returns a JSON-RPC error.", + divergence=Divergence( + note=( + "The spec classifies unknown tools as a protocol error (its example uses -32602 Invalid " + "params); MCPServer reports a tool execution error (isError true) instead. The low-level " + "path follows the spec example (see tools:call:unknown-name)." + ), + ), + ), + "mcpserver:tool:url-elicitation-error": Requirement( + source="sdk", + behavior=( + "A tool function that raises the URL-elicitation-required error surfaces to the caller as " + "error -32042 with the elicitation parameters intact." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # MCPServer: Context helpers (SDK) + # ═══════════════════════════════════════════════════════════════════════════ + "mcpserver:context:logging": Requirement( + source="sdk", + behavior=( + "The Context logging helpers (debug/info/warning/error) send log message notifications at the " + "corresponding severity." + ), + ), + "mcpserver:context:progress": Requirement( + source="sdk", + behavior=( + "Context.report_progress sends a progress notification against the requesting client's progress token." + ), + ), + "mcpserver:context:elicit": Requirement( + source="sdk", + behavior=( + "Context.elicit sends a form elicitation built from a typed schema and returns a typed " + "accepted/declined/cancelled result." + ), + ), + "mcpserver:context:read-resource": Requirement( + source="sdk", + behavior="Context.read_resource reads a resource registered on the same server from inside a tool.", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Resources + # ═══════════════════════════════════════════════════════════════════════════ + "resources:annotations": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#annotations", + behavior="Resource annotations supplied by the server round-trip to the client in the list result.", + divergence=Divergence( + note=( + "The SDK Annotations model is missing the schema's lastModified field; MCPModel uses the " + "pydantic default extra='ignore', so the value is silently dropped on parse." + ), + ), + ), + "resources:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#capabilities", + behavior=( + "A server with resource handlers advertises the resources capability, including the subscribe " + "sub-flag when a subscribe handler is registered." + ), + ), + "resources:list-changed": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#list-changed-notification", + behavior=( + "When the resource set changes, the server sends notifications/resources/list_changed and it " + "reaches the client's handler." + ), + ), + "resources:list:basic": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#listing-resources", + behavior=( + "resources/list returns the registered resources with uri, name, and the optional descriptive " + "fields supplied by the server." + ), + ), + "resources:list:pagination": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#operations-supporting-pagination", + behavior="resources/list supports cursor pagination.", + ), + "resources:read:blob": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#reading-resources", + behavior="resources/read returns binary contents base64-encoded in blob.", + ), + "resources:read:template-vars": Requirement( + source="sdk", + behavior="Variables extracted from a templated resource URI reach the resource function as typed arguments.", + ), + "resources:read:text": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#reading-resources", + behavior="resources/read returns text contents carrying uri, mimeType, and the text.", + ), + "resources:read:unknown-uri": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#error-handling", + behavior="resources/read for an unknown URI returns JSON-RPC error -32002 (resource not found).", + ), + "resources:subscribe": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#subscriptions", + behavior="resources/subscribe delivers the URI to the server's subscribe handler and returns an empty result.", + ), + "resources:subscribe:capability-required": Requirement( + source="sdk", + behavior=( + "resources/subscribe to a server that did not advertise the subscribe capability is rejected with an error." + ), + ), + "resources:subscribe:updated": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#subscriptions", + behavior="After resources/subscribe, changes to that resource send notifications/resources/updated.", + deferred=( + "Not implemented in the SDK: the server keeps no subscription state linking subscribe to " + "updated notifications; emitting updates is entirely handler code. The two halves are pinned " + "separately by resources:subscribe and resources:updated-notification." + ), + ), + "resources:templates:list": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#resource-templates", + behavior=( + "resources/templates/list returns the registered templates with their uriTemplate and descriptive fields." + ), + ), + "resources:templates:pagination": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#operations-supporting-pagination", + behavior="resources/templates/list supports cursor pagination.", + ), + "resources:unsubscribe": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#subscriptions", + behavior=( + "resources/unsubscribe delivers the URI to the server's unsubscribe handler and returns an empty result." + ), + ), + "resources:unsubscribe:stops-updates": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#subscriptions", + behavior="After resources/unsubscribe the server stops sending updated notifications for that URI.", + deferred=( + "Not implemented in the SDK: the server keeps no subscription state, so whether updated " + "notifications stop after unsubscribe is entirely handler code; there is no SDK behaviour to " + "pin beyond the unsubscribe request reaching the handler (covered by resources:unsubscribe)." + ), + ), + "resources:updated-notification": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#subscriptions", + behavior=( + "A resources/updated notification sent by the server reaches the client carrying the URI of " + "the changed resource." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Resources: SDK guarantees + # ═══════════════════════════════════════════════════════════════════════════ + "mcpserver:resource:duplicate-name": Requirement( + source="sdk", + behavior="Registering a resource or template with a duplicate identifier is rejected at registration time.", + divergence=Divergence( + note=( + "MCPServer logs a warning and keeps the first registration instead of rejecting; same " + "warn-and-ignore behaviour as duplicate tool names (mcpserver:tool:duplicate-name). " + "Templates differ: a duplicate uri_template silently replaces the first with no warning." + ), + ), + ), + "mcpserver:resource:read-throws-surfaced": Requirement( + source="sdk", + behavior="A resource function that raises is surfaced to the caller as a JSON-RPC error response.", + ), + "mcpserver:resource:static": Requirement( + source="sdk", + behavior=( + "A function registered with @mcp.resource() for a fixed URI is listed by resources/list and " + "served by resources/read at that URI." + ), + ), + "mcpserver:resource:template": Requirement( + source="sdk", + behavior=( + "A function registered with a URI template is listed by resources/templates/list and matched " + "by resources/read, receiving the parameters extracted from the requested URI." + ), + ), + "mcpserver:resource:unknown-uri": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#error-handling", + behavior="resources/read for a URI matching no registered resource returns JSON-RPC error -32002.", + divergence=Divergence( + note=( + "The spec reserves -32002 for resource-not-found; MCPServer raises ResourceError, which " + "the low-level server converts to error code 0." + ), + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Prompts + # ═══════════════════════════════════════════════════════════════════════════ + "prompts:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#capabilities", + behavior="A server with a list_prompts handler advertises the prompts capability in its initialize result.", + ), + "prompts:get:content:audio": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#audio-content", + behavior="Prompt messages may contain audio content with base64 data and a mimeType.", + ), + "prompts:get:content:embedded-resource": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#embedded-resources", + behavior="Prompt messages may contain embedded resource content.", + ), + "prompts:get:content:image": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#image-content", + behavior="Prompt messages may contain image content.", + ), + "prompts:get:missing-required-args": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#error-handling", + behavior="prompts/get omitting a required argument returns JSON-RPC error -32602 (Invalid params).", + divergence=Divergence( + note=( + "MCPServer's prompt renderer raises a plain ValueError before the prompt function runs, " + "which the low-level server converts to error code 0 with the exception text as the message." + ), + ), + ), + "prompts:get:multi-message": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#getting-a-prompt", + behavior="A prompt can return multiple messages mixing user and assistant roles; order is preserved.", + ), + "prompts:get:no-args": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#getting-a-prompt", + behavior="prompts/get with no arguments returns the prompt's messages.", + ), + "prompts:get:unknown-name": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#error-handling", + behavior="prompts/get for an unknown prompt name returns JSON-RPC error -32602 (Invalid params).", + ), + "prompts:get:with-args": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#getting-a-prompt", + behavior="prompts/get delivers the supplied arguments to the prompt handler and returns its messages.", + ), + "prompts:list-changed": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#list-changed-notification", + behavior=( + "When the prompt set changes, the server sends notifications/prompts/list_changed and it " + "reaches the client's handler." + ), + ), + "prompts:list:basic": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#listing-prompts", + behavior="prompts/list returns the registered prompts with name, description, and argument declarations.", + ), + "prompts:list:pagination": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#operations-supporting-pagination", + behavior="prompts/list supports cursor pagination.", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Prompts: SDK guarantees + # ═══════════════════════════════════════════════════════════════════════════ + "mcpserver:prompt:args-validation": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#implementation-considerations", + behavior="prompts/get arguments that fail the prompt's argument schema are rejected before the function runs.", + ), + "mcpserver:prompt:decorated": Requirement( + source="sdk", + behavior=( + "A function registered with @mcp.prompt() is listed with arguments derived from its signature " + "and rendered into prompt messages by prompts/get." + ), + ), + "mcpserver:prompt:duplicate-name": Requirement( + source="sdk", + behavior="Registering a duplicate prompt name is rejected at registration time.", + divergence=Divergence( + note=( + "MCPServer logs a warning and keeps the first registration instead of rejecting; same " + "warn-and-ignore behaviour as duplicate tool names (mcpserver:tool:duplicate-name)." + ), + ), + ), + "mcpserver:prompt:optional-args": Requirement( + source="sdk", + behavior="A prompt with optional arguments can be fetched without supplying them.", + ), + "mcpserver:prompt:unknown-name": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#error-handling", + behavior="prompts/get for a name that was never registered returns JSON-RPC error -32602 (Invalid params).", + divergence=Divergence( + note=( + "The spec's example uses -32602 Invalid params for unknown prompts; MCPServer raises " + "ValueError, which the low-level server converts to error code 0." + ), + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Completion + # ═══════════════════════════════════════════════════════════════════════════ + "completion:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#capabilities", + behavior="A server with a completion handler advertises the completions capability in its initialize result.", + ), + "completion:complete:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#capabilities", + behavior=( + "A server with no completion handler does not advertise the completions capability and rejects " + "completion/complete with METHOD_NOT_FOUND." + ), + ), + "completion:context-arguments": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#requesting-completions", + behavior="Previously-resolved argument values supplied in context.arguments reach the completion handler.", + ), + "completion:error:invalid-ref": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#error-handling", + behavior=( + "completion/complete with a ref naming an unknown prompt or non-matching resource URI returns " + "JSON-RPC error -32602 (Invalid params)." + ), + ), + "completion:prompt-arg": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#reference-types", + behavior="completion/complete with a ref/prompt returns suggested values for the named prompt argument.", + ), + "completion:resource-template-arg": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#reference-types", + behavior="completion/complete with a ref/resource returns suggested values for a URI template variable.", + ), + "completion:result-shape": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#completion-results", + behavior="The completion result carries values (at most 100), an optional total, and an optional hasMore flag.", + ), + "mcpserver:completion:capability-auto": Requirement( + source="sdk", + behavior=( + "MCPServer advertises the completions capability when at least one completion source is " + "registered, and omits it otherwise." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Logging + # ═══════════════════════════════════════════════════════════════════════════ + "logging:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#capabilities", + behavior=( + "A server that emits log message notifications declares the logging capability in its initialize result." + ), + divergence=Divergence( + note=( + "MCPServer registers no setLevel handler, so capability derivation leaves logging unset " + "even though the Context helpers send log message notifications." + ), + ), + ), + "logging:message:all-levels": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#log-levels", + behavior="All eight RFC 5424 severity levels are deliverable as log message notifications.", + ), + "logging:message:fields": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#log-message-notifications", + behavior=( + "A log message sent by a server handler is delivered to the client's logging callback with its " + "severity level, logger name, and data." + ), + ), + "logging:message:filtered": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#setting-log-level", + behavior="After logging/setLevel, log messages below the configured level are not sent.", + divergence=Divergence( + note=( + "Neither MCPServer (which rejects logging/setLevel with method-not-found) nor the " + "low-level Server (which leaves the handler entirely to the author) implements any " + "filtering; messages are delivered at every severity regardless of the requested level." + ), + ), + ), + "logging:set-level": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#setting-log-level", + behavior="logging/setLevel delivers the requested level to the server's handler and returns an empty result.", + ), + "logging:set-level:invalid-level": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#error-handling", + behavior="logging/setLevel with an invalid level value returns JSON-RPC error -32602 (Invalid params).", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Sampling (server → client) + # ═══════════════════════════════════════════════════════════════════════════ + "sampling:capability:declare": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#capabilities", + behavior=( + "A client that handles sampling requests advertises the sampling capability in its initialize request." + ), + ), + "sampling:create:basic": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", + behavior=( + "A sampling/createMessage request from a server handler is answered by the client's sampling " + "callback, and the callback's result (role, content, model, stopReason) is returned to the handler." + ), + ), + "sampling:create:include-context": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#capabilities", + behavior="The includeContext value supplied by the server reaches the client callback intact.", + ), + "sampling:context:server-gated-by-capability": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#capabilities", + behavior=( + "The server does not use includeContext values thisServer or allServers unless the client " + "declared the sampling.context capability." + ), + divergence=Divergence( + note=( + "include_context is forwarded regardless of the client's declared sampling.context " + "capability; the server-side validator only checks tools/tool_choice." + ), + ), + ), + "sampling:create:model-preferences": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#model-preferences", + behavior=( + "The model preferences supplied by the server (hints and the cost, speed, and intelligence " + "priorities) reach the client callback intact." + ), + ), + "sampling:create:system-prompt": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", + behavior="The system prompt supplied by the server reaches the client callback intact.", + ), + "sampling:create:tools": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#tools-in-sampling", + behavior=( + "A sampling request carrying tools and toolChoice reaches the client, and a tool_use response " + "with a toolUse stop reason returns to the requesting handler." + ), + deferred=( + "Not implemented in the SDK: Client does not expose ClientSession's sampling_capabilities " + "parameter, so a client can never declare sampling.tools and the server-side validator " + "rejects every tool-enabled request before it is sent." + ), + ), + "sampling:create-message:audio-content": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#audio-content", + behavior="Sampling messages can carry audio content: base64 data with a mimeType.", + ), + "sampling:create-message:image-content": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#image-content", + behavior="Sampling messages can carry image content: base64 data with a mimeType.", + ), + "sampling:create-message:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#capabilities", + behavior=( + "A sampling request to a client that did not declare the sampling capability fails with an " + "error rather than hanging or being silently dropped; the spec names no error code for this case." + ), + ), + "sampling:error:user-rejected": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#error-handling", + behavior=( + "A sampling request the user rejects is answered with a JSON-RPC error (the spec's code for " + "this case is -1, 'User rejected sampling request'), surfaced to the requesting handler as an MCPError." + ), + ), + "sampling:message:content-cardinality": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling", + behavior="A sampling message's content may be a single block or an array of blocks.", + ), + "sampling:result:no-tools-single-content": Requirement( + source="sdk", + behavior=( + "When the request carries no tools, a sampling callback result whose content is an array is " + "rejected by the client." + ), + divergence=Divergence( + note=( + "The client does not validate the callback result against the request shape; an array-content " + "result for a tool-free request is accepted client-side and surfaces as a raw " + "pydantic.ValidationError from the server's response parsing (send_request) instead." + ), + ), + ), + "sampling:result:with-tools-array-content": Requirement( + source="sdk", + behavior=( + "When the request includes tools, the client accepts a callback result whose content is an " + "array including tool_use blocks." + ), + deferred=( + "Not implemented in the SDK: requires declaring sampling.tools, which the high-level client " + "cannot do (see sampling:create:tools)." + ), + ), + "sampling:tool-result:no-mixed-content": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#tool-result-messages", + behavior=( + "A user sampling message that carries tool_result content contains only tool_result blocks; " + "mixing tool_result with text, image, or audio content is rejected as invalid." + ), + ), + "sampling:tool-use:result-balance": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#tool-use-and-result-balance", + behavior=( + "In a sampling/createMessage request, every assistant tool_use block in messages MUST be " + "matched by a tool_result with the same toolUseId in the immediately-following user message; " + "an unmatched tool_use is rejected with -32602 Invalid params." + ), + divergence=Divergence( + note=( + "The client does not validate inbound tool_use/tool_result balance; the SDK enforces " + "the rule server-side instead, before the request leaves the server (see " + "sampling:tool-use:server-preflight)." + ), + ), + deferred=( + "Not implemented on the client receive path: validation runs only on the server send path " + "(pinned by sampling:tool-use:server-preflight)." + ), + ), + "sampling:tool-use:server-preflight": Requirement( + source="sdk", + behavior=( + "The server validates tool_use/tool_result balance before sending a sampling/createMessage " + "request; an unmatched tool_use raises ValueError and the request never reaches the wire." + ), + ), + "sampling:tools:server-gated-by-capability": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#tools-in-sampling", + behavior=( + "A tool-enabled sampling request to a client that did not declare sampling.tools is rejected " + "by the server before anything reaches the wire (the SDK surfaces this as an Invalid params error)." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Elicitation (server → client) + # ═══════════════════════════════════════════════════════════════════════════ + "elicitation:capability:empty-is-form": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#capabilities", + behavior="A client advertising an empty elicitation capability accepts form-mode elicitation requests.", + deferred=( + "Not implemented in the SDK: a Client with an elicitation callback always declares explicit " + "form and url sub-capabilities, so an empty elicitation capability cannot be produced through " + "the public API." + ), + ), + "elicitation:capability:mode-mismatch": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#error-handling", + behavior=( + "The client answers elicitation requests for a mode it did not advertise with JSON-RPC error " + "-32602 (Invalid params)." + ), + deferred=( + "Not implemented in the SDK: a client cannot be configured form-only or url-only, so the " + "per-mode mismatch error cannot arise (see elicitation:url:not-supported)." + ), + ), + "elicitation:capability:server-respects-mode": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#capabilities", + behavior=( + "The server refuses to send an elicitation request with a mode the connected client did not " + "declare in its capabilities." + ), + divergence=Divergence( + note=( + "The server does not check the client's declared elicitation modes before sending " + "elicitation/create; the spec's MUST NOT is not enforced." + ), + ), + ), + "elicitation:form:action:accept": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior=( + "A form-mode elicitation answered with action 'accept' returns the user's content to the " + "requesting handler." + ), + ), + "elicitation:form:action:cancel": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior="A form-mode elicitation answered with action 'cancel' returns no content to the handler.", + ), + "elicitation:form:action:decline": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior="A form-mode elicitation answered with action 'decline' returns no content to the handler.", + ), + "elicitation:form:basic": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#form-mode-elicitation-requests", + behavior=( + "A form-mode elicitation delivers the message and requested schema to the client callback " + "exactly as the server sent them." + ), + ), + "elicitation:form:defaults": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#requested-schema", + behavior=( + "Optional default values declared in a form-mode requested schema are pre-populated into the " + "form presented to the user." + ), + deferred=( + "Not implemented in the SDK: there is no form-rendering layer that could pre-populate " + "defaults; client callbacks receive the requested schema as-is." + ), + ), + "elicitation:form:mode-omitted-default": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#elicitation-requests", + behavior="An elicitation request with no mode field is treated as form mode by the client.", + ), + "elicitation:form:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#error-handling", + behavior=( + "An elicitation request to a client that did not declare the elicitation capability is " + "answered with -32602 Invalid params." + ), + divergence=Divergence( + note="The client's default callback answers with -32600 Invalid request instead of -32602.", + ), + ), + "elicitation:form:schema:enum-variants": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#requested-schema", + behavior=( + "Requested-schema enum fields (including titled and multi-select variants) reach the client " + "callback as sent." + ), + ), + "elicitation:form:schema:primitives": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#requested-schema", + behavior="Requested-schema fields may be string (with format), number or integer, or boolean.", + ), + "elicitation:form:schema:restricted-subset": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#requested-schema", + behavior=( + "Form-mode requested schemas are flat objects with primitive-typed properties only; nested " + "structures and arrays of objects are not used." + ), + divergence=Divergence( + note=( + "ServerSession.elicit_form forwards an arbitrary dict[str, Any] schema unchanged; no shape " + "validation at the low-level session layer (the high-level Context.elicit / " + "elicit_with_validation helper enforces primitive-only fields before generating the schema)." + ), + ), + ), + "elicitation:form:response-validation": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#form-mode-security", + behavior=( + "Accepted form-mode content is validated against the requested schema: the client validates " + "the response before sending and the server validates the content it receives." + ), + divergence=Divergence( + note=( + "The client never validates outbound content; ServerSession.elicit_form returns received " + "content unvalidated (the high-level Context.elicit / elicit_with_validation helper " + "validates server-side, but the low-level session API does not)." + ), + ), + ), + "elicitation:url:action:accept-no-content": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior=( + "A URL-mode elicitation delivers the message, URL, and elicitationId to the client; an accept " + "response carries no content (accept means the user agreed to visit the URL, not that the " + "interaction completed)." + ), + ), + "elicitation:url:basic": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#url-mode-elicitation-requests", + behavior=( + "A url-mode elicitation delivers the elicitation id and URL to the client callback exactly as " + "the server sent them." + ), + ), + "elicitation:url:cancel": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior="A URL-mode elicitation answered with cancel returns the action with no content.", + ), + "elicitation:url:complete-notification": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#completion-notifications-for-url-mode-elicitation", + behavior=( + "An elicitation/complete notification sent by the server after an out-of-band elicitation " + "finishes reaches the client carrying the elicitationId." + ), + ), + "elicitation:url:complete-unknown-ignored": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#completion-notifications-for-url-mode-elicitation", + behavior=( + "The client ignores an elicitation/complete notification referencing an unknown or " + "already-completed elicitationId without error." + ), + ), + "elicitation:url:decline": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior="A URL-mode elicitation answered with decline returns the action with no content.", + ), + "elicitation:url:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#error-handling", + behavior=( + "A URL-mode elicitation to a client that declared only form-mode support is rejected with an " + "Invalid params error." + ), + deferred=( + "Not implemented in the SDK: a Client with an elicitation callback always declares both the " + "form and url sub-capabilities, so a form-only client cannot be constructed." + ), + ), + "elicitation:url:required-error": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#url-elicitation-required-error", + behavior=( + "A handler that cannot proceed without a URL elicitation rejects the request with error " + "-32042, carrying the pending elicitations in the error data." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Roots (server → client) + # ═══════════════════════════════════════════════════════════════════════════ + "roots:list-changed": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#root-list-changes", + behavior="A roots/list_changed notification sent by the client is delivered to the server's handler.", + ), + "roots:list-changed:client-emits": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#root-list-changes", + behavior=( + "A client that declared roots.listChanged sends notifications/roots/list_changed when its set " + "of roots changes." + ), + deferred=( + "Not implemented in the SDK: the client does not own the root set (it calls back to the host " + "via list_roots_callback), so there is no mutation it could observe to auto-emit on; the SDK " + "provides send_roots_list_changed() for the host to call when its roots change, and that " + "emission path is covered by roots:list-changed." + ), + ), + "roots:list:basic": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#listing-roots", + behavior=( + "A roots/list request from a server handler is answered by the client's roots callback, and " + "the returned roots (uri, name) reach the handler." + ), + ), + "roots:list:client-error": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#error-handling", + behavior="A roots callback that answers with an error surfaces to the requesting handler as an MCPError.", + ), + "roots:list:empty": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#listing-roots", + behavior="An empty roots list is a valid response and reaches the handler as such.", + ), + "roots:list:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#error-handling", + behavior=( + "A roots/list request to a client that did not declare the roots capability is answered with " + "-32601 Method not found." + ), + divergence=Divergence( + note="The client's default callback answers with -32600 Invalid request instead of -32601.", + ), + ), + "roots:uri:file-scheme": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#root", + behavior="Every root returned by the client identifies itself with a file:// URI.", + deferred=( + "Schema-level validation: the FileUrl type on Root.uri rejects any non-file:// scheme at " + "construction and at parse, so a non-conforming root cannot reach the wire from either side; " + "type-level coverage belongs in tests/test_types.py rather than this interaction suite." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # list_changed & dynamic registration + # ═══════════════════════════════════════════════════════════════════════════ + "client:list-changed:auto-refresh": Requirement( + source="sdk", + behavior=( + "A client configured to react to list_changed notifications automatically re-fetches the " + "corresponding list and delivers the fresh result to its callback." + ), + deferred=( + "Not implemented in the SDK: the client has no list-changed auto-refresh mechanism; " + "notifications are only delivered to the message handler." + ), + ), + "client:list-changed:capability-gated": Requirement( + source="sdk", + behavior=( + "The client does not activate list-changed handling for a kind the server did not advertise " + "with listChanged true." + ), + deferred="Not implemented in the SDK: no client-side list-changed handling exists to gate.", + ), + "client:list-changed:signal-only": Requirement( + source="sdk", + behavior="A client configured for signal-only list-changed handling is notified without auto-refreshing.", + deferred="Not implemented in the SDK: no client-side list-changed handling exists.", + ), + "mcpserver:list-changed:debounce": Requirement( + source="sdk", + behavior=( + "Bursts of registration changes on MCPServer are debounced into one list_changed notification per kind." + ), + deferred=( + "Not implemented in the SDK: MCPServer does not send list_changed notifications on " + "registration changes at all (see mcpserver:register:post-connect), so there is nothing to " + "debounce." + ), + ), + "mcpserver:register:post-connect": Requirement( + source="sdk", + behavior=( + "A tool, resource, or prompt registered or removed after the client connected appears in (or " + "disappears from) the corresponding list results, and the change is announced with a " + "list_changed notification." + ), + divergence=Divergence( + note=( + "MCPServer never sends list_changed notifications on registration changes, so a connected " + "client cannot learn that the set changed without polling." + ), + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Pagination + # ═══════════════════════════════════════════════════════════════════════════ + "pagination:exhaustion": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#response-format", + behavior=( + "Following nextCursor until it is absent yields every page exactly once; a result without " + "nextCursor ends the sequence." + ), + ), + "pagination:invalid-cursor": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#error-handling", + behavior="A list request with an invalid cursor returns JSON-RPC error -32602 (Invalid params).", + ), + "pagination:client:cursor-handling": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#implementation-guidelines", + behavior=( + "The client treats cursors as opaque tokens — it does not parse, modify, or persist them — " + "and does not assume a fixed page size." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Tasks (experimental) + # ═══════════════════════════════════════════════════════════════════════════ + "tasks:auth:context-isolation": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-isolation-and-access-control", + behavior=( + "When an authorization context is available, task operations are scoped to the context that " + "created the task: other contexts cannot get it, retrieve its result, cancel it, or see it in " + "tasks/list." + ), + transports=("streamable-http",), + deferred=_TASKS_DEFERRAL, + ), + "tasks:bidirectional": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#definitions", + behavior="Task APIs are bidirectional: the server may create, get, list, and cancel tasks on the client.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:cancel:no-handler-abort": Requirement( + source="sdk", + behavior=( + "tasks/cancel marks the task cancelled without aborting the originating request handler " + "(the spec says receivers SHOULD attempt to stop execution)." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:cancel:remains-cancelled": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-cancellation", + behavior=( + "After tasks/cancel, the task remains cancelled even if the underlying handler subsequently " + "completes or fails." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:cancel:terminal-rejected": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-cancellation", + behavior="tasks/cancel on a task already in a terminal state returns Invalid params (-32602).", + deferred=_TASKS_DEFERRAL, + ), + "tasks:cancel:working": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-cancellation", + behavior="tasks/cancel on a working task transitions it to cancelled and returns the updated task.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:create:ttl-honored": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#ttl-and-resource-management", + behavior=( + "tasks/get responses include the actual ttl applied by the receiver (or null for unlimited); " + "the create-task result carries the same value." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:create:via-tool-call": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#creating-tasks", + behavior="A task-augmented tools/call returns a create-task result instead of the tool result.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:get": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#getting-tasks", + behavior="tasks/get returns the task's current status, ttl, timestamps, and status message.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:lifecycle:initial-working": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-status-lifecycle", + behavior="A newly created task has status 'working'.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:lifecycle:input-required": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#input-required-status", + behavior=( + "While a task awaits a side-channel client response its status is input_required; once the " + "response arrives the task leaves input_required (typically returning to working)." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:list:invalid-cursor": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#protocol-errors", + behavior="tasks/list with an invalid cursor returns Invalid params (-32602).", + deferred=_TASKS_DEFERRAL, + ), + "tasks:list:pagination": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#listing-tasks", + behavior="tasks/list returns created tasks and supports cursor pagination.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:no-capability:ignore-task-param": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-support-and-handling", + behavior=( + "A receiver that did not declare task capability for a request type processes the request " + "normally and returns the ordinary result, ignoring the task augmentation." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:progress:after-create": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-progress-notifications", + behavior=( + "After the create-task result, progress notifications keyed to the original progress token " + "continue to reach the caller until the task is terminal." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:request-cancel:no-task-cancel": Requirement( + source="sdk", + behavior="A cancellation notification for the originating request does not auto-cancel the created task.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:result:failed": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-execution-errors", + behavior="tasks/result for a failed task returns the failure result (isError true).", + deferred=_TASKS_DEFERRAL, + ), + "tasks:result:related-task-meta": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#related-task-metadata", + behavior="The tasks/result response carries related-task _meta naming the requested task.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:result:terminal": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#result-retrieval", + behavior="tasks/result for a completed task returns the stored result of the original request type.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:drain-fifo": Requirement( + source="sdk", + behavior="tasks/result drains queued related-task messages in FIFO order before returning the final result.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:drop-on-cancel": Requirement( + source="sdk", + behavior="When a task is cancelled before tasks/result, queued related-task messages are dropped.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:elicitation": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#input-required-status", + behavior=( + "An elicitation issued mid-task is delivered through the tasks/result side-channel, and the " + "client's response routes back to the handler." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:queue": Requirement( + source="sdk", + behavior=( + "Server-to-client requests with related-task metadata sent while no tasks/result is open are queued." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:sampling": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#input-required-status", + behavior=( + "A sampling request issued mid-task is delivered through the tasks/result side-channel, and " + "the client's response routes back to the task." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:stream": Requirement( + source="sdk", + behavior=( + "Calling tasks/result while the task is working streams related-task messages as they are " + "produced, then returns the result." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:status-notification": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-status-notification", + behavior="Task status notifications deliver status updates carrying the full task fields.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:tool-level:forbidden-with-task-32601": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#tool-level-negotiation", + behavior=( + "A task-augmented tools/call on a tool that does not support tasks returns Method not found (-32601)." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:tool-level:required-no-task-32601": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#tool-level-negotiation", + behavior=("A plain tools/call on a tool that requires task augmentation returns Method not found (-32601)."), + deferred=_TASKS_DEFERRAL, + ), + "tasks:unknown-id": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#protocol-errors", + behavior="tasks/get, tasks/result, and tasks/cancel for an unknown task id return Invalid params (-32602).", + deferred=_TASKS_DEFERRAL, + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Transports (in-suite coverage) + # ═══════════════════════════════════════════════════════════════════════════ + "transport:streamable-http:stateful": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "The interaction round trip (initialize, tool calls, tool errors) works through the " + "streamable HTTP framing in its default stateful SSE-response mode." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:json-response": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior="The interaction round trip works when the server answers with plain JSON instead of SSE.", + transports=("streamable-http",), + ), + "transport:streamable-http:stateless": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "The interaction round trip works in stateless mode, where every request is served by a " + "fresh transport with no session id." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:notifications": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "Notifications emitted during a request are delivered on that request's SSE stream and reach " + "the client's callbacks, in order, before the response." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:stateless-restrictions": Requirement( + source="sdk", + behavior=( + "A handler that attempts a server-initiated request in stateless mode fails with an error " + "result, because there is no session to call back through." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:unrelated-messages": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "A server-to-client message that is not related to an in-flight request is routed to the " + "standalone GET stream and delivered to the client listening on it, not to any request's " + "own stream." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:server-to-client": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "A server-initiated request nested inside an in-flight call round-trips over stateful streamable HTTP." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:resumability": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior="A client that reconnects with Last-Event-ID receives the events it missed.", + transports=("streamable-http",), + ), + "transport:streamable-http:origin-validation": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#security-warning", + behavior="Requests with an invalid Origin header are rejected with 403 before reaching the session.", + transports=("streamable-http",), + ), + "transport:sse": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#backwards-compatibility", + behavior=( + "A client connected over the legacy HTTP+SSE transport completes the handshake and round-trips " + "requests, with server messages delivered on the SSE stream." + ), + transports=("sse",), + ), + "transport:sse:endpoint-event": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#backwards-compatibility", + behavior="Opening the SSE stream delivers an `endpoint` event naming the message-POST URL as the first event.", + transports=("sse",), + ), + "transport:sse:post:session-routing": Requirement( + source="sdk", + behavior=( + "The endpoint URL carries a fresh session identifier; the server registers the session before " + "the endpoint event is sent and releases it when the stream disconnects, and a POST that names " + "no session id, a malformed session id, or an unknown session id is rejected (400/400/404)." + ), + transports=("sse",), + ), + "transport:stdio": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#stdio", + behavior=( + "A Client connected to a real SDK Server over stdio initializes, calls a tool with arguments, " + "and receives notifications and results over the child process's stdin/stdout." + ), + transports=("stdio",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Hosting: session lifecycle + # ═══════════════════════════════════════════════════════════════════════════ + "hosting:session:cors-expose": Requirement( + source="sdk", + behavior="CORS configuration exposes the Mcp-Session-Id header so browser clients can read it.", + transports=("streamable-http",), + deferred="Not implemented in the SDK: CORS configuration is left to the hosting ASGI application.", + ), + "hosting:session:create": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior=( + "An initialize POST without a session id creates a session and returns Mcp-Session-Id in the " + "response headers." + ), + transports=("streamable-http",), + ), + "hosting:session:delete": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="DELETE with a valid Mcp-Session-Id terminates the session.", + transports=("streamable-http",), + ), + "hosting:session:id-charset": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="Generated Mcp-Session-Id values contain only visible ASCII characters.", + transports=("streamable-http",), + ), + "hosting:session:isolation": Requirement( + source="sdk", + behavior="Each session gets its own server instance; closing one session does not affect others.", + transports=("streamable-http",), + ), + "hosting:session:missing-id": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="A non-initialize POST without Mcp-Session-Id in stateful mode returns 400.", + transports=("streamable-http",), + ), + "hosting:session:post-termination-404": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior=( + "After a session is terminated, any further request carrying that session ID is answered with " + "404 Not Found." + ), + transports=("streamable-http",), + ), + "hosting:session:reinitialize": Requirement( + source="sdk", + behavior="A second initialize on an already-initialized session transport is rejected.", + transports=("streamable-http",), + divergence=Divergence( + note=( + "The transport forwards a second initialize carrying the existing session ID to the running " + "server, which answers it as a fresh handshake; nothing rejects re-initialization." + ), + ), + ), + "hosting:session:reuse": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="A POST carrying a valid Mcp-Session-Id routes to that session's transport with state preserved.", + transports=("streamable-http",), + ), + "hosting:session:unknown-id": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="A POST, GET, or DELETE with an unknown Mcp-Session-Id returns 404.", + transports=("streamable-http",), + ), + "hosting:stateless:concurrent-clients": Requirement( + source="sdk", + behavior="Multiple independent clients can connect to a stateless server concurrently.", + transports=("streamable-http",), + ), + "hosting:stateless:no-reuse": Requirement( + source="sdk", + behavior="A stateless per-request transport cannot be reused for a second request.", + transports=("streamable-http",), + ), + "hosting:stateless:no-session-id": Requirement( + source="sdk", + behavior="In stateless mode no Mcp-Session-Id is emitted and no session validation is performed.", + transports=("streamable-http",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Hosting: auth + # ═══════════════════════════════════════════════════════════════════════════ + "hosting:auth:as-router": Requirement( + source="sdk", + behavior=( + "The authorization-server routes expose the authorize, token, and registration endpoints " + "(and revocation when supported)." + ), + transports=("streamable-http",), + ), + "hosting:auth:aud-validation": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#access-token-usage", + behavior="The resource server validates that the token audience matches its resource identifier.", + transports=("streamable-http",), + divergence=Divergence( + note=( + "BearerAuthBackend never inspects AccessToken.resource; a token issued for a different " + "resource is accepted. Spec MUST." + ), + ), + ), + "hosting:auth:authinfo-propagates": Requirement( + source="sdk", + behavior="A valid token's auth info is exposed to request handlers.", + transports=("streamable-http",), + ), + "hosting:auth:expired-401": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#token-handling", + behavior="An expired token returns 401 invalid_token.", + transports=("streamable-http",), + divergence=Divergence( + note="The challenge carries no `scope` parameter; see the note on hosting:auth:missing-401.", + ), + ), + "hosting:auth:invalid-401": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#token-handling", + behavior="A malformed bearer token or token-verification failure returns 401 with WWW-Authenticate.", + transports=("streamable-http",), + divergence=Divergence( + note="The challenge carries no `scope` parameter; see the note on hosting:auth:missing-401.", + ), + ), + "hosting:auth:metadata-endpoints": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-location", + behavior=( + "The MCP server publishes protected-resource metadata at its well-known endpoint, and the " + "authorization server (which the SDK can also host) publishes authorization-server metadata " + "at its own." + ), + transports=("streamable-http",), + ), + "hosting:auth:missing-401": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#protected-resource-metadata-discovery-requirements", + behavior=( + "A request without an Authorization header is rejected with 401; the WWW-Authenticate header " + "carries resource_metadata (one of the spec's two permitted discovery mechanisms)." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The SDK never emits a `scope` parameter in any WWW-Authenticate challenge — neither the " + "discovery-time 401 (#protected-resource-metadata-discovery-requirements SHOULD) nor the " + "runtime 403 (#runtime-insufficient-scope-errors SHOULD); and for the no-credentials case " + 'it emits error="invalid_token", which RFC 6750 Section 3.1 says SHOULD NOT appear when no ' + "authentication information was presented." + ), + ), + ), + "hosting:auth:prm:authorization-servers-field": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-location", + behavior=( + "The protected-resource metadata document includes an authorization_servers array with at least one entry." + ), + transports=("streamable-http",), + ), + "hosting:auth:query-token-ignored": Requirement( + source="sdk", + behavior=( + "An access token presented in the URI query string is not accepted; the request is treated as " + "unauthenticated." + ), + transports=("streamable-http",), + ), + "hosting:auth:scope-403": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#runtime-insufficient-scope-errors", + behavior=( + "A token lacking a required scope returns 403 with WWW-Authenticate carrying " + "insufficient_scope, the required scope, and resource_metadata." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + 'The SDK emits error="insufficient_scope" and error_description but never the `scope` ' + "parameter the spec SHOULD include; the SDK client reads `scope` from this header to drive " + "step-up (utils.py extract_scope_from_www_auth) — a resource-server/client asymmetry." + ), + ), + ), + "hosting:auth:as:authorize-requires-pkce": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-code-protection", + behavior=( + "The bundled authorization endpoint rejects an authorize request that omits " + "`code_challenge` with `invalid_request`." + ), + transports=("streamable-http",), + ), + "hosting:auth:as:verifier-mismatch": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-code-protection", + behavior=( + "The bundled token endpoint rejects an authorization-code exchange whose `code_verifier` " + "does not hash to the stored `code_challenge` with `invalid_grant`." + ), + transports=("streamable-http",), + ), + "hosting:auth:as:code-single-use": Requirement( + source="sdk", + behavior=( + "An authorization code can be exchanged exactly once; a second exchange of the same code " + "is rejected with `invalid_grant`. Enforced by the provider deleting the code on first use; " + "the handler relies on `load_authorization_code` returning None." + ), + transports=("streamable-http",), + ), + "hosting:auth:as:redirect-uri-binding": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#open-redirection", + behavior=( + "The bundled token endpoint rejects an authorization-code exchange whose `redirect_uri` " + "differs from the one used at authorize; the bundled authorize endpoint rejects a " + "`redirect_uri` not in the client's registered list without redirecting to it." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "RFC 6749 §5.2 assigns redirect_uri mismatch at the token endpoint to invalid_grant; " + "the SDK's TokenHandler returns invalid_request (src/mcp/server/auth/handlers/token.py:157). " + "The rejection itself is the security-relevant property and is correct." + ), + ), + ), + "hosting:auth:as:redirect-uri-scheme": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#communication-security", + behavior=( + "The bundled registration endpoint accepts only redirect URIs that use HTTPS or target a loopback host." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "Not enforced: the registration handler models redirect_uris as AnyUrl with no scheme or " + "host check, so http://evil.example/callback is accepted and registered. The spec's " + "localhost-or-HTTPS rule is left to the provider implementation." + ), + ), + ), + "hosting:auth:as:token-cache-headers": Requirement( + source="sdk", + behavior=("Every token-endpoint response carries `Cache-Control: no-store` and `Pragma: no-cache`."), + transports=("streamable-http",), + ), + "hosting:auth:as:register-error-response": Requirement( + source="sdk", + behavior=( + "The bundled registration endpoint answers invalid client metadata with HTTP 400 and an " + "RFC 7591 error body." + ), + transports=("streamable-http",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Hosting: resumability + # ═══════════════════════════════════════════════════════════════════════════ + "hosting:resume:bad-event-id": Requirement( + source="sdk", + behavior="A Last-Event-ID that cannot be mapped to a stream is rejected.", + transports=("streamable-http",), + divergence=Divergence( + note=( + "The replay path returns an empty SSE stream rather than rejecting an unknown " + "Last-Event-ID; the client cannot tell an unknown ID apart from a stream with no missed " + "events." + ), + ), + ), + "hosting:resume:buffered-replay": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior="Notifications emitted while no client is connected are replayed in order on reconnect.", + transports=("streamable-http",), + ), + "hosting:resume:close-stream": Requirement( + source="sdk", + behavior="Handlers can close an SSE stream cleanly when an event store is configured.", + transports=("streamable-http",), + ), + "hosting:resume:event-ids": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior="With an event store configured, every SSE event carries an id field.", + transports=("streamable-http",), + ), + "hosting:resume:priming": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A server-initiated SSE stream begins with a priming event carrying an event ID and an empty " + "data field; a server that closes the connection before terminating the stream sends an SSE " + "retry field first." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The retry hint is attached to the priming event itself rather than sent as a separate " + "event before the connection closes, and a priming event is only sent when an event store " + "is configured and the negotiated protocol version is at least 2025-11-25." + ), + ), + ), + "hosting:resume:replay": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior="GET with Last-Event-ID replays stored events for that stream after the given id.", + transports=("streamable-http",), + ), + "hosting:resume:stream-scoped": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior="Replay via Last-Event-ID returns only messages from the stream that event id belongs to.", + transports=("streamable-http",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Hosting: HTTP semantics + # ═══════════════════════════════════════════════════════════════════════════ + "hosting:http:accept-406": Requirement( + source="sdk", + behavior="A request whose Accept header does not allow the response representation returns 406.", + transports=("streamable-http",), + ), + "hosting:http:batch": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A POST body is a single JSON-RPC message; batched arrays are rejected for protocol revisions " + "that forbid them." + ), + transports=("streamable-http",), + ), + "hosting:http:content-type-415": Requirement( + source="sdk", + behavior="A POST with a Content-Type other than application/json returns 415.", + transports=("streamable-http",), + divergence=Divergence( + note=( + "The transport-security middleware rejects a non-JSON Content-Type with 400 'Invalid " + "Content-Type header' before the request reaches the transport, so the transport's own 415 " + "path is unreachable through any public entry point." + ), + ), + ), + "hosting:http:disconnect-not-cancel": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A client connection drop during an in-flight request does not cancel the server-side " + "handler; the request continues and its result remains retrievable." + ), + transports=("streamable-http",), + ), + "hosting:http:dns-rebinding": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#security-warning", + behavior=( + "The Origin header is validated on every incoming connection; a request with an invalid " + "Origin is rejected with 403 Forbidden." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The spec's Origin validation is an unconditional MUST; the SDK enables it only when the " + "host is a localhost address or explicit TransportSecuritySettings are passed (with no " + "settings, no Origin validation runs), and additionally validates the Host header " + "(returning 421 on mismatch), which the spec does not require." + ), + ), + ), + "hosting:http:json-response-mode": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior="With JSON response mode enabled, POST returns application/json instead of SSE.", + transports=("streamable-http",), + ), + "hosting:http:method-405": Requirement( + source="sdk", + behavior="An unsupported HTTP method on the MCP endpoint returns 405.", + transports=("streamable-http",), + ), + "hosting:http:no-broadcast": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#multiple-connections", + behavior=( + "When multiple SSE streams are open for a session, each server-originated message is sent on " + "exactly one stream, never duplicated." + ), + transports=("streamable-http",), + ), + "hosting:http:notifications-202": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior="A POST containing only notifications or responses returns 202 with no body.", + transports=("streamable-http",), + ), + "hosting:http:onerror": Requirement( + source="sdk", + behavior="Transport-level rejections are reported through an error callback on the server transport.", + transports=("streamable-http",), + deferred="Not implemented in the SDK: the server transport has no error callback; rejections are logged.", + ), + "hosting:http:parse-error-400": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A POST body that is not valid JSON or not a valid JSON-RPC message is rejected with HTTP 400; " + "the body may carry a JSON-RPC error response (the SDK sends a Parse error body)." + ), + transports=("streamable-http",), + ), + "hosting:http:protocol-version-400": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#protocol-version-header", + behavior="An invalid or unsupported MCP-Protocol-Version header returns 400 Bad Request.", + transports=("streamable-http",), + ), + "hosting:http:protocol-version-default": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#protocol-version-header", + behavior=( + "When no MCP-Protocol-Version header is received and the version cannot be determined another " + "way, the server assumes protocol version 2025-03-26." + ), + transports=("streamable-http",), + ), + "hosting:http:response-same-connection": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A response is delivered on the SSE stream opened by the POST that carried its request (or " + "that stream's resumed continuation), not on an unrelated stream." + ), + transports=("streamable-http",), + ), + "hosting:http:second-sse-rejected": Requirement( + source="sdk", + behavior="A second concurrent standalone GET SSE stream on the same session is rejected.", + transports=("streamable-http",), + ), + "hosting:http:sse-close-after-response": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior="The server terminates a POST-initiated SSE stream after writing the JSON-RPC response.", + transports=("streamable-http",), + ), + "hosting:http:standalone-sse": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#listening-for-messages-from-the-server", + behavior="GET opens a standalone SSE stream that receives server-initiated messages.", + transports=("streamable-http",), + ), + "hosting:http:standalone-sse-no-response": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#listening-for-messages-from-the-server", + behavior=( + "The standalone GET SSE stream carries server requests and notifications but never a JSON-RPC " + "response, except when resuming a prior request stream." + ), + transports=("streamable-http",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Client transport: streamable HTTP + # ═══════════════════════════════════════════════════════════════════════════ + "client-transport:http:404-surfaces": Requirement( + source="sdk", + behavior="A 404 (session expired) on a request surfaces as an error to the caller.", + transports=("streamable-http",), + ), + "client-transport:http:session-404-reinitialize": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior=( + "A 404 in response to a request carrying a session ID makes the client start a new session " + "with a fresh InitializeRequest and no session ID attached." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The client surfaces the 404 as an error to the caller instead of re-initializing a new " + "session; the spec's MUST is not satisfied." + ), + ), + deferred=( + "Not implemented in the SDK: the client surfaces a Session terminated error instead of " + "re-initializing (the surfaced error is pinned by client-transport:http:404-surfaces)." + ), + ), + "client-transport:http:accept-header-get": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#listening-for-messages-from-the-server", + behavior="The client GET to the MCP endpoint includes an Accept header listing text/event-stream.", + transports=("streamable-http",), + ), + "client-transport:http:accept-header-post": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "Every client POST to the MCP endpoint includes an Accept header listing both application/json " + "and text/event-stream." + ), + transports=("streamable-http",), + ), + "client-transport:http:concurrent-streams": Requirement( + source="sdk", + behavior="Multiple concurrent POST-initiated SSE streams each deliver their response to the right caller.", + transports=("streamable-http",), + ), + "client-transport:http:custom-client": Requirement( + source="sdk", + behavior=( + "A caller-supplied HTTP client (and its event hooks and headers) is used for all MCP traffic, " + "including auth flows." + ), + transports=("streamable-http",), + ), + "client-transport:http:custom-headers": Requirement( + source="sdk", + behavior="Caller-supplied headers are sent on every POST, GET, and DELETE to the MCP endpoint.", + transports=("streamable-http",), + ), + "client-transport:http:json-response-parsed": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior="A Content-Type application/json response is parsed as a single JSON-RPC message.", + transports=("streamable-http",), + ), + "client-transport:http:no-reconnect-after-close": Requirement( + source="sdk", + behavior="After the transport is closed, no further reconnection attempts are scheduled.", + transports=("streamable-http",), + ), + "client-transport:http:no-reconnect-after-response": Requirement( + source="sdk", + behavior="A POST-initiated stream that already delivered its response is not reconnected when it closes.", + transports=("streamable-http",), + ), + "client-transport:http:protocol-version-header": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#protocol-version-header", + behavior=( + "After initialization, the client sends the negotiated MCP-Protocol-Version header on every " + "subsequent HTTP request." + ), + transports=("streamable-http",), + ), + "client-transport:http:protocol-version-stored": Requirement( + source="sdk", + behavior=( + "The client transport stores the negotiated protocol version and sends it on every subsequent request." + ), + transports=("streamable-http",), + ), + "client-transport:http:reconnect-get": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior=( + "A standalone GET SSE stream that errors is reconnected with the Last-Event-ID of the last received event." + ), + transports=("streamable-http",), + deferred=( + "The server's standalone GET stream emits no priming event or retry hint, so the client's " + "reconnection path always sleeps the hard-coded 1 s default; a deterministic in-process test " + "would require accepting that real-time wait. The POST-stream reconnection path is covered " + "by client-transport:http:reconnect-post-priming." + ), + ), + "client-transport:http:reconnect-post-priming": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A POST-initiated SSE stream that errors before delivering its response is reconnected only " + "if a priming event (an event carrying an ID) was received on it." + ), + transports=("streamable-http",), + ), + "client-transport:http:reconnect-retry-value": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior="Reconnection delay honours the server-provided SSE retry value when one was sent.", + transports=("streamable-http",), + ), + "client-transport:http:resume-stream-api": Requirement( + source="sdk", + behavior=( + "The client can capture a resumption token, reconnect with the same session id, and receive " + "the notifications it missed." + ), + transports=("streamable-http",), + ), + "client-transport:http:session-stored": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior=( + "The Mcp-Session-Id returned by initialize is stored by the client transport and sent on " + "every subsequent request." + ), + transports=("streamable-http",), + ), + "client-transport:http:sse-405-tolerated": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#listening-for-messages-from-the-server", + behavior="Opening the standalone GET SSE stream tolerates a 405 response without failing the connection.", + transports=("streamable-http",), + ), + "client-transport:http:terminate-405-ok": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="Session termination succeeds without error if the server answers 405 (termination unsupported).", + transports=("streamable-http",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Client auth + # ═══════════════════════════════════════════════════════════════════════════ + "client-auth:401-after-auth-throws": Requirement( + source="sdk", + behavior=( + "If the server still returns 401 after a successful authorization, the client fails instead of looping." + ), + transports=("streamable-http",), + ), + "client-auth:401-triggers-flow": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#protected-resource-metadata-discovery-requirements", + behavior="A 401 on a request triggers the OAuth authorization flow once.", + transports=("streamable-http",), + ), + "client-auth:403-scope-upgrade": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#step-up-authorization-flow", + behavior=( + "A 403 with WWW-Authenticate triggers a scope-upgrade authorization attempt; repeated 403s do not loop." + ), + transports=("streamable-http",), + ), + "client-auth:as-metadata-discovery:priority-order": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-metadata-discovery", + behavior=( + "The client discovers authorization-server metadata by trying, in order, the OAuth " + "path-inserted, OIDC path-inserted, and OIDC path-appended well-known URLs (with the " + "root-path forms when the issuer URL has no path)." + ), + transports=("streamable-http",), + ), + "client-auth:as-metadata-discovery:issuer-validation": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-metadata-discovery", + behavior=( + "The client rejects authorization-server metadata whose issuer does not match the URL the " + "metadata was retrieved from (RFC 8414 section 3.3)." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The SDK parses authorization-server metadata without comparing issuer to the discovery " + "URL; a mismatched issuer is accepted and the flow proceeds." + ), + ), + ), + "client-auth:authorize:error-surfaces": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-flow-steps", + behavior=( + "An OAuth error redirect from the authorize endpoint aborts the flow before any token " + "request is issued, surfacing as an error to the caller." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The callback contract has no error form, so the client surfaces 'No authorization code " + "received' rather than the redirect's `error`/`error_description` values." + ), + ), + ), + "client-auth:authorize:offline-access-consent": Requirement( + source="sdk", + behavior=( + "When the authorization server's metadata advertises offline_access in scopes_supported and " + "the client uses the refresh_token grant, offline_access is appended to the requested scope " + "and prompt=consent is added to the authorize request." + ), + transports=("streamable-http",), + ), + "client-auth:bearer-header:every-request": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#token-requirements", + behavior=( + "Once authorized, the client sends the bearer token in the Authorization header on every HTTP " + "request to the MCP server, never in the query string." + ), + transports=("streamable-http",), + ), + "client-auth:cimd": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#client-id-metadata-documents", + behavior="The client can use a client-ID metadata document URL as its OAuth client_id instead of registration.", + transports=("streamable-http",), + ), + "client-auth:client-credentials": Requirement( + source="sdk", + behavior=( + "A client-credentials provider obtains a token without user interaction and the resulting " + "bearer token authorizes subsequent requests." + ), + transports=("streamable-http",), + ), + "client-auth:dcr:registration-error-surfaces": Requirement( + source="sdk", + behavior=( + "A 400 from the registration endpoint surfaces to the caller as an OAuthRegistrationError " + "carrying the status and the server's RFC 7591 error body." + ), + transports=("streamable-http",), + ), + "client-auth:dcr": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#dynamic-client-registration", + behavior=( + "The client performs dynamic client registration against the authorization server when no " + "client_id is preconfigured." + ), + transports=("streamable-http",), + ), + "client-auth:invalid-client-clears-all": Requirement( + source="sdk", + behavior=( + "An invalid-client or unauthorized-client error during authorization invalidates all stored credentials." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The token-response handlers do not parse the error body; an invalid_client or " + "unauthorized_client response leaves stored client_info untouched. The TypeScript SDK " + "clears it." + ), + ), + deferred=( + "Not implemented in the SDK: no token-response path inspects the error code to decide " + "whether to clear client_info." + ), + ), + "client-auth:invalid-grant-clears-tokens": Requirement( + source="sdk", + behavior="An invalid-grant error during authorization invalidates only the stored tokens.", + transports=("streamable-http",), + ), + "client-auth:pkce:refuse-if-unsupported": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-code-protection", + behavior=( + "The client refuses to proceed when the authorization server's metadata does not include " + "code_challenge_methods_supported, since PKCE support cannot be verified." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The client never inspects code_challenge_methods_supported and proceeds with PKCE S256 " + "regardless; the spec MUST is not enforced." + ), + ), + ), + "client-auth:pkce:s256": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-code-protection", + behavior=( + "The authorization request includes a PKCE S256 code challenge and the token request includes " + "the matching verifier." + ), + transports=("streamable-http",), + ), + "client-auth:pre-registration": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#preregistration", + behavior=( + "A client with statically preconfigured credentials skips dynamic registration and uses them directly." + ), + transports=("streamable-http",), + ), + "client-auth:private-key-jwt": Requirement( + source="sdk", + behavior="The client can authenticate the client-credentials grant with a signed JWT assertion.", + transports=("streamable-http",), + ), + "client-auth:prm-discovery:fallback-order": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#protected-resource-metadata-discovery-requirements", + behavior=( + "The client uses resource_metadata from WWW-Authenticate when present, then falls back to the " + "well-known protected-resource locations in the documented order." + ), + transports=("streamable-http",), + ), + "client-auth:prm-discovery:no-prm-fallback": Requirement( + source="sdk", + behavior=( + "When every protected-resource metadata probe fails, the client falls back to discovering " + "authorization-server metadata directly at the MCP server's origin (the legacy 2025-03-26 path) " + "rather than aborting." + ), + transports=("streamable-http",), + ), + "client-auth:prm-resource-mismatch": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-location", + behavior=( + "The client refuses to proceed when the protected-resource metadata's resource field does not " + "match the server URL it is connecting to." + ), + transports=("streamable-http",), + ), + "client-auth:refresh:transparent": Requirement( + source="sdk", + behavior=( + "An access token the client considers expired is transparently refreshed before the next " + "request, using the stored refresh token; the refresh request includes the resource indicator " + "and the new token is persisted." + ), + transports=("streamable-http",), + ), + "client-auth:resource-parameter": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#resource-parameter-implementation", + behavior=( + "The client includes the canonical server URI as the resource parameter in both the " + "authorization request and the token request." + ), + transports=("streamable-http",), + ), + "client-auth:scope-selection:priority": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#scope-selection-strategy", + behavior=( + "Client selects requested scope from the WWW-Authenticate scope param if present; otherwise " + "uses scopes_supported from the PRM document; otherwise omits scope." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The SDK inserts an extra fallback step between PRM and omit: if the authorization " + "server metadata advertises scopes_supported, that list is used (client/auth/utils.py). " + "This is beyond the spec's two-step chain." + ), + ), + ), + "client-auth:state:verify": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#open-redirection", + behavior=( + "A state parameter is included in the authorization URL, and authorization results with a " + "missing or mismatched state are discarded." + ), + transports=("streamable-http",), + ), + "client-auth:token-endpoint-auth-method": Requirement( + source="sdk", + behavior="The client authenticates to the token endpoint using the auth method established at registration.", + transports=("streamable-http",), + ), + "client-auth:token-provenance": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#token-handling", + behavior=( + "The client sends the MCP server only tokens issued by that server's authorization server, " + "never tokens obtained elsewhere." + ), + transports=("streamable-http",), + deferred=( + "Untestable negative through the public API: there is no path to inject a token obtained " + "elsewhere into the auth provider's state, so the absence cannot be observed end to end." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # stdio transport + # ═══════════════════════════════════════════════════════════════════════════ + "transport:stdio:clean-shutdown": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#shutdown", + behavior="Closing the client transport closes the child process's stdin and the server exits cleanly.", + transports=("stdio",), + ), + "transport:stdio:stream-purity": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#stdio", + behavior=( + "Nothing that is not a valid MCP message is written to the server's stdout, and nothing that " + "is not a valid MCP message is written to its stdin." + ), + transports=("stdio",), + divergence=Divergence( + note=( + "stdio_server's own writes satisfy this, but it does not redirect or guard sys.stdout: " + "handler code that calls print() writes directly to the protocol stream and corrupts the " + "framing. The spec MUST is satisfied only as long as application code behaves." + ), + ), + ), + "transport:stdio:no-embedded-newlines": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#stdio", + behavior="Serialized JSON-RPC messages on stdio contain no embedded newlines; one message per line.", + transports=("stdio",), + ), + "transport:stdio:shutdown-escalation": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#stdio", + behavior=( + "If the server process does not exit after stdin is closed, the client transport terminates " + "it (and kills it if still alive) after a grace period." + ), + transports=("stdio",), + deferred=( + "A server that ignores stdin close takes the full PROCESS_TERMINATION_TIMEOUT (2.0 s) grace " + "period plus up to a further 2.0 s for SIGTERM/SIGKILL escalation; testing that path is " + "real-time-bound (the constant is module-level with no public override) and so is deliberately " + "excluded from this suite. Covered by tests/client/test_stdio.py." + ), + ), + "transport:stdio:stderr-passthrough": Requirement( + source="sdk", + behavior="Server stderr is available to the client and is not consumed by the transport.", + transports=("stdio",), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Composite end-to-end flows + # ═══════════════════════════════════════════════════════════════════════════ + "flow:compat:dual-transport-server": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#backwards-compatibility", + behavior=( + "A single server instance can serve streamable HTTP and the legacy SSE transport " + "concurrently; clients on either transport can call the same tools." + ), + transports=("streamable-http", "sse"), + ), + "flow:compat:streamable-then-sse-fallback": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#backwards-compatibility", + behavior=( + "When a streamable HTTP initialize fails with 400, 404, or 405, falling back to the legacy " + "SSE client transport against the same server connects successfully." + ), + transports=("streamable-http", "sse"), + divergence=Divergence( + note=( + "The SDK provides no automatic streamable-HTTP-to-SSE client fallback; the spec's " + "client-side SHOULD is left to the application to compose from streamable_http_client " + "and sse_client. Both halves are independently proven by the matrix." + ), + ), + deferred=( + "A demonstration test would only re-prove what the matrix already covers (an SSE-only " + "server is reachable via sse_client; an unmounted route returns 404), with the application " + "doing the fallback in between rather than the SDK." + ), + ), + "flow:elicitation:multi-step-form": Requirement( + source="sdk", + behavior=( + "A single tool handler issues sequential elicitations; an accept on one step feeds the next, " + "and a decline or cancel at any step short-circuits to a final result." + ), + ), + "flow:elicitation:url-at-session-init": Requirement( + source="sdk", + behavior=( + "The server can issue a URL-mode elicitation over the standalone GET stream immediately after " + "session initialization, before any client request." + ), + transports=("streamable-http",), + deferred=( + "Not implemented in the SDK: no public per-session post-initialization hook exists on either " + "server flavour (Server.lifespan runs at server startup, not per session; ServerSession " + "handles the initialized notification internally with no callback). Driving 'before any " + "client request' deterministically would also require knowing the standalone GET stream is " + "established, which has no synchronization signal." + ), + ), + "flow:elicitation:url-required-then-retry": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#url-elicitation-required-error", + behavior=( + "A tool call rejected with the URL-elicitation-required error can be retried successfully " + "after the client completes the URL flow and the server announces completion." + ), + ), + "flow:multi-client:stateful-isolation": Requirement( + source="sdk", + behavior=( + "Independent clients connected to one stateful server each receive a distinct session and " + "only the notifications produced by their own requests." + ), + transports=("streamable-http",), + ), + "flow:oauth:authorization-code-roundtrip": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-flow-steps", + behavior=( + "Connecting to a protected server walks the authorization-code flow end to end: the first " + "attempt requires authorization, the code is exchanged, and a subsequent connection succeeds." + ), + transports=("streamable-http",), + ), + "flow:resume:tool-call-resumption-token": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior=( + "A tool call interrupted mid-stream is transparently resumed by the client transport using " + "the last-seen event id, delivering only the remaining notifications and the final result." + ), + transports=("streamable-http",), + ), + "flow:session:terminate-then-reconnect": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior=("After terminating a session, a fresh connection obtains a new session id and operations succeed."), + transports=("streamable-http",), + ), + "flow:tool-result:resource-link-follow": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#resource-links", + behavior=( + "A resource_link returned by a tool call can be followed with resources/read on the linked " + "URI to retrieve the referenced contents." + ), + ), +} + + +def requirement(requirement_id: str) -> Callable[[_TestFn], _TestFn]: + """Mark a test as exercising a requirement from :data:`REQUIREMENTS`. + + Applies the `requirement` pytest marker and records the coverage link checked by + `test_coverage.py`. Unknown IDs fail at import time so a typo surfaces as a collection + error on the offending test, not as a missing-coverage report later. + """ + if requirement_id not in REQUIREMENTS: + raise KeyError(f"Unknown requirement id {requirement_id!r}: add it to REQUIREMENTS in {__name__}") + + def apply(test_fn: _TestFn) -> _TestFn: + covered_by(requirement_id).append(f"{test_fn.__module__}.{test_fn.__qualname__}") + return pytest.mark.requirement(requirement_id)(test_fn) + + return apply + + +_COVERAGE: dict[str, list[str]] = {} + + +def covered_by(requirement_id: str) -> list[str]: + """Return the (mutable) list of test names recorded as exercising `requirement_id`.""" + return _COVERAGE.setdefault(requirement_id, []) diff --git a/tests/interaction/auth/__init__.py b/tests/interaction/auth/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/interaction/auth/_harness.py b/tests/interaction/auth/_harness.py new file mode 100644 index 0000000000..d013364f33 --- /dev/null +++ b/tests/interaction/auth/_harness.py @@ -0,0 +1,465 @@ +"""In-process harness for the auth interaction tests. + +Co-hosts the SDK's authorization-server routes, protected-resource metadata route, and the +bearer-gated MCP endpoint on one Starlette app via `Server.streamable_http_app(auth=..., +token_verifier=..., auth_server_provider=...)`, drives that app through the streaming bridge +on a single `httpx.AsyncClient` carrying `auth=OAuthClientProvider(...)`, and completes the +authorize redirect headlessly by GETing the URL through the same bridge and parsing the code +from the 302 `Location`. The whole authorization-code flow runs in one event loop with no +sockets, no threads, and no real time. +""" + +import json +from collections.abc import AsyncIterator, Callable, Mapping, Sequence +from contextlib import AsyncExitStack, asynccontextmanager +from dataclasses import dataclass, field +from typing import Any +from urllib.parse import parse_qs, parse_qsl, urlsplit + +import httpx +from pydantic import AnyHttpUrl, AnyUrl, BaseModel +from starlette.types import ASGIApp, Receive, Scope, Send + +from mcp.client.auth import OAuthClientProvider +from mcp.client.client import Client +from mcp.client.streamable_http import streamable_http_client +from mcp.server import Server +from mcp.server.auth.provider import AccessToken, ProviderTokenVerifier +from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions, RevocationOptions +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken +from tests.interaction._connect import BASE_URL, NO_DNS_REBINDING_PROTECTION +from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider +from tests.interaction.transports._bridge import StreamingASGITransport + +REDIRECT_URI = f"{BASE_URL}/oauth/callback" + +AppShim = Callable[[ASGIApp], ASGIApp] + + +@dataclass +class RecordedRequest: + """A snapshot of an `httpx.Request` at the moment it was sent. + + The auth flow re-yields the same `httpx.Request` object after mutating its headers in + place for the retry, so tests that need to assert on the first attempt's headers must + capture a copy rather than a live reference. `record_requests` produces these. + """ + + method: str + url: httpx.URL + headers: dict[str, str] + content: bytes + + @property + def path(self) -> str: + return self.url.path + + +def record_requests() -> tuple[list[RecordedRequest], Callable[[httpx.Request], None]]: + """Build an `on_request` callback that snapshots each request, and the list it appends to.""" + recorded: list[RecordedRequest] = [] + + def on_request(request: httpx.Request) -> None: + recorded.append( + RecordedRequest( + method=request.method, + url=request.url, + headers=dict(request.headers), + content=bytes(request.content), + ) + ) + + return recorded, on_request + + +def metadata_body(model: BaseModel, **extra: object) -> bytes: + """Serialize a metadata model to a JSON body for `shimmed_app(serve=...)`. + + `extra` keys are merged into the serialized object so a test can inject fields the model + does not declare (e.g. an unknown extension field, to prove the client's parser tolerates + unrecognized members per RFC 8414/9728 §3.2). The model itself would silently drop such + fields at construction, so they have to be added after serialization. + """ + document = model.model_dump(by_alias=True, mode="json", exclude_none=True) + document.update(extra) + return json.dumps(document).encode() + + +class StaticTokenVerifier: + """A `TokenVerifier` backed by a fixed token→`AccessToken` mapping. + + Any token string not in the mapping verifies to `None`, which the bearer middleware treats + as an unrecognized token. Tests seed the mapping with the exact token shapes (valid, expired, + wrong scope, wrong audience) they need so the resource-server gate's behaviour is asserted in + isolation from the authorization-server provider. + """ + + def __init__(self, tokens: Mapping[str, AccessToken]) -> None: + self._tokens = dict(tokens) + + async def verify_token(self, token: str) -> AccessToken | None: + return self._tokens.get(token) + + +class InMemoryTokenStorage: + """A `TokenStorage` that holds tokens and client info as instance attributes. + + Tests pre-seed `client_info` (via the constructor or by assignment) to drive the + pre-registered path, and read both attributes after the flow to assert what the SDK + persisted. + """ + + def __init__(self, *, client_info: OAuthClientInformationFull | None = None) -> None: + self.tokens: OAuthToken | None = None + self.client_info: OAuthClientInformationFull | None = client_info + + async def get_tokens(self) -> OAuthToken | None: + return self.tokens + + async def set_tokens(self, tokens: OAuthToken) -> None: + self.tokens = tokens + + async def get_client_info(self) -> OAuthClientInformationFull | None: + return self.client_info + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + self.client_info = client_info + + +class HeadlessOAuth: + """Completes the authorize step in-process by following the redirect through the bridge. + + `redirect_handler` GETs the authorize URL on the bound client (with `auth=None` so the + request does not re-enter the locked auth flow), parses `code` and `state` from the 302 + `Location`, and stashes them; `callback_handler` returns the stashed pair. Tests inspect + `authorize_url` to assert what the SDK put on the authorize request. + + `state_override`: when set, `callback_handler` returns this value as the state instead of + the one parsed from the redirect, so tests can drive the state-mismatch path. + """ + + def __init__(self, *, state_override: str | None = None) -> None: + self.authorize_url: str | None = None + self.authorize_urls: list[str] = [] + self.error: str | None = None + self._state_override = state_override + self._http: httpx.AsyncClient | None = None + self._code: str = "" + self._state: str | None = None + + def bind(self, http_client: httpx.AsyncClient) -> None: + self._http = http_client + + async def redirect_handler(self, authorization_url: str) -> None: + assert self._http is not None + self.authorize_url = authorization_url + self.authorize_urls.append(authorization_url) + # auth=None is load-bearing: without it the GET re-enters OAuthClientProvider.async_auth_flow + # through its context lock and the flow deadlocks. + response = await self._http.get(authorization_url, follow_redirects=False, auth=None) + assert response.status_code == 302, f"authorize endpoint returned {response.status_code}: {response.text}" + params = parse_qs(urlsplit(response.headers["location"]).query) + self._code = params.get("code", [""])[0] + self._state = params.get("state", [None])[0] + self.error = params.get("error", [None])[0] + + async def callback_handler(self) -> tuple[str, str | None]: + return self._code, self._state_override if self._state_override is not None else self._state + + +def auth_settings( + *, required_scopes: Sequence[str] = ("mcp",), valid_scopes: Sequence[str] | None = None +) -> AuthSettings: + """Build `AuthSettings` for the co-hosted authorization + resource server. + + The issuer and resource URLs use the suite's loopback origin, which `validate_issuer_url` + accepts in lieu of HTTPS. Dynamic client registration is enabled. `valid_scopes` defaults + to `required_scopes` so a client requesting exactly those passes registration scope + validation; tests pass a wider set when they need the protected-resource metadata's + `scopes_supported` (which mirrors `required_scopes`) to differ from what the client may + register or when AS metadata should advertise additional scopes such as `offline_access`. + """ + required = list(required_scopes) + valid = list(valid_scopes) if valid_scopes is not None else required + return AuthSettings( + issuer_url=AnyHttpUrl(BASE_URL), + resource_server_url=AnyHttpUrl(f"{BASE_URL}/mcp"), + required_scopes=required, + client_registration_options=ClientRegistrationOptions( + enabled=True, valid_scopes=valid, default_scopes=required + ), + revocation_options=RevocationOptions(enabled=False), + ) + + +def oauth_client_metadata() -> OAuthClientMetadata: + """Build the client's registration metadata. + + `scope` is left unset so the SDK's scope-selection strategy chooses one from the server's + metadata before registration. + """ + return OAuthClientMetadata( + client_name="interaction-suite", + redirect_uris=[AnyUrl(REDIRECT_URI)], + grant_types=["authorization_code", "refresh_token"], + ) + + +def shimmed_app( + app: ASGIApp, + *, + not_found: frozenset[str] = frozenset(), + serve: Mapping[str, bytes | tuple[int, bytes]] | None = None, +) -> ASGIApp: + """Wrap an ASGI app so specific paths return canned responses before reaching the real app. + + Paths in `serve` return the given body as `application/json` (status 200, or the supplied + status when the value is a `(status, body)` pair); paths in `not_found` return 404; + everything else reaches the wrapped app unchanged. Used by the discovery tests to make a + well-known endpoint 404 or return alternate metadata while keeping the real authorization + and MCP endpoints behind it. + """ + overrides: dict[str, tuple[int, bytes]] = { + path: value if isinstance(value, tuple) else (200, value) for path, value in (serve or {}).items() + } + + async def wrapped(scope: Scope, receive: Receive, send: Send) -> None: + path = scope["path"] + if path in overrides: + status, body = overrides[path] + await send( + { + "type": "http.response.start", + "status": status, + "headers": [ + (b"content-type", b"application/json"), + (b"content-length", str(len(body)).encode()), + ], + } + ) + await send({"type": "http.response.body", "body": body}) + return + if path in not_found: + await send({"type": "http.response.start", "status": 404, "headers": []}) + await send({"type": "http.response.body", "body": b""}) + return + await app(scope, receive, send) + + return wrapped + + +def shim( + *, not_found: frozenset[str] = frozenset(), serve: Mapping[str, bytes | tuple[int, bytes]] | None = None +) -> AppShim: + """Build an `app_shim` for `connect_with_oauth` that applies `shimmed_app` with these overrides.""" + return lambda app: shimmed_app(app, not_found=not_found, serve=serve) + + +@dataclass +class _FirstChallenge: + """ASGI shim that answers the first request to a path with 401 + a given WWW-Authenticate. + + Subsequent requests pass through to the wrapped app. Used to make the initial 401 carry + parameters (such as `scope=`) that the SDK's own bearer middleware cannot be configured + to emit, so client behaviour driven by those parameters is reachable end to end. Reserve + this pattern for behaviour the real server cannot be made to produce. + """ + + app: ASGIApp + path: str + www_authenticate: str + _seen: set[str] = field(default_factory=set[str]) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http" and scope["path"] == self.path and self.path not in self._seen: + self._seen.add(self.path) + await send( + { + "type": "http.response.start", + "status": 401, + "headers": [(b"www-authenticate", self.www_authenticate.encode())], + } + ) + await send({"type": "http.response.body", "body": b""}) + return + await self.app(scope, receive, send) + + +def first_challenge_shim(www_authenticate: str, *, path: str = "/mcp") -> Callable[[ASGIApp], ASGIApp]: + """Build an `app_shim` that 401s the first request to `path` with the given header value.""" + return lambda app: _FirstChallenge(app, path, www_authenticate) + + +def step_up_shim(www_authenticate: str, *, on_nth_authenticated_post: int = 2) -> AppShim: + """Build an `app_shim` that 403s the Nth authenticated POST to `/mcp` with the given challenge. + + Subsequent requests pass through. Used to drive the client's `insufficient_scope` step-up + handling: the SDK's bearer middleware never emits `scope=` in its 403 challenge (see the + divergence on `hosting:auth:scope-403`), so the test supplies the 403 itself. Reserve this + pattern for behaviour the real server cannot be made to produce. + + The default `on_nth_authenticated_post=2` targets the `notifications/initialized` POST: the + first authenticated POST is the auth flow's retry of the original initialize request (yielded + after the 401 branch, where the generator ends without inspecting the response), so a 403 + there would not reach the step-up handler. + """ + seen = 0 + fired = False + + def factory(app: ASGIApp) -> ASGIApp: + async def wrapped(scope: Scope, receive: Receive, send: Send) -> None: + nonlocal seen, fired + if ( + not fired + and scope["type"] == "http" + and scope["path"] == "/mcp" + and scope["method"] == "POST" + and any(name == b"authorization" for name, _ in scope["headers"]) + ): + seen += 1 + if seen < on_nth_authenticated_post: + await app(scope, receive, send) + return + fired = True + await send( + { + "type": "http.response.start", + "status": 403, + "headers": [(b"www-authenticate", www_authenticate.encode())], + } + ) + await send({"type": "http.response.body", "body": b""}) + return + await app(scope, receive, send) + + return wrapped + + return factory + + +def m2m_token_shim(provider: InMemoryAuthorizationServerProvider, *, scopes: list[str]) -> AppShim: + """Build an `app_shim` that handles `grant_type=client_credentials` at `/token`. + + The SDK server's `TokenHandler` only routes `authorization_code` and `refresh_token`, so a + `client_credentials` request would fail discriminator validation. This shim mints a token via + `provider.mint_access_token` so the M2M client providers can complete e2e against the real + bearer middleware. The shim is harness; the SDK-under-test is the client provider, whose + outbound `/token` body the test asserts. The shim does not authenticate the client (no + credential check) because the test asserts the credentials on the recorded request, not on + the server's acceptance. + """ + + def factory(app: ASGIApp) -> ASGIApp: + async def wrapped(scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http" and scope["path"] == "/token" and scope["method"] == "POST": + # The streaming bridge buffers the request body and delivers it in a single + # http.request event, so one receive is sufficient. + message = await receive() + assert not message.get("more_body", False) + form = dict(parse_qsl(message.get("body", b"").decode())) + assert form.get("grant_type") == "client_credentials", ( + f"m2m_token_shim only handles client_credentials; got {form.get('grant_type')!r}" + ) + access = provider.mint_access_token(client_id="m2m", scopes=scopes, resource=form.get("resource")) + token = OAuthToken(access_token=access, token_type="Bearer", expires_in=3600, scope=" ".join(scopes)) + response_body = token.model_dump_json(exclude_none=True).encode() + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [ + (b"content-type", b"application/json"), + (b"content-length", str(len(response_body)).encode()), + (b"cache-control", b"no-store"), + ], + } + ) + await send({"type": "http.response.body", "body": response_body}) + return + await app(scope, receive, send) + + return wrapped + + return factory + + +@asynccontextmanager +async def connect_with_oauth( + server: Server, + *, + provider: InMemoryAuthorizationServerProvider, + settings: AuthSettings | None = None, + storage: InMemoryTokenStorage | None = None, + client_metadata: OAuthClientMetadata | None = None, + client_metadata_url: str | None = None, + headless: HeadlessOAuth | None = None, + auth: httpx.Auth | None = None, + verify_tokens: bool = True, + app_shim: Callable[[ASGIApp], ASGIApp] | None = None, + on_request: Callable[[httpx.Request], None] | None = None, +) -> AsyncIterator[tuple[Client, HeadlessOAuth]]: + """Connect a `Client` to a server's bearer-gated streamable-HTTP app, completing OAuth in process. + + Yields the connected `Client` and the `HeadlessOAuth` whose `authorize_url` records what the + SDK put on the authorize request. `on_request` records every HTTP request the underlying + `httpx.AsyncClient` issues, including those yielded from inside the auth flow. + + `headless`: supply a pre-configured `HeadlessOAuth` to override the callback behaviour + (state mismatch, error redirects). `verify_tokens=False` mounts the MCP endpoint without + the bearer middleware so a flow driven by a shimmed 401 completes regardless of the granted + scopes. `app_shim` wraps the built Starlette app before it reaches the bridge transport, + for tests that need to intercept or rewrite specific server responses. + + `auth`: supply a pre-built `httpx.Auth` (such as `ClientCredentialsOAuthProvider`) to use + instead of constructing the default `OAuthClientProvider`; in that case `storage`, + `client_metadata`, `client_metadata_url`, and `headless` are unused (the yielded + `HeadlessOAuth` is never invoked and its `authorize_url` stays None). + """ + settings = settings if settings is not None else auth_settings() + storage = storage if storage is not None else InMemoryTokenStorage() + client_metadata = client_metadata if client_metadata is not None else oauth_client_metadata() + headless = headless if headless is not None else HeadlessOAuth() + + oauth = ( + auth + if auth is not None + else OAuthClientProvider( + server_url=f"{BASE_URL}/mcp", + client_metadata=client_metadata, + storage=storage, + redirect_handler=headless.redirect_handler, + callback_handler=headless.callback_handler, + client_metadata_url=client_metadata_url, + ) + ) + + app: ASGIApp = server.streamable_http_app( + auth=settings, + token_verifier=ProviderTokenVerifier(provider) if verify_tokens else None, + auth_server_provider=provider, + transport_security=NO_DNS_REBINDING_PROTECTION, + ) + if app_shim is not None: + app = app_shim(app) + + event_hooks: dict[str, list[Callable[..., Any]]] | None = None + if on_request is not None: + record = on_request + + async def hook(request: httpx.Request) -> None: + record(request) + + event_hooks = {"request": [hook]} + + async with AsyncExitStack() as stack: + await stack.enter_async_context(server.session_manager.run()) + http_client = await stack.enter_async_context( + httpx.AsyncClient( + transport=StreamingASGITransport(app), base_url=BASE_URL, auth=oauth, event_hooks=event_hooks + ) + ) + headless.bind(http_client) + client = await stack.enter_async_context( + Client(streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client)) + ) + yield client, headless diff --git a/tests/interaction/auth/_provider.py b/tests/interaction/auth/_provider.py new file mode 100644 index 0000000000..5c88995a30 --- /dev/null +++ b/tests/interaction/auth/_provider.py @@ -0,0 +1,186 @@ +"""An in-memory implementation of the SDK's OAuth authorization-server provider protocol. + +The provider holds clients, authorization codes, refresh tokens and access tokens in plain +instance dicts so tests can inspect them; tokens are minted from `secrets.token_hex` so the +values are unique without being predictable. The behaviour mirrors what the SDK's authorization +handlers expect: `authorize` immediately mints a code and returns the redirect, `exchange_*` +issue and rotate tokens, and `load_*` are simple lookups. Only the parts the auth interaction +suite drives are implemented; methods the suite does not exercise raise `NotImplementedError`. +""" + +import secrets +import time + +from mcp.server.auth.provider import ( + AccessToken, + AuthorizationCode, + AuthorizationParams, + OAuthAuthorizationServerProvider, + RefreshToken, + TokenError, + construct_redirect_uri, +) +from mcp.shared.auth import OAuthClientInformationFull, OAuthToken + +_TOKEN_LIFETIME_SECONDS = 3600 + + +class InMemoryAuthorizationServerProvider( + OAuthAuthorizationServerProvider[AuthorizationCode, RefreshToken, AccessToken] +): + """An OAuth authorization-server provider backed by in-memory dicts. + + Holds registered clients, issued codes, refresh tokens and access tokens as instance state + so tests can both drive the SDK's authorization handlers and inspect what was issued. + + Knobs: + `default_scopes`: scopes granted when an authorize request supplies none. + `deny_authorize`: every authorize request returns an `error=access_denied` redirect. + `issue_expired_first`: the first issued token's `expires_in` is in the past so the + client immediately considers it expired and refreshes; the server-side + `AccessToken.expires_at` stays in the future so the bearer middleware accepts it + on the retry that completes the connect. + `fail_next_refresh`: the next refresh-token exchange raises `invalid_grant` once. + `reject_all_tokens`: `load_access_token` returns None for every token, so the bearer + middleware 401s every authenticated request. + """ + + def __init__( + self, + *, + default_scopes: list[str] | None = None, + deny_authorize: bool = False, + issue_expired_first: bool = False, + fail_next_refresh: bool = False, + reject_all_tokens: bool = False, + ) -> None: + self._default_scopes = list(default_scopes) if default_scopes is not None else ["mcp"] + self._issuer = "http://127.0.0.1:8000" + self._deny_authorize = deny_authorize + self._issue_expired_first = issue_expired_first + self._fail_next_refresh = fail_next_refresh + self._reject_all_tokens = reject_all_tokens + self._tokens_issued = 0 + self.clients: dict[str, OAuthClientInformationFull] = {} + self.codes: dict[str, AuthorizationCode] = {} + self.refresh_tokens: dict[str, RefreshToken] = {} + self.access_tokens: dict[str, AccessToken] = {} + + def _next_expires_in(self) -> int: + self._tokens_issued += 1 + if self._issue_expired_first and self._tokens_issued == 1: + return -_TOKEN_LIFETIME_SECONDS + return _TOKEN_LIFETIME_SECONDS + + def mint_access_token(self, *, client_id: str, scopes: list[str], resource: str | None = None) -> str: + """Mint and store an access token, returning its value. + + Used by the auth-code and refresh exchanges and by the M2M `/token` shim. The + server-side `expires_at` is always in the future regardless of `issue_expired_first`, + which only affects what the client is told. + """ + access = f"access_{secrets.token_hex(16)}" + self.access_tokens[access] = AccessToken( + token=access, + client_id=client_id, + scopes=scopes, + expires_at=int(time.time()) + _TOKEN_LIFETIME_SECONDS, + resource=resource, + ) + return access + + async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: + return self.clients.get(client_id) + + async def register_client(self, client_info: OAuthClientInformationFull) -> None: + assert client_info.client_id is not None + self.clients[client_info.client_id] = client_info + + async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: + """Mint an authorization code immediately and return the redirect carrying it. + + A real provider would interpose user consent here; the test provider grants + unconditionally so the headless redirect handler can complete the flow in-process. + When `deny_authorize` is set, returns an `error=access_denied` redirect instead. + """ + assert client.client_id is not None + if self._deny_authorize: + return construct_redirect_uri( + str(params.redirect_uri), error="access_denied", error_description="user denied", state=params.state + ) + code = AuthorizationCode( + code=f"code_{secrets.token_hex(16)}", + client_id=client.client_id, + scopes=params.scopes or self._default_scopes, + expires_at=time.time() + 300, + code_challenge=params.code_challenge, + redirect_uri=params.redirect_uri, + redirect_uri_provided_explicitly=params.redirect_uri_provided_explicitly, + resource=params.resource, + ) + self.codes[code.code] = code + # `iss` is RFC 9207's authorization-response issuer identifier — an extra parameter many + # real authorization servers send. Including it on every success redirect proves the + # client tolerates unrecognized callback parameters (RFC 6749 §4.1.2 MUST) by virtue of + # every flow test passing unchanged. + return construct_redirect_uri(str(params.redirect_uri), code=code.code, state=params.state, iss=self._issuer) + + async def load_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: str + ) -> AuthorizationCode | None: + return self.codes.get(authorization_code) + + async def exchange_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode + ) -> OAuthToken: + """Mint an access token and a refresh token for a valid authorization code, then consume the code.""" + assert client.client_id is not None + access = self.mint_access_token( + client_id=client.client_id, scopes=authorization_code.scopes, resource=authorization_code.resource + ) + refresh = f"refresh_{secrets.token_hex(16)}" + self.refresh_tokens[refresh] = RefreshToken( + token=refresh, + client_id=client.client_id, + scopes=authorization_code.scopes, + ) + del self.codes[authorization_code.code] + return OAuthToken( + access_token=access, + token_type="Bearer", + expires_in=self._next_expires_in(), + scope=" ".join(authorization_code.scopes), + refresh_token=refresh, + ) + + async def load_access_token(self, token: str) -> AccessToken | None: + if self._reject_all_tokens: + return None + return self.access_tokens.get(token) + + async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None: + return self.refresh_tokens.get(refresh_token) + + async def exchange_refresh_token( + self, client: OAuthClientInformationFull, refresh_token: RefreshToken, scopes: list[str] + ) -> OAuthToken: + """Mint a new access token and rotate the refresh token, consuming the old one.""" + assert client.client_id is not None + if self._fail_next_refresh: + self._fail_next_refresh = False + raise TokenError(error="invalid_grant", error_description="refresh denied by harness") + access = self.mint_access_token(client_id=client.client_id, scopes=scopes) + new_refresh = f"refresh_{secrets.token_hex(16)}" + self.refresh_tokens[new_refresh] = RefreshToken(token=new_refresh, client_id=client.client_id, scopes=scopes) + del self.refresh_tokens[refresh_token.token] + return OAuthToken( + access_token=access, + token_type="Bearer", + expires_in=self._next_expires_in(), + scope=" ".join(scopes), + refresh_token=new_refresh, + ) + + async def revoke_token(self, token: AccessToken | RefreshToken) -> None: + """Not exercised by this suite; revocation is out of scope for the interaction tests.""" + raise NotImplementedError diff --git a/tests/interaction/auth/test_as_handlers.py b/tests/interaction/auth/test_as_handlers.py new file mode 100644 index 0000000000..5cb4e92d86 --- /dev/null +++ b/tests/interaction/auth/test_as_handlers.py @@ -0,0 +1,300 @@ +"""Error-plane behaviour of the SDK's bundled OAuth authorization-server handlers. + +The end-to-end OAuth tests prove the handlers' happy paths; these tests drive the same +mounted authorization server directly with raw httpx so the assertions are the HTTP +semantics (status, redirect target, error body, headers) the OAuth RFCs mandate. Almost +every behaviour here is enforced by the SDK's own handlers; where the pinned output +deviates from the RFC, the manifest entry carries the divergence. +""" + +import base64 +import hashlib +import secrets +from collections.abc import AsyncIterator +from urllib.parse import parse_qs, urlsplit + +import httpx +import pytest +from inline_snapshot import snapshot + +from mcp.server import Server +from mcp.server.auth.provider import ProviderTokenVerifier +from mcp.shared.auth import OAuthClientInformationFull +from tests.interaction._connect import mounted_app +from tests.interaction._requirements import requirement +from tests.interaction.auth._harness import REDIRECT_URI, auth_settings, oauth_client_metadata +from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider + +pytestmark = pytest.mark.anyio + + +@pytest.fixture +async def as_app() -> AsyncIterator[tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider]]: + """Co-host the SDK's authorization-server routes and yield a raw httpx client against them.""" + provider = InMemoryAuthorizationServerProvider() + settings = auth_settings() + async with mounted_app( + Server("guarded"), + auth=settings, + token_verifier=ProviderTokenVerifier(provider), + auth_server_provider=provider, + ) as (http, _): + yield http, provider + + +def _pkce_pair() -> tuple[str, str]: + """Generate a (code_verifier, code_challenge) pair the same way the SDK client does.""" + verifier = secrets.token_urlsafe(48)[:64] + challenge = base64.urlsafe_b64encode(hashlib.sha256(verifier.encode()).digest()).decode().rstrip("=") + return verifier, challenge + + +async def _register_client(http: httpx.AsyncClient) -> OAuthClientInformationFull: + """Dynamically register a client and return its full credentials.""" + response = await http.post("/register", content=oauth_client_metadata().model_dump_json()) + assert response.status_code == 201 + return OAuthClientInformationFull.model_validate_json(response.content) + + +async def _mint_code(http: httpx.AsyncClient) -> tuple[OAuthClientInformationFull, str, str]: + """Register a client, complete a valid authorize step, and return (client_info, code, verifier).""" + client_info = await _register_client(http) + assert client_info.client_id is not None + verifier, challenge = _pkce_pair() + response = await http.get( + "/authorize", + params={ + "response_type": "code", + "client_id": client_info.client_id, + "redirect_uri": REDIRECT_URI, + "code_challenge": challenge, + "code_challenge_method": "S256", + "state": "s", + }, + follow_redirects=False, + ) + assert response.status_code == 302 + redirect = urlsplit(response.headers["location"]) + assert f"{redirect.scheme}://{redirect.netloc}{redirect.path}" == REDIRECT_URI + code = parse_qs(redirect.query)["code"][0] + return client_info, code, verifier + + +def _token_form(client_info: OAuthClientInformationFull, **overrides: str) -> dict[str, str]: + """Build the form body for an authorization-code token request, with the defaults a real client would send.""" + assert client_info.client_id is not None + assert client_info.client_secret is not None + form = { + "grant_type": "authorization_code", + "client_id": client_info.client_id, + "client_secret": client_info.client_secret, + "redirect_uri": REDIRECT_URI, + } + form.update(overrides) + return form + + +@requirement("hosting:auth:as:authorize-requires-pkce") +async def test_authorize_without_a_code_challenge_is_rejected_with_invalid_request( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """An authorize request omitting `code_challenge` is redirected back with `error=invalid_request`. + + PKCE is mandatory: the bundled authorize handler models `code_challenge` as a required field, so + a code without a stored challenge can never be issued. That makes the PKCE-downgrade attack (a + token request carrying a verifier for a code minted without a challenge) structurally impossible + through these handlers, so no separate downgrade-guard test is needed. + """ + http, _ = as_app + client_info = await _register_client(http) + assert client_info.client_id is not None + + response = await http.get( + "/authorize", + params={ + "response_type": "code", + "client_id": client_info.client_id, + "redirect_uri": REDIRECT_URI, + "state": "abc", + }, + follow_redirects=False, + ) + + assert response.status_code == 302 + redirect = urlsplit(response.headers["location"]) + assert f"{redirect.scheme}://{redirect.netloc}{redirect.path}" == REDIRECT_URI + params = parse_qs(redirect.query) + assert params["error"] == ["invalid_request"] + assert params["state"] == ["abc"] + assert "code_challenge" in params["error_description"][0] + + +@requirement("hosting:auth:as:verifier-mismatch") +async def test_a_mismatched_code_verifier_is_rejected_with_invalid_grant( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """A token exchange whose `code_verifier` does not hash to the stored challenge is rejected.""" + http, _ = as_app + client_info, code, _ = await _mint_code(http) + + response = await http.post("/token", data=_token_form(client_info, code=code, code_verifier="0" * 64)) + + assert response.status_code == 400 + assert response.json() == snapshot({"error": "invalid_grant", "error_description": "incorrect code_verifier"}) + + +@requirement("hosting:auth:as:code-single-use") +async def test_reusing_an_authorization_code_is_rejected_with_invalid_grant( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """An authorization code can be exchanged exactly once; a second exchange is `invalid_grant`. + + The handler does not track used codes itself: it returns `invalid_grant` whenever the provider's + `load_authorization_code` returns None, and the in-memory provider deletes the code on first + exchange. The test proves the combination enforces single-use; a provider that did not consume + codes would not get this guarantee from the handler. + """ + http, _ = as_app + client_info, code, verifier = await _mint_code(http) + form = _token_form(client_info, code=code, code_verifier=verifier) + + first = await http.post("/token", data=form) + assert first.status_code == 200 + assert first.json()["token_type"] == "Bearer" + + second = await http.post("/token", data=form) + assert second.status_code == 400 + assert second.json() == snapshot( + {"error": "invalid_grant", "error_description": "authorization code does not exist"} + ) + + +@requirement("hosting:auth:as:redirect-uri-binding") +async def test_a_redirect_uri_differing_from_authorize_is_rejected_at_the_token_endpoint( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """A token exchange whose `redirect_uri` differs from the one used at authorize is rejected. + + This is the security-critical half of redirect-URI binding: a code intercepted via redirect + substitution cannot be redeemed because the attacker cannot reproduce the original authorize + redirect URI at the token endpoint. RFC 6749 §5.2 specifies `invalid_grant` for this case; + the SDK returns `invalid_request` (see the divergence on the requirement). The rejection + itself is the security property and is correct. + """ + http, _ = as_app + client_info, code, verifier = await _mint_code(http) + + response = await http.post( + "/token", + data=_token_form(client_info, code=code, code_verifier=verifier, redirect_uri=f"{REDIRECT_URI}/different"), + ) + + assert response.status_code == 400 + assert response.json() == snapshot( + { + "error": "invalid_request", + "error_description": "redirect_uri did not match the one used when creating auth code", + } + ) + + +@requirement("hosting:auth:as:token-cache-headers") +async def test_token_responses_carry_cache_control_no_store( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """Every token-endpoint response (success and error) carries `Cache-Control: no-store`.""" + http, _ = as_app + client_info, code, verifier = await _mint_code(http) + form = _token_form(client_info, code=code, code_verifier=verifier) + + success = await http.post("/token", data=form) + assert success.status_code == 200 + assert success.headers["cache-control"] == "no-store" + assert success.headers["pragma"] == "no-cache" + + failure = await http.post("/token", data=form) + assert failure.status_code == 400 + assert failure.headers["cache-control"] == "no-store" + assert failure.headers["pragma"] == "no-cache" + + +@requirement("hosting:auth:as:register-error-response") +async def test_registration_with_invalid_metadata_is_rejected_with_400( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """Invalid client metadata at the registration endpoint returns 400 with an RFC 7591 error body.""" + http, _ = as_app + + malformed = await http.post("/register", json={"redirect_uris": ["not-a-url"]}) + assert malformed.status_code == 400 + assert malformed.json()["error"] == "invalid_client_metadata" + + body = oauth_client_metadata().model_dump(mode="json", exclude_none=True) + + no_auth_code = await http.post("/register", json=body | {"grant_types": ["refresh_token"]}) + assert no_auth_code.status_code == 400 + assert no_auth_code.json() == snapshot( + {"error": "invalid_client_metadata", "error_description": "grant_types must include 'authorization_code'"} + ) + + bad_scope = await http.post("/register", json=body | {"scope": "forbidden"}) + assert bad_scope.status_code == 400 + body = bad_scope.json() + assert body["error"] == "invalid_client_metadata" + # The description embeds a set difference whose ordering is not stable, so assert the prefix. + assert body["error_description"].startswith("Requested scopes are not valid: ") + + +@requirement("hosting:auth:as:redirect-uri-binding") +async def test_authorize_with_an_unregistered_redirect_uri_is_rejected_directly( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """An authorize request naming an unregistered `redirect_uri` returns 400 without redirecting to it. + + The security property is that the authorization server never redirects to an unvalidated URI: + the response is a direct JSON error to the user agent, not a 302 to the attacker's host. + """ + http, _ = as_app + client_info = await _register_client(http) + assert client_info.client_id is not None + _, challenge = _pkce_pair() + + response = await http.get( + "/authorize", + params={ + "response_type": "code", + "client_id": client_info.client_id, + "redirect_uri": "http://127.0.0.1:8000/evil", + "code_challenge": challenge, + "code_challenge_method": "S256", + }, + follow_redirects=False, + ) + + assert response.status_code == 400 + assert "location" not in response.headers + body = response.json() + assert body["error"] == "invalid_request" + assert "not registered" in body["error_description"] + + +@requirement("hosting:auth:as:redirect-uri-scheme") +async def test_a_non_loopback_http_redirect_uri_is_accepted_at_registration( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """A registration carrying a non-HTTPS, non-loopback redirect URI is accepted. + + The spec requires every redirect URI to be either HTTPS or a loopback host; the bundled + registration handler does not enforce this and registers `http://evil.example/callback` + successfully. See the divergence on the requirement. + """ + http, provider = as_app + body = oauth_client_metadata().model_dump(mode="json", exclude_none=True) + body["redirect_uris"] = ["http://evil.example/callback"] + + response = await http.post("/register", json=body) + + assert response.status_code == 201 + info = OAuthClientInformationFull.model_validate_json(response.content) + assert [str(u) for u in (info.redirect_uris or [])] == ["http://evil.example/callback"] + assert info.client_id in provider.clients diff --git a/tests/interaction/auth/test_authorize_token.py b/tests/interaction/auth/test_authorize_token.py new file mode 100644 index 0000000000..cb8524c097 --- /dev/null +++ b/tests/interaction/auth/test_authorize_token.py @@ -0,0 +1,399 @@ +"""Authorization-request, token-request, and PKCE wire-level invariants of the SDK's OAuth client. + +Every test connects a real `Client` end to end via `connect_with_oauth`; the assertions are on +the parsed authorize URL and the recorded `/token` form body, because those wire shapes are what +the spec mandates and `Client` cannot observe them. The recording uses `record_requests`, which +snapshots each request at send time so the auth flow's in-place header mutation on retry never +affects what was captured for the first attempt. + +Tests #1/#2/#4/#5 share one `recorded_oauth_flow` fixture (one connect, several disjoint +assertions on its recording); the others connect fresh because each needs a different harness +configuration. +""" + +import base64 +import hashlib +import json +import re +from collections.abc import AsyncIterator +from dataclasses import dataclass +from urllib.parse import parse_qsl, quote, urlsplit + +import anyio +import pytest +from inline_snapshot import snapshot +from pydantic import AnyHttpUrl, AnyUrl + +from mcp import types +from mcp.client.auth import OAuthFlowError +from mcp.server import Server, ServerRequestContext +from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata +from mcp.types import ListToolsResult, Tool +from tests.interaction._connect import BASE_URL +from tests.interaction._requirements import requirement +from tests.interaction.auth._harness import ( + REDIRECT_URI, + HeadlessOAuth, + InMemoryTokenStorage, + RecordedRequest, + auth_settings, + connect_with_oauth, + first_challenge_shim, + record_requests, + shimmed_app, +) +from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider + +pytestmark = pytest.mark.anyio + +PRM_PATH = "/.well-known/oauth-protected-resource/mcp" +ASM_PATH = "/.well-known/oauth-authorization-server" + + +async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="echo", input_schema={"type": "object"})]) + + +def authorize_params(authorize_url: str) -> dict[str, str]: + """Parse the authorize URL's query string into a flat dict (one value per key).""" + return dict(parse_qsl(urlsplit(authorize_url).query)) + + +def form_body(request: RecordedRequest) -> dict[str, str]: + """Parse an `application/x-www-form-urlencoded` request body into a flat dict.""" + return dict(parse_qsl(request.content.decode())) + + +def find(recorded: list[RecordedRequest], method: str, path: str) -> list[RecordedRequest]: + """Filter recorded requests by method and exact path.""" + return [r for r in recorded if r.method == method and r.path == path] + + +@dataclass +class RecordedFlow: + """One completed OAuth connect: every recorded request, plus the parsed authorize URL params.""" + + requests: list[RecordedRequest] + authorize_url: str + + @property + def authorize(self) -> dict[str, str]: + return authorize_params(self.authorize_url) + + @property + def token_request(self) -> RecordedRequest: + token_posts = find(self.requests, "POST", "/token") + assert len(token_posts) == 1 + return token_posts[0] + + +@pytest.fixture +async def recorded_oauth_flow() -> AsyncIterator[RecordedFlow]: + """Run one full OAuth connect with default configuration and yield its recorded wire traffic. + + `valid_scopes` includes `offline_access` so the AS metadata advertises it and the SDK's + SEP-2207 auto-append (and the resulting `prompt=consent`) is exercised; `required_scopes` + stays at `["mcp"]` so the issued token still passes the bearer middleware. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + settings = auth_settings(required_scopes=["mcp"], valid_scopes=["mcp", "offline_access"]) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, settings=settings, on_request=on_request) as ( + client, + headless, + ): + await client.list_tools() + + assert headless.authorize_url is not None + yield RecordedFlow(requests=recorded, authorize_url=headless.authorize_url) + + +@requirement("client-auth:pkce:s256") +@requirement("client-auth:resource-parameter") +@requirement("client-auth:authorize:offline-access-consent") +async def test_the_authorize_url_carries_s256_pkce_and_the_resource_indicator( + recorded_oauth_flow: RecordedFlow, +) -> None: + """Every spec-mandated parameter appears on the authorize URL with the right value. + + The full key set is snapshotted so a parameter added or dropped fails the test. The + `code_challenge` length bound is the RFC 7636 §4.2 grammar; an S256 challenge is in + practice always 43 characters, so the upper bound is never approached. + """ + params = recorded_oauth_flow.authorize + + assert sorted(params) == snapshot( + [ + "client_id", + "code_challenge", + "code_challenge_method", + "prompt", + "redirect_uri", + "resource", + "response_type", + "scope", + "state", + ] + ) + assert params["response_type"] == "code" + assert params["code_challenge_method"] == "S256" + assert 43 <= len(params["code_challenge"]) <= 128 + # The exact resource value depends on canonical-URI normalisation (a spec ambiguity); pin + # the stable prefix so the test does not lock in a trailing-slash decision. + assert params["resource"].startswith(BASE_URL) + assert params["state"] != "" + + assert params["scope"].split(" ") == snapshot(["mcp", "offline_access"]) + assert params["prompt"] == "consent" + + +@requirement("client-auth:pkce:s256") +async def test_the_code_verifier_on_the_token_request_hashes_to_the_code_challenge( + recorded_oauth_flow: RecordedFlow, +) -> None: + """The PKCE verifier sent on /token is the S256 pre-image of the challenge sent on /authorize. + + The verifier is also checked against RFC 7636 §4.1's length and `unreserved` charset. + """ + challenge = recorded_oauth_flow.authorize["code_challenge"] + verifier = form_body(recorded_oauth_flow.token_request)["code_verifier"] + + assert re.fullmatch(r"[A-Za-z0-9._~-]{43,128}", verifier) + assert base64.urlsafe_b64encode(hashlib.sha256(verifier.encode()).digest()).decode().rstrip("=") == challenge + + +@requirement("client-auth:state:verify") +async def test_a_mismatched_state_on_the_callback_aborts_the_flow() -> None: + """A callback whose state does not match the value sent on /authorize raises and stops the flow. + + The auth flow runs inside the streamable-HTTP client's task group, so the `OAuthFlowError` + reaches the test wrapped in nested single-element exception groups; `pytest.RaisesGroup` + asserts the leaf type and the SDK-authored message prefix (the full message embeds two + random tokens). + """ + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + headless = HeadlessOAuth(state_override="wrong-state") + + with anyio.fail_after(5): + with pytest.RaisesGroup( + pytest.RaisesExc(OAuthFlowError, match="^State parameter mismatch:"), flatten_subgroups=True + ): + # Entering the connect raises during the OAuth handshake (inside `Client.__aenter__`), + # so an `async with` body would be unreachable; entering explicitly avoids dead code. + await connect_with_oauth(server, provider=provider, headless=headless).__aenter__() + + +@requirement("client-auth:resource-parameter") +async def test_the_authorization_code_token_request_carries_grant_type_code_redirect_and_resource( + recorded_oauth_flow: RecordedFlow, +) -> None: + """The /token form body has exactly the auth-code grant fields, with redirect_uri and resource matching /authorize. + + `client_secret` is present because the SDK's dynamic-registration handler issues a secret + and the client defaults to `client_secret_post`. + """ + token_req = recorded_oauth_flow.token_request + body = form_body(token_req) + + assert sorted(body) == snapshot( + ["client_id", "client_secret", "code", "code_verifier", "grant_type", "redirect_uri", "resource"] + ) + assert body["grant_type"] == "authorization_code" + assert body["code"] != "" + assert body["redirect_uri"] == recorded_oauth_flow.authorize["redirect_uri"] + assert body["resource"] == recorded_oauth_flow.authorize["resource"] + assert token_req.headers["content-type"] == "application/x-www-form-urlencoded" + + +@requirement("client-auth:bearer-header:every-request") +async def test_every_mcp_request_after_auth_carries_the_bearer_header_and_never_a_query_token( + recorded_oauth_flow: RecordedFlow, +) -> None: + """Every MCP request after the flow has `Authorization: Bearer ...` and never `?access_token=`. + + The first /mcp POST is the unauthenticated trigger and is asserted to carry no Authorization + header; that assertion is only meaningful because the recording snapshots requests at send + time (the SDK mutates the same request object in place for the retry). + """ + mcp_posts = find(recorded_oauth_flow.requests, "POST", "/mcp") + assert len(mcp_posts) >= 3 + + assert "authorization" not in mcp_posts[0].headers + for r in mcp_posts[1:]: + assert r.headers["authorization"].startswith("Bearer ") + assert r.headers["authorization"] != "Bearer " + assert "access_token" not in dict(r.url.params) + + +@requirement("client-auth:token-endpoint-auth-method") +async def test_a_client_with_a_secret_authenticates_the_token_request_with_http_basic() -> None: + """A `client_secret_basic` client sends URL-encoded credentials in HTTP Basic, not the body. + + Credentials are URL-encoded before base64 per RFC 6749 §2.3.1; the secret contains `/` so + the encoding is observable. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + client_info = OAuthClientInformationFull( + client_id="cid", + client_secret="s/cret", + token_endpoint_auth_method="client_secret_basic", + redirect_uris=[AnyUrl(REDIRECT_URI)], + grant_types=["authorization_code", "refresh_token"], + scope="mcp", + ) + await provider.register_client(client_info) + storage = InMemoryTokenStorage(client_info=client_info) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, storage=storage, on_request=on_request) as (client, _): + await client.list_tools() + + assert find(recorded, "POST", "/register") == [] + [token_req] = find(recorded, "POST", "/token") + + decoded = base64.b64decode(token_req.headers["authorization"].removeprefix("Basic ")).decode() + assert decoded == f"{quote('cid', safe='')}:{quote('s/cret', safe='')}" + assert "client_secret" not in form_body(token_req) + + +@requirement("client-auth:token-endpoint-auth-method") +async def test_the_registered_auth_method_is_used_regardless_of_as_metadata_advertised_methods() -> None: + """The token-endpoint auth method comes from the registered client info, not from AS metadata. + + The shim serves AS metadata advertising only `client_secret_basic`; the client dynamically + registers and the SDK's registration handler issues `client_secret_post`. The client uses + `client_secret_post` (secret in the body, no Basic header) because the SDK reads the + registered `token_endpoint_auth_method`, not `token_endpoint_auth_methods_supported`. Other + SDKs (TypeScript, Go) do consult the AS metadata; this test pins where the python SDK's + selection point lives. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + override = OAuthMetadata( + issuer=AnyHttpUrl(f"{BASE_URL}/"), + authorization_endpoint=AnyHttpUrl(f"{BASE_URL}/authorize"), + token_endpoint=AnyHttpUrl(f"{BASE_URL}/token"), + registration_endpoint=AnyHttpUrl(f"{BASE_URL}/register"), + scopes_supported=["mcp"], + grant_types_supported=["authorization_code", "refresh_token"], + code_challenge_methods_supported=["S256"], + token_endpoint_auth_methods_supported=["client_secret_basic"], + ) + serve = {ASM_PATH: override.model_dump_json(exclude_none=True).encode()} + + with anyio.fail_after(5): + async with connect_with_oauth( + server, provider=provider, app_shim=lambda app: shimmed_app(app, serve=serve), on_request=on_request + ) as (client, _): + await client.list_tools() + + [register] = find(recorded, "POST", "/register") + assert json.loads(register.content).get("token_endpoint_auth_method") is None + + [token_req] = find(recorded, "POST", "/token") + body = form_body(token_req) + assert "client_secret" in body + assert body["client_secret"] != "" + assert "authorization" not in token_req.headers + + +@requirement("client-auth:scope-selection:priority") +async def test_scope_is_selected_from_the_www_authenticate_challenge_over_prm_metadata() -> None: + """When the 401 challenge carries `scope=`, that value is requested instead of the PRM scopes. + + The SDK's bearer middleware never emits `scope=` in WWW-Authenticate (see the divergence + on `hosting:auth:scope-403`), so the test supplies the first 401 itself via + `first_challenge_shim` and disables token verification so the post-auth retry succeeds + regardless of the granted scope. PRM advertises `["from-prm"]` (it mirrors + `required_scopes`); the challenge says `from-header`; the authorize URL must carry + `from-header`. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider(default_scopes=["from-header"]) + server = Server("guarded", on_list_tools=list_tools) + settings = auth_settings(required_scopes=["from-prm"], valid_scopes=["from-header", "from-prm"]) + challenge = f'Bearer scope="from-header", resource_metadata="{BASE_URL}{PRM_PATH}"' + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + settings=settings, + verify_tokens=False, + app_shim=first_challenge_shim(challenge), + on_request=on_request, + ) as (client, headless): + await client.list_tools() + + assert headless.authorize_url is not None + assert authorize_params(headless.authorize_url)["scope"] == "from-header" + + [register] = find(recorded, "POST", "/register") + assert json.loads(register.content)["scope"] == "from-header" + + +@requirement("client-auth:pkce:refuse-if-unsupported") +async def test_pkce_is_still_sent_when_as_metadata_omits_code_challenge_methods_supported() -> None: + """AS metadata without `code_challenge_methods_supported` does not stop the client sending PKCE. + + The spec says the client MUST refuse to proceed in this case; the SDK proceeds and the flow + completes. See the divergence on the requirement. + """ + override = OAuthMetadata( + issuer=AnyHttpUrl(f"{BASE_URL}/"), + authorization_endpoint=AnyHttpUrl(f"{BASE_URL}/authorize"), + token_endpoint=AnyHttpUrl(f"{BASE_URL}/token"), + registration_endpoint=AnyHttpUrl(f"{BASE_URL}/register"), + scopes_supported=["mcp"], + grant_types_supported=["authorization_code", "refresh_token"], + ) + assert override.code_challenge_methods_supported is None + serve = {ASM_PATH: override.model_dump_json(exclude_none=True).encode()} + + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + with anyio.fail_after(5): + async with connect_with_oauth( + server, provider=provider, app_shim=lambda app: shimmed_app(app, serve=serve) + ) as (client, headless): + result = await client.list_tools() + + assert headless.authorize_url is not None + params = authorize_params(headless.authorize_url) + assert params["code_challenge_method"] == "S256" + assert params["code_challenge"] != "" + assert result.tools[0].name == "echo" + + +@requirement("client-auth:authorize:error-surfaces") +async def test_an_authorize_error_on_the_callback_aborts_the_flow_before_the_token_request() -> None: + """An `error=` redirect from /authorize aborts the flow with no /token request issued. + + The SDK's callback contract is `() -> (code, state)` with no error form, so the failure is + observed as an empty code reaching the SDK and `OAuthFlowError("No authorization code + received")` being raised. The actual `error` value from the redirect is not surfaced to the + caller; that gap is noted in the manifest. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider(deny_authorize=True) + server = Server("guarded", on_list_tools=list_tools) + headless = HeadlessOAuth() + + with anyio.fail_after(5): + with pytest.RaisesGroup( + pytest.RaisesExc(OAuthFlowError, match="^No authorization code received$"), flatten_subgroups=True + ): + await connect_with_oauth(server, provider=provider, headless=headless, on_request=on_request).__aenter__() + + assert headless.error == "access_denied" + assert find(recorded, "POST", "/token") == [] diff --git a/tests/interaction/auth/test_bearer.py b/tests/interaction/auth/test_bearer.py new file mode 100644 index 0000000000..341a8e0db9 --- /dev/null +++ b/tests/interaction/auth/test_bearer.py @@ -0,0 +1,189 @@ +"""Resource-server bearer-token gate: status codes and `WWW-Authenticate` for each token shape. + +These tests mount only the resource-server side of the auth wiring (a `StaticTokenVerifier` +seeded with hand-built tokens, no authorization-server provider) and speak raw HTTP, since +every assertion is about HTTP semantics the SDK `Client` cannot observe: the 401/403 status, +the `WWW-Authenticate` header structure, and that a wrong-audience token reaches the MCP +endpoint behind the gate. The flow side of the same 401 is `test_flow.py`'s flagship test. +""" + +import time +from collections.abc import AsyncIterator + +import httpx +import pytest +from inline_snapshot import snapshot + +from mcp.server import Server +from mcp.server.auth.provider import AccessToken +from mcp.types import JSONRPCResponse +from tests.interaction._connect import base_headers, initialize_body, mounted_app +from tests.interaction._requirements import requirement +from tests.interaction.auth._harness import StaticTokenVerifier, auth_settings + +pytestmark = pytest.mark.anyio + +REQUIRED_SCOPE = "mcp:read" +RESOURCE_METADATA_URL = "http://127.0.0.1:8000/.well-known/oauth-protected-resource/mcp" + +_FUTURE = int(time.time()) + 3600 +_PAST = int(time.time()) - 3600 + +TOKENS = { + "tok-valid": AccessToken(token="tok-valid", client_id="c", scopes=[REQUIRED_SCOPE], expires_at=_FUTURE), + "tok-expired": AccessToken(token="tok-expired", client_id="c", scopes=[REQUIRED_SCOPE], expires_at=_PAST), + "tok-noscope": AccessToken(token="tok-noscope", client_id="c", scopes=["other:thing"], expires_at=_FUTURE), + "tok-wrong-aud": AccessToken( + token="tok-wrong-aud", + client_id="c", + scopes=[REQUIRED_SCOPE], + expires_at=_FUTURE, + resource="https://other.example/mcp", + ), +} + + +@pytest.fixture +async def protected() -> AsyncIterator[httpx.AsyncClient]: + """A bearer-gated streamable-HTTP app (resource server only) on the in-process bridge.""" + server = Server("rs") + settings = auth_settings(required_scopes=[REQUIRED_SCOPE]) + async with mounted_app(server, auth=settings, token_verifier=StaticTokenVerifier(TOKENS)) as (http, _): + yield http + + +async def post_mcp( + http: httpx.AsyncClient, *, bearer: str | None = None, query: dict[str, str] | None = None +) -> httpx.Response: + """POST an initialize body to `/mcp`, optionally with a bearer token and/or a query string.""" + headers = base_headers() + if bearer is not None: + headers["authorization"] = f"Bearer {bearer}" + return await http.post("/mcp", headers=headers, params=query, json=initialize_body()) + + +def parse_www_authenticate(value: str) -> dict[str, str]: + """Parse a `Bearer k="v", k="v"` challenge into a dict. + + The SDK emits each parameter exactly once, comma-space separated, with double-quoted + values that contain no quotes themselves; this helper relies on that and would fail + visibly if the format changed. + """ + scheme, _, params = value.partition(" ") + assert scheme == "Bearer" + return {key: quoted.strip('"') for key, _, quoted in (pair.partition("=") for pair in params.split(", "))} + + +@requirement("hosting:auth:missing-401") +async def test_a_request_with_no_authorization_header_is_challenged_with_resource_metadata( + protected: httpx.AsyncClient, +) -> None: + """No `Authorization` header → 401 with a `WWW-Authenticate` carrying `resource_metadata`. + + The snapshot pins current behaviour: the SDK collapses the no-header, unknown-token, and + expired-token cases into one challenge (`error="invalid_token"`, no `scope` parameter). The + spec says the discovery-time challenge SHOULD include `scope` and RFC 6750 says the + no-credentials case SHOULD NOT carry an error code; both gaps are recorded as the divergence + on this requirement. Asserting the dict equals an exact key set also pins that no parameter + appears twice. + """ + response = await post_mcp(protected) + + assert response.status_code == 401 + assert response.headers["www-authenticate"] == snapshot( + 'Bearer error="invalid_token", error_description="Authentication required", ' + 'resource_metadata="http://127.0.0.1:8000/.well-known/oauth-protected-resource/mcp"' + ) + assert parse_www_authenticate(response.headers["www-authenticate"]) == { + "error": "invalid_token", + "error_description": "Authentication required", + "resource_metadata": RESOURCE_METADATA_URL, + } + assert response.json() == snapshot({"error": "invalid_token", "error_description": "Authentication required"}) + + +@requirement("hosting:auth:invalid-401") +async def test_an_unrecognized_bearer_token_is_answered_401_invalid_token(protected: httpx.AsyncClient) -> None: + """A token the verifier does not recognize is answered 401 `invalid_token`. + + The challenge is identical to the no-header case (the backend returns `None` for both); the + missing `scope` parameter is the recorded divergence on this requirement. + """ + response = await post_mcp(protected, bearer="tok-unknown") + + assert response.status_code == 401 + assert parse_www_authenticate(response.headers["www-authenticate"]) == { + "error": "invalid_token", + "error_description": "Authentication required", + "resource_metadata": RESOURCE_METADATA_URL, + } + + +@requirement("hosting:auth:expired-401") +async def test_an_expired_token_is_answered_401(protected: httpx.AsyncClient) -> None: + """A token whose `expires_at` is in the past is answered 401 `invalid_token`. + + The expiry check is the bearer backend's, against the wall clock; the test seeds a concrete + past timestamp so no time mocking is involved. The missing `scope` parameter is the recorded + divergence on this requirement. + """ + response = await post_mcp(protected, bearer="tok-expired") + + assert response.status_code == 401 + assert parse_www_authenticate(response.headers["www-authenticate"])["error"] == "invalid_token" + + +@requirement("hosting:auth:scope-403") +async def test_a_token_missing_a_required_scope_is_answered_403_insufficient_scope_without_a_scope_param( + protected: httpx.AsyncClient, +) -> None: + """A token lacking the required scope is answered 403 `insufficient_scope`, with no `scope` parameter. + + The spec's runtime-insufficient-scope guidance says the challenge SHOULD include `scope` + naming the required scope; the SDK never emits it, recorded as the divergence on this + requirement. The SDK client reads `scope` from this header to drive step-up, so the gap is + a resource-server/client asymmetry. + """ + response = await post_mcp(protected, bearer="tok-noscope") + + assert response.status_code == 403 + parsed = parse_www_authenticate(response.headers["www-authenticate"]) + assert parsed == { + "error": "insufficient_scope", + "error_description": f"Required scope: {REQUIRED_SCOPE}", + "resource_metadata": RESOURCE_METADATA_URL, + } + assert "scope" not in parsed + + +@requirement("hosting:auth:aud-validation") +async def test_a_token_with_a_mismatched_audience_is_accepted(protected: httpx.AsyncClient) -> None: + """A token whose `resource` does not match the server's resource identifier is accepted. + + The spec mandates the resource server validate the token's audience; the bearer backend + never inspects `AccessToken.resource`, so the request passes the gate and the MCP endpoint + serves it. This pins current behaviour with the divergence recorded on the requirement. + """ + response = await post_mcp(protected, bearer="tok-wrong-aud") + + assert response.status_code == 200 + assert response.headers["content-type"].startswith("text/event-stream") + # The body is finite SSE: a result event followed by stream close. Pull the JSON-RPC response + # out of the buffered text to prove the MCP endpoint actually answered the initialize request. + [data] = [line.removeprefix("data: ") for line in response.text.splitlines() if line.startswith("data: ")] + assert "protocolVersion" in JSONRPCResponse.model_validate_json(data).result + + +@requirement("hosting:auth:query-token-ignored") +async def test_an_access_token_in_the_query_string_is_not_accepted(protected: httpx.AsyncClient) -> None: + """A valid token presented in the URI query string is treated as no authentication. + + The bearer backend reads only the `Authorization` header, so `?access_token=...` is never + consulted; the request is treated as unauthenticated and answered 401. This satisfies, by + absence, the security best-practice that resource servers must not accept query-string + tokens. + """ + response = await post_mcp(protected, query={"access_token": "tok-valid"}) + + assert response.status_code == 401 + assert parse_www_authenticate(response.headers["www-authenticate"])["error"] == "invalid_token" diff --git a/tests/interaction/auth/test_discovery.py b/tests/interaction/auth/test_discovery.py new file mode 100644 index 0000000000..68c33c8a2d --- /dev/null +++ b/tests/interaction/auth/test_discovery.py @@ -0,0 +1,333 @@ +"""Protected-resource and authorization-server metadata discovery, end to end. + +Every client-side test connects a real `Client` via `connect_with_oauth` and asserts on the +recorded request paths the discovery probes produced; the discovery URL ordering is a wire +detail `Client` cannot observe directly but the recording can. Tests that need a metadata +endpoint to 404 or return alternate content wrap the SDK's app in `shimmed_app` while leaving +the real authorize and token endpoints behind it, so the rest of the flow runs unaltered. + +The two server-side tests (#5, #6) drive raw httpx against `mounted_app` because their +assertions are the metadata response bodies and headers, which `Client` does not surface. +""" + +import json + +import anyio +import pytest +from inline_snapshot import snapshot +from pydantic import AnyHttpUrl + +from mcp import types +from mcp.client.auth import OAuthFlowError, OAuthRegistrationError +from mcp.server import Server, ServerRequestContext +from mcp.shared.auth import OAuthMetadata, ProtectedResourceMetadata +from mcp.types import ListToolsResult, Tool +from tests.interaction._connect import BASE_URL, mounted_app +from tests.interaction._requirements import requirement +from tests.interaction.auth._harness import ( + RecordedRequest, + auth_settings, + connect_with_oauth, + metadata_body, + record_requests, + shim, +) +from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider + +pytestmark = pytest.mark.anyio + +PRM_PATH_SUFFIXED = "/.well-known/oauth-protected-resource/mcp" +PRM_ROOT = "/.well-known/oauth-protected-resource" +ASM_ROOT = "/.well-known/oauth-authorization-server" +OIDC_ROOT = "/.well-known/openid-configuration" + + +async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="probe", input_schema={"type": "object"})]) + + +def discovery_gets(recorded: list[RecordedRequest]) -> list[str]: + """Return the well-known GET paths in recorded order, ignoring everything else.""" + return [r.path for r in recorded if r.method == "GET" and "/.well-known/" in r.path] + + +def real_asm() -> OAuthMetadata: + """Build an authorization-server metadata document pointing at the real co-hosted endpoints.""" + return OAuthMetadata( + issuer=AnyHttpUrl(BASE_URL), + authorization_endpoint=AnyHttpUrl(f"{BASE_URL}/authorize"), + token_endpoint=AnyHttpUrl(f"{BASE_URL}/token"), + registration_endpoint=AnyHttpUrl(f"{BASE_URL}/register"), + scopes_supported=["mcp"], + grant_types_supported=["authorization_code", "refresh_token"], + code_challenge_methods_supported=["S256"], + ) + + +@requirement("client-auth:prm-discovery:fallback-order") +async def test_prm_discovery_uses_the_resource_metadata_url_from_www_authenticate() -> None: + """The first protected-resource probe is the URL the 401's `WWW-Authenticate` header supplied. + + With co-hosted defaults the header carries the path-suffixed well-known URL; the client + fetches that one first and, because it succeeds, never falls back. The single-probe + sequence proves priority 1. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, on_request=on_request) as (client, _): + await client.list_tools() + + assert discovery_gets(recorded) == snapshot([PRM_PATH_SUFFIXED, ASM_ROOT]) + assert (recorded[0].method, recorded[0].path) == ("POST", "/mcp") + assert (recorded[1].method, recorded[1].path) == ("GET", PRM_PATH_SUFFIXED) + + +@requirement("client-auth:prm-discovery:fallback-order") +async def test_prm_discovery_falls_back_from_path_well_known_to_root_on_404() -> None: + """When the path-suffixed PRM well-known 404s, the client falls back to the root well-known. + + The exact GET count is not asserted: the WWW-Authenticate URL equals the path well-known + here, so the SDK probes it twice (once as priority 1, once as priority 2) before reaching + root. Asserting "path before root, root reached, then the flow proceeds" pins the spec + invariant; the duplicate probe is an implementation detail. The served PRM body carries an + unrecognized field to prove the client's parser ignores unknown members (RFC 9728 §3.2). + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + prm = ProtectedResourceMetadata( + resource=AnyHttpUrl(f"{BASE_URL}/mcp"), authorization_servers=[AnyHttpUrl(BASE_URL)] + ) + app_shim = shim( + not_found=frozenset({PRM_PATH_SUFFIXED}), + serve={PRM_ROOT: metadata_body(prm, x_unknown_extension="ignored")}, + ) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, app_shim=app_shim, on_request=on_request) as ( + client, + _, + ): + await client.list_tools() + + well_known = discovery_gets(recorded) + assert PRM_PATH_SUFFIXED in well_known + assert PRM_ROOT in well_known + assert well_known.index(PRM_PATH_SUFFIXED) < well_known.index(PRM_ROOT) + assert any(r.path == "/authorize" for r in recorded) + + +@requirement("client-auth:prm-discovery:no-prm-fallback") +async def test_when_every_prm_probe_fails_the_client_discovers_as_metadata_at_the_server_origin() -> None: + """When every protected-resource metadata probe 404s, the client falls back to the legacy path. + + The legacy 2025-03-26 behaviour: with no PRM document available, treat the MCP server's + origin as the authorization server and fetch its `/.well-known/oauth-authorization-server` + directly. The real co-hosted ASM endpoint is at exactly that location, so the flow completes. + The recorded sequence shows both PRM well-known paths probed (and failed) before ASM_ROOT. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + app_shim = shim(not_found=frozenset({PRM_PATH_SUFFIXED, PRM_ROOT})) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, app_shim=app_shim, on_request=on_request) as ( + client, + _, + ): + result = await client.list_tools() + + well_known = discovery_gets(recorded) + assert PRM_PATH_SUFFIXED in well_known + assert PRM_ROOT in well_known + assert well_known[-1] == ASM_ROOT + assert all(well_known.index(prm) < well_known.index(ASM_ROOT) for prm in (PRM_PATH_SUFFIXED, PRM_ROOT)) + assert result.tools[0].name == "probe" + + +@requirement("client-auth:dcr:registration-error-surfaces") +async def test_a_400_from_the_registration_endpoint_surfaces_as_a_registration_error() -> None: + """A 400 from `/register` surfaces as `OAuthRegistrationError` carrying the server's body. + + The shim makes `/register` return RFC 7591's `invalid_client_metadata`; the SDK reads the + body and raises with the status and text in the message, before any authorize or token + request is made. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + error_body = json.dumps({"error": "invalid_client_metadata", "error_description": "no"}).encode() + app_shim = shim(serve={"/register": (400, error_body)}) + + with anyio.fail_after(5): + with pytest.RaisesGroup( + pytest.RaisesExc(OAuthRegistrationError, match=r"^Registration failed: 400 .*invalid_client_metadata"), + flatten_subgroups=True, + ): + await connect_with_oauth(server, provider=provider, app_shim=app_shim, on_request=on_request).__aenter__() + + assert [r.path for r in recorded if r.path in ("/authorize", "/token")] == [] + + +@requirement("client-auth:prm-resource-mismatch") +async def test_prm_with_a_mismatched_resource_aborts_the_flow_before_authorize() -> None: + """A PRM document whose `resource` does not cover the server URL aborts the flow. + + The shim serves PRM at the URL the WWW-Authenticate header supplies, but with a `resource` + on a different path; `check_resource_allowed` rejects it and `OAuthFlowError` is raised + before any authorize or token request is made. The error reaches the test wrapped in nested + single-element exception groups by the streamable-HTTP client's task group. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + prm = ProtectedResourceMetadata( + resource=AnyHttpUrl(f"{BASE_URL}/other"), authorization_servers=[AnyHttpUrl(BASE_URL)] + ) + app_shim = shim(serve={PRM_PATH_SUFFIXED: metadata_body(prm)}) + + with anyio.fail_after(5): + with pytest.RaisesGroup( + pytest.RaisesExc(OAuthFlowError, match="^Protected resource .* does not match expected"), + flatten_subgroups=True, + ): + await connect_with_oauth(server, provider=provider, app_shim=app_shim, on_request=on_request).__aenter__() + + assert [r.path for r in recorded if r.path in ("/authorize", "/token")] == [] + + +@requirement("client-auth:as-metadata-discovery:priority-order") +@pytest.mark.parametrize( + ("authorization_server", "not_found", "serve_at", "expected_order"), + [ + pytest.param( + f"{BASE_URL}/", + frozenset({ASM_ROOT}), + OIDC_ROOT, + [ASM_ROOT, OIDC_ROOT], + id="root-issuer", + ), + pytest.param( + f"{BASE_URL}/tenant", + frozenset({f"{ASM_ROOT}/tenant", f"{OIDC_ROOT}/tenant"}), + "/tenant/.well-known/openid-configuration", + [f"{ASM_ROOT}/tenant", f"{OIDC_ROOT}/tenant", "/tenant/.well-known/openid-configuration"], + id="path-issuer", + ), + ], +) +async def test_as_metadata_discovery_falls_back_through_the_spec_endpoint_order( + authorization_server: str, not_found: frozenset[str], serve_at: str, expected_order: list[str] +) -> None: + """Authorization-server metadata is fetched at the spec's endpoints in the spec's order. + + The shim 404s every endpoint before the last so the recording proves each probe and its + position. For an issuer URL with no path the order is OAuth root then OIDC root; for an + issuer URL with a path component it is OAuth path-inserted, OIDC path-inserted, then OIDC + path-appended (the spec's three-endpoint MUST). The path-issuer case is driven by serving + a PRM whose `authorization_servers` carries the path; the SDK's own AS routes stay at root + (the served body points at the real `/authorize` and `/token`). The served bodies carry an + unrecognized field to prove the client's parser ignores unknown members (RFC 8414 §3.2). + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + prm = ProtectedResourceMetadata( + resource=AnyHttpUrl(f"{BASE_URL}/mcp"), authorization_servers=[AnyHttpUrl(authorization_server)] + ) + app_shim = shim( + not_found=not_found, + serve={ + PRM_PATH_SUFFIXED: metadata_body(prm), + serve_at: metadata_body(real_asm(), x_unknown_extension="ignored"), + }, + ) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, app_shim=app_shim, on_request=on_request) as ( + client, + _, + ): + await client.list_tools() + + assert discovery_gets(recorded) == [PRM_PATH_SUFFIXED, *expected_order] + + +@requirement("hosting:auth:metadata-endpoints") +@requirement("hosting:auth:prm:authorization-servers-field") +async def test_the_prm_endpoint_serves_the_resource_url_and_at_least_one_authorization_server() -> None: + """The protected-resource metadata document the SDK serves identifies the resource and an authorization server. + + Also asserts the response is `application/json` (RFC 9728 §3.2) and that fields the SDK has + no value for are absent rather than null (`PydanticJSONResponse` serializes with + `exclude_none=True`, satisfying RFC 9728 §3.2's omit-zero-value rule). + """ + server = Server("bare") + provider = InMemoryAuthorizationServerProvider() + + async with mounted_app(server, auth=auth_settings(), auth_server_provider=provider) as (http, _): + response = await http.get(PRM_PATH_SUFFIXED) + + assert response.status_code == 200 + assert response.headers["content-type"].startswith("application/json") + + document = json.loads(response.content) + assert "resource_documentation" not in document + assert "scopes_supported" in document + + metadata = ProtectedResourceMetadata.model_validate(document) + assert str(metadata.resource).rstrip("/") == f"{BASE_URL}/mcp" + assert len(metadata.authorization_servers) >= 1 + assert metadata.bearer_methods_supported == ["header"] + + +@requirement("hosting:auth:as-router") +async def test_as_metadata_advertises_authorize_token_registration_and_s256() -> None: + """The authorization-server metadata document the SDK serves names the required endpoints and S256.""" + server = Server("bare") + provider = InMemoryAuthorizationServerProvider() + + async with mounted_app(server, auth=auth_settings(), auth_server_provider=provider) as (http, _): + response = await http.get(ASM_ROOT) + + assert response.status_code == 200 + assert response.headers["content-type"].startswith("application/json") + + metadata = OAuthMetadata.model_validate_json(response.content) + assert str(metadata.issuer).rstrip("/") == BASE_URL + assert str(metadata.authorization_endpoint) == f"{BASE_URL}/authorize" + assert str(metadata.token_endpoint) == f"{BASE_URL}/token" + assert str(metadata.registration_endpoint) == f"{BASE_URL}/register" + assert metadata.response_types_supported == ["code"] + assert metadata.code_challenge_methods_supported is not None + assert "S256" in metadata.code_challenge_methods_supported + + +@requirement("client-auth:as-metadata-discovery:issuer-validation") +async def test_as_metadata_with_a_mismatched_issuer_is_accepted_and_the_flow_proceeds() -> None: + """Authorization-server metadata whose `issuer` does not match the discovery URL is accepted. + + RFC 8414 §3.3 requires the client to reject the document; the SDK parses and uses it + without comparing `issuer` to the URL it was fetched from. See the divergence on the + requirement. The served body carries an unrecognized field as a fold-in proof of + unknown-field tolerance. + """ + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + metadata = real_asm() + metadata.issuer = AnyHttpUrl(f"{BASE_URL}/wrong-issuer") + app_shim = shim(serve={ASM_ROOT: metadata_body(metadata, x_unknown_extension="ignored")}) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, app_shim=app_shim) as (client, _): + result = await client.list_tools() + + assert result.tools[0].name == "probe" diff --git a/tests/interaction/auth/test_flow.py b/tests/interaction/auth/test_flow.py new file mode 100644 index 0000000000..968fc5f980 --- /dev/null +++ b/tests/interaction/auth/test_flow.py @@ -0,0 +1,239 @@ +"""End-to-end OAuth authorization-code flow against the SDK's own server, fully in process. + +Auth is HTTP-only so these tests are not transport-parametrized; each connects via +`connect_with_oauth`, which co-hosts the SDK's authorization server, protected-resource +metadata, and bearer-gated MCP endpoint on one bridge-backed Starlette app and drives the +whole flow through one `httpx.AsyncClient` carrying the SDK's `OAuthClientProvider`. The +authorize redirect completes headlessly through the same bridge, so every request the flow +makes is observable via `on_request`. +""" + +import json +from collections import Counter +from urllib.parse import parse_qs, urlsplit + +import anyio +import httpx +import pytest +from inline_snapshot import snapshot +from pydantic import AnyUrl + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.server.auth.middleware.auth_context import get_access_token +from mcp.shared.auth import OAuthClientInformationFull +from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool +from tests.interaction._connect import BASE_URL +from tests.interaction._requirements import requirement +from tests.interaction.auth._harness import ( + REDIRECT_URI, + InMemoryTokenStorage, + auth_settings, + connect_with_oauth, + oauth_client_metadata, + shimmed_app, +) +from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider +from tests.interaction.transports._bridge import StreamingASGITransport + +pytestmark = pytest.mark.anyio + + +async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="whoami", input_schema={"type": "object"})]) + + +@requirement("flow:oauth:authorization-code-roundtrip") +@requirement("client-auth:401-triggers-flow") +@requirement("hosting:auth:missing-401") +async def test_an_unauthenticated_request_is_challenged_then_the_full_oauth_flow_connects() -> None: + """Connecting to a bearer-gated server walks the full authorization-code flow and succeeds. + + Three requirements are proven by one connect: the flow runs end to end (authorization-code + roundtrip), it was triggered by a 401 on the first MCP request (401-triggers-flow), and + that 401 carried `resource_metadata` in `WWW-Authenticate` for discovery (missing-401). + The flagship test pins the recorded request sequence so the discovery → registration → + authorize → token → retry order is asserted explicitly. + + Steps the SDK is expected to perform: + 1. POST /mcp without a token → 401 with `WWW-Authenticate: Bearer resource_metadata=...`. + 2. GET the protected-resource metadata. + 3. GET the authorization-server metadata. + 4. POST /register (dynamic client registration). + 5. GET /authorize → 302 with code+state (completed by the headless redirect). + 6. POST /token (authorization-code exchange). + 7. Retry POST /mcp with `Authorization: Bearer ` → succeeds. + """ + requests: list[httpx.Request] = [] + provider = InMemoryAuthorizationServerProvider() + storage = InMemoryTokenStorage() + server = Server("guarded", on_list_tools=list_tools) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, storage=storage, on_request=requests.append) as ( + client, + headless, + ): + result = await client.list_tools() + + assert result == snapshot(ListToolsResult(tools=[Tool(name="whoami", input_schema={"type": "object"})])) + assert headless.authorize_url is not None + + paths = [(r.method, r.url.path) for r in requests] + assert Counter(paths) == snapshot( + Counter( + { + ("POST", "/mcp"): 4, + ("GET", "/.well-known/oauth-protected-resource/mcp"): 1, + ("GET", "/.well-known/oauth-authorization-server"): 1, + ("POST", "/register"): 1, + ("GET", "/authorize"): 1, + ("POST", "/token"): 1, + ("GET", "/mcp"): 1, + ("DELETE", "/mcp"): 1, + } + ) + ) + + assert (requests[0].method, requests[0].url.path) == ("POST", "/mcp") + # The recorded Request objects are live references: the auth flow mutates the original + # request's headers in place when it adds the bearer token for the retry, so the first + # entry's headers cannot be used to assert "no Authorization on the first attempt". The + # path multiset above proving discovery happened is the evidence the first attempt was 401. + + # The first PRM discovery GET carries the protocol-version header (an SDK behaviour, not a + # spec requirement on discovery requests). + prm_get = next(r for r in requests if r.url.path == "/.well-known/oauth-protected-resource/mcp") + assert prm_get.headers.get("mcp-protocol-version") == snapshot("2025-11-25") + + authorize = parse_qs(urlsplit(headless.authorize_url).query) + assert authorize["response_type"] == ["code"] + assert authorize["code_challenge_method"] == ["S256"] + assert authorize["client_id"][0] in provider.clients + + assert storage.tokens is not None + bearer = f"Bearer {storage.tokens.access_token}" + authed_mcp = [r for r in requests if r.url.path == "/mcp" and r.headers.get("authorization") == bearer] + assert len(authed_mcp) > 0 + assert storage.tokens.access_token in provider.access_tokens + + +@requirement("hosting:auth:authinfo-propagates") +async def test_the_access_token_reaches_the_tool_handler_via_get_access_token() -> None: + """A tool handler reads the request's access token through `get_access_token()`.""" + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "whoami" + token = get_access_token() + assert token is not None + return CallToolResult(content=[TextContent(text=" ".join(token.scopes))]) + + server = Server("guarded", on_list_tools=list_tools, on_call_tool=call_tool) + provider = InMemoryAuthorizationServerProvider() + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider) as (client, _): + result = await client.call_tool("whoami", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="mcp")])) + + +@requirement("client-auth:pre-registration") +async def test_a_preregistered_client_skips_registration() -> None: + """A client whose storage already holds client info uses it instead of registering. + + The provider holds the same registration server-side so the authorize and token steps + accept it; the recorded requests prove no `/register` call was made. + """ + requests: list[httpx.Request] = [] + provider = InMemoryAuthorizationServerProvider() + storage = InMemoryTokenStorage() + server = Server("guarded", on_list_tools=list_tools) + + client_info = OAuthClientInformationFull( + client_id="preregistered", + client_secret="s3cret", + token_endpoint_auth_method="client_secret_post", + redirect_uris=[AnyUrl(REDIRECT_URI)], + grant_types=["authorization_code", "refresh_token"], + scope="mcp", + ) + await provider.register_client(client_info) + storage.client_info = client_info + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, storage=storage, on_request=requests.append) as ( + client, + _, + ): + await client.list_tools() + + assert [r.url.path for r in requests].count("/register") == 0 + assert list(provider.clients) == ["preregistered"] + + +@requirement("client-auth:dcr") +async def test_the_dcr_request_carries_the_client_metadata() -> None: + """Dynamic registration sends the client's metadata and persists what the server issued. + + The body of the recorded `/register` POST carries the metadata the test supplied (with the + scope filled in from server discovery), and the server's issued client_id and secret are + persisted to storage and held by the provider. + """ + requests: list[httpx.Request] = [] + provider = InMemoryAuthorizationServerProvider() + storage = InMemoryTokenStorage() + server = Server("guarded", on_list_tools=list_tools) + + client_metadata = oauth_client_metadata() + client_metadata.software_id = "interaction-test-suite" + + with anyio.fail_after(5): + async with connect_with_oauth( + server, provider=provider, storage=storage, client_metadata=client_metadata, on_request=requests.append + ) as (client, _): + await client.list_tools() + + register = next(r for r in requests if r.url.path == "/register") + assert register.headers["content-type"] == "application/json" + body = json.loads(register.content) + assert body == snapshot( + { + "redirect_uris": ["http://127.0.0.1:8000/oauth/callback"], + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "scope": "mcp", + "client_name": "interaction-suite", + "software_id": "interaction-test-suite", + } + ) + + assert storage.client_info is not None + assert storage.client_info.client_id is not None + assert storage.client_info.client_secret is not None + assert list(provider.clients) == [storage.client_info.client_id] + + +async def test_shimmed_app_serves_overrides_404s_and_otherwise_forwards_to_the_wrapped_app() -> None: + """Harness self-test: `shimmed_app` serves canned bodies, 404s, and forwards everything else. + + Wraps a real auth-hosting Starlette app so the forward path is exercised against the SDK's + own routing; provided here so the discovery tests can rely on the shim without each adding + their own contract test. + """ + server = Server("bare") + provider = InMemoryAuthorizationServerProvider() + real_app = server.streamable_http_app(auth=auth_settings(), auth_server_provider=provider) + app = shimmed_app(real_app, not_found=frozenset({"/missing"}), serve={"/override": b'{"shimmed": true}'}) + async with server.session_manager.run(): + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http: + served = await http.get("/override") + assert served.status_code == 200 + assert served.headers["content-type"] == "application/json" + assert served.json() == {"shimmed": True} + + assert (await http.get("/missing")).status_code == 404 + + forwarded = await http.get("/.well-known/oauth-authorization-server") + assert forwarded.status_code == 200 + assert forwarded.json()["issuer"] == "http://127.0.0.1:8000/" diff --git a/tests/interaction/auth/test_lifecycle.py b/tests/interaction/auth/test_lifecycle.py new file mode 100644 index 0000000000..aa552ae8a6 --- /dev/null +++ b/tests/interaction/auth/test_lifecycle.py @@ -0,0 +1,445 @@ +"""Token lifecycle, step-up, and registration-variant flows of the SDK's OAuth client. + +Every test connects end to end via `connect_with_oauth`; the assertions are recording-first +(the recorded request sequence is asserted before, or independently of, the call result), so a +surprise in the refresh or step-up paths produces a readable diff of what fired rather than an +opaque failure. The provider knobs that drive each scenario are documented per test. +""" + +import base64 +from collections import Counter +from urllib.parse import parse_qsl, urlsplit + +import anyio +import pytest +from inline_snapshot import snapshot +from pydantic import AnyHttpUrl, AnyUrl + +from mcp import MCPError, types +from mcp.client.auth.extensions.client_credentials import ClientCredentialsOAuthProvider, PrivateKeyJWTOAuthProvider +from mcp.server import Server, ServerRequestContext +from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata +from mcp.types import INTERNAL_ERROR, ListToolsResult, Tool +from tests.interaction._connect import BASE_URL +from tests.interaction._requirements import requirement +from tests.interaction.auth._harness import ( + REDIRECT_URI, + InMemoryTokenStorage, + RecordedRequest, + auth_settings, + connect_with_oauth, + m2m_token_shim, + metadata_body, + record_requests, + shim, + step_up_shim, +) +from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider + +pytestmark = pytest.mark.anyio + +PRM_PATH = "/.well-known/oauth-protected-resource/mcp" +ASM_PATH = "/.well-known/oauth-authorization-server" +CIMD_URL = "https://client.example/.well-known/mcp-client" + + +async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="echo", input_schema={"type": "object"})]) + + +def form_body(request: RecordedRequest) -> dict[str, str]: + """Parse an `application/x-www-form-urlencoded` request body into a flat dict.""" + return dict(parse_qsl(request.content.decode())) + + +def authorize_params(authorize_url: str) -> dict[str, str]: + """Parse the authorize URL's query string into a flat dict.""" + return dict(parse_qsl(urlsplit(authorize_url).query)) + + +def find(recorded: list[RecordedRequest], method: str, path: str) -> list[RecordedRequest]: + return [r for r in recorded if r.method == method and r.path == path] + + +def path_counts(recorded: list[RecordedRequest]) -> Counter[tuple[str, str]]: + return Counter((r.method, r.path) for r in recorded) + + +def cimd_supported_metadata() -> bytes: + """AS metadata advertising `client_id_metadata_document_supported: true` (the SDK server never sets it).""" + metadata = OAuthMetadata( + issuer=AnyHttpUrl(f"{BASE_URL}/"), + authorization_endpoint=AnyHttpUrl(f"{BASE_URL}/authorize"), + token_endpoint=AnyHttpUrl(f"{BASE_URL}/token"), + registration_endpoint=AnyHttpUrl(f"{BASE_URL}/register"), + scopes_supported=["mcp"], + response_types_supported=["code"], + grant_types_supported=["authorization_code", "refresh_token"], + code_challenge_methods_supported=["S256"], + client_id_metadata_document_supported=True, + ) + return metadata_body(metadata) + + +def seeded_client(provider: InMemoryAuthorizationServerProvider, **kwargs: object) -> OAuthClientInformationFull: + """Register a client with the provider and return its info, for pre-registration and CIMD scenarios.""" + base: dict[str, object] = { + "client_id": "preregistered", + "token_endpoint_auth_method": "none", + "redirect_uris": [AnyUrl(REDIRECT_URI)], + "grant_types": ["authorization_code", "refresh_token"], + "scope": "mcp", + } + base.update(kwargs) + info = OAuthClientInformationFull.model_validate(base) + assert info.client_id is not None + provider.clients[info.client_id] = info + return info + + +@requirement("client-auth:refresh:transparent") +async def test_an_expired_access_token_is_transparently_refreshed_before_the_next_request() -> None: + """An access token the client considers expired is refreshed and the new bearer is used. + + The provider tells the client `expires_in=-3600` for the first token while keeping the + server-side `expires_at` in the future, so the connect's retry succeeds and the next + request finds the token expired and refreshes. The recorded requests prove exactly one + `grant_type=refresh_token` exchange carrying the resource indicator, and the bearer used + after the refresh is the second access token, which is the one persisted to storage. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider(issue_expired_first=True) + storage = InMemoryTokenStorage() + server = Server("guarded", on_list_tools=list_tools) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, storage=storage, on_request=on_request) as (client, _): + result = await client.list_tools() + + assert result.tools[0].name == "echo" + + token_posts = find(recorded, "POST", "/token") + bodies = [form_body(r) for r in token_posts] + assert [b["grant_type"] for b in bodies] == snapshot(["authorization_code", "refresh_token"]) + + refresh_body = bodies[1] + assert sorted(refresh_body) == snapshot(["client_id", "client_secret", "grant_type", "refresh_token", "resource"]) + assert refresh_body["refresh_token"].startswith("refresh_") + assert refresh_body["resource"].startswith(BASE_URL) + + bearers = {r.headers["authorization"] for r in recorded if r.path == "/mcp" and "authorization" in r.headers} + assert len(bearers) == 2 + assert storage.tokens is not None + assert f"Bearer {storage.tokens.access_token}" in bearers + assert storage.tokens.expires_in == 3600 + + +@requirement("client-auth:403-scope-upgrade") +async def test_a_403_insufficient_scope_triggers_one_reauthorize_with_the_challenged_scope() -> None: + """A 403 `insufficient_scope` challenge is answered by one re-authorize with the challenge's scope. + + The shim 403s the second authenticated `/mcp` POST (the `notifications/initialized` request, + which reaches the auth flow's step-up handler; the first authenticated POST is the post-401 + retry, after which the generator ends without inspecting the response). The challenge names a + wider scope; step-up reuses cached metadata and the existing client registration, + re-authorizes with the new scope, and the connect completes. The client is pre-registered + with both scopes so the server's authorize handler accepts the wider second request. One + re-authorize, one retry; the spec's SHOULD-retry-limit ("a few") is not enforced. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + storage = InMemoryTokenStorage(client_info=seeded_client(provider, scope="mcp write")) + server = Server("guarded", on_list_tools=list_tools) + settings = auth_settings(required_scopes=["mcp"], valid_scopes=["mcp", "write"]) + challenge = 'Bearer error="insufficient_scope", scope="mcp write"' + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + storage=storage, + settings=settings, + app_shim=step_up_shim(challenge), + on_request=on_request, + ) as (client, headless): + result = await client.list_tools() + + assert result.tools[0].name == "echo" + + assert len(headless.authorize_urls) == 2 + assert authorize_params(headless.authorize_urls[0])["scope"] == "mcp" + assert authorize_params(headless.authorize_urls[1])["scope"] == "mcp write" + + counts = path_counts(recorded) + assert counts[("GET", PRM_PATH)] == 1 + assert counts[("GET", ASM_PATH)] == 1 + assert counts[("POST", "/register")] == 0 + assert counts[("GET", "/authorize")] == 2 + assert counts[("POST", "/token")] == 2 + + +@requirement("client-auth:401-after-auth-throws") +async def test_a_second_401_after_a_completed_oauth_flow_surfaces_without_looping() -> None: + """A 401 on the post-auth retry surfaces as an error rather than re-entering discovery. + + The provider rejects every token at verification, so the full flow runs once and the retry + is 401'd. The auth-flow generator ends after that retry, so the 401 propagates and the + transport converts it to an INTERNAL_ERROR result, raising during connect. Discovery, + registration, authorize, and token each ran exactly once: no loop. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider(reject_all_tokens=True) + server = Server("guarded", on_list_tools=list_tools) + + def is_internal_error(error: MCPError) -> bool: + return error.error.code == INTERNAL_ERROR + + with anyio.fail_after(5): + with pytest.RaisesGroup(pytest.RaisesExc(MCPError, check=is_internal_error), flatten_subgroups=True): + # Entering the connect raises during the OAuth handshake (inside `Client.__aenter__`), + # so an `async with` body would be unreachable; entering explicitly avoids dead code. + await connect_with_oauth(server, provider=provider, on_request=on_request).__aenter__() + + counts = path_counts(recorded) + assert counts[("GET", PRM_PATH)] == 1 + assert counts[("GET", ASM_PATH)] == 1 + assert counts[("POST", "/register")] == 1 + assert counts[("GET", "/authorize")] == 1 + assert counts[("POST", "/token")] == 1 + assert counts[("POST", "/mcp")] == 2 + + +@requirement("client-auth:cimd") +async def test_cimd_is_selected_when_the_as_advertises_support_and_a_metadata_url_is_supplied() -> None: + """A client-ID metadata-document URL is used as `client_id` instead of registering. + + AS metadata is shimmed to advertise `client_id_metadata_document_supported: true`; the + provider is pre-seeded so the server's authorize and token handlers accept the URL as a + client_id (the SDK server has no CIMD-aware client lookup of its own). The recorded + requests prove no `/register` call, the authorize URL's `client_id` is the CIMD URL, the + token request uses `token_endpoint_auth_method=none`, and storage persists the URL as + `client_id`. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + seeded_client(provider, client_id=CIMD_URL) + storage = InMemoryTokenStorage() + server = Server("guarded", on_list_tools=list_tools) + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + storage=storage, + client_metadata_url=CIMD_URL, + app_shim=shim(serve={ASM_PATH: cimd_supported_metadata()}), + on_request=on_request, + ) as (client, headless): + await client.list_tools() + + assert find(recorded, "POST", "/register") == [] + assert headless.authorize_url is not None + assert authorize_params(headless.authorize_url)["client_id"] == CIMD_URL + + [token_req] = find(recorded, "POST", "/token") + body = form_body(token_req) + assert body["client_id"] == CIMD_URL + assert "client_secret" not in body + assert "authorization" not in token_req.headers + + assert storage.client_info is not None + assert storage.client_info.client_id == CIMD_URL + assert storage.client_info.token_endpoint_auth_method == "none" + + +@requirement("client-auth:invalid-grant-clears-tokens") +async def test_a_failed_refresh_clears_stored_tokens_and_restarts_the_full_flow() -> None: + """A non-200 refresh response clears the in-memory tokens and the flow re-runs from discovery. + + The first token is reported expired so the next request refreshes; the provider denies the + refresh once with `invalid_grant`, the auth flow clears its tokens, the unauthenticated + request 401s, and discovery, authorize, and token run again. The original registration is + preserved (`client_info` is not cleared). The SDK clears tokens on any non-200 refresh + response, not specifically `error=invalid_grant`; `source="sdk"` so this is a precision + note rather than a divergence. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider(issue_expired_first=True, fail_next_refresh=True) + storage = InMemoryTokenStorage() + server = Server("guarded", on_list_tools=list_tools) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, storage=storage, on_request=on_request) as (client, _): + result = await client.list_tools() + + assert result.tools[0].name == "echo" + + token_posts = find(recorded, "POST", "/token") + assert [form_body(r)["grant_type"] for r in token_posts] == snapshot( + ["authorization_code", "refresh_token", "authorization_code"] + ) + + counts = path_counts(recorded) + assert counts[("POST", "/register")] == 1 + assert counts[("GET", "/authorize")] == 2 + assert counts[("GET", PRM_PATH)] == 2 + assert counts[("GET", ASM_PATH)] == 2 + + assert storage.client_info is not None + assert storage.tokens is not None + assert storage.tokens.access_token in provider.access_tokens + + +@requirement("client-auth:client-credentials") +async def test_client_credentials_provider_obtains_a_token_without_an_authorize_step() -> None: + """The client-credentials provider connects with no authorize step and a `client_credentials` grant. + + The SDK server's `TokenHandler` does not route `client_credentials`, so the harness shim + handles it (the shim is harness; the SDK-under-test is the client provider). The recorded + `/token` body proves the grant type, scope, resource indicator, and HTTP-Basic client + authentication; no `/authorize` or `/register` request was made. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + auth = ClientCredentialsOAuthProvider( + server_url=f"{BASE_URL}/mcp", + storage=InMemoryTokenStorage(), + client_id="m2m-client", + client_secret="m2m-secret", + scopes="mcp", + ) + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + auth=auth, + app_shim=m2m_token_shim(provider, scopes=["mcp"]), + on_request=on_request, + ) as (client, headless): + result = await client.list_tools() + + assert result.tools[0].name == "echo" + assert headless.authorize_url is None + assert find(recorded, "GET", "/authorize") == [] + assert find(recorded, "POST", "/register") == [] + + [token_req] = find(recorded, "POST", "/token") + body = form_body(token_req) + assert body == snapshot( + {"grant_type": "client_credentials", "resource": "http://127.0.0.1:8000/mcp", "scope": "mcp"} + ) + decoded = base64.b64decode(token_req.headers["authorization"].removeprefix("Basic ")).decode() + assert decoded == "m2m-client:m2m-secret" + + +@requirement("client-auth:private-key-jwt") +async def test_private_key_jwt_provider_authenticates_the_token_request_with_an_assertion() -> None: + """The private-key-JWT provider sends a `client_assertion` on the token request, with the issuer as audience. + + The assertion provider is a closure that records the audience it was called with and returns + a fixed opaque value (the JWT contents are not the SDK's concern here); the test asserts the + `client_assertion`/`client_assertion_type` form fields and that the audience matches the AS + metadata's issuer. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + audiences: list[str] = [] + + async def assertion_provider(audience: str) -> str: + audiences.append(audience) + return "header.payload.sig" + + auth = PrivateKeyJWTOAuthProvider( + server_url=f"{BASE_URL}/mcp", + storage=InMemoryTokenStorage(), + client_id="m2m-jwt-client", + assertion_provider=assertion_provider, + scopes="mcp", + ) + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + auth=auth, + app_shim=m2m_token_shim(provider, scopes=["mcp"]), + on_request=on_request, + ) as (client, _): + result = await client.list_tools() + + assert result.tools[0].name == "echo" + assert audiences == [f"{BASE_URL}/"] + + [token_req] = find(recorded, "POST", "/token") + body = form_body(token_req) + assert body == snapshot( + { + "grant_type": "client_credentials", + "client_assertion": "header.payload.sig", + "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", + "resource": "http://127.0.0.1:8000/mcp", + "scope": "mcp", + } + ) + assert "client_secret" not in body + assert "authorization" not in token_req.headers + + +@pytest.mark.parametrize( + ("case", "preseed_storage", "advertise_cimd"), + [("cimd_unsupported_falls_through_to_dcr", False, False), ("preregistered_beats_cimd", True, True)], + ids=["cimd_unsupported_falls_through_to_dcr", "preregistered_beats_cimd"], +) +@requirement("client-auth:cimd") +async def test_registration_priority_prefers_preregistered_then_cimd_then_dcr( + case: str, preseed_storage: bool, advertise_cimd: bool +) -> None: + """The client picks pre-registration over CIMD over DCR, falling through when each is unavailable. + + Two priority edges are exercised: with a CIMD URL configured but no AS support, DCR runs and + the registered `client_id` is used; with a CIMD URL configured and AS support but a + pre-registered client in storage, the stored `client_id` is used and neither CIMD nor DCR + runs. (The positive CIMD case and pre-registration over DCR are covered by their own tests.) + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + storage = InMemoryTokenStorage() + + expected_client_id: str + if preseed_storage: + info = seeded_client(provider) + storage.client_info = info + assert info.client_id is not None + expected_client_id = info.client_id + else: + expected_client_id = "" + + app_shim = shim(serve={ASM_PATH: cimd_supported_metadata()}) if advertise_cimd else None + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + storage=storage, + client_metadata_url=CIMD_URL, + app_shim=app_shim, + on_request=on_request, + ) as (client, headless): + await client.list_tools() + + assert headless.authorize_url is not None + chosen_client_id = authorize_params(headless.authorize_url)["client_id"] + assert chosen_client_id != CIMD_URL + + if case == "cimd_unsupported_falls_through_to_dcr": + assert len(find(recorded, "POST", "/register")) == 1 + assert chosen_client_id in provider.clients + else: + assert find(recorded, "POST", "/register") == [] + assert chosen_client_id == expected_client_id diff --git a/tests/interaction/conftest.py b/tests/interaction/conftest.py new file mode 100644 index 0000000000..c2ace45077 --- /dev/null +++ b/tests/interaction/conftest.py @@ -0,0 +1,23 @@ +"""Shared fixtures for the interaction suite.""" + +import pytest + +from tests.interaction._connect import Connect, connect_in_memory, connect_over_sse, connect_over_streamable_http + +_FACTORIES: dict[str, Connect] = { + "in-memory": connect_in_memory, + "streamable-http": connect_over_streamable_http, + "sse": connect_over_sse, +} + + +@pytest.fixture(params=sorted(_FACTORIES)) +def connect(request: pytest.FixtureRequest) -> Connect: + """The transport-parametrized connection factory: a test using it runs once per transport. + + Tests that are tied to one transport (the wire-recording tests, the bare-ClientSession tests, + the transport-specific tests under transports/) do not use this fixture and connect directly. + """ + transport_name = request.param + assert isinstance(transport_name, str) + return _FACTORIES[transport_name] diff --git a/tests/interaction/lowlevel/__init__.py b/tests/interaction/lowlevel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/interaction/lowlevel/test_cancellation.py b/tests/interaction/lowlevel/test_cancellation.py new file mode 100644 index 0000000000..6f1454e58a --- /dev/null +++ b/tests/interaction/lowlevel/test_cancellation.py @@ -0,0 +1,234 @@ +"""Cancellation interactions against the low-level Server, driven through the public Client API. + +There is no client-side cancellation API: cancelling means sending a CancelledNotification +carrying the request id, which only the server-side handler can observe (`ctx.request_id`), so +these tests capture the id from inside the blocked handler before cancelling. The handler blocks +on an Event rather than a sleep, and every wait is bounded by `anyio.fail_after`. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client import ClientSession +from mcp.server import Server, ServerRequestContext +from mcp.shared.memory import MessageStream, create_client_server_memory_streams +from mcp.shared.message import SessionMessage +from mcp.types import ( + CallToolResult, + EmptyResult, + ErrorData, + Implementation, + InitializeResult, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + PingRequest, + ServerCapabilities, + TextContent, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("protocol:cancel:in-flight") +@requirement("protocol:cancel:handler-abort-propagates") +async def test_cancellation_stops_in_flight_handler(connect: Connect) -> None: + """Cancelling an in-flight request interrupts its handler and fails the pending call. + + The server answers the cancelled request with an error response (the spec says it should + not respond at all; see the divergence note on the requirement), so the caller's pending + request raises rather than hanging. + """ + started = anyio.Event() + handler_cancelled = anyio.Event() + request_ids: list[types.RequestId] = [] + errors: list[ErrorData] = [] + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "block" + assert ctx.request_id is not None + request_ids.append(ctx.request_id) + started.set() + try: + await anyio.Event().wait() # blocks until cancelled; nothing ever sets this event + except anyio.get_cancelled_exc_class(): + handler_cancelled.set() + raise + raise NotImplementedError # unreachable: the wait above never completes normally + + server = Server("blocker", on_call_tool=call_tool) + + async with connect(server) as client: + with anyio.fail_after(5): + async with anyio.create_task_group() as task_group: + + async def call_and_capture_error() -> None: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("block", {}) + errors.append(exc_info.value.error) + + task_group.start_soon(call_and_capture_error) + await started.wait() + await client.session.send_notification( + types.CancelledNotification( + params=types.CancelledNotificationParams(request_id=request_ids[0], reason="user aborted") + ) + ) + + await handler_cancelled.wait() + + assert errors == snapshot([ErrorData(code=0, message="Request cancelled")]) + + +@requirement("protocol:cancel:server-survives") +async def test_session_serves_requests_after_cancellation(connect: Connect) -> None: + """A request cancelled mid-flight does not poison the session: the next request succeeds.""" + started = anyio.Event() + request_ids: list[types.RequestId] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool(name="block", input_schema={"type": "object"}), + types.Tool(name="echo", input_schema={"type": "object"}), + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + if params.name == "echo": + return CallToolResult(content=[TextContent(text="still alive")]) + assert ctx.request_id is not None + request_ids.append(ctx.request_id) + started.set() + await anyio.Event().wait() # blocks until cancelled + raise NotImplementedError # unreachable + + server = Server("blocker", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + with anyio.fail_after(5): + async with anyio.create_task_group() as task_group: + + async def call_and_swallow_cancellation_error() -> None: + with pytest.raises(MCPError): + await client.call_tool("block", {}) + + task_group.start_soon(call_and_swallow_cancellation_error) + await started.wait() + await client.session.send_notification( + types.CancelledNotification(params=types.CancelledNotificationParams(request_id=request_ids[0])) + ) + + result = await client.call_tool("echo", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="still alive")])) + + +@requirement("protocol:cancel:unknown-id-ignored") +async def test_cancellation_for_unknown_request_is_ignored(connect: Connect) -> None: + """A cancellation referencing a request id that is not in flight is ignored without error.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="echo", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "echo" + return CallToolResult(content=[TextContent(text="unbothered")]) + + server = Server("calm", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.session.send_notification( + types.CancelledNotification(params=types.CancelledNotificationParams(request_id=9999)) + ) + result = await client.call_tool("echo", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="unbothered")])) + + +@requirement("protocol:cancel:late-response-ignored") +async def test_a_response_for_an_unknown_request_id_surfaces_to_the_message_handler() -> None: + """A response whose id matches no in-flight request is surfaced to the message handler as a RuntimeError. + + The spec says a sender SHOULD ignore a response that arrives after it issued a cancellation; + that is the same client-side code path as any response with an unknown id, and that form is + deterministic to test without depending on the cancellation API the SDK does not yet provide. + See the divergence note on the requirement. + + A real Server cannot be made to answer with a fabricated id, so the test plays the server's + side of the wire by hand. Reserve this pattern for behaviour no real server can produce. The + other tests in this file run over the transport matrix; this one is in-memory only because the + scripted-peer mechanism is the in-memory stream pair, not because the behaviour is + transport-specific. + """ + + async def scripted_server(streams: MessageStream) -> None: + server_read, server_write = streams + + def respond(request_id: types.RequestId, result: types.Result) -> SessionMessage: + return SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request_id, + # Serialized exactly as a real server serializes results onto the wire. + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + + init = await server_read.receive() + assert isinstance(init, SessionMessage) + assert isinstance(init.message, JSONRPCRequest) + assert init.message.method == "initialize" + await server_write.send( + respond( + init.message.id, + InitializeResult( + protocol_version="2025-11-25", + capabilities=ServerCapabilities(), + server_info=Implementation(name="scripted", version="0.0.1"), + ), + ) + ) + + initialized = await server_read.receive() + assert isinstance(initialized, SessionMessage) + assert isinstance(initialized.message, JSONRPCNotification) + assert initialized.message.method == "notifications/initialized" + + ping = await server_read.receive() + assert isinstance(ping, SessionMessage) + assert isinstance(ping.message, JSONRPCRequest) + assert ping.message.method == "ping" + # First answer with a fabricated id that matches nothing in flight, then the real id. + await server_write.send(respond(9999, EmptyResult())) + await server_write.send(respond(ping.message.id, EmptyResult())) + + incoming: list[IncomingMessage] = [] + + async def message_handler(message: IncomingMessage) -> None: + incoming.append(message) + + async with ( + create_client_server_memory_streams() as ((client_read, client_write), server_streams), + anyio.create_task_group() as task_group, + ClientSession(client_read, client_write, message_handler=message_handler) as session, + ): + task_group.start_soon(scripted_server, server_streams) + with anyio.fail_after(5): + await session.initialize() + pong = await session.send_request(PingRequest(), EmptyResult) + + assert pong == snapshot(EmptyResult()) + assert len(incoming) == 1 + assert isinstance(incoming[0], RuntimeError) + # The full message embeds the response object's repr; only the prefix is stable. + assert str(incoming[0]).startswith("Received response with an unknown request ID:") diff --git a/tests/interaction/lowlevel/test_completion.py b/tests/interaction/lowlevel/test_completion.py new file mode 100644 index 0000000000..6a35404df3 --- /dev/null +++ b/tests/interaction/lowlevel/test_completion.py @@ -0,0 +1,131 @@ +"""Completion interactions against the low-level Server, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + INVALID_PARAMS, + METHOD_NOT_FOUND, + CompleteResult, + Completion, + ErrorData, + PromptReference, + ResourceTemplateReference, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("completion:prompt-arg") +@requirement("completion:result-shape") +async def test_complete_prompt_argument(connect: Connect) -> None: + """Completing a prompt argument delivers the ref, argument name, and current value to the handler. + + The returned values are filtered by the argument's value, proving the value reached the handler. + """ + + async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult: + assert isinstance(params.ref, PromptReference) + assert params.ref.name == "code_review" + assert params.argument.name == "language" + candidates = ["python", "pytorch", "ruby"] + matches = [candidate for candidate in candidates if candidate.startswith(params.argument.value)] + return CompleteResult(completion=Completion(values=matches, total=len(matches), has_more=False)) + + server = Server("completer", on_completion=completion) + + async with connect(server) as client: + result = await client.complete( + PromptReference(name="code_review"), argument={"name": "language", "value": "py"} + ) + + assert result == snapshot( + CompleteResult(completion=Completion(values=["python", "pytorch"], total=2, has_more=False)) + ) + + +@requirement("completion:resource-template-arg") +async def test_complete_resource_template_variable(connect: Connect) -> None: + """Completing a URI template variable delivers the template URI and variable name to the handler.""" + + async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult: + assert isinstance(params.ref, ResourceTemplateReference) + assert params.ref.uri == "github://repos/{owner}/{repo}" + assert params.argument.name == "owner" + return CompleteResult(completion=Completion(values=[f"{params.argument.value}contextprotocol"])) + + server = Server("completer", on_completion=completion) + + async with connect(server) as client: + result = await client.complete( + ResourceTemplateReference(uri="github://repos/{owner}/{repo}"), + argument={"name": "owner", "value": "model"}, + ) + + assert result == snapshot(CompleteResult(completion=Completion(values=["modelcontextprotocol"]))) + + +@requirement("completion:context-arguments") +async def test_complete_receives_context_arguments(connect: Connect) -> None: + """Previously-resolved arguments passed as completion context reach the handler. + + The returned value is derived from the context, proving it arrived. + """ + + async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult: + assert params.argument.name == "repo" + assert params.context is not None + assert params.context.arguments is not None + return CompleteResult(completion=Completion(values=[f"{params.context.arguments['owner']}/python-sdk"])) + + server = Server("completer", on_completion=completion) + + async with connect(server) as client: + result = await client.complete( + ResourceTemplateReference(uri="github://repos/{owner}/{repo}"), + argument={"name": "repo", "value": ""}, + context_arguments={"owner": "modelcontextprotocol"}, + ) + + assert result == snapshot(CompleteResult(completion=Completion(values=["modelcontextprotocol/python-sdk"]))) + + +@requirement("completion:error:invalid-ref") +async def test_completion_against_an_unknown_ref_is_rejected_with_invalid_params(connect: Connect) -> None: + """completion/complete with a ref naming an unknown prompt is answered with -32602 Invalid params. + + The lowlevel server does not validate refs itself (it has no prompt/template registry to check + against); rejecting an unknown ref is the handler's job, and this test pins the spec-recommended + way to do it. + """ + + async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult: + assert isinstance(params.ref, PromptReference) + raise MCPError(code=INVALID_PARAMS, message=f"Unknown prompt: {params.ref.name!r}") + + server = Server("completer", on_completion=completion) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.complete(PromptReference(name="ghost"), argument={"name": "x", "value": ""}) + + assert exc_info.value.error.code == INVALID_PARAMS + + +@requirement("completion:complete:not-supported") +@requirement("protocol:error:method-not-found") +async def test_complete_without_handler_is_method_not_found(connect: Connect) -> None: + """A server with no completion handler advertises no completions capability and rejects the request.""" + server = Server("incomplete") + + async with connect(server) as client: + assert client.initialize_result.capabilities.completions is None + + with pytest.raises(MCPError) as exc_info: + await client.complete(PromptReference(name="anything"), argument={"name": "topic", "value": ""}) + + assert exc_info.value.error == snapshot(ErrorData(code=METHOD_NOT_FOUND, message="Method not found")) diff --git a/tests/interaction/lowlevel/test_elicitation.py b/tests/interaction/lowlevel/test_elicitation.py new file mode 100644 index 0000000000..b8edf601d0 --- /dev/null +++ b/tests/interaction/lowlevel/test_elicitation.py @@ -0,0 +1,662 @@ +"""Form- and URL-mode elicitation against the low-level Server, driven through the public Client API. + +The final test plays the server's side of the wire by hand to issue an elicitation request with no +mode field, because the typed server API (`elicit_form`/`elicit_url`) always serializes one. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, UrlElicitationRequiredError, types +from mcp.client import ClientRequestContext, ClientSession +from mcp.server import Server, ServerRequestContext +from mcp.shared.memory import MessageStream, create_client_server_memory_streams +from mcp.shared.message import SessionMessage +from mcp.types import ( + CallToolResult, + ElicitCompleteNotification, + ElicitCompleteNotificationParams, + ElicitRequestedSchema, + ElicitRequestFormParams, + ElicitRequestURLParams, + ElicitResult, + ErrorData, + Implementation, + InitializeResult, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + ServerCapabilities, + TextContent, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + +REQUESTED_SCHEMA: dict[str, object] = { + "type": "object", + "properties": { + "username": {"type": "string"}, + "newsletter": {"type": "boolean"}, + }, + "required": ["username"], +} + + +@requirement("elicitation:form:action:accept") +@requirement("elicitation:form:basic") +@requirement("tools:call:elicitation-roundtrip") +async def test_elicit_form_accepted_content_returns_to_handler(connect: Connect) -> None: + """An accepted form elicitation returns the user's content to the requesting handler. + + The tool reports the action as text and the received content as structured content, proving + the client's answer made it back into the tool's own result. + """ + received: list[types.ElicitRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="signup", description="Register the user.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "signup" + answer = await ctx.session.elicit_form("Choose a username.", REQUESTED_SCHEMA) + return CallToolResult(content=[TextContent(text=answer.action)], structured_content=answer.content) + + server = Server("registrar", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="accept", content={"username": "ada", "newsletter": True}) + + async with connect(server, elicitation_callback=answer_form) as client: + result = await client.call_tool("signup", {}) + + assert received == snapshot( + [ + ElicitRequestFormParams( + _meta={}, + message="Choose a username.", + requested_schema={ + "type": "object", + "properties": { + "username": {"type": "string"}, + "newsletter": {"type": "boolean"}, + }, + "required": ["username"], + }, + ) + ] + ) + assert result == snapshot( + CallToolResult( + content=[TextContent(text="accept")], + structured_content={"username": "ada", "newsletter": True}, + ) + ) + + +@requirement("elicitation:form:action:decline") +async def test_elicit_form_decline_returns_no_content(connect: Connect) -> None: + """A declined form elicitation returns the decline action to the handler with no content.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="confirm", description="Ask for confirmation.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "confirm" + answer = await ctx.session.elicit_form("Proceed?", {"type": "object", "properties": {}}) + return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + + server = Server("confirmer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="decline") + + async with connect(server, elicitation_callback=answer_form) as client: + result = await client.call_tool("confirm", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="decline content=None")])) + + +@requirement("elicitation:form:action:cancel") +async def test_elicit_form_cancel_returns_no_content(connect: Connect) -> None: + """A cancelled form elicitation returns the cancel action to the handler with no content.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="confirm", description="Ask for confirmation.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "confirm" + answer = await ctx.session.elicit_form("Proceed?", {"type": "object", "properties": {}}) + return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + + server = Server("confirmer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="cancel") + + async with connect(server, elicitation_callback=answer_form) as client: + result = await client.call_tool("confirm", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="cancel content=None")])) + + +@requirement("elicitation:form:not-supported") +@requirement("elicitation:capability:server-respects-mode") +async def test_elicit_form_without_callback_is_error(connect: Connect) -> None: + """Eliciting from a client that configured no elicitation callback fails with an error. + + The client's default callback answers with an Invalid request error, which the server-side + elicit call raises as an MCPError; the tool reports the code and message it caught. The spec + requires -32602 for an undeclared mode (see the divergence note on the requirement). The + request reaching the client also shows the server does not check the client's declared + elicitation capability before sending (see the divergence on `server-respects-mode`). + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="ask", description="Ask the user.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask" + try: + await ctx.session.elicit_form("Anyone there?", {"type": "object", "properties": {}}) + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # elicit_form cannot succeed without a client callback + + server = Server("asker", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("ask", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="-32600: Elicitation not supported")])) + + +@requirement("elicitation:url:action:accept-no-content") +@requirement("elicitation:url:basic") +async def test_elicit_url_delivers_url_and_returns_accept_without_content(connect: Connect) -> None: + """A URL elicitation delivers the message, URL, and elicitation id to the client; accepting it + returns the action with no content. + + Accept means the user agreed to visit the URL, not that the out-of-band interaction finished, + so there is never form content to return. + """ + received: list[types.ElicitRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="authorize", description="Link an account.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "authorize" + answer = await ctx.session.elicit_url( + "Authorize access to your calendar.", "https://example.com/oauth/authorize", "auth-001" + ) + return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + + server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_url(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="accept") + + async with connect(server, elicitation_callback=answer_url) as client: + result = await client.call_tool("authorize", {}) + + assert received == snapshot( + [ + ElicitRequestURLParams( + _meta={}, + message="Authorize access to your calendar.", + url="https://example.com/oauth/authorize", + elicitation_id="auth-001", + ) + ] + ) + assert result == snapshot(CallToolResult(content=[TextContent(text="accept content=None")])) + + +@requirement("elicitation:url:decline") +async def test_elicit_url_decline_returns_no_content(connect: Connect) -> None: + """A declined URL elicitation returns the decline action to the handler with no content.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="authorize", description="Link an account.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "authorize" + answer = await ctx.session.elicit_url( + "Authorize access to your calendar.", "https://example.com/oauth/authorize", "auth-001" + ) + return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + + server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_url(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="decline") + + async with connect(server, elicitation_callback=answer_url) as client: + result = await client.call_tool("authorize", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="decline content=None")])) + + +@requirement("elicitation:url:cancel") +async def test_elicit_url_cancel_returns_no_content(connect: Connect) -> None: + """A cancelled URL elicitation returns the cancel action to the handler with no content.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="authorize", description="Link an account.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "authorize" + answer = await ctx.session.elicit_url( + "Authorize access to your calendar.", "https://example.com/oauth/authorize", "auth-001" + ) + return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + + server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_url(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="cancel") + + async with connect(server, elicitation_callback=answer_url) as client: + result = await client.call_tool("authorize", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="cancel content=None")])) + + +@requirement("elicitation:url:complete-notification") +async def test_elicitation_complete_notification_carries_the_elicited_id_back_to_the_client(connect: Connect) -> None: + """After a URL elicitation finishes, the server announces it with a notification carrying the same id. + + The lifecycle under test: the tool elicits a URL interaction with an elicitationId, the user + agrees to visit the URL, the out-of-band interaction finishes, and the server emits + elicitation/complete so the client can correlate the completion with the elicitation it + accepted earlier. The completion notification carries ``related_request_id`` so over + streamable HTTP it rides the tool call's own stream and reaches the client before the call + returns; the same ordering already holds on in-memory and SSE transports. + """ + elicitation_id = "auth-001" + elicited_ids: list[str] = [] + received: list[IncomingMessage] = [] + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="link_account", description="Link an account.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "link_account" + answer = await ctx.session.elicit_url( + "Authorize access to your files.", "https://example.com/oauth/authorize", elicitation_id + ) + assert answer.action == "accept" + await ctx.session.send_elicit_complete(elicitation_id, related_request_id=ctx.request_id) + return CallToolResult(content=[TextContent(text="linked")]) + + server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_url(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + assert isinstance(params, ElicitRequestURLParams) + elicited_ids.append(params.elicitation_id) + return ElicitResult(action="accept") + + async with connect(server, message_handler=collect, elicitation_callback=answer_url) as client: + await client.call_tool("link_account", {}) + + # The completion notification refers to the same elicitation the client accepted. + assert elicited_ids == [elicitation_id] + assert received == snapshot( + [ElicitCompleteNotification(params=ElicitCompleteNotificationParams(elicitation_id="auth-001"))] + ) + + +@requirement("elicitation:url:required-error") +async def test_url_elicitation_required_error_carries_pending_elicitations(connect: Connect) -> None: + """A request that cannot proceed until a URL interaction completes is rejected with error -32042. + + This is the non-interactive alternative to elicit_url: instead of asking and waiting, the + handler rejects the whole request and lists the required URL elicitations in the error data. + The client is expected to present those URLs, wait for the matching elicitation/complete + notifications, and retry the original request. + """ + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "read_files" + raise UrlElicitationRequiredError( + [ + ElicitRequestURLParams( + message="Authorization required for your files.", + url="https://example.com/oauth/authorize", + elicitation_id="auth-001", + ) + ] + ) + + server = Server("authorizer", on_call_tool=call_tool) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("read_files", {}) + + assert exc_info.value.error == snapshot( + ErrorData( + code=-32042, + message="URL elicitation required", + data={ + "elicitations": [ + { + "mode": "url", + "message": "Authorization required for your files.", + "url": "https://example.com/oauth/authorize", + "elicitationId": "auth-001", + } + ] + }, + ) + ) + + +@requirement("elicitation:form:schema:primitives") +@requirement("elicitation:form:schema:enum-variants") +async def test_elicit_form_schema_with_every_primitive_and_enum_type_reaches_the_callback_as_sent( + connect: Connect, +) -> None: + """A requested schema covering every spec-listed property kind is delivered to the callback unchanged. + + One schema with one property per kind: a formatted string, an integer with bounds, a number, + a boolean, a plain enum, a oneOf-const titled enum, and a multi-select array-of-enum. The + callback observing the same schema as the handler sent proves both the primitive coverage and + the enum-variant coverage in one snapshot. + """ + schema: ElicitRequestedSchema = { + "type": "object", + "properties": { + "email": {"type": "string", "format": "email", "title": "Email", "description": "Contact address."}, + "age": {"type": "integer", "minimum": 0, "maximum": 150}, + "score": {"type": "number"}, + "subscribe": {"type": "boolean", "default": False}, + "tier": {"type": "string", "enum": ["free", "pro", "team"]}, + "region": { + "oneOf": [ + {"const": "eu", "title": "Europe"}, + {"const": "na", "title": "North America"}, + ], + }, + "channels": {"type": "array", "items": {"type": "string", "enum": ["email", "sms", "push"]}}, + }, + "required": ["email"], + } + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="onboard", description="Onboard the user.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "onboard" + answer = await ctx.session.elicit_form("Tell us about yourself.", schema) + return CallToolResult(content=[TextContent(text=answer.action)]) + + server = Server("onboarder", on_list_tools=list_tools, on_call_tool=call_tool) + + received: list[types.ElicitRequestParams] = [] + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="accept", content={"email": "ada@example.com"}) + + async with connect(server, elicitation_callback=answer_form) as client: + await client.call_tool("onboard", {}) + + assert len(received) == 1 + assert isinstance(received[0], ElicitRequestFormParams) + assert received[0].requested_schema == schema + + +@requirement("elicitation:form:schema:restricted-subset") +async def test_elicit_form_with_a_nested_schema_is_forwarded_unchanged(connect: Connect) -> None: + """A requested schema with nested-object and array-of-object properties passes through unchanged. + + The spec restricts form-mode requested schemas to flat objects with primitive-typed properties; + this test pins that the SDK does not enforce that restriction on either side (see the + divergence on the requirement). + """ + schema: ElicitRequestedSchema = { + "type": "object", + "properties": { + "address": { + "type": "object", + "properties": {"street": {"type": "string"}, "city": {"type": "string"}}, + }, + "contacts": { + "type": "array", + "items": {"type": "object", "properties": {"name": {"type": "string"}}}, + }, + }, + } + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="profile", description="Collect a profile.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "profile" + answer = await ctx.session.elicit_form("Profile details.", schema) + return CallToolResult(content=[TextContent(text=answer.action)]) + + server = Server("profiler", on_list_tools=list_tools, on_call_tool=call_tool) + + received: list[types.ElicitRequestParams] = [] + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="decline") + + async with connect(server, elicitation_callback=answer_form) as client: + await client.call_tool("profile", {}) + + assert len(received) == 1 + assert isinstance(received[0], ElicitRequestFormParams) + assert received[0].requested_schema == schema + + +@requirement("elicitation:form:response-validation") +async def test_accepted_elicitation_content_that_violates_the_schema_reaches_the_handler_unchanged( + connect: Connect, +) -> None: + """Accepted form content that contradicts the requested schema is delivered to the handler unchanged. + + The schema requires a string `name`; the callback answers with a wrong-type value and an extra + field. Nothing on either side validates the response against the schema (see the divergence on + the requirement), so the handler observes exactly what the callback sent. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="signup", description="Register the user.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "signup" + answer = await ctx.session.elicit_form( + "Choose a name.", + {"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]}, + ) + return CallToolResult(content=[TextContent(text=answer.action)], structured_content=answer.content) + + server = Server("registrar", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="accept", content={"name": 42, "extra": "field"}) + + async with connect(server, elicitation_callback=answer_form) as client: + result = await client.call_tool("signup", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="accept")], structured_content={"name": 42, "extra": "field"}) + ) + + +@requirement("elicitation:url:complete-unknown-ignored") +async def test_elicitation_complete_for_an_unknown_id_is_received_without_error(connect: Connect) -> None: + """An elicitation/complete for an id the client never elicited is delivered and does not fail anything. + + No URL elicitation precedes the notification; the client neither tracks elicitation ids nor + rejects unknown ones, so the call completes normally and the message handler observes the + notification as-is. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="noop", description="Send a stray complete.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "noop" + await ctx.session.send_elicit_complete("never-elicited", related_request_id=ctx.request_id) + return CallToolResult(content=[TextContent(text="ok")]) + + server = Server("notifier", on_list_tools=list_tools, on_call_tool=call_tool) + + received: list[IncomingMessage] = [] + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async with connect(server, message_handler=collect) as client: + result = await client.call_tool("noop", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="ok")])) + assert received == snapshot( + [ElicitCompleteNotification(params=ElicitCompleteNotificationParams(elicitation_id="never-elicited"))] + ) + + +@requirement("elicitation:form:mode-omitted-default") +async def test_a_mode_less_elicitation_request_is_treated_as_form_mode() -> None: + """An elicitation/create request with no mode field reaches the client callback as form-mode. + + The typed server API always serializes a mode (`elicit_form` writes 'form', `elicit_url` writes + 'url'), so this test plays the server's side of the wire by hand to send a request body without + one. Reserve this pattern for behaviour the typed server API cannot produce. + """ + received: list[types.ElicitRequestParams] = [] + answered = anyio.Event() + server_received: list[JSONRPCMessage] = [] + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="accept", content={}) + + async def scripted_server(streams: MessageStream) -> None: + server_read, server_write = streams + initialize = await server_read.receive() + assert isinstance(initialize, SessionMessage) + request = initialize.message + assert isinstance(request, JSONRPCRequest) + assert request.method == "initialize" + result = InitializeResult( + protocol_version="2025-11-25", + capabilities=ServerCapabilities(), + server_info=Implementation(name="legacy", version="0.0.1"), + ) + await server_write.send( + SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + initialized = await server_read.receive() + assert isinstance(initialized, SessionMessage) + assert isinstance(initialized.message, JSONRPCNotification) + assert initialized.message.method == "notifications/initialized" + # No mode key: a server speaking a pre-mode revision of the spec sends only message + schema. + await server_write.send( + SessionMessage( + JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="elicitation/create", + params={"message": "Legacy ask.", "requestedSchema": {"type": "object", "properties": {}}}, + ) + ) + ) + response = await server_read.receive() + assert isinstance(response, SessionMessage) + server_received.append(response.message) + answered.set() + + async with ( + create_client_server_memory_streams() as ((client_read, client_write), server_streams), + anyio.create_task_group() as tg, + ClientSession(client_read, client_write, elicitation_callback=answer_form) as session, + ): + tg.start_soon(scripted_server, server_streams) + with anyio.fail_after(5): + await session.initialize() + await answered.wait() + + assert received == snapshot( + [ + ElicitRequestFormParams( + _meta=None, + message="Legacy ask.", + requested_schema={"type": "object", "properties": {}}, + ) + ] + ) + assert isinstance(received[0], ElicitRequestFormParams) + assert received[0].mode == "form" + assert len(server_received) == 1 + assert isinstance(server_received[0], JSONRPCResponse) + assert server_received[0].id == 2 diff --git a/tests/interaction/lowlevel/test_flows.py b/tests/interaction/lowlevel/test_flows.py new file mode 100644 index 0000000000..8d96582341 --- /dev/null +++ b/tests/interaction/lowlevel/test_flows.py @@ -0,0 +1,203 @@ +"""Composed multi-feature flows against the low-level Server, driven through the public Client API. + +Each test reads as the scenario it proves: the steps run top to bottom in the order a real client +would perform them, composing two or more feature areas (a tool call followed by a resource read; +a chain of elicitations inside one tool call; the full URL-elicitation-required retry loop). The +individual features are pinned by their own tests; these prove they compose. +""" + +from collections.abc import Awaitable, Callable + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, UrlElicitationRequiredError, types +from mcp.client import ClientRequestContext +from mcp.server import Server, ServerRequestContext +from mcp.server.session import ServerSession +from mcp.types import ( + URL_ELICITATION_REQUIRED, + CallToolResult, + ElicitCompleteNotification, + ElicitRequestFormParams, + ElicitRequestURLParams, + ElicitResult, + EmptyResult, + ListToolsResult, + ReadResourceResult, + ResourceLink, + TextContent, + TextResourceContents, + Tool, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + +ListToolsHandler = Callable[ + [ServerRequestContext, types.PaginatedRequestParams | None], Awaitable[types.ListToolsResult] +] + + +def _list_tools(*names: str) -> ListToolsHandler: + """A list_tools handler advertising the named tools, so call_tool's implicit list succeeds.""" + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name=name, input_schema={"type": "object"}) for name in names]) + + return list_tools + + +@requirement("flow:tool-result:resource-link-follow") +async def test_a_resource_link_returned_by_a_tool_can_be_followed_with_read(connect: Connect) -> None: + """A tool returns a resource_link; reading that link's URI returns the referenced contents. + + Steps: (1) call the tool, (2) extract the link from its content, (3) read_resource on the + link's URI, (4) the read result carries the linked contents. + """ + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "generate" + return CallToolResult(content=[ResourceLink(uri="file:///report.txt", name="report")]) + + async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: + assert str(params.uri) == "file:///report.txt" + return ReadResourceResult(contents=[TextResourceContents(uri="file:///report.txt", text="generated")]) + + server = Server( + "linker", on_list_tools=_list_tools("generate"), on_call_tool=call_tool, on_read_resource=read_resource + ) + + async with connect(server) as client: + called = await client.call_tool("generate", {}) + link = called.content[0] + assert isinstance(link, ResourceLink) + read = await client.read_resource(link.uri) + + assert called == snapshot(CallToolResult(content=[ResourceLink(name="report", uri="file:///report.txt")])) + assert read == snapshot( + ReadResourceResult(contents=[TextResourceContents(uri="file:///report.txt", text="generated")]) + ) + + +@requirement("flow:elicitation:multi-step-form") +async def test_a_tool_handler_chains_form_elicitations_feeding_each_answer_forward(connect: Connect) -> None: + """Sequential form elicitations inside one tool call: each accepted answer feeds the next step. + + Steps: (1) call the tool, (2) the handler issues a step-one form elicitation that the client + accepts with content, (3) the handler issues a step-two elicitation whose message references + the step-one answer, (4) the client accepts step two, (5) the tool result summarises both + answers. The callback is invoked exactly twice with the expected messages and schemas. The + short-circuit on decline is the application's choice (proven separately by the per-action + elicitation tests); what this flow pins is that the chain itself works end to end. + """ + received: list[ElicitRequestFormParams] = [] + answers: list[dict[str, str | int | float | bool | list[str] | None]] = [{"name": "ada"}, {"age": 37}] + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "onboard" + first = await ctx.session.elicit_form( + "Step 1: choose a username.", {"type": "object", "properties": {"name": {"type": "string"}}} + ) + assert first.action == "accept" and first.content is not None + second = await ctx.session.elicit_form( + f"Step 2: confirm age for {first.content['name']}.", + {"type": "object", "properties": {"age": {"type": "integer"}}}, + ) + assert second.action == "accept" and second.content is not None + return CallToolResult(content=[TextContent(text=f"{first.content['name']} is {second.content['age']}")]) + + server = Server("onboarder", on_list_tools=_list_tools("onboard"), on_call_tool=call_tool) + + async def answer(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + assert isinstance(params, ElicitRequestFormParams) + received.append(params) + return ElicitResult(action="accept", content=answers[len(received) - 1]) + + async with connect(server, elicitation_callback=answer) as client: + result = await client.call_tool("onboard", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="ada is 37")])) + assert [(p.message, p.requested_schema) for p in received] == snapshot( + [ + ("Step 1: choose a username.", {"type": "object", "properties": {"name": {"type": "string"}}}), + ("Step 2: confirm age for ada.", {"type": "object", "properties": {"age": {"type": "integer"}}}), + ] + ) + + +@requirement("flow:elicitation:url-required-then-retry") +async def test_a_tool_rejected_with_url_elicitation_required_succeeds_on_retry_after_completion( + connect: Connect, +) -> None: + """The full URL-elicitation-required retry loop: -32042, completion announced, retry succeeds. + + Steps: (1) the first call is rejected with -32042 carrying the required URL elicitation in + its error data, (2) the client extracts the elicitation id from the error, (3) the server + announces completion via the elicitation/complete notification (driven via the captured + session, the same way a real out-of-band callback would reach a held session reference), + (4) the client observes the matching completion notification and retries, (5) the retry + succeeds. The handler distinguishes the two calls by a closure flag the test flips between + them; the test waits on the completion notification with an event so the retry only happens + after the announcement has arrived. + """ + elicitation_id = "auth-001" + authorised: list[bool] = [False] + captured: list[ServerSession] = [] + completed = anyio.Event() + notifications: list[ElicitCompleteNotification] = [] + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "read_files" + captured.append(ctx.session) + if not authorised[0]: + # The log line gives the message handler a non-completion notification, so the test's + # filtering branch is exercised in both directions and the wait remains specific. + await ctx.session.send_log_message(level="warning", data="authorisation required", logger="gate") + raise UrlElicitationRequiredError( + [ + ElicitRequestURLParams( + message="Authorize file access.", + url="https://example.com/oauth/authorize", + elicitation_id=elicitation_id, + ) + ] + ) + return CallToolResult(content=[TextContent(text="contents")]) + + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + """Registered so the logging capability is advertised; the client never sets a level.""" + raise NotImplementedError + + server = Server( + "gatekeeper", + on_list_tools=_list_tools("read_files"), + on_call_tool=call_tool, + on_set_logging_level=set_logging_level, + ) + + async def collect(message: IncomingMessage) -> None: + if isinstance(message, ElicitCompleteNotification): + notifications.append(message) + completed.set() + + async with connect(server, message_handler=collect) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("read_files", {}) + assert exc_info.value.error.code == URL_ELICITATION_REQUIRED + required = UrlElicitationRequiredError.from_error(exc_info.value.error) + assert [e.elicitation_id for e in required.elicitations] == [elicitation_id] + + # The out-of-band interaction completes; the server announces it on the same session. + await captured[0].send_elicit_complete(elicitation_id) + with anyio.fail_after(5): + await completed.wait() + assert notifications[0].params.elicitation_id == elicitation_id + + authorised[0] = True + result = await client.call_tool("read_files", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="contents")])) diff --git a/tests/interaction/lowlevel/test_initialize.py b/tests/interaction/lowlevel/test_initialize.py new file mode 100644 index 0000000000..91adbf5611 --- /dev/null +++ b/tests/interaction/lowlevel/test_initialize.py @@ -0,0 +1,384 @@ +"""Initialization handshake against the low-level Server, driven through the public Client API. + +The later tests drive a bare ClientSession over an InMemoryTransport instead: Client always +performs the full handshake with the latest protocol version, so skipping initialization or +requesting a different version can only be expressed one level down. The final test goes one step +further and plays the server's side of the wire by hand, because no real Server can be made to +answer initialize with an unsupported protocol version. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client import ClientRequestContext, ClientSession +from mcp.client._memory import InMemoryTransport +from mcp.server import Server, ServerRequestContext +from mcp.shared.memory import MessageStream, create_client_server_memory_streams +from mcp.shared.message import SessionMessage +from mcp.types import ( + INVALID_PARAMS, + CallToolResult, + ClientCapabilities, + CompletionsCapability, + EmptyResult, + ErrorData, + Icon, + Implementation, + InitializeRequest, + InitializeRequestParams, + InitializeResult, + JSONRPCRequest, + JSONRPCResponse, + ListToolsRequest, + ListToolsResult, + LoggingCapability, + PromptsCapability, + ResourcesCapability, + ServerCapabilities, + TextContent, + ToolsCapability, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("lifecycle:initialize:basic") +@requirement("lifecycle:initialize:server-info") +async def test_initialize_returns_server_info(connect: Connect) -> None: + """Every identity field the server declares is returned to the client in server_info.""" + server = Server( + "greeter", + version="1.2.3", + title="Greeter", + description="Greets people.", + website_url="https://example.com/greeter", + icons=[Icon(src="https://example.com/icon.png", mime_type="image/png", sizes=["48x48"])], + ) + + async with connect(server) as client: + server_info = client.initialize_result.server_info + + assert server_info == snapshot( + Implementation( + name="greeter", + title="Greeter", + description="Greets people.", + version="1.2.3", + website_url="https://example.com/greeter", + icons=[Icon(src="https://example.com/icon.png", mime_type="image/png", sizes=["48x48"])], + ) + ) + + +@requirement("lifecycle:initialize:instructions") +async def test_initialize_returns_instructions(connect: Connect) -> None: + """Instructions are returned when the server declares them and omitted when it does not.""" + async with connect(Server("guided", instructions="Call the add tool.")) as client: + assert client.initialize_result.instructions == snapshot("Call the add tool.") + + async with connect(Server("unguided")) as client: + assert client.initialize_result.instructions is None + + +@requirement("lifecycle:initialize:capabilities:from-handlers") +@requirement("tools:capability:declared") +@requirement("resources:capability:declared") +@requirement("prompts:capability:declared") +@requirement("completion:capability:declared") +async def test_initialize_capabilities_reflect_registered_handlers(connect: Connect) -> None: + """Each feature area with a registered handler is advertised as a capability. + + The in-memory transport connects with default initialization options, so the + list_changed flags are always False regardless of the server's notification behaviour. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + """Registered only so the tools capability is advertised; never called.""" + raise NotImplementedError + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListResourcesResult: + """Registered only so the resources capability is advertised; never called.""" + raise NotImplementedError + + async def subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeRequestParams) -> types.EmptyResult: + """Registered only so the subscribe sub-capability is advertised; never called.""" + raise NotImplementedError + + async def list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListPromptsResult: + """Registered only so the prompts capability is advertised; never called.""" + raise NotImplementedError + + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> types.EmptyResult: + """Registered only so the logging capability is advertised; never called.""" + raise NotImplementedError + + async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> types.CompleteResult: + """Registered only so the completions capability is advertised; never called.""" + raise NotImplementedError + + server = Server( + "full", + on_list_tools=list_tools, + on_list_resources=list_resources, + on_subscribe_resource=subscribe_resource, + on_list_prompts=list_prompts, + on_set_logging_level=set_logging_level, + on_completion=completion, + ) + + async with connect(server) as client: + capabilities = client.initialize_result.capabilities + + assert capabilities == snapshot( + ServerCapabilities( + experimental={}, + logging=LoggingCapability(), + prompts=PromptsCapability(list_changed=False), + resources=ResourcesCapability(subscribe=True, list_changed=False), + tools=ToolsCapability(list_changed=False), + completions=CompletionsCapability(), + ) + ) + + +@requirement("lifecycle:initialize:capabilities:minimal") +async def test_initialize_minimal_server_advertises_no_capabilities(connect: Connect) -> None: + """A server with no feature handlers advertises no feature capabilities.""" + async with connect(Server("bare")) as client: + capabilities = client.initialize_result.capabilities + + assert capabilities == snapshot(ServerCapabilities(experimental={})) + + +@requirement("lifecycle:initialize:client-info") +async def test_initialize_server_sees_client_info(connect: Connect) -> None: + """The client identity supplied to Client is visible to server handlers after initialization.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="whoami", description="Report the caller.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "whoami" + assert ctx.session.client_params is not None + client_info = ctx.session.client_params.client_info + return CallToolResult(content=[TextContent(text=f"{client_info.name} {client_info.version}")]) + + server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + async with connect(server, client_info=Implementation(name="acme-agent", version="9.9.9")) as client: + result = await client.call_tool("whoami", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="acme-agent 9.9.9")])) + + +@requirement("lifecycle:initialize:client-capabilities") +async def test_initialize_server_sees_client_capabilities(connect: Connect) -> None: + """The client capabilities visible to the server reflect which callbacks the client configured.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="abilities", description="Report capabilities.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "abilities" + assert ctx.session.client_params is not None + capabilities = ctx.session.client_params.capabilities + declared = [ + name + for name, value in ( + ("sampling", capabilities.sampling), + ("elicitation", capabilities.elicitation), + ) + if value is not None + ] + if capabilities.roots is not None: + declared.append(f"roots(list_changed={capabilities.roots.list_changed})") + return CallToolResult(content=[TextContent(text=",".join(declared) or "none")]) + + async def list_roots(context: ClientRequestContext) -> types.ListRootsResult: + """Registered only so the client declares the roots capability; never called.""" + raise NotImplementedError + + server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("abilities", {}) + assert result == snapshot(CallToolResult(content=[TextContent(text="none")])) + + async with connect(server, list_roots_callback=list_roots) as client: + result = await client.call_tool("abilities", {}) + assert result == snapshot(CallToolResult(content=[TextContent(text="roots(list_changed=True)")])) + + +@requirement("lifecycle:requests-before-initialized") +async def test_request_before_initialization_is_rejected() -> None: + """A feature request sent before the handshake completes is rejected; ping is exempt. + + Client always initializes on entry, so this drives a bare ClientSession that never sends + initialize. The server's stated reason for the rejection never reaches the client: the error + is reported as a generic invalid-params failure. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + """Registered so the request is routed to a real handler; never reached.""" + raise NotImplementedError + + server = Server("strict", on_list_tools=list_tools) + + async with ( + InMemoryTransport(server) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + with anyio.fail_after(5): + with pytest.raises(MCPError) as exc_info: + await session.send_request(ListToolsRequest(), ListToolsResult) + + # Ping is explicitly permitted before initialization completes. + pong = await session.send_ping() + + assert exc_info.value.error == snapshot( + ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="") + ) + assert pong == snapshot(EmptyResult()) + + +@requirement("lifecycle:version:match") +@requirement("lifecycle:version:server-fallback-latest") +async def test_initialize_negotiates_protocol_version() -> None: + """The server echoes a supported requested version and answers an unsupported one with its latest. + + Client always requests the latest version, so each half hand-builds an InitializeRequest on a + bare ClientSession to control the requested version. + """ + server = Server("negotiator") + + def initialize_request(protocol_version: str) -> InitializeRequest: + return InitializeRequest( + params=InitializeRequestParams( + protocol_version=protocol_version, + capabilities=ClientCapabilities(), + client_info=Implementation(name="time-traveller", version="0.0.1"), + ) + ) + + async with ( + InMemoryTransport(server) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + with anyio.fail_after(5): + result = await session.send_request(initialize_request("2025-03-26"), InitializeResult) + assert result.protocol_version == snapshot("2025-03-26") + + async with ( + InMemoryTransport(server) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + with anyio.fail_after(5): + result = await session.send_request(initialize_request("1999-01-01"), InitializeResult) + assert result.protocol_version == snapshot("2025-11-25") + + +@requirement("lifecycle:version:reject-unsupported") +async def test_unsupported_server_protocol_version_fails_initialization() -> None: + """An initialize response carrying a protocol version the client does not support fails initialization. + + A real Server only ever answers with a version it supports, so this test alone plays the + server's side of the wire by hand: it reads the initialize request off the raw stream and + answers it with a hand-built result. Reserve this pattern for behaviour no real server can + be made to produce. + """ + + async def scripted_server(streams: MessageStream) -> None: + server_read, server_write = streams + message = await server_read.receive() + assert isinstance(message, SessionMessage) + request = message.message + assert isinstance(request, JSONRPCRequest) + assert request.method == "initialize" + result = InitializeResult( + protocol_version="1991-08-06", + capabilities=ServerCapabilities(), + server_info=Implementation(name="relic", version="0.0.1"), + ) + await server_write.send( + SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request.id, + # Serialized exactly as a real server serializes results onto the wire. + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + + async with ( + create_client_server_memory_streams() as ((client_read, client_write), server_streams), + anyio.create_task_group() as tg, + ClientSession(client_read, client_write) as session, + ): + tg.start_soon(scripted_server, server_streams) + with anyio.fail_after(5): + with pytest.raises(RuntimeError) as exc_info: + await session.initialize() + + assert str(exc_info.value) == snapshot("Unsupported protocol version from the server: 1991-08-06") + + +@requirement("lifecycle:version:downgrade") +async def test_an_older_supported_protocol_version_from_the_server_is_accepted() -> None: + """An initialize response carrying an older supported protocol version completes the handshake at that version. + + A real Server answers with the version the client requested (or its own latest), so this test + plays the server's side of the wire by hand to return a fixed older version regardless of what + was requested. Reserve this pattern for behaviour no real server can be made to produce. + """ + + async def scripted_server(streams: MessageStream) -> None: + server_read, server_write = streams + message = await server_read.receive() + assert isinstance(message, SessionMessage) + request = message.message + assert isinstance(request, JSONRPCRequest) + assert request.method == "initialize" + result = InitializeResult( + protocol_version="2025-06-18", + capabilities=ServerCapabilities(), + server_info=Implementation(name="conservative", version="0.0.1"), + ) + await server_write.send( + SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request.id, + # Serialized exactly as a real server serializes results onto the wire. + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + + async with ( + create_client_server_memory_streams() as ((client_read, client_write), server_streams), + anyio.create_task_group() as tg, + ClientSession(client_read, client_write) as session, + ): + tg.start_soon(scripted_server, server_streams) + with anyio.fail_after(5): + initialize_result = await session.initialize() + + assert initialize_result.protocol_version == snapshot("2025-06-18") diff --git a/tests/interaction/lowlevel/test_list_changed.py b/tests/interaction/lowlevel/test_list_changed.py new file mode 100644 index 0000000000..a2f85eeacf --- /dev/null +++ b/tests/interaction/lowlevel/test_list_changed.py @@ -0,0 +1,136 @@ +"""List-changed notifications from the low-level Server, driven through the public Client API. + +``send_*_list_changed`` does not take a ``related_request_id``, so over streamable HTTP the +notification routes to the standalone GET stream and is not guaranteed to arrive before the tool +result on its POST stream. Tests therefore wait on an event the collector sets, the same pattern +as ``transports/test_streamable_http.py::test_unrelated_server_messages_arrive_on_the_standalone_stream``. +The collector still records every message it receives, so the snapshot also proves nothing else +was delivered. + +The servers register the parent capability (resources/prompts) so that part of the spec's +precondition holds, but the ``listChanged`` sub-capability stays ``False``: ``NotificationOptions`` +is not threaded through any of the suite's connection paths. The tests therefore rely on the +recorded ``lifecycle:capability:server-not-advertised`` divergence and will need updating +alongside the fix that introduces capability gating. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + CallToolResult, + PromptListChangedNotification, + ResourceListChangedNotification, + TextContent, + ToolListChangedNotification, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("tools:list-changed") +async def test_tool_list_changed_notification(connect: Connect) -> None: + """A tools/list_changed notification sent during a tool call reaches the client's message handler.""" + received: list[IncomingMessage] = [] + seen = anyio.Event() + + async def collect(message: IncomingMessage) -> None: + received.append(message) + seen.set() + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="install", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "install" + await ctx.session.send_tool_list_changed() + return CallToolResult(content=[TextContent(text="installed")]) + + server = Server("registry", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server, message_handler=collect) as client: + await client.call_tool("install", {}) + with anyio.fail_after(5): + await seen.wait() + + assert received == snapshot([ToolListChangedNotification()]) + + +@requirement("resources:list-changed") +async def test_resource_list_changed_notification(connect: Connect) -> None: + """A resources/list_changed notification sent during a tool call reaches the client's message handler.""" + received: list[IncomingMessage] = [] + seen = anyio.Event() + + async def collect(message: IncomingMessage) -> None: + received.append(message) + seen.set() + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="mount", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "mount" + await ctx.session.send_resource_list_changed() + return CallToolResult(content=[TextContent(text="mounted")]) + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListResourcesResult: + """Registered so the resources capability is advertised; the client never lists resources.""" + raise NotImplementedError + + server = Server("registry", on_list_tools=list_tools, on_call_tool=call_tool, on_list_resources=list_resources) + + async with connect(server, message_handler=collect) as client: + await client.call_tool("mount", {}) + with anyio.fail_after(5): + await seen.wait() + + assert received == snapshot([ResourceListChangedNotification()]) + + +@requirement("prompts:list-changed") +async def test_prompt_list_changed_notification(connect: Connect) -> None: + """A prompts/list_changed notification sent during a tool call reaches the client's message handler.""" + received: list[IncomingMessage] = [] + seen = anyio.Event() + + async def collect(message: IncomingMessage) -> None: + received.append(message) + seen.set() + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="learn", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "learn" + await ctx.session.send_prompt_list_changed() + return CallToolResult(content=[TextContent(text="learned")]) + + async def list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListPromptsResult: + """Registered so the prompts capability is advertised; the client never lists prompts.""" + raise NotImplementedError + + server = Server("registry", on_list_tools=list_tools, on_call_tool=call_tool, on_list_prompts=list_prompts) + + async with connect(server, message_handler=collect) as client: + await client.call_tool("learn", {}) + with anyio.fail_after(5): + await seen.wait() + + assert received == snapshot([PromptListChangedNotification()]) diff --git a/tests/interaction/lowlevel/test_logging.py b/tests/interaction/lowlevel/test_logging.py new file mode 100644 index 0000000000..fba632ef4d --- /dev/null +++ b/tests/interaction/lowlevel/test_logging.py @@ -0,0 +1,127 @@ +"""Logging interactions against the low-level Server, driven through the public Client API. + +Notification ordering: the in-memory transport delivers every server-to-client message on one +ordered stream, and the client's receive loop dispatches each incoming message to completion +before reading the next one. Over streamable HTTP that ordered single-stream guarantee holds +only for messages that carry a ``related_request_id`` (they ride the originating request's POST +stream); without it the message routes to the standalone GET stream and may arrive after the +response. These tests pass ``related_request_id`` so they can collect into a plain list and +assert after the request completes on every transport leg -- no events, no waiting. +""" + +import pytest +from inline_snapshot import snapshot + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.types import CallToolResult, EmptyResult, LoggingMessageNotificationParams, TextContent +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + +ALL_LEVELS: tuple[types.LoggingLevel, ...] = ( + "debug", + "info", + "notice", + "warning", + "error", + "critical", + "alert", + "emergency", +) + + +@requirement("logging:set-level") +async def test_set_logging_level_reaches_handler(connect: Connect) -> None: + """The level requested by the client is delivered to the server's handler verbatim.""" + + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + assert params.level == "warning" + return EmptyResult() + + server = Server("logger", on_set_logging_level=set_logging_level) + + async with connect(server) as client: + result = await client.set_logging_level("warning") + + assert result == snapshot(EmptyResult()) + + +@requirement("logging:message:fields") +@requirement("tools:call:logging-mid-execution") +async def test_log_messages_reach_logging_callback_in_order(connect: Connect) -> None: + """Log messages sent during a tool call arrive at the logging callback, in order, before the call returns. + + The two messages pin the full notification shape: severity, optional logger name, and both + string and structured data payloads. + """ + received: list[LoggingMessageNotificationParams] = [] + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="chatty", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "chatty" + await ctx.session.send_log_message( + level="info", data="starting up", logger="app.lifecycle", related_request_id=ctx.request_id + ) + await ctx.session.send_log_message( + level="error", data={"code": 502, "retryable": True}, related_request_id=ctx.request_id + ) + return CallToolResult(content=[TextContent(text="done")]) + + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + """Registered so the logging capability is advertised; the client never sets a level.""" + raise NotImplementedError + + server = Server("logger", on_list_tools=list_tools, on_call_tool=call_tool, on_set_logging_level=set_logging_level) + + async with connect(server, logging_callback=collect) as client: + result = await client.call_tool("chatty", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="done")])) + assert received == snapshot( + [ + LoggingMessageNotificationParams(level="info", logger="app.lifecycle", data="starting up"), + LoggingMessageNotificationParams(level="error", data={"code": 502, "retryable": True}), + ] + ) + + +@requirement("logging:message:all-levels") +async def test_log_messages_at_every_severity_level(connect: Connect) -> None: + """Each of the eight RFC 5424 severity levels is deliverable as a log message notification.""" + received: list[LoggingMessageNotificationParams] = [] + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="siren", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "siren" + for level in ALL_LEVELS: + await ctx.session.send_log_message( + level=level, data=f"a {level} message", related_request_id=ctx.request_id + ) + return CallToolResult(content=[TextContent(text="logged")]) + + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + """Registered so the logging capability is advertised; the client never sets a level.""" + raise NotImplementedError + + server = Server("logger", on_list_tools=list_tools, on_call_tool=call_tool, on_set_logging_level=set_logging_level) + + async with connect(server, logging_callback=collect) as client: + await client.call_tool("siren", {}) + + assert [params.level for params in received] == list(ALL_LEVELS) diff --git a/tests/interaction/lowlevel/test_meta.py b/tests/interaction/lowlevel/test_meta.py new file mode 100644 index 0000000000..a9e4f994d8 --- /dev/null +++ b/tests/interaction/lowlevel/test_meta.py @@ -0,0 +1,63 @@ +"""Request and result _meta round trips against the low-level Server, through the public Client API. + +Meta is opaque pass-through data, so these tests assert identity against the value that was sent +rather than snapshotting a literal: the expected value and the sent value are the same variable, +which also proves the SDK injected nothing alongside it. +""" + +import pytest + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.types import CallToolResult, RequestParamsMeta, TextContent +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("meta:request-to-handler") +async def test_request_meta_reaches_handler(connect: Connect) -> None: + """The _meta object the client attaches to a request arrives at the tool handler unchanged.""" + request_meta: RequestParamsMeta = {"example.com/trace": "abc-123"} + observed_metas: list[dict[str, object]] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="traced", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "traced" + assert ctx.meta is not None + observed_metas.append(dict(ctx.meta)) + return CallToolResult(content=[TextContent(text="traced")]) + + server = Server("observability", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.call_tool("traced", {}, meta=request_meta) + + assert observed_metas == [dict(request_meta)] + + +@requirement("meta:result-to-client") +async def test_result_meta_reaches_client(connect: Connect) -> None: + """The _meta object a handler attaches to its result is delivered to the client unchanged.""" + result_meta = {"example.com/cost": 3} + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="metered", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "metered" + return CallToolResult(content=[TextContent(text="done")], _meta=result_meta) + + server = Server("observability", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("metered", {}) + + assert result == CallToolResult(content=[TextContent(text="done")], _meta=result_meta) diff --git a/tests/interaction/lowlevel/test_pagination.py b/tests/interaction/lowlevel/test_pagination.py new file mode 100644 index 0000000000..77db90401e --- /dev/null +++ b/tests/interaction/lowlevel/test_pagination.py @@ -0,0 +1,242 @@ +"""Cursor pagination of the list operations against the low-level Server. + +The cursor is an opaque string chosen by the server: the suite only asserts that whatever the +handler returns as next_cursor comes back verbatim on the client's next call, not any particular +pagination scheme. +""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + INVALID_PARAMS, + ListPromptsResult, + ListResourcesResult, + ListResourceTemplatesResult, + ListToolsResult, + Prompt, + Resource, + ResourceTemplate, + Tool, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("tools:list:pagination") +async def test_next_cursor_round_trips_through_the_client(connect: Connect) -> None: + """The next_cursor a list handler returns reaches the client, and the cursor the client sends + back on the following call reaches the handler verbatim. + """ + cursor = "page-2" + seen_cursors: list[str | None] = [] + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + assert params is not None # the client always sends params, even without a cursor + seen_cursors.append(params.cursor) + if params.cursor is None: + return ListToolsResult( + tools=[Tool(name="alpha", input_schema={"type": "object"})], + next_cursor=cursor, + ) + return ListToolsResult(tools=[Tool(name="beta", input_schema={"type": "object"})]) + + server = Server("paginated", on_list_tools=list_tools) + + async with connect(server) as client: + first_page = await client.list_tools() + second_page = await client.list_tools(cursor=first_page.next_cursor) + + assert first_page.next_cursor == cursor + assert seen_cursors == [None, cursor] + assert [tool.name for tool in first_page.tools] == ["alpha"] + assert second_page == snapshot(ListToolsResult(tools=[Tool(name="beta", input_schema={"type": "object"})])) + + +@requirement("pagination:exhaustion") +@requirement("tools:list:pagination") +async def test_paginating_until_next_cursor_is_absent_yields_every_page(connect: Connect) -> None: + """Following next_cursor until it is absent visits every page exactly once, in order.""" + pages: dict[str | None, tuple[str, str | None]] = { + None: ("alpha", "page-2"), + "page-2": ("beta", "page-3"), + "page-3": ("gamma", None), + } + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + assert params is not None + tool_name, next_cursor = pages[params.cursor] + return ListToolsResult(tools=[Tool(name=tool_name, input_schema={"type": "object"})], next_cursor=next_cursor) + + server = Server("paginated", on_list_tools=list_tools) + + collected: list[str] = [] + cursor: str | None = None + requests_made = 0 + async with connect(server) as client: + while True: + result = await client.list_tools(cursor=cursor) + requests_made += 1 + assert requests_made <= len(pages), "the server kept returning next_cursor past the last page" + collected.extend(tool.name for tool in result.tools) + if result.next_cursor is None: + break + cursor = result.next_cursor + + assert collected == snapshot(["alpha", "beta", "gamma"]) + assert requests_made == len(pages) + + +@requirement("pagination:client:cursor-handling") +async def test_the_client_follows_opaque_cursors_through_pages_of_varying_sizes(connect: Connect) -> None: + """The client passes a server-issued cursor back byte-for-byte and follows pages of varying sizes. + + The cursors are deliberately base64-looking strings (with padding and URL-unsafe characters) to + show the client treats them as opaque tokens; the page sizes [3, 1, 2] show the loop relies only + on next_cursor, not on a fixed page size. + """ + cursor_to_page_2 = "YWxwaGE+YnJhdm8/Y2hhcmxpZQ==" + cursor_to_page_3 = "ZGVsdGE=" + pages: dict[str | None, tuple[list[str], str | None]] = { + None: (["alpha", "beta", "gamma"], cursor_to_page_2), + cursor_to_page_2: (["delta"], cursor_to_page_3), + cursor_to_page_3: (["epsilon", "zeta"], None), + } + received_cursors: list[str | None] = [] + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + assert params is not None + received_cursors.append(params.cursor) + names, next_cursor = pages[params.cursor] + return ListToolsResult( + tools=[Tool(name=name, input_schema={"type": "object"}) for name in names], next_cursor=next_cursor + ) + + server = Server("paginated", on_list_tools=list_tools) + + page_sizes: list[int] = [] + cursor: str | None = None + async with connect(server) as client: + while True: + result = await client.list_tools(cursor=cursor) + page_sizes.append(len(result.tools)) + if result.next_cursor is None: + break + cursor = result.next_cursor + + # Identity, not a snapshot: what arrived at the handler is exactly what the handler issued. + assert received_cursors == [None, cursor_to_page_2, cursor_to_page_3] + assert page_sizes == [3, 1, 2] + + +@requirement("pagination:invalid-cursor") +async def test_an_unrecognized_pagination_cursor_is_rejected_with_invalid_params(connect: Connect) -> None: + """A list request with a cursor the server did not issue is answered with -32602 Invalid params. + + The lowlevel server does not validate cursors itself (they are opaque to it); rejecting an + unrecognized cursor is the handler's job, and this test pins the spec-recommended way to do it. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + assert params is not None + assert params.cursor == "never-issued" + raise MCPError(code=INVALID_PARAMS, message=f"Unknown cursor: {params.cursor!r}") + + server = Server("paginated", on_list_tools=list_tools) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.list_tools(cursor="never-issued") + + assert exc_info.value.error.code == INVALID_PARAMS + + +@requirement("resources:list:pagination") +async def test_resources_list_supports_cursor_pagination(connect: Connect) -> None: + """resources/list round-trips the cursor like every other list operation.""" + cursor = "page-2" + seen_cursors: list[str | None] = [] + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: + assert params is not None + seen_cursors.append(params.cursor) + if params.cursor is None: + return ListResourcesResult(resources=[Resource(uri="memo://1", name="first")], next_cursor=cursor) + return ListResourcesResult(resources=[Resource(uri="memo://2", name="second")]) + + server = Server("paginated", on_list_resources=list_resources) + + async with connect(server) as client: + first_page = await client.list_resources() + second_page = await client.list_resources(cursor=first_page.next_cursor) + + assert first_page.next_cursor == cursor + assert seen_cursors == [None, cursor] + assert [resource.name for resource in first_page.resources] == ["first"] + assert [resource.name for resource in second_page.resources] == ["second"] + assert second_page.next_cursor is None + + +@requirement("resources:templates:pagination") +async def test_resource_templates_list_supports_cursor_pagination(connect: Connect) -> None: + """resources/templates/list round-trips the cursor like every other list operation.""" + cursor = "page-2" + seen_cursors: list[str | None] = [] + + async def list_resource_templates( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourceTemplatesResult: + assert params is not None + seen_cursors.append(params.cursor) + if params.cursor is None: + return ListResourceTemplatesResult( + resource_templates=[ResourceTemplate(name="first", uri_template="users://{id}")], + next_cursor=cursor, + ) + return ListResourceTemplatesResult( + resource_templates=[ResourceTemplate(name="second", uri_template="teams://{id}")] + ) + + server = Server("paginated", on_list_resource_templates=list_resource_templates) + + async with connect(server) as client: + first_page = await client.list_resource_templates() + second_page = await client.list_resource_templates(cursor=first_page.next_cursor) + + assert first_page.next_cursor == cursor + assert seen_cursors == [None, cursor] + assert [template.name for template in first_page.resource_templates] == ["first"] + assert [template.name for template in second_page.resource_templates] == ["second"] + assert second_page.next_cursor is None + + +@requirement("prompts:list:pagination") +async def test_prompts_list_supports_cursor_pagination(connect: Connect) -> None: + """prompts/list round-trips the cursor like every other list operation.""" + cursor = "page-2" + seen_cursors: list[str | None] = [] + + async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListPromptsResult: + assert params is not None + seen_cursors.append(params.cursor) + if params.cursor is None: + return ListPromptsResult(prompts=[Prompt(name="first")], next_cursor=cursor) + return ListPromptsResult(prompts=[Prompt(name="second")]) + + server = Server("paginated", on_list_prompts=list_prompts) + + async with connect(server) as client: + first_page = await client.list_prompts() + second_page = await client.list_prompts(cursor=first_page.next_cursor) + + assert first_page.next_cursor == cursor + assert seen_cursors == [None, cursor] + assert [prompt.name for prompt in first_page.prompts] == ["first"] + assert [prompt.name for prompt in second_page.prompts] == ["second"] + assert second_page.next_cursor is None diff --git a/tests/interaction/lowlevel/test_ping.py b/tests/interaction/lowlevel/test_ping.py new file mode 100644 index 0000000000..797e20dc35 --- /dev/null +++ b/tests/interaction/lowlevel/test_ping.py @@ -0,0 +1,53 @@ +"""Ping interactions against the low-level Server, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.types import CallToolResult, EmptyResult, TextContent +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("lifecycle:ping") +@requirement("ping:client-to-server") +async def test_client_ping_returns_empty_result(connect: Connect) -> None: + """A client ping is answered with an empty result, even by a server with no handlers.""" + server = Server("silent") + + async with connect(server) as client: + result = await client.send_ping() + + assert result == snapshot(EmptyResult()) + + +@requirement("lifecycle:ping") +@requirement("ping:server-to-client") +async def test_server_ping_returns_empty_result(connect: Connect) -> None: + """A server-initiated ping sent while a request is in flight is answered by the client. + + The tool returns the type of the ping response, proving the round trip completed inside + the handler before the tool result was produced. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="ping_back", description="Ping the client.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ping_back" + pong = await ctx.session.send_ping() + return CallToolResult(content=[TextContent(text=type(pong).__name__)]) + + server = Server("pinger", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("ping_back", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="EmptyResult")])) diff --git a/tests/interaction/lowlevel/test_progress.py b/tests/interaction/lowlevel/test_progress.py new file mode 100644 index 0000000000..6350c33a33 --- /dev/null +++ b/tests/interaction/lowlevel/test_progress.py @@ -0,0 +1,301 @@ +"""Progress interactions against the low-level Server, driven through the public Client API. + +Server-to-client progress emitted during a request follows the same ordering guarantee as +logging notifications (see test_logging.py) -- on the in-memory transport unconditionally, and +over streamable HTTP only when sent with ``related_request_id`` so the notification rides the +originating request's POST stream rather than the standalone GET stream. These tests pass +``related_request_id`` so no synchronisation is needed. The client-to-server direction is a +standalone notification with no response to await, so that test waits on an event set by the +server's handler. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.server.session import ServerSession +from mcp.shared.session import ProgressFnT +from mcp.types import CallToolResult, ProgressNotification, ProgressNotificationParams, ProgressToken, TextContent +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("protocol:progress:callback") +@requirement("tools:call:progress") +async def test_progress_during_tool_call_reaches_callback_in_order(connect: Connect) -> None: + """Progress notifications emitted by a tool handler reach the caller's progress callback in order.""" + received: list[tuple[float, float | None, str | None]] = [] + + async def collect(progress: float, total: float | None, message: str | None) -> None: + received.append((progress, total, message)) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="download", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "download" + assert ctx.meta is not None + token = ctx.meta.get("progress_token") + assert token is not None + await ctx.session.send_progress_notification( + token, 1.0, total=3.0, message="first chunk", related_request_id=str(ctx.request_id) + ) + await ctx.session.send_progress_notification( + token, 2.0, total=3.0, message="second chunk", related_request_id=str(ctx.request_id) + ) + await ctx.session.send_progress_notification( + token, 3.0, total=3.0, message="done", related_request_id=str(ctx.request_id) + ) + return CallToolResult(content=[TextContent(text="downloaded")]) + + server = Server("downloader", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("download", {}, progress_callback=collect) + + assert result == snapshot(CallToolResult(content=[TextContent(text="downloaded")])) + assert received == snapshot([(1.0, 3.0, "first chunk"), (2.0, 3.0, "second chunk"), (3.0, 3.0, "done")]) + + +@requirement("protocol:progress:token-injected") +async def test_progress_token_visible_to_handler(connect: Connect) -> None: + """Supplying a progress callback attaches a progress token that the handler can read from the request meta.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="inspect", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "inspect" + assert ctx.meta is not None + return CallToolResult(content=[TextContent(text=str(ctx.meta.get("progress_token")))]) + + server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + + async def ignore(progress: float, total: float | None, message: str | None) -> None: + """A progress callback that is never invoked; the tool only inspects the token.""" + raise NotImplementedError + + async with connect(server) as client: + result = await client.call_tool("inspect", {}, progress_callback=ignore) + + # The token is the request id of the tools/call request itself (initialize is request 0). + assert result == snapshot(CallToolResult(content=[TextContent(text="1")])) + + +@requirement("protocol:progress:no-token") +async def test_no_progress_callback_means_no_token(connect: Connect) -> None: + """Without a progress callback the request carries no progress token. + + The low-level API has no way to report request-scoped progress without a token, so a handler + that sees no token has nothing to send progress against. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="inspect", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "inspect" + assert ctx.meta is not None + return CallToolResult(content=[TextContent(text=str(ctx.meta.get("progress_token")))]) + + server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("inspect", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="None")])) + + +@requirement("protocol:progress:client-to-server") +async def test_client_progress_notification_reaches_server_handler(connect: Connect) -> None: + """A progress notification sent by the client is delivered to the server's progress handler.""" + received: list[ProgressNotificationParams] = [] + delivered = anyio.Event() + + async def on_progress(ctx: ServerRequestContext, params: ProgressNotificationParams) -> None: + received.append(params) + delivered.set() + + server = Server("observer", on_progress=on_progress) + + async with connect(server) as client: + await client.send_progress_notification("upload-1", 0.5, total=1.0, message="halfway") + with anyio.fail_after(5): + await delivered.wait() + + assert received == snapshot( + [ProgressNotificationParams(progress_token="upload-1", progress=0.5, total=1.0, message="halfway")] + ) + + +@requirement("protocol:progress:token-unique") +async def test_concurrent_requests_carry_distinct_progress_tokens(connect: Connect) -> None: + """Two concurrent requests carry distinct progress tokens, and each callback sees only its own progress. + + Without the barrier the first call could run to completion before the second starts, so only one + token would be live at a time and the demultiplexing would never be exercised. The handlers each + block until both have started and then hand control back and forth so the four progress + notifications are emitted in strict a, b, a, b order on the wire. The two handlers send different + progress values so a stream swap (token A delivered to callback B and vice versa) would fail: each + callback receiving exactly its own values proves notifications are routed by token, not by arrival + order or by chance. + """ + progress_values = {"a": (1.0, 2.0), "b": (10.0, 20.0)} + tokens: dict[str, ProgressToken] = {} + entered = {"a": anyio.Event(), "b": anyio.Event()} + # turns[n] is set to release the nth emission; each emission releases the next. + turns = [anyio.Event() for _ in range(4)] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="report", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "report" + assert params.arguments is not None + assert ctx.meta is not None + token = ctx.meta.get("progress_token") + assert token is not None + label = params.arguments["label"] + tokens[label] = token + entered[label].set() + # The two handlers interleave by waiting on alternating turns: a takes 0 and 2, b takes 1 and 3. + first, second = (0, 2) if label == "a" else (1, 3) + await turns[first].wait() + await ctx.session.send_progress_notification( + token, progress_values[label][0], related_request_id=str(ctx.request_id) + ) + turns[first + 1].set() + await turns[second].wait() + await ctx.session.send_progress_notification( + token, progress_values[label][1], related_request_id=str(ctx.request_id) + ) + if second + 1 < len(turns): + turns[second + 1].set() + return CallToolResult(content=[TextContent(text="done")]) + + server = Server("reporter", on_list_tools=list_tools, on_call_tool=call_tool) + + received_a: list[float] = [] + received_b: list[float] = [] + + async def collect_a(progress: float, total: float | None, message: str | None) -> None: + received_a.append(progress) + + async def collect_b(progress: float, total: float | None, message: str | None) -> None: + received_b.append(progress) + + async with connect(server) as client: + + async def call(label: str, collect: ProgressFnT) -> None: + await client.call_tool("report", {"label": label}, progress_callback=collect) + + with anyio.fail_after(5): + async with anyio.create_task_group() as task_group: # pragma: no branch + task_group.start_soon(call, "a", collect_a) + task_group.start_soon(call, "b", collect_b) + await entered["a"].wait() + await entered["b"].wait() + turns[0].set() + + assert tokens["a"] != tokens["b"] + assert received_a == [1.0, 2.0] + assert received_b == [10.0, 20.0] + + +@requirement("protocol:progress:stops-after-completion") +@requirement("protocol:progress:late-dropped-by-client") +async def test_progress_sent_after_the_response_is_not_delivered_to_the_callback(connect: Connect) -> None: + """A progress notification sent after the response is emitted, and the client drops it from the callback. + + This single body proves both halves: the server's `send_progress_notification` happily sends for + a token whose request has already completed (the spec MUST that progress stops is not enforced; + see the divergence on `stops-after-completion`), and the client, having removed the callback when + the call returned, does not deliver the late notification to it. The message handler observes the + late notification arriving so the test knows when to assert without polling. + """ + captured: list[tuple[ServerSession, ProgressToken]] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="report", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "report" + assert ctx.meta is not None + token = ctx.meta.get("progress_token") + assert token is not None + captured.append((ctx.session, token)) + await ctx.session.send_progress_notification(token, 0.5, related_request_id=str(ctx.request_id)) + return CallToolResult(content=[TextContent(text="done")]) + + server = Server("reporter", on_list_tools=list_tools, on_call_tool=call_tool) + + received: list[float] = [] + late_progress_arrived = anyio.Event() + + async def collect(progress: float, total: float | None, message: str | None) -> None: + received.append(progress) + + async def message_handler(message: IncomingMessage) -> None: + if isinstance(message, ProgressNotification) and message.params.progress == 1.0: + late_progress_arrived.set() + + async with connect(server, message_handler=message_handler) as client: + with anyio.fail_after(5): + await client.call_tool("report", {}, progress_callback=collect) + assert received == [0.5] + + server_session, token = captured[0] + await server_session.send_progress_notification(token, 1.0) + await late_progress_arrived.wait() + + assert received == [0.5] + + +@requirement("protocol:progress:monotonic") +async def test_non_increasing_progress_values_are_forwarded_unchanged(connect: Connect) -> None: + """A handler that emits non-increasing progress values has them forwarded to the callback unchanged. + + The spec says progress MUST increase with each notification; the SDK does not enforce that on + either side. See the divergence note on the requirement. + """ + received: list[float] = [] + + async def collect(progress: float, total: float | None, message: str | None) -> None: + received.append(progress) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="zigzag", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "zigzag" + assert ctx.meta is not None + token = ctx.meta.get("progress_token") + assert token is not None + await ctx.session.send_progress_notification(token, 0.5, related_request_id=str(ctx.request_id)) + await ctx.session.send_progress_notification(token, 0.3, related_request_id=str(ctx.request_id)) + await ctx.session.send_progress_notification(token, 0.9, related_request_id=str(ctx.request_id)) + return CallToolResult(content=[TextContent(text="done")]) + + server = Server("zigzagger", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.call_tool("zigzag", {}, progress_callback=collect) + + assert received == snapshot([0.5, 0.3, 0.9]) diff --git a/tests/interaction/lowlevel/test_prompts.py b/tests/interaction/lowlevel/test_prompts.py new file mode 100644 index 0000000000..868b82692c --- /dev/null +++ b/tests/interaction/lowlevel/test_prompts.py @@ -0,0 +1,209 @@ +"""Prompt interactions against the low-level Server, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + INVALID_PARAMS, + AudioContent, + EmbeddedResource, + ErrorData, + GetPromptResult, + Icon, + ImageContent, + ListPromptsResult, + Prompt, + PromptArgument, + PromptMessage, + TextContent, + TextResourceContents, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("prompts:list:basic") +async def test_list_prompts_returns_registered_prompts(connect: Connect) -> None: + """The prompts returned by the handler reach the client with their argument declarations intact.""" + + async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListPromptsResult: + return ListPromptsResult( + prompts=[ + Prompt( + name="code_review", + description="Review a piece of code.", + arguments=[ + PromptArgument(name="code", description="The code to review.", required=True), + PromptArgument(name="style_guide", description="Optional style guide to apply."), + ], + icons=[Icon(src="https://example.com/review.png", mime_type="image/png", sizes=["48x48"])], + ), + Prompt(name="daily_standup"), + ] + ) + + server = Server("prompter", on_list_prompts=list_prompts) + + async with connect(server) as client: + result = await client.list_prompts() + + assert result == snapshot( + ListPromptsResult( + prompts=[ + Prompt( + name="code_review", + description="Review a piece of code.", + arguments=[ + PromptArgument(name="code", description="The code to review.", required=True), + PromptArgument(name="style_guide", description="Optional style guide to apply."), + ], + icons=[Icon(src="https://example.com/review.png", mime_type="image/png", sizes=["48x48"])], + ), + Prompt(name="daily_standup"), + ] + ) + ) + + +@requirement("prompts:get:with-args") +async def test_get_prompt_substitutes_arguments(connect: Connect) -> None: + """Arguments supplied by the client reach the prompt handler; the templated message comes back.""" + + async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: + assert params.name == "greet" + assert params.arguments is not None + return GetPromptResult( + description="A personalised greeting.", + messages=[PromptMessage(role="user", content=TextContent(text=f"Hello, {params.arguments['name']}!"))], + ) + + server = Server("prompter", on_get_prompt=get_prompt) + + async with connect(server) as client: + result = await client.get_prompt("greet", {"name": "Ada"}) + + assert result == snapshot( + GetPromptResult( + description="A personalised greeting.", + messages=[PromptMessage(role="user", content=TextContent(text="Hello, Ada!"))], + ) + ) + + +@requirement("prompts:get:multi-message") +async def test_get_prompt_multiple_messages_preserve_roles_and_order(connect: Connect) -> None: + """A prompt returning a user/assistant conversation reaches the client with roles and order intact.""" + + async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: + assert params.name == "geography_quiz" + return GetPromptResult( + messages=[ + PromptMessage(role="user", content=TextContent(text="What is the capital of France?")), + PromptMessage(role="assistant", content=TextContent(text="The capital of France is Paris.")), + PromptMessage(role="user", content=TextContent(text="And of Italy?")), + ] + ) + + server = Server("prompter", on_get_prompt=get_prompt) + + async with connect(server) as client: + result = await client.get_prompt("geography_quiz") + + assert result == snapshot( + GetPromptResult( + messages=[ + PromptMessage(role="user", content=TextContent(text="What is the capital of France?")), + PromptMessage(role="assistant", content=TextContent(text="The capital of France is Paris.")), + PromptMessage(role="user", content=TextContent(text="And of Italy?")), + ] + ) + ) + + +@requirement("prompts:get:no-args") +async def test_get_prompt_without_arguments_returns_the_messages(connect: Connect) -> None: + """A prompt fetched with no arguments delivers None as the handler's arguments and returns its messages.""" + + async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: + assert params.name == "static" + assert params.arguments is None + return GetPromptResult(messages=[PromptMessage(role="user", content=TextContent(text="Say hello."))]) + + server = Server("prompter", on_get_prompt=get_prompt) + + async with connect(server) as client: + result = await client.get_prompt("static") + + assert result == snapshot( + GetPromptResult(messages=[PromptMessage(role="user", content=TextContent(text="Say hello."))]) + ) + + +@requirement("prompts:get:content:image") +@requirement("prompts:get:content:audio") +@requirement("prompts:get:content:embedded-resource") +async def test_get_prompt_with_non_text_content_round_trips(connect: Connect) -> None: + """Prompt messages can carry image, audio, and embedded-resource content; all reach the client. + + A single full-result snapshot proves all three content types round-trip: each block in the result + is one of the three behaviours under test. Tiny fixed base64 payloads ("aW1n" is b"img", "YXVk" + is b"aud") so the snapshot pins the exact bytes. + """ + + async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: + assert params.name == "media" + return GetPromptResult( + messages=[ + PromptMessage(role="user", content=ImageContent(data="aW1n", mime_type="image/png")), + PromptMessage(role="assistant", content=AudioContent(data="YXVk", mime_type="audio/wav")), + PromptMessage( + role="user", + content=EmbeddedResource( + resource=TextResourceContents(uri="resource://notes/1", mime_type="text/plain", text="attached") + ), + ), + ] + ) + + server = Server("prompter", on_get_prompt=get_prompt) + + async with connect(server) as client: + result = await client.get_prompt("media", {}) + + assert result == snapshot( + GetPromptResult( + messages=[ + PromptMessage(role="user", content=ImageContent(data="aW1n", mime_type="image/png")), + PromptMessage(role="assistant", content=AudioContent(data="YXVk", mime_type="audio/wav")), + PromptMessage( + role="user", + content=EmbeddedResource( + resource=TextResourceContents(uri="resource://notes/1", mime_type="text/plain", text="attached") + ), + ), + ] + ) + ) + + +@requirement("prompts:get:unknown-name") +async def test_get_prompt_unknown_name_is_protocol_error(connect: Connect) -> None: + """A handler that rejects an unrecognised prompt name with MCPError produces a JSON-RPC error. + + The error's code and message chosen by the handler reach the client verbatim. + """ + + async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: + raise MCPError(code=INVALID_PARAMS, message=f"Unknown prompt: {params.name}") + + server = Server("prompter", on_get_prompt=get_prompt) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.get_prompt("nope") + + assert exc_info.value.error == snapshot(ErrorData(code=INVALID_PARAMS, message="Unknown prompt: nope")) diff --git a/tests/interaction/lowlevel/test_resources.py b/tests/interaction/lowlevel/test_resources.py new file mode 100644 index 0000000000..4e369d3645 --- /dev/null +++ b/tests/interaction/lowlevel/test_resources.py @@ -0,0 +1,309 @@ +"""Resource interactions against the low-level Server, driven through the public Client API.""" + +import base64 + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + METHOD_NOT_FOUND, + Annotations, + BlobResourceContents, + CallToolResult, + EmptyResult, + ErrorData, + Icon, + ListResourcesResult, + ListResourceTemplatesResult, + ReadResourceResult, + Resource, + ResourceTemplate, + ResourceUpdatedNotification, + ResourceUpdatedNotificationParams, + TextContent, + TextResourceContents, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("resources:list:basic") +@requirement("resources:annotations") +async def test_list_resources_returns_registered_resources(connect: Connect) -> None: + """Listed resources reach the client with their URIs, names, and optional descriptive fields intact. + + The fully-populated entry includes annotations, so the snapshot also proves they round-trip. + The SDK's Annotations model omits the schema's lastModified field (see the divergence on + resources:annotations); the input is built via model_validate with lastModified set so the + snapshot pins the drop and will fail once the SDK adds the field. + """ + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult( + resources=[ + Resource(uri="memo://minimal", name="minimal"), + Resource( + uri="file:///project/README.md", + name="readme", + title="Project README", + description="The project's front page.", + mime_type="text/markdown", + size=1024, + annotations=Annotations.model_validate( + {"audience": ["user", "assistant"], "priority": 0.8, "lastModified": "2025-01-01T00:00:00Z"} + ), + icons=[Icon(src="https://example.com/readme.png", mime_type="image/png", sizes=["48x48"])], + ), + ] + ) + + server = Server("library", on_list_resources=list_resources) + + async with connect(server) as client: + result = await client.list_resources() + + assert result == snapshot( + ListResourcesResult( + resources=[ + Resource(uri="memo://minimal", name="minimal"), + Resource( + uri="file:///project/README.md", + name="readme", + title="Project README", + description="The project's front page.", + mime_type="text/markdown", + size=1024, + annotations=Annotations(audience=["user", "assistant"], priority=0.8), + icons=[Icon(src="https://example.com/readme.png", mime_type="image/png", sizes=["48x48"])], + ), + ] + ) + ) + + +@requirement("resources:read:text") +async def test_read_resource_text(connect: Connect) -> None: + """Reading a text resource returns its contents with the URI, MIME type, and text supplied by the handler.""" + + async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: + return ReadResourceResult( + contents=[TextResourceContents(uri=params.uri, mime_type="text/plain", text="Hello, world!")] + ) + + server = Server("library", on_read_resource=read_resource) + + async with connect(server) as client: + result = await client.read_resource("file:///greeting.txt") + + assert result == snapshot( + ReadResourceResult( + contents=[TextResourceContents(uri="file:///greeting.txt", mime_type="text/plain", text="Hello, world!")] + ) + ) + + +@requirement("resources:read:blob") +async def test_read_resource_binary(connect: Connect) -> None: + """Reading a binary resource returns its contents base64-encoded in the blob field.""" + + async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: + return ReadResourceResult( + contents=[ + BlobResourceContents( + uri=params.uri, + mime_type="image/png", + blob=base64.b64encode(b"\x89PNG").decode(), + ) + ] + ) + + server = Server("library", on_read_resource=read_resource) + + async with connect(server) as client: + result = await client.read_resource("file:///pixel.png") + + assert result == snapshot( + ReadResourceResult( + contents=[BlobResourceContents(uri="file:///pixel.png", mime_type="image/png", blob="iVBORw==")] + ) + ) + + +@requirement("resources:read:unknown-uri") +async def test_read_resource_unknown_uri_is_protocol_error(connect: Connect) -> None: + """A handler that rejects an unrecognised URI with MCPError produces a JSON-RPC error. + + The spec reserves -32002 for resource-not-found; the code is the handler's choice and reaches + the client verbatim. + """ + + async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: + raise MCPError(code=-32002, message=f"Resource not found: {params.uri}") + + server = Server("library", on_read_resource=read_resource) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.read_resource("file:///missing.txt") + + assert exc_info.value.error == snapshot(ErrorData(code=-32002, message="Resource not found: file:///missing.txt")) + + +@requirement("resources:templates:list") +async def test_list_resource_templates_returns_registered_templates(connect: Connect) -> None: + """Listed resource templates reach the client with their URI templates and descriptive fields intact.""" + + async def list_resource_templates( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourceTemplatesResult: + return ListResourceTemplatesResult( + resource_templates=[ + ResourceTemplate(uri_template="users://{user_id}", name="user"), + ResourceTemplate( + uri_template="logs://{service}/{date}", + name="service_logs", + title="Service logs", + description="One day of logs for one service.", + mime_type="text/plain", + icons=[Icon(src="https://example.com/logs.png", mime_type="image/png", sizes=["48x48"])], + ), + ] + ) + + server = Server("library", on_list_resource_templates=list_resource_templates) + + async with connect(server) as client: + result = await client.list_resource_templates() + + assert result == snapshot( + ListResourceTemplatesResult( + resource_templates=[ + ResourceTemplate(uri_template="users://{user_id}", name="user"), + ResourceTemplate( + uri_template="logs://{service}/{date}", + name="service_logs", + title="Service logs", + description="One day of logs for one service.", + mime_type="text/plain", + icons=[Icon(src="https://example.com/logs.png", mime_type="image/png", sizes=["48x48"])], + ), + ] + ) + ) + + +@requirement("resources:subscribe") +async def test_subscribe_resource_delivers_uri_to_handler(connect: Connect) -> None: + """Subscribing to a resource delivers the URI to the server's subscribe handler and returns an empty result.""" + + async def subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeRequestParams) -> EmptyResult: + assert params.uri == "file:///watched.txt" + return EmptyResult() + + server = Server("library", on_subscribe_resource=subscribe_resource) + + async with connect(server) as client: + result = await client.subscribe_resource("file:///watched.txt") + + assert result == snapshot(EmptyResult()) + + +@requirement("resources:subscribe:capability-required") +async def test_subscribe_without_a_subscribe_handler_is_method_not_found(connect: Connect) -> None: + """Subscribing to a server that registered no subscribe handler is rejected with METHOD_NOT_FOUND. + + The rejection comes from no handler being registered, not from any capability check; see the + divergence on lifecycle:capability:server-not-advertised. + """ + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: + """Registered only so the resources capability is advertised; never called.""" + raise NotImplementedError + + server = Server("library", on_list_resources=list_resources) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.subscribe_resource("file:///watched.txt") + + assert exc_info.value.error == snapshot(ErrorData(code=METHOD_NOT_FOUND, message="Method not found")) + + +@requirement("resources:unsubscribe") +async def test_unsubscribe_resource_delivers_uri_to_handler(connect: Connect) -> None: + """Unsubscribing from a resource delivers the URI to the server's unsubscribe handler.""" + + async def unsubscribe_resource(ctx: ServerRequestContext, params: types.UnsubscribeRequestParams) -> EmptyResult: + assert params.uri == "file:///watched.txt" + return EmptyResult() + + server = Server("library", on_unsubscribe_resource=unsubscribe_resource) + + async with connect(server) as client: + result = await client.unsubscribe_resource("file:///watched.txt") + + assert result == snapshot(EmptyResult()) + + +@requirement("resources:updated-notification") +async def test_resource_updated_notification_reaches_client(connect: Connect) -> None: + """A resources/updated notification sent during a tool call reaches the client with the resource URI. + + ``send_resource_updated`` does not take a ``related_request_id``, so over streamable HTTP the + notification routes to the standalone GET stream and is not guaranteed to arrive before the + tool result; the test waits on an event the collector sets. The collector records every + message the handler receives, so the assertion also proves nothing else was delivered. + """ + received: list[IncomingMessage] = [] + seen = anyio.Event() + + async def collect(message: IncomingMessage) -> None: + received.append(message) + seen.set() + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="touch", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "touch" + await ctx.session.send_resource_updated("file:///watched.txt") + return CallToolResult(content=[TextContent(text="touched")]) + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: + """Registered so the resources capability is advertised; the client never lists resources.""" + raise NotImplementedError + + async def subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeRequestParams) -> EmptyResult: + """Registered so the resources subscribe sub-capability is advertised; the client never subscribes.""" + raise NotImplementedError + + server = Server( + "library", + on_list_tools=list_tools, + on_call_tool=call_tool, + on_list_resources=list_resources, + on_subscribe_resource=subscribe_resource, + ) + + async with connect(server, message_handler=collect) as client: + await client.call_tool("touch", {}) + with anyio.fail_after(5): + await seen.wait() + + assert received == snapshot( + [ResourceUpdatedNotification(params=ResourceUpdatedNotificationParams(uri="file:///watched.txt"))] + ) diff --git a/tests/interaction/lowlevel/test_roots.py b/tests/interaction/lowlevel/test_roots.py new file mode 100644 index 0000000000..8149e0befb --- /dev/null +++ b/tests/interaction/lowlevel/test_roots.py @@ -0,0 +1,166 @@ +"""Roots interactions against the low-level Server, driven through the public Client API.""" + +import anyio +import pytest +from inline_snapshot import snapshot +from pydantic import FileUrl + +from mcp import MCPError, types +from mcp.client import ClientRequestContext +from mcp.server import Server, ServerRequestContext +from mcp.types import INTERNAL_ERROR, CallToolResult, ErrorData, ListRootsResult, Root, TextContent +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("roots:list:basic") +async def test_list_roots_round_trip(connect: Connect) -> None: + """A roots/list request from a tool handler is answered by the client's roots callback. + + The tool reports the URIs and names it received, proving the client's roots reached the server. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="show_roots", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "show_roots" + result = await ctx.session.list_roots() + lines = [f"{root.uri} name={root.name}" for root in result.roots] + return CallToolResult(content=[TextContent(text="\n".join(lines))]) + + server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) + + async def list_roots(context: ClientRequestContext) -> ListRootsResult: + return ListRootsResult( + roots=[ + Root(uri=FileUrl("file:///home/alice/project"), name="project"), + Root(uri=FileUrl("file:///home/alice/scratch")), + ] + ) + + async with connect(server, list_roots_callback=list_roots) as client: + result = await client.call_tool("show_roots", {}) + + assert result == snapshot( + CallToolResult( + content=[TextContent(text="file:///home/alice/project name=project\nfile:///home/alice/scratch name=None")] + ) + ) + + +@requirement("roots:list:empty") +async def test_list_roots_empty(connect: Connect) -> None: + """A client with no roots to offer answers roots/list with an empty list, not an error.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="count_roots", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "count_roots" + result = await ctx.session.list_roots() + return CallToolResult(content=[TextContent(text=str(len(result.roots)))]) + + server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) + + async def list_roots(context: ClientRequestContext) -> ListRootsResult: + return ListRootsResult(roots=[]) + + async with connect(server, list_roots_callback=list_roots) as client: + result = await client.call_tool("count_roots", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="0")])) + + +@requirement("roots:list:not-supported") +async def test_list_roots_without_callback_is_error(connect: Connect) -> None: + """A roots/list request to a client with no roots callback fails with an error the handler can observe. + + The client's default callback answers with INVALID_REQUEST rather than leaving the server + hanging; the spec names -32601 for this case (see the divergence note on the requirement). + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="show_roots", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "show_roots" + try: + await ctx.session.list_roots() + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # list_roots cannot succeed without a client callback + + server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("show_roots", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="-32600: List roots not supported")])) + + +@requirement("roots:list:client-error") +async def test_list_roots_callback_error_surfaces_to_the_handler(connect: Connect) -> None: + """A roots callback that answers with an error fails the roots/list request with that exact error. + + The callback's code and message reach the requesting handler verbatim as an MCPError. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="show_roots", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "show_roots" + try: + await ctx.session.list_roots() + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # the callback always answers with an error + + server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) + + async def list_roots(context: ClientRequestContext) -> ErrorData: + return ErrorData(code=INTERNAL_ERROR, message="roots provider crashed") + + async with connect(server, list_roots_callback=list_roots) as client: + result = await client.call_tool("show_roots", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="-32603: roots provider crashed")])) + + +@requirement("roots:list-changed") +async def test_roots_list_changed_reaches_server_handler(connect: Connect) -> None: + """A roots/list_changed notification from the client is delivered to the server's handler. + + Unlike a request, a notification has no response to await: the handler sets an event and the + test waits on it, which is the only synchronisation point proving delivery. + """ + delivered = anyio.Event() + received: list[types.NotificationParams | None] = [] + + async def roots_list_changed(ctx: ServerRequestContext, params: types.NotificationParams | None) -> None: + received.append(params) + delivered.set() + + server = Server("rooted", on_roots_list_changed=roots_list_changed) + + async def list_roots(context: ClientRequestContext) -> ListRootsResult: + """Registered so the client declares the roots capability; the server never asks for roots.""" + raise NotImplementedError + + async with connect(server, list_roots_callback=list_roots) as client: + await client.send_roots_list_changed() + with anyio.fail_after(5): + await delivered.wait() + + assert received == snapshot([None]) diff --git a/tests/interaction/lowlevel/test_sampling.py b/tests/interaction/lowlevel/test_sampling.py new file mode 100644 index 0000000000..260e564192 --- /dev/null +++ b/tests/interaction/lowlevel/test_sampling.py @@ -0,0 +1,687 @@ +"""Sampling interactions against the low-level Server, driven through the public Client API. + +Each test nests a sampling/createMessage request inside a tool call: the tool handler calls +ctx.session.create_message(), the client's sampling callback answers it, and the handler +round-trips what it received back to the test through its tool result. +""" + +import pydantic +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client import ClientRequestContext +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + AudioContent, + CallToolResult, + CreateMessageRequestParams, + CreateMessageResult, + CreateMessageResultWithTools, + ErrorData, + ImageContent, + ModelHint, + ModelPreferences, + SamplingCapability, + SamplingMessage, + TextContent, + ToolResultContent, + ToolUseContent, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("sampling:create:basic") +@requirement("tools:call:sampling-roundtrip") +async def test_create_message_round_trip(connect: Connect) -> None: + """A handler's sampling request is answered by the client callback, and the callback's result + (role, content, model, stop reason) is returned to the handler. + """ + received: list[CreateMessageRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Say hello."))], + max_tokens=100, + ) + assert isinstance(result.content, TextContent) + return CallToolResult(content=[TextContent(text=f"{result.model}/{result.stop_reason}: {result.content.text}")]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + received.append(params) + return CreateMessageResult( + role="assistant", + content=TextContent(text="Hello to you too."), + model="mock-llm-1", + stop_reason="endTurn", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="mock-llm-1/endTurn: Hello to you too.")])) + assert received == snapshot( + [ + CreateMessageRequestParams( + _meta={}, + messages=[SamplingMessage(role="user", content=TextContent(text="Say hello."))], + max_tokens=100, + ) + ] + ) + + +@requirement("sampling:create:include-context") +@requirement("sampling:create:model-preferences") +@requirement("sampling:create:system-prompt") +@requirement("sampling:context:server-gated-by-capability") +async def test_create_message_params_reach_callback(connect: Connect) -> None: + """Every sampling parameter the handler supplies arrives at the client callback unchanged. + + The client has not declared the sampling.context capability (Client cannot declare it), yet + include_context="thisServer" reaches the callback regardless: the spec's SHOULD NOT is not + enforced. See the divergence note on `sampling:context:server-gated-by-capability`. + """ + received: list[CreateMessageRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Pick a model."))], + max_tokens=50, + system_prompt="You are terse.", + include_context="thisServer", + temperature=0.7, + stop_sequences=["\n\n", "END"], + model_preferences=ModelPreferences( + hints=[ModelHint(name="claude"), ModelHint(name="gpt")], + cost_priority=0.2, + speed_priority=0.3, + intelligence_priority=0.9, + ), + ) + assert isinstance(result.content, TextContent) + return CallToolResult(content=[TextContent(text=result.content.text)]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + received.append(params) + return CreateMessageResult(role="assistant", content=TextContent(text="ok"), model="mock-llm-1") + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="ok")])) + assert received == snapshot( + [ + CreateMessageRequestParams( + _meta={}, + messages=[SamplingMessage(role="user", content=TextContent(text="Pick a model."))], + model_preferences=ModelPreferences( + hints=[ModelHint(name="claude"), ModelHint(name="gpt")], + cost_priority=0.2, + speed_priority=0.3, + intelligence_priority=0.9, + ), + system_prompt="You are terse.", + include_context="thisServer", + temperature=0.7, + max_tokens=50, + stop_sequences=["\n\n", "END"], + ) + ] + ) + + +@requirement("sampling:create-message:image-content") +async def test_create_message_request_with_image_content_reaches_callback(connect: Connect) -> None: + """A sampling request message carrying image content arrives at the client callback intact. + + This is the server-to-client direction: the server includes an image in the conversation it + asks the client to sample from. + """ + received: list[CreateMessageRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="describe_image", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "describe_image" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=ImageContent(data="aW1n", mime_type="image/png"))], + max_tokens=100, + ) + assert isinstance(result.content, TextContent) + return CallToolResult(content=[TextContent(text=result.content.text)]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + received.append(params) + image = params.messages[0].content + assert isinstance(image, ImageContent) + return CreateMessageResult( + role="assistant", + content=TextContent(text=f"described {image.mime_type} ({image.data})"), + model="mock-vision-1", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("describe_image", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="described image/png (aW1n)")])) + assert received == snapshot( + [ + CreateMessageRequestParams( + _meta={}, + messages=[SamplingMessage(role="user", content=ImageContent(data="aW1n", mime_type="image/png"))], + max_tokens=100, + ) + ] + ) + + +@requirement("sampling:create-message:image-content") +async def test_create_message_result_with_image_content_returns_to_handler(connect: Connect) -> None: + """A sampling result whose content is an image is returned to the requesting handler intact. + + This is the client-to-server direction: the model's response is an image rather than text. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="draw", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "draw" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Draw a cat."))], + max_tokens=100, + ) + image = result.content + assert isinstance(image, ImageContent) + return CallToolResult(content=[TextContent(text=f"{result.model}: {image.mime_type} {image.data}")]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + return CreateMessageResult( + role="assistant", + content=ImageContent(data="Y2F0", mime_type="image/png"), + model="mock-vision-1", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("draw", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="mock-vision-1: image/png Y2F0")])) + + +@requirement("sampling:error:user-rejected") +async def test_create_message_callback_error(connect: Connect) -> None: + """A sampling callback that answers with an error surfaces to the requesting handler as an MCPError. + + The error here is the spec's own example for a user rejecting a sampling request (code -1); + the callback's code and message reach the handler verbatim, whatever they are. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + try: + await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Say hello."))], + max_tokens=100, + ) + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # the callback always answers with an error + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback(context: ClientRequestContext, params: CreateMessageRequestParams) -> ErrorData: + return ErrorData(code=-1, message="User rejected sampling request") + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="-1: User rejected sampling request")])) + + +@requirement("sampling:create-message:not-supported") +async def test_create_message_without_callback_is_error(connect: Connect) -> None: + """A sampling request to a client with no sampling callback fails with the SDK's default error.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + try: + await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Say hello."))], + max_tokens=100, + ) + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # create_message cannot succeed without a client callback + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="-32600: Sampling not supported")])) + + +@requirement("sampling:tools:server-gated-by-capability") +async def test_create_message_with_tools_is_rejected_for_unsupporting_client(connect: Connect) -> None: + """A tool-enabled sampling request to a client that has not declared sampling.tools never leaves the server. + + The client supports plain sampling but cannot declare the tools sub-capability (Client does not + expose it), so the server-side validator rejects the request before anything reaches the wire. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + try: + await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="What is the weather?"))], + max_tokens=100, + tools=[types.Tool(name="get_weather", input_schema={"type": "object"})], + ) + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # the validator rejects every tool-enabled request + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + """Declares the plain sampling capability; never invoked because the request is rejected first.""" + raise NotImplementedError + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="-32602: Client does not support sampling tools capability")]) + ) + + +@requirement("sampling:tool-result:no-mixed-content") +async def test_create_message_with_mixed_tool_result_content_is_rejected(connect: Connect) -> None: + """A sampling request whose user message mixes tool_result with other content never leaves the server. + + The message-structure validation runs inside create_message before the request is sent, even + when no tools are passed, so the client callback is never invoked and the handler observes the + ValueError directly. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="summarise_tools", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "summarise_tools" + try: + await ctx.session.create_message( + messages=[ + SamplingMessage( + role="user", + content=[ + ToolResultContent(tool_use_id="call-1", content=[TextContent(text="42")]), + TextContent(text="Also, a comment alongside the result."), + ], + ) + ], + max_tokens=100, + ) + except ValueError as exc: + return CallToolResult(content=[TextContent(text=f"{type(exc).__name__}: {exc}")]) + raise NotImplementedError # the validator rejects the malformed messages before sending + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + """Declares the sampling capability; never invoked because the request is rejected first.""" + raise NotImplementedError + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("summarise_tools", {}) + + assert result == snapshot( + CallToolResult( + content=[ + TextContent(text="ValueError: The last message must contain only tool_result content if any is present") + ] + ) + ) + + +@requirement("sampling:capability:declare") +async def test_a_client_with_a_sampling_callback_declares_the_sampling_capability(connect: Connect) -> None: + """A client connecting with a sampling callback advertises the sampling capability to the server. + + Client cannot declare any sub-capabilities (it does not expose ClientSession's + sampling_capabilities parameter), so the snapshot pins an empty SamplingCapability. + """ + captured: list[SamplingCapability | None] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="capabilities", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "capabilities" + assert ctx.session.client_params is not None + captured.append(ctx.session.client_params.capabilities.sampling) + return CallToolResult(content=[TextContent(text="ok")]) + + server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + """Registered only so the sampling capability is advertised; never called.""" + raise NotImplementedError + + async with connect(server, sampling_callback=sampling_callback) as client: + await client.call_tool("capabilities", {}) + + assert captured == snapshot([SamplingCapability()]) + + +@requirement("sampling:create-message:audio-content") +async def test_create_message_request_with_audio_content_reaches_callback(connect: Connect) -> None: + """A sampling request message carrying audio content arrives at the client callback intact. + + This is the server-to-client direction: the server includes audio in the conversation it asks + the client to sample from. + """ + received: list[CreateMessageRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="transcribe", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "transcribe" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=AudioContent(data="c25k", mime_type="audio/wav"))], + max_tokens=100, + ) + assert isinstance(result.content, TextContent) + return CallToolResult(content=[TextContent(text=result.content.text)]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + received.append(params) + audio = params.messages[0].content + assert isinstance(audio, AudioContent) + return CreateMessageResult( + role="assistant", + content=TextContent(text=f"transcribed {audio.mime_type} ({audio.data})"), + model="mock-audio-1", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("transcribe", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="transcribed audio/wav (c25k)")])) + assert received == snapshot( + [ + CreateMessageRequestParams( + _meta={}, + messages=[SamplingMessage(role="user", content=AudioContent(data="c25k", mime_type="audio/wav"))], + max_tokens=100, + ) + ] + ) + + +@requirement("sampling:create-message:audio-content") +async def test_create_message_result_with_audio_content_returns_to_handler(connect: Connect) -> None: + """A sampling result whose content is audio is returned to the requesting handler intact. + + This is the client-to-server direction: the model's response is audio rather than text. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="speak", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "speak" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Say hello, aloud."))], + max_tokens=100, + ) + audio = result.content + assert isinstance(audio, AudioContent) + return CallToolResult(content=[TextContent(text=f"{result.model}: {audio.mime_type} {audio.data}")]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + return CreateMessageResult( + role="assistant", + content=AudioContent(data="aGVsbG8=", mime_type="audio/wav"), + model="mock-audio-1", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("speak", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="mock-audio-1: audio/wav aGVsbG8=")])) + + +@requirement("sampling:message:content-cardinality") +async def test_create_message_with_list_valued_message_content_reaches_callback(connect: Connect) -> None: + """A sampling message whose content is a list of blocks arrives at the client callback as a list.""" + received: list[CreateMessageRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="caption", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "caption" + result = await ctx.session.create_message( + messages=[ + SamplingMessage( + role="user", + content=[ + TextContent(text="Caption this image."), + ImageContent(data="aW1n", mime_type="image/png"), + ], + ) + ], + max_tokens=100, + ) + assert isinstance(result.content, TextContent) + return CallToolResult(content=[TextContent(text=result.content.text)]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + received.append(params) + content = params.messages[0].content + assert isinstance(content, list) + return CreateMessageResult( + role="assistant", content=TextContent(text=f"{len(content)} blocks"), model="mock-llm-1" + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("caption", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="2 blocks")])) + assert received == snapshot( + [ + CreateMessageRequestParams( + _meta={}, + messages=[ + SamplingMessage( + role="user", + content=[ + TextContent(text="Caption this image."), + ImageContent(data="aW1n", mime_type="image/png"), + ], + ) + ], + max_tokens=100, + ) + ] + ) + + +@requirement("sampling:tool-use:server-preflight") +async def test_create_message_with_mismatched_tool_use_and_result_ids_is_rejected(connect: Connect) -> None: + """A sampling request whose tool_result ids do not match the preceding tool_use ids never leaves the server. + + The message-structure validation runs inside create_message before the request is sent, so the + client callback is never invoked and the handler observes the ValueError directly. The spec's + client-side -32602 check is tracked separately at sampling:tool-use:result-balance. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="continue_tools", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "continue_tools" + try: + await ctx.session.create_message( + messages=[ + SamplingMessage( + role="assistant", + content=[ToolUseContent(id="call-1", name="weather", input={})], + ), + SamplingMessage( + role="user", + content=[ToolResultContent(tool_use_id="call-WRONG", content=[TextContent(text="42")])], + ), + ], + max_tokens=100, + ) + except ValueError as exc: + return CallToolResult(content=[TextContent(text=f"{type(exc).__name__}: {exc}")]) + raise NotImplementedError # the validator rejects the malformed messages before sending + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + """Declares the sampling capability; never invoked because the request is rejected first.""" + raise NotImplementedError + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("continue_tools", {}) + + assert result == snapshot( + CallToolResult( + content=[ + TextContent( + text="ValueError: ids of tool_result blocks and tool_use blocks from previous message do not match" + ) + ] + ) + ) + + +@requirement("sampling:result:no-tools-single-content") +async def test_array_content_result_for_a_tool_free_request_surfaces_as_a_validation_error(connect: Connect) -> None: + """An array-content sampling result for a tool-free request is accepted by the client and fails server-side. + + Only the exception type is asserted: the message is pydantic's, which changes across releases. + See the divergence note on the requirement: the intended behaviour is that the client rejects + the result; instead the client accepts it and the server's response parsing raises. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + try: + await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Two thoughts, please."))], + max_tokens=100, + ) + except pydantic.ValidationError as exc: + return CallToolResult(content=[TextContent(text=type(exc).__name__)]) + raise NotImplementedError # the array-content result fails server-side parsing every time + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResultWithTools: + return CreateMessageResultWithTools( + role="assistant", + content=[TextContent(text="First thought."), TextContent(text="Second thought.")], + model="mock-llm-1", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="ValidationError")])) diff --git a/tests/interaction/lowlevel/test_timeouts.py b/tests/interaction/lowlevel/test_timeouts.py new file mode 100644 index 0000000000..a9c83d641d --- /dev/null +++ b/tests/interaction/lowlevel/test_timeouts.py @@ -0,0 +1,114 @@ +"""Request timeouts against the low-level Server, driven through the public Client API. + +The handler blocks on an event that is never set, so the awaited response can never arrive and +any positive timeout fires deterministically on the next event-loop pass. The timeout is therefore +set to an effectively-zero duration: the tests add no wall-clock time to the suite. (Zero itself +cannot be used: a falsy read_timeout_seconds is silently treated as "no timeout".) +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client.client import Client +from mcp.server import Server, ServerRequestContext +from mcp.types import REQUEST_TIMEOUT, CallToolResult, ErrorData, TextContent +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("protocol:timeout:basic") +@requirement("protocol:timeout:sends-cancellation") +async def test_request_timeout_fails_the_pending_call() -> None: + """A request whose response does not arrive within its read timeout fails with a timeout error. + + No cancellation is sent to the server (see the divergence note on the requirement): the handler + starts and is still running after the caller has already given up. The test waits for the + handler to have started only after the timeout has fired, so the timeout itself races nothing. + """ + handler_started = anyio.Event() + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "block" + handler_started.set() + await anyio.Event().wait() # blocks until the session is torn down + raise NotImplementedError # unreachable + + server = Server("blocker", on_call_tool=call_tool) + + async with Client(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("block", {}, read_timeout_seconds=0.000001) + + # The request was already on the wire: the handler still runs even though the caller gave up. + with anyio.fail_after(5): + await handler_started.wait() + + assert exc_info.value.error == snapshot( + ErrorData( + code=REQUEST_TIMEOUT, + message="Timed out while waiting for response to CallToolRequest. Waited 1e-06 seconds.", + ) + ) + + +@requirement("protocol:timeout:session-survives") +async def test_session_serves_requests_after_timeout() -> None: + """A timed-out request does not poison the session: the next request succeeds.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool(name="block", input_schema={"type": "object"}), + types.Tool(name="echo", input_schema={"type": "object"}), + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + if params.name == "echo": + return CallToolResult(content=[TextContent(text="still alive")]) + await anyio.Event().wait() # blocks until the session is torn down + raise NotImplementedError # unreachable + + server = Server("blocker", on_list_tools=list_tools, on_call_tool=call_tool) + + async with Client(server) as client: + with pytest.raises(MCPError): + await client.call_tool("block", {}, read_timeout_seconds=0.000001) + + result = await client.call_tool("echo", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="still alive")])) + + +@requirement("protocol:timeout:session-default") +async def test_session_level_timeout_applies_to_every_request() -> None: + """A read timeout configured on the client applies to requests that do not set their own.""" + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "block" + await anyio.Event().wait() # blocks until the session is torn down + raise NotImplementedError # unreachable + + server = Server("blocker", on_call_tool=call_tool) + + # The one real wall-clock wait in the suite, and it cannot be made effectively zero like the + # per-request timeouts: a session-level timeout also governs the initialize handshake, so the + # value must be long enough for the in-process handshake to complete before the blocked tool + # call waits it out in full. 50ms buys a ~50x safety margin over the handshake's actual + # latency; lowering it only erodes the margin against CI scheduler jitter without saving + # anything perceptible. + async with Client(server, read_timeout_seconds=0.05) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("block", {}) + + assert exc_info.value.error == snapshot( + ErrorData( + code=REQUEST_TIMEOUT, + message="Timed out while waiting for response to CallToolRequest. Waited 0.05 seconds.", + ) + ) diff --git a/tests/interaction/lowlevel/test_tools.py b/tests/interaction/lowlevel/test_tools.py new file mode 100644 index 0000000000..e8053fbaa7 --- /dev/null +++ b/tests/interaction/lowlevel/test_tools.py @@ -0,0 +1,512 @@ +"""Tool interactions against the low-level Server, driven through the public Client API.""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + INVALID_PARAMS, + AudioContent, + CallToolResult, + EmbeddedResource, + ErrorData, + Icon, + ImageContent, + ListToolsResult, + ResourceLink, + TextContent, + TextResourceContents, + Tool, + ToolAnnotations, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("tools:call:content:text") +async def test_call_tool_returns_text_content(connect: Connect) -> None: + """Arguments reach the tool handler; its content comes back as the call result.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="add", description="Add two integers.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "add" + assert params.arguments is not None + return CallToolResult(content=[TextContent(text=str(params.arguments["a"] + params.arguments["b"]))]) + + server = Server("adder", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("add", {"a": 2, "b": 3}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="5")])) + + +@requirement("tools:call:is-error") +async def test_call_tool_execution_error_is_returned_as_result(connect: Connect) -> None: + """A tool reporting its own failure with is_error=True reaches the client as a result, not an exception. + + Tool execution errors are part of the result so the caller (typically a model) can see + them; only protocol-level failures become JSON-RPC errors. + """ + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "flux" + return CallToolResult(content=[TextContent(text="the flux capacitor is offline")], is_error=True) + + server = Server("errors", on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("flux", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="the flux capacitor is offline")], is_error=True) + ) + + +@requirement("tools:call:unknown-name") +async def test_call_tool_unknown_tool_is_protocol_error(connect: Connect) -> None: + """A handler that rejects an unrecognised tool name with MCPError produces a JSON-RPC error. + + The error's code, message, and data chosen by the handler reach the client verbatim. + """ + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + raise MCPError(code=INVALID_PARAMS, message=f"Unknown tool: {params.name}", data={"requested": params.name}) + + server = Server("errors", on_call_tool=call_tool) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("nope", {}) + + assert exc_info.value.error == snapshot( + ErrorData(code=INVALID_PARAMS, message="Unknown tool: nope", data={"requested": "nope"}) + ) + + +@requirement("protocol:error:internal-error") +async def test_call_tool_uncaught_exception_becomes_error_response(connect: Connect) -> None: + """An uncaught exception in the tool handler surfaces to the client as a JSON-RPC error. + + The low-level server reports it with code 0 and the exception text as the message; see the + divergence note on the requirement. + """ + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "explode" + raise ValueError("boom") + + server = Server("errors", on_call_tool=call_tool) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("explode", {}) + + assert exc_info.value.error == snapshot(ErrorData(code=0, message="boom")) + + +@requirement("tools:list:basic") +async def test_list_tools_returns_registered_tools(connect: Connect) -> None: + """The tools advertised by the server's list handler arrive at the client unchanged.""" + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="add", + description="Add two integers.", + input_schema={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + ), + Tool(name="reset", description="Reset the calculator.", input_schema={"type": "object"}), + ] + ) + + server = Server("calculator", on_list_tools=list_tools) + + async with connect(server) as client: + result = await client.list_tools() + + assert result == snapshot( + ListToolsResult( + tools=[ + Tool( + name="add", + description="Add two integers.", + input_schema={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + ), + Tool(name="reset", description="Reset the calculator.", input_schema={"type": "object"}), + ] + ) + ) + + +@requirement("tools:input-schema:json-schema-2020-12") +@requirement("tools:input-schema:preserve-additional-properties") +@requirement("tools:input-schema:preserve-defs") +@requirement("tools:input-schema:preserve-schema-dialect") +async def test_tools_list_preserves_arbitrary_input_schema_keywords(connect: Connect) -> None: + """A rich JSON Schema 2020-12 inputSchema reaches the client unchanged and the tool is callable. + + The single identity assertion below proves all four pass-through behaviours at once: the same + dict literal that was registered is the dict that arrives, so $schema, $defs, the nested object + property, and additionalProperties are each preserved by virtue of the whole schema being + preserved. The follow-up call proves the rich-schema tool is callable end to end. + """ + schema = { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "$defs": {"positive": {"type": "integer", "exclusiveMinimum": 0}}, + "properties": { + "count": {"$ref": "#/$defs/positive"}, + "options": { + "type": "object", + "properties": {"verbose": {"type": "boolean"}}, + "additionalProperties": False, + }, + }, + "required": ["count"], + "additionalProperties": False, + } + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="typed", input_schema=schema)]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "typed" + assert params.arguments == {"count": 3, "options": {"verbose": True}} + return CallToolResult(content=[TextContent(text="ok")]) + + server = Server("typed", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + listed = await client.list_tools() + called = await client.call_tool("typed", {"count": 3, "options": {"verbose": True}}) + + assert listed.tools[0].input_schema == schema + assert called == snapshot(CallToolResult(content=[TextContent(text="ok")])) + + +@requirement("tools:list:metadata") +async def test_list_tools_optional_fields_round_trip(connect: Connect) -> None: + """Every optional Tool field the server supplies reaches the client unchanged.""" + + tool = Tool( + name="annotated", + title="Annotated tool", + description="A tool carrying every optional field.", + input_schema={"type": "object"}, + output_schema={"type": "object", "properties": {"answer": {"type": "integer"}}}, + icons=[Icon(src="https://example.com/icon.png", mime_type="image/png", sizes=["48x48"])], + annotations=ToolAnnotations(title="Display title", read_only_hint=True, idempotent_hint=True), + _meta={"example.com/source": "interaction-suite"}, + ) + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[tool]) + + server = Server("annotated", on_list_tools=list_tools) + + async with connect(server) as client: + result = await client.list_tools() + + assert result == snapshot( + ListToolsResult( + tools=[ + Tool( + name="annotated", + title="Annotated tool", + description="A tool carrying every optional field.", + input_schema={"type": "object"}, + output_schema={"type": "object", "properties": {"answer": {"type": "integer"}}}, + icons=[Icon(src="https://example.com/icon.png", mime_type="image/png", sizes=["48x48"])], + annotations=ToolAnnotations(title="Display title", read_only_hint=True, idempotent_hint=True), + _meta={"example.com/source": "interaction-suite"}, + ) + ] + ) + ) + + +@requirement("tools:call:content:mixed") +@requirement("tools:call:content:image") +@requirement("tools:call:content:audio") +@requirement("tools:call:content:resource-link") +@requirement("tools:call:content:embedded-resource") +async def test_call_tool_multiple_content_block_types(connect: Connect) -> None: + """A tool result can mix every content block type; all of them arrive in order. + + The payloads are tiny fixed base64 strings ("aW1n" is b"img", "YXVk" is b"aud") so the + snapshot pins the exact bytes the client receives. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="render", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "render" + return CallToolResult( + content=[ + TextContent(text="all five content block types"), + ImageContent(data="aW1n", mime_type="image/png"), + AudioContent(data="YXVk", mime_type="audio/wav"), + ResourceLink(name="report", uri="resource://reports/1", description="The full report"), + EmbeddedResource( + resource=TextResourceContents(uri="resource://reports/1", mime_type="text/plain", text="contents") + ), + ] + ) + + server = Server("renderer", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("render", {}) + + assert result == snapshot( + CallToolResult( + content=[ + TextContent(text="all five content block types"), + ImageContent(data="aW1n", mime_type="image/png"), + AudioContent(data="YXVk", mime_type="audio/wav"), + ResourceLink(name="report", uri="resource://reports/1", description="The full report"), + EmbeddedResource( + resource=TextResourceContents(uri="resource://reports/1", mime_type="text/plain", text="contents") + ), + ] + ) + ) + + +@requirement("tools:call:structured-content") +async def test_call_tool_structured_content(connect: Connect) -> None: + """A tool result carrying structured content alongside content delivers both to the client.""" + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="sum", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "sum" + return CallToolResult(content=[TextContent(text="the sum is 5")], structured_content={"sum": 5}) + + server = Server("calculator", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + result = await client.call_tool("sum", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="the sum is 5")], structured_content={"sum": 5})) + + +@requirement("tools:call:concurrent") +async def test_concurrent_tool_calls_complete_independently(connect: Connect) -> None: + """Two tool calls in flight at once run concurrently and each caller gets its own answer. + + Both handlers are held on a shared event after signalling that they have started, and the test + only releases them once both signals have arrived -- a server that processed requests + sequentially would never start the second handler and the test would time out instead. + """ + started: list[str] = [] + started_events = {"first": anyio.Event(), "second": anyio.Event()} + release = anyio.Event() + results: dict[str, CallToolResult] = {} + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="echo", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "echo" + assert params.arguments is not None + tag = params.arguments["tag"] + assert isinstance(tag, str) + started.append(tag) + started_events[tag].set() + await release.wait() + return CallToolResult(content=[TextContent(text=tag)]) + + server = Server("echoer", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + with anyio.fail_after(5): + async with anyio.create_task_group() as task_group: # pragma: no branch + + async def call_and_record(tag: str) -> None: + results[tag] = await client.call_tool("echo", {"tag": tag}) + + task_group.start_soon(call_and_record, "first") + task_group.start_soon(call_and_record, "second") + + # Both handlers are running at the same time before either is allowed to finish. + await started_events["first"].wait() + await started_events["second"].wait() + release.set() + + assert sorted(started) == ["first", "second"] + assert results == snapshot( + { + "first": CallToolResult(content=[TextContent(text="first")]), + "second": CallToolResult(content=[TextContent(text="second")]), + } + ) + + +@requirement("client:output-schema:validate") +async def test_call_tool_structured_content_violating_output_schema_is_rejected_by_the_client(connect: Connect) -> None: + """A result whose structured content does not conform to the tool's declared output schema never + reaches the caller: the client validates it against the schema cached from tools/list and raises. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="forecast", + input_schema={"type": "object"}, + output_schema={ + "type": "object", + "properties": {"temperature": {"type": "number"}}, + "required": ["temperature"], + }, + ) + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "forecast" + return CallToolResult(content=[TextContent(text="warm")], structured_content={"temperature": "warm"}) + + server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.list_tools() + with pytest.raises(RuntimeError) as exc_info: + await client.call_tool("forecast", {}) + + # The message embeds the jsonschema validation error, so only the SDK-authored prefix is pinned. + assert str(exc_info.value).startswith("Invalid structured content returned by tool forecast") + + +@requirement("client:output-schema:skip-on-error") +async def test_is_error_result_bypasses_client_output_schema_validation(connect: Connect) -> None: + """A tool result with isError true is returned as-is even when its structured content violates the schema. + + The schema is cached up front so the client could validate, proving the bypass is specifically the + isError flag and not an empty cache. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="forecast", + input_schema={"type": "object"}, + output_schema={ + "type": "object", + "properties": {"temperature": {"type": "number"}}, + "required": ["temperature"], + }, + ) + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "forecast" + return CallToolResult( + content=[TextContent(text="boom")], structured_content={"temperature": "warm"}, is_error=True + ) + + server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.list_tools() + result = await client.call_tool("forecast", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="boom")], structured_content={"temperature": "warm"}, is_error=True) + ) + + +@requirement("client:output-schema:missing-structured") +async def test_declared_output_schema_with_no_structured_content_is_rejected_by_the_client(connect: Connect) -> None: + """A tool that declared an output schema but returned no structuredContent fails the client-side check. + + The error is the SDK's own message, so the full text is snapshotted. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="forecast", + input_schema={"type": "object"}, + output_schema={"type": "object", "properties": {"temperature": {"type": "number"}}}, + ) + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "forecast" + return CallToolResult(content=[TextContent(text="warm")]) + + server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.list_tools() + with pytest.raises(RuntimeError) as exc_info: + await client.call_tool("forecast", {}) + + assert str(exc_info.value) == snapshot("Tool forecast has an output schema but did not return structured content") + + +@requirement("client:output-schema:auto-list") +async def test_call_tool_populates_the_output_schema_cache_via_an_implicit_tools_list(connect: Connect) -> None: + """Calling a tool whose schema is not cached issues exactly one implicit tools/list to populate it. + + The first call_tool of an uncached tool triggers a tools/list the caller never asked for; the + second call hits the cache and does not. This is the SDK's chosen cache strategy and the cause of + the surprising behaviour where a server with only on_call_tool sees a successful call answered + with METHOD_NOT_FOUND from a request the caller never made; see the divergence on the requirement. + """ + list_calls: list[str] = [] + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + list_calls.append("called") + return ListToolsResult( + tools=[ + Tool( + name="forecast", + input_schema={"type": "object"}, + output_schema={"type": "object", "properties": {"temperature": {"type": "number"}}}, + ) + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "forecast" + return CallToolResult(content=[TextContent(text="21 C")], structured_content={"temperature": 21}) + + server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + first = await client.call_tool("forecast", {}) + assert list_calls == ["called"] + second = await client.call_tool("forecast", {}) + + assert list_calls == ["called"] + assert first == snapshot(CallToolResult(content=[TextContent(text="21 C")], structured_content={"temperature": 21})) + assert second == first diff --git a/tests/interaction/lowlevel/test_wire.py b/tests/interaction/lowlevel/test_wire.py new file mode 100644 index 0000000000..0f9c58aa7a --- /dev/null +++ b/tests/interaction/lowlevel/test_wire.py @@ -0,0 +1,309 @@ +"""Wire-level invariants observed at the client's transport boundary. + +These behaviours are invisible to API callers -- they are properties of the raw JSON-RPC frames. +The tests wrap the in-memory transport in a RecordingTransport, which tees every message crossing +the transport seam into a list without touching the session, so the assertions hold for whatever +the session implementation sends rather than for what its API returns. + +The later tests drive the wire by hand instead: one closes the server-to-client stream while a +request is in flight to pin the connection-closed teardown, and the last two send deliberately +malformed JSON-RPC requests that the typed client API cannot produce. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client import ClientRequestContext, ClientSession +from mcp.client._memory import InMemoryTransport +from mcp.client.client import Client +from mcp.server import Server, ServerRequestContext +from mcp.shared.memory import create_client_server_memory_streams +from mcp.shared.message import SessionMessage +from mcp.types import ( + CONNECTION_CLOSED, + INVALID_PARAMS, + CallToolRequest, + CallToolRequestParams, + CallToolResult, + EmptyResult, + ErrorData, + JSONRPCError, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + ListRootsResult, + TextContent, +) +from tests.interaction._helpers import RecordingTransport, _RecordingReadStream +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +def _echo_server() -> Server: + """A server with one echo tool, used by every test in this module.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="echo", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "echo" + return CallToolResult(content=[TextContent(text="ok")]) + + return Server("wire", on_list_tools=list_tools, on_call_tool=call_tool) + + +@requirement("protocol:request-id:unique") +async def test_request_ids_are_unique_and_never_null() -> None: + """Every request the client sends carries a distinct, non-null id. + + The id sequence is pinned: sequential integers from zero, in send order. + """ + recording = RecordingTransport(InMemoryTransport(_echo_server())) + + async with Client(recording) as client: + await client.list_tools() + await client.call_tool("echo", {}) + await client.call_tool("echo", {}) + await client.send_ping() + + sent = [message.message for message in recording.sent] + request_ids = [message.id for message in sent if isinstance(message, JSONRPCRequest)] + assert all(request_id is not None for request_id in request_ids) + assert len(request_ids) == len(set(request_ids)) + # initialize, tools/list, tools/call, tools/call, ping -- the client does not issue a + # schema-cache refresh here because the explicit tools/list already populated the cache. + assert request_ids == snapshot([0, 1, 2, 3, 4]) + + +@requirement("protocol:notifications:no-response") +async def test_notifications_are_never_answered() -> None: + """A notification produces no response: everything the server sends back answers a request. + + The client sends two notifications (initialized and roots/list_changed) and several requests; + the messages received from the server must be exactly one response per request, each carrying + the id of the request it answers, and nothing else. + """ + + async def list_roots(context: ClientRequestContext) -> ListRootsResult: + """Registered so the client declares the roots capability; the server never asks for roots.""" + raise NotImplementedError + + recording = RecordingTransport(InMemoryTransport(_echo_server())) + + async with Client(recording, list_roots_callback=list_roots) as client: + await client.send_roots_list_changed() + await client.send_ping() + + sent = [message.message for message in recording.sent] + sent_request_ids = [message.id for message in sent if isinstance(message, JSONRPCRequest)] + sent_notifications = [message for message in sent if isinstance(message, JSONRPCNotification)] + received = [message.message for message in recording.received if isinstance(message, SessionMessage)] + received_responses = [message for message in received if isinstance(message, JSONRPCResponse)] + + assert len(sent_notifications) == 2 # notifications/initialized and notifications/roots/list_changed + assert len(received_responses) == len(received) # nothing the server sent was anything but a response + assert [message.id for message in received_responses] == sent_request_ids + + +async def test_recording_read_stream_ends_iteration_when_the_sender_closes() -> None: + """The recording wrapper preserves the end-of-stream behaviour of the stream it wraps. + + This exercises the helper itself rather than an interaction-model behaviour: a transport whose + far end closes must end the client's receive loop cleanly, and the wrapper must not swallow or + mistranslate that. + """ + send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](1) + log: list[SessionMessage | Exception] = [] + async with send_stream, _RecordingReadStream(receive_stream, log) as wrapped: + await send_stream.aclose() + items = [item async for item in wrapped] + assert items == [] + assert log == [] + + +@requirement("lifecycle:initialized-notification") +async def test_exactly_one_initialized_notification_is_sent_after_the_handshake() -> None: + """The client sends initialized exactly once, between the initialize response and its first request. + + The full method sequence the client puts on the wire is pinned in send order. + """ + recording = RecordingTransport(InMemoryTransport(_echo_server())) + + async with Client(recording) as client: + await client.list_tools() + + sent_methods = [ + message.message.method + for message in recording.sent + if isinstance(message.message, JSONRPCRequest | JSONRPCNotification) + ] + assert sent_methods.count("notifications/initialized") == 1 + assert sent_methods == snapshot(["initialize", "notifications/initialized", "tools/list"]) + + +@requirement("protocol:error:connection-closed") +async def test_closing_the_transport_fails_in_flight_requests_with_connection_closed() -> None: + """When the server-to-client stream closes, every in-flight client request fails with CONNECTION_CLOSED. + + Driven over a bare ClientSession against a real Server so the test holds the transport stream + pair directly: once the request is in flight (the server handler signals it has started) the + test closes the server's write stream, which ends the client's receive loop and triggers the + teardown that fails the pending request. + """ + handler_started = anyio.Event() + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "block" + handler_started.set() + await anyio.Event().wait() # blocks until cancelled; nothing ever sets this event + raise NotImplementedError # unreachable: the wait above never completes normally + + server = Server("blocker", on_call_tool=call_tool) + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + errors: list[ErrorData] = [] + + async with anyio.create_task_group() as server_task_group: + server_task_group.start_soon(server.run, server_read, server_write, server.create_initialization_options()) + + async with ClientSession(client_read, client_write) as session: + with anyio.fail_after(5): + await session.initialize() + + async def call_and_capture_error() -> None: + with pytest.raises(MCPError) as exc_info: + await session.send_request( + CallToolRequest(params=CallToolRequestParams(name="block")), CallToolResult + ) + errors.append(exc_info.value.error) + + async with anyio.create_task_group() as task_group: # pragma: no branch + task_group.start_soon(call_and_capture_error) + await handler_started.wait() + await server_write.aclose() + + server_task_group.cancel_scope.cancel() + + assert errors == snapshot([ErrorData(code=CONNECTION_CLOSED, message="Connection closed")]) + + +@requirement("protocol:error:invalid-params") +async def test_malformed_request_params_are_answered_with_invalid_params() -> None: + """A request whose params fail validation is answered with -32602 Invalid params. + + The typed client API cannot construct a request with the wrong parameter types, so the test + plays the client's side of the wire by hand against a real Server: it completes the + initialization handshake at the JSON-RPC layer and then sends a tools/call whose `name` is an + integer. Reserve this pattern for behaviour the typed API cannot produce. + """ + server = Server("strict") + errors: list[ErrorData] = [] + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async with anyio.create_task_group() as server_task_group: + server_task_group.start_soon(server.run, server_read, server_write, server.create_initialization_options()) + + with anyio.fail_after(5): + await client_write.send( + SessionMessage( + JSONRPCRequest( + jsonrpc="2.0", + id=0, + method="initialize", + params={ + "protocolVersion": "2025-11-25", + "capabilities": {}, + "clientInfo": {"name": "raw", "version": "0.0.1"}, + }, + ) + ) + ) + init_response = await client_read.receive() + assert isinstance(init_response, SessionMessage) + assert isinstance(init_response.message, JSONRPCResponse) + await client_write.send( + SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")) + ) + + await client_write.send( + SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=1, method="tools/call", params={"name": 42})) + ) + error_response = await client_read.receive() + assert isinstance(error_response, SessionMessage) + assert isinstance(error_response.message, JSONRPCError) + errors.append(error_response.message.error) + + server_task_group.cancel_scope.cancel() + + assert errors == snapshot([ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="")]) + + +@requirement("logging:set-level:invalid-level") +async def test_set_level_with_an_unrecognized_value_is_answered_with_invalid_params() -> None: + """logging/setLevel with a value outside the spec's level enum is answered with -32602 Invalid params. + + The typed client API cannot construct a setLevel request with an unrecognized level (pyright and + the client-side model both reject it), so the test plays the client's side of the wire by hand + against a real Server. Reserve this pattern for behaviour the typed API cannot produce. + """ + + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + """Registered so the logging capability is advertised; never called -- params validation fails first.""" + raise NotImplementedError + + server = Server("logger", on_set_logging_level=set_logging_level) + errors: list[ErrorData] = [] + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async with anyio.create_task_group() as server_task_group: + server_task_group.start_soon(server.run, server_read, server_write, server.create_initialization_options()) + + with anyio.fail_after(5): + await client_write.send( + SessionMessage( + JSONRPCRequest( + jsonrpc="2.0", + id=0, + method="initialize", + params={ + "protocolVersion": "2025-11-25", + "capabilities": {}, + "clientInfo": {"name": "raw", "version": "0.0.1"}, + }, + ) + ) + ) + init_response = await client_read.receive() + assert isinstance(init_response, SessionMessage) + assert isinstance(init_response.message, JSONRPCResponse) + await client_write.send( + SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")) + ) + + await client_write.send( + SessionMessage( + JSONRPCRequest(jsonrpc="2.0", id=1, method="logging/setLevel", params={"level": "loud"}) + ) + ) + error_response = await client_read.receive() + assert isinstance(error_response, SessionMessage) + assert isinstance(error_response.message, JSONRPCError) + errors.append(error_response.message.error) + + server_task_group.cancel_scope.cancel() + + assert len(errors) == 1 + assert errors[0].code == INVALID_PARAMS diff --git a/tests/interaction/mcpserver/__init__.py b/tests/interaction/mcpserver/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/interaction/mcpserver/test_completion.py b/tests/interaction/mcpserver/test_completion.py new file mode 100644 index 0000000000..7761066e94 --- /dev/null +++ b/tests/interaction/mcpserver/test_completion.py @@ -0,0 +1,38 @@ +"""Completion behaviour against MCPServer, driven through the public Client API.""" + +import pytest + +from mcp.server.mcpserver import MCPServer +from mcp.types import ( + Completion, + CompletionArgument, + CompletionContext, + CompletionsCapability, + PromptReference, + ResourceTemplateReference, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("mcpserver:completion:capability-auto") +async def test_completion_capability_is_advertised_only_when_a_handler_is_registered(connect: Connect) -> None: + """An MCPServer with a registered completion handler advertises the completions capability; one without does not.""" + with_handler = MCPServer("completer") + + @with_handler.completion() + async def complete( + ref: PromptReference | ResourceTemplateReference, + argument: CompletionArgument, + context: CompletionContext | None, + ) -> Completion | None: + """Registered only so the completions capability is advertised; never called.""" + raise NotImplementedError + + async with connect(with_handler) as client: + assert client.initialize_result.capabilities.completions == CompletionsCapability() + + async with connect(MCPServer("plain")) as client: + assert client.initialize_result.capabilities.completions is None diff --git a/tests/interaction/mcpserver/test_context.py b/tests/interaction/mcpserver/test_context.py new file mode 100644 index 0000000000..26556fea7a --- /dev/null +++ b/tests/interaction/mcpserver/test_context.py @@ -0,0 +1,271 @@ +"""The Context convenience methods MCPServer injects into tool functions, observed from the client.""" + +import pytest +from inline_snapshot import snapshot +from pydantic import BaseModel + +from mcp import MCPError +from mcp.client import ClientRequestContext +from mcp.server.elicitation import AcceptedElicitation +from mcp.server.mcpserver import Context, MCPServer +from mcp.types import ( + METHOD_NOT_FOUND, + CallToolResult, + ElicitRequestFormParams, + ElicitRequestParams, + ElicitResult, + ErrorData, + Implementation, + LoggingMessageNotification, + LoggingMessageNotificationParams, + TextContent, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("mcpserver:context:logging") +@requirement("logging:capability:declared") +async def test_context_logging_helpers_send_log_notifications(connect: Connect) -> None: + """Each Context logging helper sends a log message notification at the matching severity. + + All four notifications reach the client's logging callback before the tool call returns; none + of them carry a logger name unless one is passed explicitly. The server emits these without + advertising the logging capability (see the divergence note on logging:capability). + """ + received: list[LoggingMessageNotificationParams] = [] + mcp = MCPServer("chatty") + + @mcp.tool() + async def narrate(ctx: Context) -> str: + await ctx.debug("d") + await ctx.info("i") + await ctx.warning("w") + await ctx.error("e") + return "done" + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params) + + async with connect(mcp, logging_callback=collect) as client: + result = await client.call_tool("narrate", {}) + advertised_logging = client.initialize_result.capabilities.logging + + assert result == snapshot(CallToolResult(content=[TextContent(text="done")], structured_content={"result": "done"})) + assert received == snapshot( + [ + LoggingMessageNotificationParams(level="debug", data="d"), + LoggingMessageNotificationParams(level="info", data="i"), + LoggingMessageNotificationParams(level="warning", data="w"), + LoggingMessageNotificationParams(level="error", data="e"), + ] + ) + # The spec requires servers that emit log notifications to declare the logging capability. + assert advertised_logging is None + + +@requirement("mcpserver:context:progress") +async def test_context_report_progress_sends_progress_notifications(connect: Connect) -> None: + """Context.report_progress sends progress notifications correlated to the calling request. + + The caller's progress callback receives each report, in order, before the tool call returns. + """ + received: list[tuple[float, float | None, str | None]] = [] + mcp = MCPServer("worker") + + @mcp.tool() + async def crunch(ctx: Context) -> str: + await ctx.report_progress(1, 3) + await ctx.report_progress(2, 3, "halfway there") + return "crunched" + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + received.append((progress, total, message)) + + async with connect(mcp) as client: + result = await client.call_tool("crunch", {}, progress_callback=on_progress) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="crunched")], structured_content={"result": "crunched"}) + ) + assert received == snapshot([(1.0, 3.0, None), (2.0, 3.0, "halfway there")]) + + +@requirement("mcpserver:tool:extra") +async def test_context_exposes_request_id_and_client_info_to_a_tool(connect: Connect) -> None: + """A tool can read the per-request id and the connecting client's identity through Context. + + The request id is non-empty (its concrete value depends on transport-level sequencing, so the + test asserts the value the tool saw is the one returned, rather than pinning the literal); the + client info reflects what the caller passed to `Client`. + """ + mcp = MCPServer("introspector") + + @mcp.tool() + async def whoami(ctx: Context) -> str: + client_params = ctx.session.client_params + assert client_params is not None + return f"request {ctx.request_id} from {client_params.client_info.name} {client_params.client_info.version}" + + async with connect(mcp, client_info=Implementation(name="acme-agent", version="9.9.9")) as client: + result = await client.call_tool("whoami", {}) + + assert isinstance(result.content[0], TextContent) + text = result.content[0].text + assert text.startswith("request ") + assert text.endswith(" from acme-agent 9.9.9") + request_id = text.removeprefix("request ").removesuffix(" from acme-agent 9.9.9") + assert request_id + + +@requirement("protocol:progress:no-token") +async def test_report_progress_without_a_progress_token_sends_nothing(connect: Connect) -> None: + """When the caller supplied no progress callback, Context.report_progress is a silent no-op. + + The tool also emits one log message as a sentinel: the message handler receives only that, + proving the notification pipeline works and no progress notification was sent for the + token-less request. + """ + received: list[IncomingMessage] = [] + mcp = MCPServer("quiet") + + @mcp.tool() + async def mill(ctx: Context) -> str: + await ctx.report_progress(1, 3) + await ctx.info("milling done") + return "milled" + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async with connect(mcp, message_handler=collect) as client: + result = await client.call_tool("mill", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="milled")], structured_content={"result": "milled"}) + ) + assert received == snapshot( + [LoggingMessageNotification(params=LoggingMessageNotificationParams(level="info", data="milling done"))] + ) + + +@requirement("mcpserver:context:elicit") +@requirement("tools:call:elicitation-roundtrip") +async def test_context_elicit_returns_typed_result(connect: Connect) -> None: + """Context.elicit sends a form elicitation built from a pydantic schema and returns a typed result. + + The client sees the JSON schema generated from the model; the accepted content is validated + back into the model and handed to the tool as result.data. + """ + received: list[ElicitRequestParams] = [] + mcp = MCPServer("travel") + + class TravelPreferences(BaseModel): + destination: str + window_seat: bool + + @mcp.tool() + async def book_flight(ctx: Context) -> str: + answer = await ctx.elicit("Where to?", TravelPreferences) + assert isinstance(answer, AcceptedElicitation) + return f"{answer.action}: {answer.data.destination} window={answer.data.window_seat}" + + async def answer_form(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="accept", content={"destination": "Lisbon", "window_seat": True}) + + async with connect(mcp, elicitation_callback=answer_form) as client: + result = await client.call_tool("book_flight", {}) + + assert received == snapshot( + [ + ElicitRequestFormParams( + _meta={}, + message="Where to?", + requested_schema={ + "properties": { + "destination": {"title": "Destination", "type": "string"}, + "window_seat": {"title": "Window Seat", "type": "boolean"}, + }, + "required": ["destination", "window_seat"], + "title": "TravelPreferences", + "type": "object", + }, + ) + ] + ) + assert result == snapshot( + CallToolResult( + content=[TextContent(text="accept: Lisbon window=True")], + structured_content={"result": "accept: Lisbon window=True"}, + ) + ) + + +@requirement("mcpserver:context:read-resource") +async def test_context_read_resource_reads_registered_resource(connect: Connect) -> None: + """Context.read_resource lets a tool read a resource registered on the same server. + + The tool reports the MIME type and content it read, proving the resource function ran and its + return value came back through the context. + """ + mcp = MCPServer("library") + + @mcp.resource("config://app") + def app_config() -> str: + """The application configuration.""" + return "theme = dark" + + @mcp.tool() + async def show_config(ctx: Context) -> str: + contents = list(await ctx.read_resource("config://app")) + return "\n".join(f"{item.mime_type}: {item.content!r}" for item in contents) + + async with connect(mcp) as client: + result = await client.call_tool("show_config", {}) + + assert result == snapshot( + CallToolResult( + content=[TextContent(text="text/plain: 'theme = dark'")], + structured_content={"result": "text/plain: 'theme = dark'"}, + ) + ) + + +@requirement("logging:message:filtered") +async def test_set_logging_level_is_rejected_and_messages_are_never_filtered(connect: Connect) -> None: + """MCPServer does not support logging/setLevel, so log messages are never filtered by severity. + + The request is rejected with METHOD_NOT_FOUND because MCPServer registers no handler for it, + and every message a tool emits is delivered regardless of level. The spec says the server + should only send messages at or above the configured level; with no way to configure one, + everything is sent. + """ + received: list[LoggingMessageNotificationParams] = [] + mcp = MCPServer("unfilterable") + + @mcp.tool() + async def chatter(ctx: Context) -> str: + await ctx.debug("noise") + await ctx.error("signal") + return "done" + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params) + + async with connect(mcp, logging_callback=collect) as client: + with pytest.raises(MCPError) as exc_info: + await client.set_logging_level("error") + + await client.call_tool("chatter", {}) + + assert exc_info.value.error == snapshot(ErrorData(code=METHOD_NOT_FOUND, message="Method not found")) + assert received == snapshot( + [ + LoggingMessageNotificationParams(level="debug", data="noise"), + LoggingMessageNotificationParams(level="error", data="signal"), + ] + ) diff --git a/tests/interaction/mcpserver/test_prompts.py b/tests/interaction/mcpserver/test_prompts.py new file mode 100644 index 0000000000..2095f086d4 --- /dev/null +++ b/tests/interaction/mcpserver/test_prompts.py @@ -0,0 +1,195 @@ +"""Prompt interactions against MCPServer, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError +from mcp.server.mcpserver import MCPServer +from mcp.types import ( + ErrorData, + GetPromptResult, + ListPromptsResult, + Prompt, + PromptArgument, + PromptMessage, + TextContent, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("mcpserver:prompt:decorated") +async def test_list_prompts_derives_arguments_from_signature(connect: Connect) -> None: + """A decorated prompt is listed with arguments derived from the function signature. + + Parameters without a default are required; the description comes from the docstring. + """ + mcp = MCPServer("prompter") + + @mcp.prompt() + def code_review(code: str, style_guide: str = "pep8") -> str: + """Review a piece of code.""" + raise NotImplementedError # registered for listing only; never rendered + + async with connect(mcp) as client: + result = await client.list_prompts() + + assert result == snapshot( + ListPromptsResult( + prompts=[ + Prompt( + name="code_review", + description="Review a piece of code.", + arguments=[ + PromptArgument(name="code", required=True), + PromptArgument(name="style_guide", required=False), + ], + ) + ] + ) + ) + + +@requirement("mcpserver:prompt:decorated") +async def test_get_prompt_renders_function_return(connect: Connect) -> None: + """The decorated function's string return value is rendered as a single user message.""" + mcp = MCPServer("prompter") + + @mcp.prompt() + def greet(name: str) -> str: + """A personalised greeting.""" + return f"Say hello to {name}." + + async with connect(mcp) as client: + result = await client.get_prompt("greet", {"name": "Ada"}) + + assert result == snapshot( + GetPromptResult( + description="A personalised greeting.", + messages=[PromptMessage(role="user", content=TextContent(text="Say hello to Ada."))], + ) + ) + + +@requirement("mcpserver:prompt:unknown-name") +async def test_get_unknown_prompt_is_error(connect: Connect) -> None: + """Getting a prompt name that was never registered fails with a JSON-RPC error. + + The spec reserves -32602 for this case; the SDK reports code 0 (see the divergence note on + the requirement). + """ + mcp = MCPServer("prompter") + + @mcp.prompt() + def greet(name: str) -> str: + """A registered prompt; the test requests a different name.""" + raise NotImplementedError + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.get_prompt("nope") + + assert exc_info.value.error == snapshot(ErrorData(code=0, message="Unknown prompt: nope")) + + +@requirement("prompts:get:missing-required-args") +async def test_get_prompt_with_a_missing_required_argument_is_an_error(connect: Connect) -> None: + """Getting a prompt without one of its required arguments fails with a JSON-RPC error. + + The missing argument is detected before the prompt function is called, but the spec's -32602 + Invalid params is reported as error code 0 with the bare exception text (see the divergence + note on the requirement). + """ + mcp = MCPServer("prompter") + + @mcp.prompt() + def greet(name: str) -> str: + """A registered prompt; validation rejects the call before the function runs.""" + raise NotImplementedError + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.get_prompt("greet") + + assert exc_info.value.error == snapshot(ErrorData(code=0, message="Missing required arguments: {'name'}")) + + +@requirement("mcpserver:prompt:args-validation") +async def test_get_prompt_with_a_wrong_type_argument_is_rejected_before_the_function_runs(connect: Connect) -> None: + """An argument that fails the function signature's type validation is rejected before the function runs. + + The decorated function is wrapped in pydantic's validate_call, so a value that cannot be + coerced to the parameter's annotation fails before the body executes. The function body + raises NotImplementedError to prove it never ran. The error is wrapped in the SDK's stable + rendering-error prefix; the body of the message is raw pydantic output and is not asserted. + """ + mcp = MCPServer("prompter") + + @mcp.prompt() + def repeat(phrase: str, count: int) -> str: + """A registered prompt; type validation rejects the call before the function runs.""" + raise NotImplementedError + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.get_prompt("repeat", {"phrase": "hi", "count": "many"}) + + assert exc_info.value.error.code == 0 + assert exc_info.value.error.message.startswith("Error rendering prompt repeat: 1 validation error") + + +@requirement("mcpserver:prompt:optional-args") +async def test_get_prompt_with_an_optional_argument_omitted_uses_the_default(connect: Connect) -> None: + """A prompt rendered without one of its optional arguments uses that parameter's default value.""" + mcp = MCPServer("prompter") + + @mcp.prompt() + def review(code: str, style: str = "pep8") -> str: + """Review a snippet of code against a style guide.""" + return f"Review {code} per {style}." + + async with connect(mcp) as client: + result = await client.get_prompt("review", {"code": "x = 1"}) + + assert result == snapshot( + GetPromptResult( + description="Review a snippet of code against a style guide.", + messages=[PromptMessage(role="user", content=TextContent(text="Review x = 1 per pep8."))], + ) + ) + + +@requirement("mcpserver:prompt:duplicate-name") +async def test_registering_a_duplicate_prompt_name_warns_and_keeps_the_first(connect: Connect) -> None: + """Registering a second prompt with an already-used name keeps the first registration. + + The intended behaviour is rejection at registration time; MCPServer instead logs a warning + and discards the second registration (see the divergence note on the requirement). The + second function is registered via the decorator with an explicit name so the test does not + redefine the same function name in this scope. + """ + mcp = MCPServer("prompter") + + @mcp.prompt() + def greet() -> str: + """The first registration; this is the one that wins.""" + return "first" + + @mcp.prompt(name="greet") + def greet_second() -> str: + """Registered with a duplicate name; the registration is discarded so this never runs.""" + raise NotImplementedError + + async with connect(mcp) as client: + listed = await client.list_prompts() + result = await client.get_prompt("greet") + + assert [prompt.name for prompt in listed.prompts] == ["greet"] + assert result == snapshot( + GetPromptResult( + description="The first registration; this is the one that wins.", + messages=[PromptMessage(role="user", content=TextContent(text="first"))], + ) + ) diff --git a/tests/interaction/mcpserver/test_resources.py b/tests/interaction/mcpserver/test_resources.py new file mode 100644 index 0000000000..57b0fdc86d --- /dev/null +++ b/tests/interaction/mcpserver/test_resources.py @@ -0,0 +1,183 @@ +"""Resource interactions against MCPServer, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError +from mcp.server.mcpserver import MCPServer +from mcp.types import ( + ErrorData, + ListResourcesResult, + ListResourceTemplatesResult, + ReadResourceResult, + Resource, + ResourceTemplate, + TextResourceContents, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("mcpserver:resource:static") +async def test_read_static_resource(connect: Connect) -> None: + """A function registered for a fixed URI is served at that URI with its return value as text.""" + mcp = MCPServer("library") + + @mcp.resource("config://app") + def app_config() -> str: + """The application configuration.""" + return "theme = dark" + + async with connect(mcp) as client: + result = await client.read_resource("config://app") + + assert result == snapshot( + ReadResourceResult( + contents=[TextResourceContents(uri="config://app", mime_type="text/plain", text="theme = dark")] + ) + ) + + +@requirement("mcpserver:resource:static") +async def test_list_static_and_templated_resources(connect: Connect) -> None: + """Statically-registered resources appear in resources/list; templated ones only in templates/list. + + The name and description are derived from the function name and docstring; the MIME type + defaults to text/plain. + """ + mcp = MCPServer("library") + + @mcp.resource("config://app") + def app_config() -> str: + """The application configuration.""" + raise NotImplementedError # registered for listing only; never read + + @mcp.resource("users://{user_id}/profile") + def user_profile(user_id: str) -> str: + """A user's profile.""" + raise NotImplementedError # registered for listing only; never read + + async with connect(mcp) as client: + resources = await client.list_resources() + templates = await client.list_resource_templates() + + assert resources == snapshot( + ListResourcesResult( + resources=[ + Resource( + name="app_config", + uri="config://app", + description="The application configuration.", + mime_type="text/plain", + ) + ] + ) + ) + assert templates == snapshot( + ListResourceTemplatesResult( + resource_templates=[ + ResourceTemplate( + name="user_profile", + uri_template="users://{user_id}/profile", + description="A user's profile.", + mime_type="text/plain", + ) + ] + ) + ) + + +@requirement("mcpserver:resource:template") +@requirement("resources:read:template-vars") +async def test_read_templated_resource(connect: Connect) -> None: + """Reading a URI that matches a registered template invokes the function with the extracted parameters.""" + mcp = MCPServer("library") + + @mcp.resource("users://{user_id}/profile") + def user_profile(user_id: str) -> str: + """A user's profile.""" + return f"profile for {user_id}" + + async with connect(mcp) as client: + result = await client.read_resource("users://42/profile") + + assert result == snapshot( + ReadResourceResult( + contents=[TextResourceContents(uri="users://42/profile", mime_type="text/plain", text="profile for 42")] + ) + ) + + +@requirement("mcpserver:resource:unknown-uri") +async def test_read_unknown_uri_is_error(connect: Connect) -> None: + """Reading a URI that matches no registered resource fails with a JSON-RPC error. + + The spec reserves -32002 for resource-not-found; see the divergence note on the requirement. + """ + mcp = MCPServer("library") + + @mcp.resource("config://app") + def app_config() -> str: + """A registered resource; the test reads a different URI.""" + raise NotImplementedError + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.read_resource("config://missing") + + assert exc_info.value.error == snapshot(ErrorData(code=0, message="Unknown resource: config://missing")) + + +@requirement("mcpserver:resource:read-throws-surfaced") +async def test_resource_function_that_raises_is_surfaced_as_a_jsonrpc_error(connect: Connect) -> None: + """An exception raised by a resource function reaches the caller as a JSON-RPC error. + + MCPServer wraps the failure in a generic error that names only the URI, so the original + exception text is not leaked to the client. The wrapped exception becomes error code 0 the + same way every other unhandled server-side exception does. + """ + mcp = MCPServer("library") + + @mcp.resource("res://boom") + def boom() -> str: + raise RuntimeError("nope") + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.read_resource("res://boom") + + assert exc_info.value.error == snapshot(ErrorData(code=0, message="Error reading resource res://boom")) + + +@requirement("mcpserver:resource:duplicate-name") +async def test_registering_a_duplicate_resource_uri_warns_and_keeps_the_first(connect: Connect) -> None: + """Registering a second static resource at an already-used URI keeps the first registration. + + The intended behaviour is rejection at registration time; MCPServer instead logs a warning + and discards the second registration (see the divergence note on the requirement). The two + registrations use different function names so the test does not redefine a name in this scope; + the resource decorator keys on the URI, not the function name. + """ + mcp = MCPServer("library") + + @mcp.resource("config://app") + def config_first() -> str: + """The first registration; this is the one that wins.""" + return "first" + + @mcp.resource("config://app") + def config_second() -> str: + """Registered at a duplicate URI; the registration is discarded so this never runs.""" + raise NotImplementedError + + async with connect(mcp) as client: + listed = await client.list_resources() + result = await client.read_resource("config://app") + + assert [resource.uri for resource in listed.resources] == ["config://app"] + assert listed.resources[0].name == "config_first" + assert result == snapshot( + ReadResourceResult(contents=[TextResourceContents(uri="config://app", mime_type="text/plain", text="first")]) + ) diff --git a/tests/interaction/mcpserver/test_tools.py b/tests/interaction/mcpserver/test_tools.py new file mode 100644 index 0000000000..05135c1286 --- /dev/null +++ b/tests/interaction/mcpserver/test_tools.py @@ -0,0 +1,432 @@ +"""Tool interactions against MCPServer, driven through the public Client API.""" + +import logging +from typing import Annotated, Literal + +import pytest +from inline_snapshot import snapshot +from pydantic import BaseModel, Field + +from mcp import MCPError +from mcp.server.mcpserver import Context, MCPServer +from mcp.server.mcpserver.exceptions import ToolError +from mcp.shared.exceptions import UrlElicitationRequiredError +from mcp.types import ( + URL_ELICITATION_REQUIRED, + CallToolResult, + ElicitRequestURLParams, + ErrorData, + LoggingMessageNotification, + LoggingMessageNotificationParams, + TextContent, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("tools:call:content:text") +async def test_call_tool_returns_text_content(connect: Connect) -> None: + """Arguments reach the tool function; its return value comes back as text content. + + MCPServer also derives an output schema from the return annotation and attaches the + matching structuredContent to the result. + """ + mcp = MCPServer("adder") + + @mcp.tool() + def add(a: int, b: int) -> str: + return str(a + b) + + async with connect(mcp) as client: + result = await client.call_tool("add", {"a": 2, "b": 3}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="5")], structured_content={"result": "5"})) + + +@requirement("mcpserver:tool:schema-variants") +async def test_complex_parameter_types_are_validated_and_coerced_before_the_tool_runs(connect: Connect) -> None: + """Literal, nested-model, and constrained parameters are validated and coerced from the wire arguments. + + The string "3" is coerced to `int` and the `point` dict to a `Point` instance before the function + body sees them, proving the generated input schema and validation pipeline cover non-trivial types. + """ + mcp = MCPServer("typed") + + class Point(BaseModel): + x: int + y: int + + @mcp.tool() + def place(mode: Literal["fast", "slow"], point: Point, count: Annotated[int, Field(ge=1, le=10)]) -> str: + assert isinstance(point, Point) + return f"{mode} at ({point.x}, {point.y}) x{count}" + + async with connect(mcp) as client: + result = await client.call_tool("place", {"mode": "fast", "point": {"x": "3", "y": 4}, "count": 5}) + + assert result == snapshot( + CallToolResult( + content=[TextContent(text="fast at (3, 4) x5")], structured_content={"result": "fast at (3, 4) x5"} + ) + ) + + +@requirement("mcpserver:tool:handler-throws") +@requirement("mcpserver:output-schema:skip-on-error") +async def test_call_tool_function_exception_becomes_error_result(connect: Connect) -> None: + """An exception raised by a tool function is returned as an is_error result, not a JSON-RPC error. + + The function's `-> str` annotation gives the tool a derived output schema, but the error + result is built before any schema validation runs, so no validation failure is layered on + top of the original exception. + """ + mcp = MCPServer("errors") + + @mcp.tool() + def explode() -> str: + raise ValueError("boom") + + async with connect(mcp) as client: + result = await client.call_tool("explode", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="Error executing tool explode: boom")], is_error=True) + ) + + +@requirement("mcpserver:tool:handler-throws") +async def test_call_tool_tool_error_becomes_error_result(connect: Connect) -> None: + """A ToolError raised by a tool function is returned as an is_error result, not a JSON-RPC error.""" + mcp = MCPServer("errors") + + @mcp.tool() + def flux() -> str: + raise ToolError("flux capacitor offline") + + async with connect(mcp) as client: + result = await client.call_tool("flux", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="Error executing tool flux: flux capacitor offline")], is_error=True) + ) + + +@requirement("mcpserver:tool:unknown-name") +async def test_call_tool_unknown_name_returns_error_result(connect: Connect) -> None: + """Calling a tool name that was never registered is reported as an is_error result. + + The spec classifies unknown tools as a protocol error; see the divergence note on the + requirement. + """ + mcp = MCPServer("errors") + + @mcp.tool() + def add() -> None: + """A registered tool; the test calls a different name.""" + + async with connect(mcp) as client: + result = await client.call_tool("nope", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="Unknown tool: nope")], is_error=True)) + + +@requirement("mcpserver:tool:output-schema:model") +@requirement("tools:call:structured-content:text-mirror") +async def test_call_tool_model_return_becomes_structured_content(connect: Connect) -> None: + """A tool returning a pydantic model advertises the model's schema as the tool's output schema + and returns the model's fields as structured content alongside a serialised text block. + """ + mcp = MCPServer("weather") + + class Weather(BaseModel): + temperature: float + conditions: str + + @mcp.tool() + def get_weather() -> Weather: + return Weather(temperature=22.5, conditions="sunny") + + async with connect(mcp) as client: + listed = await client.list_tools() + result = await client.call_tool("get_weather", {}) + + assert listed.tools[0].output_schema == snapshot( + { + "properties": { + "temperature": {"title": "Temperature", "type": "number"}, + "conditions": {"title": "Conditions", "type": "string"}, + }, + "required": ["temperature", "conditions"], + "title": "Weather", + "type": "object", + } + ) + assert result == snapshot( + CallToolResult( + content=[ + TextContent( + text="""\ +{ + "temperature": 22.5, + "conditions": "sunny" +}\ +""" + ) + ], + structured_content={"temperature": 22.5, "conditions": "sunny"}, + ) + ) + + +@requirement("mcpserver:tool:output-schema:wrapped") +async def test_call_tool_list_return_is_wrapped_in_result_key(connect: Connect) -> None: + """A tool returning a list wraps the value under a "result" key in both the generated output + schema and the structured content. + """ + mcp = MCPServer("primes") + + @mcp.tool() + def primes() -> list[int]: + return [2, 3, 5] + + async with connect(mcp) as client: + listed = await client.list_tools() + result = await client.call_tool("primes", {}) + + assert listed.tools[0].output_schema == snapshot( + { + "properties": {"result": {"items": {"type": "integer"}, "title": "Result", "type": "array"}}, + "required": ["result"], + "title": "primesOutput", + "type": "object", + } + ) + assert result == snapshot( + CallToolResult( + content=[TextContent(text="2"), TextContent(text="3"), TextContent(text="5")], + structured_content={"result": [2, 3, 5]}, + ) + ) + + +@requirement("mcpserver:tool:input-validation") +async def test_call_tool_invalid_arguments_become_error_result(connect: Connect) -> None: + """Arguments that fail validation against the tool's signature are reported as an is_error + result describing the failure, not as a protocol error. + """ + mcp = MCPServer("adder") + + @mcp.tool() + def add(a: int, b: int) -> str: + """Validation rejects the arguments before the function is ever called.""" + raise NotImplementedError + + async with connect(mcp) as client: + result = await client.call_tool("add", {"b": 3}) + + # The description is raw pydantic output -- it embeds a pydantic-version-specific + # errors.pydantic.dev URL and the internal `addArguments` model name -- so only the stable + # prefix is asserted; a full snapshot would break on every pydantic upgrade. + assert result.is_error is True + assert isinstance(result.content[0], TextContent) + assert result.content[0].text.startswith("Error executing tool add: 1 validation error") + + +@requirement("mcpserver:output-schema:server-validate") +@requirement("mcpserver:output-schema:missing-structured") +async def test_tool_with_output_schema_returning_mismatched_structured_content_is_an_error_result( + connect: Connect, +) -> None: + """Structured content that fails the tool's own output schema is rejected on the server side. + + A tool annotated `Annotated[CallToolResult, Model]` returns a hand-built CallToolResult while + declaring `Model` as its output schema; MCPServer validates the supplied structured_content + against that schema before returning. The two cases -- a content shape that does not match, + and no structured content at all -- both fail that validation and are reported as is_error + results carrying the (raw pydantic) validation error wrapped in the SDK's stable prefix. + """ + mcp = MCPServer("forecaster") + + class Weather(BaseModel): + temperature: float + conditions: str + + @mcp.tool() + def mismatched() -> Annotated[CallToolResult, Weather]: + return CallToolResult(content=[TextContent(text="oops")], structured_content={"nope": True}) + + @mcp.tool() + def missing() -> Annotated[CallToolResult, Weather]: + return CallToolResult(content=[TextContent(text="oops")]) + + async with connect(mcp) as client: + mismatched_result = await client.call_tool("mismatched", {}) + missing_result = await client.call_tool("missing", {}) + + # The body of each message is raw pydantic ValidationError output (model name, field paths, + # an errors.pydantic.dev URL) and changes across pydantic versions, so only the SDK's stable + # prefix is asserted. + assert mismatched_result.is_error is True + assert isinstance(mismatched_result.content[0], TextContent) + assert mismatched_result.content[0].text.startswith("Error executing tool mismatched: 2 validation errors") + + assert missing_result.is_error is True + assert isinstance(missing_result.content[0], TextContent) + assert missing_result.content[0].text.startswith("Error executing tool missing: 1 validation error") + + +@requirement("mcpserver:tool:duplicate-name") +async def test_registering_a_duplicate_tool_name_warns_and_keeps_the_first(connect: Connect) -> None: + """Registering a second tool with an already-used name keeps the first registration. + + The intended behaviour is rejection at registration time; MCPServer instead logs a warning + and discards the second registration (see the divergence note on the requirement). The + second function is registered via add_tool with an explicit name so the test does not + redefine the same function name in this scope. + """ + mcp = MCPServer("duplicates") + + @mcp.tool() + def echo() -> str: + return "first" + + def echo_second() -> str: + """Passed to add_tool with a duplicate name; the registration is discarded so this never runs.""" + raise NotImplementedError + + mcp.add_tool(echo_second, name="echo") + + async with connect(mcp) as client: + listed = await client.list_tools() + result = await client.call_tool("echo", {}) + + assert [tool.name for tool in listed.tools] == ["echo"] + assert result == snapshot( + CallToolResult(content=[TextContent(text="first")], structured_content={"result": "first"}) + ) + + +@requirement("mcpserver:tool:naming-validation") +async def test_registering_a_tool_with_a_spec_invalid_name_warns_but_does_not_reject( + connect: Connect, caplog: pytest.LogCaptureFixture +) -> None: + """A tool name that violates the SEP-986 rules logs a warning at registration but is still registered. + + The intended behaviour is rejection at registration time; MCPServer instead logs the + naming-rule violation and proceeds (see the divergence note on the requirement). The warning + spans several SDK-authored log records, so only the stable prefix and inclusion of the + offending name are asserted. + """ + mcp = MCPServer("naming") + + with caplog.at_level(logging.WARNING, logger="mcp.shared.tool_name_validation"): + + @mcp.tool(name="bad name!") + def bad() -> str: + return "ok" + + assert any( + rec.levelno == logging.WARNING + and rec.message.startswith("Tool name validation warning") + and "bad name!" in rec.message + for rec in caplog.records + ) + + async with connect(mcp) as client: + listed = await client.list_tools() + result = await client.call_tool("bad name!", {}) + + assert [tool.name for tool in listed.tools] == ["bad name!"] + assert result == snapshot(CallToolResult(content=[TextContent(text="ok")], structured_content={"result": "ok"})) + + +@requirement("mcpserver:tool:url-elicitation-error") +async def test_decorated_tool_raising_url_elicitation_required_surfaces_as_error_32042(connect: Connect) -> None: + """A decorated tool raising the URL-elicitation-required error reaches the client as error -32042. + + MCPServer wraps every other tool exception as an is_error result; this error is special-cased + so it propagates as the JSON-RPC error the client needs in order to present the listed URL + interactions and retry the call. + """ + mcp = MCPServer("authorizer") + + @mcp.tool() + def read_files() -> str: + raise UrlElicitationRequiredError( + [ + ElicitRequestURLParams( + message="Authorization required for your files.", + url="https://example.com/oauth/authorize", + elicitation_id="auth-001", + ) + ] + ) + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("read_files", {}) + + assert exc_info.value.error.code == URL_ELICITATION_REQUIRED + assert exc_info.value.error == snapshot( + ErrorData( + code=-32042, + message="URL elicitation required", + data={ + "elicitations": [ + { + "mode": "url", + "message": "Authorization required for your files.", + "url": "https://example.com/oauth/authorize", + "elicitationId": "auth-001", + } + ] + }, + ) + ) + + +@requirement("mcpserver:register:post-connect") +async def test_adding_and_removing_tools_does_not_notify_connected_clients(connect: Connect) -> None: + """Mutating the tool set on a running server changes tools/list but sends no notification. + + add_tool and remove_tool only update the registry: a connected client that listed the tools + before the mutation has no way to learn it should list them again. The spec provides + notifications/tools/list_changed for exactly this; MCPServer never sends it. The tool emits + one log message as a sentinel so the test proves notifications do reach the collector -- the + log message arrives, a list_changed does not. + """ + received: list[IncomingMessage] = [] + mcp = MCPServer("mutable") + + def extra() -> str: + """A tool registered at runtime; never called.""" + raise NotImplementedError + + @mcp.tool() + def doomed() -> str: + """A tool removed at runtime; never called.""" + raise NotImplementedError + + @mcp.tool() + async def grow(ctx: Context) -> str: + mcp.add_tool(extra, name="extra") + mcp.remove_tool("doomed") + await ctx.info("tool set changed") + return "mutated" + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async with connect(mcp, message_handler=collect) as client: + before = await client.list_tools() + await client.call_tool("grow", {}) + after = await client.list_tools() + + assert [tool.name for tool in before.tools] == ["doomed", "grow"] + assert [tool.name for tool in after.tools] == ["grow", "extra"] + assert received == snapshot( + [LoggingMessageNotification(params=LoggingMessageNotificationParams(level="info", data="tool set changed"))] + ) diff --git a/tests/interaction/test_coverage.py b/tests/interaction/test_coverage.py new file mode 100644 index 0000000000..7821c1eed5 --- /dev/null +++ b/tests/interaction/test_coverage.py @@ -0,0 +1,105 @@ +"""Enforces the contract between the requirements manifest and the test suite. + +The contract runs in both directions: every non-deferred entry in :data:`REQUIREMENTS` must be +exercised by at least one test, and every test in the suite must carry at least one +`@requirement(...)` mark referencing a manifest entry. Deferral reasons that point at coverage +elsewhere in the repo must point at paths that exist. Test modules are imported directly +(rather than relying on pytest collection) so the check holds even when only this file is run. +""" + +import importlib +import re +from pathlib import Path +from types import ModuleType + +import pytest + +from tests.interaction._requirements import REQUIREMENTS, Requirement, covered_by, requirement + +_SUITE_ROOT = Path(__file__).parent +_REPO_ROOT = _SUITE_ROOT.parent.parent + +# Repo paths cited inside deferral reasons ("Covered by tests/... "). +_CITED_PATH = re.compile(r"(?:tests|src)/[\w./-]*\w") + +# Tests that exercise the suite's own helpers rather than an interaction-model behaviour. +# Anything listed here is exempt from the every-test-has-a-requirement check. +_HARNESS_SELF_TESTS = { + "tests.interaction.lowlevel.test_wire.test_recording_read_stream_ends_iteration_when_the_sender_closes", + "tests.interaction.transports.test_bridge.test_response_chunks_arrive_as_the_application_sends_them", + "tests.interaction.transports.test_bridge.test_closing_the_response_delivers_a_disconnect_to_the_application", + "tests.interaction.transports.test_bridge.test_an_application_failure_before_the_response_starts_fails_the_request", + "tests.interaction.transports.test_bridge.test_disabling_cancel_on_close_lets_the_application_finish_after_disconnect", + "tests.interaction.auth.test_flow.test_shimmed_app_serves_overrides_404s_and_otherwise_forwards_to_the_wrapped_app", +} + + +def _import_all_test_modules() -> list[ModuleType]: + """Import every other test module in the suite so their `@requirement` decorators register.""" + modules: list[ModuleType] = [] + for path in sorted(_SUITE_ROOT.rglob("test_*.py")): + relative = path.relative_to(_SUITE_ROOT).with_suffix("") + name = f"{__package__}.{'.'.join(relative.parts)}" + if name != __name__: + modules.append(importlib.import_module(name)) + return modules + + +def test_every_requirement_is_exercised() -> None: + """Each non-deferred requirement is covered by at least one test (deferred ones by none).""" + _import_all_test_modules() + + uncovered = [ + requirement_id + for requirement_id, spec in sorted(REQUIREMENTS.items()) + if spec.deferred is None and not covered_by(requirement_id) + ] + assert not uncovered, f"Requirements with no test and no deferred reason: {uncovered}" + + stale_deferrals = [ + requirement_id + for requirement_id, spec in sorted(REQUIREMENTS.items()) + if spec.deferred is not None and covered_by(requirement_id) + ] + assert not stale_deferrals, f"Deferred requirements that now have tests (remove deferred): {stale_deferrals}" + + +def test_every_test_exercises_a_requirement() -> None: + """Each test in the suite carries at least one `@requirement` mark (harness self-tests excepted).""" + all_tests = { + f"{module.__name__}.{name}" + for module in _import_all_test_modules() + for name in vars(module) + if name.startswith("test_") + } + linked_tests = {test_name for requirement_id in REQUIREMENTS for test_name in covered_by(requirement_id)} + + unlinked = sorted(all_tests - linked_tests - _HARNESS_SELF_TESTS) + assert not unlinked, f"Tests with no @requirement mark: {unlinked}" + + stale_exemptions = sorted(_HARNESS_SELF_TESTS - all_tests) + assert not stale_exemptions, f"Harness self-test exemptions that no longer exist: {stale_exemptions}" + + +def test_deferral_reasons_cite_existing_paths() -> None: + """Every repo path named in a deferral reason exists, so coverage pointers cannot rot.""" + missing = sorted( + f"{requirement_id}: {cited}" + for requirement_id, spec in REQUIREMENTS.items() + if spec.deferred is not None + for cited in _CITED_PATH.findall(spec.deferred) + if not (_REPO_ROOT / cited).exists() + ) + assert not missing, f"Deferral reasons citing paths that do not exist: {missing}" + + +def test_unknown_requirement_id_is_rejected() -> None: + """Marking a test with an ID that is not in the manifest fails at decoration time.""" + with pytest.raises(KeyError, match="Unknown requirement id 'tools:call:does-not-exist'"): + requirement("tools:call:does-not-exist") + + +def test_invalid_requirement_source_is_rejected() -> None: + """A requirement whose source is not a spec URL, 'sdk', or an issue reference fails at construction.""" + with pytest.raises(ValueError, match="source must be a specification URL"): + Requirement(source="https://example.com/not-the-spec", behavior="Never constructed.") diff --git a/tests/interaction/transports/__init__.py b/tests/interaction/transports/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/interaction/transports/_bridge.py b/tests/interaction/transports/_bridge.py new file mode 100644 index 0000000000..f78c6d14b5 --- /dev/null +++ b/tests/interaction/transports/_bridge.py @@ -0,0 +1,169 @@ +"""An in-process, full-duplex HTTP transport for driving ASGI applications from httpx. + +`httpx.ASGITransport` runs the application to completion and only then hands the buffered +response to the caller, so a server that streams its response — the streamable HTTP transport's +SSE responses — can never converse with the client mid-request: a server-initiated request +nested inside a still-open call deadlocks. `StreamingASGITransport` removes that limitation by +running the application as a background task and forwarding every `http.response.body` chunk to +the client the moment it is sent. Everything happens on the one event loop: no sockets, no +threads, no sleeps, no extra dependencies. + +The behavioural contract, pinned by `test_bridge.py`: + +- The request body is buffered before the application is invoked (MCP requests are small JSON + documents); the response streams chunk by chunk. +- Closing the response — or the whole client — delivers `http.disconnect` to the application, + exactly as a real server sees when its peer goes away. +- An exception the application raises before sending `http.response.start` fails the originating + request with that same exception. After the response has started, a failure is visible to the + client only through the response itself (status code, truncated body) — the same signal a real + server over a real socket would give. + +The transport owns an anyio task group for the application tasks; it is opened and closed by +`httpx.AsyncClient`'s own context manager, so use the client as a context manager (the suite +always does). Closing the transport cancels every running application task by default; set +`cancel_on_close=False` to wait for the application's own disconnect handling instead. +""" + +import math +from collections.abc import AsyncIterator +from types import TracebackType + +import anyio +import anyio.abc +import httpx +from anyio.streams.memory import MemoryObjectReceiveStream +from starlette.types import ASGIApp, Message, Scope + + +class _StreamingResponseBody(httpx.AsyncByteStream): + """A response body that yields chunks as the application produces them. + + Closing it tells the application the client has gone away (`http.disconnect`), mirroring a + peer that drops the connection mid-response. + """ + + def __init__(self, chunks: MemoryObjectReceiveStream[bytes], client_disconnected: anyio.Event) -> None: + self._chunks = chunks + self._client_disconnected = client_disconnected + + async def __aiter__(self) -> AsyncIterator[bytes]: + async for chunk in self._chunks: + yield chunk + + async def aclose(self) -> None: + self._client_disconnected.set() + await self._chunks.aclose() + + +class StreamingASGITransport(httpx.AsyncBaseTransport): + """Drive an ASGI application in-process, streaming each response as it is produced. + + With `cancel_on_close` (the default), closing the transport cancels every application task + still running so harness teardown can never hang. Setting it to False makes the transport wait + for the application's own disconnect handling to complete instead, which is the path the legacy + SSE server transport relies on for resource cleanup. + """ + + _task_group: anyio.abc.TaskGroup + + def __init__(self, app: ASGIApp, *, cancel_on_close: bool = True) -> None: + self._app = app + self._cancel_on_close = cancel_on_close + + async def __aenter__(self) -> "StreamingASGITransport": + self._task_group = anyio.create_task_group() + await self._task_group.__aenter__() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, + ) -> None: + # httpx closes every streamed response before closing the transport, so by now each + # application task has been delivered `http.disconnect`. Either cancel immediately, or wait + # for the application's own disconnect handling to unwind. + if self._cancel_on_close: + self._task_group.cancel_scope.cancel() + await self._task_group.__aexit__(exc_type, exc_value, traceback) + + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + assert isinstance(request.stream, httpx.AsyncByteStream) + request_body = b"".join([chunk async for chunk in request.stream]) + + scope: Scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": request.method, + "scheme": request.url.scheme, + "path": request.url.path, + "raw_path": request.url.raw_path.split(b"?", maxsplit=1)[0], + "query_string": request.url.query, + "root_path": "", + "headers": [(name.lower(), value) for name, value in request.headers.raw], + "server": (request.url.host, request.url.port), + "client": ("127.0.0.1", 1234), + } + + request_delivered = False + client_disconnected = anyio.Event() + response_started = anyio.Event() + response_status = 0 + response_headers: list[tuple[bytes, bytes]] = [] + application_error: Exception | None = None + chunk_writer, chunk_reader = anyio.create_memory_object_stream[bytes](math.inf) + + async def receive_request() -> Message: + nonlocal request_delivered + if not request_delivered: + request_delivered = True + return {"type": "http.request", "body": request_body, "more_body": False} + await client_disconnected.wait() + return {"type": "http.disconnect"} + + async def send_response(message: Message) -> None: + nonlocal response_status, response_headers + if message["type"] == "http.response.start": + response_status = message["status"] + response_headers = list(message.get("headers", [])) + response_started.set() + return + assert message["type"] == "http.response.body" + body: bytes = message.get("body", b"") + if body: + await chunk_writer.send(body) + if not message.get("more_body", False): + await chunk_writer.aclose() + + async def run_application() -> None: + nonlocal application_error + try: + await self._app(scope, receive_request, send_response) + except Exception as exc: # The bridge is the application's outermost boundary: a crash + # must fail the originating request (or show up in the already-started response), + # never tear down the task group shared with every other in-flight request. + application_error = exc + finally: + response_started.set() + await chunk_writer.aclose() + + self._task_group.start_soon(run_application) + try: + await response_started.wait() + if application_error is not None: + raise application_error + except BaseException: + # No response will be built, so close the reader the response body would have owned + # and tell the application its peer has gone away. + client_disconnected.set() + await chunk_reader.aclose() + raise + return httpx.Response( + status_code=response_status, + headers=response_headers, + stream=_StreamingResponseBody(chunk_reader, client_disconnected), + request=request, + ) diff --git a/tests/interaction/transports/_event_store.py b/tests/interaction/transports/_event_store.py new file mode 100644 index 0000000000..84d1a2646a --- /dev/null +++ b/tests/interaction/transports/_event_store.py @@ -0,0 +1,55 @@ +"""A predictable event store for resumability tests. + +The SDK's `EventStore` interface lets a streamable-HTTP server stamp every SSE event with an ID +and replay missed events when a client reconnects with `Last-Event-ID`. This implementation +issues sequential integer IDs starting at "1" so tests can assert exact IDs (the example store +uses uuid4, which cannot be snapshotted) and is small enough that every line is exercised by the +resumability tests themselves. +""" + +import anyio + +from mcp.server.streamable_http import EventCallback, EventId, EventMessage, EventStore, StreamId +from mcp.types import JSONRPCMessage + + +class SequencedEventStore(EventStore): + """Stores every event in order and replays the same-stream tail after a given ID.""" + + def __init__(self) -> None: + self._events: list[tuple[StreamId, JSONRPCMessage | None]] = [] + self._milestones: dict[int, anyio.Event] = {} + + async def store_event(self, stream_id: StreamId, message: JSONRPCMessage | None) -> EventId: + self._events.append((stream_id, message)) + count = len(self._events) + milestone = self._milestones.pop(count, None) + if milestone is not None: + milestone.set() + return str(count) + + async def wait_until_stored(self, count: int) -> None: + """Block until at least `count` events have been stored. + + Tests use this to wait for the server's message router (which runs in another task) to + finish storing a known set of events before issuing a replay, so the replay's content is + deterministic rather than depending on task scheduling order. + """ + if len(self._events) >= count: + return + milestone = self._milestones.setdefault(count, anyio.Event()) + await milestone.wait() + + async def replay_events_after(self, last_event_id: EventId, send_callback: EventCallback) -> StreamId | None: + try: + cursor = int(last_event_id) + except ValueError: + return None + if not 0 < cursor <= len(self._events): + return None + stream_id, _ = self._events[cursor - 1] + for index in range(cursor, len(self._events)): + event_stream_id, message = self._events[index] + if event_stream_id == stream_id and message is not None: + await send_callback(EventMessage(message, str(index + 1))) + return stream_id diff --git a/tests/interaction/transports/_stdio_server.py b/tests/interaction/transports/_stdio_server.py new file mode 100644 index 0000000000..5977cc3e99 --- /dev/null +++ b/tests/interaction/transports/_stdio_server.py @@ -0,0 +1,63 @@ +"""A real low-level Server over the stdio transport, for the suite's one subprocess test. + +Runnable as `python -m tests.interaction.transports._stdio_server` from the repo root; the test +launches it that way via `stdio_client`. Kept separate from the test module so the server lives in +its own importable file (subprocess coverage applies) while the test file follows the suite's +test-only-functions convention. +""" + +import sys + +import anyio + +from mcp.server import Server, ServerRequestContext +from mcp.server.stdio import stdio_server +from mcp.types import ( + CallToolRequestParams, + CallToolResult, + EmptyResult, + ListToolsResult, + PaginatedRequestParams, + SetLevelRequestParams, + TextContent, + Tool, +) + + +async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="echo", + input_schema={"type": "object", "properties": {"text": {"type": "string"}}, "required": ["text"]}, + ) + ] + ) + + +async def call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + assert params.name == "echo" + assert params.arguments is not None + text = params.arguments["text"] + await ctx.session.send_log_message(level="info", data=f"echoing {text}", logger="echo") + return CallToolResult(content=[TextContent(text=text)]) + + +async def set_logging_level(ctx: ServerRequestContext, params: SetLevelRequestParams) -> EmptyResult: + """Registered so the logging capability is advertised; the client never sets a level.""" + raise NotImplementedError + + +server = Server("stdio-echo", on_list_tools=list_tools, on_call_tool=call_tool, on_set_logging_level=set_logging_level) + + +async def main() -> None: + async with stdio_server() as (read_stream, write_stream): + await server.run(read_stream, write_stream, server.create_initialization_options()) + # Reached only when the run loop exits because stdin closed; if the process were terminated + # the test's stderr capture would not see this line. + print("stdio-echo: clean exit", file=sys.stderr, flush=True) + + +if __name__ == "__main__": + anyio.run(main) diff --git a/tests/interaction/transports/test_bridge.py b/tests/interaction/transports/test_bridge.py new file mode 100644 index 0000000000..7420b9d902 --- /dev/null +++ b/tests/interaction/transports/test_bridge.py @@ -0,0 +1,94 @@ +"""Contract tests for the suite's streaming ASGI bridge. + +These pin what `StreamingASGITransport` itself guarantees — chunk-by-chunk delivery, disconnect +propagation, and failure handling — against minimal hand-written ASGI applications, so the MCP +transport tests built on top of it never have to wonder what the harness provides. They are +harness self-tests, not interaction-model tests, and are exempted from the requirement-coverage +contract in `test_coverage.py`. +""" + +import anyio +import httpx +import pytest +from starlette.types import Message, Receive, Scope, Send + +from tests.interaction.transports._bridge import StreamingASGITransport + +pytestmark = pytest.mark.anyio + + +async def test_response_chunks_arrive_as_the_application_sends_them() -> None: + """Each body chunk is delivered as sent, empty chunks are skipped, and the stream ends with the application.""" + + async def chunked_app(scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + assert (await receive())["type"] == "http.request" + await send({"type": "http.response.start", "status": 200, "headers": [(b"content-type", b"text/plain")]}) + await send({"type": "http.response.body", "body": b"first", "more_body": True}) + await send({"type": "http.response.body", "body": b"", "more_body": True}) + await send({"type": "http.response.body", "body": b"second", "more_body": False}) + + async with ( + httpx.AsyncClient(transport=StreamingASGITransport(chunked_app), base_url="http://bridge") as http, + http.stream("GET", "/chunks") as response, + ): + with anyio.fail_after(5): + chunks = [chunk async for chunk in response.aiter_raw()] + + assert response.status_code == 200 + assert response.headers["content-type"] == "text/plain" + assert chunks == [b"first", b"second"] + + +async def test_closing_the_response_delivers_a_disconnect_to_the_application() -> None: + """A client that closes the response early is seen by the application as an http.disconnect.""" + seen_after_request: list[Message] = [] + disconnect_seen = anyio.Event() + + async def waiting_app(scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + assert (await receive())["type"] == "http.request" + await send({"type": "http.response.start", "status": 200, "headers": []}) + seen_after_request.append(await receive()) + disconnect_seen.set() + + async with httpx.AsyncClient(transport=StreamingASGITransport(waiting_app), base_url="http://bridge") as http: + async with http.stream("GET", "/wait") as response: + assert response.status_code == 200 + # Leaving the stream block closes the response while the application is still mid-response. + with anyio.fail_after(5): + await disconnect_seen.wait() + + assert seen_after_request == [{"type": "http.disconnect"}] + + +async def test_an_application_failure_before_the_response_starts_fails_the_request() -> None: + """An exception raised before http.response.start reaches the caller as that same exception.""" + + async def broken_app(scope: Scope, receive: Receive, send: Send) -> None: + raise RuntimeError("the demo application is broken") + + async with httpx.AsyncClient(transport=StreamingASGITransport(broken_app), base_url="http://bridge") as http: + with pytest.raises(RuntimeError, match="the demo application is broken"): + await http.get("/broken") + + +async def test_disabling_cancel_on_close_lets_the_application_finish_after_disconnect() -> None: + """With cancel_on_close=False, an application that runs cleanup after seeing http.disconnect + completes that cleanup before the transport finishes closing.""" + cleanup_ran = anyio.Event() + + async def lingering_app(scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + await receive() + await send({"type": "http.response.start", "status": 200, "headers": []}) + assert (await receive())["type"] == "http.disconnect" + cleanup_ran.set() + + transport = StreamingASGITransport(lingering_app, cancel_on_close=False) + with anyio.fail_after(5): + async with httpx.AsyncClient(transport=transport, base_url="http://bridge") as http: + async with http.stream("GET", "/linger") as response: + assert response.status_code == 200 + assert not cleanup_ran.is_set() + assert cleanup_ran.is_set() diff --git a/tests/interaction/transports/test_client_transport_http.py b/tests/interaction/transports/test_client_transport_http.py new file mode 100644 index 0000000000..65ed03f1e4 --- /dev/null +++ b/tests/interaction/transports/test_client_transport_http.py @@ -0,0 +1,247 @@ +"""Behaviour of the streamable-HTTP client transport itself, observed at the wire. + +These tests connect a real `Client` to a real server over the in-process bridge, recording every +HTTP request the SDK client issues, so the assertions are about what the transport sends (headers, +methods, ordering) rather than what the protocol layer on top of it returns. The recording is the +wire-level instrument; the SDK client never exposes these details. +""" + +from collections.abc import AsyncIterator + +import anyio +import httpx +import pytest +from inline_snapshot import snapshot +from starlette.types import Receive, Scope, Send + +from mcp import MCPError, types +from mcp.client.client import Client +from mcp.client.streamable_http import streamable_http_client +from mcp.server import Server, ServerRequestContext +from mcp.types import INVALID_REQUEST, CallToolResult, ErrorData, ListToolsResult, TextContent, Tool +from tests.interaction._connect import BASE_URL, NO_DNS_REBINDING_PROTECTION, client_via_http, mounted_app +from tests.interaction._requirements import requirement +from tests.interaction.transports._bridge import StreamingASGITransport +from tests.interaction.transports._event_store import SequencedEventStore + +pytestmark = pytest.mark.anyio + + +def _tooled_server() -> Server: + """A low-level server with one echo tool, used by every test in this file.""" + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="echo", description="Echo text.", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "echo" + assert params.arguments is not None + return CallToolResult(content=[TextContent(text=str(params.arguments["text"]))]) + + return Server("echoer", on_list_tools=list_tools, on_call_tool=call_tool) + + +@pytest.fixture +async def recorded() -> AsyncIterator[list[httpx.Request]]: + """Connect a `Client` over a recording HTTP client, list tools, exit, and yield every request sent. + + The HTTP client carries one caller-supplied header (`x-trace`) so its propagation can be + asserted; the recording captures the closing DELETE because it is read after the `Client` has + fully exited. + """ + requests: list[httpx.Request] = [] + + async def record(request: httpx.Request) -> None: + requests.append(request) + + async with mounted_app(_tooled_server(), on_request=record, headers={"x-trace": "abc"}) as (http, _): + async with client_via_http(http) as client: + result = await client.list_tools() + assert [tool.name for tool in result.tools] == ["echo"] + + yield requests + + +def _after_initialize(recorded: list[httpx.Request]) -> list[httpx.Request]: + """Every recorded request after the initialize POST (which carries no session yet).""" + assert recorded[0].method == "POST" + assert "mcp-session-id" not in recorded[0].headers + return recorded[1:] + + +@requirement("client-transport:http:custom-client") +@requirement("client-transport:http:custom-headers") +async def test_the_client_uses_the_supplied_http_client_and_propagates_its_headers( + recorded: list[httpx.Request], +) -> None: + """A caller-supplied `httpx.AsyncClient` is used for every request and carries its own headers. + + The recording itself proves the supplied client is the one in use; the propagated header + proves the SDK transport does not replace the caller's client configuration. + """ + # Exact ordering past the first request is not guaranteed (the standalone GET stream is + # scheduled concurrently with later POSTs), so methods are asserted as a multiset. + assert sorted(request.method for request in recorded) == snapshot(["DELETE", "GET", "POST", "POST", "POST"]) + assert all(request.headers["x-trace"] == "abc" for request in recorded) + + +@requirement("client-transport:http:session-stored") +async def test_every_request_after_initialize_carries_the_issued_session_id(recorded: list[httpx.Request]) -> None: + """The session id from the initialize response is sent on every subsequent request.""" + session_ids = {request.headers["mcp-session-id"] for request in _after_initialize(recorded)} + assert len(session_ids) == 1 + (session_id,) = session_ids + assert session_id + + +@requirement("client-transport:http:protocol-version-stored") +@requirement("client-transport:http:protocol-version-header") +async def test_every_request_after_initialize_carries_the_negotiated_protocol_version( + recorded: list[httpx.Request], +) -> None: + """The negotiated protocol version is sent on every subsequent request (and not on initialize).""" + assert "mcp-protocol-version" not in recorded[0].headers + versions = {request.headers["mcp-protocol-version"] for request in _after_initialize(recorded)} + assert versions == snapshot({"2025-11-25"}) + + +@requirement("client-transport:http:accept-header-post") +@requirement("client-transport:http:accept-header-get") +async def test_accept_headers_cover_the_response_representations_the_transport_handles( + recorded: list[httpx.Request], +) -> None: + """POSTs accept both JSON and SSE; the standalone GET stream accepts SSE.""" + for request in recorded: + if request.method == "POST": + assert "application/json" in request.headers["accept"] + assert "text/event-stream" in request.headers["accept"] + if request.method == "GET": + assert "text/event-stream" in request.headers["accept"] + + +@requirement("client-transport:http:no-reconnect-after-close") +async def test_closing_the_client_sends_delete_and_does_not_reconnect(recorded: list[httpx.Request]) -> None: + """Client teardown sends DELETE and issues no further requests (no resumption GET).""" + assert recorded[-1].method == "DELETE" + assert all("last-event-id" not in request.headers for request in recorded) + + +@requirement("client-transport:http:concurrent-streams") +async def test_concurrent_tool_calls_each_open_a_post_stream_and_receive_their_own_response() -> None: + """Three tool calls issued at once each open their own POST stream and get the right answer.""" + requests: list[httpx.Request] = [] + results: dict[int, CallToolResult] = {} + + async def record(request: httpx.Request) -> None: + requests.append(request) + + async with mounted_app(_tooled_server(), on_request=record) as (http, _), client_via_http(http) as client: + + async def call(n: int) -> None: + results[n] = await client.call_tool("echo", {"text": str(n)}) + + with anyio.fail_after(5): # pragma: no branch + async with anyio.create_task_group() as tg: # pragma: no branch + for n in (1, 2, 3): + tg.start_soon(call, n) + + assert results == snapshot( + { + 1: CallToolResult(content=[TextContent(text="1")]), + 2: CallToolResult(content=[TextContent(text="2")]), + 3: CallToolResult(content=[TextContent(text="3")]), + } + ) + tools_call_posts = [r for r in requests if r.method == "POST" and b'"tools/call"' in r.content] + assert len(tools_call_posts) == 3 + + +@requirement("client-transport:http:sse-405-tolerated") +@requirement("client-transport:http:terminate-405-ok") +async def test_client_tolerates_405_on_get_and_delete() -> None: + """A 405 on the standalone GET stream or the closing DELETE does not fail the connection. + + The GET-stream task swallows the failure and schedules a reconnect that the closing cancel + interrupts before it ever sleeps the full default delay; the DELETE 405 is logged and ignored. + Neither surfaces to the caller. + """ + server = _tooled_server() + real_app = server.streamable_http_app(transport_security=NO_DNS_REBINDING_PROTECTION) + + async def filter_methods(scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http" and scope["method"] in ("GET", "DELETE"): + await send({"type": "http.response.start", "status": 405, "headers": []}) + await send({"type": "http.response.body", "body": b""}) + return + await real_app(scope, receive, send) + + async with ( + server.session_manager.run(), + httpx.AsyncClient(transport=StreamingASGITransport(filter_methods), base_url=BASE_URL) as http_client, + ): + transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) + with anyio.fail_after(5): # pragma: no branch + async with Client(transport) as client: # pragma: no branch + result = await client.list_tools() + + assert [tool.name for tool in result.tools] == ["echo"] + + +@requirement("client-transport:http:no-reconnect-after-response") +async def test_a_completed_post_stream_is_not_reconnected() -> None: + """A POST stream that delivered its response closes without a resumption GET. + + With an event store the server stamps every SSE event with an ID, so the client transport has a + Last-Event-ID it could resume from -- the test proves it does not, because the response arrived + and the stream completed normally. + """ + requests: list[httpx.Request] = [] + + async def record(request: httpx.Request) -> None: + requests.append(request) + + server = _tooled_server() + async with ( + mounted_app(server, event_store=SequencedEventStore(), retry_interval=0, on_request=record) as (http, _), + client_via_http(http) as client, + ): + with anyio.fail_after(5): + result = await client.list_tools() + + assert [tool.name for tool in result.tools] == ["echo"] + resumption_gets = [r for r in requests if r.method == "GET" and "last-event-id" in r.headers] + assert resumption_gets == [] + + +@requirement("client-transport:http:404-surfaces") +async def test_a_404_mid_session_surfaces_as_a_session_terminated_error() -> None: + """A 404 in response to a request after initialization is reported to the caller as an MCP error. + + The spec says the client MUST start a new session in this situation; the SDK instead surfaces a + `Session terminated` error to the caller. The spec's MUST is tracked at + client-transport:http:session-404-reinitialize; this test pins the SDK's current behaviour. + """ + server = _tooled_server() + real_app = server.streamable_http_app(transport_security=NO_DNS_REBINDING_PROTECTION) + initialize_seen = anyio.Event() + + async def first_post_then_404(scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http" and scope["method"] == "POST" and initialize_seen.is_set(): + await send({"type": "http.response.start", "status": 404, "headers": []}) + await send({"type": "http.response.body", "body": b""}) + return + if scope["type"] == "http" and scope["method"] == "POST": + initialize_seen.set() + await real_app(scope, receive, send) + + async with ( + server.session_manager.run(), + httpx.AsyncClient(transport=StreamingASGITransport(first_post_then_404), base_url=BASE_URL) as http_client, + ): + transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) + with anyio.fail_after(5): # pragma: no branch + async with Client(transport) as client: # pragma: no branch + with pytest.raises(MCPError) as exc_info: # pragma: no branch + await client.list_tools() + + assert exc_info.value.error == snapshot(ErrorData(code=INVALID_REQUEST, message="Session terminated")) diff --git a/tests/interaction/transports/test_flows.py b/tests/interaction/transports/test_flows.py new file mode 100644 index 0000000000..c428fe2d68 --- /dev/null +++ b/tests/interaction/transports/test_flows.py @@ -0,0 +1,129 @@ +"""Transport-level composed flows: multi-client isolation, reconnection, and dual-transport hosting. + +These scenarios are about how the transport layer holds together across more than one connection +or more than one transport, so they connect real `Client`s against one mounted server rather than +running over the matrix. +""" + +import anyio +import httpx +import pytest +from inline_snapshot import snapshot + +from mcp.client.session import LoggingFnT +from mcp.server.mcpserver import Context, MCPServer +from mcp.types import CallToolResult, LoggingMessageNotificationParams, TextContent +from tests.interaction._connect import client_via_http, connect_over_sse, mounted_app +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("flow:multi-client:stateful-isolation") +async def test_concurrent_clients_on_one_stateful_server_receive_only_their_own_notifications() -> None: + """Two clients on one stateful manager each receive only the notifications their own request produced. + + Complements `test_terminating_one_session_leaves_others_working` (which proves session + independence under termination) with the notification-isolation dimension: a notification + emitted by one session's handler does not leak to another session's client. + """ + mcp = MCPServer("multi") + + @mcp.tool() + async def announce(label: str, ctx: Context) -> str: + """Emit one info-level log carrying the caller's label, then return it.""" + await ctx.info(label) + return label + + received_a: list[object] = [] + received_b: list[object] = [] + + async def collect_a(params: LoggingMessageNotificationParams) -> None: + received_a.append(params.data) + + async def collect_b(params: LoggingMessageNotificationParams) -> None: + received_b.append(params.data) + + async with mounted_app(mcp) as (http, _): + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + + async def call(label: str, collect: LoggingFnT) -> None: + async with client_via_http(http, logging_callback=collect) as client: + await client.call_tool("announce", {"label": label}) + + tg.start_soon(call, "a", collect_a) + tg.start_soon(call, "b", collect_b) + + assert received_a == ["a"] + assert received_b == ["b"] + + +@requirement("flow:session:terminate-then-reconnect") +async def test_a_fresh_connection_after_termination_obtains_a_new_session_and_operates() -> None: + """After a client terminates, a fresh connection to the same manager gets a distinct session. + + Steps: (1) connect a client and call list_tools, (2) the client exits (its DELETE fires), + (3) connect a second client to the same mounted app, (4) the second client's call_tool + succeeds and the recorded session ids show two distinct sessions were issued. + """ + mcp = MCPServer("reconnectable") + + @mcp.tool() + def echo(text: str) -> str: + """Return the input unchanged.""" + return text + + session_ids: list[str] = [] + + async def record(request: httpx.Request) -> None: + session_id = request.headers.get("mcp-session-id") + if session_id is not None: + session_ids.append(session_id) + + async with mounted_app(mcp, on_request=record) as (http, _): + async with client_via_http(http) as first: + first_result = await first.list_tools() + async with client_via_http(http) as second: + second_result = await second.call_tool("echo", {"text": "again"}) + + assert {tool.name for tool in first_result.tools} == {"echo"} + assert second_result == snapshot( + CallToolResult(content=[TextContent(text="again")], structured_content={"result": "again"}) + ) + distinct = set(session_ids) + assert len(distinct) == 2, f"expected two distinct session ids across the two connections, saw {distinct}" + + +@requirement("flow:compat:dual-transport-server") +async def test_one_server_serves_streamable_http_and_sse_clients_concurrently() -> None: + """One MCPServer instance serves a streamable-HTTP client and a legacy-SSE client at the same time. + + The two transports have independent connection management (the streamable-HTTP session manager + versus a per-connection SSE handler), but both dispatch into the same server's request + handlers. The test connects one client over each transport against the same instance and + proves both reach the same tool. Uses MCPServer because the low-level Server has no SSE + convenience; the entry is about hosting composition, not the low-level API. + """ + mcp = MCPServer("dual") + + @mcp.tool() + def echo(text: str) -> str: + """Return the input unchanged.""" + return text + + async with ( + mounted_app(mcp) as (http, _), + connect_over_sse(mcp) as sse_client, + client_via_http(http) as shttp_client, + ): + with anyio.fail_after(5): + shttp_result = await shttp_client.call_tool("echo", {"text": "via http"}) + sse_result = await sse_client.call_tool("echo", {"text": "via sse"}) + + assert shttp_result == snapshot( + CallToolResult(content=[TextContent(text="via http")], structured_content={"result": "via http"}) + ) + assert sse_result == snapshot( + CallToolResult(content=[TextContent(text="via sse")], structured_content={"result": "via sse"}) + ) diff --git a/tests/interaction/transports/test_hosting_http.py b/tests/interaction/transports/test_hosting_http.py new file mode 100644 index 0000000000..85e64ded42 --- /dev/null +++ b/tests/interaction/transports/test_hosting_http.py @@ -0,0 +1,344 @@ +"""Streamable HTTP semantics: status codes, header validation, message routing, and security. + +These tests speak HTTP directly to the server's mounted ASGI app via the in-process bridge, +asserting the wire contract -- which status code answers which condition, which stream a message +travels on -- that the SDK client never exposes. Transport-agnostic behaviour is covered by the +`connect`-fixture matrix. +""" + +import anyio +import pytest +from anyio.lowlevel import checkpoint +from httpx_sse import ServerSentEvent, aconnect_sse +from inline_snapshot import snapshot + +from mcp.server import Server, ServerRequestContext +from mcp.server.transport_security import TransportSecuritySettings +from mcp.types import ( + INVALID_PARAMS, + PARSE_ERROR, + CallToolRequestParams, + CallToolResult, + EmptyResult, + JSONRPCError, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + ListResourcesResult, + ListToolsResult, + PaginatedRequestParams, + SetLevelRequestParams, + SubscribeRequestParams, + TextContent, +) +from tests.interaction._connect import ( + base_headers, + initialize_body, + initialize_via_http, + mounted_app, + parse_sse_messages, +) +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +def _server() -> Server: + """A low-level server with one tool that emits a related and an unrelated notification.""" + + async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + """Registered only so the tools capability is advertised; never called.""" + raise NotImplementedError + + async def call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + assert params.name == "narrate" + await ctx.session.send_log_message(level="info", data="related", logger=None, related_request_id=ctx.request_id) + await ctx.session.send_resource_updated("file:///watched.txt") + return CallToolResult(content=[TextContent(text="done")]) + + async def set_logging_level(ctx: ServerRequestContext, params: SetLevelRequestParams) -> EmptyResult: + """Registered so the logging capability is advertised; the client never sets a level.""" + raise NotImplementedError + + async def list_resources(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListResourcesResult: + """Registered so the resources capability is advertised; the client never lists resources.""" + raise NotImplementedError + + async def subscribe_resource(ctx: ServerRequestContext, params: SubscribeRequestParams) -> EmptyResult: + """Registered so the resources subscribe sub-capability is advertised; the client never subscribes.""" + raise NotImplementedError + + return Server( + "hosted", + on_list_tools=list_tools, + on_call_tool=call_tool, + on_set_logging_level=set_logging_level, + on_list_resources=list_resources, + on_subscribe_resource=subscribe_resource, + ) + + +@requirement("hosting:http:method-405") +async def test_unsupported_http_methods_return_405() -> None: + """PUT and PATCH on the MCP endpoint return 405 with an Allow header naming the supported methods.""" + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + put = await http.put("/mcp", json={}, headers=base_headers(session_id=session_id)) + patch = await http.patch("/mcp", json={}, headers=base_headers(session_id=session_id)) + + assert (put.status_code, put.headers.get("allow")) == snapshot((405, "GET, POST, DELETE")) + assert (patch.status_code, patch.headers.get("allow")) == snapshot((405, "GET, POST, DELETE")) + + +@requirement("hosting:http:accept-406") +async def test_missing_accept_media_types_return_406() -> None: + """A POST whose Accept header lacks both required types, or a GET lacking text/event-stream, returns 406.""" + async with mounted_app(_server()) as (http, _): + post = await http.post( + "/mcp", json=initialize_body(), headers={"accept": "text/plain", "mcp-protocol-version": "2025-11-25"} + ) + session_id = await initialize_via_http(http) + get = await http.get( + "/mcp", + headers={"accept": "application/json", "mcp-protocol-version": "2025-11-25", "mcp-session-id": session_id}, + ) + + assert (post.status_code, post.json()["error"]["message"]) == snapshot( + (406, "Not Acceptable: Client must accept both application/json and text/event-stream") + ) + assert (get.status_code, get.json()["error"]["message"]) == snapshot( + (406, "Not Acceptable: Client must accept text/event-stream") + ) + + +@requirement("hosting:http:content-type-415") +async def test_non_json_content_type_is_rejected() -> None: + """A POST with a non-JSON Content-Type is rejected before reaching the transport. + + See the divergence on the requirement: the security middleware rejects with 400, so the + transport's own 415 path is unreachable through any public entry point. + """ + async with mounted_app(_server()) as (http, _): + response = await http.post( + "/mcp", content=b"", headers=base_headers() | {"content-type": "text/plain"} + ) + + assert (response.status_code, response.text) == snapshot((400, "Invalid Content-Type header")) + + +@requirement("hosting:http:parse-error-400") +@requirement("hosting:http:batch") +async def test_malformed_and_batched_bodies_return_400() -> None: + """A non-JSON body returns 400 Parse error; a JSON array of requests returns 400 Invalid params.""" + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + not_json = await http.post( + "/mcp", + content=b"this is not json", + headers=base_headers(session_id=session_id) | {"content-type": "application/json"}, + ) + batched = await http.post( + "/mcp", + json=[ + {"jsonrpc": "2.0", "id": 1, "method": "tools/list"}, + {"jsonrpc": "2.0", "id": 2, "method": "tools/list"}, + ], + headers=base_headers(session_id=session_id), + ) + + assert not_json.status_code == 400 + assert JSONRPCError.model_validate_json(not_json.text).error.code == PARSE_ERROR + assert batched.status_code == 400 + assert JSONRPCError.model_validate_json(batched.text).error.code == INVALID_PARAMS + + +@requirement("hosting:http:protocol-version-400") +@requirement("hosting:http:protocol-version-default") +async def test_protocol_version_header_is_validated() -> None: + """An unsupported MCP-Protocol-Version header returns 400; an absent header is accepted as the default.""" + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + + bad = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 2, "method": "tools/list"}, + headers=base_headers(session_id=session_id) | {"mcp-protocol-version": "1991-01-01"}, + ) + # Only Accept and the session ID -- no MCP-Protocol-Version header at all. + defaulted = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "method": "notifications/progress", "params": {"progressToken": 0, "progress": 1}}, + headers={"accept": "application/json, text/event-stream", "mcp-session-id": session_id}, + ) + + assert bad.status_code == 400 + assert JSONRPCError.model_validate_json(bad.text).error.message.startswith( + "Bad Request: Unsupported protocol version: 1991-01-01." + ) + # 202 proves the request was accepted under the assumed default version (2025-03-26). + assert defaulted.status_code == 202 + + +@requirement("hosting:http:json-response-mode") +async def test_json_response_mode_answers_with_application_json_not_sse() -> None: + """With JSON response mode enabled, request POSTs are answered with a single application/json body. + + Asserted at the wire level because the SDK client parses either representation, so a + Client-driven round trip cannot distinguish a JSON response from an SSE one. + """ + async with mounted_app(_server(), json_response=True) as (http, _): + initialized = await http.post("/mcp", json=initialize_body(), headers=base_headers()) + session_id = initialized.headers["mcp-session-id"] + ping = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 2, "method": "ping"}, + headers=base_headers(session_id=session_id), + ) + + assert initialized.status_code == 200 + assert initialized.headers["content-type"].split(";", 1)[0] == "application/json" + assert JSONRPCResponse.model_validate(initialized.json()).id == 1 + assert ping.status_code == 200 + assert ping.headers["content-type"].split(";", 1)[0] == "application/json" + assert JSONRPCResponse.model_validate(ping.json()).id == 2 + + +@requirement("hosting:http:notifications-202") +async def test_notification_post_returns_202_with_no_body() -> None: + """A POST containing only a notification (no request ID) returns 202 Accepted with no body.""" + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + response = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "method": "notifications/progress", "params": {"progressToken": 0, "progress": 1}}, + headers=base_headers(session_id=session_id), + ) + + assert (response.status_code, response.content) == snapshot((202, b"")) + + +@requirement("hosting:http:second-sse-rejected") +async def test_a_second_standalone_get_stream_on_the_same_session_returns_409() -> None: + """Opening a second standalone GET SSE stream while one is already established returns 409 Conflict.""" + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + + async with aconnect_sse(http, "GET", "/mcp", headers=base_headers(session_id=session_id)) as first: + assert first.response.status_code == 200 + # The standalone-stream writer registers its key as its first action, then parks + # awaiting messages; one yield to the loop lets that registration complete before the + # second GET is dispatched. + await checkpoint() + second = await http.get("/mcp", headers=base_headers(session_id=session_id)) + + assert (second.status_code, second.json()["error"]["message"]) == snapshot( + (409, "Conflict: Only one SSE stream is allowed per session") + ) + + +@requirement("hosting:http:standalone-sse") +@requirement("hosting:http:standalone-sse-no-response") +@requirement("hosting:http:response-same-connection") +@requirement("hosting:http:sse-close-after-response") +@requirement("hosting:http:no-broadcast") +async def test_messages_are_routed_to_exactly_one_stream() -> None: + """Each server message travels on exactly one SSE stream and is never broadcast. + + A streamable-HTTP session has two kinds of server-to-client SSE stream: one short-lived stream + per POST request, carrying that request's response and any notifications related to it, and one + long-lived standalone stream (opened by GET) for notifications not tied to any request. The + spec's routing rule is that the POST stream delivers the response (and its related + notifications) and then closes, the standalone stream carries only unrelated notifications and + never a JSON-RPC response, and no message appears on both. The test opens both streams, calls a + tool whose handler emits one related and one unrelated notification, and asserts each message's + routing. + """ + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + post_events: list[ServerSentEvent] = [] + get_events: list[ServerSentEvent] = [] + + async def read_standalone_stream() -> None: + async with aconnect_sse(http, "GET", "/mcp", headers=base_headers(session_id=session_id)) as get: + assert get.response.status_code == 200 + standalone_ready.set() + async for event in get.aiter_sse(): + get_events.append(event) + seen_on_standalone.set() + + standalone_ready = anyio.Event() + seen_on_standalone = anyio.Event() + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + tg.start_soon(read_standalone_stream) + await standalone_ready.wait() + + params = CallToolRequestParams(name="narrate", arguments={}) + body = JSONRPCRequest(jsonrpc="2.0", id=5, method="tools/call", params=params.model_dump()) + async with aconnect_sse( + http, + "POST", + "/mcp", + json=body.model_dump(by_alias=True, exclude_none=True), + headers=base_headers(session_id=session_id), + ) as post: + assert post.response.status_code == 200 + # The POST stream iterator ends when the server closes the stream after the response. + post_events = [event async for event in post.aiter_sse()] + + await seen_on_standalone.wait() + tg.cancel_scope.cancel() + + post_messages = parse_sse_messages(post_events) + get_messages = parse_sse_messages(get_events) + + # POST stream: the related log notification, then the response, then the iterator ends (close). + assert [type(m).__name__ for m in post_messages] == snapshot(["JSONRPCNotification", "JSONRPCResponse"]) + assert isinstance(post_messages[0], JSONRPCNotification) + assert (post_messages[0].method, post_messages[0].params) == snapshot( + ("notifications/message", {"level": "info", "data": "related"}) + ) + assert isinstance(post_messages[1], JSONRPCResponse) + assert post_messages[1].id == 5 + + # Standalone stream: only the unrelated resource-updated notification, never a response. + assert [type(m).__name__ for m in get_messages] == snapshot(["JSONRPCNotification"]) + assert isinstance(get_messages[0], JSONRPCNotification) + assert get_messages[0].method == snapshot("notifications/resources/updated") + + +@requirement("hosting:http:dns-rebinding") +@requirement("transport:streamable-http:origin-validation") +async def test_origin_validation_rejects_disallowed_origins_when_enabled() -> None: + """A disallowed Origin returns 403 (and Host 421) with protection enabled; disabled lets both through. + + See the divergence on hosting:http:dns-rebinding: the spec's Origin validation is an + unconditional MUST, but the SDK enables it only when the host is localhost (or settings are + passed explicitly) and additionally checks the Host header (returning 421), which the spec + does not require. + """ + # transport_security=None triggers the localhost auto-enable behaviour. + async with mounted_app(Server("guarded"), transport_security=None) as (http, _): + bad_origin = await http.post( + "/mcp", json=initialize_body(), headers=base_headers() | {"origin": "http://evil.example"} + ) + bad_host = await http.post("/mcp", json=initialize_body(), headers=base_headers() | {"host": "evil.example"}) + async with aconnect_sse( + http, "POST", "/mcp", json=initialize_body(), headers=base_headers() | {"origin": "http://127.0.0.1:8000"} + ) as ok: + assert ok.response.status_code == 200 + assert [event async for event in ok.aiter_sse()] + + assert (bad_origin.status_code, bad_origin.text) == snapshot((403, "Invalid Origin header")) + assert (bad_host.status_code, bad_host.text) == snapshot((421, "Invalid Host header")) + + async with mounted_app( + Server("unguarded"), transport_security=TransportSecuritySettings(enable_dns_rebinding_protection=False) + ) as (http, _): + async with aconnect_sse( + http, "POST", "/mcp", json=initialize_body(), headers=base_headers() | {"origin": "http://evil.example"} + ) as unguarded: + status = unguarded.response.status_code + assert [event async for event in unguarded.aiter_sse()] + + assert status == 200 diff --git a/tests/interaction/transports/test_hosting_resume.py b/tests/interaction/transports/test_hosting_resume.py new file mode 100644 index 0000000000..c7945d56c3 --- /dev/null +++ b/tests/interaction/transports/test_hosting_resume.py @@ -0,0 +1,372 @@ +"""Resumability over the streamable HTTP transport, exercised entirely in process. + +These tests configure the server with an event store, so every SSE event is stamped with an ID +and a client that loses its connection can resume by sending `Last-Event-ID`. The wire-level +tests (`mounted_app` + raw httpx) assert exactly what travels on the wire; the end-to-end test +drives the SDK client through a server-initiated stream close and proves the call still +completes. The bridge's `aclose()` delivers `http.disconnect` to the running application, so +closing a streaming response mid-read is a deterministic in-process disconnect -- no sockets, +no real time. Every server here uses `retry_interval=0` so reconnection waits are no-ops. +""" + +import json + +import anyio +import httpx +import pytest +from httpx_sse import EventSource, ServerSentEvent +from inline_snapshot import snapshot + +from mcp.client.session import ClientSession +from mcp.client.streamable_http import streamable_http_client +from mcp.server.mcpserver import Context, MCPServer +from mcp.shared.message import ClientMessageMetadata +from mcp.types import ( + LATEST_PROTOCOL_VERSION, + CallToolRequest, + CallToolRequestParams, + CallToolResult, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + LoggingMessageNotificationParams, + TextContent, + jsonrpc_message_adapter, +) +from tests.interaction._connect import ( + BASE_URL, + base_headers, + connect_over_streamable_http, + initialize_via_http, + mounted_app, + parse_sse_messages, +) +from tests.interaction._requirements import requirement +from tests.interaction.transports._event_store import SequencedEventStore + +pytestmark = pytest.mark.anyio + + +def _counting_server() -> MCPServer: + """A server with one tool that emits related notifications and one unrelated notification.""" + mcp = MCPServer("resumable") + + @mcp.tool() + async def count(ctx: Context, n: int) -> str: + """Emit n log notifications related to this call, plus one unrelated resource update.""" + for i in range(1, n + 1): + await ctx.info(f"tick {i}") + await ctx.session.send_resource_updated("file:///elsewhere.txt") + return f"counted to {n}" + + return mcp + + +def _tools_call(request_id: int, name: str, arguments: dict[str, object]) -> str: + """A serialized tools/call JSON-RPC request body.""" + return JSONRPCRequest( + jsonrpc="2.0", id=request_id, method="tools/call", params={"name": name, "arguments": arguments} + ).model_dump_json(by_alias=True, exclude_none=True) + + +async def _read_events(response: httpx.Response, count: int) -> list[ServerSentEvent]: + """Read exactly `count` SSE events from a streaming response without closing it.""" + source = EventSource(response).aiter_sse() + return [await anext(source) for _ in range(count)] + + +@requirement("hosting:resume:event-ids") +@requirement("hosting:resume:priming") +async def test_a_post_sse_stream_begins_with_a_priming_event_and_stamps_every_event() -> None: + """A request's SSE stream opens with a priming event (id, empty data, retry) then stamps each message.""" + async with mounted_app(_counting_server(), event_store=SequencedEventStore(), retry_interval=0) as (http, _): + session_id = await initialize_via_http(http) + with anyio.fail_after(5): + async with http.stream( # pragma: no branch + "POST", "/mcp", content=_tools_call(1, "count", {"n": 2}), headers=base_headers(session_id=session_id) + ) as response: + assert response.status_code == 200 + events = await _read_events(response, 4) + + priming, first, second, result = events + # The priming event is the only event a client could have seen before any work happened, so it + # is the resumption anchor: it carries an ID and empty data. The SDK attaches the retry hint + # to this event (see the divergence on hosting:resume:priming). + assert (priming.id, priming.data, priming.retry) == snapshot(("3", "", 0)) + assert priming.event == snapshot("message") + # Every subsequent event carries an event-store ID; the related notifications and the response + # all ride this stream and close it after the response. + assert [event.id for event in (first, second, result)] == snapshot(["4", "5", "7"]) + assert [json.loads(event.data)["method"] for event in (first, second)] == snapshot( + ["notifications/message", "notifications/message"] + ) + assert jsonrpc_message_adapter.validate_json(result.data) == snapshot( + JSONRPCResponse( + jsonrpc="2.0", + id=1, + result={ + "content": [{"type": "text", "text": "counted to 2"}], + "structuredContent": {"result": "counted to 2"}, + "isError": False, + }, + ) + ) + + +@requirement("hosting:resume:replay") +@requirement("hosting:resume:stream-scoped") +@requirement("hosting:resume:buffered-replay") +async def test_get_with_last_event_id_replays_only_that_streams_missed_events() -> None: + """Reconnecting with Last-Event-ID returns the missed events from that one stream, in order. + + The handler also emits an unrelated notification (which the server stores under the + standalone-stream key); replay must not return it, proving replay is scoped to the stream + the given event ID belongs to. + + Steps: (1) initialize; (2) POST a tool call and read events until the first notification is + captured; (3) close the response mid-stream -- the bridge delivers `http.disconnect`, the + handler keeps running; (4) release the handler so it emits the remaining messages, which the + server buffers in the event store; (5) wait on the event store for the handler's response to + be stored, so the replay's content is independent of task scheduling; (6) GET with + `Last-Event-ID` and assert the replay is exactly the missed events from this request's stream. + """ + release = anyio.Event() + store = SequencedEventStore() + + mcp = MCPServer("resumable") + + @mcp.tool() + async def count(ctx: Context) -> str: + """Emit one related notification, wait for the test, then emit two more plus an unrelated one.""" + await ctx.info("tick 1") + await release.wait() + await ctx.info("tick 2") + await ctx.info("tick 3") + await ctx.session.send_resource_updated("file:///elsewhere.txt") + return "counted" + + async with mounted_app(mcp, event_store=store, retry_interval=0) as (http, _): + session_id = await initialize_via_http(http) + with anyio.fail_after(5): + async with http.stream( + "POST", "/mcp", content=_tools_call(1, "count", {}), headers=base_headers(session_id=session_id) + ) as response: + # Read the priming event and the first notification, then drop the connection. + priming, first = await _read_events(response, 2) + assert (priming.id, first.id) == snapshot(("3", "4")) + last_seen = first.id + release.set() + # The handler keeps running after the disconnect; its remaining messages are stored. + # The first wait returns immediately (the priming and first tick are already stored); + # the second blocks until the response itself is stored so the replay content is fixed. + await store.wait_until_stored(4) + await store.wait_until_stored(8) + replay_headers = base_headers(session_id=session_id) | {"last-event-id": last_seen} + async with http.stream("GET", "/mcp", headers=replay_headers) as replay: # pragma: no branch + assert replay.status_code == 200 + missed = await _read_events(replay, 3) + + decoded = parse_sse_messages(missed) + # Exactly the two remaining related notifications and the response, with their original IDs. + assert [event.id for event in missed] == snapshot(["5", "6", "8"]) + assert [type(message).__name__ for message in decoded] == snapshot( + ["JSONRPCNotification", "JSONRPCNotification", "JSONRPCResponse"] + ) + assert isinstance(decoded[2], JSONRPCResponse) + assert decoded[2].id == 1 + # The unrelated resource-updated notification was stored under the standalone-stream key, not + # this request's stream, so it must not appear in the replay. + assert all( + not (isinstance(message, JSONRPCNotification) and message.method == "notifications/resources/updated") + for message in decoded + ) + + +@requirement("hosting:resume:bad-event-id") +async def test_an_unknown_last_event_id_yields_an_empty_replay_stream() -> None: + """A Last-Event-ID the event store cannot map produces an empty SSE stream rather than an error. + + See the divergence on hosting:resume:bad-event-id: this pins current behaviour. + """ + async with mounted_app(_counting_server(), event_store=SequencedEventStore(), retry_interval=0) as (http, _): + session_id = await initialize_via_http(http) + with anyio.fail_after(5): + for unknown in ("no-such-event", "0"): + headers = base_headers(session_id=session_id) | {"last-event-id": unknown} + async with http.stream("GET", "/mcp", headers=headers) as replay: + assert replay.status_code == 200 + assert replay.headers["content-type"].startswith("text/event-stream") + events = [event async for event in EventSource(replay).aiter_sse()] + assert events == [] + + +@requirement("hosting:http:disconnect-not-cancel") +async def test_dropping_the_connection_mid_request_does_not_cancel_the_handler() -> None: + """Closing the request's SSE connection while the handler is running leaves the handler running. + + The handler signals when it has started and when it has finished; the test drops the + connection in between and then releases the handler. If the disconnect cancelled the handler, + `finished` would never be set and the test would time out. + """ + started = anyio.Event() + release = anyio.Event() + finished = anyio.Event() + + mcp = MCPServer("resumable") + + @mcp.tool() + async def hold(ctx: Context) -> str: + """Signal start, wait for the test, signal completion.""" + started.set() + await release.wait() + await ctx.info("released") + finished.set() + return "held" + + async with mounted_app(mcp, event_store=SequencedEventStore(), retry_interval=0) as (http, _): + session_id = await initialize_via_http(http) + with anyio.fail_after(5): + async with http.stream( + "POST", "/mcp", content=_tools_call(1, "hold", {}), headers=base_headers(session_id=session_id) + ) as response: + await _read_events(response, 1) + await started.wait() + assert not finished.is_set() + release.set() + await finished.wait() + + +# This test intentionally carries every automatic-reconnection requirement: the +# close-then-resume scenario is indivisible, so splitting it would mean five near-identical bodies. +@requirement("hosting:resume:close-stream") +@requirement("transport:streamable-http:resumability") +@requirement("client-transport:http:reconnect-post-priming") +@requirement("client-transport:http:reconnect-retry-value") +@requirement("flow:resume:tool-call-resumption-token") +async def test_a_call_whose_stream_the_server_closes_is_resumed_by_the_client() -> None: + """A server-closed request stream is reconnected by the client and the call completes. + + The handler emits one notification, closes its own SSE stream, then (once released) emits + another and returns. The client observed the priming event (so it has a Last-Event-ID and a + retry hint of 0ms), sees the stream end, reconnects via GET with Last-Event-ID, and receives + the post-close notification and the result over the replay stream. The shared events make the + test deterministic: the handler only proceeds once the test knows the first notification has + arrived (and so the client's reconnection has begun). + """ + received: list[object] = [] + before_seen = anyio.Event() + gate = anyio.Event() + done = anyio.Event() + + mcp = MCPServer("resumable") + + @mcp.tool() + async def interrupt(ctx: Context) -> str: + """Emit, close this call's SSE stream, then emit again after the test releases the gate.""" + await ctx.info("before close") + await ctx.close_sse_stream() + await gate.wait() + await ctx.info("after close") + done.set() + return "resumed" + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params.data) + if params.data == "before close": + before_seen.set() + + result: list[CallToolResult] = [] + async with connect_over_streamable_http( + mcp, event_store=SequencedEventStore(), retry_interval=0, logging_callback=collect + ) as client: + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + + async def call() -> None: + result.append(await client.call_tool("interrupt", {})) + + tg.start_soon(call) + await before_seen.wait() + gate.set() + await done.wait() + + assert result == snapshot( + [CallToolResult(content=[TextContent(text="resumed")], structured_content={"result": "resumed"})] + ) + assert received == snapshot(["before close", "after close"]) + + +@requirement("client-transport:http:resume-stream-api") +async def test_a_captured_resumption_token_replays_missed_messages_on_a_new_connection() -> None: + """A resumption token captured via on_resumption_token_update on one connection lets a fresh + connection retrieve the messages it missed by passing resumption_token to send_request. + + This is the explicit ClientMessageMetadata API, distinct from the automatic reconnection the + previous test covers: the transport dispatches a resumption_token request as a GET with + Last-Event-ID instead of POSTing the body, and remaps the replayed response onto the new + request's id. Client.call_tool does not expose ClientMessageMetadata, so the test drives a + bare ClientSession via session.send_request -- the sanctioned drop-down for behaviour Client + cannot express. The second connection carries the original session id but does not initialize + (the server-side session already is), modelling a caller that resumes after a process restart. + """ + captured: list[str] = [] + received: list[object] = [] + first_seen = anyio.Event() + token_seen = anyio.Event() + release = anyio.Event() + store = SequencedEventStore() + + mcp = MCPServer("resumable") + + @mcp.tool() + async def hold(ctx: Context) -> str: + """Emit one notification, wait for the test, emit another, return.""" + await ctx.info("first") + await release.wait() + await ctx.info("second") + return "done" + + async def on_token(token: str) -> None: + captured.append(token) + if len(captured) >= 2: + token_seen.set() + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params.data) + first_seen.set() + + call = CallToolRequest(params=CallToolRequestParams(name="hold", arguments={})) + capture = ClientMessageMetadata(on_resumption_token_update=on_token) + + async with mounted_app(mcp, event_store=store, retry_interval=0) as (http, manager): + with anyio.fail_after(5): # pragma: no branch + async with ( # pragma: no branch + streamable_http_client(f"{BASE_URL}/mcp", http_client=http, terminate_on_close=False) as (r1, w1), + ClientSession(r1, w1, logging_callback=collect) as first, + anyio.create_task_group() as tg, + ): + await first.initialize() + tg.start_soon(first.send_request, call, CallToolResult, None, capture) + await first_seen.wait() + await token_seen.wait() + assert captured == snapshot(["3", "4"]) + assert received == snapshot(["first"]) + # The session id is only observable via the manager (the client transport does not expose it). + (session_id,) = manager._server_instances + http.headers["mcp-session-id"] = session_id + http.headers["mcp-protocol-version"] = LATEST_PROTOCOL_VERSION + tg.cancel_scope.cancel() + + with anyio.fail_after(5): # pragma: no branch + release.set() # pragma: lax no cover — python/cpython#106749: 3.11 drops this line event + # init priming + init response + call priming + "first" + "second" + result = 6 stored events. + await store.wait_until_stored(6) + async with ( # pragma: no branch + streamable_http_client(f"{BASE_URL}/mcp", http_client=http) as (r2, w2), + ClientSession(r2, w2, logging_callback=collect) as second, + ): + result = await second.send_request( + call, CallToolResult, metadata=ClientMessageMetadata(resumption_token=captured[-1]) + ) + assert result == snapshot(CallToolResult(content=[TextContent(text="done")], structured_content={"result": "done"})) + assert received == snapshot(["first", "second"]) diff --git a/tests/interaction/transports/test_hosting_session.py b/tests/interaction/transports/test_hosting_session.py new file mode 100644 index 0000000000..a926c3e8a2 --- /dev/null +++ b/tests/interaction/transports/test_hosting_session.py @@ -0,0 +1,202 @@ +"""Streamable HTTP session lifecycle: creation, routing, termination, and stateless mode. + +A test here speaks raw HTTP only when its assertion is the wire contract -- which header is +issued, which status code answers which condition -- that the SDK `Client` cannot observe. +Everything else is `Client`-driven against the same mounted session manager. Transport-agnostic +behaviour is covered by the `connect`-fixture matrix. +""" + +import re + +import anyio +import httpx +import pytest +from inline_snapshot import snapshot + +from mcp.server import Server, ServerRequestContext +from mcp.types import JSONRPCResponse, ListToolsResult, PaginatedRequestParams, Tool +from tests.interaction._connect import ( + base_headers, + client_via_http, + initialize_body, + initialize_via_http, + mounted_app, + post_jsonrpc, +) +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +def _server() -> Server: + """A minimal low-level server with one tool, so subsequent-request routing can be observed.""" + + async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="noop", description="Does nothing.", input_schema={"type": "object"})]) + + return Server("hosted", on_list_tools=list_tools) + + +@requirement("hosting:session:create") +@requirement("hosting:session:id-charset") +async def test_initialize_issues_a_visible_ascii_session_id() -> None: + """An initialize POST without a session ID creates a session and returns a visible-ASCII Mcp-Session-Id.""" + async with mounted_app(_server()) as (http, _): + response, messages = await post_jsonrpc(http, initialize_body()) + + assert response.status_code == 200 + session_id = response.headers.get("mcp-session-id") + assert session_id is not None + # The spec requires the session ID to consist only of visible ASCII (0x21-0x7E). + assert re.fullmatch(r"[\x21-\x7E]+", session_id) + assert isinstance(messages[0], JSONRPCResponse) + assert messages[0].id == 1 + + +@requirement("hosting:session:reuse") +async def test_subsequent_requests_with_the_session_id_route_to_the_same_session() -> None: + """Requests carrying the issued Mcp-Session-Id reuse that session's transport rather than creating another.""" + async with mounted_app(_server()) as (http, manager): + async with client_via_http(http) as client: + await client.list_tools() + await client.list_tools() + # The session count is the only signal that distinguishes routing-to-existing from + # silently creating a second session: both produce a successful result. + assert len(manager._server_instances) == 1 + + +@requirement("hosting:session:unknown-id") +async def test_requests_with_an_unknown_session_id_return_404() -> None: + """POST, GET, and DELETE each carrying an unknown Mcp-Session-Id are answered 404 by the manager.""" + async with mounted_app(_server()) as (http, _): + post = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 1, "method": "tools/list"}, + headers=base_headers(session_id="not-a-session"), + ) + get = await http.get("/mcp", headers=base_headers(session_id="not-a-session")) + delete = await http.delete("/mcp", headers=base_headers(session_id="not-a-session")) + + assert (post.status_code, post.json()) == snapshot( + (404, {"jsonrpc": "2.0", "id": None, "error": {"code": -32600, "message": "Session not found"}}) + ) + assert (get.status_code, delete.status_code) == (404, 404) + + +@requirement("hosting:session:missing-id") +async def test_non_initialize_post_without_a_session_id_returns_400() -> None: + """A non-initialize POST that omits Mcp-Session-Id in stateful mode is rejected with 400.""" + async with mounted_app(_server()) as (http, _): + await initialize_via_http(http) + response = await http.post( + "/mcp", json={"jsonrpc": "2.0", "id": 2, "method": "tools/list"}, headers=base_headers() + ) + + assert (response.status_code, response.json()) == snapshot( + (400, {"jsonrpc": "2.0", "id": None, "error": {"code": -32600, "message": "Bad Request: Missing session ID"}}) + ) + + +@requirement("hosting:session:delete") +@requirement("hosting:session:post-termination-404") +async def test_delete_terminates_the_session_and_subsequent_requests_return_404() -> None: + """DELETE with a valid Mcp-Session-Id terminates the session; further requests on that ID return 404.""" + async with mounted_app(_server()) as (http, manager): + session_id = await initialize_via_http(http) + + delete = await http.delete("/mcp", headers=base_headers(session_id=session_id)) + assert delete.status_code == 200 + + # The manager keeps the terminated transport registered, so the next request reaches the + # transport's own _terminated check rather than the manager's unknown-session path. + assert session_id in manager._server_instances + post = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 2, "method": "tools/list"}, + headers=base_headers(session_id=session_id), + ) + assert (post.status_code, post.json()) == snapshot( + ( + 404, + { + "jsonrpc": "2.0", + "id": None, + "error": {"code": -32600, "message": "Not Found: Session has been terminated"}, + }, + ) + ) + + +@requirement("hosting:session:isolation") +async def test_terminating_one_session_leaves_others_working() -> None: + """Terminating one session on a manager does not disturb a concurrent session on the same manager.""" + async with mounted_app(_server()) as (http, manager): + async with client_via_http(http) as survivor: + async with client_via_http(http) as terminated: + await terminated.list_tools() + assert len(manager._server_instances) == 2 + # `terminated` has exited (its DELETE has been sent); `survivor` still answers. + result = await survivor.list_tools() + + assert result.tools[0].name == "noop" + + +@requirement("hosting:session:reinitialize") +async def test_second_initialize_on_an_existing_session_is_accepted() -> None: + """A second initialize POST carrying an existing session ID is processed rather than rejected. + + See the divergence on the requirement: the entry expects a rejection, but the SDK forwards the + second initialize to the running server, which answers it as a fresh handshake. + """ + async with mounted_app(_server()) as (http, manager): + session_id = await initialize_via_http(http) + response, messages = await post_jsonrpc(http, initialize_body(request_id=2), session_id=session_id) + assert len(manager._server_instances) == 1 + + assert response.status_code == snapshot(200) + assert isinstance(messages[0], JSONRPCResponse) + assert messages[0].id == 2 + + +@requirement("hosting:stateless:no-session-id") +@requirement("hosting:stateless:no-reuse") +async def test_stateless_mode_never_issues_a_session_id() -> None: + """A stateless server issues no Mcp-Session-Id and creates no persistent transport. + + The recording proves no request the SDK client sent carried an Mcp-Session-Id (the server + cannot have issued one, or the client would echo it); the empty instance map proves the + manager kept no transport between requests. + """ + requests: list[httpx.Request] = [] + + async def record(request: httpx.Request) -> None: + requests.append(request) + + async with mounted_app(_server(), stateless_http=True, on_request=record) as (http, manager): + async with client_via_http(http) as client: + result = await client.list_tools() + assert manager._server_instances == {} + + assert result.tools[0].name == "noop" + assert all("mcp-session-id" not in request.headers for request in requests) + assert "DELETE" not in {request.method for request in requests} + + +@requirement("hosting:stateless:concurrent-clients") +async def test_stateless_mode_serves_concurrent_clients_independently() -> None: + """Two clients connected concurrently to the same stateless app each complete a round trip.""" + results: dict[str, ListToolsResult] = {} + + async with mounted_app(_server(), stateless_http=True) as (http, _): + + async def list_via(label: str) -> None: + async with client_via_http(http) as client: + results[label] = await client.list_tools() + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + tg.start_soon(list_via, "a") + tg.start_soon(list_via, "b") + + assert results["a"].tools[0].name == "noop" + assert results["b"].tools[0].name == "noop" diff --git a/tests/interaction/transports/test_sse.py b/tests/interaction/transports/test_sse.py new file mode 100644 index 0000000000..9c7353dda5 --- /dev/null +++ b/tests/interaction/transports/test_sse.py @@ -0,0 +1,90 @@ +"""Behaviour specific to the legacy HTTP+SSE transport, exercised entirely in process. + +Transport-agnostic behaviour is covered by the `connect`-fixture matrix, which runs the rest of +the suite over this transport as well; this file pins only what is observable on the SSE wiring +itself: the GET-then-POST connection lifecycle, the endpoint event, and how the message endpoint +rejects requests it cannot route to a session. Every test drives the server's real Starlette app +through the suite's streaming ASGI bridge. +""" + +from uuid import UUID, uuid4 + +import anyio +import httpx +import pytest +from inline_snapshot import snapshot + +from mcp.client.client import Client +from mcp.client.sse import sse_client +from mcp.server import Server +from mcp.types import EmptyResult +from tests.interaction._connect import BASE_URL, build_sse_app +from tests.interaction._requirements import requirement +from tests.interaction.transports._bridge import StreamingASGITransport + +pytestmark = pytest.mark.anyio + + +@requirement("transport:sse") +@requirement("transport:sse:endpoint-event") +async def test_endpoint_event_names_the_message_endpoint_with_a_fresh_session_id() -> None: + """Connecting opens a GET stream whose first event names the POST endpoint and a fresh + session id; messages POSTed there are answered on that stream, and disconnecting releases the + server's session entry.""" + app, sse = build_sse_app(Server("legacy")) + captured_session_id: list[str] = [] + + def httpx_client_factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + return httpx.AsyncClient( + transport=StreamingASGITransport(app, cancel_on_close=False), + base_url=BASE_URL, + headers=headers, + timeout=timeout, + auth=auth, + ) + + transport = sse_client( + f"{BASE_URL}/sse", httpx_client_factory=httpx_client_factory, on_session_created=captured_session_id.append + ) + with anyio.fail_after(5): + async with Client(transport) as client: + assert len(captured_session_id) == 1 + assert UUID(hex=captured_session_id[0]) in sse._read_stream_writers + assert await client.send_ping() == snapshot(EmptyResult()) + + assert sse._read_stream_writers == {} + + +@requirement("transport:sse:post:session-routing") +async def test_post_without_a_session_id_is_rejected() -> None: + """A POST to the message endpoint with no session_id query parameter is answered 400.""" + app, _ = build_sse_app(Server("legacy")) + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http: + response = await http.post("/messages/", json={"jsonrpc": "2.0", "method": "ping", "id": 1}) + assert (response.status_code, response.text) == snapshot((400, "session_id is required")) + + +@requirement("transport:sse:post:session-routing") +async def test_post_with_a_malformed_session_id_is_rejected() -> None: + """A POST whose session_id query parameter is not a UUID is answered 400.""" + app, _ = build_sse_app(Server("legacy")) + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http: + response = await http.post( + "/messages/", params={"session_id": "not-a-uuid"}, json={"jsonrpc": "2.0", "method": "ping", "id": 1} + ) + assert (response.status_code, response.text) == snapshot((400, "Invalid session ID")) + + +@requirement("transport:sse:post:session-routing") +async def test_post_for_an_unknown_session_is_rejected() -> None: + """A POST naming a well-formed session_id that no SSE stream owns is answered 404.""" + app, _ = build_sse_app(Server("legacy")) + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http: + response = await http.post( + "/messages/", params={"session_id": uuid4().hex}, json={"jsonrpc": "2.0", "method": "ping", "id": 1} + ) + assert (response.status_code, response.text) == snapshot((404, "Could not find session")) diff --git a/tests/interaction/transports/test_stdio.py b/tests/interaction/transports/test_stdio.py new file mode 100644 index 0000000000..27cc65de42 --- /dev/null +++ b/tests/interaction/transports/test_stdio.py @@ -0,0 +1,143 @@ +"""The stdio transport: one subprocess end-to-end test and one in-process framing test. + +Everything else in the suite runs in a single process; the subprocess test exists to prove the same +client↔server round trip works over the stdio transport's real boundary (a child process whose +stdin/stdout carry one newline-delimited JSON-RPC message per line). The server lives in +`_stdio_server.py` and is launched via `python -m` so subprocess coverage measurement applies. + +The framing test drives `stdio_server` in-process by passing it injected text streams instead of the +real stdin/stdout, so the raw lines the transport writes can be asserted directly without a process +boundary. + +stdio is deliberately not a leg of the `connect`-fixture matrix: spawning a subprocess per test +would be slow, and the matrix already proves transport-agnosticism over three in-process +transports. Process-lifecycle edge cases (escalation to terminate/kill, parse errors) are covered by +`tests/client/test_stdio.py` and stay deferred here. +""" + +import io +import json +import os +import sys +import tempfile +from pathlib import Path + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp.client.client import Client +from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.server.stdio import stdio_server +from mcp.shared.message import SessionMessage +from mcp.types import ( + CallToolResult, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + LoggingMessageNotificationParams, + TextContent, +) +from mcp.types.jsonrpc import jsonrpc_message_adapter +from tests.interaction._connect import initialize_body +from tests.interaction._requirements import requirement +from tests.interaction.transports import _stdio_server + +pytestmark = pytest.mark.anyio + +_REPO_ROOT = Path(__file__).parents[3] + + +@requirement("transport:stdio") +@requirement("transport:stdio:clean-shutdown") +@requirement("transport:stdio:stderr-passthrough") +async def test_tool_call_and_notification_round_trip_over_a_stdio_subprocess() -> None: + """A Client connected over stdio initializes, calls a tool with arguments, receives the + server's log notification before the call returns, and the server exits when the transport + closes its stdin.""" + received: list[LoggingMessageNotificationParams] = [] + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params) + + with tempfile.TemporaryFile(mode="w+") as errlog: + transport = stdio_client( + StdioServerParameters( + command=sys.executable, + args=["-m", _stdio_server.__name__], + cwd=str(_REPO_ROOT), + # stdio_client deliberately filters the inherited environment to a safe minimum, + # which drops the variables coverage.py's subprocess support uses; pass them through + # so the server module is measured. Empty when not running under coverage. + env={key: value for key, value in os.environ.items() if key.startswith("COVERAGE_")}, + ), + errlog=errlog, + ) + + with anyio.fail_after(10): + async with Client(transport, logging_callback=collect) as client: + assert client.initialize_result.server_info.name == "stdio-echo" + result = await client.call_tool("echo", {"text": "across\nprocesses"}) + + errlog.seek(0) + captured_stderr = errlog.read() + + assert result == snapshot(CallToolResult(content=[TextContent(text="across\nprocesses")])) + # stdio carries one ordered server→client stream, so the same notification-before-response + # guarantee holds here as for the in-memory transport. + assert received == snapshot( + [LoggingMessageNotificationParams(level="info", logger="echo", data="echoing across\nprocesses")] + ) + # The server writes this line only after its run loop returns, which happens when stdin closes: + # seeing it proves the process exited on its own rather than via the transport's terminate + # escalation, without a timing-based assertion. The capture itself proves stderr passthrough: + # the transport routes the child's stderr to the caller's `errlog` without consuming it. + assert captured_stderr == snapshot("stdio-echo: clean exit\n") + + +@requirement("transport:stdio:stream-purity") +@requirement("transport:stdio:no-embedded-newlines") +async def test_stdio_server_writes_one_jsonrpc_message_per_line() -> None: + """Everything `stdio_server` writes is a valid JSON-RPC message on its own line, and nothing else. + + The transport's stdin/stdout parameters are public, so the test injects in-process text streams + instead of the real process handles and drives the read/write streams directly: a JSON-RPC line on + stdin is parsed and delivered, and every message sent on the write stream appears as exactly one + newline-terminated line whose payload newlines are JSON-escaped. This proves the transport's own + framing; it does not guard `sys.stdout` against handler code that prints to it directly (see the + divergence on `transport:stdio:stream-purity`). + """ + captured = io.StringIO() + sent_line = json.dumps(initialize_body(request_id=1)) + "\n" + + with anyio.fail_after(5): + async with ( + stdio_server(stdin=anyio.wrap_file(io.StringIO(sent_line)), stdout=anyio.wrap_file(captured)) as ( + read_stream, + write_stream, + ), + read_stream, + write_stream, + ): + received = await read_stream.receive() + assert isinstance(received, SessionMessage) + assert isinstance(received.message, JSONRPCRequest) + assert received.message.method == "initialize" + + response = JSONRPCResponse(jsonrpc="2.0", id=1, result={"text": "line\nbreak"}) + notification = JSONRPCNotification( + jsonrpc="2.0", method="notifications/message", params={"level": "info", "data": "two\nlines"} + ) + await write_stream.send(SessionMessage(response)) + await write_stream.send(SessionMessage(notification)) + + output = captured.getvalue() + assert output.endswith("\n") + lines = output.removesuffix("\n").split("\n") + assert len(lines) == 2 + messages = [jsonrpc_message_adapter.validate_json(line) for line in lines] + assert [type(message).__name__ for message in messages] == snapshot(["JSONRPCResponse", "JSONRPCNotification"]) + # The newline inside the payload is JSON-escaped on the wire, not a literal newline that would + # break the one-message-per-line framing. + assert r"line\nbreak" in lines[0] + assert r"two\nlines" in lines[1] diff --git a/tests/interaction/transports/test_streamable_http.py b/tests/interaction/transports/test_streamable_http.py new file mode 100644 index 0000000000..d38e2a0bb3 --- /dev/null +++ b/tests/interaction/transports/test_streamable_http.py @@ -0,0 +1,168 @@ +"""Behaviour specific to the streamable HTTP transport, exercised entirely in process. + +Transport-agnostic behaviour is covered by the `connect`-fixture matrix, which runs the rest of +the suite over this transport as well; this file only pins what cannot be observed in memory: the +server's stateless and JSON-response modes, the standalone GET stream, and the full-duplex +server-initiated exchange on a still-open call. Every test drives the server's real Starlette app +through the suite's streaming ASGI bridge — no sockets, threads, or subprocesses. +""" + +import anyio +import pytest +from inline_snapshot import snapshot +from pydantic import BaseModel + +from mcp.client import ClientRequestContext +from mcp.server.elicitation import AcceptedElicitation +from mcp.server.mcpserver import Context, MCPServer +from mcp.types import ( + CallToolResult, + ElicitRequestParams, + ElicitResult, + LoggingMessageNotification, + LoggingMessageNotificationParams, + ResourceUpdatedNotification, + ResourceUpdatedNotificationParams, + TextContent, +) +from tests.interaction._connect import connect_over_streamable_http +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +def _smoke_server() -> MCPServer: + """A server exercising each message shape the transport-specific tests need.""" + mcp = MCPServer("smoke", instructions="Talk to the smoke server.") + + @mcp.tool() + def echo(text: str) -> str: + """Echo the text back.""" + return text + + class Confirmation(BaseModel): + confirmed: bool + + @mcp.tool() + async def ask(ctx: Context) -> str: + """Elicit a confirmation from the client and report the outcome.""" + answer = await ctx.elicit("Proceed?", Confirmation) + # In stateless mode the elicit raises before this point: there is no session to call back through. + assert isinstance(answer, AcceptedElicitation) + return f"confirmed={answer.data.confirmed}" + + @mcp.tool() + async def announce(ctx: Context) -> str: + """Send one notification related to this request and one that is not.""" + await ctx.info("about to announce") + await ctx.session.send_resource_updated("file:///watched.txt") + return "announced" + + return mcp + + +@requirement("transport:streamable-http:json-response") +@requirement("client-transport:http:json-response-parsed") +async def test_tool_call_over_streamable_http_with_json_responses() -> None: + """The round trip works when the server answers with a single JSON body instead of an SSE stream.""" + async with connect_over_streamable_http(_smoke_server(), json_response=True) as client: + assert client.initialize_result.server_info.name == "smoke" + result = await client.call_tool("echo", {"text": "as json"}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="as json")], structured_content={"result": "as json"}) + ) + + +@requirement("transport:streamable-http:stateless") +async def test_tool_calls_over_stateless_streamable_http() -> None: + """Consecutive requests each succeed against a stateless server with no session to share.""" + async with connect_over_streamable_http(_smoke_server(), stateless_http=True) as client: + first = await client.call_tool("echo", {"text": "first"}) + second = await client.call_tool("echo", {"text": "second"}) + + assert first == snapshot( + CallToolResult(content=[TextContent(text="first")], structured_content={"result": "first"}) + ) + assert second == snapshot( + CallToolResult(content=[TextContent(text="second")], structured_content={"result": "second"}) + ) + + +@requirement("transport:streamable-http:stateless-restrictions") +async def test_stateless_streamable_http_rejects_server_initiated_requests() -> None: + """A handler that tries to call back to the client in stateless mode fails: there is no session.""" + async with connect_over_streamable_http(_smoke_server(), stateless_http=True) as client: + result = await client.call_tool("ask", {}) + + assert result.is_error is True + assert isinstance(result.content[0], TextContent) + # The exact message is the StatelessModeNotSupported exception text wrapped by the tool-error + # path; pin the stable prefix rather than the full exception prose. + assert result.content[0].text.startswith("Error executing tool ask:") + + +@requirement("transport:streamable-http:notifications") +@requirement("transport:streamable-http:unrelated-messages") +@requirement("hosting:http:standalone-sse") +async def test_unrelated_server_messages_arrive_on_the_standalone_stream() -> None: + """A server message with no related request reaches the client through the standalone GET stream. + + The log notification is related to the tool call and travels on that call's own SSE stream; + the resource-updated notification is not related to any request, so the only way it can reach + the client is the standalone stream the client opens after initialization. Delivery order + across the two streams is not guaranteed, so the unrelated message is awaited rather than + assumed to beat the tool result. + """ + received: list[IncomingMessage] = [] + resource_update_seen = anyio.Event() + + async def collect(message: IncomingMessage) -> None: + received.append(message) + if isinstance(message, ResourceUpdatedNotification): + resource_update_seen.set() + + async with connect_over_streamable_http(_smoke_server(), message_handler=collect) as client: + result = await client.call_tool("announce", {}) + with anyio.fail_after(5): + await resource_update_seen.wait() + + assert result == snapshot( + CallToolResult(content=[TextContent(text="announced")], structured_content={"result": "announced"}) + ) + # The related log notification rides the call's stream; the unrelated resource-updated + # notification rides the standalone stream. Both arrive, nothing else does. + assert [message for message in received if isinstance(message, LoggingMessageNotification)] == snapshot( + [LoggingMessageNotification(params=LoggingMessageNotificationParams(level="info", data="about to announce"))] + ) + assert [message for message in received if isinstance(message, ResourceUpdatedNotification)] == snapshot( + [ResourceUpdatedNotification(params=ResourceUpdatedNotificationParams(uri="file:///watched.txt"))] + ) + assert len(received) == 2 + + +@requirement("transport:streamable-http:stateful") +@requirement("transport:streamable-http:server-to-client") +async def test_server_initiated_elicitation_round_trips_during_a_tool_call() -> None: + """An elicitation issued mid-call reaches the client and its answer reaches the handler over stateful HTTP. + + The elicitation request travels on the still-open SSE response of the tool call that triggered + it, and the client's answer arrives as a separate POST -- the full-duplex exchange the + streamable HTTP transport exists to provide. + """ + asked: list[ElicitRequestParams] = [] + + async def answer(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + asked.append(params) + return ElicitResult(action="accept", content={"confirmed": True}) + + async with connect_over_streamable_http(_smoke_server(), elicitation_callback=answer) as client: + # Bounded because a harness regression here historically meant deadlock, not failure. + with anyio.fail_after(5): + result = await client.call_tool("ask", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="confirmed=True")], structured_content={"result": "confirmed=True"}) + ) + assert [params.message for params in asked] == snapshot(["Proceed?"]) diff --git a/tests/issues/test_129_resource_templates.py b/tests/issues/test_129_resource_templates.py index 39e2c6f2ad..bb4735121f 100644 --- a/tests/issues/test_129_resource_templates.py +++ b/tests/issues/test_129_resource_templates.py @@ -1,15 +1,13 @@ import pytest -from mcp import types +from mcp import Client from mcp.server.mcpserver import MCPServer @pytest.mark.anyio async def test_resource_templates(): - # Create an MCP server mcp = MCPServer("Demo") - # Add a dynamic greeting resource @mcp.resource("greeting://{name}") def get_greeting(name: str) -> str: # pragma: no cover """Get a personalized greeting""" @@ -20,23 +18,16 @@ def get_user_profile(user_id: str) -> str: # pragma: no cover """Dynamic user data""" return f"Profile data for user {user_id}" - # Get the list of resource templates using the underlying server - # Note: list_resource_templates() returns a decorator that wraps the handler - # The handler returns a ServerResult with a ListResourceTemplatesResult inside - result = await mcp._lowlevel_server.request_handlers[types.ListResourceTemplatesRequest]( - types.ListResourceTemplatesRequest(params=None) - ) - assert isinstance(result, types.ListResourceTemplatesResult) - templates = result.resource_templates - - # Verify we get both templates back - assert len(templates) == 2 - - # Verify template details - greeting_template = next(t for t in templates if t.name == "get_greeting") - assert greeting_template.uri_template == "greeting://{name}" - assert greeting_template.description == "Get a personalized greeting" - - profile_template = next(t for t in templates if t.name == "get_user_profile") - assert profile_template.uri_template == "users://{user_id}/profile" - assert profile_template.description == "Dynamic user data" + async with Client(mcp) as client: + result = await client.list_resource_templates() + templates = result.resource_templates + + assert len(templates) == 2 + + greeting_template = next(t for t in templates if t.name == "get_greeting") + assert greeting_template.uri_template == "greeting://{name}" + assert greeting_template.description == "Get a personalized greeting" + + profile_template = next(t for t in templates if t.name == "get_user_profile") + assert profile_template.uri_template == "users://{user_id}/profile" + assert profile_template.description == "Dynamic user data" diff --git a/tests/issues/test_152_resource_mime_type.py b/tests/issues/test_152_resource_mime_type.py index e738017f85..851e89979f 100644 --- a/tests/issues/test_152_resource_mime_type.py +++ b/tests/issues/test_152_resource_mime_type.py @@ -3,9 +3,16 @@ import pytest from mcp import Client, types -from mcp.server.lowlevel import Server -from mcp.server.lowlevel.helper_types import ReadResourceContents +from mcp.server import Server, ServerRequestContext from mcp.server.mcpserver import MCPServer +from mcp.types import ( + BlobResourceContents, + ListResourcesResult, + PaginatedRequestParams, + ReadResourceRequestParams, + ReadResourceResult, + TextResourceContents, +) pytestmark = pytest.mark.anyio @@ -58,7 +65,6 @@ def get_image_as_bytes() -> bytes: async def test_lowlevel_resource_mime_type(): """Test that mime_type parameter is respected for resources.""" - server = Server("test") # Create a small test image as bytes image_bytes = b"fake_image_data" @@ -74,17 +80,24 @@ async def test_lowlevel_resource_mime_type(): ), ] - @server.list_resources() - async def handle_list_resources(): - return test_resources - - @server.read_resource() - async def handle_read_resource(uri: str): - if str(uri) == "test://image": - return [ReadResourceContents(content=base64_string, mime_type="image/png")] - elif str(uri) == "test://image_bytes": - return [ReadResourceContents(content=bytes(image_bytes), mime_type="image/png")] - raise Exception(f"Resource not found: {uri}") # pragma: no cover + async def handle_list_resources( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult(resources=test_resources) + + resource_contents: dict[str, list[TextResourceContents | BlobResourceContents]] = { + "test://image": [TextResourceContents(uri="test://image", text=base64_string, mime_type="image/png")], + "test://image_bytes": [ + BlobResourceContents( + uri="test://image_bytes", blob=base64.b64encode(image_bytes).decode("utf-8"), mime_type="image/png" + ) + ], + } + + async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: + return ReadResourceResult(contents=resource_contents[str(params.uri)]) + + server = Server("test", on_list_resources=handle_list_resources, on_read_resource=handle_read_resource) # Test that resources are listed with correct mime type async with Client(server) as client: diff --git a/tests/issues/test_1574_resource_uri_validation.py b/tests/issues/test_1574_resource_uri_validation.py index e6ff568774..c677081282 100644 --- a/tests/issues/test_1574_resource_uri_validation.py +++ b/tests/issues/test_1574_resource_uri_validation.py @@ -13,8 +13,14 @@ import pytest from mcp import Client, types -from mcp.server.lowlevel import Server -from mcp.server.lowlevel.helper_types import ReadResourceContents +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + ListResourcesResult, + PaginatedRequestParams, + ReadResourceRequestParams, + ReadResourceResult, + TextResourceContents, +) pytestmark = pytest.mark.anyio @@ -26,24 +32,24 @@ async def test_relative_uri_roundtrip(): the server would fail to serialize resources with relative URIs, or the URI would be transformed during the roundtrip. """ - server = Server("test") - - @server.list_resources() - async def list_resources(): - return [ - types.Resource(name="user", uri="users/me"), - types.Resource(name="config", uri="./config"), - types.Resource(name="parent", uri="../parent/resource"), - ] - - @server.read_resource() - async def read_resource(uri: str): - return [ - ReadResourceContents( - content=f"data for {uri}", - mime_type="text/plain", - ) - ] + + async def handle_list_resources( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult( + resources=[ + types.Resource(name="user", uri="users/me"), + types.Resource(name="config", uri="./config"), + types.Resource(name="parent", uri="../parent/resource"), + ] + ) + + async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: + return ReadResourceResult( + contents=[TextResourceContents(uri=str(params.uri), text=f"data for {params.uri}", mime_type="text/plain")] + ) + + server = Server("test", on_list_resources=handle_list_resources, on_read_resource=handle_read_resource) async with Client(server) as client: # List should return the exact URIs we specified @@ -67,18 +73,23 @@ async def test_custom_scheme_uri_roundtrip(): Some MCP servers use custom schemes like "custom://resource". These should work end-to-end. """ - server = Server("test") - - @server.list_resources() - async def list_resources(): - return [ - types.Resource(name="custom", uri="custom://my-resource"), - types.Resource(name="file", uri="file:///path/to/file"), - ] - - @server.read_resource() - async def read_resource(uri: str): - return [ReadResourceContents(content="data", mime_type="text/plain")] + + async def handle_list_resources( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult( + resources=[ + types.Resource(name="custom", uri="custom://my-resource"), + types.Resource(name="file", uri="file:///path/to/file"), + ] + ) + + async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: + return ReadResourceResult( + contents=[TextResourceContents(uri=str(params.uri), text="data", mime_type="text/plain")] + ) + + server = Server("test", on_list_resources=handle_list_resources, on_read_resource=handle_read_resource) async with Client(server) as client: resources = await client.list_resources() diff --git a/tests/issues/test_176_progress_token.py b/tests/issues/test_176_progress_token.py index fb4bb0101f..5d5f8b8fc9 100644 --- a/tests/issues/test_176_progress_token.py +++ b/tests/issues/test_176_progress_token.py @@ -35,6 +35,12 @@ async def test_progress_token_zero_first_call(): # Verify progress notifications assert mock_session.send_progress_notification.call_count == 3, "All progress notifications should be sent" - mock_session.send_progress_notification.assert_any_call(progress_token=0, progress=0.0, total=10.0, message=None) - mock_session.send_progress_notification.assert_any_call(progress_token=0, progress=5.0, total=10.0, message=None) - mock_session.send_progress_notification.assert_any_call(progress_token=0, progress=10.0, total=10.0, message=None) + mock_session.send_progress_notification.assert_any_call( + progress_token=0, progress=0.0, total=10.0, message=None, related_request_id="test-request" + ) + mock_session.send_progress_notification.assert_any_call( + progress_token=0, progress=5.0, total=10.0, message=None, related_request_id="test-request" + ) + mock_session.send_progress_notification.assert_any_call( + progress_token=0, progress=10.0, total=10.0, message=None, related_request_id="test-request" + ) diff --git a/tests/issues/test_342_base64_encoding.py b/tests/issues/test_342_base64_encoding.py index 44b17d3372..2bccedf8d2 100644 --- a/tests/issues/test_342_base64_encoding.py +++ b/tests/issues/test_342_base64_encoding.py @@ -1,83 +1,52 @@ """Test for base64 encoding issue in MCP server. -This test demonstrates the issue in server.py where the server uses -urlsafe_b64encode but the BlobResourceContents validator expects standard -base64 encoding. - -The test should FAIL before fixing server.py to use b64encode instead of -urlsafe_b64encode. -After the fix, the test should PASS. +This test verifies that binary resource data is encoded with standard base64 +(not urlsafe_b64encode), so BlobResourceContents validation succeeds. """ import base64 -from typing import cast import pytest -from mcp.server.lowlevel.helper_types import ReadResourceContents -from mcp.server.lowlevel.server import Server -from mcp.types import ( - BlobResourceContents, - ReadResourceRequest, - ReadResourceRequestParams, - ReadResourceResult, - ServerResult, -) +from mcp import Client +from mcp.server.mcpserver import MCPServer +from mcp.types import BlobResourceContents +pytestmark = pytest.mark.anyio -@pytest.mark.anyio -async def test_server_base64_encoding_issue(): - """Tests that server response can be validated by BlobResourceContents. - This test will: - 1. Set up a server that returns binary data - 2. Extract the base64-encoded blob from the server's response - 3. Verify the encoded data can be properly validated by BlobResourceContents +async def test_server_base64_encoding(): + """Tests that binary resource data round-trips correctly through base64 encoding. - BEFORE FIX: The test will fail because server uses urlsafe_b64encode - AFTER FIX: The test will pass because server uses standard b64encode + The test uses binary data that produces different results with urlsafe vs standard + base64, ensuring the server uses standard encoding. """ - server = Server("test") + mcp = MCPServer("test") # Create binary data that will definitely result in + and / characters # when encoded with standard base64 binary_data = bytes(list(range(255)) * 4) - # Register a resource handler that returns our test data - @server.read_resource() - async def read_resource(uri: str) -> list[ReadResourceContents]: - return [ReadResourceContents(content=binary_data, mime_type="application/octet-stream")] - - # Get the handler directly from the server - handler = server.request_handlers[ReadResourceRequest] - - # Create a request - request = ReadResourceRequest( - params=ReadResourceRequestParams(uri="test://resource"), - ) - - # Call the handler to get the response - result: ServerResult = await handler(request) - - # After (fixed code): - read_result: ReadResourceResult = cast(ReadResourceResult, result) - blob_content = read_result.contents[0] - - # First verify our test data actually produces different encodings + # Sanity check: our test data produces different encodings urlsafe_b64 = base64.urlsafe_b64encode(binary_data).decode() standard_b64 = base64.b64encode(binary_data).decode() - assert urlsafe_b64 != standard_b64, "Test data doesn't demonstrate" - " encoding difference" + assert urlsafe_b64 != standard_b64, "Test data doesn't demonstrate encoding difference" + + @mcp.resource("test://binary", mime_type="application/octet-stream") + def get_binary() -> bytes: + """Return binary test data.""" + return binary_data + + async with Client(mcp) as client: + result = await client.read_resource("test://binary") + assert len(result.contents) == 1 - # Now validate the server's output with BlobResourceContents.model_validate - # Before the fix: This should fail with "Invalid base64" because server - # uses urlsafe_b64encode - # After the fix: This should pass because server will use standard b64encode - model_dict = blob_content.model_dump() + blob_content = result.contents[0] + assert isinstance(blob_content, BlobResourceContents) - # Direct validation - this will fail before fix, pass after fix - blob_model = BlobResourceContents.model_validate(model_dict) + # Verify standard base64 was used (not urlsafe) + assert blob_content.blob == standard_b64 - # Verify we can decode the data back correctly - decoded = base64.b64decode(blob_model.blob) - assert decoded == binary_data + # Verify we can decode the data back correctly + decoded = base64.b64decode(blob_content.blob) + assert decoded == binary_data diff --git a/tests/issues/test_88_random_error.py b/tests/issues/test_88_random_error.py index cd27698e66..6b593d2a54 100644 --- a/tests/issues/test_88_random_error.py +++ b/tests/issues/test_88_random_error.py @@ -1,8 +1,6 @@ """Test to reproduce issue #88: Random error thrown on response.""" -from collections.abc import Sequence from pathlib import Path -from typing import Any import anyio import pytest @@ -11,10 +9,10 @@ from mcp import types from mcp.client.session import ClientSession -from mcp.server.lowlevel import Server +from mcp.server import Server, ServerRequestContext from mcp.shared.exceptions import MCPError from mcp.shared.message import SessionMessage -from mcp.types import ContentBlock, TextContent +from mcp.types import CallToolRequestParams, CallToolResult, ListToolsResult, PaginatedRequestParams, TextContent @pytest.mark.anyio @@ -32,36 +30,38 @@ async def test_notification_validation_error(tmp_path: Path): - Slow operations use minimal timeout (10ms) for quick test execution """ - server = Server(name="test") request_count = 0 slow_request_lock = anyio.Event() - @server.list_tools() - async def list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="slow", - description="A slow tool", - input_schema={"type": "object"}, - ), - types.Tool( - name="fast", - description="A fast tool", - input_schema={"type": "object"}, - ), - ] - - @server.call_tool() - async def slow_tool(name: str, arguments: dict[str, Any]) -> Sequence[ContentBlock]: + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + types.Tool( + name="slow", + description="A slow tool", + input_schema={"type": "object"}, + ), + types.Tool( + name="fast", + description="A fast tool", + input_schema={"type": "object"}, + ), + ] + ) + + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: nonlocal request_count request_count += 1 + assert params.name in ("slow", "fast"), f"Unknown tool: {params.name}" - if name == "slow": + if params.name == "slow": await slow_request_lock.wait() # it should timeout here - return [TextContent(type="text", text=f"slow {request_count}")] - elif name == "fast": - return [TextContent(type="text", text=f"fast {request_count}")] - return [TextContent(type="text", text=f"unknown {request_count}")] # pragma: no cover + text = f"slow {request_count}" + else: + text = f"fast {request_count}" + return CallToolResult(content=[TextContent(type="text", text=text)]) + + server = Server(name="test", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) async def server_handler( read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], diff --git a/tests/server/auth/test_routes.py b/tests/server/auth/test_routes.py new file mode 100644 index 0000000000..3d13b5ba53 --- /dev/null +++ b/tests/server/auth/test_routes.py @@ -0,0 +1,47 @@ +import pytest +from pydantic import AnyHttpUrl + +from mcp.server.auth.routes import validate_issuer_url + + +def test_validate_issuer_url_https_allowed(): + validate_issuer_url(AnyHttpUrl("https://example.com/path")) + + +def test_validate_issuer_url_http_localhost_allowed(): + validate_issuer_url(AnyHttpUrl("http://localhost:8080/path")) + + +def test_validate_issuer_url_http_127_0_0_1_allowed(): + validate_issuer_url(AnyHttpUrl("http://127.0.0.1:8080/path")) + + +def test_validate_issuer_url_http_ipv6_loopback_allowed(): + validate_issuer_url(AnyHttpUrl("http://[::1]:8080/path")) + + +def test_validate_issuer_url_http_non_loopback_rejected(): + with pytest.raises(ValueError, match="Issuer URL must be HTTPS"): + validate_issuer_url(AnyHttpUrl("http://evil.com/path")) + + +def test_validate_issuer_url_http_127_prefix_domain_rejected(): + """A domain like 127.0.0.1.evil.com is not loopback.""" + with pytest.raises(ValueError, match="Issuer URL must be HTTPS"): + validate_issuer_url(AnyHttpUrl("http://127.0.0.1.evil.com/path")) + + +def test_validate_issuer_url_http_127_prefix_subdomain_rejected(): + """A domain like 127.0.0.1something.example.com is not loopback.""" + with pytest.raises(ValueError, match="Issuer URL must be HTTPS"): + validate_issuer_url(AnyHttpUrl("http://127.0.0.1something.example.com/path")) + + +def test_validate_issuer_url_fragment_rejected(): + with pytest.raises(ValueError, match="fragment"): + validate_issuer_url(AnyHttpUrl("https://example.com/path#frag")) + + +def test_validate_issuer_url_query_rejected(): + with pytest.raises(ValueError, match="query"): + validate_issuer_url(AnyHttpUrl("https://example.com/path?q=1")) diff --git a/tests/server/lowlevel/test_func_inspection.py b/tests/server/lowlevel/test_func_inspection.py deleted file mode 100644 index 9cb2b561ac..0000000000 --- a/tests/server/lowlevel/test_func_inspection.py +++ /dev/null @@ -1,292 +0,0 @@ -"""Unit tests for func_inspection module. - -Tests the create_call_wrapper function which determines how to call handler functions -with different parameter signatures and type hints. -""" - -from typing import Any, Generic, TypeVar - -import pytest - -from mcp.server.lowlevel.func_inspection import create_call_wrapper -from mcp.types import ListPromptsRequest, ListResourcesRequest, ListToolsRequest, PaginatedRequestParams - -T = TypeVar("T") - - -@pytest.mark.anyio -async def test_no_params_returns_deprecated_wrapper() -> None: - """Test: def foo() - should call without request.""" - called_without_request = False - - async def handler() -> list[str]: - nonlocal called_without_request - called_without_request = True - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should call handler without passing request - request = ListPromptsRequest(method="prompts/list", params=None) - result = await wrapper(request) - assert called_without_request is True - assert result == ["test"] - - -@pytest.mark.anyio -async def test_param_with_default_returns_deprecated_wrapper() -> None: - """Test: def foo(thing: int = 1) - should call without request.""" - called_without_request = False - - async def handler(thing: int = 1) -> list[str]: - nonlocal called_without_request - called_without_request = True - return [f"test-{thing}"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should call handler without passing request (uses default value) - request = ListPromptsRequest(method="prompts/list", params=None) - result = await wrapper(request) - assert called_without_request is True - assert result == ["test-1"] - - -@pytest.mark.anyio -async def test_typed_request_param_passes_request() -> None: - """Test: def foo(req: ListPromptsRequest) - should pass request through.""" - received_request = None - - async def handler(req: ListPromptsRequest) -> list[str]: - nonlocal received_request - received_request = req - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should pass request to handler - request = ListPromptsRequest(method="prompts/list", params=PaginatedRequestParams(cursor="test-cursor")) - await wrapper(request) - - assert received_request is not None - assert received_request is request - params = getattr(received_request, "params", None) - assert params is not None - assert params.cursor == "test-cursor" - - -@pytest.mark.anyio -async def test_typed_request_with_default_param_passes_request() -> None: - """Test: def foo(req: ListPromptsRequest, thing: int = 1) - should pass request through.""" - received_request = None - received_thing = None - - async def handler(req: ListPromptsRequest, thing: int = 1) -> list[str]: - nonlocal received_request, received_thing - received_request = req - received_thing = thing - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should pass request to handler - request = ListPromptsRequest(method="prompts/list", params=None) - await wrapper(request) - - assert received_request is request - assert received_thing == 1 # default value - - -@pytest.mark.anyio -async def test_optional_typed_request_with_default_none_is_deprecated() -> None: - """Test: def foo(thing: int = 1, req: ListPromptsRequest | None = None) - old style.""" - called_without_request = False - - async def handler(thing: int = 1, req: ListPromptsRequest | None = None) -> list[str]: - nonlocal called_without_request - called_without_request = True - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should call handler without passing request - request = ListPromptsRequest(method="prompts/list", params=None) - result = await wrapper(request) - assert called_without_request is True - assert result == ["test"] - - -@pytest.mark.anyio -async def test_untyped_request_param_is_deprecated() -> None: - """Test: def foo(req) - should call without request.""" - called = False - - async def handler(req): # type: ignore[no-untyped-def] # pyright: ignore[reportMissingParameterType] # pragma: no cover - nonlocal called - called = True - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) # pyright: ignore[reportUnknownArgumentType] - - # Wrapper should call handler without passing request, which will fail because req is required - request = ListPromptsRequest(method="prompts/list", params=None) - # This will raise TypeError because handler expects 'req' but wrapper doesn't provide it - with pytest.raises(TypeError, match="missing 1 required positional argument"): - await wrapper(request) - - -@pytest.mark.anyio -async def test_any_typed_request_param_is_deprecated() -> None: - """Test: def foo(req: Any) - should call without request.""" - - async def handler(req: Any) -> list[str]: # pragma: no cover - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should call handler without passing request, which will fail because req is required - request = ListPromptsRequest(method="prompts/list", params=None) - # This will raise TypeError because handler expects 'req' but wrapper doesn't provide it - with pytest.raises(TypeError, match="missing 1 required positional argument"): - await wrapper(request) - - -@pytest.mark.anyio -async def test_generic_typed_request_param_is_deprecated() -> None: - """Test: def foo(req: Generic[T]) - should call without request.""" - - async def handler(req: Generic[T]) -> list[str]: # pyright: ignore[reportGeneralTypeIssues] # pragma: no cover - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should call handler without passing request, which will fail because req is required - request = ListPromptsRequest(method="prompts/list", params=None) - # This will raise TypeError because handler expects 'req' but wrapper doesn't provide it - with pytest.raises(TypeError, match="missing 1 required positional argument"): - await wrapper(request) - - -@pytest.mark.anyio -async def test_wrong_typed_request_param_is_deprecated() -> None: - """Test: def foo(req: str) - should call without request.""" - - async def handler(req: str) -> list[str]: # pragma: no cover - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should call handler without passing request, which will fail because req is required - request = ListPromptsRequest(method="prompts/list", params=None) - # This will raise TypeError because handler expects 'req' but wrapper doesn't provide it - with pytest.raises(TypeError, match="missing 1 required positional argument"): - await wrapper(request) - - -@pytest.mark.anyio -async def test_required_param_before_typed_request_attempts_to_pass() -> None: - """Test: def foo(thing: int, req: ListPromptsRequest) - attempts to pass request (will fail at runtime).""" - received_request = None - - async def handler(thing: int, req: ListPromptsRequest) -> list[str]: # pragma: no cover - nonlocal received_request - received_request = req - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper will attempt to pass request, but it will fail at runtime - # because 'thing' is required and has no default - request = ListPromptsRequest(method="prompts/list", params=None) - - # This will raise TypeError because 'thing' is missing - with pytest.raises(TypeError, match="missing 1 required positional argument: 'thing'"): - await wrapper(request) - - -@pytest.mark.anyio -async def test_positional_only_param_with_correct_type() -> None: - """Test: def foo(req: ListPromptsRequest, /) - should pass request through.""" - received_request = None - - async def handler(req: ListPromptsRequest, /) -> list[str]: - nonlocal received_request - received_request = req - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should pass request to handler - request = ListPromptsRequest(method="prompts/list", params=None) - await wrapper(request) - - assert received_request is request - - -@pytest.mark.anyio -async def test_keyword_only_param_with_correct_type() -> None: - """Test: def foo(*, req: ListPromptsRequest) - should pass request through.""" - received_request = None - - async def handler(*, req: ListPromptsRequest) -> list[str]: - nonlocal received_request - received_request = req - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should pass request to handler with keyword argument - request = ListPromptsRequest(method="prompts/list", params=None) - await wrapper(request) - - assert received_request is request - - -@pytest.mark.anyio -async def test_different_request_types() -> None: - """Test that wrapper works with different request types.""" - # Test with ListResourcesRequest - received_request = None - - async def handler(req: ListResourcesRequest) -> list[str]: - nonlocal received_request - received_request = req - return ["test"] - - wrapper = create_call_wrapper(handler, ListResourcesRequest) - - request = ListResourcesRequest(method="resources/list", params=None) - await wrapper(request) - - assert received_request is request - - # Test with ListToolsRequest - received_request = None - - async def handler2(req: ListToolsRequest) -> list[str]: - nonlocal received_request - received_request = req - return ["test"] - - wrapper2 = create_call_wrapper(handler2, ListToolsRequest) - - request2 = ListToolsRequest(method="tools/list", params=None) - await wrapper2(request2) - - assert received_request is request2 - - -@pytest.mark.anyio -async def test_mixed_params_with_typed_request() -> None: - """Test: def foo(a: str, req: ListPromptsRequest, b: int = 5) - attempts to pass request.""" - - async def handler(a: str, req: ListPromptsRequest, b: int = 5) -> list[str]: # pragma: no cover - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Will fail at runtime due to missing 'a' - request = ListPromptsRequest(method="prompts/list", params=None) - - with pytest.raises(TypeError, match="missing 1 required positional argument: 'a'"): - await wrapper(request) diff --git a/tests/server/lowlevel/test_server_listing.py b/tests/server/lowlevel/test_server_listing.py index 6bf4cddb39..2c3d303a92 100644 --- a/tests/server/lowlevel/test_server_listing.py +++ b/tests/server/lowlevel/test_server_listing.py @@ -1,20 +1,16 @@ -"""Basic tests for list_prompts, list_resources, and list_tools decorators without pagination.""" - -import warnings +"""Basic tests for list_prompts, list_resources, and list_tools handlers without pagination.""" import pytest -from mcp.server import Server +from mcp import Client +from mcp.server import Server, ServerRequestContext from mcp.types import ( - ListPromptsRequest, ListPromptsResult, - ListResourcesRequest, ListResourcesResult, - ListToolsRequest, ListToolsResult, + PaginatedRequestParams, Prompt, Resource, - ServerResult, Tool, ) @@ -22,60 +18,44 @@ @pytest.mark.anyio async def test_list_prompts_basic() -> None: """Test basic prompt listing without pagination.""" - server = Server("test") - test_prompts = [ Prompt(name="prompt1", description="First prompt"), Prompt(name="prompt2", description="Second prompt"), ] - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - - @server.list_prompts() - async def handle_list_prompts() -> list[Prompt]: - return test_prompts - - handler = server.request_handlers[ListPromptsRequest] - request = ListPromptsRequest(method="prompts/list", params=None) - result = await handler(request) + async def handle_list_prompts( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListPromptsResult: + return ListPromptsResult(prompts=test_prompts) - assert isinstance(result, ServerResult) - assert isinstance(result, ListPromptsResult) - assert result.prompts == test_prompts + server = Server("test", on_list_prompts=handle_list_prompts) + async with Client(server) as client: + result = await client.list_prompts() + assert result.prompts == test_prompts @pytest.mark.anyio async def test_list_resources_basic() -> None: """Test basic resource listing without pagination.""" - server = Server("test") - test_resources = [ Resource(uri="file:///test1.txt", name="Test 1"), Resource(uri="file:///test2.txt", name="Test 2"), ] - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - - @server.list_resources() - async def handle_list_resources() -> list[Resource]: - return test_resources + async def handle_list_resources( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult(resources=test_resources) - handler = server.request_handlers[ListResourcesRequest] - request = ListResourcesRequest(method="resources/list", params=None) - result = await handler(request) - - assert isinstance(result, ServerResult) - assert isinstance(result, ListResourcesResult) - assert result.resources == test_resources + server = Server("test", on_list_resources=handle_list_resources) + async with Client(server) as client: + result = await client.list_resources() + assert result.resources == test_resources @pytest.mark.anyio async def test_list_tools_basic() -> None: """Test basic tool listing without pagination.""" - server = Server("test") - test_tools = [ Tool( name="tool1", @@ -102,80 +82,53 @@ async def test_list_tools_basic() -> None: ), ] - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=test_tools) - @server.list_tools() - async def handle_list_tools() -> list[Tool]: - return test_tools - - handler = server.request_handlers[ListToolsRequest] - request = ListToolsRequest(method="tools/list", params=None) - result = await handler(request) - - assert isinstance(result, ServerResult) - assert isinstance(result, ListToolsResult) - assert result.tools == test_tools + server = Server("test", on_list_tools=handle_list_tools) + async with Client(server) as client: + result = await client.list_tools() + assert result.tools == test_tools @pytest.mark.anyio async def test_list_prompts_empty() -> None: """Test listing with empty results.""" - server = Server("test") - - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - - @server.list_prompts() - async def handle_list_prompts() -> list[Prompt]: - return [] - handler = server.request_handlers[ListPromptsRequest] - request = ListPromptsRequest(method="prompts/list", params=None) - result = await handler(request) + async def handle_list_prompts( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListPromptsResult: + return ListPromptsResult(prompts=[]) - assert isinstance(result, ServerResult) - assert isinstance(result, ListPromptsResult) - assert result.prompts == [] + server = Server("test", on_list_prompts=handle_list_prompts) + async with Client(server) as client: + result = await client.list_prompts() + assert result.prompts == [] @pytest.mark.anyio async def test_list_resources_empty() -> None: """Test listing with empty results.""" - server = Server("test") - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) + async def handle_list_resources( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult(resources=[]) - @server.list_resources() - async def handle_list_resources() -> list[Resource]: - return [] - - handler = server.request_handlers[ListResourcesRequest] - request = ListResourcesRequest(method="resources/list", params=None) - result = await handler(request) - - assert isinstance(result, ServerResult) - assert isinstance(result, ListResourcesResult) - assert result.resources == [] + server = Server("test", on_list_resources=handle_list_resources) + async with Client(server) as client: + result = await client.list_resources() + assert result.resources == [] @pytest.mark.anyio async def test_list_tools_empty() -> None: """Test listing with empty results.""" - server = Server("test") - - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - - @server.list_tools() - async def handle_list_tools() -> list[Tool]: - return [] - handler = server.request_handlers[ListToolsRequest] - request = ListToolsRequest(method="tools/list", params=None) - result = await handler(request) + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[]) - assert isinstance(result, ServerResult) - assert isinstance(result, ListToolsResult) - assert result.tools == [] + server = Server("test", on_list_tools=handle_list_tools) + async with Client(server) as client: + result = await client.list_tools() + assert result.tools == [] diff --git a/tests/server/lowlevel/test_server_pagination.py b/tests/server/lowlevel/test_server_pagination.py index 081fb262ab..a4627b316d 100644 --- a/tests/server/lowlevel/test_server_pagination.py +++ b/tests/server/lowlevel/test_server_pagination.py @@ -1,111 +1,83 @@ import pytest -from mcp.server import Server +from mcp import Client +from mcp.server import Server, ServerRequestContext from mcp.types import ( - ListPromptsRequest, ListPromptsResult, - ListResourcesRequest, ListResourcesResult, - ListToolsRequest, ListToolsResult, PaginatedRequestParams, - ServerResult, ) @pytest.mark.anyio async def test_list_prompts_pagination() -> None: - server = Server("test") test_cursor = "test-cursor-123" + received_params: PaginatedRequestParams | None = None - # Track what request was received - received_request: ListPromptsRequest | None = None - - @server.list_prompts() - async def handle_list_prompts(request: ListPromptsRequest) -> ListPromptsResult: - nonlocal received_request - received_request = request + async def handle_list_prompts( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListPromptsResult: + nonlocal received_params + received_params = params return ListPromptsResult(prompts=[], next_cursor="next") - handler = server.request_handlers[ListPromptsRequest] - - # Test: No cursor provided -> handler receives request with None params - request = ListPromptsRequest(method="prompts/list", params=None) - result = await handler(request) - assert received_request is not None - assert received_request.params is None - assert isinstance(result, ServerResult) + server = Server("test", on_list_prompts=handle_list_prompts) + async with Client(server) as client: + # No cursor provided + await client.list_prompts() + assert received_params is not None + assert received_params.cursor is None - # Test: Cursor provided -> handler receives request with cursor in params - request_with_cursor = ListPromptsRequest(method="prompts/list", params=PaginatedRequestParams(cursor=test_cursor)) - result2 = await handler(request_with_cursor) - assert received_request is not None - assert received_request.params is not None - assert received_request.params.cursor == test_cursor - assert isinstance(result2, ServerResult) + # Cursor provided + await client.list_prompts(cursor=test_cursor) + assert received_params is not None + assert received_params.cursor == test_cursor @pytest.mark.anyio async def test_list_resources_pagination() -> None: - server = Server("test") test_cursor = "resource-cursor-456" + received_params: PaginatedRequestParams | None = None - # Track what request was received - received_request: ListResourcesRequest | None = None - - @server.list_resources() - async def handle_list_resources(request: ListResourcesRequest) -> ListResourcesResult: - nonlocal received_request - received_request = request + async def handle_list_resources( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListResourcesResult: + nonlocal received_params + received_params = params return ListResourcesResult(resources=[], next_cursor="next") - handler = server.request_handlers[ListResourcesRequest] + server = Server("test", on_list_resources=handle_list_resources) + async with Client(server) as client: + # No cursor provided + await client.list_resources() + assert received_params is not None + assert received_params.cursor is None - # Test: No cursor provided -> handler receives request with None params - request = ListResourcesRequest(method="resources/list", params=None) - result = await handler(request) - assert received_request is not None - assert received_request.params is None - assert isinstance(result, ServerResult) - - # Test: Cursor provided -> handler receives request with cursor in params - request_with_cursor = ListResourcesRequest( - method="resources/list", params=PaginatedRequestParams(cursor=test_cursor) - ) - result2 = await handler(request_with_cursor) - assert received_request is not None - assert received_request.params is not None - assert received_request.params.cursor == test_cursor - assert isinstance(result2, ServerResult) + # Cursor provided + await client.list_resources(cursor=test_cursor) + assert received_params is not None + assert received_params.cursor == test_cursor @pytest.mark.anyio async def test_list_tools_pagination() -> None: - server = Server("test") test_cursor = "tools-cursor-789" + received_params: PaginatedRequestParams | None = None - # Track what request was received - received_request: ListToolsRequest | None = None - - @server.list_tools() - async def handle_list_tools(request: ListToolsRequest) -> ListToolsResult: - nonlocal received_request - received_request = request + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + nonlocal received_params + received_params = params return ListToolsResult(tools=[], next_cursor="next") - handler = server.request_handlers[ListToolsRequest] - - # Test: No cursor provided -> handler receives request with None params - request = ListToolsRequest(method="tools/list", params=None) - result = await handler(request) - assert received_request is not None - assert received_request.params is None - assert isinstance(result, ServerResult) - - # Test: Cursor provided -> handler receives request with cursor in params - request_with_cursor = ListToolsRequest(method="tools/list", params=PaginatedRequestParams(cursor=test_cursor)) - result2 = await handler(request_with_cursor) - assert received_request is not None - assert received_request.params is not None - assert received_request.params.cursor == test_cursor - assert isinstance(result2, ServerResult) + server = Server("test", on_list_tools=handle_list_tools) + async with Client(server) as client: + # No cursor provided + await client.list_tools() + assert received_params is not None + assert received_params.cursor is None + + # Cursor provided + await client.list_tools(cursor=test_cursor) + assert received_params is not None + assert received_params.cursor == test_cursor diff --git a/tests/server/mcpserver/auth/test_auth_integration.py b/tests/server/mcpserver/auth/test_auth_integration.py index a78a86cf0b..35fec1c57e 100644 --- a/tests/server/mcpserver/auth/test_auth_integration.py +++ b/tests/server/mcpserver/auth/test_auth_integration.py @@ -21,7 +21,8 @@ RefreshToken, construct_redirect_uri, ) -from mcp.server.auth.routes import ClientRegistrationOptions, RevocationOptions, create_auth_routes +from mcp.server.auth.routes import create_auth_routes +from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions from mcp.shared.auth import OAuthClientInformationFull, OAuthToken @@ -52,6 +53,7 @@ async def authorize(self, client: OAuthClientInformationFull, params: Authorizat redirect_uri_provided_explicitly=params.redirect_uri_provided_explicitly, expires_at=time.time() + 300, scopes=params.scopes or ["read", "write"], + subject="test-user", ) self.auth_codes[code.code] = code @@ -78,6 +80,7 @@ async def exchange_authorization_code( client_id=client.client_id, scopes=authorization_code.scopes, expires_at=int(time.time()) + 3600, + subject=authorization_code.subject, ) self.refresh_tokens[refresh_token] = access_token @@ -107,6 +110,7 @@ async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_t client_id=token_info.client_id, scopes=token_info.scopes, expires_at=token_info.expires_at, + subject=token_info.subject, ) return refresh_obj @@ -140,6 +144,7 @@ async def exchange_refresh_token( client_id=client.client_id, scopes=scopes or token_info.scopes, expires_at=int(time.time()) + 3600, + subject=refresh_token.subject, ) self.refresh_tokens[new_refresh_token] = new_access_token @@ -168,6 +173,7 @@ async def load_access_token(self, token: str) -> AccessToken | None: client_id=token_info.client_id, scopes=token_info.scopes, expires_at=token_info.expires_at, + subject=token_info.subject, ) async def revoke_token(self, token: AccessToken | RefreshToken) -> None: @@ -831,6 +837,7 @@ async def test_authorization_get( assert auth_info.client_id == client_info["client_id"] assert "read" in auth_info.scopes assert "write" in auth_info.scopes + assert auth_info.subject == "test-user" # 6. Refresh the token response = await test_client.post( @@ -851,6 +858,10 @@ async def test_authorization_get( assert new_token_response["access_token"] != access_token assert new_token_response["refresh_token"] != refresh_token + refreshed_auth_info = await mock_oauth_provider.load_access_token(new_token_response["access_token"]) + assert refreshed_auth_info + assert refreshed_auth_info.subject == "test-user" + # 7. Revoke the token response = await test_client.post( "/revoke", diff --git a/tests/server/mcpserver/prompts/test_base.py b/tests/server/mcpserver/prompts/test_base.py index 035e1cc81d..d4e4e6b5a6 100644 --- a/tests/server/mcpserver/prompts/test_base.py +++ b/tests/server/mcpserver/prompts/test_base.py @@ -1,9 +1,11 @@ +import threading from typing import Any import pytest -from mcp.server.mcpserver.prompts.base import AssistantMessage, Message, Prompt, TextContent, UserMessage -from mcp.types import EmbeddedResource, TextResourceContents +from mcp.server.mcpserver import Context +from mcp.server.mcpserver.prompts.base import AssistantMessage, Message, Prompt, UserMessage +from mcp.types import EmbeddedResource, TextContent, TextResourceContents class TestRenderPrompt: @@ -13,7 +15,9 @@ def fn() -> str: return "Hello, world!" prompt = Prompt.from_function(fn) - assert await prompt.render() == [UserMessage(content=TextContent(type="text", text="Hello, world!"))] + assert await prompt.render(None, Context()) == [ + UserMessage(content=TextContent(type="text", text="Hello, world!")) + ] @pytest.mark.anyio async def test_async_fn(self): @@ -21,7 +25,9 @@ async def fn() -> str: return "Hello, world!" prompt = Prompt.from_function(fn) - assert await prompt.render() == [UserMessage(content=TextContent(type="text", text="Hello, world!"))] + assert await prompt.render(None, Context()) == [ + UserMessage(content=TextContent(type="text", text="Hello, world!")) + ] @pytest.mark.anyio async def test_fn_with_args(self): @@ -29,7 +35,7 @@ async def fn(name: str, age: int = 30) -> str: return f"Hello, {name}! You're {age} years old." prompt = Prompt.from_function(fn) - assert await prompt.render(arguments={"name": "World"}) == [ + assert await prompt.render({"name": "World"}, Context()) == [ UserMessage(content=TextContent(type="text", text="Hello, World! You're 30 years old.")) ] @@ -40,7 +46,7 @@ async def fn(name: str, age: int = 30) -> str: # pragma: no cover prompt = Prompt.from_function(fn) with pytest.raises(ValueError): - await prompt.render(arguments={"age": 40}) + await prompt.render({"age": 40}, Context()) @pytest.mark.anyio async def test_fn_returns_message(self): @@ -48,7 +54,9 @@ async def fn() -> UserMessage: return UserMessage(content="Hello, world!") prompt = Prompt.from_function(fn) - assert await prompt.render() == [UserMessage(content=TextContent(type="text", text="Hello, world!"))] + assert await prompt.render(None, Context()) == [ + UserMessage(content=TextContent(type="text", text="Hello, world!")) + ] @pytest.mark.anyio async def test_fn_returns_assistant_message(self): @@ -56,7 +64,9 @@ async def fn() -> AssistantMessage: return AssistantMessage(content=TextContent(type="text", text="Hello, world!")) prompt = Prompt.from_function(fn) - assert await prompt.render() == [AssistantMessage(content=TextContent(type="text", text="Hello, world!"))] + assert await prompt.render(None, Context()) == [ + AssistantMessage(content=TextContent(type="text", text="Hello, world!")) + ] @pytest.mark.anyio async def test_fn_returns_multiple_messages(self): @@ -70,7 +80,7 @@ async def fn() -> list[Message]: return expected prompt = Prompt.from_function(fn) - assert await prompt.render() == expected + assert await prompt.render(None, Context()) == expected @pytest.mark.anyio async def test_fn_returns_list_of_strings(self): @@ -83,7 +93,7 @@ async def fn() -> list[str]: return expected prompt = Prompt.from_function(fn) - assert await prompt.render() == [UserMessage(t) for t in expected] + assert await prompt.render(None, Context()) == [UserMessage(t) for t in expected] @pytest.mark.anyio async def test_fn_returns_resource_content(self): @@ -102,7 +112,7 @@ async def fn() -> UserMessage: ) prompt = Prompt.from_function(fn) - assert await prompt.render() == [ + assert await prompt.render(None, Context()) == [ UserMessage( content=EmbeddedResource( type="resource", @@ -136,7 +146,7 @@ async def fn() -> list[Message]: ] prompt = Prompt.from_function(fn) - assert await prompt.render() == [ + assert await prompt.render(None, Context()) == [ UserMessage(content=TextContent(type="text", text="Please analyze this file:")), UserMessage( content=EmbeddedResource( @@ -169,7 +179,7 @@ async def fn() -> dict[str, Any]: } prompt = Prompt.from_function(fn) - assert await prompt.render() == [ + assert await prompt.render(None, Context()) == [ UserMessage( content=EmbeddedResource( type="resource", @@ -181,3 +191,21 @@ async def fn() -> dict[str, Any]: ) ) ] + + +@pytest.mark.anyio +async def test_sync_fn_runs_in_worker_thread(): + """Sync prompt functions must run in a worker thread, not the event loop.""" + + main_thread = threading.get_ident() + fn_thread: list[int] = [] + + def blocking_fn() -> str: + fn_thread.append(threading.get_ident()) + return "hello" + + prompt = Prompt.from_function(blocking_fn) + messages = await prompt.render(None, Context()) + + assert messages == [UserMessage(content=TextContent(type="text", text="hello"))] + assert fn_thread[0] != main_thread diff --git a/tests/server/mcpserver/prompts/test_manager.py b/tests/server/mcpserver/prompts/test_manager.py index 0e30b2e697..99a03db565 100644 --- a/tests/server/mcpserver/prompts/test_manager.py +++ b/tests/server/mcpserver/prompts/test_manager.py @@ -1,7 +1,9 @@ import pytest -from mcp.server.mcpserver.prompts.base import Prompt, TextContent, UserMessage +from mcp.server.mcpserver import Context +from mcp.server.mcpserver.prompts.base import Prompt, UserMessage from mcp.server.mcpserver.prompts.manager import PromptManager +from mcp.types import TextContent class TestPromptManager: @@ -71,7 +73,7 @@ def fn() -> str: manager = PromptManager() prompt = Prompt.from_function(fn) manager.add_prompt(prompt) - messages = await manager.render_prompt("fn") + messages = await manager.render_prompt("fn", None, Context()) assert messages == [UserMessage(content=TextContent(type="text", text="Hello, world!"))] @pytest.mark.anyio @@ -84,7 +86,7 @@ def fn(name: str) -> str: manager = PromptManager() prompt = Prompt.from_function(fn) manager.add_prompt(prompt) - messages = await manager.render_prompt("fn", arguments={"name": "World"}) + messages = await manager.render_prompt("fn", {"name": "World"}, Context()) assert messages == [UserMessage(content=TextContent(type="text", text="Hello, World!"))] @pytest.mark.anyio @@ -92,7 +94,7 @@ async def test_render_unknown_prompt(self): """Test rendering a non-existent prompt.""" manager = PromptManager() with pytest.raises(ValueError, match="Unknown prompt: unknown"): - await manager.render_prompt("unknown") + await manager.render_prompt("unknown", None, Context()) @pytest.mark.anyio async def test_render_prompt_with_missing_args(self): @@ -105,4 +107,4 @@ def fn(name: str) -> str: # pragma: no cover prompt = Prompt.from_function(fn) manager.add_prompt(prompt) with pytest.raises(ValueError, match="Missing required arguments"): - await manager.render_prompt("fn") + await manager.render_prompt("fn", None, Context()) diff --git a/tests/server/mcpserver/resources/test_function_resources.py b/tests/server/mcpserver/resources/test_function_resources.py index 5f5c216ed1..c1ff960617 100644 --- a/tests/server/mcpserver/resources/test_function_resources.py +++ b/tests/server/mcpserver/resources/test_function_resources.py @@ -1,3 +1,7 @@ +import threading + +import anyio +import anyio.from_thread import pytest from pydantic import BaseModel @@ -190,3 +194,51 @@ def get_data() -> str: # pragma: no cover ) assert resource.meta is None + + +@pytest.mark.anyio +async def test_sync_fn_runs_in_worker_thread(): + """Sync resource functions must run in a worker thread, not the event loop.""" + + main_thread = threading.get_ident() + fn_thread: list[int] = [] + + def blocking_fn() -> str: + fn_thread.append(threading.get_ident()) + return "data" + + resource = FunctionResource(uri="resource://test", name="test", fn=blocking_fn) + result = await resource.read() + + assert result == "data" + assert fn_thread[0] != main_thread + + +@pytest.mark.anyio +async def test_sync_fn_does_not_block_event_loop(): + """A blocking sync resource function must not stall the event loop. + + On regression (sync runs inline), anyio.from_thread.run_sync raises + RuntimeError because there is no worker-thread context, failing fast. + """ + handler_entered = anyio.Event() + release = threading.Event() + + def blocking_fn() -> str: + anyio.from_thread.run_sync(handler_entered.set) + release.wait() + return "done" + + resource = FunctionResource(uri="resource://test", name="test", fn=blocking_fn) + result: list[str | bytes] = [] + + async def run() -> None: + result.append(await resource.read()) + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + tg.start_soon(run) + await handler_entered.wait() + release.set() + + assert result == ["done"] diff --git a/tests/server/mcpserver/resources/test_resource_manager.py b/tests/server/mcpserver/resources/test_resource_manager.py index eb9b355aaf..b91c71581c 100644 --- a/tests/server/mcpserver/resources/test_resource_manager.py +++ b/tests/server/mcpserver/resources/test_resource_manager.py @@ -1,176 +1,141 @@ +import logging from pathlib import Path -from tempfile import NamedTemporaryFile import pytest from pydantic import AnyUrl +from mcp.server.mcpserver import Context from mcp.server.mcpserver.resources import FileResource, FunctionResource, ResourceManager, ResourceTemplate -@pytest.fixture -def temp_file(): +@pytest.fixture() +def temp_file(tmp_path: Path): """Create a temporary file for testing. File is automatically cleaned up after the test if it still exists. """ - content = "test content" - with NamedTemporaryFile(mode="w", delete=False) as f: - f.write(content) - path = Path(f.name).resolve() - yield path - try: # pragma: lax no cover - path.unlink() - except FileNotFoundError: # pragma: lax no cover - pass # File was already deleted by the test - - -class TestResourceManager: - """Test ResourceManager functionality.""" - - def test_add_resource(self, temp_file: Path): - """Test adding a resource.""" - manager = ResourceManager() - resource = FileResource( - uri=f"file://{temp_file}", - name="test", - path=temp_file, - ) - added = manager.add_resource(resource) - assert added == resource - assert manager.list_resources() == [resource] - - def test_add_duplicate_resource(self, temp_file: Path): - """Test adding the same resource twice.""" - manager = ResourceManager() - resource = FileResource( - uri=f"file://{temp_file}", - name="test", - path=temp_file, - ) - first = manager.add_resource(resource) - second = manager.add_resource(resource) - assert first == second - assert manager.list_resources() == [resource] - - def test_warn_on_duplicate_resources(self, temp_file: Path, caplog: pytest.LogCaptureFixture): - """Test warning on duplicate resources.""" - manager = ResourceManager() - resource = FileResource( - uri=f"file://{temp_file}", - name="test", - path=temp_file, - ) - manager.add_resource(resource) - manager.add_resource(resource) - assert "Resource already exists" in caplog.text - - def test_disable_warn_on_duplicate_resources(self, temp_file: Path, caplog: pytest.LogCaptureFixture): - """Test disabling warning on duplicate resources.""" - manager = ResourceManager(warn_on_duplicate_resources=False) - resource = FileResource( - uri=f"file://{temp_file}", - name="test", - path=temp_file, - ) - manager.add_resource(resource) - manager.add_resource(resource) - assert "Resource already exists" not in caplog.text - - @pytest.mark.anyio - async def test_get_resource(self, temp_file: Path): - """Test getting a resource by URI.""" - manager = ResourceManager() - resource = FileResource( - uri=f"file://{temp_file}", - name="test", - path=temp_file, - ) - manager.add_resource(resource) - retrieved = await manager.get_resource(resource.uri) - assert retrieved == resource - - @pytest.mark.anyio - async def test_get_resource_from_template(self): - """Test getting a resource through a template.""" - manager = ResourceManager() - - def greet(name: str) -> str: - return f"Hello, {name}!" - - template = ResourceTemplate.from_function( - fn=greet, - uri_template="greet://{name}", - name="greeter", - ) - manager._templates[template.uri_template] = template - - resource = await manager.get_resource(AnyUrl("greet://world")) - assert isinstance(resource, FunctionResource) - content = await resource.read() - assert content == "Hello, world!" - - @pytest.mark.anyio - async def test_get_unknown_resource(self): - """Test getting a non-existent resource.""" - manager = ResourceManager() - with pytest.raises(ValueError, match="Unknown resource"): - await manager.get_resource(AnyUrl("unknown://test")) - - def test_list_resources(self, temp_file: Path): - """Test listing all resources.""" - manager = ResourceManager() - resource1 = FileResource( - uri=f"file://{temp_file}", - name="test1", - path=temp_file, - ) - resource2 = FileResource( - uri=f"file://{temp_file}2", - name="test2", - path=temp_file, - ) - manager.add_resource(resource1) - manager.add_resource(resource2) - resources = manager.list_resources() - assert len(resources) == 2 - assert resources == [resource1, resource2] - - -class TestResourceManagerMetadata: - """Test ResourceManager Metadata""" - - def test_add_template_with_metadata(self): - """Test that ResourceManager.add_template() accepts and passes meta parameter.""" - - manager = ResourceManager() - - def get_item(id: str) -> str: # pragma: no cover - return f"Item {id}" - - metadata = {"source": "database", "cached": True} - - template = manager.add_template( - fn=get_item, - uri_template="resource://items/{id}", - meta=metadata, - ) - - assert template.meta is not None - assert template.meta == metadata - assert template.meta["source"] == "database" - assert template.meta["cached"] is True - - def test_add_template_without_metadata(self): - """Test that ResourceManager.add_template() works without meta parameter.""" - - manager = ResourceManager() - - def get_item(id: str) -> str: # pragma: no cover - return f"Item {id}" - - template = manager.add_template( - fn=get_item, - uri_template="resource://items/{id}", - ) - - assert template.meta is None + tmp_file = tmp_path / "file" + tmp_file.touch() + yield tmp_file + + +def test_init_with_resources(temp_file: Path, caplog: pytest.LogCaptureFixture): + resource = FileResource(uri=f"file://{temp_file}", name="test", path=temp_file) + manager = ResourceManager(resources=[resource]) + assert manager.list_resources() == [resource] + + duplicate_resource = FileResource(uri=f"file://{temp_file}", name="duplicate", path=temp_file) + + with caplog.at_level(logging.WARNING): + manager = ResourceManager(True, resources=[resource, duplicate_resource]) + + assert "Resource already exists" in caplog.text + assert manager.list_resources() == [resource] + + +def test_add_resource(temp_file: Path): + """Test adding a resource.""" + manager = ResourceManager() + resource = FileResource(uri=f"file://{temp_file}", name="test", path=temp_file) + added = manager.add_resource(resource) + assert added == resource + assert manager.list_resources() == [resource] + + +def test_add_duplicate_resource(temp_file: Path): + """Test adding the same resource twice.""" + manager = ResourceManager() + resource = FileResource(uri=f"file://{temp_file}", name="test", path=temp_file) + first = manager.add_resource(resource) + second = manager.add_resource(resource) + assert first == second + assert manager.list_resources() == [resource] + + +def test_warn_on_duplicate_resources(temp_file: Path, caplog: pytest.LogCaptureFixture): + """Test warning on duplicate resources.""" + manager = ResourceManager() + resource = FileResource(uri=f"file://{temp_file}", name="test", path=temp_file) + manager.add_resource(resource) + manager.add_resource(resource) + assert "Resource already exists" in caplog.text + + +def test_disable_warn_on_duplicate_resources(temp_file: Path, caplog: pytest.LogCaptureFixture): + """Test disabling warning on duplicate resources.""" + manager = ResourceManager(warn_on_duplicate_resources=False) + resource = FileResource(uri=f"file://{temp_file}", name="test", path=temp_file) + manager.add_resource(resource) + manager.add_resource(resource) + assert "Resource already exists" not in caplog.text + + +@pytest.mark.anyio +async def test_get_resource(temp_file: Path): + """Test getting a resource by URI.""" + manager = ResourceManager() + resource = FileResource(uri=f"file://{temp_file}", name="test", path=temp_file) + manager.add_resource(resource) + retrieved = await manager.get_resource(resource.uri, Context()) + assert retrieved == resource + + +@pytest.mark.anyio +async def test_get_resource_from_template(): + """Test getting a resource through a template.""" + manager = ResourceManager() + + def greet(name: str) -> str: + return f"Hello, {name}!" + + template = ResourceTemplate.from_function(fn=greet, uri_template="greet://{name}", name="greeter") + manager._templates[template.uri_template] = template + + resource = await manager.get_resource(AnyUrl("greet://world"), Context()) + assert isinstance(resource, FunctionResource) + content = await resource.read() + assert content == "Hello, world!" + + +@pytest.mark.anyio +async def test_get_unknown_resource(): + """Test getting a non-existent resource.""" + manager = ResourceManager() + with pytest.raises(ValueError, match="Unknown resource"): + await manager.get_resource(AnyUrl("unknown://test"), Context()) + + +def test_list_resources(temp_file: Path): + """Test listing all resources.""" + manager = ResourceManager() + resource1 = FileResource(uri=f"file://{temp_file}", name="test1", path=temp_file) + resource2 = FileResource(uri=f"file://{temp_file}2", name="test2", path=temp_file) + + manager.add_resource(resource1) + manager.add_resource(resource2) + + resources = manager.list_resources() + assert len(resources) == 2 + assert resources == [resource1, resource2] + + +def get_item(id: str) -> str: ... + + +def test_add_template_with_metadata(): + """Test that ResourceManager.add_template() accepts and passes meta parameter.""" + manager = ResourceManager() + metadata = {"source": "database", "cached": True} + template = manager.add_template(fn=get_item, uri_template="resource://items/{id}", meta=metadata) + + assert template.meta is not None + assert template.meta == metadata + assert template.meta["source"] == "database" + assert template.meta["cached"] is True + + +def test_add_template_without_metadata(): + """Test that ResourceManager.add_template() works without meta parameter.""" + manager = ResourceManager() + template = manager.add_template(fn=get_item, uri_template="resource://items/{id}") + assert template.meta is None diff --git a/tests/server/mcpserver/resources/test_resource_template.py b/tests/server/mcpserver/resources/test_resource_template.py index 08f2033bff..2a7ba8d503 100644 --- a/tests/server/mcpserver/resources/test_resource_template.py +++ b/tests/server/mcpserver/resources/test_resource_template.py @@ -1,10 +1,11 @@ import json +import threading from typing import Any import pytest from pydantic import BaseModel -from mcp.server.mcpserver import MCPServer +from mcp.server.mcpserver import Context, MCPServer from mcp.server.mcpserver.resources import FunctionResource, ResourceTemplate from mcp.types import Annotations @@ -64,6 +65,7 @@ def my_func(key: str, value: int) -> dict[str, Any]: resource = await template.create_resource( "test://foo/123", {"key": "foo", "value": 123}, + Context(), ) assert isinstance(resource, FunctionResource) @@ -86,7 +88,7 @@ def failing_func(x: str) -> str: ) with pytest.raises(ValueError, match="Error creating resource from template"): - await template.create_resource("fail://test", {"x": "test"}) + await template.create_resource("fail://test", {"x": "test"}, Context()) @pytest.mark.anyio async def test_async_text_resource(self): @@ -104,6 +106,7 @@ async def greet(name: str) -> str: resource = await template.create_resource( "greet://world", {"name": "world"}, + Context(), ) assert isinstance(resource, FunctionResource) @@ -126,6 +129,7 @@ async def get_bytes(value: str) -> bytes: resource = await template.create_resource( "bytes://test", {"value": "test"}, + Context(), ) assert isinstance(resource, FunctionResource) @@ -152,6 +156,7 @@ def get_data(key: str, value: int) -> MyModel: resource = await template.create_resource( "test://foo/123", {"key": "foo", "value": 123}, + Context(), ) assert isinstance(resource, FunctionResource) @@ -183,6 +188,7 @@ def get_data(value: str) -> CustomData: resource = await template.create_resource( "test://hello", {"value": "hello"}, + Context(), ) assert isinstance(resource, FunctionResource) @@ -249,7 +255,7 @@ def get_item(item_id: str) -> str: ) # Create a resource from the template - resource = await template.create_resource("resource://items/123", {"item_id": "123"}) + resource = await template.create_resource("resource://items/123", {"item_id": "123"}, Context()) # The resource should inherit the template's annotations assert resource.annotations is not None @@ -298,10 +304,29 @@ def get_item(item_id: str) -> str: ) # Create a resource from the template - resource = await template.create_resource("resource://items/123", {"item_id": "123"}) + resource = await template.create_resource("resource://items/123", {"item_id": "123"}, Context()) # The resource should inherit the template's metadata assert resource.meta is not None assert resource.meta == metadata assert resource.meta["category"] == "inventory" assert resource.meta["cacheable"] is True + + +@pytest.mark.anyio +async def test_sync_fn_runs_in_worker_thread(): + """Sync template functions must run in a worker thread, not the event loop.""" + + main_thread = threading.get_ident() + fn_thread: list[int] = [] + + def blocking_fn(name: str) -> str: + fn_thread.append(threading.get_ident()) + return f"hello {name}" + + template = ResourceTemplate.from_function(fn=blocking_fn, uri_template="test://{name}") + resource = await template.create_resource("test://world", {"name": "world"}, Context()) + + assert isinstance(resource, FunctionResource) + assert await resource.read() == "hello world" + assert fn_thread[0] != main_thread diff --git a/tests/server/mcpserver/resources/test_resources.py b/tests/server/mcpserver/resources/test_resources.py index 93dc438d5d..5d36beda85 100644 --- a/tests/server/mcpserver/resources/test_resources.py +++ b/tests/server/mcpserver/resources/test_resources.py @@ -91,6 +91,14 @@ def dummy_func() -> str: # pragma: no cover ) assert resource.mime_type == "application/json" + # RFC 2045 quoted parameter value (gh-1756) + resource = FunctionResource( + uri="resource://test", + fn=dummy_func, + mime_type='text/plain; charset="utf-8"', + ) + assert resource.mime_type == 'text/plain; charset="utf-8"' + @pytest.mark.anyio async def test_resource_read_abstract(self): """Test that Resource.read() is abstract.""" diff --git a/tests/server/mcpserver/test_elicitation.py b/tests/server/mcpserver/test_elicitation.py index 6cf49fbd7a..679fb848f5 100644 --- a/tests/server/mcpserver/test_elicitation.py +++ b/tests/server/mcpserver/test_elicitation.py @@ -8,7 +8,6 @@ from mcp import Client, types from mcp.client.session import ClientSession, ElicitationFnT from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession from mcp.shared._context import RequestContext from mcp.types import ElicitRequestParams, ElicitResult, TextContent @@ -22,7 +21,7 @@ def create_ask_user_tool(mcp: MCPServer): """Create a standard ask_user tool that handles all elicitation responses.""" @mcp.tool(description="A tool that uses elicitation") - async def ask_user(prompt: str, ctx: Context[ServerSession, None]) -> str: + async def ask_user(prompt: str, ctx: Context) -> str: result = await ctx.elicit(message=f"Tool wants to ask: {prompt}", schema=AnswerSchema) if result.action == "accept" and result.data: @@ -97,7 +96,7 @@ async def test_elicitation_schema_validation(): def create_validation_tool(name: str, schema_class: type[BaseModel]): @mcp.tool(name=name, description=f"Tool testing {name}") - async def tool(ctx: Context[ServerSession, None]) -> str: + async def tool(ctx: Context) -> str: try: await ctx.elicit(message="This should fail validation", schema=schema_class) return "Should not reach here" # pragma: no cover @@ -147,7 +146,7 @@ class OptionalSchema(BaseModel): subscribe: bool | None = Field(default=False, description="Subscribe to newsletter?") @mcp.tool(description="Tool with optional fields") - async def optional_tool(ctx: Context[ServerSession, None]) -> str: + async def optional_tool(ctx: Context) -> str: result = await ctx.elicit(message="Please provide your information", schema=OptionalSchema) if result.action == "accept" and result.data: @@ -188,7 +187,7 @@ class InvalidOptionalSchema(BaseModel): optional_list: list[int] | None = Field(default=None, description="Invalid optional list") @mcp.tool(description="Tool with invalid optional field") - async def invalid_optional_tool(ctx: Context[ServerSession, None]) -> str: + async def invalid_optional_tool(ctx: Context) -> str: try: await ctx.elicit(message="This should fail", schema=InvalidOptionalSchema) return "Should not reach here" # pragma: no cover @@ -214,7 +213,7 @@ class ValidMultiSelectSchema(BaseModel): tags: list[str] = Field(description="Tags") @mcp.tool(description="Tool with valid list[str] field") - async def valid_multiselect_tool(ctx: Context[ServerSession, None]) -> str: + async def valid_multiselect_tool(ctx: Context) -> str: result = await ctx.elicit(message="Please provide tags", schema=ValidMultiSelectSchema) if result.action == "accept" and result.data: return f"Name: {result.data.name}, Tags: {', '.join(result.data.tags)}" @@ -233,7 +232,7 @@ class OptionalMultiSelectSchema(BaseModel): tags: list[str] | None = Field(default=None, description="Optional tags") @mcp.tool(description="Tool with optional list[str] field") - async def optional_multiselect_tool(ctx: Context[ServerSession, None]) -> str: + async def optional_multiselect_tool(ctx: Context) -> str: result = await ctx.elicit(message="Please provide optional tags", schema=OptionalMultiSelectSchema) if result.action == "accept" and result.data: tags_str = ", ".join(result.data.tags) if result.data.tags else "none" @@ -262,7 +261,7 @@ class DefaultsSchema(BaseModel): email: str = Field(description="Email address (required)") @mcp.tool(description="Tool with default values") - async def defaults_tool(ctx: Context[ServerSession, None]) -> str: + async def defaults_tool(ctx: Context) -> str: result = await ctx.elicit(message="Please provide your information", schema=DefaultsSchema) if result.action == "accept" and result.data: @@ -327,7 +326,7 @@ class FavoriteColorSchema(BaseModel): ) @mcp.tool(description="Single color selection") - async def select_favorite_color(ctx: Context[ServerSession, None]) -> str: + async def select_favorite_color(ctx: Context) -> str: result = await ctx.elicit(message="Select your favorite color", schema=FavoriteColorSchema) if result.action == "accept" and result.data: return f"User: {result.data.user_name}, Favorite: {result.data.favorite_color}" @@ -351,7 +350,7 @@ class FavoriteColorsSchema(BaseModel): ) @mcp.tool(description="Multiple color selection") - async def select_favorite_colors(ctx: Context[ServerSession, None]) -> str: + async def select_favorite_colors(ctx: Context) -> str: result = await ctx.elicit(message="Select your favorite colors", schema=FavoriteColorsSchema) if result.action == "accept" and result.data: return f"User: {result.data.user_name}, Colors: {', '.join(result.data.favorite_colors)}" @@ -366,7 +365,7 @@ class LegacyColorSchema(BaseModel): ) @mcp.tool(description="Legacy enum format") - async def select_color_legacy(ctx: Context[ServerSession, None]) -> str: + async def select_color_legacy(ctx: Context) -> str: result = await ctx.elicit(message="Select a color (legacy format)", schema=LegacyColorSchema) if result.action == "accept" and result.data: return f"User: {result.data.user_name}, Color: {result.data.color}" diff --git a/tests/server/mcpserver/test_integration.py b/tests/server/mcpserver/test_integration.py index c4ea2dad65..f71c0574cd 100644 --- a/tests/server/mcpserver/test_integration.py +++ b/tests/server/mcpserver/test_integration.py @@ -1,7 +1,7 @@ """Integration tests for MCPServer server functionality. These tests validate the proper functioning of MCPServer features using focused, -single-feature servers across different transports (SSE and StreamableHTTP). +single-feature example servers over an in-memory transport. """ # TODO(Marcelo): The `examples` package is not being imported as package. We need to solve this. # pyright: reportUnknownMemberType=false @@ -10,12 +10,8 @@ # pyright: reportUnknownArgumentType=false import json -import multiprocessing -import socket -from collections.abc import Generator import pytest -import uvicorn from inline_snapshot import snapshot from examples.snippets.servers import ( @@ -30,9 +26,8 @@ structured_output, tool_progress, ) +from mcp.client import Client from mcp.client.session import ClientSession -from mcp.client.sse import sse_client -from mcp.client.streamable_http import streamable_http_client from mcp.shared._context import RequestContext from mcp.shared.session import RequestResponder from mcp.types import ( @@ -42,7 +37,6 @@ ElicitRequestParams, ElicitResult, GetPromptResult, - InitializeResult, LoggingMessageNotification, LoggingMessageNotificationParams, NotificationParams, @@ -58,7 +52,8 @@ TextResourceContents, ToolListChangedNotification, ) -from tests.test_helpers import wait_for_server + +pytestmark = pytest.mark.anyio class NotificationCollector: @@ -85,105 +80,6 @@ async def handle_generic_notification( self.tool_notifications.append(message.params) -# Common fixtures -@pytest.fixture -def server_port() -> int: - """Get a free port for testing.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: - """Get the server URL for testing.""" - return f"http://127.0.0.1:{server_port}" - - -def run_server_with_transport(module_name: str, port: int, transport: str) -> None: # pragma: no cover - """Run server with specified transport.""" - # Get the MCP instance based on module name - if module_name == "basic_tool": - mcp = basic_tool.mcp - elif module_name == "basic_resource": - mcp = basic_resource.mcp - elif module_name == "basic_prompt": - mcp = basic_prompt.mcp - elif module_name == "tool_progress": - mcp = tool_progress.mcp - elif module_name == "sampling": - mcp = sampling.mcp - elif module_name == "elicitation": - mcp = elicitation.mcp - elif module_name == "completion": - mcp = completion.mcp - elif module_name == "notifications": - mcp = notifications.mcp - elif module_name == "mcpserver_quickstart": - mcp = mcpserver_quickstart.mcp - elif module_name == "structured_output": - mcp = structured_output.mcp - else: - raise ImportError(f"Unknown module: {module_name}") - - # Create app based on transport type - if transport == "sse": - app = mcp.sse_app() - elif transport == "streamable-http": - app = mcp.streamable_http_app() - else: - raise ValueError(f"Invalid transport for test server: {transport}") - - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=port, log_level="error")) - print(f"Starting {transport} server on port {port}") - server.run() - - -@pytest.fixture -def server_transport(request: pytest.FixtureRequest, server_port: int) -> Generator[str, None, None]: - """Start server in a separate process with specified MCP instance and transport. - - Args: - request: pytest request with param tuple of (module_name, transport) - server_port: Port to run the server on - - Yields: - str: The transport type ('sse' or 'streamable_http') - """ - module_name, transport = request.param - - proc = multiprocessing.Process( - target=run_server_with_transport, - args=(module_name, server_port, transport), - daemon=True, - ) - proc.start() - - # Wait for server to be ready - wait_for_server(server_port) - - yield transport - - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("Server process failed to terminate") - - -# Helper function to create client based on transport -def create_client_for_transport(transport: str, server_url: str): - """Create the appropriate client context manager based on transport type.""" - if transport == "sse": - endpoint = f"{server_url}/sse" - return sse_client(endpoint) - elif transport == "streamable-http": - endpoint = f"{server_url}/mcp" - return streamable_http_client(endpoint) - else: # pragma: no cover - raise ValueError(f"Invalid transport: {transport}") - - -# Callback functions for testing async def sampling_callback( context: RequestContext[ClientSession], params: CreateMessageRequestParams ) -> CreateMessageResult: @@ -210,147 +106,82 @@ async def elicitation_callback(context: RequestContext[ClientSession], params: E return ElicitResult(action="decline") -# Test basic tools -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("basic_tool", "sse"), - ("basic_tool", "streamable-http"), - ], - indirect=True, -) -async def test_basic_tools(server_transport: str, server_url: str) -> None: +async def test_basic_tools() -> None: """Test basic tool functionality.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Tool Example" - assert result.capabilities.tools is not None - - # Test sum tool - tool_result = await session.call_tool("sum", {"a": 5, "b": 3}) - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - assert tool_result.content[0].text == "8" - - # Test weather tool - weather_result = await session.call_tool("get_weather", {"city": "London"}) - assert len(weather_result.content) == 1 - assert isinstance(weather_result.content[0], TextContent) - assert "Weather in London: 22degreesC" in weather_result.content[0].text - - -# Test resources -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("basic_resource", "sse"), - ("basic_resource", "streamable-http"), - ], - indirect=True, -) -async def test_basic_resources(server_transport: str, server_url: str) -> None: + async with Client(basic_tool.mcp) as client: + assert client.initialize_result.capabilities.tools is not None + + # Test sum tool + tool_result = await client.call_tool("sum", {"a": 5, "b": 3}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == "8" + + # Test weather tool + weather_result = await client.call_tool("get_weather", {"city": "London"}) + assert len(weather_result.content) == 1 + assert isinstance(weather_result.content[0], TextContent) + assert "Weather in London: 22degreesC" in weather_result.content[0].text + + +async def test_basic_resources() -> None: """Test basic resource functionality.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Resource Example" - assert result.capabilities.resources is not None - - # Test document resource - doc_content = await session.read_resource("file://documents/readme") - assert isinstance(doc_content, ReadResourceResult) - assert len(doc_content.contents) == 1 - assert isinstance(doc_content.contents[0], TextResourceContents) - assert "Content of readme" in doc_content.contents[0].text - - # Test settings resource - settings_content = await session.read_resource("config://settings") - assert isinstance(settings_content, ReadResourceResult) - assert len(settings_content.contents) == 1 - assert isinstance(settings_content.contents[0], TextResourceContents) - settings_json = json.loads(settings_content.contents[0].text) - assert settings_json["theme"] == "dark" - assert settings_json["language"] == "en" - - -# Test prompts -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("basic_prompt", "sse"), - ("basic_prompt", "streamable-http"), - ], - indirect=True, -) -async def test_basic_prompts(server_transport: str, server_url: str) -> None: + async with Client(basic_resource.mcp) as client: + assert client.initialize_result.capabilities.resources is not None + + # Test document resource + doc_content = await client.read_resource("file://documents/readme") + assert isinstance(doc_content, ReadResourceResult) + assert len(doc_content.contents) == 1 + assert isinstance(doc_content.contents[0], TextResourceContents) + assert "Content of readme" in doc_content.contents[0].text + + # Test settings resource + settings_content = await client.read_resource("config://settings") + assert isinstance(settings_content, ReadResourceResult) + assert len(settings_content.contents) == 1 + assert isinstance(settings_content.contents[0], TextResourceContents) + settings_json = json.loads(settings_content.contents[0].text) + assert settings_json["theme"] == "dark" + assert settings_json["language"] == "en" + + +async def test_basic_prompts() -> None: """Test basic prompt functionality.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Prompt Example" - assert result.capabilities.prompts is not None - - # Test review_code prompt - prompts = await session.list_prompts() - review_prompt = next((p for p in prompts.prompts if p.name == "review_code"), None) - assert review_prompt is not None - - prompt_result = await session.get_prompt("review_code", {"code": "def hello():\n print('Hello')"}) - assert isinstance(prompt_result, GetPromptResult) - assert len(prompt_result.messages) == 1 - assert isinstance(prompt_result.messages[0].content, TextContent) - assert "Please review this code:" in prompt_result.messages[0].content.text - assert "def hello():" in prompt_result.messages[0].content.text - - # Test debug_error prompt - debug_result = await session.get_prompt( - "debug_error", {"error": "TypeError: 'NoneType' object is not subscriptable"} - ) - assert isinstance(debug_result, GetPromptResult) - assert len(debug_result.messages) == 3 - assert debug_result.messages[0].role == "user" - assert isinstance(debug_result.messages[0].content, TextContent) - assert "I'm seeing this error:" in debug_result.messages[0].content.text - assert debug_result.messages[1].role == "user" - assert isinstance(debug_result.messages[1].content, TextContent) - assert "TypeError" in debug_result.messages[1].content.text - assert debug_result.messages[2].role == "assistant" - assert isinstance(debug_result.messages[2].content, TextContent) - assert "I'll help debug that" in debug_result.messages[2].content.text - - -# Test progress reporting -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("tool_progress", "sse"), - ("tool_progress", "streamable-http"), - ], - indirect=True, -) -async def test_tool_progress(server_transport: str, server_url: str) -> None: + async with Client(basic_prompt.mcp) as client: + assert client.initialize_result.capabilities.prompts is not None + + # Test review_code prompt + prompts = await client.list_prompts() + review_prompt = next((p for p in prompts.prompts if p.name == "review_code"), None) + assert review_prompt is not None + + prompt_result = await client.get_prompt("review_code", {"code": "def hello():\n print('Hello')"}) + assert isinstance(prompt_result, GetPromptResult) + assert len(prompt_result.messages) == 1 + assert isinstance(prompt_result.messages[0].content, TextContent) + assert "Please review this code:" in prompt_result.messages[0].content.text + assert "def hello():" in prompt_result.messages[0].content.text + + # Test debug_error prompt + debug_result = await client.get_prompt( + "debug_error", {"error": "TypeError: 'NoneType' object is not subscriptable"} + ) + assert isinstance(debug_result, GetPromptResult) + assert len(debug_result.messages) == 3 + assert debug_result.messages[0].role == "user" + assert isinstance(debug_result.messages[0].content, TextContent) + assert "I'm seeing this error:" in debug_result.messages[0].content.text + assert debug_result.messages[1].role == "user" + assert isinstance(debug_result.messages[1].content, TextContent) + assert "TypeError" in debug_result.messages[1].content.text + assert debug_result.messages[2].role == "assistant" + assert isinstance(debug_result.messages[2].content, TextContent) + assert "I'll help debug that" in debug_result.messages[2].content.text + + +async def test_tool_progress() -> None: """Test tool progress reporting.""" - transport = server_transport collector = NotificationCollector() async def message_handler(message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception): @@ -358,134 +189,78 @@ async def message_handler(message: RequestResponder[ServerRequest, ClientResult] if isinstance(message, Exception): # pragma: no cover raise message - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Progress Example" - - # Test progress callback - progress_updates = [] - - async def progress_callback(progress: float, total: float | None, message: str | None) -> None: - progress_updates.append((progress, total, message)) - - # Call tool with progress - steps = 3 - tool_result = await session.call_tool( - "long_running_task", - {"task_name": "Test Task", "steps": steps}, - progress_callback=progress_callback, - ) - assert tool_result.content == snapshot([TextContent(text="Task 'Test Task' completed")]) - - # Verify progress updates - assert len(progress_updates) == steps - for i, (progress, total, message) in enumerate(progress_updates): - expected_progress = (i + 1) / steps - assert abs(progress - expected_progress) < 0.01 - assert total == 1.0 - assert f"Step {i + 1}/{steps}" in message - - # Verify log messages - assert len(collector.log_messages) > 0 - - -# Test sampling -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("sampling", "sse"), - ("sampling", "streamable-http"), - ], - indirect=True, -) -async def test_sampling(server_transport: str, server_url: str) -> None: + async with Client(tool_progress.mcp, message_handler=message_handler) as client: + # Test progress callback + progress_updates = [] + + async def progress_callback(progress: float, total: float | None, message: str | None) -> None: + progress_updates.append((progress, total, message)) + + # Call tool with progress + steps = 3 + tool_result = await client.call_tool( + "long_running_task", + {"task_name": "Test Task", "steps": steps}, + progress_callback=progress_callback, + ) + assert tool_result.content == snapshot([TextContent(text="Task 'Test Task' completed")]) + + # Verify progress updates + assert len(progress_updates) == steps + for i, (progress, total, message) in enumerate(progress_updates): + expected_progress = (i + 1) / steps + assert abs(progress - expected_progress) < 0.01 + assert total == 1.0 + assert f"Step {i + 1}/{steps}" in message + + # Verify log messages + assert len(collector.log_messages) > 0 + + +async def test_sampling() -> None: """Test sampling (LLM interaction) functionality.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, sampling_callback=sampling_callback) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Sampling Example" - assert result.capabilities.tools is not None - - # Test sampling tool - sampling_result = await session.call_tool("generate_poem", {"topic": "nature"}) - assert len(sampling_result.content) == 1 - assert isinstance(sampling_result.content[0], TextContent) - assert "This is a simulated LLM response" in sampling_result.content[0].text - - -# Test elicitation -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("elicitation", "sse"), - ("elicitation", "streamable-http"), - ], - indirect=True, -) -async def test_elicitation(server_transport: str, server_url: str) -> None: + async with Client(sampling.mcp, sampling_callback=sampling_callback) as client: + assert client.initialize_result.capabilities.tools is not None + + # Test sampling tool + sampling_result = await client.call_tool("generate_poem", {"topic": "nature"}) + assert len(sampling_result.content) == 1 + assert isinstance(sampling_result.content[0], TextContent) + assert "This is a simulated LLM response" in sampling_result.content[0].text + + +async def test_elicitation() -> None: """Test elicitation (user interaction) functionality.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, elicitation_callback=elicitation_callback) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Elicitation Example" - - # Test booking with unavailable date (triggers elicitation) - booking_result = await session.call_tool( - "book_table", - { - "date": "2024-12-25", # Unavailable date - "time": "19:00", - "party_size": 4, - }, - ) - assert len(booking_result.content) == 1 - assert isinstance(booking_result.content[0], TextContent) - assert "[SUCCESS] Booked for 2024-12-26" in booking_result.content[0].text - - # Test booking with available date (no elicitation) - booking_result = await session.call_tool( - "book_table", - { - "date": "2024-12-20", # Available date - "time": "20:00", - "party_size": 2, - }, - ) - assert len(booking_result.content) == 1 - assert isinstance(booking_result.content[0], TextContent) - assert "[SUCCESS] Booked for 2024-12-20 at 20:00" in booking_result.content[0].text - - -# Test notifications -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("notifications", "sse"), - ("notifications", "streamable-http"), - ], - indirect=True, -) -async def test_notifications(server_transport: str, server_url: str) -> None: + async with Client(elicitation.mcp, elicitation_callback=elicitation_callback) as client: + # Test booking with unavailable date (triggers elicitation) + booking_result = await client.call_tool( + "book_table", + { + "date": "2024-12-25", # Unavailable date + "time": "19:00", + "party_size": 4, + }, + ) + assert len(booking_result.content) == 1 + assert isinstance(booking_result.content[0], TextContent) + assert "[SUCCESS] Booked for 2024-12-26" in booking_result.content[0].text + + # Test booking with available date (no elicitation) + booking_result = await client.call_tool( + "book_table", + { + "date": "2024-12-20", # Available date + "time": "20:00", + "party_size": 2, + }, + ) + assert len(booking_result.content) == 1 + assert isinstance(booking_result.content[0], TextContent) + assert "[SUCCESS] Booked for 2024-12-20 at 20:00" in booking_result.content[0].text + + +async def test_notifications() -> None: """Test notifications and logging functionality.""" - transport = server_transport collector = NotificationCollector() async def message_handler(message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception): @@ -493,150 +268,86 @@ async def message_handler(message: RequestResponder[ServerRequest, ClientResult] if isinstance(message, Exception): # pragma: no cover raise message - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Notifications Example" - - # Call tool that generates notifications - tool_result = await session.call_tool("process_data", {"data": "test_data"}) - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - assert "Processed: test_data" in tool_result.content[0].text - - # Verify log messages at different levels - assert len(collector.log_messages) >= 4 - log_levels = {msg.level for msg in collector.log_messages} - assert "debug" in log_levels - assert "info" in log_levels - assert "warning" in log_levels - assert "error" in log_levels - - # Verify resource list changed notification - assert len(collector.resource_notifications) > 0 - - -# Test completion -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("completion", "sse"), - ("completion", "streamable-http"), - ], - indirect=True, -) -async def test_completion(server_transport: str, server_url: str) -> None: + async with Client(notifications.mcp, message_handler=message_handler) as client: + # Call tool that generates notifications + tool_result = await client.call_tool("process_data", {"data": "test_data"}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert "Processed: test_data" in tool_result.content[0].text + + # Verify log messages at different levels + assert len(collector.log_messages) >= 4 + log_levels = {msg.level for msg in collector.log_messages} + assert "debug" in log_levels + assert "info" in log_levels + assert "warning" in log_levels + assert "error" in log_levels + + # Verify resource list changed notification + assert len(collector.resource_notifications) > 0 + + +async def test_completion() -> None: """Test completion (autocomplete) functionality.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Example" - assert result.capabilities.resources is not None - assert result.capabilities.prompts is not None - - # Test resource completion - completion_result = await session.complete( - ref=ResourceTemplateReference(type="ref/resource", uri="github://repos/{owner}/{repo}"), - argument={"name": "repo", "value": ""}, - context_arguments={"owner": "modelcontextprotocol"}, - ) - - assert completion_result is not None - assert hasattr(completion_result, "completion") - assert completion_result.completion is not None - assert len(completion_result.completion.values) == 3 - assert "python-sdk" in completion_result.completion.values - assert "typescript-sdk" in completion_result.completion.values - assert "specification" in completion_result.completion.values - - # Test prompt completion - completion_result = await session.complete( - ref=PromptReference(type="ref/prompt", name="review_code"), - argument={"name": "language", "value": "py"}, - ) - - assert completion_result is not None - assert hasattr(completion_result, "completion") - assert completion_result.completion is not None - assert "python" in completion_result.completion.values - assert all(lang.startswith("py") for lang in completion_result.completion.values) - - -# Test MCPServer quickstart example -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("mcpserver_quickstart", "sse"), - ("mcpserver_quickstart", "streamable-http"), - ], - indirect=True, -) -async def test_mcpserver_quickstart(server_transport: str, server_url: str) -> None: + async with Client(completion.mcp) as client: + assert client.initialize_result.capabilities.resources is not None + assert client.initialize_result.capabilities.prompts is not None + + # Test resource completion + completion_result = await client.complete( + ref=ResourceTemplateReference(type="ref/resource", uri="github://repos/{owner}/{repo}"), + argument={"name": "repo", "value": ""}, + context_arguments={"owner": "modelcontextprotocol"}, + ) + + assert completion_result is not None + assert hasattr(completion_result, "completion") + assert completion_result.completion is not None + assert len(completion_result.completion.values) == 3 + assert "python-sdk" in completion_result.completion.values + assert "typescript-sdk" in completion_result.completion.values + assert "specification" in completion_result.completion.values + + # Test prompt completion + completion_result = await client.complete( + ref=PromptReference(type="ref/prompt", name="review_code"), + argument={"name": "language", "value": "py"}, + ) + + assert completion_result is not None + assert hasattr(completion_result, "completion") + assert completion_result.completion is not None + assert "python" in completion_result.completion.values + assert all(lang.startswith("py") for lang in completion_result.completion.values) + + +async def test_mcpserver_quickstart() -> None: """Test MCPServer quickstart example.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Demo" - - # Test add tool - tool_result = await session.call_tool("add", {"a": 10, "b": 20}) - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - assert tool_result.content[0].text == "30" - - # Test greeting resource directly - resource_result = await session.read_resource("greeting://Alice") - assert len(resource_result.contents) == 1 - assert isinstance(resource_result.contents[0], TextResourceContents) - assert resource_result.contents[0].text == "Hello, Alice!" - - -# Test structured output example -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("structured_output", "sse"), - ("structured_output", "streamable-http"), - ], - indirect=True, -) -async def test_structured_output(server_transport: str, server_url: str) -> None: + async with Client(mcpserver_quickstart.mcp) as client: + # Test add tool + tool_result = await client.call_tool("add", {"a": 10, "b": 20}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == "30" + + # Test greeting resource directly + resource_result = await client.read_resource("greeting://Alice") + assert len(resource_result.contents) == 1 + assert isinstance(resource_result.contents[0], TextResourceContents) + assert resource_result.contents[0].text == "Hello, Alice!" + + +async def test_structured_output() -> None: """Test structured output functionality.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Structured Output Example" - - # Test get_weather tool - weather_result = await session.call_tool("get_weather", {"city": "New York"}) - assert len(weather_result.content) == 1 - assert isinstance(weather_result.content[0], TextContent) - - # Check that the result contains expected weather data - result_text = weather_result.content[0].text - assert "22.5" in result_text # temperature - assert "sunny" in result_text # condition - assert "45" in result_text # humidity - assert "5.2" in result_text # wind_speed + async with Client(structured_output.mcp) as client: + # Test get_weather tool + weather_result = await client.call_tool("get_weather", {"city": "New York"}) + assert len(weather_result.content) == 1 + assert isinstance(weather_result.content[0], TextContent) + + # Check that the result contains expected weather data + result_text = weather_result.content[0].text + assert "22.5" in result_text # temperature + assert "sunny" in result_text # condition + assert "45" in result_text # humidity + assert "5.2" in result_text # wind_speed diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index 979dc580f8..3457ec944a 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -1,7 +1,7 @@ import base64 from pathlib import Path from typing import Any -from unittest.mock import patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from inline_snapshot import snapshot @@ -10,17 +10,21 @@ from starlette.routing import Mount, Route from mcp.client import Client +from mcp.server.context import ServerRequestContext +from mcp.server.experimental.request_context import Experimental from mcp.server.mcpserver import Context, MCPServer from mcp.server.mcpserver.exceptions import ToolError from mcp.server.mcpserver.prompts.base import Message, UserMessage from mcp.server.mcpserver.resources import FileResource, FunctionResource from mcp.server.mcpserver.utilities.types import Audio, Image -from mcp.server.session import ServerSession from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.exceptions import MCPError from mcp.types import ( AudioContent, BlobResourceContents, + Completion, + CompletionArgument, + CompletionContext, ContentBlock, EmbeddedResource, GetPromptResult, @@ -30,6 +34,7 @@ Prompt, PromptArgument, PromptMessage, + PromptReference, ReadResourceResult, Resource, ResourceTemplate, @@ -60,6 +65,15 @@ async def test_create_server(self): assert len(mcp.icons) == 1 assert mcp.icons[0].src == "https://example.com/icon.png" + def test_dependencies(self): + """Dependencies list is read by `mcp install` / `mcp dev` CLI commands.""" + mcp = MCPServer("test", dependencies=["pandas", "numpy"]) + assert mcp.dependencies == ["pandas", "numpy"] + assert mcp.settings.dependencies == ["pandas", "numpy"] + + mcp_no_deps = MCPServer("test") + assert mcp_no_deps.dependencies == [] + async def test_sse_app_returns_starlette_app(self): """Test that sse_app returns a Starlette application with correct routes.""" mcp = MCPServer("test") @@ -671,6 +685,32 @@ async def test_remove_tool_and_call(self): class TestServerResources: + async def test_init_with_resources(self): + def get_text() -> str: + """Seeded resource.""" + return "Hello from init!" + + resource = FunctionResource.from_function(fn=get_text, uri="resource://init", name="init_resource") + + mcp = MCPServer(resources=[resource]) + + async with Client(mcp) as client: + assert client.initialize_result.capabilities.resources is not None + + resources = await client.list_resources() + assert len(resources.resources) == 1 + listed = resources.resources[0] + assert listed.uri == "resource://init" + assert listed.name == "init_resource" + assert listed.description == "Seeded resource." + + result = await client.read_resource("resource://init") + + assert len(result.contents) == 1 + content = result.contents[0] + assert isinstance(content, TextResourceContents) + assert content.text == "Hello from init!" + async def test_text_resource(self): mcp = MCPServer() @@ -885,7 +925,7 @@ def get_data(name: str) -> str: assert len(await mcp.list_resources()) == 0 # When accessed, should create a concrete resource - resource = await mcp._resource_manager.get_resource("resource://test/data") + resource = await mcp._resource_manager.get_resource("resource://test/data", Context()) assert isinstance(resource, FunctionResource) result = await resource.read() assert result == "Data for test" @@ -997,7 +1037,7 @@ async def test_context_detection(self): """Test that context parameters are properly detected.""" mcp = MCPServer() - def tool_with_context(x: int, ctx: Context[ServerSession, None]) -> str: # pragma: no cover + def tool_with_context(x: int, ctx: Context) -> str: # pragma: no cover return f"Request {ctx.request_id}: {x}" tool = mcp._tool_manager.add_tool(tool_with_context) @@ -1007,7 +1047,7 @@ async def test_context_injection(self): """Test that context is properly injected into tool calls.""" mcp = MCPServer() - def tool_with_context(x: int, ctx: Context[ServerSession, None]) -> str: + def tool_with_context(x: int, ctx: Context) -> str: assert ctx.request_id is not None return f"Request {ctx.request_id}: {x}" @@ -1024,7 +1064,7 @@ async def test_async_context(self): """Test that context works in async functions.""" mcp = MCPServer() - async def async_tool(x: int, ctx: Context[ServerSession, None]) -> str: + async def async_tool(x: int, ctx: Context) -> str: assert ctx.request_id is not None return f"Async request {ctx.request_id}: {x}" @@ -1041,7 +1081,7 @@ async def test_context_logging(self): """Test that context logging methods work.""" mcp = MCPServer() - async def logging_tool(msg: str, ctx: Context[ServerSession, None]) -> str: + async def logging_tool(msg: str, ctx: Context) -> str: await ctx.debug("Debug message") await ctx.info("Info message") await ctx.warning("Warning message") @@ -1088,7 +1128,7 @@ def test_resource() -> str: return "resource data" @mcp.tool() - async def tool_with_resource(ctx: Context[ServerSession, None]) -> str: + async def tool_with_resource(ctx: Context) -> str: r_iter = await ctx.read_resource("test://data") r_list = list(r_iter) assert len(r_list) == 1 @@ -1107,7 +1147,7 @@ async def test_resource_with_context(self): mcp = MCPServer() @mcp.resource("resource://context/{name}") - def resource_with_context(name: str, ctx: Context[ServerSession, None]) -> str: + def resource_with_context(name: str, ctx: Context) -> str: """Resource that receives context.""" assert ctx is not None return f"Resource {name} - context injected" @@ -1160,7 +1200,7 @@ async def test_resource_context_custom_name(self): mcp = MCPServer() @mcp.resource("resource://custom/{id}") - def resource_custom_ctx(id: str, my_ctx: Context[ServerSession, None]) -> str: + def resource_custom_ctx(id: str, my_ctx: Context) -> str: """Resource with custom context parameter name.""" assert my_ctx is not None return f"Resource {id} with context" @@ -1188,7 +1228,7 @@ async def test_prompt_with_context(self): mcp = MCPServer() @mcp.prompt("prompt_with_ctx") - def prompt_with_context(text: str, ctx: Context[ServerSession, None]) -> str: + def prompt_with_context(text: str, ctx: Context) -> str: """Prompt that expects context.""" assert ctx is not None return f"Prompt '{text}' - context injected" @@ -1225,6 +1265,19 @@ def prompt_no_context(text: str) -> str: class TestServerPrompts: """Test prompt functionality in MCPServer server.""" + async def test_get_prompt_direct_call_without_context(self): + """Test calling mcp.get_prompt() directly without passing context.""" + mcp = MCPServer() + + @mcp.prompt() + def fn() -> str: + return "Hello, world!" + + result = await mcp.get_prompt("fn") + content = result.messages[0].content + assert isinstance(content, TextContent) + assert content.text == "Hello, world!" + async def test_prompt_decorator(self): """Test that the prompt decorator registers prompts correctly.""" mcp = MCPServer() @@ -1237,7 +1290,7 @@ def fn() -> str: assert len(prompts) == 1 assert prompts[0].name == "fn" # Don't compare functions directly since validate_call wraps them - content = await prompts[0].render() + content = await prompts[0].render(None, Context()) assert isinstance(content[0].content, TextContent) assert content[0].content.text == "Hello, world!" @@ -1252,7 +1305,7 @@ def fn() -> str: prompts = mcp._prompt_manager.list_prompts() assert len(prompts) == 1 assert prompts[0].name == "custom_name" - content = await prompts[0].render() + content = await prompts[0].render(None, Context()) assert isinstance(content[0].content, TextContent) assert content[0].content.text == "Hello, world!" @@ -1267,7 +1320,7 @@ def fn() -> str: prompts = mcp._prompt_manager.list_prompts() assert len(prompts) == 1 assert prompts[0].description == "A custom description" - content = await prompts[0].render() + content = await prompts[0].render(None, Context()) assert isinstance(content[0].content, TextContent) assert content[0].content.text == "Hello, world!" @@ -1401,6 +1454,23 @@ def prompt_fn(name: str) -> str: ... # pragma: no branch await client.get_prompt("prompt_fn") +async def test_completion_decorator() -> None: + """Test that the completion decorator registers a working handler.""" + mcp = MCPServer() + + @mcp.completion() + async def handle_completion( + ref: PromptReference, argument: CompletionArgument, context: CompletionContext | None + ) -> Completion: + assert argument.name == "style" + return Completion(values=["bold", "italic", "underline"]) + + async with Client(mcp) as client: + ref = PromptReference(type="ref/prompt", name="test") + result = await client.complete(ref=ref, argument={"name": "style", "value": "b"}) + assert result.completion.values == ["bold", "italic", "underline"] + + def test_streamable_http_no_redirect() -> None: """Test that streamable HTTP routes are correctly configured.""" mcp = MCPServer() @@ -1415,3 +1485,34 @@ def test_streamable_http_no_redirect() -> None: # Verify path values assert streamable_routes[0].path == "/mcp", "Streamable route path should be /mcp" + + +async def test_report_progress_passes_related_request_id(): + """Test that report_progress passes the request_id as related_request_id. + + Without related_request_id, the streamable HTTP transport cannot route + progress notifications to the correct SSE stream, causing them to be + silently dropped. See #953 and #2001. + """ + mock_session = AsyncMock() + mock_session.send_progress_notification = AsyncMock() + + request_context = ServerRequestContext( + request_id="req-abc-123", + session=mock_session, + meta={"progress_token": "tok-1"}, + lifespan_context=None, + experimental=Experimental(), + ) + + ctx = Context(request_context=request_context, mcp_server=MagicMock()) + + await ctx.report_progress(50, 100, message="halfway") + + mock_session.send_progress_notification.assert_awaited_once_with( + progress_token="tok-1", + progress=50, + total=100, + message="halfway", + related_request_id="req-abc-123", + ) diff --git a/tests/server/mcpserver/test_tool_manager.py b/tests/server/mcpserver/test_tool_manager.py index 550bba50a3..e4dfd4ff9b 100644 --- a/tests/server/mcpserver/test_tool_manager.py +++ b/tests/server/mcpserver/test_tool_manager.py @@ -11,7 +11,6 @@ from mcp.server.mcpserver.exceptions import ToolError from mcp.server.mcpserver.tools import Tool, ToolManager from mcp.server.mcpserver.utilities.func_metadata import ArgModelBase, FuncMetadata -from mcp.server.session import ServerSessionT from mcp.types import TextContent, ToolAnnotations @@ -188,7 +187,7 @@ def sum(a: int, b: int) -> int: manager = ToolManager() manager.add_tool(sum) - result = await manager.call_tool("sum", {"a": 1, "b": 2}) + result = await manager.call_tool("sum", {"a": 1, "b": 2}, Context()) assert result == 3 @pytest.mark.anyio @@ -199,7 +198,7 @@ async def double(n: int) -> int: manager = ToolManager() manager.add_tool(double) - result = await manager.call_tool("double", {"n": 5}) + result = await manager.call_tool("double", {"n": 5}, Context()) assert result == 10 @pytest.mark.anyio @@ -213,7 +212,7 @@ def __call__(self, x: int) -> int: manager = ToolManager() tool = manager.add_tool(MyTool()) - result = await tool.run({"x": 5}) + result = await tool.run({"x": 5}, Context()) assert result == 10 @pytest.mark.anyio @@ -227,7 +226,7 @@ async def __call__(self, x: int) -> int: manager = ToolManager() tool = manager.add_tool(MyAsyncTool()) - result = await tool.run({"x": 5}) + result = await tool.run({"x": 5}, Context()) assert result == 10 @pytest.mark.anyio @@ -238,7 +237,7 @@ def sum(a: int, b: int = 1) -> int: manager = ToolManager() manager.add_tool(sum) - result = await manager.call_tool("sum", {"a": 1}) + result = await manager.call_tool("sum", {"a": 1}, Context()) assert result == 2 @pytest.mark.anyio @@ -250,13 +249,13 @@ def sum(a: int, b: int) -> int: # pragma: no cover manager = ToolManager() manager.add_tool(sum) with pytest.raises(ToolError): - await manager.call_tool("sum", {"a": 1}) + await manager.call_tool("sum", {"a": 1}, Context()) @pytest.mark.anyio async def test_call_unknown_tool(self): manager = ToolManager() with pytest.raises(ToolError): - await manager.call_tool("unknown", {"a": 1}) + await manager.call_tool("unknown", {"a": 1}, Context()) @pytest.mark.anyio async def test_call_tool_with_list_int_input(self): @@ -266,9 +265,9 @@ def sum_vals(vals: list[int]) -> int: manager = ToolManager() manager.add_tool(sum_vals) # Try both with plain list and with JSON list - result = await manager.call_tool("sum_vals", {"vals": "[1, 2, 3]"}) + result = await manager.call_tool("sum_vals", {"vals": "[1, 2, 3]"}, Context()) assert result == 6 - result = await manager.call_tool("sum_vals", {"vals": [1, 2, 3]}) + result = await manager.call_tool("sum_vals", {"vals": [1, 2, 3]}, Context()) assert result == 6 @pytest.mark.anyio @@ -279,13 +278,13 @@ def concat_strs(vals: list[str] | str) -> str: manager = ToolManager() manager.add_tool(concat_strs) # Try both with plain python object and with JSON list - result = await manager.call_tool("concat_strs", {"vals": ["a", "b", "c"]}) + result = await manager.call_tool("concat_strs", {"vals": ["a", "b", "c"]}, Context()) assert result == "abc" - result = await manager.call_tool("concat_strs", {"vals": '["a", "b", "c"]'}) + result = await manager.call_tool("concat_strs", {"vals": '["a", "b", "c"]'}, Context()) assert result == "abc" - result = await manager.call_tool("concat_strs", {"vals": "a"}) + result = await manager.call_tool("concat_strs", {"vals": "a"}, Context()) assert result == "a" - result = await manager.call_tool("concat_strs", {"vals": '"a"'}) + result = await manager.call_tool("concat_strs", {"vals": '"a"'}, Context()) assert result == '"a"' @pytest.mark.anyio @@ -297,7 +296,7 @@ class Shrimp(BaseModel): shrimp: list[Shrimp] x: None - def name_shrimp(tank: MyShrimpTank, ctx: Context[ServerSessionT, None]) -> list[str]: + def name_shrimp(tank: MyShrimpTank) -> list[str]: return [x.name for x in tank.shrimp] manager = ToolManager() @@ -305,11 +304,13 @@ def name_shrimp(tank: MyShrimpTank, ctx: Context[ServerSessionT, None]) -> list[ result = await manager.call_tool( "name_shrimp", {"tank": {"x": None, "shrimp": [{"name": "rex"}, {"name": "gertrude"}]}}, + Context(), ) assert result == ["rex", "gertrude"] result = await manager.call_tool( "name_shrimp", {"tank": '{"x": null, "shrimp": [{"name": "rex"}, {"name": "gertrude"}]}'}, + Context(), ) assert result == ["rex", "gertrude"] @@ -317,7 +318,7 @@ def name_shrimp(tank: MyShrimpTank, ctx: Context[ServerSessionT, None]) -> list[ class TestToolSchema: @pytest.mark.anyio async def test_context_arg_excluded_from_schema(self): - def something(a: int, ctx: Context[ServerSessionT, None]) -> int: # pragma: no cover + def something(a: int, ctx: Context) -> int: # pragma: no cover return a manager = ToolManager() @@ -334,7 +335,7 @@ def test_context_parameter_detection(self): """Test that context parameters are properly detected in Tool.from_function().""" - def tool_with_context(x: int, ctx: Context[ServerSessionT, None]) -> str: # pragma: no cover + def tool_with_context(x: int, ctx: Context) -> str: # pragma: no cover return str(x) manager = ToolManager() @@ -357,61 +358,42 @@ def tool_with_parametrized_context(x: int, ctx: Context[LifespanContextT, Reques async def test_context_injection(self): """Test that context is properly injected during tool execution.""" - def tool_with_context(x: int, ctx: Context[ServerSessionT, None]) -> str: + def tool_with_context(x: int, ctx: Context) -> str: assert isinstance(ctx, Context) return str(x) manager = ToolManager() manager.add_tool(tool_with_context) - mcp = MCPServer() - ctx = mcp.get_context() - result = await manager.call_tool("tool_with_context", {"x": 42}, context=ctx) + result = await manager.call_tool("tool_with_context", {"x": 42}, context=Context()) assert result == "42" @pytest.mark.anyio async def test_context_injection_async(self): """Test that context is properly injected in async tools.""" - async def async_tool(x: int, ctx: Context[ServerSessionT, None]) -> str: + async def async_tool(x: int, ctx: Context) -> str: assert isinstance(ctx, Context) return str(x) manager = ToolManager() manager.add_tool(async_tool) - mcp = MCPServer() - ctx = mcp.get_context() - result = await manager.call_tool("async_tool", {"x": 42}, context=ctx) - assert result == "42" - - @pytest.mark.anyio - async def test_context_optional(self): - """Test that context is optional when calling tools.""" - - def tool_with_context(x: int, ctx: Context[ServerSessionT, None] | None = None) -> str: - return str(x) - - manager = ToolManager() - manager.add_tool(tool_with_context) - # Should not raise an error when context is not provided - result = await manager.call_tool("tool_with_context", {"x": 42}) + result = await manager.call_tool("async_tool", {"x": 42}, context=Context()) assert result == "42" @pytest.mark.anyio async def test_context_error_handling(self): """Test error handling when context injection fails.""" - def tool_with_context(x: int, ctx: Context[ServerSessionT, None]) -> str: + def tool_with_context(x: int, ctx: Context) -> str: raise ValueError("Test error") manager = ToolManager() manager.add_tool(tool_with_context) - mcp = MCPServer() - ctx = mcp.get_context() with pytest.raises(ToolError, match="Error executing tool tool_with_context"): - await manager.call_tool("tool_with_context", {"x": 42}, context=ctx) + await manager.call_tool("tool_with_context", {"x": 42}, context=Context()) class TestToolAnnotations: @@ -471,7 +453,7 @@ def get_user(user_id: int) -> UserOutput: manager = ToolManager() manager.add_tool(get_user) - result = await manager.call_tool("get_user", {"user_id": 1}, convert_result=True) + result = await manager.call_tool("get_user", {"user_id": 1}, Context(), convert_result=True) # don't test unstructured output here, just the structured conversion assert len(result) == 2 and result[1] == {"name": "John", "age": 30} @@ -485,9 +467,9 @@ def double_number(n: int) -> int: manager = ToolManager() manager.add_tool(double_number) - result = await manager.call_tool("double_number", {"n": 5}) + result = await manager.call_tool("double_number", {"n": 5}, Context()) assert result == 10 - result = await manager.call_tool("double_number", {"n": 5}, convert_result=True) + result = await manager.call_tool("double_number", {"n": 5}, Context(), convert_result=True) assert isinstance(result[0][0], TextContent) and result[1] == {"result": 10} @pytest.mark.anyio @@ -506,7 +488,7 @@ def get_user_dict(user_id: int) -> UserDict: manager = ToolManager() manager.add_tool(get_user_dict) - result = await manager.call_tool("get_user_dict", {"user_id": 1}) + result = await manager.call_tool("get_user_dict", {"user_id": 1}, Context()) assert result == expected_output @pytest.mark.anyio @@ -526,7 +508,7 @@ def get_person() -> Person: manager = ToolManager() manager.add_tool(get_person) - result = await manager.call_tool("get_person", {}, convert_result=True) + result = await manager.call_tool("get_person", {}, Context(), convert_result=True) # don't test unstructured output here, just the structured conversion assert len(result) == 2 and result[1] == expected_output @@ -543,9 +525,9 @@ def get_numbers() -> list[int]: manager = ToolManager() manager.add_tool(get_numbers) - result = await manager.call_tool("get_numbers", {}) + result = await manager.call_tool("get_numbers", {}, Context()) assert result == expected_list - result = await manager.call_tool("get_numbers", {}, convert_result=True) + result = await manager.call_tool("get_numbers", {}, Context(), convert_result=True) assert isinstance(result[0][0], TextContent) and result[1] == expected_output @pytest.mark.anyio @@ -558,7 +540,7 @@ def get_dict() -> dict[str, Any]: manager = ToolManager() manager.add_tool(get_dict, structured_output=False) - result = await manager.call_tool("get_dict", {}) + result = await manager.call_tool("get_dict", {}, Context()) assert isinstance(result, dict) assert result == {"key": "value"} @@ -601,12 +583,12 @@ def get_config() -> dict[str, Any]: assert "properties" not in tool.output_schema # dict[str, Any] has no constraints # Test raw result - result = await manager.call_tool("get_config", {}) + result = await manager.call_tool("get_config", {}, Context()) expected = {"debug": True, "port": 8080, "features": ["auth", "logging"]} assert result == expected # Test converted result - result = await manager.call_tool("get_config", {}) + result = await manager.call_tool("get_config", {}, Context()) assert result == expected @pytest.mark.anyio @@ -626,12 +608,12 @@ def get_scores() -> dict[str, int]: assert tool.output_schema["additionalProperties"]["type"] == "integer" # Test raw result - result = await manager.call_tool("get_scores", {}) + result = await manager.call_tool("get_scores", {}, Context()) expected = {"alice": 100, "bob": 85, "charlie": 92} assert result == expected # Test converted result - result = await manager.call_tool("get_scores", {}) + result = await manager.call_tool("get_scores", {}, Context()) assert result == expected @@ -885,7 +867,7 @@ def greet(name: str) -> str: manager.add_tool(greet) # Verify tool works before removal - result = await manager.call_tool("greet", {"name": "World"}) + result = await manager.call_tool("greet", {"name": "World"}, Context()) assert result == "Hello, World!" # Remove the tool @@ -893,7 +875,7 @@ def greet(name: str) -> str: # Verify calling removed tool raises error with pytest.raises(ToolError, match="Unknown tool: greet"): - await manager.call_tool("greet", {"name": "World"}) + await manager.call_tool("greet", {"name": "World"}, Context()) def test_remove_tool_case_sensitive(self): """Test that tool removal is case-sensitive.""" diff --git a/tests/server/mcpserver/test_url_elicitation.py b/tests/server/mcpserver/test_url_elicitation.py index 1311bd6728..af90dc208b 100644 --- a/tests/server/mcpserver/test_url_elicitation.py +++ b/tests/server/mcpserver/test_url_elicitation.py @@ -8,7 +8,6 @@ from mcp.client.session import ClientSession from mcp.server.elicitation import CancelledElicitation, DeclinedElicitation, elicit_url from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession from mcp.shared._context import RequestContext from mcp.types import ElicitRequestParams, ElicitResult, TextContent @@ -19,7 +18,7 @@ async def test_url_elicitation_accept(): mcp = MCPServer(name="URLElicitationServer") @mcp.tool(description="A tool that uses URL elicitation") - async def request_api_key(ctx: Context[ServerSession, None]) -> str: + async def request_api_key(ctx: Context) -> str: result = await ctx.session.elicit_url( message="Please provide your API key to continue.", url="https://example.com/api_key_setup", @@ -49,7 +48,7 @@ async def test_url_elicitation_decline(): mcp = MCPServer(name="URLElicitationDeclineServer") @mcp.tool(description="A tool that uses URL elicitation") - async def oauth_flow(ctx: Context[ServerSession, None]) -> str: + async def oauth_flow(ctx: Context) -> str: result = await ctx.session.elicit_url( message="Authorize access to your files.", url="https://example.com/oauth/authorize", @@ -75,7 +74,7 @@ async def test_url_elicitation_cancel(): mcp = MCPServer(name="URLElicitationCancelServer") @mcp.tool(description="A tool that uses URL elicitation") - async def payment_flow(ctx: Context[ServerSession, None]) -> str: + async def payment_flow(ctx: Context) -> str: result = await ctx.session.elicit_url( message="Complete payment to proceed.", url="https://example.com/payment", @@ -101,7 +100,7 @@ async def test_url_elicitation_helper_function(): mcp = MCPServer(name="URLElicitationHelperServer") @mcp.tool(description="Tool using elicit_url helper") - async def setup_credentials(ctx: Context[ServerSession, None]) -> str: + async def setup_credentials(ctx: Context) -> str: result = await elicit_url( session=ctx.session, message="Set up your credentials", @@ -127,7 +126,7 @@ async def test_url_no_content_in_response(): mcp = MCPServer(name="URLContentCheckServer") @mcp.tool(description="Check URL response format") - async def check_url_response(ctx: Context[ServerSession, None]) -> str: + async def check_url_response(ctx: Context) -> str: result = await ctx.session.elicit_url( message="Test message", url="https://example.com/test", @@ -164,7 +163,7 @@ class NameSchema(BaseModel): name: str = Field(description="Your name") @mcp.tool(description="Test form mode") - async def ask_name(ctx: Context[ServerSession, None]) -> str: + async def ask_name(ctx: Context) -> str: result = await ctx.elicit(message="What is your name?", schema=NameSchema) # Test only checks accept path with data assert result.action == "accept" @@ -195,7 +194,7 @@ async def test_elicit_complete_notification(): notification_sent = False @mcp.tool(description="Tool that sends completion notification") - async def trigger_elicitation(ctx: Context[ServerSession, None]) -> str: + async def trigger_elicitation(ctx: Context) -> str: nonlocal notification_sent # Simulate an async operation (e.g., user completing auth in browser) @@ -238,7 +237,7 @@ async def test_elicit_url_typed_results(): mcp = MCPServer(name="TypedResultsServer") @mcp.tool(description="Test declined result") - async def test_decline(ctx: Context[ServerSession, None]) -> str: + async def test_decline(ctx: Context) -> str: result = await elicit_url( session=ctx.session, message="Test decline", @@ -251,7 +250,7 @@ async def test_decline(ctx: Context[ServerSession, None]) -> str: return "Not declined" # pragma: no cover @mcp.tool(description="Test cancelled result") - async def test_cancel(ctx: Context[ServerSession, None]) -> str: + async def test_cancel(ctx: Context) -> str: result = await elicit_url( session=ctx.session, message="Test cancel", @@ -293,7 +292,7 @@ class EmailSchema(BaseModel): email: str = Field(description="Email address") @mcp.tool(description="Test deprecated elicit method") - async def use_deprecated_elicit(ctx: Context[ServerSession, None]) -> str: + async def use_deprecated_elicit(ctx: Context) -> str: # Use the deprecated elicit() method which should call elicit_form() result = await ctx.session.elicit( message="Enter your email", @@ -323,7 +322,7 @@ async def test_ctx_elicit_url_convenience_method(): mcp = MCPServer(name="CtxElicitUrlServer") @mcp.tool(description="A tool that uses ctx.elicit_url() directly") - async def direct_elicit_url(ctx: Context[ServerSession, None]) -> str: + async def direct_elicit_url(ctx: Context) -> str: # Use ctx.elicit_url() directly instead of ctx.session.elicit_url() result = await ctx.elicit_url( message="Test the convenience method", diff --git a/tests/server/mcpserver/test_url_elicitation_error_throw.py b/tests/server/mcpserver/test_url_elicitation_error_throw.py index 2d29937995..1f45fd60f0 100644 --- a/tests/server/mcpserver/test_url_elicitation_error_throw.py +++ b/tests/server/mcpserver/test_url_elicitation_error_throw.py @@ -5,7 +5,6 @@ from mcp import Client, ErrorData, types from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession from mcp.shared.exceptions import MCPError, UrlElicitationRequiredError @@ -15,7 +14,7 @@ async def test_url_elicitation_error_thrown_from_tool(): mcp = MCPServer(name="UrlElicitationErrorServer") @mcp.tool(description="A tool that raises UrlElicitationRequiredError") - async def connect_service(service_name: str, ctx: Context[ServerSession, None]) -> str: + async def connect_service(service_name: str, ctx: Context) -> str: # This tool cannot proceed without authorization raise UrlElicitationRequiredError( [ @@ -56,7 +55,7 @@ async def test_url_elicitation_error_from_error(): mcp = MCPServer(name="UrlElicitationErrorServer") @mcp.tool(description="A tool that raises UrlElicitationRequiredError with multiple elicitations") - async def multi_auth(ctx: Context[ServerSession, None]) -> str: + async def multi_auth(ctx: Context) -> str: raise UrlElicitationRequiredError( [ types.ElicitRequestURLParams( @@ -97,7 +96,7 @@ async def test_normal_exceptions_still_return_error_result(): mcp = MCPServer(name="NormalErrorServer") @mcp.tool(description="A tool that raises a normal exception") - async def failing_tool(ctx: Context[ServerSession, None]) -> str: + async def failing_tool(ctx: Context) -> str: raise ValueError("Something went wrong") async with Client(mcp) as client: diff --git a/tests/server/mcpserver/tools/__init__.py b/tests/server/mcpserver/tools/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/server/mcpserver/tools/test_base.py b/tests/server/mcpserver/tools/test_base.py new file mode 100644 index 0000000000..22d5f973e9 --- /dev/null +++ b/tests/server/mcpserver/tools/test_base.py @@ -0,0 +1,10 @@ +from mcp.server.mcpserver import Context +from mcp.server.mcpserver.tools.base import Tool + + +def test_context_detected_in_union_annotation(): + def my_tool(x: int, ctx: Context | None) -> str: + raise NotImplementedError + + tool = Tool.from_function(my_tool) + assert tool.context_kwarg == "ctx" diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index 6d1634f2e9..cff5a37c15 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -1,19 +1,27 @@ """Test that cancelled requests don't cause double responses.""" -from typing import Any - import anyio import pytest -from mcp import Client, types -from mcp.server.lowlevel.server import Server +from mcp import Client +from mcp.server import Server, ServerRequestContext from mcp.shared.exceptions import MCPError +from mcp.shared.message import SessionMessage from mcp.types import ( + LATEST_PROTOCOL_VERSION, CallToolRequest, CallToolRequestParams, CallToolResult, CancelledNotification, CancelledNotificationParams, + ClientCapabilities, + Implementation, + InitializeRequestParams, + JSONRPCNotification, + JSONRPCRequest, + ListToolsResult, + PaginatedRequestParams, + TextContent, Tool, ) @@ -22,34 +30,34 @@ async def test_server_remains_functional_after_cancel(): """Verify server can handle new requests after a cancellation.""" - server = Server("test-server") - # Track tool calls call_count = 0 ev_first_call = anyio.Event() first_request_id = None - @server.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="test_tool", - description="Tool for testing", - input_schema={}, - ) - ] + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="test_tool", + description="Tool for testing", + input_schema={}, + ) + ] + ) - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[types.TextContent]: + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: nonlocal call_count, first_request_id - if name == "test_tool": + if params.name == "test_tool": call_count += 1 if call_count == 1: - first_request_id = server.request_context.request_id + first_request_id = ctx.request_id ev_first_call.set() await anyio.sleep(5) # First call is slow - return [types.TextContent(type="text", text=f"Call number: {call_count}")] - raise ValueError(f"Unknown tool: {name}") # pragma: no cover + return CallToolResult(content=[TextContent(type="text", text=f"Call number: {call_count}")]) + raise ValueError(f"Unknown tool: {params.name}") # pragma: no cover + + server = Server("test-server", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) async with Client(server) as client: # First request (will be cancelled) @@ -86,6 +94,157 @@ async def first_request(): # Type narrowing for pyright content = result.content[0] assert content.type == "text" - assert isinstance(content, types.TextContent) + assert isinstance(content, TextContent) assert content.text == "Call number: 2" assert call_count == 2 + + +@pytest.mark.anyio +async def test_server_cancels_in_flight_handlers_on_transport_close(): + """When the transport closes mid-request, server.run() must cancel in-flight + handlers rather than join on them. + + Without the cancel, the task group waits for the handler, which then tries + to respond through a write stream that _receive_loop already closed, + raising ClosedResourceError and crashing server.run() with exit code 1. + + This drives server.run() with raw memory streams because InMemoryTransport + wraps it in its own finally-cancel (_memory.py) which masks the bug. + """ + handler_started = anyio.Event() + handler_cancelled = anyio.Event() + server_run_returned = anyio.Event() + + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + handler_started.set() + try: + await anyio.sleep_forever() + finally: + handler_cancelled.set() + # unreachable: sleep_forever only exits via cancellation + raise AssertionError # pragma: no cover + + server = Server("test", on_call_tool=handle_call_tool) + + to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10) + server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server(): + await server.run(server_read, server_write, server.create_initialization_options()) + server_run_returned.set() + + init_req = JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=InitializeRequestParams( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ClientCapabilities(), + client_info=Implementation(name="test", version="1.0"), + ).model_dump(by_alias=True, mode="json", exclude_none=True), + ) + initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized") + call_req = JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"), + ) + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server: + tg.start_soon(run_server) + + await to_server.send(SessionMessage(init_req)) + await from_server.receive() # init response + await to_server.send(SessionMessage(initialized)) + await to_server.send(SessionMessage(call_req)) + + await handler_started.wait() + + # Close the server's input stream — this is what stdin EOF does. + # server.run()'s incoming_messages loop ends, finally-cancel fires, + # handler gets CancelledError, server.run() returns. + await to_server.aclose() + + await server_run_returned.wait() + + assert handler_cancelled.is_set() + + +@pytest.mark.anyio +async def test_server_handles_transport_close_with_pending_server_to_client_requests(): + """When the transport closes while handlers are blocked on server→client + requests (sampling, roots, elicitation), server.run() must still exit cleanly. + + Two bugs covered: + 1. _receive_loop's finally iterates _response_streams with await checkpoints + inside; the woken handler's send_request finally pops from that dict + before the next __next__() — RuntimeError: dictionary changed size. + 2. The woken handler's MCPError is caught in _handle_request, which falls + through to respond() against a write stream _receive_loop already closed. + """ + handlers_started = 0 + both_started = anyio.Event() + server_run_returned = anyio.Event() + + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + nonlocal handlers_started + handlers_started += 1 + if handlers_started == 2: + both_started.set() + # Blocks on send_request waiting for a client response that never comes. + # _receive_loop's finally will wake this with CONNECTION_CLOSED. + await ctx.session.list_roots() + raise AssertionError # pragma: no cover + + server = Server("test", on_call_tool=handle_call_tool) + + to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10) + server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server(): + await server.run(server_read, server_write, server.create_initialization_options()) + server_run_returned.set() + + init_req = JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=InitializeRequestParams( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ClientCapabilities(), + client_info=Implementation(name="test", version="1.0"), + ).model_dump(by_alias=True, mode="json", exclude_none=True), + ) + initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized") + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server: + tg.start_soon(run_server) + + await to_server.send(SessionMessage(init_req)) + await from_server.receive() # init response + await to_server.send(SessionMessage(initialized)) + + # Two tool calls → two handlers → two _response_streams entries. + for rid in (2, 3): + call_req = JSONRPCRequest( + jsonrpc="2.0", + id=rid, + method="tools/call", + params=CallToolRequestParams(name="t", arguments={}).model_dump(by_alias=True, mode="json"), + ) + await to_server.send(SessionMessage(call_req)) + + await both_started.wait() + # Drain the two roots/list requests so send_request's _write_stream.send() + # completes and both handlers are parked at response_stream_reader.receive(). + await from_server.receive() + await from_server.receive() + + await to_server.aclose() + + # Without the fixes: RuntimeError (dict mutation) or ClosedResourceError + # (respond after write-stream close) escapes run_server and this hangs. + await server_run_returned.wait() diff --git a/tests/server/test_completion_with_context.py b/tests/server/test_completion_with_context.py index 5a8d67f09e..a01d0d4d72 100644 --- a/tests/server/test_completion_with_context.py +++ b/tests/server/test_completion_with_context.py @@ -1,15 +1,13 @@ """Tests for completion handler with context functionality.""" -from typing import Any - import pytest from mcp import Client -from mcp.server.lowlevel import Server +from mcp.server import Server, ServerRequestContext from mcp.types import ( + CompleteRequestParams, + CompleteResult, Completion, - CompletionArgument, - CompletionContext, PromptReference, ResourceTemplateReference, ) @@ -18,23 +16,15 @@ @pytest.mark.anyio async def test_completion_handler_receives_context(): """Test that the completion handler receives context correctly.""" - server = Server("test-server") - # Track what the handler receives - received_args: dict[str, Any] = {} + received_params: CompleteRequestParams | None = None - @server.completion() - async def handle_completion( - ref: PromptReference | ResourceTemplateReference, - argument: CompletionArgument, - context: CompletionContext | None, - ) -> Completion | None: - received_args["ref"] = ref - received_args["argument"] = argument - received_args["context"] = context + async def handle_completion(ctx: ServerRequestContext, params: CompleteRequestParams) -> CompleteResult: + nonlocal received_params + received_params = params + return CompleteResult(completion=Completion(values=["test-completion"], total=1, has_more=False)) - # Return test completion - return Completion(values=["test-completion"], total=1, has_more=False) + server = Server("test-server", on_completion=handle_completion) async with Client(server) as client: # Test with context @@ -45,28 +35,23 @@ async def handle_completion( ) # Verify handler received the context - assert received_args["context"] is not None - assert received_args["context"].arguments == {"previous": "value"} + assert received_params is not None + assert received_params.context is not None + assert received_params.context.arguments == {"previous": "value"} assert result.completion.values == ["test-completion"] @pytest.mark.anyio async def test_completion_backward_compatibility(): """Test that completion works without context (backward compatibility).""" - server = Server("test-server") - context_was_none = False - @server.completion() - async def handle_completion( - ref: PromptReference | ResourceTemplateReference, - argument: CompletionArgument, - context: CompletionContext | None, - ) -> Completion | None: + async def handle_completion(ctx: ServerRequestContext, params: CompleteRequestParams) -> CompleteResult: nonlocal context_was_none - context_was_none = context is None + context_was_none = params.context is None + return CompleteResult(completion=Completion(values=["no-context-completion"], total=1, has_more=False)) - return Completion(values=["no-context-completion"], total=1, has_more=False) + server = Server("test-server", on_completion=handle_completion) async with Client(server) as client: # Test without context @@ -82,30 +67,31 @@ async def handle_completion( @pytest.mark.anyio async def test_dependent_completion_scenario(): """Test a real-world scenario with dependent completions.""" - server = Server("test-server") - - @server.completion() - async def handle_completion( - ref: PromptReference | ResourceTemplateReference, - argument: CompletionArgument, - context: CompletionContext | None, - ) -> Completion | None: + + async def handle_completion(ctx: ServerRequestContext, params: CompleteRequestParams) -> CompleteResult: # Simulate database/table completion scenario - if isinstance(ref, ResourceTemplateReference): - if ref.uri == "db://{database}/{table}": - if argument.name == "database": - # Complete database names - return Completion(values=["users_db", "products_db", "analytics_db"], total=3, has_more=False) - elif argument.name == "table": - # Complete table names based on selected database - if context and context.arguments: - db = context.arguments.get("database") - if db == "users_db": - return Completion(values=["users", "sessions", "permissions"], total=3, has_more=False) - elif db == "products_db": - return Completion(values=["products", "categories", "inventory"], total=3, has_more=False) - - return Completion(values=[], total=0, has_more=False) # pragma: no cover + assert isinstance(params.ref, ResourceTemplateReference) + assert params.ref.uri == "db://{database}/{table}" + + if params.argument.name == "database": + return CompleteResult( + completion=Completion(values=["users_db", "products_db", "analytics_db"], total=3, has_more=False) + ) + + assert params.argument.name == "table" + assert params.context and params.context.arguments + db = params.context.arguments.get("database") + if db == "users_db": + return CompleteResult( + completion=Completion(values=["users", "sessions", "permissions"], total=3, has_more=False) + ) + else: + assert db == "products_db" + return CompleteResult( + completion=Completion(values=["products", "categories", "inventory"], total=3, has_more=False) + ) + + server = Server("test-server", on_completion=handle_completion) async with Client(server) as client: # First, complete database @@ -136,27 +122,20 @@ async def handle_completion( @pytest.mark.anyio async def test_completion_error_on_missing_context(): """Test that server can raise error when required context is missing.""" - server = Server("test-server") - - @server.completion() - async def handle_completion( - ref: PromptReference | ResourceTemplateReference, - argument: CompletionArgument, - context: CompletionContext | None, - ) -> Completion | None: - if isinstance(ref, ResourceTemplateReference): - if ref.uri == "db://{database}/{table}": - if argument.name == "table": - # Check if database context is provided - if not context or not context.arguments or "database" not in context.arguments: - # Raise an error instead of returning error as completion - raise ValueError("Please select a database first to see available tables") - # Normal completion if context is provided - db = context.arguments.get("database") - if db == "test_db": - return Completion(values=["users", "orders", "products"], total=3, has_more=False) - - return Completion(values=[], total=0, has_more=False) # pragma: no cover + + async def handle_completion(ctx: ServerRequestContext, params: CompleteRequestParams) -> CompleteResult: + assert isinstance(params.ref, ResourceTemplateReference) + assert params.ref.uri == "db://{database}/{table}" + assert params.argument.name == "table" + + if not params.context or not params.context.arguments or "database" not in params.context.arguments: + raise ValueError("Please select a database first to see available tables") + + db = params.context.arguments.get("database") + assert db == "test_db" + return CompleteResult(completion=Completion(values=["users", "orders", "products"], total=3, has_more=False)) + + server = Server("test-server", on_completion=handle_completion) async with Client(server) as client: # Try to complete table without database context - should raise error diff --git a/tests/server/test_lifespan.py b/tests/server/test_lifespan.py index a303664a54..0d87905042 100644 --- a/tests/server/test_lifespan.py +++ b/tests/server/test_lifespan.py @@ -2,18 +2,19 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager -from typing import Any import anyio import pytest from pydantic import TypeAdapter +from mcp.server import ServerRequestContext from mcp.server.lowlevel.server import NotificationOptions, Server from mcp.server.mcpserver import Context, MCPServer from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession from mcp.shared.message import SessionMessage from mcp.types import ( + CallToolRequestParams, + CallToolResult, ClientCapabilities, Implementation, InitializeRequestParams, @@ -39,20 +40,20 @@ async def test_lifespan(server: Server) -> AsyncIterator[dict[str, bool]]: finally: context["shutdown"] = True - server = Server[dict[str, bool]]("test", lifespan=test_lifespan) - - # Create memory streams for testing - send_stream1, receive_stream1 = anyio.create_memory_object_stream[SessionMessage](100) - send_stream2, receive_stream2 = anyio.create_memory_object_stream[SessionMessage](100) - # Create a tool that accesses lifespan context - @server.call_tool() - async def check_lifespan(name: str, arguments: dict[str, Any]) -> list[TextContent]: - ctx = server.request_context + async def check_lifespan( + ctx: ServerRequestContext[dict[str, bool]], params: CallToolRequestParams + ) -> CallToolResult: assert isinstance(ctx.lifespan_context, dict) assert ctx.lifespan_context["started"] assert not ctx.lifespan_context["shutdown"] - return [TextContent(type="text", text="true")] + return CallToolResult(content=[TextContent(type="text", text="true")]) + + server = Server[dict[str, bool]]("test", lifespan=test_lifespan, on_call_tool=check_lifespan) + + # Create memory streams for testing + send_stream1, receive_stream1 = anyio.create_memory_object_stream[SessionMessage](100) + send_stream2, receive_stream2 = anyio.create_memory_object_stream[SessionMessage](100) # Run server in background task async with anyio.create_task_group() as tg, send_stream1, receive_stream1, send_stream2, receive_stream2: @@ -141,7 +142,7 @@ async def test_lifespan(server: MCPServer) -> AsyncIterator[dict[str, bool]]: # Add a tool that checks lifespan context @server.tool() - def check_lifespan(ctx: Context[ServerSession, None]) -> bool: + def check_lifespan(ctx: Context) -> bool: """Tool that checks lifespan context.""" assert isinstance(ctx.request_context.lifespan_context, dict) assert ctx.request_context.lifespan_context["started"] diff --git a/tests/server/test_lowlevel_exception_handling.py b/tests/server/test_lowlevel_exception_handling.py index 848b35b299..46925916d9 100644 --- a/tests/server/test_lowlevel_exception_handling.py +++ b/tests/server/test_lowlevel_exception_handling.py @@ -1,55 +1,42 @@ from unittest.mock import AsyncMock, Mock +import anyio import pytest from mcp import types from mcp.server.lowlevel.server import Server from mcp.server.session import ServerSession +from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder @pytest.mark.anyio async def test_exception_handling_with_raise_exceptions_true(): - """Test that exceptions are re-raised when raise_exceptions=True""" + """Transport exceptions are re-raised when raise_exceptions=True.""" server = Server("test-server") session = Mock(spec=ServerSession) - session.send_log_message = AsyncMock() test_exception = RuntimeError("Test error") with pytest.raises(RuntimeError, match="Test error"): await server._handle_message(test_exception, session, {}, raise_exceptions=True) - session.send_log_message.assert_called_once() - @pytest.mark.anyio -@pytest.mark.parametrize( - "exception_class,message", - [ - (ValueError, "Test validation error"), - (RuntimeError, "Test runtime error"), - (KeyError, "Test key error"), - (Exception, "Basic error"), - ], -) -async def test_exception_handling_with_raise_exceptions_false(exception_class: type[Exception], message: str): - """Test that exceptions are logged when raise_exceptions=False""" +async def test_exception_handling_with_raise_exceptions_false(): + """Transport exceptions are logged locally but not sent to the client. + + The transport that reported the error is likely broken; writing back + through it races with stream closure (#1967, #2064). The TypeScript, + Go, and C# SDKs all log locally only. + """ server = Server("test-server") session = Mock(spec=ServerSession) session.send_log_message = AsyncMock() - test_exception = exception_class(message) - - await server._handle_message(test_exception, session, {}, raise_exceptions=False) - - # Should send log message - session.send_log_message.assert_called_once() - call_args = session.send_log_message.call_args + await server._handle_message(RuntimeError("Test error"), session, {}, raise_exceptions=False) - assert call_args.kwargs["level"] == "error" - assert call_args.kwargs["data"] == "Internal Server Error" - assert call_args.kwargs["logger"] == "mcp.server.exception_handler" + session.send_log_message.assert_not_called() @pytest.mark.anyio @@ -72,3 +59,48 @@ async def test_normal_message_handling_not_affected(): # Verify _handle_request was called server._handle_request.assert_called_once() + + +@pytest.mark.anyio +async def test_server_run_exits_cleanly_when_transport_yields_exception_then_closes(): + """Regression test for #1967 / #2064. + + Exercises the real Server.run() path with real memory streams, reproducing + what happens in stateless streamable HTTP when a POST handler throws: + + 1. Transport yields an Exception into the read stream + (streamable_http.py does this in its broad POST-handler except). + 2. Transport closes the read stream (terminate() in stateless mode). + 3. _receive_loop exits its `async with read_stream, write_stream:` block, + closing the write stream. + 4. Meanwhile _handle_message(exc) was spawned via tg.start_soon and runs + after the write stream is closed. + + Before the fix, _handle_message tried to send_log_message through the + closed write stream, raising ClosedResourceError inside the TaskGroup + and crashing server.run(). After the fix, it only logs locally. + """ + server = Server("test-server") + + read_send, read_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + # Zero-buffer on the write stream forces send() to block until received. + # With no receiver, a send() sits blocked until _receive_loop exits its + # `async with self._read_stream, self._write_stream:` block and closes the + # stream, at which point the blocked send raises ClosedResourceError. + # This deterministically reproduces the race without sleeps. + write_send, write_recv = anyio.create_memory_object_stream[SessionMessage](0) + + # What the streamable HTTP transport does: push the exception, then close. + read_send.send_nowait(RuntimeError("simulated transport error")) + read_send.close() + + with anyio.fail_after(5): + # stateless=True so server.run doesn't wait for initialize handshake. + # Before this fix, this raised ExceptionGroup(ClosedResourceError). + await server.run(read_recv, write_send, server.create_initialization_options(), stateless=True) + + # write_send was closed inside _receive_loop's `async with`; receive_nowait + # raises EndOfStream iff the buffer is empty (i.e., server wrote nothing). + with pytest.raises(anyio.EndOfStream): + write_recv.receive_nowait() + write_recv.close() diff --git a/tests/server/test_lowlevel_input_validation.py b/tests/server/test_lowlevel_input_validation.py deleted file mode 100644 index 3f977bcc1b..0000000000 --- a/tests/server/test_lowlevel_input_validation.py +++ /dev/null @@ -1,311 +0,0 @@ -"""Test input schema validation for lowlevel server.""" - -import logging -from collections.abc import Awaitable, Callable -from typing import Any - -import anyio -import pytest - -from mcp.client.session import ClientSession -from mcp.server import Server -from mcp.server.lowlevel import NotificationOptions -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession -from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder -from mcp.types import CallToolResult, ClientResult, ServerNotification, ServerRequest, TextContent, Tool - - -async def run_tool_test( - tools: list[Tool], - call_tool_handler: Callable[[str, dict[str, Any]], Awaitable[list[TextContent]]], - test_callback: Callable[[ClientSession], Awaitable[CallToolResult]], -) -> CallToolResult | None: - """Helper to run a tool test with minimal boilerplate. - - Args: - tools: List of tools to register - call_tool_handler: Handler function for tool calls - test_callback: Async function that performs the test using the client session - - Returns: - The result of the tool call - """ - server = Server("test") - result = None - - @server.list_tools() - async def list_tools(): - return tools - - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: - return await call_tool_handler(name, arguments) - - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - # Message handler for client - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: - if isinstance(message, Exception): # pragma: no cover - raise message - - # Server task - async def run_server(): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), - ) as server_session: - async with anyio.create_task_group() as tg: - - async def handle_messages(): - async for message in server_session.incoming_messages: # pragma: no branch - await server._handle_message(message, server_session, {}, False) - - tg.start_soon(handle_messages) - await anyio.sleep_forever() - - # Run the test - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - # Initialize the session - await client_session.initialize() - - # Run the test callback - result = await test_callback(client_session) - - # Cancel the server task - tg.cancel_scope.cancel() - - return result - - -def create_add_tool() -> Tool: - """Create a standard 'add' tool for testing.""" - return Tool( - name="add", - description="Add two numbers", - input_schema={ - "type": "object", - "properties": { - "a": {"type": "number"}, - "b": {"type": "number"}, - }, - "required": ["a", "b"], - "additionalProperties": False, - }, - ) - - -@pytest.mark.anyio -async def test_valid_tool_call(): - """Test that valid arguments pass validation.""" - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: - if name == "add": - result = arguments["a"] + arguments["b"] - return [TextContent(type="text", text=f"Result: {result}")] - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("add", {"a": 5, "b": 3}) - - result = await run_tool_test([create_add_tool()], call_tool_handler, test_callback) - - # Verify results - assert result is not None - assert not result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Result: 8" - - -@pytest.mark.anyio -async def test_invalid_tool_call_missing_required(): - """Test that missing required arguments fail validation.""" - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: # pragma: no cover - # This should not be reached due to validation - raise RuntimeError("Should not reach here") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("add", {"a": 5}) # missing 'b' - - result = await run_tool_test([create_add_tool()], call_tool_handler, test_callback) - - # Verify results - assert result is not None - assert result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert "Input validation error" in result.content[0].text - assert "'b' is a required property" in result.content[0].text - - -@pytest.mark.anyio -async def test_invalid_tool_call_wrong_type(): - """Test that wrong argument types fail validation.""" - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: # pragma: no cover - # This should not be reached due to validation - raise RuntimeError("Should not reach here") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("add", {"a": "five", "b": 3}) # 'a' should be number - - result = await run_tool_test([create_add_tool()], call_tool_handler, test_callback) - - # Verify results - assert result is not None - assert result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert "Input validation error" in result.content[0].text - assert "'five' is not of type 'number'" in result.content[0].text - - -@pytest.mark.anyio -async def test_cache_refresh_on_missing_tool(): - """Test that tool cache is refreshed when tool is not found.""" - tools = [ - Tool( - name="multiply", - description="Multiply two numbers", - input_schema={ - "type": "object", - "properties": { - "x": {"type": "number"}, - "y": {"type": "number"}, - }, - "required": ["x", "y"], - }, - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: - if name == "multiply": - result = arguments["x"] * arguments["y"] - return [TextContent(type="text", text=f"Result: {result}")] - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - # Call tool without first listing tools (cache should be empty) - # The cache should be refreshed automatically - return await client_session.call_tool("multiply", {"x": 10, "y": 20}) - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify results - should work because cache will be refreshed - assert result is not None - assert not result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Result: 200" - - -@pytest.mark.anyio -async def test_enum_constraint_validation(): - """Test that enum constraints are validated.""" - tools = [ - Tool( - name="greet", - description="Greet someone", - input_schema={ - "type": "object", - "properties": { - "name": {"type": "string"}, - "title": {"type": "string", "enum": ["Mr", "Ms", "Dr"]}, - }, - "required": ["name"], - }, - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: # pragma: no cover - # This should not be reached due to validation failure - raise RuntimeError("Should not reach here") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("greet", {"name": "Smith", "title": "Prof"}) # Invalid title - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify results - assert result is not None - assert result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert "Input validation error" in result.content[0].text - assert "'Prof' is not one of" in result.content[0].text - - -@pytest.mark.anyio -async def test_tool_not_in_list_logs_warning(caplog: pytest.LogCaptureFixture): - """Test that calling a tool not in list_tools logs a warning and skips validation.""" - tools = [ - Tool( - name="add", - description="Add two numbers", - input_schema={ - "type": "object", - "properties": { - "a": {"type": "number"}, - "b": {"type": "number"}, - }, - "required": ["a", "b"], - }, - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: - # This should be reached since validation is skipped for unknown tools - if name == "unknown_tool": - # Even with invalid arguments, this should execute since validation is skipped - return [TextContent(type="text", text="Unknown tool executed without validation")] - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - # Call a tool that's not in the list with invalid arguments - # This should trigger the warning about validation not being performed - return await client_session.call_tool("unknown_tool", {"invalid": "args"}) - - with caplog.at_level(logging.WARNING): - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify results - should succeed because validation is skipped for unknown tools - assert result is not None - assert not result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Unknown tool executed without validation" - - # Verify warning was logged - assert any( - "Tool 'unknown_tool' not listed, no validation will be performed" in record.message for record in caplog.records - ) diff --git a/tests/server/test_lowlevel_output_validation.py b/tests/server/test_lowlevel_output_validation.py deleted file mode 100644 index 92d9c047ca..0000000000 --- a/tests/server/test_lowlevel_output_validation.py +++ /dev/null @@ -1,476 +0,0 @@ -"""Test output schema validation for lowlevel server.""" - -import json -from collections.abc import Awaitable, Callable -from typing import Any - -import anyio -import pytest - -from mcp.client.session import ClientSession -from mcp.server import Server -from mcp.server.lowlevel import NotificationOptions -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession -from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder -from mcp.types import CallToolResult, ClientResult, ServerNotification, ServerRequest, TextContent, Tool - - -async def run_tool_test( - tools: list[Tool], - call_tool_handler: Callable[[str, dict[str, Any]], Awaitable[Any]], - test_callback: Callable[[ClientSession], Awaitable[CallToolResult]], -) -> CallToolResult | None: - """Helper to run a tool test with minimal boilerplate. - - Args: - tools: List of tools to register - call_tool_handler: Handler function for tool calls - test_callback: Async function that performs the test using the client session - - Returns: - The result of the tool call - """ - server = Server("test") - - result = None - - @server.list_tools() - async def list_tools(): - return tools - - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]): - return await call_tool_handler(name, arguments) - - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - # Message handler for client - async def message_handler( # pragma: no cover - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: - if isinstance(message, Exception): - raise message - - # Server task - async def run_server(): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), - ) as server_session: - async with anyio.create_task_group() as tg: - - async def handle_messages(): - async for message in server_session.incoming_messages: # pragma: no branch - await server._handle_message(message, server_session, {}, False) - - tg.start_soon(handle_messages) - await anyio.sleep_forever() - - # Run the test - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - # Initialize the session - await client_session.initialize() - - # Run the test callback - result = await test_callback(client_session) - - # Cancel the server task - tg.cancel_scope.cancel() - - return result - - -@pytest.mark.anyio -async def test_content_only_without_output_schema(): - """Test returning content only when no outputSchema is defined.""" - tools = [ - Tool( - name="echo", - description="Echo a message", - input_schema={ - "type": "object", - "properties": { - "message": {"type": "string"}, - }, - "required": ["message"], - }, - # No outputSchema defined - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: - if name == "echo": - return [TextContent(type="text", text=f"Echo: {arguments['message']}")] - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("echo", {"message": "Hello"}) - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify results - assert result is not None - assert not result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Echo: Hello" - assert result.structured_content is None - - -@pytest.mark.anyio -async def test_dict_only_without_output_schema(): - """Test returning dict only when no outputSchema is defined.""" - tools = [ - Tool( - name="get_info", - description="Get structured information", - input_schema={ - "type": "object", - "properties": {}, - }, - # No outputSchema defined - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> dict[str, Any]: - if name == "get_info": - return {"status": "ok", "data": {"value": 42}} - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("get_info", {}) - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify results - assert result is not None - assert not result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - # Check that the content is the JSON serialization - assert json.loads(result.content[0].text) == {"status": "ok", "data": {"value": 42}} - assert result.structured_content == {"status": "ok", "data": {"value": 42}} - - -@pytest.mark.anyio -async def test_both_content_and_dict_without_output_schema(): - """Test returning both content and dict when no outputSchema is defined.""" - tools = [ - Tool( - name="process", - description="Process data", - input_schema={ - "type": "object", - "properties": {}, - }, - # No outputSchema defined - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> tuple[list[TextContent], dict[str, Any]]: - if name == "process": - content = [TextContent(type="text", text="Processing complete")] - data = {"result": "success", "count": 10} - return (content, data) - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("process", {}) - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify results - assert result is not None - assert not result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Processing complete" - assert result.structured_content == {"result": "success", "count": 10} - - -@pytest.mark.anyio -async def test_content_only_with_output_schema_error(): - """Test error when outputSchema is defined but only content is returned.""" - tools = [ - Tool( - name="structured_tool", - description="Tool expecting structured output", - input_schema={ - "type": "object", - "properties": {}, - }, - output_schema={ - "type": "object", - "properties": { - "result": {"type": "string"}, - }, - "required": ["result"], - }, - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: - # This returns only content, but outputSchema expects structured data - return [TextContent(type="text", text="This is not structured")] - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("structured_tool", {}) - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify error - assert result is not None - assert result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert "Output validation error: outputSchema defined but no structured output returned" in result.content[0].text - - -@pytest.mark.anyio -async def test_valid_dict_with_output_schema(): - """Test valid dict output matching outputSchema.""" - tools = [ - Tool( - name="calc", - description="Calculate result", - input_schema={ - "type": "object", - "properties": { - "x": {"type": "number"}, - "y": {"type": "number"}, - }, - "required": ["x", "y"], - }, - output_schema={ - "type": "object", - "properties": { - "sum": {"type": "number"}, - "product": {"type": "number"}, - }, - "required": ["sum", "product"], - }, - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> dict[str, Any]: - if name == "calc": - x = arguments["x"] - y = arguments["y"] - return {"sum": x + y, "product": x * y} - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("calc", {"x": 3, "y": 4}) - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify results - assert result is not None - assert not result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - # Check JSON serialization - assert json.loads(result.content[0].text) == {"sum": 7, "product": 12} - assert result.structured_content == {"sum": 7, "product": 12} - - -@pytest.mark.anyio -async def test_invalid_dict_with_output_schema(): - """Test dict output that doesn't match outputSchema.""" - tools = [ - Tool( - name="user_info", - description="Get user information", - input_schema={ - "type": "object", - "properties": {}, - }, - output_schema={ - "type": "object", - "properties": { - "name": {"type": "string"}, - "age": {"type": "integer"}, - }, - "required": ["name", "age"], - }, - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> dict[str, Any]: - if name == "user_info": - # Missing required 'age' field - return {"name": "Alice"} - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("user_info", {}) - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify error - assert result is not None - assert result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert "Output validation error:" in result.content[0].text - assert "'age' is a required property" in result.content[0].text - - -@pytest.mark.anyio -async def test_both_content_and_valid_dict_with_output_schema(): - """Test returning both content and valid dict with outputSchema.""" - tools = [ - Tool( - name="analyze", - description="Analyze data", - input_schema={ - "type": "object", - "properties": { - "text": {"type": "string"}, - }, - "required": ["text"], - }, - output_schema={ - "type": "object", - "properties": { - "sentiment": {"type": "string", "enum": ["positive", "negative", "neutral"]}, - "confidence": {"type": "number", "minimum": 0, "maximum": 1}, - }, - "required": ["sentiment", "confidence"], - }, - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> tuple[list[TextContent], dict[str, Any]]: - if name == "analyze": - content = [TextContent(type="text", text=f"Analysis of: {arguments['text']}")] - data = {"sentiment": "positive", "confidence": 0.95} - return (content, data) - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("analyze", {"text": "Great job!"}) - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify results - assert result is not None - assert not result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert result.content[0].text == "Analysis of: Great job!" - assert result.structured_content == {"sentiment": "positive", "confidence": 0.95} - - -@pytest.mark.anyio -async def test_tool_call_result(): - """Test returning ToolCallResult when no outputSchema is defined.""" - tools = [ - Tool( - name="get_info", - description="Get structured information", - input_schema={ - "type": "object", - "properties": {}, - }, - # No outputSchema for direct return of tool call result - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> CallToolResult: - if name == "get_info": - return CallToolResult( - content=[TextContent(type="text", text="Results calculated")], - structured_content={"status": "ok", "data": {"value": 42}}, - _meta={"some": "metadata"}, - ) - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("get_info", {}) - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify results - assert result is not None - assert not result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert result.content[0].text == "Results calculated" - assert isinstance(result.content[0], TextContent) - assert result.structured_content == {"status": "ok", "data": {"value": 42}} - assert result.meta == {"some": "metadata"} - - -@pytest.mark.anyio -async def test_output_schema_type_validation(): - """Test outputSchema validates types correctly.""" - tools = [ - Tool( - name="stats", - description="Get statistics", - input_schema={ - "type": "object", - "properties": {}, - }, - output_schema={ - "type": "object", - "properties": { - "count": {"type": "integer"}, - "average": {"type": "number"}, - "items": {"type": "array", "items": {"type": "string"}}, - }, - "required": ["count", "average", "items"], - }, - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> dict[str, Any]: - if name == "stats": - # Wrong type for 'count' - should be integer - return {"count": "five", "average": 2.5, "items": ["a", "b"]} - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("stats", {}) - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify error - assert result is not None - assert result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert "Output validation error:" in result.content[0].text - assert "'five' is not of type 'integer'" in result.content[0].text diff --git a/tests/server/test_lowlevel_tool_annotations.py b/tests/server/test_lowlevel_tool_annotations.py index 68543136eb..705abdfe8c 100644 --- a/tests/server/test_lowlevel_tool_annotations.py +++ b/tests/server/test_lowlevel_tool_annotations.py @@ -1,100 +1,44 @@ """Tests for tool annotations in low-level server.""" -import anyio import pytest -from mcp.client.session import ClientSession -from mcp.server import Server -from mcp.server.lowlevel import NotificationOptions -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession -from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder -from mcp.types import ClientResult, ServerNotification, ServerRequest, Tool, ToolAnnotations +from mcp import Client +from mcp.server import Server, ServerRequestContext +from mcp.types import ListToolsResult, PaginatedRequestParams, Tool, ToolAnnotations @pytest.mark.anyio async def test_lowlevel_server_tool_annotations(): """Test that tool annotations work in low-level server.""" - server = Server("test") - # Create a tool with annotations - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="echo", - description="Echo a message back", - input_schema={ - "type": "object", - "properties": { - "message": {"type": "string"}, + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="echo", + description="Echo a message back", + input_schema={ + "type": "object", + "properties": { + "message": {"type": "string"}, + }, + "required": ["message"], }, - "required": ["message"], - }, - annotations=ToolAnnotations( - title="Echo Tool", - read_only_hint=True, - ), - ) - ] - - tools_result = None - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - # Message handler for client - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: - if isinstance(message, Exception): # pragma: no cover - raise message - - # Server task - async def run_server(): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), - ) as server_session: - async with anyio.create_task_group() as tg: - - async def handle_messages(): - async for message in server_session.incoming_messages: # pragma: no branch - await server._handle_message(message, server_session, {}, False) - - tg.start_soon(handle_messages) - await anyio.sleep_forever() - - # Run the test - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - # Initialize the session - await client_session.initialize() - - # List tools - tools_result = await client_session.list_tools() - - # Cancel the server task - tg.cancel_scope.cancel() - - # Verify results - assert tools_result is not None - assert len(tools_result.tools) == 1 - assert tools_result.tools[0].name == "echo" - assert tools_result.tools[0].annotations is not None - assert tools_result.tools[0].annotations.title == "Echo Tool" - assert tools_result.tools[0].annotations.read_only_hint is True + annotations=ToolAnnotations( + title="Echo Tool", + read_only_hint=True, + ), + ) + ] + ) + + server = Server("test", on_list_tools=handle_list_tools) + + async with Client(server) as client: + tools_result = await client.list_tools() + + assert len(tools_result.tools) == 1 + assert tools_result.tools[0].name == "echo" + assert tools_result.tools[0].annotations is not None + assert tools_result.tools[0].annotations.title == "Echo Tool" + assert tools_result.tools[0].annotations.read_only_hint is True diff --git a/tests/server/test_read_resource.py b/tests/server/test_read_resource.py index 88fd1e38ff..102a58d039 100644 --- a/tests/server/test_read_resource.py +++ b/tests/server/test_read_resource.py @@ -1,106 +1,58 @@ -from collections.abc import Iterable -from pathlib import Path -from tempfile import NamedTemporaryFile +import base64 import pytest -from mcp import types -from mcp.server.lowlevel.server import ReadResourceContents, Server +from mcp import Client +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + BlobResourceContents, + ReadResourceRequestParams, + ReadResourceResult, + TextResourceContents, +) +pytestmark = pytest.mark.anyio -@pytest.fixture -def temp_file(): - """Create a temporary file for testing.""" - with NamedTemporaryFile(mode="w", delete=False) as f: - f.write("test content") - path = Path(f.name).resolve() - yield path - try: - path.unlink() - except FileNotFoundError: # pragma: no cover - pass +async def test_read_resource_text(): + async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: + return ReadResourceResult( + contents=[TextResourceContents(uri=str(params.uri), text="Hello World", mime_type="text/plain")] + ) -@pytest.mark.anyio -async def test_read_resource_text(temp_file: Path): - server = Server("test") + server = Server("test", on_read_resource=handle_read_resource) - @server.read_resource() - async def read_resource(uri: str) -> Iterable[ReadResourceContents]: - return [ReadResourceContents(content="Hello World", mime_type="text/plain")] + async with Client(server) as client: + result = await client.read_resource("test://resource") + assert len(result.contents) == 1 - # Get the handler directly from the server - handler = server.request_handlers[types.ReadResourceRequest] + content = result.contents[0] + assert isinstance(content, TextResourceContents) + assert content.text == "Hello World" + assert content.mime_type == "text/plain" - # Create a request - request = types.ReadResourceRequest( - params=types.ReadResourceRequestParams(uri=temp_file.as_uri()), - ) - # Call the handler - result = await handler(request) - assert isinstance(result, types.ReadResourceResult) - assert len(result.contents) == 1 +async def test_read_resource_binary(): + binary_data = b"Hello World" - content = result.contents[0] - assert isinstance(content, types.TextResourceContents) - assert content.text == "Hello World" - assert content.mime_type == "text/plain" + async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: + return ReadResourceResult( + contents=[ + BlobResourceContents( + uri=str(params.uri), + blob=base64.b64encode(binary_data).decode("utf-8"), + mime_type="application/octet-stream", + ) + ] + ) + server = Server("test", on_read_resource=handle_read_resource) -@pytest.mark.anyio -async def test_read_resource_binary(temp_file: Path): - server = Server("test") + async with Client(server) as client: + result = await client.read_resource("test://resource") + assert len(result.contents) == 1 - @server.read_resource() - async def read_resource(uri: str) -> Iterable[ReadResourceContents]: - return [ReadResourceContents(content=b"Hello World", mime_type="application/octet-stream")] - - # Get the handler directly from the server - handler = server.request_handlers[types.ReadResourceRequest] - - # Create a request - request = types.ReadResourceRequest( - params=types.ReadResourceRequestParams(uri=temp_file.as_uri()), - ) - - # Call the handler - result = await handler(request) - assert isinstance(result, types.ReadResourceResult) - assert len(result.contents) == 1 - - content = result.contents[0] - assert isinstance(content, types.BlobResourceContents) - assert content.mime_type == "application/octet-stream" - - -@pytest.mark.anyio -async def test_read_resource_default_mime(temp_file: Path): - server = Server("test") - - @server.read_resource() - async def read_resource(uri: str) -> Iterable[ReadResourceContents]: - return [ - ReadResourceContents( - content="Hello World", - # No mime_type specified, should default to text/plain - ) - ] - - # Get the handler directly from the server - handler = server.request_handlers[types.ReadResourceRequest] - - # Create a request - request = types.ReadResourceRequest( - params=types.ReadResourceRequestParams(uri=temp_file.as_uri()), - ) - - # Call the handler - result = await handler(request) - assert isinstance(result, types.ReadResourceResult) - assert len(result.contents) == 1 - - content = result.contents[0] - assert isinstance(content, types.TextResourceContents) - assert content.text == "Hello World" - assert content.mime_type == "text/plain" + content = result.contents[0] + assert isinstance(content, BlobResourceContents) + assert content.mime_type == "application/octet-stream" + assert base64.b64decode(content.blob) == binary_data diff --git a/tests/server/test_session.py b/tests/server/test_session.py index d353e46e45..a2786d865d 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -5,7 +5,7 @@ from mcp import types from mcp.client.session import ClientSession -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession @@ -14,17 +14,10 @@ from mcp.shared.session import RequestResponder from mcp.types import ( ClientNotification, - Completion, - CompletionArgument, - CompletionContext, CompletionsCapability, InitializedNotification, - Prompt, - PromptReference, PromptsCapability, - Resource, ResourcesCapability, - ResourceTemplateReference, ServerCapabilities, ) @@ -85,47 +78,50 @@ async def run_server(): @pytest.mark.anyio async def test_server_capabilities(): - server = Server("test") notification_options = NotificationOptions() experimental_capabilities: dict[str, Any] = {} - # Initially no capabilities + async def noop_list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListPromptsResult: + raise NotImplementedError + + async def noop_list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListResourcesResult: + raise NotImplementedError + + async def noop_completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> types.CompleteResult: + raise NotImplementedError + + # No capabilities + server = Server("test") caps = server.get_capabilities(notification_options, experimental_capabilities) assert caps.prompts is None assert caps.resources is None assert caps.completions is None - # Add a prompts handler - @server.list_prompts() - async def list_prompts() -> list[Prompt]: # pragma: no cover - return [] - + # With prompts handler + server = Server("test", on_list_prompts=noop_list_prompts) caps = server.get_capabilities(notification_options, experimental_capabilities) assert caps.prompts == PromptsCapability(list_changed=False) assert caps.resources is None assert caps.completions is None - # Add a resources handler - @server.list_resources() - async def list_resources() -> list[Resource]: # pragma: no cover - return [] - + # With prompts + resources handlers + server = Server("test", on_list_prompts=noop_list_prompts, on_list_resources=noop_list_resources) caps = server.get_capabilities(notification_options, experimental_capabilities) assert caps.prompts == PromptsCapability(list_changed=False) assert caps.resources == ResourcesCapability(subscribe=False, list_changed=False) assert caps.completions is None - # Add a complete handler - @server.completion() - async def complete( # pragma: no cover - ref: PromptReference | ResourceTemplateReference, - argument: CompletionArgument, - context: CompletionContext | None, - ) -> Completion | None: - return Completion( - values=["completion1", "completion2"], - ) - + # With prompts + resources + completion handlers + server = Server( + "test", + on_list_prompts=noop_list_prompts, + on_list_resources=noop_list_resources, + on_completion=noop_completion, + ) caps = server.get_capabilities(notification_options, experimental_capabilities) assert caps.prompts == PromptsCapability(list_changed=False) assert caps.resources == ResourcesCapability(subscribe=False, list_changed=False) diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index 010eaf6a25..e95dc51b31 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -1,27 +1,44 @@ -"""Tests for SSE server DNS rebinding protection.""" +"""Tests for SSE server request validation.""" import logging import multiprocessing +import re import socket +import anyio import httpx import pytest +import sse_starlette.sse import uvicorn from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response from starlette.routing import Mount, Route +from starlette.types import Message, Receive, Scope, Send from mcp.server import Server +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser +from mcp.server.auth.provider import AccessToken from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings -from mcp.types import Tool +from mcp.shared._stream_protocols import WriteStream +from mcp.shared.message import SessionMessage +from mcp.types import JSONRPCRequest, JSONRPCResponse, Tool from tests.test_helpers import wait_for_server logger = logging.getLogger(__name__) SERVER_NAME = "test_sse_security_server" +@pytest.fixture(autouse=True) +def reset_sse_starlette_exit_event() -> None: + """sse-starlette<2 caches a module-level anyio.Event on AppStatus; reset it + between tests so it is not bound to a previous test's event loop.""" + app_status = getattr(sse_starlette.sse, "AppStatus", None) + if app_status is not None and hasattr(app_status, "should_exit_event"): # pragma: lax no cover + app_status.should_exit_event = None + + @pytest.fixture def server_port() -> int: with socket.socket() as s: @@ -291,3 +308,270 @@ async def test_sse_security_post_valid_content_type(server_port: int): finally: process.terminate() process.join() + + +def _authenticated_user(client_id: str, subject: str | None = None, issuer: str | None = None) -> AuthenticatedUser: + """Build the scope["user"] value that AuthenticationMiddleware would set for this principal.""" + claims = {"iss": issuer} if issuer is not None else None + return AuthenticatedUser(AccessToken(token="token", client_id=client_id, scopes=[], subject=subject, claims=claims)) + + +def _sse_scope( + method: str, path: str, user: AuthenticatedUser | None, *, query_string: bytes = b"", body: bytes = b"" +) -> tuple[Scope, Receive, Send, list[Message]]: + """Build an ASGI scope/receive/send triple for a request to the SSE transport.""" + scope: Scope = { + "type": "http", + "method": method, + "path": path, + "root_path": "", + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + } + if user is not None: + scope["user"] = user + sent: list[Message] = [] + + async def receive() -> Message: + return {"type": "http.request", "body": body, "more_body": False} + + async def send(message: Message) -> None: + sent.append(message) + + return scope, receive, send, sent + + +def _response_status(sent: list[Message]) -> int: + response_start = next(msg for msg in sent if msg["type"] == "http.response.start") + return response_start["status"] + + +async def _post_message(transport: SseServerTransport, session_id: str, user: AuthenticatedUser | None) -> int: + """POST a message to an SSE session as `user` and return the response status.""" + body = b'{"jsonrpc": "2.0", "id": 1, "method": "ping", "params": null}' + scope, receive, send, sent = _sse_scope( + "POST", "/messages/", user, query_string=f"session_id={session_id}".encode(), body=body + ) + await transport.handle_post_message(scope, receive, send) + return _response_status(sent) + + +_Principal = tuple[str] | tuple[str, str] | tuple[str, str, str] + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("creator", "sender", "expected"), + [ + pytest.param(("client-a",), ("client-b",), 404, id="different-client"), + pytest.param(("client-a",), None, 404, id="unauthenticated-sender"), + pytest.param(("client-a", "alice"), ("client-a", "bob"), 404, id="same-client-different-subject"), + pytest.param(("client-a", "alice"), ("client-a",), 404, id="same-client-no-subject"), + pytest.param( + ("client-a", "alice", "https://i1"), ("client-a", "alice", "https://i2"), 404, id="different-issuer" + ), + pytest.param(None, ("client-a",), 404, id="unauthenticated-creator"), + pytest.param(("client-a",), ("client-a",), 202, id="same-client"), + pytest.param(("client-a", "alice"), ("client-a", "alice"), 202, id="same-client-and-subject"), + pytest.param(None, None, 202, id="both-unauthenticated"), + ], +) +async def test_sse_post_requires_the_credential_that_created_the_session( + creator: _Principal | None, + sender: _Principal | None, + expected: int, +): + """The session endpoint URL issued to one authenticated principal must not + accept messages from a request authenticated as a different one.""" + transport = SseServerTransport("/messages/") + session_id_received = anyio.Event() + session_ids: list[str] = [] + client_disconnected = anyio.Event() + + async def get_send(message: Message) -> None: + # The first body chunk is the SSE event announcing the session URI to POST messages to. + if message["type"] == "http.response.body" and not session_ids: + match = re.search(rb"session_id=([0-9a-f]{32})", message.get("body", b"")) + assert match is not None, f"expected the endpoint event first, got {message!r}" + session_ids.append(match.group(1).decode()) + session_id_received.set() + + async def get_receive() -> Message: + # The SSE client stays connected until the test signals otherwise. + await client_disconnected.wait() + return {"type": "http.disconnect"} + + creator_user = _authenticated_user(*creator) if creator is not None else None + sender_user = _authenticated_user(*sender) if sender is not None else None + + async def hold_sse_connection() -> None: + """Establish the SSE session as `creator` and keep it open, as a server would.""" + scope, _, _, _ = _sse_scope("GET", "/sse", creator_user) + with anyio.fail_after(5): + async with transport.connect_sse(scope, get_receive, get_send) as (read_stream, write_stream): + async with read_stream, write_stream: # pragma: no branch + async for _ in read_stream: + pass + + async with anyio.create_task_group() as tg: + tg.start_soon(hold_sse_connection) + with anyio.fail_after(5): + await session_id_received.wait() + + assert await _post_message(transport, session_ids[0], sender_user) == expected + + client_disconnected.set() + + # Once the connection is gone the session is no longer routable. + assert await _post_message(transport, session_ids[0], creator_user) == 404 + + +@pytest.mark.anyio +async def test_sse_connect_rejects_a_non_http_scope(): + """connect_sse refuses ASGI scopes that are not HTTP requests.""" + transport = SseServerTransport("/messages/") + with pytest.raises(ValueError): + async with transport.connect_sse({"type": "websocket"}, _no_receive, _no_send): + raise NotImplementedError + + +@pytest.mark.anyio +async def test_sse_connect_rejects_a_disallowed_host(): + """connect_sse rejects requests whose Host header fails the configured security check.""" + settings = TransportSecuritySettings(allowed_hosts=["allowed.example.com"]) + transport = SseServerTransport("/messages/", security_settings=settings) + scope, receive, send, sent = _sse_scope("GET", "/sse", None) + scope["headers"] = [(b"host", b"disallowed.example.com")] + + with pytest.raises(ValueError): + async with transport.connect_sse(scope, receive, send): + raise NotImplementedError + assert _response_status(sent) == 421 + + +@pytest.mark.anyio +async def test_sse_post_without_a_session_id_returns_400(): + """POSTs to the messages endpoint must include a session_id query parameter.""" + transport = SseServerTransport("/messages/") + scope, receive, send, sent = _sse_scope("POST", "/messages/", None) + + await transport.handle_post_message(scope, receive, send) + assert _response_status(sent) == 400 + + +@pytest.mark.anyio +async def test_sse_post_with_a_malformed_session_id_returns_400(): + """A session_id that is not 32 hex characters is rejected before any session lookup.""" + transport = SseServerTransport("/messages/") + scope, receive, send, sent = _sse_scope("POST", "/messages/", None, query_string=b"session_id=not-hex") + + await transport.handle_post_message(scope, receive, send) + assert _response_status(sent) == 400 + + +@pytest.mark.anyio +async def test_sse_post_with_a_disallowed_host_is_rejected_before_session_lookup(): + """The transport security check on POST runs before any session-ID handling.""" + settings = TransportSecuritySettings(allowed_hosts=["allowed.example.com"]) + transport = SseServerTransport("/messages/", security_settings=settings) + scope, receive, send, sent = _sse_scope("POST", "/messages/", None) + scope["headers"] = [(b"host", b"disallowed.example.com"), (b"content-type", b"application/json")] + + await transport.handle_post_message(scope, receive, send) + assert _response_status(sent) == 421 + + +@pytest.mark.anyio +async def test_sse_round_trip_delivers_posted_messages_and_streams_responses(): + """A POSTed JSON-RPC message reaches the server's read stream, and a message + written to the server's write stream is sent to the client as an SSE event.""" + transport = SseServerTransport("/messages/") + session = _SseSession(transport) + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + tg.start_soon(session.hold) + await session.ready.wait() + + # POST a parse-failing body: client gets 400, server's read stream receives the error. + scope, receive, send, sent = _sse_scope( + "POST", "/messages/", None, query_string=f"session_id={session.session_id}".encode(), body=b"not json" + ) + await transport.handle_post_message(scope, receive, send) + assert _response_status(sent) == 400 + assert isinstance(await session.next_read_item(), Exception) + + # POST a valid message: client gets 202, server's read stream receives it. + assert await _post_message(transport, session.session_id, None) == 202 + received = await session.next_read_item() + assert isinstance(received, SessionMessage) + assert isinstance(received.message, JSONRPCRequest) + assert received.message.method == "ping" + + # Server writes a response: it appears as an SSE `message` event on the GET stream. + outgoing = JSONRPCResponse(jsonrpc="2.0", id=1, result={}) + await session.write_stream.send(SessionMessage(outgoing)) + chunk = await session.next_body_chunk() + assert b"event: message" in chunk + assert outgoing.model_dump_json(by_alias=True, exclude_unset=True).encode() in chunk + + session.disconnect() + + +class _SseSession: + """Drive an in-process SSE GET connection and surface what the server reads and the client receives. + + `hold` runs the connection in a background task and consumes the server-side read stream + into a buffer so that `handle_post_message` (which writes to that stream with a zero-capacity + channel) never blocks the test body. + """ + + def __init__(self, transport: SseServerTransport) -> None: + self.transport = transport + self.ready = anyio.Event() + self._disconnected = anyio.Event() + self._body_send, self._body_recv = anyio.create_memory_object_stream[bytes](16) + self._read_send, self._read_recv = anyio.create_memory_object_stream[SessionMessage | Exception](16) + self.session_id = "" + self.write_stream: WriteStream[SessionMessage] + + async def hold(self) -> None: + scope, _, _, _ = _sse_scope("GET", "/sse", None) + async with self.transport.connect_sse(scope, self._receive, self._send) as (read, write): + self.write_stream = write + async with read, write, self._body_send, self._body_recv, self._read_send, self._read_recv: + async for item in read: + await self._read_send.send(item) + + def disconnect(self) -> None: + self._disconnected.set() + + async def next_read_item(self) -> SessionMessage | Exception: + return await self._read_recv.receive() + + async def next_body_chunk(self) -> bytes: + return await self._body_recv.receive() + + async def _receive(self) -> Message: + await self._disconnected.wait() + return {"type": "http.disconnect"} + + async def _send(self, message: Message) -> None: + if message["type"] != "http.response.body": + return + body: bytes = message.get("body", b"") + if not self.session_id: + match = re.search(rb"session_id=([0-9a-f]{32})", body) + assert match is not None, f"expected the endpoint event first, got {message!r}" + self.session_id = match.group(1).decode() + self.ready.set() + else: + await self._body_send.send(body) + + +async def _no_receive() -> Message: + raise NotImplementedError + + +async def _no_send(message: Message) -> None: + raise NotImplementedError diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index 9a7ddaab40..677a993567 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -1,4 +1,6 @@ import io +import sys +from io import TextIOWrapper import anyio import pytest @@ -59,3 +61,34 @@ async def test_stdio_server(): assert len(received_responses) == 2 assert received_responses[0] == JSONRPCRequest(jsonrpc="2.0", id=3, method="ping") assert received_responses[1] == JSONRPCResponse(jsonrpc="2.0", id=4, result={}) + + +@pytest.mark.anyio +async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch): + """Non-UTF-8 bytes on stdin must not crash the server. + + Invalid bytes are replaced with U+FFFD, which then fails JSON parsing and + is delivered as an in-stream exception. Subsequent valid messages must + still be processed. + """ + # \xff\xfe are invalid UTF-8 start bytes. + valid = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + raw_stdin = io.BytesIO(b"\xff\xfe\n" + valid.model_dump_json(by_alias=True, exclude_none=True).encode() + b"\n") + + # Replace sys.stdin with a wrapper whose .buffer is our raw bytes, so that + # stdio_server()'s default path wraps it with errors='replace'. + monkeypatch.setattr(sys, "stdin", TextIOWrapper(raw_stdin, encoding="utf-8")) + monkeypatch.setattr(sys, "stdout", TextIOWrapper(io.BytesIO(), encoding="utf-8")) + + with anyio.fail_after(5): + async with stdio_server() as (read_stream, write_stream): + await write_stream.aclose() + async with read_stream: # pragma: no branch + # First line: \xff\xfe -> U+FFFD U+FFFD -> JSON parse fails -> exception in stream + first = await read_stream.receive() + assert isinstance(first, Exception) + + # Second line: valid message still comes through + second = await read_stream.receive() + assert isinstance(second, SessionMessage) + assert second.message == valid diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 51572baa9c..ba75547964 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -1,21 +1,23 @@ """Tests for StreamableHTTPSessionManager.""" import json +import logging from typing import Any from unittest.mock import AsyncMock, patch import anyio import httpx import pytest -from starlette.types import Message +from starlette.types import Message, Scope -from mcp import Client, types +from mcp import Client from mcp.client.streamable_http import streamable_http_client -from mcp.server import streamable_http_manager -from mcp.server.lowlevel import Server +from mcp.server import Server, ServerRequestContext, streamable_http_manager +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser +from mcp.server.auth.provider import AccessToken from mcp.server.streamable_http import MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport from mcp.server.streamable_http_manager import StreamableHTTPSessionManager -from mcp.types import INVALID_REQUEST +from mcp.types import INVALID_REQUEST, ListToolsResult, PaginatedRequestParams @pytest.mark.anyio @@ -218,7 +220,7 @@ async def test_stateless_requests_memory_cleanup(): # Patch StreamableHTTPServerTransport constructor to track instances - original_constructor = streamable_http_manager.StreamableHTTPServerTransport + original_constructor = StreamableHTTPServerTransport def track_transport(*args: Any, **kwargs: Any) -> StreamableHTTPServerTransport: transport = original_constructor(*args, **kwargs) @@ -270,7 +272,7 @@ async def mock_receive(): @pytest.mark.anyio -async def test_unknown_session_id_returns_404(): +async def test_unknown_session_id_returns_404(caplog: pytest.LogCaptureFixture): """Test that requests with unknown session IDs return HTTP 404 per MCP spec.""" app = Server("test-unknown-session") manager = StreamableHTTPSessionManager(app=app) @@ -300,7 +302,8 @@ async def mock_send(message: Message): async def mock_receive(): return {"type": "http.request", "body": b"{}", "more_body": False} # pragma: no cover - await manager.handle_request(scope, mock_receive, mock_send) + with caplog.at_level(logging.INFO): + await manager.handle_request(scope, mock_receive, mock_send) # Find the response start message response_start = next( @@ -313,20 +316,20 @@ async def mock_receive(): # Verify JSON-RPC error format error_data = json.loads(response_body) assert error_data["jsonrpc"] == "2.0" - assert error_data["id"] == "server-error" + assert error_data["id"] is None assert error_data["error"]["code"] == INVALID_REQUEST assert error_data["error"]["message"] == "Session not found" + assert "Rejected request with unknown or expired session ID: non-existent-session-id" in caplog.text @pytest.mark.anyio async def test_e2e_streamable_http_server_cleanup(): host = "testserver" - app = Server("test-server") - @app.list_tools() - async def list_tools(req: types.ListToolsRequest) -> types.ListToolsResult: - return types.ListToolsResult(tools=[]) + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[]) + app = Server("test-server", on_list_tools=handle_list_tools) mcp_app = app.streamable_http_app(host=host) async with ( mcp_app.router.lifespan_context(mcp_app), @@ -335,3 +338,243 @@ async def list_tools(req: types.ListToolsRequest) -> types.ListToolsResult: Client(streamable_http_client(f"http://{host}/mcp", http_client=http_client)) as client, ): await client.list_tools() + + +@pytest.mark.anyio +async def test_idle_session_is_reaped(): + """After idle timeout fires, the session returns 404.""" + app = Server("test-idle-reap") + manager = StreamableHTTPSessionManager(app=app, session_idle_timeout=0.05) + + async with manager.run(): + sent_messages: list[Message] = [] + + async def mock_send(message: Message): + sent_messages.append(message) + + scope = { + "type": "http", + "method": "POST", + "path": "/mcp", + "headers": [(b"content-type", b"application/json")], + } + + async def mock_receive(): # pragma: no cover + return {"type": "http.request", "body": b"", "more_body": False} + + await manager.handle_request(scope, mock_receive, mock_send) + + session_id = None + for msg in sent_messages: # pragma: no branch + if msg["type"] == "http.response.start": # pragma: no branch + for header_name, header_value in msg.get("headers", []): # pragma: no branch + if header_name.decode().lower() == MCP_SESSION_ID_HEADER.lower(): + session_id = header_value.decode() + break + if session_id: # pragma: no branch + break + + assert session_id is not None, "Session ID not found in response headers" + + # Wait for the 50ms idle timeout to fire and cleanup to complete + await anyio.sleep(0.1) + + # Verify via public API: old session ID now returns 404 + response_messages: list[Message] = [] + + async def capture_send(message: Message): + response_messages.append(message) + + scope_with_session = { + "type": "http", + "method": "POST", + "path": "/mcp", + "headers": [ + (b"content-type", b"application/json"), + (b"mcp-session-id", session_id.encode()), + ], + } + + await manager.handle_request(scope_with_session, mock_receive, capture_send) + + response_start = next( + (msg for msg in response_messages if msg["type"] == "http.response.start"), + None, + ) + assert response_start is not None + assert response_start["status"] == 404 + + +def test_session_idle_timeout_rejects_non_positive(): + with pytest.raises(ValueError, match="positive number"): + StreamableHTTPSessionManager(app=Server("test"), session_idle_timeout=-1) + with pytest.raises(ValueError, match="positive number"): + StreamableHTTPSessionManager(app=Server("test"), session_idle_timeout=0) + + +def test_session_idle_timeout_rejects_stateless(): + with pytest.raises(RuntimeError, match="not supported in stateless"): + StreamableHTTPSessionManager(app=Server("test"), session_idle_timeout=30, stateless=True) + + +def _user(client_id: str, subject: str | None = None, issuer: str | None = None) -> AuthenticatedUser: + """Build the scope["user"] value that AuthenticationMiddleware would set for this principal.""" + claims = {"iss": issuer} if issuer is not None else None + return AuthenticatedUser(AccessToken(token="token", client_id=client_id, scopes=[], subject=subject, claims=claims)) + + +def _request_scope( + *, session_id: str | None = None, user: AuthenticatedUser | None = None, method: str = "POST" +) -> Scope: + """Build an ASGI scope for a request to the MCP endpoint.""" + headers = [ + (b"content-type", b"application/json"), + (b"accept", b"application/json, text/event-stream"), + ] + if session_id is not None: + headers.append((b"mcp-session-id", session_id.encode())) + scope: Scope = { + "type": "http", + "method": method, + "path": "/mcp", + "headers": headers, + } + if user is not None: + scope["user"] = user + return scope + + +async def _open_session(manager: StreamableHTTPSessionManager, user: AuthenticatedUser | None) -> str: + """Create a new session as `user` and return its session ID.""" + sent_messages: list[Message] = [] + + async def mock_send(message: Message) -> None: + sent_messages.append(message) + + async def mock_receive() -> Message: + return {"type": "http.request", "body": b"", "more_body": False} + + await manager.handle_request(_request_scope(user=user), mock_receive, mock_send) + + response_start = next(msg for msg in sent_messages if msg["type"] == "http.response.start") + headers = dict(response_start.get("headers", [])) + return headers[MCP_SESSION_ID_HEADER.encode()].decode() + + +async def _request_session( + manager: StreamableHTTPSessionManager, session_id: str, user: AuthenticatedUser | None, method: str = "POST" +) -> int: + """Send a request for an existing session as `user` and return the response status.""" + sent_messages: list[Message] = [] + + async def mock_send(message: Message) -> None: + sent_messages.append(message) + + async def mock_receive() -> Message: + return {"type": "http.request", "body": b"", "more_body": False} + + await manager.handle_request( + _request_scope(session_id=session_id, user=user, method=method), mock_receive, mock_send + ) + + response_start = next(msg for msg in sent_messages if msg["type"] == "http.response.start") + return response_start["status"] + + +@pytest.fixture +async def manager_with_live_session(): + """A running manager around a real `Server`. Sessions remain registered until + `manager.run()` exits because `Server.run` blocks waiting for an initialize message.""" + manager = StreamableHTTPSessionManager(app=Server("test-session-credentials")) + async with manager.run(): + yield manager + + +@pytest.mark.anyio +async def test_session_accepts_requests_from_the_credential_that_created_it( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """Requests presenting the same credential as the one that created the session are served.""" + manager = manager_with_live_session + session_id = await _open_session(manager, _user("client-a")) + + status = await _request_session(manager, session_id, _user("client-a")) + + # The request passes the manager's credential check and reaches the + # session's transport, instead of being answered with 404 by the manager. + assert status != 404 + + +@pytest.mark.anyio +@pytest.mark.parametrize("method", ["POST", "GET", "DELETE"]) +async def test_session_rejects_requests_from_a_different_credential( + manager_with_live_session: StreamableHTTPSessionManager, method: str +) -> None: + """A session created by one credential cannot be used with another credential, whatever the method.""" + manager = manager_with_live_session + session_id = await _open_session(manager, _user("client-a")) + + assert await _request_session(manager, session_id, _user("client-b"), method) == 404 + # The session is still registered and still serves its creator. + assert await _request_session(manager, session_id, _user("client-a")) != 404 + + +@pytest.mark.anyio +async def test_session_rejects_requests_from_a_different_subject_of_the_same_client( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """Two end-users that share an OAuth client cannot use each other's sessions.""" + manager = manager_with_live_session + session_id = await _open_session(manager, _user("client-a", subject="alice")) + + assert await _request_session(manager, session_id, _user("client-a", subject="bob")) == 404 + assert await _request_session(manager, session_id, _user("client-a", subject=None)) == 404 + assert await _request_session(manager, session_id, _user("client-a", subject="alice")) != 404 + + +@pytest.mark.anyio +async def test_session_rejects_requests_with_the_same_subject_from_a_different_issuer( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """A subject is unique only per issuer, so a colliding subject from a different issuer is not the same principal.""" + manager = manager_with_live_session + creator = _user("client-a", subject="alice", issuer="https://issuer.one") + session_id = await _open_session(manager, creator) + + other_issuer = _user("client-a", subject="alice", issuer="https://issuer.two") + assert await _request_session(manager, session_id, other_issuer) == 404 + assert await _request_session(manager, session_id, _user("client-a", subject="alice")) == 404 + assert await _request_session(manager, session_id, creator) != 404 + + +@pytest.mark.anyio +async def test_session_rejects_unauthenticated_requests_for_an_authenticated_session( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """A session created with a credential cannot be used without one.""" + manager = manager_with_live_session + session_id = await _open_session(manager, _user("client-a")) + + assert await _request_session(manager, session_id, None) == 404 + + +@pytest.mark.anyio +async def test_session_rejects_authenticated_requests_for_an_anonymous_session( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """A session created without a credential cannot be used with one.""" + manager = manager_with_live_session + session_id = await _open_session(manager, None) + + assert await _request_session(manager, session_id, _user("client-a")) == 404 + + +@pytest.mark.anyio +async def test_anonymous_session_accepts_anonymous_requests( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """Servers without authentication keep working: no credential on either side.""" + manager = manager_with_live_session + session_id = await _open_session(manager, None) + + assert await _request_session(manager, session_id, None) != 404 diff --git a/tests/server/test_transport_security.py b/tests/server/test_transport_security.py new file mode 100644 index 0000000000..be28980b53 --- /dev/null +++ b/tests/server/test_transport_security.py @@ -0,0 +1,88 @@ +"""Tests for the transport-security request validation middleware.""" + +import pytest +from starlette.requests import Request + +from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings + + +def _request(host: str | None, origin: str | None, content_type: str | None = "application/json") -> Request: + headers: list[tuple[bytes, bytes]] = [] + if content_type is not None: + headers.append((b"content-type", content_type.encode())) + if host is not None: + headers.append((b"host", host.encode())) + if origin is not None: + headers.append((b"origin", origin.encode())) + return Request({"type": "http", "method": "GET", "headers": headers}) + + +SETTINGS = TransportSecuritySettings( + enable_dns_rebinding_protection=True, + allowed_hosts=["good.example", "wild.example:*"], + allowed_origins=["http://good.example", "http://wild.example:*"], +) + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("host", "origin", "expected"), + [ + pytest.param(None, None, 421, id="missing-host"), + pytest.param("evil.example", None, 421, id="host-no-match"), + pytest.param("evil.example:9000", None, 421, id="host-wildcard-base-mismatch"), + pytest.param("good.example", None, None, id="host-exact-no-origin"), + pytest.param("wild.example:9000", None, None, id="host-wildcard-match"), + pytest.param("good.example", "http://evil.example", 403, id="origin-no-match"), + pytest.param("good.example", "http://evil.example:9000", 403, id="origin-wildcard-base-mismatch"), + pytest.param("good.example", "http://good.example", None, id="origin-exact"), + pytest.param("good.example", "http://wild.example:9000", None, id="origin-wildcard-match"), + ], +) +async def test_validate_request_checks_host_then_origin( + host: str | None, origin: str | None, expected: int | None +) -> None: + """Host is checked first, then Origin; exact and wildcard-port allowlist entries are honoured.""" + middleware = TransportSecurityMiddleware(SETTINGS) + response = await middleware.validate_request(_request(host, origin)) + assert (None if response is None else response.status_code) == expected + + +@pytest.mark.anyio +async def test_validate_request_skips_host_and_origin_when_protection_is_disabled() -> None: + """With DNS-rebinding protection off, any Host/Origin is accepted.""" + middleware = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False)) + assert await middleware.validate_request(_request("evil.example", "http://evil.example")) is None + + +@pytest.mark.anyio +async def test_validate_request_defaults_to_protection_disabled() -> None: + """Constructing the middleware without settings leaves DNS-rebinding protection off.""" + middleware = TransportSecurityMiddleware() + assert await middleware.validate_request(_request("evil.example", "http://evil.example")) is None + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("content_type", "expected"), + [ + pytest.param("application/json", None, id="json"), + pytest.param("application/json; charset=utf-8", None, id="json-with-charset"), + pytest.param("APPLICATION/JSON", None, id="case-insensitive"), + pytest.param("text/plain", 400, id="wrong-type"), + pytest.param(None, 400, id="missing"), + ], +) +async def test_validate_request_checks_content_type_on_post(content_type: str | None, expected: int | None) -> None: + """POST requests must carry an application/json Content-Type, regardless of DNS-rebinding settings.""" + middleware = TransportSecurityMiddleware() + response = await middleware.validate_request(_request("any", None, content_type=content_type), is_post=True) + assert (None if response is None else response.status_code) == expected + + +@pytest.mark.anyio +async def test_validate_request_ignores_content_type_on_get() -> None: + """Content-Type is only enforced for POST requests.""" + middleware = TransportSecurityMiddleware(SETTINGS) + response = await middleware.validate_request(_request("good.example", None, content_type="text/plain")) + assert response is None diff --git a/tests/shared/test_auth.py b/tests/shared/test_auth.py index cd3c35332f..7463bc5a8a 100644 --- a/tests/shared/test_auth.py +++ b/tests/shared/test_auth.py @@ -1,6 +1,9 @@ """Tests for OAuth 2.0 shared code.""" -from mcp.shared.auth import OAuthMetadata +import pytest +from pydantic import ValidationError + +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthMetadata def test_oauth(): @@ -58,3 +61,80 @@ def test_oauth_with_jarm(): "token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"], } ) + + +# RFC 7591 §2 marks client_uri/logo_uri/tos_uri/policy_uri/jwks_uri as OPTIONAL. +# Some authorization servers echo the client's omitted metadata back as "" +# instead of dropping the keys; without coercion, AnyHttpUrl rejects "" and +# the whole registration response is thrown away even though the server +# returned a valid client_id. + + +@pytest.mark.parametrize( + "empty_field", + ["client_uri", "logo_uri", "tos_uri", "policy_uri", "jwks_uri"], +) +def test_optional_url_empty_string_coerced_to_none(empty_field: str): + data = { + "redirect_uris": ["https://example.com/callback"], + empty_field: "", + } + metadata = OAuthClientMetadata.model_validate(data) + assert getattr(metadata, empty_field) is None + + +def test_all_optional_urls_empty_together(): + data = { + "redirect_uris": ["https://example.com/callback"], + "client_uri": "", + "logo_uri": "", + "tos_uri": "", + "policy_uri": "", + "jwks_uri": "", + } + metadata = OAuthClientMetadata.model_validate(data) + assert metadata.client_uri is None + assert metadata.logo_uri is None + assert metadata.tos_uri is None + assert metadata.policy_uri is None + assert metadata.jwks_uri is None + + +def test_valid_url_passes_through_unchanged(): + data = { + "redirect_uris": ["https://example.com/callback"], + "client_uri": "https://udemy.com/", + } + metadata = OAuthClientMetadata.model_validate(data) + assert str(metadata.client_uri) == "https://udemy.com/" + + +def test_information_full_inherits_coercion(): + """OAuthClientInformationFull subclasses OAuthClientMetadata, so the + same coercion applies to DCR responses parsed via the full model.""" + data = { + "client_id": "abc123", + "redirect_uris": ["https://example.com/callback"], + "client_uri": "", + "logo_uri": "", + "tos_uri": "", + "policy_uri": "", + "jwks_uri": "", + } + info = OAuthClientInformationFull.model_validate(data) + assert info.client_id == "abc123" + assert info.client_uri is None + assert info.logo_uri is None + assert info.tos_uri is None + assert info.policy_uri is None + assert info.jwks_uri is None + + +def test_invalid_non_empty_url_still_rejected(): + """Coercion must only touch empty strings — garbage URLs still raise.""" + data = { + "redirect_uris": ["https://example.com/callback"], + "client_uri": "not a url", + } + with pytest.raises(ValidationError): + OAuthClientMetadata.model_validate(data) diff --git a/tests/shared/test_auth_utils.py b/tests/shared/test_auth_utils.py index 2c1c16dc32..5ae0e22b0c 100644 --- a/tests/shared/test_auth_utils.py +++ b/tests/shared/test_auth_utils.py @@ -104,7 +104,7 @@ def test_check_resource_allowed_trailing_slash_handling(): """Trailing slashes should be handled correctly.""" # With and without trailing slashes assert check_resource_allowed("https://example.com/api/", "https://example.com/api") is True - assert check_resource_allowed("https://example.com/api", "https://example.com/api/") is False + assert check_resource_allowed("https://example.com/api", "https://example.com/api/") is True assert check_resource_allowed("https://example.com/api/v1", "https://example.com/api") is True assert check_resource_allowed("https://example.com/api/v1", "https://example.com/api/") is True diff --git a/tests/shared/test_memory.py b/tests/shared/test_memory.py deleted file mode 100644 index 31238b9ffd..0000000000 --- a/tests/shared/test_memory.py +++ /dev/null @@ -1,30 +0,0 @@ -import pytest - -from mcp import Client -from mcp.server import Server -from mcp.types import EmptyResult, Resource - - -@pytest.fixture -def mcp_server() -> Server: - server = Server(name="test_server") - - @server.list_resources() - async def handle_list_resources(): # pragma: no cover - return [ - Resource( - uri="memory://test", - name="Test Resource", - description="A test resource", - ) - ] - - return server - - -@pytest.mark.anyio -async def test_memory_server_and_client_connection(mcp_server: Server): - """Shows how a client and server can communicate over memory streams.""" - async with Client(mcp_server) as client: - response = await client.send_ping() - assert isinstance(response, EmptyResult) diff --git a/tests/shared/test_otel.py b/tests/shared/test_otel.py new file mode 100644 index 0000000000..ec7ff78cc1 --- /dev/null +++ b/tests/shared/test_otel.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import pytest +from logfire.testing import CaptureLogfire + +from mcp import types +from mcp.client.client import Client +from mcp.server.mcpserver import MCPServer + +pytestmark = pytest.mark.anyio + + +# Logfire warns about propagated trace context by default (distributed_tracing=None). +# This is expected here since we're testing cross-boundary context propagation. +@pytest.mark.filterwarnings("ignore::RuntimeWarning") +async def test_client_and_server_spans(capfire: CaptureLogfire): + """Verify that calling a tool produces client and server spans with correct attributes.""" + server = MCPServer("test") + + @server.tool() + def greet(name: str) -> str: + """Greet someone.""" + return f"Hello, {name}!" + + async with Client(server) as client: + result = await client.call_tool("greet", {"name": "World"}) + + assert isinstance(result.content[0], types.TextContent) + assert result.content[0].text == "Hello, World!" + + spans = capfire.exporter.exported_spans_as_dict() + span_names = {s["name"] for s in spans} + + assert "MCP send tools/call greet" in span_names + assert "MCP handle tools/call greet" in span_names + + client_span = next(s for s in spans if s["name"] == "MCP send tools/call greet") + server_span = next(s for s in spans if s["name"] == "MCP handle tools/call greet") + + assert client_span["attributes"]["mcp.method.name"] == "tools/call" + assert server_span["attributes"]["mcp.method.name"] == "tools/call" + + # Server span should be in the same trace as the client span (context propagation). + assert server_span["context"]["trace_id"] == client_span["context"]["trace_id"] diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index ab117f1f02..aad9e5d439 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -6,13 +6,11 @@ from mcp import Client, types from mcp.client.session import ClientSession -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession -from mcp.shared._context import RequestContext from mcp.shared.message import SessionMessage -from mcp.shared.progress import progress from mcp.shared.session import RequestResponder @@ -35,9 +33,6 @@ async def run_server(): capabilities=server.get_capabilities(NotificationOptions(), {}), ), ) as server_session: - global serv_sesh - - serv_sesh = server_session async for message in server_session.incoming_messages: try: await server._handle_message(message, server_session, {}) @@ -52,79 +47,73 @@ async def run_server(): server_progress_token = "server_token_123" client_progress_token = "client_token_456" - # Create a server with progress capability - server = Server(name="ProgressTestServer") - # Register progress handler - @server.progress_notification() - async def handle_progress( - progress_token: str | int, - progress: float, - total: float | None, - message: str | None, - ): + async def handle_progress(ctx: ServerRequestContext, params: types.ProgressNotificationParams) -> None: server_progress_updates.append( { - "token": progress_token, - "progress": progress, - "total": total, - "message": message, + "token": params.progress_token, + "progress": params.progress, + "total": params.total, + "message": params.message, } ) # Register list tool handler - @server.list_tools() - async def handle_list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="test_tool", - description="A tool that sends progress notifications types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool( + name="test_tool", + description="A tool that sends progress notifications list[types.TextContent]: + async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: # Make sure we received a progress token - if name == "test_tool": - if arguments and "_meta" in arguments: - progressToken = arguments["_meta"]["progressToken"] - - if not progressToken: # pragma: no cover - raise ValueError("Empty progress token received") - - if progressToken != client_progress_token: # pragma: no cover - raise ValueError("Server sending back incorrect progressToken") - - # Send progress notifications - await serv_sesh.send_progress_notification( - progress_token=progressToken, - progress=0.25, - total=1.0, - message="Server progress 25%", - ) + if params.name == "test_tool": + assert params.meta is not None + progress_token = params.meta.get("progress_token") + assert progress_token is not None + assert progress_token == client_progress_token + + # Send progress notifications using ctx.session + await ctx.session.send_progress_notification( + progress_token=progress_token, + progress=0.25, + total=1.0, + message="Server progress 25%", + ) - await serv_sesh.send_progress_notification( - progress_token=progressToken, - progress=0.5, - total=1.0, - message="Server progress 50%", - ) + await ctx.session.send_progress_notification( + progress_token=progress_token, + progress=0.5, + total=1.0, + message="Server progress 50%", + ) - await serv_sesh.send_progress_notification( - progress_token=progressToken, - progress=1.0, - total=1.0, - message="Server progress 100%", - ) + await ctx.session.send_progress_notification( + progress_token=progress_token, + progress=1.0, + total=1.0, + message="Server progress 100%", + ) - else: # pragma: no cover - raise ValueError("Progress token not sent.") + return types.CallToolResult(content=[types.TextContent(type="text", text="Tool executed successfully")]) - return [types.TextContent(type="text", text="Tool executed successfully")] + raise ValueError(f"Unknown tool: {params.name}") # pragma: no cover - raise ValueError(f"Unknown tool: {name}") # pragma: no cover + # Create a server with progress capability + server = Server( + name="ProgressTestServer", + on_progress=handle_progress, + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) # Client message handler to store progress notifications async def handle_client_message( @@ -164,7 +153,7 @@ async def handle_client_message( await client_session.list_tools() # Call test_tool with progress token - await client_session.call_tool("test_tool", {"_meta": {"progressToken": client_progress_token}}) + await client_session.call_tool("test_tool", meta={"progress_token": client_progress_token}) # Send progress notifications from client to server await client_session.send_progress_notification( @@ -207,118 +196,6 @@ async def handle_client_message( assert server_progress_updates[2]["progress"] == 1.0 -@pytest.mark.anyio -async def test_progress_context_manager(): - """Test client using progress context manager for sending progress notifications.""" - # Create memory streams for client/server - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](5) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5) - - # Track progress updates - server_progress_updates: list[dict[str, Any]] = [] - - server = Server(name="ProgressContextTestServer") - - progress_token = None - - # Register progress handler - @server.progress_notification() - async def handle_progress( - progress_token: str | int, - progress: float, - total: float | None, - message: str | None, - ): - server_progress_updates.append( - {"token": progress_token, "progress": progress, "total": total, "message": message} - ) - - # Run server session to receive progress updates - async def run_server(): - # Create a server session - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="ProgressContextTestServer", - server_version="0.1.0", - capabilities=server.get_capabilities(NotificationOptions(), {}), - ), - ) as server_session: - async for message in server_session.incoming_messages: - try: - await server._handle_message(message, server_session, {}) - except Exception as e: # pragma: no cover - raise e - - # Client message handler - async def handle_client_message( - message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, - ) -> None: - if isinstance(message, Exception): # pragma: no cover - raise message - - # run client session - async with ( - ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=handle_client_message, - ) as client_session, - anyio.create_task_group() as tg, - ): - tg.start_soon(run_server) - - await client_session.initialize() - - progress_token = "client_token_456" - - # Create request context - request_context = RequestContext( - request_id="test-request", - session=client_session, - meta={"progress_token": progress_token}, - ) - - # Utilize progress context manager - with progress(request_context, total=100) as p: - await p.progress(10, message="Loading configuration...") - await p.progress(30, message="Connecting to database...") - await p.progress(40, message="Fetching data...") - await p.progress(20, message="Processing results...") - - # Wait for all messages to be processed - await anyio.sleep(0.5) - tg.cancel_scope.cancel() - - # Verify progress updates were received by server - assert len(server_progress_updates) == 4 - - # first update - assert server_progress_updates[0]["token"] == progress_token - assert server_progress_updates[0]["progress"] == 10 - assert server_progress_updates[0]["total"] == 100 - assert server_progress_updates[0]["message"] == "Loading configuration..." - - # second update - assert server_progress_updates[1]["token"] == progress_token - assert server_progress_updates[1]["progress"] == 40 - assert server_progress_updates[1]["total"] == 100 - assert server_progress_updates[1]["message"] == "Connecting to database..." - - # third update - assert server_progress_updates[2]["token"] == progress_token - assert server_progress_updates[2]["progress"] == 80 - assert server_progress_updates[2]["total"] == 100 - assert server_progress_updates[2]["message"] == "Fetching data..." - - # final update - assert server_progress_updates[3]["token"] == progress_token - assert server_progress_updates[3]["progress"] == 100 - assert server_progress_updates[3]["total"] == 100 - assert server_progress_updates[3]["message"] == "Processing results..." - - @pytest.mark.anyio async def test_progress_callback_exception_logging(): """Test that exceptions in progress callbacks are logged and \ @@ -334,30 +211,37 @@ async def failing_progress_callback(progress: float, total: float | None, messag raise ValueError("Progress callback failed!") # Create a server with a tool that sends progress notifications - server = Server(name="TestProgressServer") - - @server.call_tool() - async def handle_call_tool(name: str, arguments: Any) -> list[types.TextContent]: - if name == "progress_tool": + async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + if params.name == "progress_tool": + assert ctx.request_id is not None # Send a progress notification - await server.request_context.session.send_progress_notification( - progress_token=server.request_context.request_id, + await ctx.session.send_progress_notification( + progress_token=ctx.request_id, progress=50.0, total=100.0, message="Halfway done", ) - return [types.TextContent(type="text", text="progress_result")] - raise ValueError(f"Unknown tool: {name}") # pragma: no cover - - @server.list_tools() - async def handle_list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="progress_tool", - description="A tool that sends progress notifications", - input_schema={}, - ) - ] + return types.CallToolResult(content=[types.TextContent(type="text", text="progress_result")]) + raise ValueError(f"Unknown tool: {params.name}") # pragma: no cover + + async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool( + name="progress_tool", + description="A tool that sends progress notifications", + input_schema={}, + ) + ] + ) + + server = Server( + name="TestProgressServer", + on_call_tool=handle_call_tool, + on_list_tools=handle_list_tools, + ) # Test with mocked logging with patch("mcp.shared.session.logging.exception", side_effect=mock_log_exception): diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 182b4671df..d7c6cc3b5f 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -1,23 +1,25 @@ -from typing import Any - import anyio import pytest from mcp import Client, types from mcp.client.session import ClientSession -from mcp.server.lowlevel.server import Server +from mcp.server import Server, ServerRequestContext from mcp.shared.exceptions import MCPError from mcp.shared.memory import create_client_server_memory_streams from mcp.shared.message import SessionMessage +from mcp.shared.session import RequestResponder from mcp.types import ( + PARSE_ERROR, CancelledNotification, CancelledNotificationParams, + ClientResult, EmptyResult, ErrorData, JSONRPCError, JSONRPCRequest, JSONRPCResponse, - TextContent, + ServerNotification, + ServerRequest, ) @@ -42,29 +44,25 @@ async def test_request_cancellation(): request_id = None # Create a server with a slow tool - server = Server(name="TestSessionServer") - - # Register the tool handler - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[TextContent]: + async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: nonlocal request_id, ev_tool_called - if name == "slow_tool": - request_id = server.request_context.request_id + if params.name == "slow_tool": + request_id = ctx.request_id ev_tool_called.set() await anyio.sleep(10) # Long enough to ensure we can cancel - return [] # pragma: no cover - raise ValueError(f"Unknown tool: {name}") # pragma: no cover - - # Register the tool so it shows up in list_tools - @server.list_tools() - async def handle_list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="slow_tool", - description="A slow tool that takes 10 seconds to complete", - input_schema={}, - ) - ] + return types.CallToolResult(content=[]) # pragma: no cover + raise ValueError(f"Unknown tool: {params.name}") # pragma: no cover + + async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + raise NotImplementedError + + server = Server( + name="TestSessionServer", + on_call_tool=handle_call_tool, + on_list_tools=handle_list_tools, + ) async def make_request(client: Client): nonlocal ev_cancelled @@ -304,3 +302,117 @@ async def mock_server(): await ev_closed.wait() with anyio.fail_after(1): # pragma: no branch await ev_response.wait() + + +@pytest.mark.anyio +async def test_null_id_error_surfaced_via_message_handler(): + """Test that a JSONRPCError with id=None is surfaced to the message handler. + + Per JSON-RPC 2.0, error responses use id=null when the request id could not + be determined (e.g., parse errors). These cannot be correlated to any pending + request, so they are forwarded to the message handler as MCPError. + """ + ev_error_received = anyio.Event() + error_holder: list[MCPError] = [] + + async def capture_errors( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: + assert isinstance(message, MCPError) + error_holder.append(message) + ev_error_received.set() + + sent_error = ErrorData(code=PARSE_ERROR, message="Parse error") + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + _server_read, server_write = server_streams + + async def mock_server(): + """Send a null-id error (simulating a parse error).""" + error_response = JSONRPCError(jsonrpc="2.0", id=None, error=sent_error) + await server_write.send(SessionMessage(message=error_response)) + + async with ( + anyio.create_task_group() as tg, + ClientSession( + read_stream=client_read, + write_stream=client_write, + message_handler=capture_errors, + ) as _client_session, + ): + tg.start_soon(mock_server) + + with anyio.fail_after(2): # pragma: no branch + await ev_error_received.wait() + + assert len(error_holder) == 1 + assert error_holder[0].error == sent_error + + +@pytest.mark.anyio +async def test_null_id_error_does_not_affect_pending_request(): + """Test that a null-id error doesn't interfere with an in-flight request. + + When a null-id error arrives while a request is pending, the error should + go to the message handler and the pending request should still complete + normally with its own response. + """ + ev_error_received = anyio.Event() + ev_response_received = anyio.Event() + error_holder: list[MCPError] = [] + result_holder: list[EmptyResult] = [] + + async def capture_errors( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: + assert isinstance(message, MCPError) + error_holder.append(message) + ev_error_received.set() + + sent_error = ErrorData(code=PARSE_ERROR, message="Parse error") + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async def mock_server(): + """Read a request, inject a null-id error, then respond normally.""" + message = await server_read.receive() + assert isinstance(message, SessionMessage) + assert isinstance(message.message, JSONRPCRequest) + request_id = message.message.id + + # First, send a null-id error (should go to message handler) + await server_write.send(SessionMessage(message=JSONRPCError(jsonrpc="2.0", id=None, error=sent_error))) + + # Then, respond normally to the pending request + await server_write.send(SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=request_id, result={}))) + + async def make_request(client_session: ClientSession): + result = await client_session.send_ping() + result_holder.append(result) + ev_response_received.set() + + async with ( + anyio.create_task_group() as tg, + ClientSession( + read_stream=client_read, + write_stream=client_write, + message_handler=capture_errors, + ) as client_session, + ): + tg.start_soon(mock_server) + tg.start_soon(make_request, client_session) + + with anyio.fail_after(2): # pragma: no branch + await ev_error_received.wait() + await ev_response_received.wait() + + # Null-id error reached the message handler + assert len(error_holder) == 1 + assert error_holder[0].error == sent_error + + # Pending request completed successfully + assert len(result_holder) == 1 + assert isinstance(result_holder[0], EmptyResult) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index df2321ba16..5629a5707b 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -1,7 +1,6 @@ import json import multiprocessing import socket -import time from collections.abc import AsyncGenerator, Generator from typing import Any from unittest.mock import AsyncMock, MagicMock, Mock, patch @@ -22,15 +21,20 @@ from mcp import types from mcp.client.session import ClientSession from mcp.client.sse import _extract_session_id_from_endpoint, sse_client -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.exceptions import MCPError from mcp.types import ( + CallToolRequestParams, + CallToolResult, EmptyResult, Implementation, InitializeResult, JSONRPCResponse, + ListToolsResult, + PaginatedRequestParams, + ReadResourceRequestParams, ReadResourceResult, ServerCapabilities, TextContent, @@ -54,36 +58,48 @@ def server_url(server_port: int) -> str: return f"http://127.0.0.1:{server_port}" -# Test server implementation -class ServerTest(Server): # pragma: no cover - def __init__(self): - super().__init__(SERVER_NAME) - - @self.read_resource() - async def handle_read_resource(uri: str) -> str | bytes: - parsed = urlparse(uri) - if parsed.scheme == "foobar": - return f"Read {parsed.netloc}" - if parsed.scheme == "slow": - # Simulate a slow resource - await anyio.sleep(2.0) - return f"Slow response from {parsed.netloc}" - - raise MCPError(code=404, message="OOPS! no resource with that URI was found") - - @self.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="test_tool", - description="A test tool", - input_schema={"type": "object", "properties": {}}, - ) - ] +async def _handle_read_resource( # pragma: no cover + ctx: ServerRequestContext, params: ReadResourceRequestParams +) -> ReadResourceResult: + uri = str(params.uri) + parsed = urlparse(uri) + if parsed.scheme == "foobar": + text = f"Read {parsed.netloc}" + elif parsed.scheme == "slow": + await anyio.sleep(2.0) + text = f"Slow response from {parsed.netloc}" + else: + raise MCPError(code=404, message="OOPS! no resource with that URI was found") + return ReadResourceResult(contents=[TextResourceContents(uri=uri, text=text, mime_type="text/plain")]) + + +async def _handle_list_tools( # pragma: no cover + ctx: ServerRequestContext, params: PaginatedRequestParams | None +) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="test_tool", + description="A test tool", + input_schema={"type": "object", "properties": {}}, + ) + ] + ) - @self.call_tool() - async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: - return [TextContent(type="text", text=f"Called {name}")] + +async def _handle_call_tool( # pragma: no cover + ctx: ServerRequestContext, params: CallToolRequestParams +) -> CallToolResult: + return CallToolResult(content=[TextContent(type="text", text=f"Called {params.name}")]) + + +def _create_server() -> Server: # pragma: no cover + return Server( + SERVER_NAME, + on_read_resource=_handle_read_resource, + on_list_tools=_handle_list_tools, + on_call_tool=_handle_call_tool, + ) # Test fixtures @@ -94,7 +110,7 @@ def make_server_app() -> Starlette: # pragma: no cover allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] ) sse = SseServerTransport("/messages/", security_settings=security_settings) - server = ServerTest() + server = _create_server() async def handle_sse(request: Request) -> Response: async with sse.connect_sse(request.scope, request.receive, request._send) as streams: @@ -117,11 +133,6 @@ def run_server(server_port: int) -> None: # pragma: no cover print(f"starting server on {server_port}") server.run() - # Give server time to start - while not server.started: - print("waiting for server to start") - time.sleep(0.5) - @pytest.fixture() def server(server_port: int) -> Generator[None, None, None]: @@ -192,19 +203,15 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non @pytest.mark.anyio async def test_sse_client_on_session_created(server: None, server_url: str) -> None: - captured_session_id: str | None = None - - def on_session_created(session_id: str) -> None: - nonlocal captured_session_id - captured_session_id = session_id + captured: list[str] = [] - async with sse_client(server_url + "/sse", on_session_created=on_session_created) as streams: + async with sse_client(server_url + "/sse", on_session_created=captured.append) as streams: async with ClientSession(*streams) as session: result = await session.initialize() assert isinstance(result, InitializeResult) - - assert captured_session_id is not None # pragma: lax no cover - assert len(captured_session_id) > 0 # pragma: lax no cover + # Callback fires when the endpoint event arrives, before sse_client yields. + assert len(captured) == 1 + assert len(captured[0]) > 0 @pytest.mark.parametrize( @@ -237,8 +244,9 @@ def mock_extract(url: str) -> None: async with ClientSession(*streams) as session: result = await session.initialize() assert isinstance(result, InitializeResult) - - callback_mock.assert_not_called() # pragma: lax no cover + # Callback would have fired by now (endpoint event arrives before + # sse_client yields); if it hasn't, it won't. + callback_mock.assert_not_called() @pytest.fixture @@ -296,11 +304,6 @@ def run_mounted_server(server_port: int) -> None: # pragma: no cover print(f"starting server on {server_port}") server.run() - # Give server time to start - while not server.started: - print("waiting for server to start") - time.sleep(0.5) - @pytest.fixture() def mounted_server(server_port: int) -> Generator[None, None, None]: @@ -336,47 +339,46 @@ async def test_sse_client_basic_connection_mounted_app(mounted_server: None, ser assert isinstance(ping_result, EmptyResult) -# Test server with request context that returns headers in the response -class RequestContextServer(Server[object, Request]): # pragma: no cover - def __init__(self): - super().__init__("request_context_server") - - @self.call_tool() - async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: - headers_info = {} - context = self.request_context - if context.request: - headers_info = dict(context.request.headers) - - if name == "echo_headers": - return [TextContent(type="text", text=json.dumps(headers_info))] - elif name == "echo_context": - context_data = { - "request_id": args.get("request_id"), - "headers": headers_info, - } - return [TextContent(type="text", text=json.dumps(context_data))] - - return [TextContent(type="text", text=f"Called {name}")] - - @self.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="echo_headers", - description="Echoes request headers", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="echo_context", - description="Echoes request context", - input_schema={ - "type": "object", - "properties": {"request_id": {"type": "string"}}, - "required": ["request_id"], - }, - ), - ] +async def _handle_context_call_tool( # pragma: no cover + ctx: ServerRequestContext, params: CallToolRequestParams +) -> CallToolResult: + headers_info: dict[str, Any] = {} + if ctx.request: + headers_info = dict(ctx.request.headers) + + if params.name == "echo_headers": + return CallToolResult(content=[TextContent(type="text", text=json.dumps(headers_info))]) + elif params.name == "echo_context": + context_data = { + "request_id": (params.arguments or {}).get("request_id"), + "headers": headers_info, + } + return CallToolResult(content=[TextContent(type="text", text=json.dumps(context_data))]) + + return CallToolResult(content=[TextContent(type="text", text=f"Called {params.name}")]) + + +async def _handle_context_list_tools( # pragma: no cover + ctx: ServerRequestContext, params: PaginatedRequestParams | None +) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="echo_headers", + description="Echoes request headers", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="echo_context", + description="Echoes request context", + input_schema={ + "type": "object", + "properties": {"request_id": {"type": "string"}}, + "required": ["request_id"], + }, + ), + ] + ) def run_context_server(server_port: int) -> None: # pragma: no cover @@ -386,7 +388,11 @@ def run_context_server(server_port: int) -> None: # pragma: no cover allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] ) sse = SseServerTransport("/messages/", security_settings=security_settings) - context_server = RequestContextServer() + context_server = Server( + "request_context_server", + on_call_tool=_handle_context_call_tool, + on_list_tools=_handle_context_list_tools, + ) async def handle_sse(request: Request) -> Response: async with sse.connect_sse(request.scope, request.receive, request._send) as streams: @@ -538,12 +544,6 @@ def test_sse_server_transport_endpoint_validation(endpoint: str, expected_result assert sse._endpoint.startswith("/") -# ResourceWarning filter: When mocking aconnect_sse, the sse_client's internal task -# group doesn't receive proper cancellation signals, so the sse_reader task's finally -# block (which closes read_stream_writer) doesn't execute. This is a test artifact - -# the actual code path (`if not sse.data: continue`) IS exercised and works correctly. -# Production code with real SSE connections cleans up properly. -@pytest.mark.filterwarnings("ignore::ResourceWarning") @pytest.mark.anyio async def test_sse_client_handles_empty_keepalive_pings() -> None: """Test that SSE client properly handles empty data lines (keep-alive pings). @@ -602,3 +602,30 @@ async def mock_aiter_sse() -> AsyncGenerator[ServerSentEvent, None]: assert not isinstance(msg, Exception) assert isinstance(msg.message, types.JSONRPCResponse) assert msg.message.id == 1 + + +@pytest.mark.anyio +async def test_sse_session_cleanup_on_disconnect(server: None, server_url: str) -> None: + """Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/1227 + + When a client disconnects, the server should remove the session from + _read_stream_writers. Without this cleanup, stale sessions accumulate and + POST requests to disconnected sessions return 202 Accepted followed by a + ClosedResourceError when the server tries to write to the dead stream. + """ + captured: list[str] = [] + + # Connect a client session, then disconnect + async with sse_client(server_url + "/sse", on_session_created=captured.append) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + + # After disconnect, POST to the stale session should return 404 + # (not 202 as it did before the fix) + async with httpx.AsyncClient() as client: + response = await client.post( + f"{server_url}/messages/?session_id={captured[0]}", + json={"jsonrpc": "2.0", "method": "ping", "id": 99}, + headers={"Content-Type": "application/json"}, + ) + assert response.status_code == 404 diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index b04b920262..3d5770fb61 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -10,7 +10,9 @@ import socket import time import traceback -from collections.abc import Generator +from collections.abc import AsyncIterator, Generator +from contextlib import asynccontextmanager +from dataclasses import dataclass, field from typing import Any from unittest.mock import MagicMock from urllib.parse import urlparse @@ -28,7 +30,7 @@ from mcp import MCPError, types from mcp.client.session import ClientSession from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.streamable_http import ( MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, @@ -43,6 +45,7 @@ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings from mcp.shared._context import RequestContext +from mcp.shared._context_streams import create_context_streams from mcp.shared._httpx_utils import ( MCP_DEFAULT_SSE_READ_TIMEOUT, MCP_DEFAULT_TIMEOUT, @@ -50,7 +53,19 @@ ) from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder -from mcp.types import InitializeResult, JSONRPCRequest, TextContent, TextResourceContents, Tool +from mcp.types import ( + CallToolRequestParams, + CallToolResult, + InitializeResult, + JSONRPCRequest, + ListToolsResult, + PaginatedRequestParams, + ReadResourceRequestParams, + ReadResourceResult, + TextContent, + TextResourceContents, + Tool, +) from tests.test_helpers import wait_for_server # Test constants @@ -124,263 +139,258 @@ async def replay_events_after( # pragma: no cover return target_stream_id -# Test server implementation that follows MCP protocol -class ServerTest(Server): # pragma: no cover - def __init__(self): - super().__init__(SERVER_NAME) - self._lock = None # Will be initialized in async context - - @self.read_resource() - async def handle_read_resource(uri: str) -> str | bytes: - parsed = urlparse(uri) - if parsed.scheme == "foobar": - return f"Read {parsed.netloc}" - if parsed.scheme == "slow": - # Simulate a slow resource - await anyio.sleep(2.0) - return f"Slow response from {parsed.netloc}" - - raise ValueError(f"Unknown resource: {uri}") - - @self.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="test_tool", - description="A test tool", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="test_tool_with_standalone_notification", - description="A test tool that sends a notification", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="long_running_with_checkpoints", - description="A long-running tool that sends periodic notifications", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="test_sampling_tool", - description="A tool that triggers server-side sampling", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="wait_for_lock_with_notification", - description="A tool that sends a notification and waits for lock", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="release_lock", - description="A tool that releases the lock", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="tool_with_stream_close", - description="A tool that closes SSE stream mid-operation", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="tool_with_multiple_notifications_and_close", - description="Tool that sends notification1, closes stream, sends notification2, notification3", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="tool_with_multiple_stream_closes", - description="Tool that closes SSE stream multiple times during execution", - input_schema={ - "type": "object", - "properties": { - "checkpoints": {"type": "integer", "default": 3}, - "sleep_time": {"type": "number", "default": 0.2}, - }, +@dataclass +class ServerState: + lock: anyio.Event = field(default_factory=anyio.Event) + + +@asynccontextmanager +async def _server_lifespan(_server: Server[ServerState]) -> AsyncIterator[ServerState]: # pragma: no cover + yield ServerState() + + +async def _handle_read_resource( # pragma: no cover + ctx: ServerRequestContext[ServerState], params: ReadResourceRequestParams +) -> ReadResourceResult: + uri = str(params.uri) + parsed = urlparse(uri) + if parsed.scheme == "foobar": + text = f"Read {parsed.netloc}" + elif parsed.scheme == "slow": + await anyio.sleep(2.0) + text = f"Slow response from {parsed.netloc}" + else: + raise ValueError(f"Unknown resource: {uri}") + return ReadResourceResult(contents=[TextResourceContents(uri=uri, text=text, mime_type="text/plain")]) + + +async def _handle_list_tools( # pragma: no cover + ctx: ServerRequestContext[ServerState], params: PaginatedRequestParams | None +) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="test_tool", + description="A test tool", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="test_tool_with_standalone_notification", + description="A test tool that sends a notification", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="long_running_with_checkpoints", + description="A long-running tool that sends periodic notifications", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="test_sampling_tool", + description="A tool that triggers server-side sampling", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="wait_for_lock_with_notification", + description="A tool that sends a notification and waits for lock", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="release_lock", + description="A tool that releases the lock", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="tool_with_stream_close", + description="A tool that closes SSE stream mid-operation", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="tool_with_multiple_notifications_and_close", + description="Tool that sends notification1, closes stream, sends notification2, notification3", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="tool_with_multiple_stream_closes", + description="Tool that closes SSE stream multiple times during execution", + input_schema={ + "type": "object", + "properties": { + "checkpoints": {"type": "integer", "default": 3}, + "sleep_time": {"type": "number", "default": 0.2}, }, - ), - Tool( - name="tool_with_standalone_stream_close", - description="Tool that closes standalone GET stream mid-operation", - input_schema={"type": "object", "properties": {}}, - ), - ] + }, + ), + Tool( + name="tool_with_standalone_stream_close", + description="Tool that closes standalone GET stream mid-operation", + input_schema={"type": "object", "properties": {}}, + ), + ] + ) - @self.call_tool() - async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: - ctx = self.request_context - # When the tool is called, send a notification to test GET stream - if name == "test_tool_with_standalone_notification": - await ctx.session.send_resource_updated(uri="http://test_resource") - return [TextContent(type="text", text=f"Called {name}")] +async def _handle_call_tool( # pragma: no cover + ctx: ServerRequestContext[ServerState], params: CallToolRequestParams +) -> CallToolResult: + name = params.name + args = params.arguments or {} - elif name == "long_running_with_checkpoints": - # Send notifications that are part of the response stream - # This simulates a long-running tool that sends logs + # When the tool is called, send a notification to test GET stream + if name == "test_tool_with_standalone_notification": + await ctx.session.send_resource_updated(uri="http://test_resource") + return CallToolResult(content=[TextContent(type="text", text=f"Called {name}")]) - await ctx.session.send_log_message( - level="info", - data="Tool started", - logger="tool", - related_request_id=ctx.request_id, # need for stream association - ) + elif name == "long_running_with_checkpoints": + await ctx.session.send_log_message( + level="info", + data="Tool started", + logger="tool", + related_request_id=ctx.request_id, + ) - await anyio.sleep(0.1) + await anyio.sleep(0.1) - await ctx.session.send_log_message( - level="info", - data="Tool is almost done", - logger="tool", - related_request_id=ctx.request_id, - ) + await ctx.session.send_log_message( + level="info", + data="Tool is almost done", + logger="tool", + related_request_id=ctx.request_id, + ) - return [TextContent(type="text", text="Completed!")] + return CallToolResult(content=[TextContent(type="text", text="Completed!")]) - elif name == "test_sampling_tool": - # Test sampling by requesting the client to sample a message - sampling_result = await ctx.session.create_message( - messages=[ - types.SamplingMessage( - role="user", - content=types.TextContent(type="text", text="Server needs client sampling"), - ) - ], - max_tokens=100, - related_request_id=ctx.request_id, + elif name == "test_sampling_tool": + sampling_result = await ctx.session.create_message( + messages=[ + types.SamplingMessage( + role="user", + content=types.TextContent(type="text", text="Server needs client sampling"), ) + ], + max_tokens=100, + related_request_id=ctx.request_id, + ) - # Return the sampling result in the tool response - # Since we're not passing tools param, result.content is single content - if sampling_result.content.type == "text": - response = sampling_result.content.text - else: - response = str(sampling_result.content) - return [ - TextContent( - type="text", - text=f"Response from sampling: {response}", - ) - ] - - elif name == "wait_for_lock_with_notification": - # Initialize lock if not already done - if self._lock is None: - self._lock = anyio.Event() - - # First send a notification - await ctx.session.send_log_message( - level="info", - data="First notification before lock", - logger="lock_tool", - related_request_id=ctx.request_id, - ) - - # Now wait for the lock to be released - await self._lock.wait() - - # Send second notification after lock is released - await ctx.session.send_log_message( - level="info", - data="Second notification after lock", - logger="lock_tool", - related_request_id=ctx.request_id, + if sampling_result.content.type == "text": + response = sampling_result.content.text + else: + response = str(sampling_result.content) + return CallToolResult( + content=[ + TextContent( + type="text", + text=f"Response from sampling: {response}", ) + ] + ) - return [TextContent(type="text", text="Completed")] + elif name == "wait_for_lock_with_notification": + await ctx.session.send_log_message( + level="info", + data="First notification before lock", + logger="lock_tool", + related_request_id=ctx.request_id, + ) - elif name == "release_lock": - assert self._lock is not None, "Lock must be initialized before releasing" + await ctx.lifespan_context.lock.wait() - # Release the lock - self._lock.set() - return [TextContent(type="text", text="Lock released")] + await ctx.session.send_log_message( + level="info", + data="Second notification after lock", + logger="lock_tool", + related_request_id=ctx.request_id, + ) - elif name == "tool_with_stream_close": - # Send notification before closing - await ctx.session.send_log_message( - level="info", - data="Before close", - logger="stream_close_tool", - related_request_id=ctx.request_id, - ) - # Close SSE stream (triggers client reconnect) - assert ctx.close_sse_stream is not None - await ctx.close_sse_stream() - # Continue processing (events stored in event_store) - await anyio.sleep(0.1) - await ctx.session.send_log_message( - level="info", - data="After close", - logger="stream_close_tool", - related_request_id=ctx.request_id, - ) - return [TextContent(type="text", text="Done")] - - elif name == "tool_with_multiple_notifications_and_close": - # Send notification1 - await ctx.session.send_log_message( - level="info", - data="notification1", - logger="multi_notif_tool", - related_request_id=ctx.request_id, - ) - # Close SSE stream - assert ctx.close_sse_stream is not None - await ctx.close_sse_stream() - # Send notification2, notification3 (stored in event_store) - await anyio.sleep(0.1) - await ctx.session.send_log_message( - level="info", - data="notification2", - logger="multi_notif_tool", - related_request_id=ctx.request_id, - ) - await ctx.session.send_log_message( - level="info", - data="notification3", - logger="multi_notif_tool", - related_request_id=ctx.request_id, - ) - return [TextContent(type="text", text="All notifications sent")] + return CallToolResult(content=[TextContent(type="text", text="Completed")]) - elif name == "tool_with_multiple_stream_closes": - num_checkpoints = args.get("checkpoints", 3) - sleep_time = args.get("sleep_time", 0.2) + elif name == "release_lock": + ctx.lifespan_context.lock.set() + return CallToolResult(content=[TextContent(type="text", text="Lock released")]) - for i in range(num_checkpoints): - await ctx.session.send_log_message( - level="info", - data=f"checkpoint_{i}", - logger="multi_close_tool", - related_request_id=ctx.request_id, - ) + elif name == "tool_with_stream_close": + await ctx.session.send_log_message( + level="info", + data="Before close", + logger="stream_close_tool", + related_request_id=ctx.request_id, + ) + assert ctx.close_sse_stream is not None + await ctx.close_sse_stream() + await anyio.sleep(0.1) + await ctx.session.send_log_message( + level="info", + data="After close", + logger="stream_close_tool", + related_request_id=ctx.request_id, + ) + return CallToolResult(content=[TextContent(type="text", text="Done")]) + + elif name == "tool_with_multiple_notifications_and_close": + await ctx.session.send_log_message( + level="info", + data="notification1", + logger="multi_notif_tool", + related_request_id=ctx.request_id, + ) + assert ctx.close_sse_stream is not None + await ctx.close_sse_stream() + await anyio.sleep(0.1) + await ctx.session.send_log_message( + level="info", + data="notification2", + logger="multi_notif_tool", + related_request_id=ctx.request_id, + ) + await ctx.session.send_log_message( + level="info", + data="notification3", + logger="multi_notif_tool", + related_request_id=ctx.request_id, + ) + return CallToolResult(content=[TextContent(type="text", text="All notifications sent")]) + + elif name == "tool_with_multiple_stream_closes": + num_checkpoints = args.get("checkpoints", 3) + sleep_time = args.get("sleep_time", 0.2) + + for i in range(num_checkpoints): + await ctx.session.send_log_message( + level="info", + data=f"checkpoint_{i}", + logger="multi_close_tool", + related_request_id=ctx.request_id, + ) - if ctx.close_sse_stream: - await ctx.close_sse_stream() + if ctx.close_sse_stream: + await ctx.close_sse_stream() - await anyio.sleep(sleep_time) + await anyio.sleep(sleep_time) - return [TextContent(type="text", text=f"Completed {num_checkpoints} checkpoints")] + return CallToolResult(content=[TextContent(type="text", text=f"Completed {num_checkpoints} checkpoints")]) - elif name == "tool_with_standalone_stream_close": - # Test for GET stream reconnection - # 1. Send unsolicited notification via GET stream (no related_request_id) - await ctx.session.send_resource_updated(uri="http://notification_1") + elif name == "tool_with_standalone_stream_close": + await ctx.session.send_resource_updated(uri="http://notification_1") + await anyio.sleep(0.1) - # Small delay to ensure notification is flushed before closing - await anyio.sleep(0.1) + if ctx.close_standalone_sse_stream: + await ctx.close_standalone_sse_stream() - # 2. Close the standalone GET stream - if ctx.close_standalone_sse_stream: - await ctx.close_standalone_sse_stream() + await anyio.sleep(1.5) + await ctx.session.send_resource_updated(uri="http://notification_2") - # 3. Wait for client to reconnect (uses retry_interval from server, default 1000ms) - await anyio.sleep(1.5) + return CallToolResult(content=[TextContent(type="text", text="Standalone stream close test done")]) - # 4. Send another notification on the new GET stream connection - await ctx.session.send_resource_updated(uri="http://notification_2") + return CallToolResult(content=[TextContent(type="text", text=f"Called {name}")]) - return [TextContent(type="text", text="Standalone stream close test done")] - return [TextContent(type="text", text=f"Called {name}")] +def _create_server() -> Server[ServerState]: # pragma: no cover + return Server( + SERVER_NAME, + lifespan=_server_lifespan, + on_read_resource=_handle_read_resource, + on_list_tools=_handle_list_tools, + on_call_tool=_handle_call_tool, + ) def create_app( @@ -396,7 +406,7 @@ def create_app( retry_interval: Retry interval in milliseconds for SSE polling. """ # Create server instance - server = ServerTest() + server = _create_server() # Create the session manager security_settings = TransportSecuritySettings( @@ -563,8 +573,10 @@ def json_server_url(json_server_port: int) -> str: # Basic request validation tests def test_accept_header_validation(basic_server: None, basic_server_url: str): """Test that Accept header is properly validated.""" - # Test without Accept header - response = requests.post( + # Test without Accept header (suppress requests library default Accept: */*) + session = requests.Session() + session.headers.pop("Accept") + response = session.post( f"{basic_server_url}/mcp", headers={"Content-Type": "application/json"}, json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, @@ -573,6 +585,52 @@ def test_accept_header_validation(basic_server: None, basic_server_url: str): assert "Not Acceptable" in response.text +@pytest.mark.parametrize( + "accept_header", + [ + "*/*", + "application/*, text/*", + "text/*, application/json", + "application/json, text/*", + "*/*;q=0.8", + "application/*;q=0.9, text/*;q=0.8", + ], +) +def test_accept_header_wildcard(basic_server: None, basic_server_url: str, accept_header: str): + """Test that wildcard Accept headers are accepted per RFC 7231.""" + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": accept_header, + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + + +@pytest.mark.parametrize( + "accept_header", + [ + "text/html", + "application/*", + "text/*", + ], +) +def test_accept_header_incompatible(basic_server: None, basic_server_url: str, accept_header: str): + """Test that incompatible Accept headers are rejected for SSE mode.""" + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": accept_header, + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text + + def test_content_type_validation(basic_server: None, basic_server_url: str): """Test that Content-Type header is properly validated.""" # Test with incorrect Content-Type @@ -817,7 +875,10 @@ def test_json_response_accept_json_only(json_response_server: None, json_server_ def test_json_response_missing_accept_header(json_response_server: None, json_server_url: str): """Test that json_response servers reject requests without Accept header.""" mcp_url = f"{json_server_url}/mcp" - response = requests.post( + # Suppress requests library default Accept: */* header + session = requests.Session() + session.headers.pop("Accept") + response = session.post( mcp_url, headers={ "Content-Type": "application/json", @@ -844,6 +905,29 @@ def test_json_response_incorrect_accept_header(json_response_server: None, json_ assert "Not Acceptable" in response.text +@pytest.mark.parametrize( + "accept_header", + [ + "*/*", + "application/*", + "application/*;q=0.9", + ], +) +def test_json_response_wildcard_accept_header(json_response_server: None, json_server_url: str, accept_header: str): + """Test that json_response servers accept wildcard Accept headers per RFC 7231.""" + mcp_url = f"{json_server_url}/mcp" + response = requests.post( + mcp_url, + headers={ + "Accept": accept_header, + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + assert response.headers.get("Content-Type") == "application/json" + + def test_get_sse_stream(basic_server: None, basic_server_url: str): """Test establishing an SSE stream via GET request.""" # First, we need to initialize a session @@ -932,8 +1016,10 @@ def test_get_validation(basic_server: None, basic_server_url: str): assert init_data is not None negotiated_version = init_data["result"]["protocolVersion"] - # Test without Accept header - response = requests.get( + # Test without Accept header (suppress requests library default Accept: */*) + session = requests.Session() + session.headers.pop("Accept") + response = session.get( mcp_url, headers={ MCP_SESSION_ID_HEADER: session_id, @@ -1123,22 +1209,19 @@ async def test_streamable_http_client_session_termination(basic_server: None, ba read_stream, write_stream, ): - async with ClientSession(read_stream, write_stream) as session: + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch # Initialize the session result = await session.initialize() assert isinstance(result, InitializeResult) assert len(captured_ids) > 0 captured_session_id = captured_ids[0] assert captured_session_id is not None + headers = {MCP_SESSION_ID_HEADER: captured_session_id} # Make a request to confirm session is working tools = await session.list_tools() assert len(tools.tools) == 10 - headers: dict[str, str] = {} # pragma: lax no cover - if captured_session_id: # pragma: lax no cover - headers[MCP_SESSION_ID_HEADER] = captured_session_id - async with create_mcp_http_client(headers=headers) as httpx_client2: async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client2) as ( read_stream, @@ -1187,22 +1270,19 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt read_stream, write_stream, ): - async with ClientSession(read_stream, write_stream) as session: + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch # Initialize the session result = await session.initialize() assert isinstance(result, InitializeResult) assert len(captured_ids) > 0 captured_session_id = captured_ids[0] assert captured_session_id is not None + headers = {MCP_SESSION_ID_HEADER: captured_session_id} # Make a request to confirm session is working tools = await session.list_tools() assert len(tools.tools) == 10 - headers: dict[str, str] = {} # pragma: lax no cover - if captured_session_id: # pragma: lax no cover - headers[MCP_SESSION_ID_HEADER] = captured_session_id - async with create_mcp_http_client(headers=headers) as httpx_client2: async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client2) as ( read_stream, @@ -1222,7 +1302,6 @@ async def test_streamable_http_client_resumption(event_server: tuple[SimpleEvent # Variables to track the state captured_resumption_token: str | None = None captured_notifications: list[types.ServerNotification] = [] - captured_protocol_version: str | int | None = None first_notification_received = False async def message_handler( # pragma: no branch @@ -1249,15 +1328,20 @@ async def on_resumption_token_update(token: str) -> None: read_stream, write_stream, ): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: + async with ClientSession( # pragma: no branch + read_stream, write_stream, message_handler=message_handler + ) as session: # Initialize the session result = await session.initialize() assert isinstance(result, InitializeResult) assert len(captured_ids) > 0 captured_session_id = captured_ids[0] assert captured_session_id is not None - # Capture the negotiated protocol version - captured_protocol_version = result.protocol_version + # Build phase-2 headers now while both values are in scope + headers: dict[str, Any] = { + MCP_SESSION_ID_HEADER: captured_session_id, + MCP_PROTOCOL_VERSION_HEADER: result.protocol_version, + } # Start the tool that will wait on lock in a task async with anyio.create_task_group() as tg: # pragma: no branch @@ -1282,25 +1366,19 @@ async def run_tool(): while not first_notification_received or not captured_resumption_token: await anyio.sleep(0.1) + # The while loop only exits after first_notification_received=True, + # which is set by message_handler immediately after appending to + # captured_notifications. The server tool is blocked on its lock, + # so nothing else can arrive before we cancel. + assert len(captured_notifications) == 1 + assert isinstance(captured_notifications[0], types.LoggingMessageNotification) + assert captured_notifications[0].params.data == "First notification before lock" + # Reset for phase 2 before cancelling + captured_notifications.clear() + # Kill the client session while tool is waiting on lock tg.cancel_scope.cancel() - # Verify we received exactly one notification (inside ClientSession - # so coverage tracks these on Python 3.11, see PR #1897 for details) - assert len(captured_notifications) == 1 # pragma: lax no cover - assert isinstance(captured_notifications[0], types.LoggingMessageNotification) # pragma: lax no cover - assert captured_notifications[0].params.data == "First notification before lock" # pragma: lax no cover - - # Clear notifications and set up headers for phase 2 (between connections, - # not tracked by coverage on Python 3.11 due to cancel scope + sys.settrace bug) - captured_notifications = [] # pragma: lax no cover - assert captured_session_id is not None # pragma: lax no cover - assert captured_protocol_version is not None # pragma: lax no cover - headers: dict[str, Any] = { # pragma: lax no cover - MCP_SESSION_ID_HEADER: captured_session_id, - MCP_PROTOCOL_VERSION_HEADER: captured_protocol_version, - } - async with create_mcp_http_client(headers=headers) as httpx_client2: async with streamable_http_client(f"{server_url}/mcp", http_client=httpx_client2) as ( read_stream, @@ -1385,69 +1463,68 @@ async def sampling_callback( # Context-aware server implementation for testing request context propagation -class ContextAwareServerTest(Server): # pragma: no cover - def __init__(self): - super().__init__("ContextAwareServer") - - @self.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="echo_headers", - description="Echo request headers from context", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="echo_context", - description="Echo request context with custom data", - input_schema={ - "type": "object", - "properties": { - "request_id": {"type": "string"}, - }, - "required": ["request_id"], +async def _handle_context_list_tools( # pragma: no cover + ctx: ServerRequestContext, params: PaginatedRequestParams | None +) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="echo_headers", + description="Echo request headers from context", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="echo_context", + description="Echo request context with custom data", + input_schema={ + "type": "object", + "properties": { + "request_id": {"type": "string"}, }, - ), - ] + "required": ["request_id"], + }, + ), + ] + ) - @self.call_tool() - async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: - ctx = self.request_context - - if name == "echo_headers": - # Access the request object from context - headers_info = {} - if ctx.request and isinstance(ctx.request, Request): - headers_info = dict(ctx.request.headers) - return [TextContent(type="text", text=json.dumps(headers_info))] - - elif name == "echo_context": - # Return full context information - context_data: dict[str, Any] = { - "request_id": args.get("request_id"), - "headers": {}, - "method": None, - "path": None, - } - if ctx.request and isinstance(ctx.request, Request): - request = ctx.request - context_data["headers"] = dict(request.headers) - context_data["method"] = request.method - context_data["path"] = request.url.path - return [ - TextContent( - type="text", - text=json.dumps(context_data), - ) - ] - return [TextContent(type="text", text=f"Unknown tool: {name}")] +async def _handle_context_call_tool( # pragma: no cover + ctx: ServerRequestContext, params: CallToolRequestParams +) -> CallToolResult: + name = params.name + args = params.arguments or {} + + if name == "echo_headers": + headers_info: dict[str, Any] = {} + if ctx.request and isinstance(ctx.request, Request): + headers_info = dict(ctx.request.headers) + return CallToolResult(content=[TextContent(type="text", text=json.dumps(headers_info))]) + + elif name == "echo_context": + context_data: dict[str, Any] = { + "request_id": args.get("request_id"), + "headers": {}, + "method": None, + "path": None, + } + if ctx.request and isinstance(ctx.request, Request): + request = ctx.request + context_data["headers"] = dict(request.headers) + context_data["method"] = request.method + context_data["path"] = request.url.path + return CallToolResult(content=[TextContent(type="text", text=json.dumps(context_data))]) + + return CallToolResult(content=[TextContent(type="text", text=f"Unknown tool: {name}")]) # Server runner for context-aware testing def run_context_aware_server(port: int): # pragma: no cover """Run the context-aware test server.""" - server = ContextAwareServerTest() + server = Server( + "ContextAwareServer", + on_list_tools=_handle_context_list_tools, + on_call_tool=_handle_context_call_tool, + ) session_manager = StreamableHTTPSessionManager( app=server, @@ -1707,8 +1784,8 @@ async def test_handle_sse_event_skips_empty_data(): # Create a mock SSE event with empty data (keep-alive ping) mock_sse = ServerSentEvent(event="message", data="", id=None, retry=None) - # Create a mock stream writer - write_stream, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](1) + # Create a context-aware stream writer (matches StreamWriter type alias) + write_stream, read_stream = create_context_streams[SessionMessage | Exception](1) try: # Call _handle_sse_event with empty data - should return False and not raise @@ -1718,8 +1795,9 @@ async def test_handle_sse_event_skips_empty_data(): assert result is False # Nothing should have been written to the stream - # Check buffer is empty (statistics().current_buffer_used returns buffer size) - assert write_stream.statistics().current_buffer_used == 0 + with pytest.raises(TimeoutError): + with anyio.fail_after(0): + await read_stream.receive() finally: await write_stream.aclose() await read_stream.aclose() @@ -2084,11 +2162,12 @@ async def on_resumption_token(token: str) -> None: assert isinstance(result.content[0], TextContent) assert "Completed 3 checkpoints" in result.content[0].text - # 4 priming + 3 notifications + 1 response = 8 tokens - assert len(resumption_tokens) == 8, ( # pragma: lax no cover - f"Expected 8 resumption tokens (4 priming + 3 notifs + 1 response), " - f"got {len(resumption_tokens)}: {resumption_tokens}" - ) + # 4 priming + 3 notifications + 1 response = 8 tokens. All tokens are + # captured before send_request returns, so this is safe to check here. + assert len(resumption_tokens) == 8, ( + f"Expected 8 resumption tokens (4 priming + 3 notifs + 1 response), " + f"got {len(resumption_tokens)}: {resumption_tokens}" + ) @pytest.mark.anyio diff --git a/tests/shared/test_ws.py b/tests/shared/test_ws.py index 07e19195d5..482dcdcf32 100644 --- a/tests/shared/test_ws.py +++ b/tests/shared/test_ws.py @@ -1,188 +1,51 @@ -import multiprocessing -import socket -import time -from collections.abc import AsyncGenerator, Generator -from typing import Any -from urllib.parse import urlparse +"""Smoke test for the WebSocket transport. + +Runs the full WS stack end-to-end over a real TCP connection, covering both +``src/mcp/client/websocket.py`` and ``src/mcp/server/websocket.py``. MCP +semantics (error propagation, timeouts, etc.) are transport-agnostic and are +covered in ``tests/client/test_client.py`` and ``tests/issues/test_88_random_error.py``. +""" + +from collections.abc import Generator -import anyio import pytest -import uvicorn from starlette.applications import Starlette from starlette.routing import WebSocketRoute from starlette.websockets import WebSocket -from mcp import MCPError from mcp.client.session import ClientSession from mcp.client.websocket import websocket_client from mcp.server import Server from mcp.server.websocket import websocket_server -from mcp.types import EmptyResult, InitializeResult, ReadResourceResult, TextContent, TextResourceContents, Tool -from tests.test_helpers import wait_for_server +from mcp.types import EmptyResult, InitializeResult +from tests.test_helpers import run_uvicorn_in_thread SERVER_NAME = "test_server_for_WS" -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: - return f"ws://127.0.0.1:{server_port}" - - -# Test server implementation -class ServerTest(Server): # pragma: no cover - def __init__(self): - super().__init__(SERVER_NAME) - - @self.read_resource() - async def handle_read_resource(uri: str) -> str | bytes: - parsed = urlparse(uri) - if parsed.scheme == "foobar": - return f"Read {parsed.netloc}" - elif parsed.scheme == "slow": - # Simulate a slow resource - await anyio.sleep(2.0) - return f"Slow response from {parsed.netloc}" - - raise MCPError(code=404, message="OOPS! no resource with that URI was found") - - @self.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="test_tool", - description="A test tool", - input_schema={"type": "object", "properties": {}}, - ) - ] +def make_server_app() -> Starlette: + srv = Server(SERVER_NAME) - @self.call_tool() - async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: - return [TextContent(type="text", text=f"Called {name}")] - - -# Test fixtures -def make_server_app() -> Starlette: # pragma: no cover - """Create test Starlette app with WebSocket transport""" - server = ServerTest() - - async def handle_ws(websocket: WebSocket): + async def handle_ws(websocket: WebSocket) -> None: async with websocket_server(websocket.scope, websocket.receive, websocket.send) as streams: - await server.run(streams[0], streams[1], server.create_initialization_options()) - - app = Starlette(routes=[WebSocketRoute("/ws", endpoint=handle_ws)]) - return app - - -def run_server(server_port: int) -> None: # pragma: no cover - app = make_server_app() - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting server on {server_port}") - server.run() - - # Give server time to start - while not server.started: - print("waiting for server to start") - time.sleep(0.5) - - -@pytest.fixture() -def server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True) - print("starting process") - proc.start() - - # Wait for server to be running - print("waiting for server to start") - wait_for_server(server_port) + await srv.run(streams[0], streams[1], srv.create_initialization_options()) - yield + return Starlette(routes=[WebSocketRoute("/ws", endpoint=handle_ws)]) - print("killing server") - # Signal the server to stop - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("server process failed to terminate") - -@pytest.fixture() -async def initialized_ws_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: - """Create and initialize a WebSocket client session""" - async with websocket_client(server_url + "/ws") as streams: - async with ClientSession(*streams) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == SERVER_NAME - - # Test ping - ping_result = await session.send_ping() - assert isinstance(ping_result, EmptyResult) - - yield session +@pytest.fixture +def ws_server_url() -> Generator[str, None, None]: + with run_uvicorn_in_thread(make_server_app()) as base_url: + yield base_url.replace("http://", "ws://") + "/ws" -# Tests @pytest.mark.anyio -async def test_ws_client_basic_connection(server: None, server_url: str) -> None: - """Test the WebSocket connection establishment""" - async with websocket_client(server_url + "/ws") as streams: +async def test_ws_client_basic_connection(ws_server_url: str) -> None: + async with websocket_client(ws_server_url) as streams: async with ClientSession(*streams) as session: - # Test initialization result = await session.initialize() assert isinstance(result, InitializeResult) assert result.server_info.name == SERVER_NAME - # Test ping ping_result = await session.send_ping() assert isinstance(ping_result, EmptyResult) - - -@pytest.mark.anyio -async def test_ws_client_happy_request_and_response( - initialized_ws_client_session: ClientSession, -) -> None: - """Test a successful request and response via WebSocket""" - result = await initialized_ws_client_session.read_resource("foobar://example") - assert isinstance(result, ReadResourceResult) - assert isinstance(result.contents, list) - assert len(result.contents) > 0 - assert isinstance(result.contents[0], TextResourceContents) - assert result.contents[0].text == "Read example" - - -@pytest.mark.anyio -async def test_ws_client_exception_handling( - initialized_ws_client_session: ClientSession, -) -> None: - """Test exception handling in WebSocket communication""" - with pytest.raises(MCPError) as exc_info: - await initialized_ws_client_session.read_resource("unknown://example") - assert exc_info.value.error.code == 404 - - -@pytest.mark.anyio -async def test_ws_client_timeout( - initialized_ws_client_session: ClientSession, -) -> None: - """Test timeout handling in WebSocket communication""" - # Set a very short timeout to trigger a timeout exception - with pytest.raises(TimeoutError): - with anyio.fail_after(0.1): # 100ms timeout - await initialized_ws_client_session.read_resource("slow://example") - - # Now test that we can still use the session after a timeout - with anyio.fail_after(5): # Longer timeout to allow completion - result = await initialized_ws_client_session.read_resource("foobar://example") - assert isinstance(result, ReadResourceResult) - assert isinstance(result.contents, list) - assert len(result.contents) > 0 - assert isinstance(result.contents[0], TextResourceContents) - assert result.contents[0].text == "Read example" diff --git a/tests/test_examples.py b/tests/test_examples.py index aa9de09579..3af82f04c5 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -5,7 +5,6 @@ # pyright: reportUnknownArgumentType=false # pyright: reportUnknownMemberType=false -import sys from pathlib import Path import pytest @@ -65,12 +64,17 @@ async def test_direct_call_tool_result_return(): @pytest.mark.anyio -async def test_desktop(monkeypatch: pytest.MonkeyPatch): +async def test_desktop(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): """Test the desktop server""" - # Mock desktop directory listing - mock_files = [Path("/fake/path/file1.txt"), Path("/fake/path/file2.txt")] - monkeypatch.setattr(Path, "iterdir", lambda self: mock_files) # type: ignore[reportUnknownArgumentType] - monkeypatch.setattr(Path, "home", lambda: Path("/fake/home")) + # Build a real Desktop directory under tmp_path rather than patching + # Path.iterdir — a class-level patch breaks jsonschema_specifications' + # import-time schema discovery when this test happens to be the first + # tool call in an xdist worker. + desktop = tmp_path / "Desktop" + desktop.mkdir() + (desktop / "file1.txt").touch() + (desktop / "file2.txt").touch() + monkeypatch.setattr(Path, "home", lambda: tmp_path) from examples.mcpserver.desktop import mcp @@ -85,15 +89,8 @@ async def test_desktop(monkeypatch: pytest.MonkeyPatch): content = result.contents[0] assert isinstance(content, TextResourceContents) assert isinstance(content.text, str) - if sys.platform == "win32": # pragma: no cover - file_1 = "/fake/path/file1.txt".replace("/", "\\\\") # might be a bug - file_2 = "/fake/path/file2.txt".replace("/", "\\\\") # might be a bug - assert file_1 in content.text - assert file_2 in content.text - # might be a bug, but the test is passing - else: # pragma: lax no cover - assert "/fake/path/file1.txt" in content.text - assert "/fake/path/file2.txt" in content.text + assert "file1.txt" in content.text + assert "file2.txt" in content.text # TODO(v2): Change back to README.md when v2 is released diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 5c04c269ff..810c72820b 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,7 +1,61 @@ """Common test utilities for MCP server tests.""" import socket +import threading import time +from collections.abc import Generator +from contextlib import contextmanager +from typing import Any + +import uvicorn + +_SERVER_SHUTDOWN_TIMEOUT_S = 5.0 + + +@contextmanager +def run_uvicorn_in_thread(app: Any, **config_kwargs: Any) -> Generator[str, None, None]: + """Run a uvicorn server in a background thread on an ephemeral port. + + The socket is bound and put into listening state *before* the thread + starts, so the port is known immediately with no wait. The kernel's + listen queue buffers any connections that arrive before uvicorn's event + loop reaches ``accept()``, so callers can connect as soon as this + function yields — no polling, no sleeps, no startup race. + + This also avoids the TOCTOU race of the old pick-a-port-then-rebind + pattern: the socket passed here is the one uvicorn serves on, with no + gap where another pytest-xdist worker could claim it. + + Args: + app: ASGI application to serve. + **config_kwargs: Additional keyword arguments for :class:`uvicorn.Config` + (e.g. ``log_level``). ``host``/``port`` are ignored since the + socket is pre-bound. + + Yields: + The base URL of the running server, e.g. ``http://127.0.0.1:54321``. + """ + host = "127.0.0.1" + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((host, 0)) + sock.listen() + port = sock.getsockname()[1] + + config_kwargs.setdefault("log_level", "error") + # Uvicorn's interface autodetection calls asyncio.iscoroutinefunction, + # which Python 3.14 deprecates. Under filterwarnings=error this crashes + # the server thread silently. Starlette is asgi3; skip the autodetect. + config_kwargs.setdefault("interface", "asgi3") + server = uvicorn.Server(config=uvicorn.Config(app=app, **config_kwargs)) + + thread = threading.Thread(target=server.run, kwargs={"sockets": [sock]}, daemon=True) + thread.start() + try: + yield f"http://{host}:{port}" + finally: + server.should_exit = True + thread.join(timeout=_SERVER_SHUTDOWN_TIMEOUT_S) def wait_for_server(port: int, timeout: float = 20.0) -> None: diff --git a/uv.lock b/uv.lock index 5d3a83f376..5b72e97fce 100644 --- a/uv.lock +++ b/uv.lock @@ -28,6 +28,19 @@ members = [ "mcp-sse-polling-demo", "mcp-structured-output-lowlevel", ] +build-constraints = [ + { name = "dunamai", specifier = "==1.26.1" }, + { name = "hatchling", specifier = "==1.29.0" }, + { name = "jinja2", specifier = "==3.1.6" }, + { name = "markupsafe", specifier = "==3.0.3" }, + { name = "packaging", specifier = "==26.1" }, + { name = "pathspec", specifier = "==1.0.4" }, + { name = "pluggy", specifier = "==1.6.0" }, + { name = "setuptools", specifier = "==82.0.1" }, + { name = "tomlkit", specifier = "==0.14.0" }, + { name = "trove-classifiers", specifier = "==2026.1.14.14" }, + { name = "uv-dynamic-versioning", specifier = "==0.14.0" }, +] [[package]] name = "annotated-types" @@ -96,7 +109,7 @@ wheels = [ [[package]] name = "black" -version = "25.1.0" +version = "26.3.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, @@ -104,28 +117,66 @@ dependencies = [ { name = "packaging" }, { name = "pathspec" }, { name = "platformdirs" }, + { name = "pytokens" }, { name = "tomli", marker = "python_full_version < '3.11'" }, { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/94/49/26a7b0f3f35da4b5a65f081943b7bcd22d7002f5f0fb8098ec1ff21cb6ef/black-25.1.0.tar.gz", hash = "sha256:33496d5cd1222ad73391352b4ae8da15253c5de89b93a80b3e2c8d9a19ec2666", size = 649449, upload-time = "2025-01-29T04:15:40.373Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e1/c5/61175d618685d42b005847464b8fb4743a67b1b8fdb75e50e5a96c31a27a/black-26.3.1.tar.gz", hash = "sha256:2c50f5063a9641c7eed7795014ba37b0f5fa227f3d408b968936e24bc0566b07", size = 666155, upload-time = "2026-03-12T03:36:03.593Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/a8/11170031095655d36ebc6664fe0897866f6023892396900eec0e8fdc4299/black-26.3.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:86a8b5035fce64f5dcd1b794cf8ec4d31fe458cf6ce3986a30deb434df82a1d2", size = 1866562, upload-time = "2026-03-12T03:39:58.639Z" }, + { url = "https://files.pythonhosted.org/packages/69/ce/9e7548d719c3248c6c2abfd555d11169457cbd584d98d179111338423790/black-26.3.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5602bdb96d52d2d0672f24f6ffe5218795736dd34807fd0fd55ccd6bf206168b", size = 1703623, upload-time = "2026-03-12T03:40:00.347Z" }, + { url = "https://files.pythonhosted.org/packages/7f/0a/8d17d1a9c06f88d3d030d0b1d4373c1551146e252afe4547ed601c0e697f/black-26.3.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6c54a4a82e291a1fee5137371ab488866b7c86a3305af4026bdd4dc78642e1ac", size = 1768388, upload-time = "2026-03-12T03:40:01.765Z" }, + { url = "https://files.pythonhosted.org/packages/52/79/c1ee726e221c863cde5164f925bacf183dfdf0397d4e3f94889439b947b4/black-26.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:6e131579c243c98f35bce64a7e08e87fb2d610544754675d4a0e73a070a5aa3a", size = 1412969, upload-time = "2026-03-12T03:40:03.252Z" }, + { url = "https://files.pythonhosted.org/packages/73/a5/15c01d613f5756f68ed8f6d4ec0a1e24b82b18889fa71affd3d1f7fad058/black-26.3.1-cp310-cp310-win_arm64.whl", hash = "sha256:5ed0ca58586c8d9a487352a96b15272b7fa55d139fc8496b519e78023a8dab0a", size = 1220345, upload-time = "2026-03-12T03:40:04.892Z" }, + { url = "https://files.pythonhosted.org/packages/17/57/5f11c92861f9c92eb9dddf515530bc2d06db843e44bdcf1c83c1427824bc/black-26.3.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:28ef38aee69e4b12fda8dba75e21f9b4f979b490c8ac0baa7cb505369ac9e1ff", size = 1851987, upload-time = "2026-03-12T03:40:06.248Z" }, + { url = "https://files.pythonhosted.org/packages/54/aa/340a1463660bf6831f9e39646bf774086dbd8ca7fc3cded9d59bbdf4ad0a/black-26.3.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bf9bf162ed91a26f1adba8efda0b573bc6924ec1408a52cc6f82cb73ec2b142c", size = 1689499, upload-time = "2026-03-12T03:40:07.642Z" }, + { url = "https://files.pythonhosted.org/packages/f3/01/b726c93d717d72733da031d2de10b92c9fa4c8d0c67e8a8a372076579279/black-26.3.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:474c27574d6d7037c1bc875a81d9be0a9a4f9ee95e62800dab3cfaadbf75acd5", size = 1754369, upload-time = "2026-03-12T03:40:09.279Z" }, + { url = "https://files.pythonhosted.org/packages/e3/09/61e91881ca291f150cfc9eb7ba19473c2e59df28859a11a88248b5cbbc4d/black-26.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:5e9d0d86df21f2e1677cc4bd090cd0e446278bcbbe49bf3659c308c3e402843e", size = 1413613, upload-time = "2026-03-12T03:40:10.943Z" }, + { url = "https://files.pythonhosted.org/packages/16/73/544f23891b22e7efe4d8f812371ab85b57f6a01b2fc45e3ba2e52ba985b8/black-26.3.1-cp311-cp311-win_arm64.whl", hash = "sha256:9a5e9f45e5d5e1c5b5c29b3bd4265dcc90e8b92cf4534520896ed77f791f4da5", size = 1219719, upload-time = "2026-03-12T03:40:12.597Z" }, + { url = "https://files.pythonhosted.org/packages/dc/f8/da5eae4fc75e78e6dceb60624e1b9662ab00d6b452996046dfa9b8a6025b/black-26.3.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b5e6f89631eb88a7302d416594a32faeee9fb8fb848290da9d0a5f2903519fc1", size = 1895920, upload-time = "2026-03-12T03:40:13.921Z" }, + { url = "https://files.pythonhosted.org/packages/2c/9f/04e6f26534da2e1629b2b48255c264cabf5eedc5141d04516d9d68a24111/black-26.3.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:41cd2012d35b47d589cb8a16faf8a32ef7a336f56356babd9fcf70939ad1897f", size = 1718499, upload-time = "2026-03-12T03:40:15.239Z" }, + { url = "https://files.pythonhosted.org/packages/04/91/a5935b2a63e31b331060c4a9fdb5a6c725840858c599032a6f3aac94055f/black-26.3.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0f76ff19ec5297dd8e66eb64deda23631e642c9393ab592826fd4bdc97a4bce7", size = 1794994, upload-time = "2026-03-12T03:40:17.124Z" }, + { url = "https://files.pythonhosted.org/packages/e7/0a/86e462cdd311a3c2a8ece708d22aba17d0b2a0d5348ca34b40cdcbea512e/black-26.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:ddb113db38838eb9f043623ba274cfaf7d51d5b0c22ecb30afe58b1bb8322983", size = 1420867, upload-time = "2026-03-12T03:40:18.83Z" }, + { url = "https://files.pythonhosted.org/packages/5b/e5/22515a19cb7eaee3440325a6b0d95d2c0e88dd180cb011b12ae488e031d1/black-26.3.1-cp312-cp312-win_arm64.whl", hash = "sha256:dfdd51fc3e64ea4f35873d1b3fb25326773d55d2329ff8449139ebaad7357efb", size = 1230124, upload-time = "2026-03-12T03:40:20.425Z" }, + { url = "https://files.pythonhosted.org/packages/f5/77/5728052a3c0450c53d9bb3945c4c46b91baa62b2cafab6801411b6271e45/black-26.3.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:855822d90f884905362f602880ed8b5df1b7e3ee7d0db2502d4388a954cc8c54", size = 1895034, upload-time = "2026-03-12T03:40:21.813Z" }, + { url = "https://files.pythonhosted.org/packages/52/73/7cae55fdfdfbe9d19e9a8d25d145018965fe2079fa908101c3733b0c55a0/black-26.3.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:8a33d657f3276328ce00e4d37fe70361e1ec7614da5d7b6e78de5426cb56332f", size = 1718503, upload-time = "2026-03-12T03:40:23.666Z" }, + { url = "https://files.pythonhosted.org/packages/e1/87/af89ad449e8254fdbc74654e6467e3c9381b61472cc532ee350d28cfdafb/black-26.3.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f1cd08e99d2f9317292a311dfe578fd2a24b15dbce97792f9c4d752275c1fa56", size = 1793557, upload-time = "2026-03-12T03:40:25.497Z" }, + { url = "https://files.pythonhosted.org/packages/43/10/d6c06a791d8124b843bf325ab4ac7d2f5b98731dff84d6064eafd687ded1/black-26.3.1-cp313-cp313-win_amd64.whl", hash = "sha256:c7e72339f841b5a237ff14f7d3880ddd0fc7f98a1199e8c4327f9a4f478c1839", size = 1422766, upload-time = "2026-03-12T03:40:27.14Z" }, + { url = "https://files.pythonhosted.org/packages/59/4f/40a582c015f2d841ac24fed6390bd68f0fc896069ff3a886317959c9daf8/black-26.3.1-cp313-cp313-win_arm64.whl", hash = "sha256:afc622538b430aa4c8c853f7f63bc582b3b8030fd8c80b70fb5fa5b834e575c2", size = 1232140, upload-time = "2026-03-12T03:40:28.882Z" }, + { url = "https://files.pythonhosted.org/packages/d5/da/e36e27c9cebc1311b7579210df6f1c86e50f2d7143ae4fcf8a5017dc8809/black-26.3.1-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:2d6bfaf7fd0993b420bed691f20f9492d53ce9a2bcccea4b797d34e947318a78", size = 1889234, upload-time = "2026-03-12T03:40:30.964Z" }, + { url = "https://files.pythonhosted.org/packages/0e/7b/9871acf393f64a5fa33668c19350ca87177b181f44bb3d0c33b2d534f22c/black-26.3.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:f89f2ab047c76a9c03f78d0d66ca519e389519902fa27e7a91117ef7611c0568", size = 1720522, upload-time = "2026-03-12T03:40:32.346Z" }, + { url = "https://files.pythonhosted.org/packages/03/87/e766c7f2e90c07fb7586cc787c9ae6462b1eedab390191f2b7fc7f6170a9/black-26.3.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b07fc0dab849d24a80a29cfab8d8a19187d1c4685d8a5e6385a5ce323c1f015f", size = 1787824, upload-time = "2026-03-12T03:40:33.636Z" }, + { url = "https://files.pythonhosted.org/packages/ac/94/2424338fb2d1875e9e83eed4c8e9c67f6905ec25afd826a911aea2b02535/black-26.3.1-cp314-cp314-win_amd64.whl", hash = "sha256:0126ae5b7c09957da2bdbd91a9ba1207453feada9e9fe51992848658c6c8e01c", size = 1445855, upload-time = "2026-03-12T03:40:35.442Z" }, + { url = "https://files.pythonhosted.org/packages/86/43/0c3338bd928afb8ee7471f1a4eec3bdbe2245ccb4a646092a222e8669840/black-26.3.1-cp314-cp314-win_arm64.whl", hash = "sha256:92c0ec1f2cc149551a2b7b47efc32c866406b6891b0ee4625e95967c8f4acfb1", size = 1258109, upload-time = "2026-03-12T03:40:36.832Z" }, + { url = "https://files.pythonhosted.org/packages/8e/0d/52d98722666d6fc6c3dd4c76df339501d6efd40e0ff95e6186a7b7f0befd/black-26.3.1-py3-none-any.whl", hash = "sha256:2bd5aa94fc267d38bb21a70d7410a89f1a1d318841855f698746f8e7f51acd1b", size = 207542, upload-time = "2026-03-12T03:36:01.668Z" }, +] + +[[package]] +name = "cairocffi" +version = "1.7.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/70/c5/1a4dc131459e68a173cbdab5fad6b524f53f9c1ef7861b7698e998b837cc/cairocffi-1.7.1.tar.gz", hash = "sha256:2e48ee864884ec4a3a34bfa8c9ab9999f688286eb714a15a43ec9d068c36557b", size = 88096, upload-time = "2024-06-18T10:56:06.741Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/93/d8/ba13451aa6b745c49536e87b6bf8f629b950e84bd0e8308f7dc6883b67e2/cairocffi-1.7.1-py3-none-any.whl", hash = "sha256:9803a0e11f6c962f3b0ae2ec8ba6ae45e957a146a004697a1ac1bbf16b073b3f", size = 75611, upload-time = "2024-06-18T10:55:59.489Z" }, +] + +[[package]] +name = "cairosvg" +version = "2.8.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cairocffi" }, + { name = "cssselect2" }, + { name = "defusedxml" }, + { name = "pillow" }, + { name = "tinycss2" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ab/b9/5106168bd43d7cd8b7cc2a2ee465b385f14b63f4c092bb89eee2d48c8e67/cairosvg-2.8.2.tar.gz", hash = "sha256:07cbf4e86317b27a92318a4cac2a4bb37a5e9c1b8a27355d06874b22f85bef9f", size = 8398590, upload-time = "2025-05-15T06:56:32.653Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/4d/3b/4ba3f93ac8d90410423fdd31d7541ada9bcee1df32fb90d26de41ed40e1d/black-25.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:759e7ec1e050a15f89b770cefbf91ebee8917aac5c20483bc2d80a6c3a04df32", size = 1629419, upload-time = "2025-01-29T05:37:06.642Z" }, - { url = "https://files.pythonhosted.org/packages/b4/02/0bde0485146a8a5e694daed47561785e8b77a0466ccc1f3e485d5ef2925e/black-25.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0e519ecf93120f34243e6b0054db49c00a35f84f195d5bce7e9f5cfc578fc2da", size = 1461080, upload-time = "2025-01-29T05:37:09.321Z" }, - { url = "https://files.pythonhosted.org/packages/52/0e/abdf75183c830eaca7589144ff96d49bce73d7ec6ad12ef62185cc0f79a2/black-25.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:055e59b198df7ac0b7efca5ad7ff2516bca343276c466be72eb04a3bcc1f82d7", size = 1766886, upload-time = "2025-01-29T04:18:24.432Z" }, - { url = "https://files.pythonhosted.org/packages/dc/a6/97d8bb65b1d8a41f8a6736222ba0a334db7b7b77b8023ab4568288f23973/black-25.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:db8ea9917d6f8fc62abd90d944920d95e73c83a5ee3383493e35d271aca872e9", size = 1419404, upload-time = "2025-01-29T04:19:04.296Z" }, - { url = "https://files.pythonhosted.org/packages/7e/4f/87f596aca05c3ce5b94b8663dbfe242a12843caaa82dd3f85f1ffdc3f177/black-25.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a39337598244de4bae26475f77dda852ea00a93bd4c728e09eacd827ec929df0", size = 1614372, upload-time = "2025-01-29T05:37:11.71Z" }, - { url = "https://files.pythonhosted.org/packages/e7/d0/2c34c36190b741c59c901e56ab7f6e54dad8df05a6272a9747ecef7c6036/black-25.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:96c1c7cd856bba8e20094e36e0f948718dc688dba4a9d78c3adde52b9e6c2299", size = 1442865, upload-time = "2025-01-29T05:37:14.309Z" }, - { url = "https://files.pythonhosted.org/packages/21/d4/7518c72262468430ead45cf22bd86c883a6448b9eb43672765d69a8f1248/black-25.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bce2e264d59c91e52d8000d507eb20a9aca4a778731a08cfff7e5ac4a4bb7096", size = 1749699, upload-time = "2025-01-29T04:18:17.688Z" }, - { url = "https://files.pythonhosted.org/packages/58/db/4f5beb989b547f79096e035c4981ceb36ac2b552d0ac5f2620e941501c99/black-25.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:172b1dbff09f86ce6f4eb8edf9dede08b1fce58ba194c87d7a4f1a5aa2f5b3c2", size = 1428028, upload-time = "2025-01-29T04:18:51.711Z" }, - { url = "https://files.pythonhosted.org/packages/83/71/3fe4741df7adf015ad8dfa082dd36c94ca86bb21f25608eb247b4afb15b2/black-25.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4b60580e829091e6f9238c848ea6750efed72140b91b048770b64e74fe04908b", size = 1650988, upload-time = "2025-01-29T05:37:16.707Z" }, - { url = "https://files.pythonhosted.org/packages/13/f3/89aac8a83d73937ccd39bbe8fc6ac8860c11cfa0af5b1c96d081facac844/black-25.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1e2978f6df243b155ef5fa7e558a43037c3079093ed5d10fd84c43900f2d8ecc", size = 1453985, upload-time = "2025-01-29T05:37:18.273Z" }, - { url = "https://files.pythonhosted.org/packages/6f/22/b99efca33f1f3a1d2552c714b1e1b5ae92efac6c43e790ad539a163d1754/black-25.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b48735872ec535027d979e8dcb20bf4f70b5ac75a8ea99f127c106a7d7aba9f", size = 1783816, upload-time = "2025-01-29T04:18:33.823Z" }, - { url = "https://files.pythonhosted.org/packages/18/7e/a27c3ad3822b6f2e0e00d63d58ff6299a99a5b3aee69fa77cd4b0076b261/black-25.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:ea0213189960bda9cf99be5b8c8ce66bb054af5e9e861249cd23471bd7b0b3ba", size = 1440860, upload-time = "2025-01-29T04:19:12.944Z" }, - { url = "https://files.pythonhosted.org/packages/98/87/0edf98916640efa5d0696e1abb0a8357b52e69e82322628f25bf14d263d1/black-25.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8f0b18a02996a836cc9c9c78e5babec10930862827b1b724ddfe98ccf2f2fe4f", size = 1650673, upload-time = "2025-01-29T05:37:20.574Z" }, - { url = "https://files.pythonhosted.org/packages/52/e5/f7bf17207cf87fa6e9b676576749c6b6ed0d70f179a3d812c997870291c3/black-25.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:afebb7098bfbc70037a053b91ae8437c3857482d3a690fefc03e9ff7aa9a5fd3", size = 1453190, upload-time = "2025-01-29T05:37:22.106Z" }, - { url = "https://files.pythonhosted.org/packages/e3/ee/adda3d46d4a9120772fae6de454c8495603c37c4c3b9c60f25b1ab6401fe/black-25.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:030b9759066a4ee5e5aca28c3c77f9c64789cdd4de8ac1df642c40b708be6171", size = 1782926, upload-time = "2025-01-29T04:18:58.564Z" }, - { url = "https://files.pythonhosted.org/packages/cc/64/94eb5f45dcb997d2082f097a3944cfc7fe87e071907f677e80788a2d7b7a/black-25.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:a22f402b410566e2d1c950708c77ebf5ebd5d0d88a6a2e87c86d9fb48afa0d18", size = 1442613, upload-time = "2025-01-29T04:19:27.63Z" }, - { url = "https://files.pythonhosted.org/packages/09/71/54e999902aed72baf26bca0d50781b01838251a462612966e9fc4891eadd/black-25.1.0-py3-none-any.whl", hash = "sha256:95e8176dae143ba9097f351d174fdaf0ccd29efb414b362ae3fd72bf0f710717", size = 207646, upload-time = "2025-01-29T04:15:38.082Z" }, + { url = "https://files.pythonhosted.org/packages/67/48/816bd4aaae93dbf9e408c58598bc32f4a8c65f4b86ab560864cb3ee60adb/cairosvg-2.8.2-py3-none-any.whl", hash = "sha256:eab46dad4674f33267a671dce39b64be245911c901c70d65d2b7b0821e852bf5", size = 45773, upload-time = "2025-05-15T06:56:28.552Z" }, ] [[package]] @@ -410,62 +461,84 @@ toml = [ [[package]] name = "cryptography" -version = "46.0.5" +version = "46.0.7" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cffi", marker = "platform_python_implementation != 'PyPy'" }, { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/60/04/ee2a9e8542e4fa2773b81771ff8349ff19cdd56b7258a0cc442639052edb/cryptography-46.0.5.tar.gz", hash = "sha256:abace499247268e3757271b2f1e244b36b06f8515cf27c4d49468fc9eb16e93d", size = 750064, upload-time = "2026-02-10T19:18:38.255Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f7/81/b0bb27f2ba931a65409c6b8a8b358a7f03c0e46eceacddff55f7c84b1f3b/cryptography-46.0.5-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:351695ada9ea9618b3500b490ad54c739860883df6c1f555e088eaf25b1bbaad", size = 7176289, upload-time = "2026-02-10T19:17:08.274Z" }, - { url = "https://files.pythonhosted.org/packages/ff/9e/6b4397a3e3d15123de3b1806ef342522393d50736c13b20ec4c9ea6693a6/cryptography-46.0.5-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:c18ff11e86df2e28854939acde2d003f7984f721eba450b56a200ad90eeb0e6b", size = 4275637, upload-time = "2026-02-10T19:17:10.53Z" }, - { url = "https://files.pythonhosted.org/packages/63/e7/471ab61099a3920b0c77852ea3f0ea611c9702f651600397ac567848b897/cryptography-46.0.5-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d7e3d356b8cd4ea5aff04f129d5f66ebdc7b6f8eae802b93739ed520c47c79b", size = 4424742, upload-time = "2026-02-10T19:17:12.388Z" }, - { url = "https://files.pythonhosted.org/packages/37/53/a18500f270342d66bf7e4d9f091114e31e5ee9e7375a5aba2e85a91e0044/cryptography-46.0.5-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:50bfb6925eff619c9c023b967d5b77a54e04256c4281b0e21336a130cd7fc263", size = 4277528, upload-time = "2026-02-10T19:17:13.853Z" }, - { url = "https://files.pythonhosted.org/packages/22/29/c2e812ebc38c57b40e7c583895e73c8c5adb4d1e4a0cc4c5a4fdab2b1acc/cryptography-46.0.5-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:803812e111e75d1aa73690d2facc295eaefd4439be1023fefc4995eaea2af90d", size = 4947993, upload-time = "2026-02-10T19:17:15.618Z" }, - { url = "https://files.pythonhosted.org/packages/6b/e7/237155ae19a9023de7e30ec64e5d99a9431a567407ac21170a046d22a5a3/cryptography-46.0.5-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3ee190460e2fbe447175cda91b88b84ae8322a104fc27766ad09428754a618ed", size = 4456855, upload-time = "2026-02-10T19:17:17.221Z" }, - { url = "https://files.pythonhosted.org/packages/2d/87/fc628a7ad85b81206738abbd213b07702bcbdada1dd43f72236ef3cffbb5/cryptography-46.0.5-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:f145bba11b878005c496e93e257c1e88f154d278d2638e6450d17e0f31e558d2", size = 3984635, upload-time = "2026-02-10T19:17:18.792Z" }, - { url = "https://files.pythonhosted.org/packages/84/29/65b55622bde135aedf4565dc509d99b560ee4095e56989e815f8fd2aa910/cryptography-46.0.5-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:e9251e3be159d1020c4030bd2e5f84d6a43fe54b6c19c12f51cde9542a2817b2", size = 4277038, upload-time = "2026-02-10T19:17:20.256Z" }, - { url = "https://files.pythonhosted.org/packages/bc/36/45e76c68d7311432741faf1fbf7fac8a196a0a735ca21f504c75d37e2558/cryptography-46.0.5-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:47fb8a66058b80e509c47118ef8a75d14c455e81ac369050f20ba0d23e77fee0", size = 4912181, upload-time = "2026-02-10T19:17:21.825Z" }, - { url = "https://files.pythonhosted.org/packages/6d/1a/c1ba8fead184d6e3d5afcf03d569acac5ad063f3ac9fb7258af158f7e378/cryptography-46.0.5-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:4c3341037c136030cb46e4b1e17b7418ea4cbd9dd207e4a6f3b2b24e0d4ac731", size = 4456482, upload-time = "2026-02-10T19:17:25.133Z" }, - { url = "https://files.pythonhosted.org/packages/f9/e5/3fb22e37f66827ced3b902cf895e6a6bc1d095b5b26be26bd13c441fdf19/cryptography-46.0.5-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:890bcb4abd5a2d3f852196437129eb3667d62630333aacc13dfd470fad3aaa82", size = 4405497, upload-time = "2026-02-10T19:17:26.66Z" }, - { url = "https://files.pythonhosted.org/packages/1a/df/9d58bb32b1121a8a2f27383fabae4d63080c7ca60b9b5c88be742be04ee7/cryptography-46.0.5-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:80a8d7bfdf38f87ca30a5391c0c9ce4ed2926918e017c29ddf643d0ed2778ea1", size = 4667819, upload-time = "2026-02-10T19:17:28.569Z" }, - { url = "https://files.pythonhosted.org/packages/ea/ed/325d2a490c5e94038cdb0117da9397ece1f11201f425c4e9c57fe5b9f08b/cryptography-46.0.5-cp311-abi3-win32.whl", hash = "sha256:60ee7e19e95104d4c03871d7d7dfb3d22ef8a9b9c6778c94e1c8fcc8365afd48", size = 3028230, upload-time = "2026-02-10T19:17:30.518Z" }, - { url = "https://files.pythonhosted.org/packages/e9/5a/ac0f49e48063ab4255d9e3b79f5def51697fce1a95ea1370f03dc9db76f6/cryptography-46.0.5-cp311-abi3-win_amd64.whl", hash = "sha256:38946c54b16c885c72c4f59846be9743d699eee2b69b6988e0a00a01f46a61a4", size = 3480909, upload-time = "2026-02-10T19:17:32.083Z" }, - { url = "https://files.pythonhosted.org/packages/00/13/3d278bfa7a15a96b9dc22db5a12ad1e48a9eb3d40e1827ef66a5df75d0d0/cryptography-46.0.5-cp314-cp314t-macosx_10_9_universal2.whl", hash = "sha256:94a76daa32eb78d61339aff7952ea819b1734b46f73646a07decb40e5b3448e2", size = 7119287, upload-time = "2026-02-10T19:17:33.801Z" }, - { url = "https://files.pythonhosted.org/packages/67/c8/581a6702e14f0898a0848105cbefd20c058099e2c2d22ef4e476dfec75d7/cryptography-46.0.5-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5be7bf2fb40769e05739dd0046e7b26f9d4670badc7b032d6ce4db64dddc0678", size = 4265728, upload-time = "2026-02-10T19:17:35.569Z" }, - { url = "https://files.pythonhosted.org/packages/dd/4a/ba1a65ce8fc65435e5a849558379896c957870dd64fecea97b1ad5f46a37/cryptography-46.0.5-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fe346b143ff9685e40192a4960938545c699054ba11d4f9029f94751e3f71d87", size = 4408287, upload-time = "2026-02-10T19:17:36.938Z" }, - { url = "https://files.pythonhosted.org/packages/f8/67/8ffdbf7b65ed1ac224d1c2df3943553766914a8ca718747ee3871da6107e/cryptography-46.0.5-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:c69fd885df7d089548a42d5ec05be26050ebcd2283d89b3d30676eb32ff87dee", size = 4270291, upload-time = "2026-02-10T19:17:38.748Z" }, - { url = "https://files.pythonhosted.org/packages/f8/e5/f52377ee93bc2f2bba55a41a886fd208c15276ffbd2569f2ddc89d50e2c5/cryptography-46.0.5-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:8293f3dea7fc929ef7240796ba231413afa7b68ce38fd21da2995549f5961981", size = 4927539, upload-time = "2026-02-10T19:17:40.241Z" }, - { url = "https://files.pythonhosted.org/packages/3b/02/cfe39181b02419bbbbcf3abdd16c1c5c8541f03ca8bda240debc467d5a12/cryptography-46.0.5-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:1abfdb89b41c3be0365328a410baa9df3ff8a9110fb75e7b52e66803ddabc9a9", size = 4442199, upload-time = "2026-02-10T19:17:41.789Z" }, - { url = "https://files.pythonhosted.org/packages/c0/96/2fcaeb4873e536cf71421a388a6c11b5bc846e986b2b069c79363dc1648e/cryptography-46.0.5-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:d66e421495fdb797610a08f43b05269e0a5ea7f5e652a89bfd5a7d3c1dee3648", size = 3960131, upload-time = "2026-02-10T19:17:43.379Z" }, - { url = "https://files.pythonhosted.org/packages/d8/d2/b27631f401ddd644e94c5cf33c9a4069f72011821cf3dc7309546b0642a0/cryptography-46.0.5-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:4e817a8920bfbcff8940ecfd60f23d01836408242b30f1a708d93198393a80b4", size = 4270072, upload-time = "2026-02-10T19:17:45.481Z" }, - { url = "https://files.pythonhosted.org/packages/f4/a7/60d32b0370dae0b4ebe55ffa10e8599a2a59935b5ece1b9f06edb73abdeb/cryptography-46.0.5-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:68f68d13f2e1cb95163fa3b4db4bf9a159a418f5f6e7242564fc75fcae667fd0", size = 4892170, upload-time = "2026-02-10T19:17:46.997Z" }, - { url = "https://files.pythonhosted.org/packages/d2/b9/cf73ddf8ef1164330eb0b199a589103c363afa0cf794218c24d524a58eab/cryptography-46.0.5-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:a3d1fae9863299076f05cb8a778c467578262fae09f9dc0ee9b12eb4268ce663", size = 4441741, upload-time = "2026-02-10T19:17:48.661Z" }, - { url = "https://files.pythonhosted.org/packages/5f/eb/eee00b28c84c726fe8fa0158c65afe312d9c3b78d9d01daf700f1f6e37ff/cryptography-46.0.5-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:c4143987a42a2397f2fc3b4d7e3a7d313fbe684f67ff443999e803dd75a76826", size = 4396728, upload-time = "2026-02-10T19:17:50.058Z" }, - { url = "https://files.pythonhosted.org/packages/65/f4/6bc1a9ed5aef7145045114b75b77c2a8261b4d38717bd8dea111a63c3442/cryptography-46.0.5-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:7d731d4b107030987fd61a7f8ab512b25b53cef8f233a97379ede116f30eb67d", size = 4652001, upload-time = "2026-02-10T19:17:51.54Z" }, - { url = "https://files.pythonhosted.org/packages/86/ef/5d00ef966ddd71ac2e6951d278884a84a40ffbd88948ef0e294b214ae9e4/cryptography-46.0.5-cp314-cp314t-win32.whl", hash = "sha256:c3bcce8521d785d510b2aad26ae2c966092b7daa8f45dd8f44734a104dc0bc1a", size = 3003637, upload-time = "2026-02-10T19:17:52.997Z" }, - { url = "https://files.pythonhosted.org/packages/b7/57/f3f4160123da6d098db78350fdfd9705057aad21de7388eacb2401dceab9/cryptography-46.0.5-cp314-cp314t-win_amd64.whl", hash = "sha256:4d8ae8659ab18c65ced284993c2265910f6c9e650189d4e3f68445ef82a810e4", size = 3469487, upload-time = "2026-02-10T19:17:54.549Z" }, - { url = "https://files.pythonhosted.org/packages/e2/fa/a66aa722105ad6a458bebd64086ca2b72cdd361fed31763d20390f6f1389/cryptography-46.0.5-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:4108d4c09fbbf2789d0c926eb4152ae1760d5a2d97612b92d508d96c861e4d31", size = 7170514, upload-time = "2026-02-10T19:17:56.267Z" }, - { url = "https://files.pythonhosted.org/packages/0f/04/c85bdeab78c8bc77b701bf0d9bdcf514c044e18a46dcff330df5448631b0/cryptography-46.0.5-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7d1f30a86d2757199cb2d56e48cce14deddf1f9c95f1ef1b64ee91ea43fe2e18", size = 4275349, upload-time = "2026-02-10T19:17:58.419Z" }, - { url = "https://files.pythonhosted.org/packages/5c/32/9b87132a2f91ee7f5223b091dc963055503e9b442c98fc0b8a5ca765fab0/cryptography-46.0.5-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:039917b0dc418bb9f6edce8a906572d69e74bd330b0b3fea4f79dab7f8ddd235", size = 4420667, upload-time = "2026-02-10T19:18:00.619Z" }, - { url = "https://files.pythonhosted.org/packages/a1/a6/a7cb7010bec4b7c5692ca6f024150371b295ee1c108bdc1c400e4c44562b/cryptography-46.0.5-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:ba2a27ff02f48193fc4daeadf8ad2590516fa3d0adeeb34336b96f7fa64c1e3a", size = 4276980, upload-time = "2026-02-10T19:18:02.379Z" }, - { url = "https://files.pythonhosted.org/packages/8e/7c/c4f45e0eeff9b91e3f12dbd0e165fcf2a38847288fcfd889deea99fb7b6d/cryptography-46.0.5-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:61aa400dce22cb001a98014f647dc21cda08f7915ceb95df0c9eaf84b4b6af76", size = 4939143, upload-time = "2026-02-10T19:18:03.964Z" }, - { url = "https://files.pythonhosted.org/packages/37/19/e1b8f964a834eddb44fa1b9a9976f4e414cbb7aa62809b6760c8803d22d1/cryptography-46.0.5-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3ce58ba46e1bc2aac4f7d9290223cead56743fa6ab94a5d53292ffaac6a91614", size = 4453674, upload-time = "2026-02-10T19:18:05.588Z" }, - { url = "https://files.pythonhosted.org/packages/db/ed/db15d3956f65264ca204625597c410d420e26530c4e2943e05a0d2f24d51/cryptography-46.0.5-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:420d0e909050490d04359e7fdb5ed7e667ca5c3c402b809ae2563d7e66a92229", size = 3978801, upload-time = "2026-02-10T19:18:07.167Z" }, - { url = "https://files.pythonhosted.org/packages/41/e2/df40a31d82df0a70a0daf69791f91dbb70e47644c58581d654879b382d11/cryptography-46.0.5-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:582f5fcd2afa31622f317f80426a027f30dc792e9c80ffee87b993200ea115f1", size = 4276755, upload-time = "2026-02-10T19:18:09.813Z" }, - { url = "https://files.pythonhosted.org/packages/33/45/726809d1176959f4a896b86907b98ff4391a8aa29c0aaaf9450a8a10630e/cryptography-46.0.5-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:bfd56bb4b37ed4f330b82402f6f435845a5f5648edf1ad497da51a8452d5d62d", size = 4901539, upload-time = "2026-02-10T19:18:11.263Z" }, - { url = "https://files.pythonhosted.org/packages/99/0f/a3076874e9c88ecb2ecc31382f6e7c21b428ede6f55aafa1aa272613e3cd/cryptography-46.0.5-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:a3d507bb6a513ca96ba84443226af944b0f7f47dcc9a399d110cd6146481d24c", size = 4452794, upload-time = "2026-02-10T19:18:12.914Z" }, - { url = "https://files.pythonhosted.org/packages/02/ef/ffeb542d3683d24194a38f66ca17c0a4b8bf10631feef44a7ef64e631b1a/cryptography-46.0.5-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9f16fbdf4da055efb21c22d81b89f155f02ba420558db21288b3d0035bafd5f4", size = 4404160, upload-time = "2026-02-10T19:18:14.375Z" }, - { url = "https://files.pythonhosted.org/packages/96/93/682d2b43c1d5f1406ed048f377c0fc9fc8f7b0447a478d5c65ab3d3a66eb/cryptography-46.0.5-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:ced80795227d70549a411a4ab66e8ce307899fad2220ce5ab2f296e687eacde9", size = 4667123, upload-time = "2026-02-10T19:18:15.886Z" }, - { url = "https://files.pythonhosted.org/packages/45/2d/9c5f2926cb5300a8eefc3f4f0b3f3df39db7f7ce40c8365444c49363cbda/cryptography-46.0.5-cp38-abi3-win32.whl", hash = "sha256:02f547fce831f5096c9a567fd41bc12ca8f11df260959ecc7c3202555cc47a72", size = 3010220, upload-time = "2026-02-10T19:18:17.361Z" }, - { url = "https://files.pythonhosted.org/packages/48/ef/0c2f4a8e31018a986949d34a01115dd057bf536905dca38897bacd21fac3/cryptography-46.0.5-cp38-abi3-win_amd64.whl", hash = "sha256:556e106ee01aa13484ce9b0239bca667be5004efb0aabbed28d353df86445595", size = 3467050, upload-time = "2026-02-10T19:18:18.899Z" }, - { url = "https://files.pythonhosted.org/packages/eb/dd/2d9fdb07cebdf3d51179730afb7d5e576153c6744c3ff8fded23030c204e/cryptography-46.0.5-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:3b4995dc971c9fb83c25aa44cf45f02ba86f71ee600d81091c2f0cbae116b06c", size = 3476964, upload-time = "2026-02-10T19:18:20.687Z" }, - { url = "https://files.pythonhosted.org/packages/e9/6f/6cc6cc9955caa6eaf83660b0da2b077c7fe8ff9950a3c5e45d605038d439/cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:bc84e875994c3b445871ea7181d424588171efec3e185dced958dad9e001950a", size = 4218321, upload-time = "2026-02-10T19:18:22.349Z" }, - { url = "https://files.pythonhosted.org/packages/3e/5d/c4da701939eeee699566a6c1367427ab91a8b7088cc2328c09dbee940415/cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:2ae6971afd6246710480e3f15824ed3029a60fc16991db250034efd0b9fb4356", size = 4381786, upload-time = "2026-02-10T19:18:24.529Z" }, - { url = "https://files.pythonhosted.org/packages/ac/97/a538654732974a94ff96c1db621fa464f455c02d4bb7d2652f4edc21d600/cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:d861ee9e76ace6cf36a6a89b959ec08e7bc2493ee39d07ffe5acb23ef46d27da", size = 4217990, upload-time = "2026-02-10T19:18:25.957Z" }, - { url = "https://files.pythonhosted.org/packages/ae/11/7e500d2dd3ba891197b9efd2da5454b74336d64a7cc419aa7327ab74e5f6/cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:2b7a67c9cd56372f3249b39699f2ad479f6991e62ea15800973b956f4b73e257", size = 4381252, upload-time = "2026-02-10T19:18:27.496Z" }, - { url = "https://files.pythonhosted.org/packages/bc/58/6b3d24e6b9bc474a2dcdee65dfd1f008867015408a271562e4b690561a4d/cryptography-46.0.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:8456928655f856c6e1533ff59d5be76578a7157224dbd9ce6872f25055ab9ab7", size = 3407605, upload-time = "2026-02-10T19:18:29.233Z" }, +sdist = { url = "https://files.pythonhosted.org/packages/47/93/ac8f3d5ff04d54bc814e961a43ae5b0b146154c89c61b47bb07557679b18/cryptography-46.0.7.tar.gz", hash = "sha256:e4cfd68c5f3e0bfdad0d38e023239b96a2fe84146481852dffbcca442c245aa5", size = 750652, upload-time = "2026-04-08T01:57:54.692Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0b/5d/4a8f770695d73be252331e60e526291e3df0c9b27556a90a6b47bccca4c2/cryptography-46.0.7-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:ea42cbe97209df307fdc3b155f1b6fa2577c0defa8f1f7d3be7d31d189108ad4", size = 7179869, upload-time = "2026-04-08T01:56:17.157Z" }, + { url = "https://files.pythonhosted.org/packages/5f/45/6d80dc379b0bbc1f9d1e429f42e4cb9e1d319c7a8201beffd967c516ea01/cryptography-46.0.7-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:b36a4695e29fe69215d75960b22577197aca3f7a25b9cf9d165dcfe9d80bc325", size = 4275492, upload-time = "2026-04-08T01:56:19.36Z" }, + { url = "https://files.pythonhosted.org/packages/4a/9a/1765afe9f572e239c3469f2cb429f3ba7b31878c893b246b4b2994ffe2fe/cryptography-46.0.7-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ad9ef796328c5e3c4ceed237a183f5d41d21150f972455a9d926593a1dcb308", size = 4426670, upload-time = "2026-04-08T01:56:21.415Z" }, + { url = "https://files.pythonhosted.org/packages/8f/3e/af9246aaf23cd4ee060699adab1e47ced3f5f7e7a8ffdd339f817b446462/cryptography-46.0.7-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:73510b83623e080a2c35c62c15298096e2a5dc8d51c3b4e1740211839d0dea77", size = 4280275, upload-time = "2026-04-08T01:56:23.539Z" }, + { url = "https://files.pythonhosted.org/packages/0f/54/6bbbfc5efe86f9d71041827b793c24811a017c6ac0fd12883e4caa86b8ed/cryptography-46.0.7-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:cbd5fb06b62bd0721e1170273d3f4d5a277044c47ca27ee257025146c34cbdd1", size = 4928402, upload-time = "2026-04-08T01:56:25.624Z" }, + { url = "https://files.pythonhosted.org/packages/2d/cf/054b9d8220f81509939599c8bdbc0c408dbd2bdd41688616a20731371fe0/cryptography-46.0.7-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:420b1e4109cc95f0e5700eed79908cef9268265c773d3a66f7af1eef53d409ef", size = 4459985, upload-time = "2026-04-08T01:56:27.309Z" }, + { url = "https://files.pythonhosted.org/packages/f9/46/4e4e9c6040fb01c7467d47217d2f882daddeb8828f7df800cb806d8a2288/cryptography-46.0.7-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:24402210aa54baae71d99441d15bb5a1919c195398a87b563df84468160a65de", size = 3990652, upload-time = "2026-04-08T01:56:29.095Z" }, + { url = "https://files.pythonhosted.org/packages/36/5f/313586c3be5a2fbe87e4c9a254207b860155a8e1f3cca99f9910008e7d08/cryptography-46.0.7-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:8a469028a86f12eb7d2fe97162d0634026d92a21f3ae0ac87ed1c4a447886c83", size = 4279805, upload-time = "2026-04-08T01:56:30.928Z" }, + { url = "https://files.pythonhosted.org/packages/69/33/60dfc4595f334a2082749673386a4d05e4f0cf4df8248e63b2c3437585f2/cryptography-46.0.7-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:9694078c5d44c157ef3162e3bf3946510b857df5a3955458381d1c7cfc143ddb", size = 4892883, upload-time = "2026-04-08T01:56:32.614Z" }, + { url = "https://files.pythonhosted.org/packages/c7/0b/333ddab4270c4f5b972f980adef4faa66951a4aaf646ca067af597f15563/cryptography-46.0.7-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:42a1e5f98abb6391717978baf9f90dc28a743b7d9be7f0751a6f56a75d14065b", size = 4459756, upload-time = "2026-04-08T01:56:34.306Z" }, + { url = "https://files.pythonhosted.org/packages/d2/14/633913398b43b75f1234834170947957c6b623d1701ffc7a9600da907e89/cryptography-46.0.7-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:91bbcb08347344f810cbe49065914fe048949648f6bd5c2519f34619142bbe85", size = 4410244, upload-time = "2026-04-08T01:56:35.977Z" }, + { url = "https://files.pythonhosted.org/packages/10/f2/19ceb3b3dc14009373432af0c13f46aa08e3ce334ec6eff13492e1812ccd/cryptography-46.0.7-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5d1c02a14ceb9148cc7816249f64f623fbfee39e8c03b3650d842ad3f34d637e", size = 4674868, upload-time = "2026-04-08T01:56:38.034Z" }, + { url = "https://files.pythonhosted.org/packages/1a/bb/a5c213c19ee94b15dfccc48f363738633a493812687f5567addbcbba9f6f/cryptography-46.0.7-cp311-abi3-win32.whl", hash = "sha256:d23c8ca48e44ee015cd0a54aeccdf9f09004eba9fc96f38c911011d9ff1bd457", size = 3026504, upload-time = "2026-04-08T01:56:39.666Z" }, + { url = "https://files.pythonhosted.org/packages/2b/02/7788f9fefa1d060ca68717c3901ae7fffa21ee087a90b7f23c7a603c32ae/cryptography-46.0.7-cp311-abi3-win_amd64.whl", hash = "sha256:397655da831414d165029da9bc483bed2fe0e75dde6a1523ec2fe63f3c46046b", size = 3488363, upload-time = "2026-04-08T01:56:41.893Z" }, + { url = "https://files.pythonhosted.org/packages/7b/56/15619b210e689c5403bb0540e4cb7dbf11a6bf42e483b7644e471a2812b3/cryptography-46.0.7-cp314-cp314t-macosx_10_9_universal2.whl", hash = "sha256:d151173275e1728cf7839aaa80c34fe550c04ddb27b34f48c232193df8db5842", size = 7119671, upload-time = "2026-04-08T01:56:44Z" }, + { url = "https://files.pythonhosted.org/packages/74/66/e3ce040721b0b5599e175ba91ab08884c75928fbeb74597dd10ef13505d2/cryptography-46.0.7-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:db0f493b9181c7820c8134437eb8b0b4792085d37dbb24da050476ccb664e59c", size = 4268551, upload-time = "2026-04-08T01:56:46.071Z" }, + { url = "https://files.pythonhosted.org/packages/03/11/5e395f961d6868269835dee1bafec6a1ac176505a167f68b7d8818431068/cryptography-46.0.7-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ebd6daf519b9f189f85c479427bbd6e9c9037862cf8fe89ee35503bd209ed902", size = 4408887, upload-time = "2026-04-08T01:56:47.718Z" }, + { url = "https://files.pythonhosted.org/packages/40/53/8ed1cf4c3b9c8e611e7122fb56f1c32d09e1fff0f1d77e78d9ff7c82653e/cryptography-46.0.7-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:b7b412817be92117ec5ed95f880defe9cf18a832e8cafacf0a22337dc1981b4d", size = 4271354, upload-time = "2026-04-08T01:56:49.312Z" }, + { url = "https://files.pythonhosted.org/packages/50/46/cf71e26025c2e767c5609162c866a78e8a2915bbcfa408b7ca495c6140c4/cryptography-46.0.7-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:fbfd0e5f273877695cb93baf14b185f4878128b250cc9f8e617ea0c025dfb022", size = 4905845, upload-time = "2026-04-08T01:56:50.916Z" }, + { url = "https://files.pythonhosted.org/packages/c0/ea/01276740375bac6249d0a971ebdf6b4dc9ead0ee0a34ef3b5a88c1a9b0d4/cryptography-46.0.7-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:ffca7aa1d00cf7d6469b988c581598f2259e46215e0140af408966a24cf086ce", size = 4444641, upload-time = "2026-04-08T01:56:52.882Z" }, + { url = "https://files.pythonhosted.org/packages/3d/4c/7d258f169ae71230f25d9f3d06caabcff8c3baf0978e2b7d65e0acac3827/cryptography-46.0.7-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:60627cf07e0d9274338521205899337c5d18249db56865f943cbe753aa96f40f", size = 3967749, upload-time = "2026-04-08T01:56:54.597Z" }, + { url = "https://files.pythonhosted.org/packages/b5/2a/2ea0767cad19e71b3530e4cad9605d0b5e338b6a1e72c37c9c1ceb86c333/cryptography-46.0.7-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:80406c3065e2c55d7f49a9550fe0c49b3f12e5bfff5dedb727e319e1afb9bf99", size = 4270942, upload-time = "2026-04-08T01:56:56.416Z" }, + { url = "https://files.pythonhosted.org/packages/41/3d/fe14df95a83319af25717677e956567a105bb6ab25641acaa093db79975d/cryptography-46.0.7-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:c5b1ccd1239f48b7151a65bc6dd54bcfcc15e028c8ac126d3fada09db0e07ef1", size = 4871079, upload-time = "2026-04-08T01:56:58.31Z" }, + { url = "https://files.pythonhosted.org/packages/9c/59/4a479e0f36f8f378d397f4eab4c850b4ffb79a2f0d58704b8fa0703ddc11/cryptography-46.0.7-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:d5f7520159cd9c2154eb61eb67548ca05c5774d39e9c2c4339fd793fe7d097b2", size = 4443999, upload-time = "2026-04-08T01:57:00.508Z" }, + { url = "https://files.pythonhosted.org/packages/28/17/b59a741645822ec6d04732b43c5d35e4ef58be7bfa84a81e5ae6f05a1d33/cryptography-46.0.7-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:fcd8eac50d9138c1d7fc53a653ba60a2bee81a505f9f8850b6b2888555a45d0e", size = 4399191, upload-time = "2026-04-08T01:57:02.654Z" }, + { url = "https://files.pythonhosted.org/packages/59/6a/bb2e166d6d0e0955f1e9ff70f10ec4b2824c9cfcdb4da772c7dd69cc7d80/cryptography-46.0.7-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:65814c60f8cc400c63131584e3e1fad01235edba2614b61fbfbfa954082db0ee", size = 4655782, upload-time = "2026-04-08T01:57:04.592Z" }, + { url = "https://files.pythonhosted.org/packages/95/b6/3da51d48415bcb63b00dc17c2eff3a651b7c4fed484308d0f19b30e8cb2c/cryptography-46.0.7-cp314-cp314t-win32.whl", hash = "sha256:fdd1736fed309b4300346f88f74cd120c27c56852c3838cab416e7a166f67298", size = 3002227, upload-time = "2026-04-08T01:57:06.91Z" }, + { url = "https://files.pythonhosted.org/packages/32/a8/9f0e4ed57ec9cebe506e58db11ae472972ecb0c659e4d52bbaee80ca340a/cryptography-46.0.7-cp314-cp314t-win_amd64.whl", hash = "sha256:e06acf3c99be55aa3b516397fe42f5855597f430add9c17fa46bf2e0fb34c9bb", size = 3475332, upload-time = "2026-04-08T01:57:08.807Z" }, + { url = "https://files.pythonhosted.org/packages/a7/7f/cd42fc3614386bc0c12f0cb3c4ae1fc2bbca5c9662dfed031514911d513d/cryptography-46.0.7-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:462ad5cb1c148a22b2e3bcc5ad52504dff325d17daf5df8d88c17dda1f75f2a4", size = 7165618, upload-time = "2026-04-08T01:57:10.645Z" }, + { url = "https://files.pythonhosted.org/packages/a5/d0/36a49f0262d2319139d2829f773f1b97ef8aef7f97e6e5bd21455e5a8fb5/cryptography-46.0.7-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:84d4cced91f0f159a7ddacad249cc077e63195c36aac40b4150e7a57e84fffe7", size = 4270628, upload-time = "2026-04-08T01:57:12.885Z" }, + { url = "https://files.pythonhosted.org/packages/8a/6c/1a42450f464dda6ffbe578a911f773e54dd48c10f9895a23a7e88b3e7db5/cryptography-46.0.7-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:128c5edfe5e5938b86b03941e94fac9ee793a94452ad1365c9fc3f4f62216832", size = 4415405, upload-time = "2026-04-08T01:57:14.923Z" }, + { url = "https://files.pythonhosted.org/packages/9a/92/4ed714dbe93a066dc1f4b4581a464d2d7dbec9046f7c8b7016f5286329e2/cryptography-46.0.7-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:5e51be372b26ef4ba3de3c167cd3d1022934bc838ae9eaad7e644986d2a3d163", size = 4272715, upload-time = "2026-04-08T01:57:16.638Z" }, + { url = "https://files.pythonhosted.org/packages/b7/e6/a26b84096eddd51494bba19111f8fffe976f6a09f132706f8f1bf03f51f7/cryptography-46.0.7-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:cdf1a610ef82abb396451862739e3fc93b071c844399e15b90726ef7470eeaf2", size = 4918400, upload-time = "2026-04-08T01:57:19.021Z" }, + { url = "https://files.pythonhosted.org/packages/c7/08/ffd537b605568a148543ac3c2b239708ae0bd635064bab41359252ef88ed/cryptography-46.0.7-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:1d25aee46d0c6f1a501adcddb2d2fee4b979381346a78558ed13e50aa8a59067", size = 4450634, upload-time = "2026-04-08T01:57:21.185Z" }, + { url = "https://files.pythonhosted.org/packages/16/01/0cd51dd86ab5b9befe0d031e276510491976c3a80e9f6e31810cce46c4ad/cryptography-46.0.7-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:cdfbe22376065ffcf8be74dc9a909f032df19bc58a699456a21712d6e5eabfd0", size = 3985233, upload-time = "2026-04-08T01:57:22.862Z" }, + { url = "https://files.pythonhosted.org/packages/92/49/819d6ed3a7d9349c2939f81b500a738cb733ab62fbecdbc1e38e83d45e12/cryptography-46.0.7-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:abad9dac36cbf55de6eb49badd4016806b3165d396f64925bf2999bcb67837ba", size = 4271955, upload-time = "2026-04-08T01:57:24.814Z" }, + { url = "https://files.pythonhosted.org/packages/80/07/ad9b3c56ebb95ed2473d46df0847357e01583f4c52a85754d1a55e29e4d0/cryptography-46.0.7-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:935ce7e3cfdb53e3536119a542b839bb94ec1ad081013e9ab9b7cfd478b05006", size = 4879888, upload-time = "2026-04-08T01:57:26.88Z" }, + { url = "https://files.pythonhosted.org/packages/b8/c7/201d3d58f30c4c2bdbe9b03844c291feb77c20511cc3586daf7edc12a47b/cryptography-46.0.7-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:35719dc79d4730d30f1c2b6474bd6acda36ae2dfae1e3c16f2051f215df33ce0", size = 4449961, upload-time = "2026-04-08T01:57:29.068Z" }, + { url = "https://files.pythonhosted.org/packages/a5/ef/649750cbf96f3033c3c976e112265c33906f8e462291a33d77f90356548c/cryptography-46.0.7-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:7bbc6ccf49d05ac8f7d7b5e2e2c33830d4fe2061def88210a126d130d7f71a85", size = 4401696, upload-time = "2026-04-08T01:57:31.029Z" }, + { url = "https://files.pythonhosted.org/packages/41/52/a8908dcb1a389a459a29008c29966c1d552588d4ae6d43f3a1a4512e0ebe/cryptography-46.0.7-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a1529d614f44b863a7b480c6d000fe93b59acee9c82ffa027cfadc77521a9f5e", size = 4664256, upload-time = "2026-04-08T01:57:33.144Z" }, + { url = "https://files.pythonhosted.org/packages/4b/fa/f0ab06238e899cc3fb332623f337a7364f36f4bb3f2534c2bb95a35b132c/cryptography-46.0.7-cp38-abi3-win32.whl", hash = "sha256:f247c8c1a1fb45e12586afbb436ef21ff1e80670b2861a90353d9b025583d246", size = 3013001, upload-time = "2026-04-08T01:57:34.933Z" }, + { url = "https://files.pythonhosted.org/packages/d2/f1/00ce3bde3ca542d1acd8f8cfa38e446840945aa6363f9b74746394b14127/cryptography-46.0.7-cp38-abi3-win_amd64.whl", hash = "sha256:506c4ff91eff4f82bdac7633318a526b1d1309fc07ca76a3ad182cb5b686d6d3", size = 3472985, upload-time = "2026-04-08T01:57:36.714Z" }, + { url = "https://files.pythonhosted.org/packages/63/0c/dca8abb64e7ca4f6b2978769f6fea5ad06686a190cec381f0a796fdcaaba/cryptography-46.0.7-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:fc9ab8856ae6cf7c9358430e49b368f3108f050031442eaeb6b9d87e4dcf4e4f", size = 3476879, upload-time = "2026-04-08T01:57:38.664Z" }, + { url = "https://files.pythonhosted.org/packages/3a/ea/075aac6a84b7c271578d81a2f9968acb6e273002408729f2ddff517fed4a/cryptography-46.0.7-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:d3b99c535a9de0adced13d159c5a9cf65c325601aa30f4be08afd680643e9c15", size = 4219700, upload-time = "2026-04-08T01:57:40.625Z" }, + { url = "https://files.pythonhosted.org/packages/6c/7b/1c55db7242b5e5612b29fc7a630e91ee7a6e3c8e7bf5406d22e206875fbd/cryptography-46.0.7-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d02c738dacda7dc2a74d1b2b3177042009d5cab7c7079db74afc19e56ca1b455", size = 4385982, upload-time = "2026-04-08T01:57:42.725Z" }, + { url = "https://files.pythonhosted.org/packages/cb/da/9870eec4b69c63ef5925bf7d8342b7e13bc2ee3d47791461c4e49ca212f4/cryptography-46.0.7-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:04959522f938493042d595a736e7dbdff6eb6cc2339c11465b3ff89343b65f65", size = 4219115, upload-time = "2026-04-08T01:57:44.939Z" }, + { url = "https://files.pythonhosted.org/packages/f4/72/05aa5832b82dd341969e9a734d1812a6aadb088d9eb6f0430fc337cc5a8f/cryptography-46.0.7-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:3986ac1dee6def53797289999eabe84798ad7817f3e97779b5061a95b0ee4968", size = 4385479, upload-time = "2026-04-08T01:57:46.86Z" }, + { url = "https://files.pythonhosted.org/packages/20/2a/1b016902351a523aa2bd446b50a5bc1175d7a7d1cf90fe2ef904f9b84ebc/cryptography-46.0.7-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:258514877e15963bd43b558917bc9f54cf7cf866c38aa576ebf47a77ddbc43a4", size = 3412829, upload-time = "2026-04-08T01:57:48.874Z" }, +] + +[[package]] +name = "cssselect2" +version = "0.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "tinycss2" }, + { name = "webencodings" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e0/20/92eaa6b0aec7189fa4b75c890640e076e9e793095721db69c5c81142c2e1/cssselect2-0.9.0.tar.gz", hash = "sha256:759aa22c216326356f65e62e791d66160a0f9c91d1424e8d8adc5e74dddfc6fb", size = 35595, upload-time = "2026-02-12T17:16:39.614Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/21/0e/8459ca4413e1a21a06c97d134bfaf18adfd27cea068813dc0faae06cbf00/cssselect2-0.9.0-py3-none-any.whl", hash = "sha256:6a99e5f91f9a016a304dd929b0966ca464bcfda15177b6fb4a118fc0fb5d9563", size = 15453, upload-time = "2026-02-12T17:16:38.317Z" }, +] + +[[package]] +name = "defusedxml" +version = "0.7.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0f/d5/c66da9b79e5bdb124974bfe172b4daf3c984ebd9c2a06e2b8a4dc7331c72/defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69", size = 75520, upload-time = "2021-03-08T10:59:26.269Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61", size = 25604, upload-time = "2021-03-08T10:59:24.45Z" }, ] [[package]] @@ -519,6 +592,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f7/ec/67fbef5d497f86283db54c22eec6f6140243aae73265799baaaa19cd17fb/ghp_import-2.1.0-py3-none-any.whl", hash = "sha256:8337dd7b50877f163d4c0289bc1f1c7f127550241988d568c1db512c4324a619", size = 11034, upload-time = "2022-05-02T15:47:14.552Z" }, ] +[[package]] +name = "googleapis-common-protos" +version = "1.73.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a1/c0/4a54c386282c13449eca8bbe2ddb518181dc113e78d240458a68856b4d69/googleapis_common_protos-1.73.1.tar.gz", hash = "sha256:13114f0e9d2391756a0194c3a8131974ed7bffb06086569ba193364af59163b6", size = 147506, upload-time = "2026-03-26T22:17:38.451Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dc/82/fcb6520612bec0c39b973a6c0954b6a0d948aadfe8f7e9487f60ceb8bfa6/googleapis_common_protos-1.73.1-py3-none-any.whl", hash = "sha256:e51f09eb0a43a8602f5a915870972e6b4a394088415c79d79605a46d8e826ee8", size = 297556, upload-time = "2026-03-26T22:15:58.455Z" }, +] + [[package]] name = "griffe" version = "1.14.0" @@ -586,6 +671,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" }, ] +[[package]] +name = "importlib-metadata" +version = "8.7.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "zipp" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f3/49/3b30cad09e7771a4982d9975a8cbf64f00d4a1ececb53297f1d9a7be1b10/importlib_metadata-8.7.1.tar.gz", hash = "sha256:49fef1ae6440c182052f407c8d34a68f72efc36db9ca90dc0113398f2fdde8bb", size = 57107, upload-time = "2025-12-21T10:00:19.278Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fa/5e/f8e9a1d23b9c20a551a8a02ea3637b4642e22c2626e3a13a9a29cdea99eb/importlib_metadata-8.7.1-py3-none-any.whl", hash = "sha256:5a1f80bf1daa489495071efbb095d75a634cf28a8bc299581244063b53176151", size = 27865, upload-time = "2025-12-21T10:00:18.329Z" }, +] + [[package]] name = "iniconfig" version = "2.1.0" @@ -650,6 +747,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/41/45/1a4ed80516f02155c51f51e8cedb3c1902296743db0bbc66608a0db2814f/jsonschema_specifications-2025.9.1-py3-none-any.whl", hash = "sha256:98802fee3a11ee76ecaca44429fda8a41bff98b00a0f2838151b113f210cc6fe", size = 18437, upload-time = "2025-09-08T01:34:57.871Z" }, ] +[[package]] +name = "logfire" +version = "4.31.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "executing" }, + { name = "opentelemetry-exporter-otlp-proto-http" }, + { name = "opentelemetry-instrumentation" }, + { name = "opentelemetry-sdk" }, + { name = "protobuf" }, + { name = "rich" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/61/fc/21f923243d8c3ca2ebfa97de46970ced734e66ac634c1c35b6abb41300f1/logfire-4.31.0.tar.gz", hash = "sha256:361bfda17c9d70ada5d220211033bae06b871ddac9d5b06978bc0ceca6b8e658", size = 1080609, upload-time = "2026-03-27T19:00:46.339Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/49/1a/8c860e35bf847ac0d647d94bad89dccbb66cbcafdd61d8334f8cc7cfdd58/logfire-4.31.0-py3-none-any.whl", hash = "sha256:49fad38b5e6f199a98e9c8814e860c8a42595bb81479b52a20413e53ee475b72", size = 308896, upload-time = "2026-03-27T19:00:43.107Z" }, +] + [[package]] name = "markdown" version = "3.9" @@ -737,6 +853,7 @@ dependencies = [ { name = "httpx" }, { name = "httpx-sse" }, { name = "jsonschema" }, + { name = "opentelemetry-api" }, { name = "pydantic" }, { name = "pydantic-settings" }, { name = "pyjwt", extra = ["crypto"] }, @@ -766,6 +883,7 @@ dev = [ { name = "coverage", extra = ["toml"] }, { name = "dirty-equals" }, { name = "inline-snapshot" }, + { name = "logfire" }, { name = "mcp", extra = ["cli", "ws"] }, { name = "pillow" }, { name = "pyright" }, @@ -780,17 +898,20 @@ dev = [ ] docs = [ { name = "mkdocs" }, + { name = "mkdocs-gen-files" }, { name = "mkdocs-glightbox" }, - { name = "mkdocs-material" }, + { name = "mkdocs-literate-nav" }, + { name = "mkdocs-material", extra = ["imaging"] }, { name = "mkdocstrings-python" }, ] [package.metadata] requires-dist = [ - { name = "anyio", specifier = ">=4.5" }, - { name = "httpx", specifier = ">=0.27.1" }, + { name = "anyio", specifier = ">=4.9" }, + { name = "httpx", specifier = ">=0.27.1,<1.0.0" }, { name = "httpx-sse", specifier = ">=0.4" }, { name = "jsonschema", specifier = ">=4.20.0" }, + { name = "opentelemetry-api", specifier = ">=1.28.0" }, { name = "pydantic", specifier = ">=2.12.0" }, { name = "pydantic-settings", specifier = ">=2.5.2" }, { name = "pyjwt", extras = ["crypto"], specifier = ">=2.10.1" }, @@ -814,10 +935,11 @@ dev = [ { name = "coverage", extras = ["toml"], specifier = ">=7.10.7,<=7.13" }, { name = "dirty-equals", specifier = ">=0.9.0" }, { name = "inline-snapshot", specifier = ">=0.23.0" }, + { name = "logfire", specifier = ">=3.0.0" }, { name = "mcp", extras = ["cli", "ws"], editable = "." }, { name = "pillow", specifier = ">=12.0" }, { name = "pyright", specifier = ">=1.1.400" }, - { name = "pytest", specifier = ">=8.3.4" }, + { name = "pytest", specifier = ">=8.4.0" }, { name = "pytest-examples", specifier = ">=0.0.14" }, { name = "pytest-flakefinder", specifier = ">=1.1.0" }, { name = "pytest-pretty", specifier = ">=1.2.0" }, @@ -828,8 +950,10 @@ dev = [ ] docs = [ { name = "mkdocs", specifier = ">=1.6.1" }, + { name = "mkdocs-gen-files", specifier = ">=0.5.0" }, { name = "mkdocs-glightbox", specifier = ">=0.4.0" }, - { name = "mkdocs-material", specifier = ">=9.5.45" }, + { name = "mkdocs-literate-nav", specifier = ">=0.6.1" }, + { name = "mkdocs-material", extras = ["imaging"], specifier = ">=9.5.45" }, { name = "mkdocstrings-python", specifier = ">=2.0.1" }, ] @@ -1441,6 +1565,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9f/4d/7123b6fa2278000688ebd338e2a06d16870aaf9eceae6ba047ea05f92df1/mkdocs_autorefs-1.4.3-py3-none-any.whl", hash = "sha256:469d85eb3114801d08e9cc55d102b3ba65917a869b893403b8987b601cf55dc9", size = 25034, upload-time = "2025-08-26T14:23:15.906Z" }, ] +[[package]] +name = "mkdocs-gen-files" +version = "0.6.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mkdocs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/61/35/f26349f7fa18414eb2e25d75a6fa9c7e3186c36e1d227c0b2d785a7bd5c4/mkdocs_gen_files-0.6.0.tar.gz", hash = "sha256:52022dc14dcc0451e05e54a8f5d5e7760351b6701eff816d1e9739577ec5635e", size = 8642, upload-time = "2025-11-23T12:13:22.124Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8d/ec/72417415563c60ae01b36f0d497f1f4c803972f447ef4fb7f7746d6e07db/mkdocs_gen_files-0.6.0-py3-none-any.whl", hash = "sha256:815af15f3e2dbfda379629c1b95c02c8e6f232edf2a901186ea3b204ab1135b2", size = 8182, upload-time = "2025-11-23T12:13:20.756Z" }, +] + [[package]] name = "mkdocs-get-deps" version = "0.2.0" @@ -1467,9 +1603,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/30/cf/e9a0ce9da269746906fdc595c030f6df66793dad1487abd1699af2ba44f1/mkdocs_glightbox-0.5.1-py3-none-any.whl", hash = "sha256:f47af0daff164edf8d36e553338425be3aab6e34b987d9cbbc2ae7819a98cb01", size = 26431, upload-time = "2025-09-04T13:10:27.933Z" }, ] +[[package]] +name = "mkdocs-literate-nav" +version = "0.6.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mkdocs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f6/5f/99aa379b305cd1c2084d42db3d26f6de0ea9bf2cc1d10ed17f61aff35b9a/mkdocs_literate_nav-0.6.2.tar.gz", hash = "sha256:760e1708aa4be86af81a2b56e82c739d5a8388a0eab1517ecfd8e5aa40810a75", size = 17419, upload-time = "2025-03-18T21:53:09.711Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8a/84/b5b14d2745e4dd1a90115186284e9ee1b4d0863104011ab46abb7355a1c3/mkdocs_literate_nav-0.6.2-py3-none-any.whl", hash = "sha256:0a6489a26ec7598477b56fa112056a5e3a6c15729f0214bea8a4dbc55bd5f630", size = 13261, upload-time = "2025-03-18T21:53:08.1Z" }, +] + [[package]] name = "mkdocs-material" -version = "9.7.1" +version = "9.7.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "babel" }, @@ -1484,9 +1632,15 @@ dependencies = [ { name = "pymdown-extensions" }, { name = "requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/27/e2/2ffc356cd72f1473d07c7719d82a8f2cbd261666828614ecb95b12169f41/mkdocs_material-9.7.1.tar.gz", hash = "sha256:89601b8f2c3e6c6ee0a918cc3566cb201d40bf37c3cd3c2067e26fadb8cce2b8", size = 4094392, upload-time = "2025-12-18T09:49:00.308Z" } +sdist = { url = "https://files.pythonhosted.org/packages/34/57/5d3c8c9e2ff9d66dc8f63aa052eb0bac5041fecff7761d8689fe65c39c13/mkdocs_material-9.7.2.tar.gz", hash = "sha256:6776256552290b9b7a7aa002780e25b1e04bc9c3a8516b6b153e82e16b8384bd", size = 4097818, upload-time = "2026-02-18T15:53:07.763Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3e/32/ed071cb721aca8c227718cffcf7bd539620e9799bbf2619e90c757bfd030/mkdocs_material-9.7.1-py3-none-any.whl", hash = "sha256:3f6100937d7d731f87f1e3e3b021c97f7239666b9ba1151ab476cabb96c60d5c", size = 9297166, upload-time = "2025-12-18T09:48:56.664Z" }, + { url = "https://files.pythonhosted.org/packages/cd/19/d194e75e82282b1d688f0720e21b5ac250ed64ddea333a228aaf83105f2e/mkdocs_material-9.7.2-py3-none-any.whl", hash = "sha256:9bf6f53452d4a4d527eac3cef3f92b7b6fc4931c55d57766a7d87890d47e1b92", size = 9305052, upload-time = "2026-02-18T15:53:05.221Z" }, +] + +[package.optional-dependencies] +imaging = [ + { name = "cairosvg" }, + { name = "pillow" }, ] [[package]] @@ -1548,6 +1702,103 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314, upload-time = "2024-06-04T18:44:08.352Z" }, ] +[[package]] +name = "opentelemetry-api" +version = "1.39.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "importlib-metadata" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/97/b9/3161be15bb8e3ad01be8be5a968a9237c3027c5be504362ff800fca3e442/opentelemetry_api-1.39.1.tar.gz", hash = "sha256:fbde8c80e1b937a2c61f20347e91c0c18a1940cecf012d62e65a7caf08967c9c", size = 65767, upload-time = "2025-12-11T13:32:39.182Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cf/df/d3f1ddf4bb4cb50ed9b1139cc7b1c54c34a1e7ce8fd1b9a37c0d1551a6bd/opentelemetry_api-1.39.1-py3-none-any.whl", hash = "sha256:2edd8463432a7f8443edce90972169b195e7d6a05500cd29e6d13898187c9950", size = 66356, upload-time = "2025-12-11T13:32:17.304Z" }, +] + +[[package]] +name = "opentelemetry-exporter-otlp-proto-common" +version = "1.39.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-proto" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e9/9d/22d241b66f7bbde88a3bfa6847a351d2c46b84de23e71222c6aae25c7050/opentelemetry_exporter_otlp_proto_common-1.39.1.tar.gz", hash = "sha256:763370d4737a59741c89a67b50f9e39271639ee4afc999dadfe768541c027464", size = 20409, upload-time = "2025-12-11T13:32:40.885Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8c/02/ffc3e143d89a27ac21fd557365b98bd0653b98de8a101151d5805b5d4c33/opentelemetry_exporter_otlp_proto_common-1.39.1-py3-none-any.whl", hash = "sha256:08f8a5862d64cc3435105686d0216c1365dc5701f86844a8cd56597d0c764fde", size = 18366, upload-time = "2025-12-11T13:32:20.2Z" }, +] + +[[package]] +name = "opentelemetry-exporter-otlp-proto-http" +version = "1.39.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "googleapis-common-protos" }, + { name = "opentelemetry-api" }, + { name = "opentelemetry-exporter-otlp-proto-common" }, + { name = "opentelemetry-proto" }, + { name = "opentelemetry-sdk" }, + { name = "requests" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/80/04/2a08fa9c0214ae38880df01e8bfae12b067ec0793446578575e5080d6545/opentelemetry_exporter_otlp_proto_http-1.39.1.tar.gz", hash = "sha256:31bdab9745c709ce90a49a0624c2bd445d31a28ba34275951a6a362d16a0b9cb", size = 17288, upload-time = "2025-12-11T13:32:42.029Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/95/f1/b27d3e2e003cd9a3592c43d099d2ed8d0a947c15281bf8463a256db0b46c/opentelemetry_exporter_otlp_proto_http-1.39.1-py3-none-any.whl", hash = "sha256:d9f5207183dd752a412c4cd564ca8875ececba13be6e9c6c370ffb752fd59985", size = 19641, upload-time = "2025-12-11T13:32:22.248Z" }, +] + +[[package]] +name = "opentelemetry-instrumentation" +version = "0.60b1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-semantic-conventions" }, + { name = "packaging" }, + { name = "wrapt" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/41/0f/7e6b713ac117c1f5e4e3300748af699b9902a2e5e34c9cf443dde25a01fa/opentelemetry_instrumentation-0.60b1.tar.gz", hash = "sha256:57ddc7974c6eb35865af0426d1a17132b88b2ed8586897fee187fd5b8944bd6a", size = 31706, upload-time = "2025-12-11T13:36:42.515Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/77/d2/6788e83c5c86a2690101681aeef27eeb2a6bf22df52d3f263a22cee20915/opentelemetry_instrumentation-0.60b1-py3-none-any.whl", hash = "sha256:04480db952b48fb1ed0073f822f0ee26012b7be7c3eac1a3793122737c78632d", size = 33096, upload-time = "2025-12-11T13:35:33.067Z" }, +] + +[[package]] +name = "opentelemetry-proto" +version = "1.39.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/49/1d/f25d76d8260c156c40c97c9ed4511ec0f9ce353f8108ca6e7561f82a06b2/opentelemetry_proto-1.39.1.tar.gz", hash = "sha256:6c8e05144fc0d3ed4d22c2289c6b126e03bcd0e6a7da0f16cedd2e1c2772e2c8", size = 46152, upload-time = "2025-12-11T13:32:48.681Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/51/95/b40c96a7b5203005a0b03d8ce8cd212ff23f1793d5ba289c87a097571b18/opentelemetry_proto-1.39.1-py3-none-any.whl", hash = "sha256:22cdc78efd3b3765d09e68bfbd010d4fc254c9818afd0b6b423387d9dee46007", size = 72535, upload-time = "2025-12-11T13:32:33.866Z" }, +] + +[[package]] +name = "opentelemetry-sdk" +version = "1.39.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-semantic-conventions" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/eb/fb/c76080c9ba07e1e8235d24cdcc4d125ef7aa3edf23eb4e497c2e50889adc/opentelemetry_sdk-1.39.1.tar.gz", hash = "sha256:cf4d4563caf7bff906c9f7967e2be22d0d6b349b908be0d90fb21c8e9c995cc6", size = 171460, upload-time = "2025-12-11T13:32:49.369Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7c/98/e91cf858f203d86f4eccdf763dcf01cf03f1dae80c3750f7e635bfa206b6/opentelemetry_sdk-1.39.1-py3-none-any.whl", hash = "sha256:4d5482c478513ecb0a5d938dcc61394e647066e0cc2676bee9f3af3f3f45f01c", size = 132565, upload-time = "2025-12-11T13:32:35.069Z" }, +] + +[[package]] +name = "opentelemetry-semantic-conventions" +version = "0.60b1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/91/df/553f93ed38bf22f4b999d9be9c185adb558982214f33eae539d3b5cd0858/opentelemetry_semantic_conventions-0.60b1.tar.gz", hash = "sha256:87c228b5a0669b748c76d76df6c364c369c28f1c465e50f661e39737e84bc953", size = 137935, upload-time = "2025-12-11T13:32:50.487Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/5e/5958555e09635d09b75de3c4f8b9cae7335ca545d77392ffe7331534c402/opentelemetry_semantic_conventions-0.60b1-py3-none-any.whl", hash = "sha256:9fa8c8b0c110da289809292b0591220d3a7b53c1526a23021e977d68597893fb", size = 219982, upload-time = "2025-12-11T13:32:36.955Z" }, +] + [[package]] name = "outcome" version = "1.3.0.post0" @@ -1580,11 +1831,11 @@ wheels = [ [[package]] name = "pathspec" -version = "0.12.1" +version = "1.0.4" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ca/bc/f35b8446f4531a7cb215605d100cd88b7ac6f44ab3fc94870c120ab3adbf/pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712", size = 51043, upload-time = "2023-12-10T22:30:45Z" } +sdist = { url = "https://files.pythonhosted.org/packages/fa/36/e27608899f9b8d4dff0617b2d9ab17ca5608956ca44461ac14ac48b44015/pathspec-1.0.4.tar.gz", hash = "sha256:0210e2ae8a21a9137c0d470578cb0e595af87edaa6ebf12ff176f14a02e0e645", size = 131200, upload-time = "2026-01-27T03:59:46.938Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08", size = 31191, upload-time = "2023-12-10T22:30:43.14Z" }, + { url = "https://files.pythonhosted.org/packages/ef/3c/2c197d226f9ea224a9ab8d197933f9da0ae0aac5b6e0f884e2b8d9c8e9f7/pathspec-1.0.4-py3-none-any.whl", hash = "sha256:fb6ae2fd4e7c921a165808a552060e722767cfa526f99ca5156ed2ce45a5c723", size = 55206, upload-time = "2026-01-27T03:59:45.137Z" }, ] [[package]] @@ -1703,6 +1954,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "protobuf" +version = "6.33.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/66/70/e908e9c5e52ef7c3a6c7902c9dfbb34c7e29c25d2f81ade3856445fd5c94/protobuf-6.33.6.tar.gz", hash = "sha256:a6768d25248312c297558af96a9f9c929e8c4cee0659cb07e780731095f38135", size = 444531, upload-time = "2026-03-18T19:05:00.988Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fc/9f/2f509339e89cfa6f6a4c4ff50438db9ca488dec341f7e454adad60150b00/protobuf-6.33.6-cp310-abi3-win32.whl", hash = "sha256:7d29d9b65f8afef196f8334e80d6bc1d5d4adedb449971fefd3723824e6e77d3", size = 425739, upload-time = "2026-03-18T19:04:48.373Z" }, + { url = "https://files.pythonhosted.org/packages/76/5d/683efcd4798e0030c1bab27374fd13a89f7c2515fb1f3123efdfaa5eab57/protobuf-6.33.6-cp310-abi3-win_amd64.whl", hash = "sha256:0cd27b587afca21b7cfa59a74dcbd48a50f0a6400cfb59391340ad729d91d326", size = 437089, upload-time = "2026-03-18T19:04:50.381Z" }, + { url = "https://files.pythonhosted.org/packages/5c/01/a3c3ed5cd186f39e7880f8303cc51385a198a81469d53d0fdecf1f64d929/protobuf-6.33.6-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:9720e6961b251bde64edfdab7d500725a2af5280f3f4c87e57c0208376aa8c3a", size = 427737, upload-time = "2026-03-18T19:04:51.866Z" }, + { url = "https://files.pythonhosted.org/packages/ee/90/b3c01fdec7d2f627b3a6884243ba328c1217ed2d978def5c12dc50d328a3/protobuf-6.33.6-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:e2afbae9b8e1825e3529f88d514754e094278bb95eadc0e199751cdd9a2e82a2", size = 324610, upload-time = "2026-03-18T19:04:53.096Z" }, + { url = "https://files.pythonhosted.org/packages/9b/ca/25afc144934014700c52e05103c2421997482d561f3101ff352e1292fb81/protobuf-6.33.6-cp39-abi3-manylinux2014_s390x.whl", hash = "sha256:c96c37eec15086b79762ed265d59ab204dabc53056e3443e702d2681f4b39ce3", size = 339381, upload-time = "2026-03-18T19:04:54.616Z" }, + { url = "https://files.pythonhosted.org/packages/16/92/d1e32e3e0d894fe00b15ce28ad4944ab692713f2e7f0a99787405e43533a/protobuf-6.33.6-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:e9db7e292e0ab79dd108d7f1a94fe31601ce1ee3f7b79e0692043423020b0593", size = 323436, upload-time = "2026-03-18T19:04:55.768Z" }, + { url = "https://files.pythonhosted.org/packages/c4/72/02445137af02769918a93807b2b7890047c32bfb9f90371cbc12688819eb/protobuf-6.33.6-py3-none-any.whl", hash = "sha256:77179e006c476e69bf8e8ce866640091ec42e1beb80b213c3900006ecfba6901", size = 170656, upload-time = "2026-03-18T19:04:59.826Z" }, +] + [[package]] name = "pycparser" version = "2.23" @@ -2008,6 +2274,45 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1b/d0/397f9626e711ff749a95d96b7af99b9c566a9bb5129b8e4c10fc4d100304/python_multipart-0.0.22-py3-none-any.whl", hash = "sha256:2b2cd894c83d21bf49d702499531c7bafd057d730c201782048f7945d82de155", size = 24579, upload-time = "2026-01-25T10:15:54.811Z" }, ] +[[package]] +name = "pytokens" +version = "0.4.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b6/34/b4e015b99031667a7b960f888889c5bd34ef585c85e1cb56a594b92836ac/pytokens-0.4.1.tar.gz", hash = "sha256:292052fe80923aae2260c073f822ceba21f3872ced9a68bb7953b348e561179a", size = 23015, upload-time = "2026-01-30T01:03:45.924Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/42/24/f206113e05cb8ef51b3850e7ef88f20da6f4bf932190ceb48bd3da103e10/pytokens-0.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2a44ed93ea23415c54f3face3b65ef2b844d96aeb3455b8a69b3df6beab6acc5", size = 161522, upload-time = "2026-01-30T01:02:50.393Z" }, + { url = "https://files.pythonhosted.org/packages/d4/e9/06a6bf1b90c2ed81a9c7d2544232fe5d2891d1cd480e8a1809ca354a8eb2/pytokens-0.4.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:add8bf86b71a5d9fb5b89f023a80b791e04fba57960aa790cc6125f7f1d39dfe", size = 246945, upload-time = "2026-01-30T01:02:52.399Z" }, + { url = "https://files.pythonhosted.org/packages/69/66/f6fb1007a4c3d8b682d5d65b7c1fb33257587a5f782647091e3408abe0b8/pytokens-0.4.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:670d286910b531c7b7e3c0b453fd8156f250adb140146d234a82219459b9640c", size = 259525, upload-time = "2026-01-30T01:02:53.737Z" }, + { url = "https://files.pythonhosted.org/packages/04/92/086f89b4d622a18418bac74ab5db7f68cf0c21cf7cc92de6c7b919d76c88/pytokens-0.4.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:4e691d7f5186bd2842c14813f79f8884bb03f5995f0575272009982c5ac6c0f7", size = 262693, upload-time = "2026-01-30T01:02:54.871Z" }, + { url = "https://files.pythonhosted.org/packages/b4/7b/8b31c347cf94a3f900bdde750b2e9131575a61fdb620d3d3c75832262137/pytokens-0.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:27b83ad28825978742beef057bfe406ad6ed524b2d28c252c5de7b4a6dd48fa2", size = 103567, upload-time = "2026-01-30T01:02:56.414Z" }, + { url = "https://files.pythonhosted.org/packages/3d/92/790ebe03f07b57e53b10884c329b9a1a308648fc083a6d4a39a10a28c8fc/pytokens-0.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d70e77c55ae8380c91c0c18dea05951482e263982911fc7410b1ffd1dadd3440", size = 160864, upload-time = "2026-01-30T01:02:57.882Z" }, + { url = "https://files.pythonhosted.org/packages/13/25/a4f555281d975bfdd1eba731450e2fe3a95870274da73fb12c40aeae7625/pytokens-0.4.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4a58d057208cb9075c144950d789511220b07636dd2e4708d5645d24de666bdc", size = 248565, upload-time = "2026-01-30T01:02:59.912Z" }, + { url = "https://files.pythonhosted.org/packages/17/50/bc0394b4ad5b1601be22fa43652173d47e4c9efbf0044c62e9a59b747c56/pytokens-0.4.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b49750419d300e2b5a3813cf229d4e5a4c728dae470bcc89867a9ad6f25a722d", size = 260824, upload-time = "2026-01-30T01:03:01.471Z" }, + { url = "https://files.pythonhosted.org/packages/4e/54/3e04f9d92a4be4fc6c80016bc396b923d2a6933ae94b5f557c939c460ee0/pytokens-0.4.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d9907d61f15bf7261d7e775bd5d7ee4d2930e04424bab1972591918497623a16", size = 264075, upload-time = "2026-01-30T01:03:04.143Z" }, + { url = "https://files.pythonhosted.org/packages/d1/1b/44b0326cb5470a4375f37988aea5d61b5cc52407143303015ebee94abfd6/pytokens-0.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:ee44d0f85b803321710f9239f335aafe16553b39106384cef8e6de40cb4ef2f6", size = 103323, upload-time = "2026-01-30T01:03:05.412Z" }, + { url = "https://files.pythonhosted.org/packages/41/5d/e44573011401fb82e9d51e97f1290ceb377800fb4eed650b96f4753b499c/pytokens-0.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:140709331e846b728475786df8aeb27d24f48cbcf7bcd449f8de75cae7a45083", size = 160663, upload-time = "2026-01-30T01:03:06.473Z" }, + { url = "https://files.pythonhosted.org/packages/f0/e6/5bbc3019f8e6f21d09c41f8b8654536117e5e211a85d89212d59cbdab381/pytokens-0.4.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6d6c4268598f762bc8e91f5dbf2ab2f61f7b95bdc07953b602db879b3c8c18e1", size = 255626, upload-time = "2026-01-30T01:03:08.177Z" }, + { url = "https://files.pythonhosted.org/packages/bf/3c/2d5297d82286f6f3d92770289fd439956b201c0a4fc7e72efb9b2293758e/pytokens-0.4.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:24afde1f53d95348b5a0eb19488661147285ca4dd7ed752bbc3e1c6242a304d1", size = 269779, upload-time = "2026-01-30T01:03:09.756Z" }, + { url = "https://files.pythonhosted.org/packages/20/01/7436e9ad693cebda0551203e0bf28f7669976c60ad07d6402098208476de/pytokens-0.4.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:5ad948d085ed6c16413eb5fec6b3e02fa00dc29a2534f088d3302c47eb59adf9", size = 268076, upload-time = "2026-01-30T01:03:10.957Z" }, + { url = "https://files.pythonhosted.org/packages/2e/df/533c82a3c752ba13ae7ef238b7f8cdd272cf1475f03c63ac6cf3fcfb00b6/pytokens-0.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:3f901fe783e06e48e8cbdc82d631fca8f118333798193e026a50ce1b3757ea68", size = 103552, upload-time = "2026-01-30T01:03:12.066Z" }, + { url = "https://files.pythonhosted.org/packages/cb/dc/08b1a080372afda3cceb4f3c0a7ba2bde9d6a5241f1edb02a22a019ee147/pytokens-0.4.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:8bdb9d0ce90cbf99c525e75a2fa415144fd570a1ba987380190e8b786bc6ef9b", size = 160720, upload-time = "2026-01-30T01:03:13.843Z" }, + { url = "https://files.pythonhosted.org/packages/64/0c/41ea22205da480837a700e395507e6a24425151dfb7ead73343d6e2d7ffe/pytokens-0.4.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5502408cab1cb18e128570f8d598981c68a50d0cbd7c61312a90507cd3a1276f", size = 254204, upload-time = "2026-01-30T01:03:14.886Z" }, + { url = "https://files.pythonhosted.org/packages/e0/d2/afe5c7f8607018beb99971489dbb846508f1b8f351fcefc225fcf4b2adc0/pytokens-0.4.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:29d1d8fb1030af4d231789959f21821ab6325e463f0503a61d204343c9b355d1", size = 268423, upload-time = "2026-01-30T01:03:15.936Z" }, + { url = "https://files.pythonhosted.org/packages/68/d4/00ffdbd370410c04e9591da9220a68dc1693ef7499173eb3e30d06e05ed1/pytokens-0.4.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:970b08dd6b86058b6dc07efe9e98414f5102974716232d10f32ff39701e841c4", size = 266859, upload-time = "2026-01-30T01:03:17.458Z" }, + { url = "https://files.pythonhosted.org/packages/a7/c9/c3161313b4ca0c601eeefabd3d3b576edaa9afdefd32da97210700e47652/pytokens-0.4.1-cp313-cp313-win_amd64.whl", hash = "sha256:9bd7d7f544d362576be74f9d5901a22f317efc20046efe2034dced238cbbfe78", size = 103520, upload-time = "2026-01-30T01:03:18.652Z" }, + { url = "https://files.pythonhosted.org/packages/8f/a7/b470f672e6fc5fee0a01d9e75005a0e617e162381974213a945fcd274843/pytokens-0.4.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:4a14d5f5fc78ce85e426aa159489e2d5961acf0e47575e08f35584009178e321", size = 160821, upload-time = "2026-01-30T01:03:19.684Z" }, + { url = "https://files.pythonhosted.org/packages/80/98/e83a36fe8d170c911f864bfded690d2542bfcfacb9c649d11a9e6eb9dc41/pytokens-0.4.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:97f50fd18543be72da51dd505e2ed20d2228c74e0464e4262e4899797803d7fa", size = 254263, upload-time = "2026-01-30T01:03:20.834Z" }, + { url = "https://files.pythonhosted.org/packages/0f/95/70d7041273890f9f97a24234c00b746e8da86df462620194cef1d411ddeb/pytokens-0.4.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:dc74c035f9bfca0255c1af77ddd2d6ae8419012805453e4b0e7513e17904545d", size = 268071, upload-time = "2026-01-30T01:03:21.888Z" }, + { url = "https://files.pythonhosted.org/packages/da/79/76e6d09ae19c99404656d7db9c35dfd20f2086f3eb6ecb496b5b31163bad/pytokens-0.4.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:f66a6bbe741bd431f6d741e617e0f39ec7257ca1f89089593479347cc4d13324", size = 271716, upload-time = "2026-01-30T01:03:23.633Z" }, + { url = "https://files.pythonhosted.org/packages/79/37/482e55fa1602e0a7ff012661d8c946bafdc05e480ea5a32f4f7e336d4aa9/pytokens-0.4.1-cp314-cp314-win_amd64.whl", hash = "sha256:b35d7e5ad269804f6697727702da3c517bb8a5228afa450ab0fa787732055fc9", size = 104539, upload-time = "2026-01-30T01:03:24.788Z" }, + { url = "https://files.pythonhosted.org/packages/30/e8/20e7db907c23f3d63b0be3b8a4fd1927f6da2395f5bcc7f72242bb963dfe/pytokens-0.4.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:8fcb9ba3709ff77e77f1c7022ff11d13553f3c30299a9fe246a166903e9091eb", size = 168474, upload-time = "2026-01-30T01:03:26.428Z" }, + { url = "https://files.pythonhosted.org/packages/d6/81/88a95ee9fafdd8f5f3452107748fd04c24930d500b9aba9738f3ade642cc/pytokens-0.4.1-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:79fc6b8699564e1f9b521582c35435f1bd32dd06822322ec44afdeba666d8cb3", size = 290473, upload-time = "2026-01-30T01:03:27.415Z" }, + { url = "https://files.pythonhosted.org/packages/cf/35/3aa899645e29b6375b4aed9f8d21df219e7c958c4c186b465e42ee0a06bf/pytokens-0.4.1-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d31b97b3de0f61571a124a00ffe9a81fb9939146c122c11060725bd5aea79975", size = 303485, upload-time = "2026-01-30T01:03:28.558Z" }, + { url = "https://files.pythonhosted.org/packages/52/a0/07907b6ff512674d9b201859f7d212298c44933633c946703a20c25e9d81/pytokens-0.4.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:967cf6e3fd4adf7de8fc73cd3043754ae79c36475c1c11d514fc72cf5490094a", size = 306698, upload-time = "2026-01-30T01:03:29.653Z" }, + { url = "https://files.pythonhosted.org/packages/39/2a/cbbf9250020a4a8dd53ba83a46c097b69e5eb49dd14e708f496f548c6612/pytokens-0.4.1-cp314-cp314t-win_amd64.whl", hash = "sha256:584c80c24b078eec1e227079d56dc22ff755e0ba8654d8383b2c549107528918", size = 116287, upload-time = "2026-01-30T01:03:30.912Z" }, + { url = "https://files.pythonhosted.org/packages/c6/78/397db326746f0a342855b81216ae1f0a32965deccfd7c830a2dbc66d2483/pytokens-0.4.1-py3-none-any.whl", hash = "sha256:26cef14744a8385f35d0e095dc8b3a7583f6c953c2e3d269c7f82484bf5ad2de", size = 13729, upload-time = "2026-01-30T01:03:45.029Z" }, +] + [[package]] name = "pywin32" version = "311" @@ -2102,7 +2407,7 @@ wheels = [ [[package]] name = "requests" -version = "2.32.5" +version = "2.33.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi" }, @@ -2110,9 +2415,9 @@ dependencies = [ { name = "idna" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517, upload-time = "2025-08-18T20:46:02.573Z" } +sdist = { url = "https://files.pythonhosted.org/packages/34/64/8860370b167a9721e8956ae116825caff829224fbca0ca6e7bf8ddef8430/requests-2.33.0.tar.gz", hash = "sha256:c7ebc5e8b0f21837386ad0e1c8fe8b829fa5f544d8df3b2253bff14ef29d7652", size = 134232, upload-time = "2026-03-25T15:10:41.586Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, + { url = "https://files.pythonhosted.org/packages/56/5d/c814546c2333ceea4ba42262d8c4d55763003e767fa169adc693bd524478/requests-2.33.0-py3-none-any.whl", hash = "sha256:3324635456fa185245e24865e810cecec7b4caf933d7eb133dcde67d48cee69b", size = 65017, upload-time = "2026-03-25T15:10:40.382Z" }, ] [[package]] @@ -2406,6 +2711,18 @@ dependencies = [ { name = "pydantic" }, ] +[[package]] +name = "tinycss2" +version = "1.5.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "webencodings" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a3/ae/2ca4913e5c0f09781d75482874c3a95db9105462a92ddd303c7d285d3df2/tinycss2-1.5.1.tar.gz", hash = "sha256:d339d2b616ba90ccce58da8495a78f46e55d4d25f9fd71dfd526f07e7d53f957", size = 88195, upload-time = "2025-11-23T10:29:10.082Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/60/45/c7b5c3168458db837e8ceab06dc77824e18202679d0463f0e8f002143a97/tinycss2-1.5.1-py3-none-any.whl", hash = "sha256:3415ba0f5839c062696996998176c4a3751d18b7edaaeeb658c9ce21ec150661", size = 28404, upload-time = "2025-11-23T10:29:08.676Z" }, +] + [[package]] name = "tomli" version = "2.2.1" @@ -2554,6 +2871,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/33/e8/e40370e6d74ddba47f002a32919d91310d6074130fe4e17dabcafc15cbf1/watchdog-6.0.0-py3-none-win_ia64.whl", hash = "sha256:a1914259fa9e1454315171103c6a30961236f508b9b623eae470268bbcc6a22f", size = 79067, upload-time = "2024-11-01T14:07:11.845Z" }, ] +[[package]] +name = "webencodings" +version = "0.5.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0b/02/ae6ceac1baeda530866a85075641cec12989bd8d31af6d5ab4a3e8c92f47/webencodings-0.5.1.tar.gz", hash = "sha256:b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923", size = 9721, upload-time = "2017-04-05T20:21:34.189Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f4/24/2a3e3df732393fed8b3ebf2ec078f05546de641fe1b667ee316ec1dcf3b7/webencodings-0.5.1-py2.py3-none-any.whl", hash = "sha256:a0af1213f3c2226497a97e2b3aa01a7e4bee4f403f95be16fc9acd2947514a78", size = 11774, upload-time = "2017-04-05T20:21:32.581Z" }, +] + [[package]] name = "websockets" version = "15.0.1" @@ -2612,3 +2938,81 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/68/a1/dcb68430b1d00b698ae7a7e0194433bce4f07ded185f0ee5fb21e2a2e91e/websockets-15.0.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:cad21560da69f4ce7658ca2cb83138fb4cf695a2ba3e475e0559e05991aa8122", size = 176884, upload-time = "2025-03-05T20:03:27.934Z" }, { url = "https://files.pythonhosted.org/packages/fa/a8/5b41e0da817d64113292ab1f8247140aac61cbf6cfd085d6a0fa77f4984f/websockets-15.0.1-py3-none-any.whl", hash = "sha256:f7a866fbc1e97b5c617ee4116daaa09b722101d4a3c170c787450ba409f9736f", size = 169743, upload-time = "2025-03-05T20:03:39.41Z" }, ] + +[[package]] +name = "wrapt" +version = "1.17.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/95/8f/aeb76c5b46e273670962298c23e7ddde79916cb74db802131d49a85e4b7d/wrapt-1.17.3.tar.gz", hash = "sha256:f66eb08feaa410fe4eebd17f2a2c8e2e46d3476e9f8c783daa8e09e0faa666d0", size = 55547, upload-time = "2025-08-12T05:53:21.714Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3f/23/bb82321b86411eb51e5a5db3fb8f8032fd30bd7c2d74bfe936136b2fa1d6/wrapt-1.17.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:88bbae4d40d5a46142e70d58bf664a89b6b4befaea7b2ecc14e03cedb8e06c04", size = 53482, upload-time = "2025-08-12T05:51:44.467Z" }, + { url = "https://files.pythonhosted.org/packages/45/69/f3c47642b79485a30a59c63f6d739ed779fb4cc8323205d047d741d55220/wrapt-1.17.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e6b13af258d6a9ad602d57d889f83b9d5543acd471eee12eb51f5b01f8eb1bc2", size = 38676, upload-time = "2025-08-12T05:51:32.636Z" }, + { url = "https://files.pythonhosted.org/packages/d1/71/e7e7f5670c1eafd9e990438e69d8fb46fa91a50785332e06b560c869454f/wrapt-1.17.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd341868a4b6714a5962c1af0bd44f7c404ef78720c7de4892901e540417111c", size = 38957, upload-time = "2025-08-12T05:51:54.655Z" }, + { url = "https://files.pythonhosted.org/packages/de/17/9f8f86755c191d6779d7ddead1a53c7a8aa18bccb7cea8e7e72dfa6a8a09/wrapt-1.17.3-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f9b2601381be482f70e5d1051a5965c25fb3625455a2bf520b5a077b22afb775", size = 81975, upload-time = "2025-08-12T05:52:30.109Z" }, + { url = "https://files.pythonhosted.org/packages/f2/15/dd576273491f9f43dd09fce517f6c2ce6eb4fe21681726068db0d0467096/wrapt-1.17.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:343e44b2a8e60e06a7e0d29c1671a0d9951f59174f3709962b5143f60a2a98bd", size = 83149, upload-time = "2025-08-12T05:52:09.316Z" }, + { url = "https://files.pythonhosted.org/packages/0c/c4/5eb4ce0d4814521fee7aa806264bf7a114e748ad05110441cd5b8a5c744b/wrapt-1.17.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:33486899acd2d7d3066156b03465b949da3fd41a5da6e394ec49d271baefcf05", size = 82209, upload-time = "2025-08-12T05:52:10.331Z" }, + { url = "https://files.pythonhosted.org/packages/31/4b/819e9e0eb5c8dc86f60dfc42aa4e2c0d6c3db8732bce93cc752e604bb5f5/wrapt-1.17.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e6f40a8aa5a92f150bdb3e1c44b7e98fb7113955b2e5394122fa5532fec4b418", size = 81551, upload-time = "2025-08-12T05:52:31.137Z" }, + { url = "https://files.pythonhosted.org/packages/f8/83/ed6baf89ba3a56694700139698cf703aac9f0f9eb03dab92f57551bd5385/wrapt-1.17.3-cp310-cp310-win32.whl", hash = "sha256:a36692b8491d30a8c75f1dfee65bef119d6f39ea84ee04d9f9311f83c5ad9390", size = 36464, upload-time = "2025-08-12T05:53:01.204Z" }, + { url = "https://files.pythonhosted.org/packages/2f/90/ee61d36862340ad7e9d15a02529df6b948676b9a5829fd5e16640156627d/wrapt-1.17.3-cp310-cp310-win_amd64.whl", hash = "sha256:afd964fd43b10c12213574db492cb8f73b2f0826c8df07a68288f8f19af2ebe6", size = 38748, upload-time = "2025-08-12T05:53:00.209Z" }, + { url = "https://files.pythonhosted.org/packages/bd/c3/cefe0bd330d389c9983ced15d326f45373f4073c9f4a8c2f99b50bfea329/wrapt-1.17.3-cp310-cp310-win_arm64.whl", hash = "sha256:af338aa93554be859173c39c85243970dc6a289fa907402289eeae7543e1ae18", size = 36810, upload-time = "2025-08-12T05:52:51.906Z" }, + { url = "https://files.pythonhosted.org/packages/52/db/00e2a219213856074a213503fdac0511203dceefff26e1daa15250cc01a0/wrapt-1.17.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:273a736c4645e63ac582c60a56b0acb529ef07f78e08dc6bfadf6a46b19c0da7", size = 53482, upload-time = "2025-08-12T05:51:45.79Z" }, + { url = "https://files.pythonhosted.org/packages/5e/30/ca3c4a5eba478408572096fe9ce36e6e915994dd26a4e9e98b4f729c06d9/wrapt-1.17.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5531d911795e3f935a9c23eb1c8c03c211661a5060aab167065896bbf62a5f85", size = 38674, upload-time = "2025-08-12T05:51:34.629Z" }, + { url = "https://files.pythonhosted.org/packages/31/25/3e8cc2c46b5329c5957cec959cb76a10718e1a513309c31399a4dad07eb3/wrapt-1.17.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0610b46293c59a3adbae3dee552b648b984176f8562ee0dba099a56cfbe4df1f", size = 38959, upload-time = "2025-08-12T05:51:56.074Z" }, + { url = "https://files.pythonhosted.org/packages/5d/8f/a32a99fc03e4b37e31b57cb9cefc65050ea08147a8ce12f288616b05ef54/wrapt-1.17.3-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b32888aad8b6e68f83a8fdccbf3165f5469702a7544472bdf41f582970ed3311", size = 82376, upload-time = "2025-08-12T05:52:32.134Z" }, + { url = "https://files.pythonhosted.org/packages/31/57/4930cb8d9d70d59c27ee1332a318c20291749b4fba31f113c2f8ac49a72e/wrapt-1.17.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8cccf4f81371f257440c88faed6b74f1053eef90807b77e31ca057b2db74edb1", size = 83604, upload-time = "2025-08-12T05:52:11.663Z" }, + { url = "https://files.pythonhosted.org/packages/a8/f3/1afd48de81d63dd66e01b263a6fbb86e1b5053b419b9b33d13e1f6d0f7d0/wrapt-1.17.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d8a210b158a34164de8bb68b0e7780041a903d7b00c87e906fb69928bf7890d5", size = 82782, upload-time = "2025-08-12T05:52:12.626Z" }, + { url = "https://files.pythonhosted.org/packages/1e/d7/4ad5327612173b144998232f98a85bb24b60c352afb73bc48e3e0d2bdc4e/wrapt-1.17.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:79573c24a46ce11aab457b472efd8d125e5a51da2d1d24387666cd85f54c05b2", size = 82076, upload-time = "2025-08-12T05:52:33.168Z" }, + { url = "https://files.pythonhosted.org/packages/bb/59/e0adfc831674a65694f18ea6dc821f9fcb9ec82c2ce7e3d73a88ba2e8718/wrapt-1.17.3-cp311-cp311-win32.whl", hash = "sha256:c31eebe420a9a5d2887b13000b043ff6ca27c452a9a22fa71f35f118e8d4bf89", size = 36457, upload-time = "2025-08-12T05:53:03.936Z" }, + { url = "https://files.pythonhosted.org/packages/83/88/16b7231ba49861b6f75fc309b11012ede4d6b0a9c90969d9e0db8d991aeb/wrapt-1.17.3-cp311-cp311-win_amd64.whl", hash = "sha256:0b1831115c97f0663cb77aa27d381237e73ad4f721391a9bfb2fe8bc25fa6e77", size = 38745, upload-time = "2025-08-12T05:53:02.885Z" }, + { url = "https://files.pythonhosted.org/packages/9a/1e/c4d4f3398ec073012c51d1c8d87f715f56765444e1a4b11e5180577b7e6e/wrapt-1.17.3-cp311-cp311-win_arm64.whl", hash = "sha256:5a7b3c1ee8265eb4c8f1b7d29943f195c00673f5ab60c192eba2d4a7eae5f46a", size = 36806, upload-time = "2025-08-12T05:52:53.368Z" }, + { url = "https://files.pythonhosted.org/packages/9f/41/cad1aba93e752f1f9268c77270da3c469883d56e2798e7df6240dcb2287b/wrapt-1.17.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ab232e7fdb44cdfbf55fc3afa31bcdb0d8980b9b95c38b6405df2acb672af0e0", size = 53998, upload-time = "2025-08-12T05:51:47.138Z" }, + { url = "https://files.pythonhosted.org/packages/60/f8/096a7cc13097a1869fe44efe68dace40d2a16ecb853141394047f0780b96/wrapt-1.17.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:9baa544e6acc91130e926e8c802a17f3b16fbea0fd441b5a60f5cf2cc5c3deba", size = 39020, upload-time = "2025-08-12T05:51:35.906Z" }, + { url = "https://files.pythonhosted.org/packages/33/df/bdf864b8997aab4febb96a9ae5c124f700a5abd9b5e13d2a3214ec4be705/wrapt-1.17.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6b538e31eca1a7ea4605e44f81a48aa24c4632a277431a6ed3f328835901f4fd", size = 39098, upload-time = "2025-08-12T05:51:57.474Z" }, + { url = "https://files.pythonhosted.org/packages/9f/81/5d931d78d0eb732b95dc3ddaeeb71c8bb572fb01356e9133916cd729ecdd/wrapt-1.17.3-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:042ec3bb8f319c147b1301f2393bc19dba6e176b7da446853406d041c36c7828", size = 88036, upload-time = "2025-08-12T05:52:34.784Z" }, + { url = "https://files.pythonhosted.org/packages/ca/38/2e1785df03b3d72d34fc6252d91d9d12dc27a5c89caef3335a1bbb8908ca/wrapt-1.17.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3af60380ba0b7b5aeb329bc4e402acd25bd877e98b3727b0135cb5c2efdaefe9", size = 88156, upload-time = "2025-08-12T05:52:13.599Z" }, + { url = "https://files.pythonhosted.org/packages/b3/8b/48cdb60fe0603e34e05cffda0b2a4adab81fd43718e11111a4b0100fd7c1/wrapt-1.17.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0b02e424deef65c9f7326d8c19220a2c9040c51dc165cddb732f16198c168396", size = 87102, upload-time = "2025-08-12T05:52:14.56Z" }, + { url = "https://files.pythonhosted.org/packages/3c/51/d81abca783b58f40a154f1b2c56db1d2d9e0d04fa2d4224e357529f57a57/wrapt-1.17.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:74afa28374a3c3a11b3b5e5fca0ae03bef8450d6aa3ab3a1e2c30e3a75d023dc", size = 87732, upload-time = "2025-08-12T05:52:36.165Z" }, + { url = "https://files.pythonhosted.org/packages/9e/b1/43b286ca1392a006d5336412d41663eeef1ad57485f3e52c767376ba7e5a/wrapt-1.17.3-cp312-cp312-win32.whl", hash = "sha256:4da9f45279fff3543c371d5ababc57a0384f70be244de7759c85a7f989cb4ebe", size = 36705, upload-time = "2025-08-12T05:53:07.123Z" }, + { url = "https://files.pythonhosted.org/packages/28/de/49493f962bd3c586ab4b88066e967aa2e0703d6ef2c43aa28cb83bf7b507/wrapt-1.17.3-cp312-cp312-win_amd64.whl", hash = "sha256:e71d5c6ebac14875668a1e90baf2ea0ef5b7ac7918355850c0908ae82bcb297c", size = 38877, upload-time = "2025-08-12T05:53:05.436Z" }, + { url = "https://files.pythonhosted.org/packages/f1/48/0f7102fe9cb1e8a5a77f80d4f0956d62d97034bbe88d33e94699f99d181d/wrapt-1.17.3-cp312-cp312-win_arm64.whl", hash = "sha256:604d076c55e2fdd4c1c03d06dc1a31b95130010517b5019db15365ec4a405fc6", size = 36885, upload-time = "2025-08-12T05:52:54.367Z" }, + { url = "https://files.pythonhosted.org/packages/fc/f6/759ece88472157acb55fc195e5b116e06730f1b651b5b314c66291729193/wrapt-1.17.3-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a47681378a0439215912ef542c45a783484d4dd82bac412b71e59cf9c0e1cea0", size = 54003, upload-time = "2025-08-12T05:51:48.627Z" }, + { url = "https://files.pythonhosted.org/packages/4f/a9/49940b9dc6d47027dc850c116d79b4155f15c08547d04db0f07121499347/wrapt-1.17.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:54a30837587c6ee3cd1a4d1c2ec5d24e77984d44e2f34547e2323ddb4e22eb77", size = 39025, upload-time = "2025-08-12T05:51:37.156Z" }, + { url = "https://files.pythonhosted.org/packages/45/35/6a08de0f2c96dcdd7fe464d7420ddb9a7655a6561150e5fc4da9356aeaab/wrapt-1.17.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:16ecf15d6af39246fe33e507105d67e4b81d8f8d2c6598ff7e3ca1b8a37213f7", size = 39108, upload-time = "2025-08-12T05:51:58.425Z" }, + { url = "https://files.pythonhosted.org/packages/0c/37/6faf15cfa41bf1f3dba80cd3f5ccc6622dfccb660ab26ed79f0178c7497f/wrapt-1.17.3-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:6fd1ad24dc235e4ab88cda009e19bf347aabb975e44fd5c2fb22a3f6e4141277", size = 88072, upload-time = "2025-08-12T05:52:37.53Z" }, + { url = "https://files.pythonhosted.org/packages/78/f2/efe19ada4a38e4e15b6dff39c3e3f3f73f5decf901f66e6f72fe79623a06/wrapt-1.17.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0ed61b7c2d49cee3c027372df5809a59d60cf1b6c2f81ee980a091f3afed6a2d", size = 88214, upload-time = "2025-08-12T05:52:15.886Z" }, + { url = "https://files.pythonhosted.org/packages/40/90/ca86701e9de1622b16e09689fc24b76f69b06bb0150990f6f4e8b0eeb576/wrapt-1.17.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:423ed5420ad5f5529db9ce89eac09c8a2f97da18eb1c870237e84c5a5c2d60aa", size = 87105, upload-time = "2025-08-12T05:52:17.914Z" }, + { url = "https://files.pythonhosted.org/packages/fd/e0/d10bd257c9a3e15cbf5523025252cc14d77468e8ed644aafb2d6f54cb95d/wrapt-1.17.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e01375f275f010fcbf7f643b4279896d04e571889b8a5b3f848423d91bf07050", size = 87766, upload-time = "2025-08-12T05:52:39.243Z" }, + { url = "https://files.pythonhosted.org/packages/e8/cf/7d848740203c7b4b27eb55dbfede11aca974a51c3d894f6cc4b865f42f58/wrapt-1.17.3-cp313-cp313-win32.whl", hash = "sha256:53e5e39ff71b3fc484df8a522c933ea2b7cdd0d5d15ae82e5b23fde87d44cbd8", size = 36711, upload-time = "2025-08-12T05:53:10.074Z" }, + { url = "https://files.pythonhosted.org/packages/57/54/35a84d0a4d23ea675994104e667ceff49227ce473ba6a59ba2c84f250b74/wrapt-1.17.3-cp313-cp313-win_amd64.whl", hash = "sha256:1f0b2f40cf341ee8cc1a97d51ff50dddb9fcc73241b9143ec74b30fc4f44f6cb", size = 38885, upload-time = "2025-08-12T05:53:08.695Z" }, + { url = "https://files.pythonhosted.org/packages/01/77/66e54407c59d7b02a3c4e0af3783168fff8e5d61def52cda8728439d86bc/wrapt-1.17.3-cp313-cp313-win_arm64.whl", hash = "sha256:7425ac3c54430f5fc5e7b6f41d41e704db073309acfc09305816bc6a0b26bb16", size = 36896, upload-time = "2025-08-12T05:52:55.34Z" }, + { url = "https://files.pythonhosted.org/packages/02/a2/cd864b2a14f20d14f4c496fab97802001560f9f41554eef6df201cd7f76c/wrapt-1.17.3-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:cf30f6e3c077c8e6a9a7809c94551203c8843e74ba0c960f4a98cd80d4665d39", size = 54132, upload-time = "2025-08-12T05:51:49.864Z" }, + { url = "https://files.pythonhosted.org/packages/d5/46/d011725b0c89e853dc44cceb738a307cde5d240d023d6d40a82d1b4e1182/wrapt-1.17.3-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:e228514a06843cae89621384cfe3a80418f3c04aadf8a3b14e46a7be704e4235", size = 39091, upload-time = "2025-08-12T05:51:38.935Z" }, + { url = "https://files.pythonhosted.org/packages/2e/9e/3ad852d77c35aae7ddebdbc3b6d35ec8013af7d7dddad0ad911f3d891dae/wrapt-1.17.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:5ea5eb3c0c071862997d6f3e02af1d055f381b1d25b286b9d6644b79db77657c", size = 39172, upload-time = "2025-08-12T05:51:59.365Z" }, + { url = "https://files.pythonhosted.org/packages/c3/f7/c983d2762bcce2326c317c26a6a1e7016f7eb039c27cdf5c4e30f4160f31/wrapt-1.17.3-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:281262213373b6d5e4bb4353bc36d1ba4084e6d6b5d242863721ef2bf2c2930b", size = 87163, upload-time = "2025-08-12T05:52:40.965Z" }, + { url = "https://files.pythonhosted.org/packages/e4/0f/f673f75d489c7f22d17fe0193e84b41540d962f75fce579cf6873167c29b/wrapt-1.17.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:dc4a8d2b25efb6681ecacad42fca8859f88092d8732b170de6a5dddd80a1c8fa", size = 87963, upload-time = "2025-08-12T05:52:20.326Z" }, + { url = "https://files.pythonhosted.org/packages/df/61/515ad6caca68995da2fac7a6af97faab8f78ebe3bf4f761e1b77efbc47b5/wrapt-1.17.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:373342dd05b1d07d752cecbec0c41817231f29f3a89aa8b8843f7b95992ed0c7", size = 86945, upload-time = "2025-08-12T05:52:21.581Z" }, + { url = "https://files.pythonhosted.org/packages/d3/bd/4e70162ce398462a467bc09e768bee112f1412e563620adc353de9055d33/wrapt-1.17.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:d40770d7c0fd5cbed9d84b2c3f2e156431a12c9a37dc6284060fb4bec0b7ffd4", size = 86857, upload-time = "2025-08-12T05:52:43.043Z" }, + { url = "https://files.pythonhosted.org/packages/2b/b8/da8560695e9284810b8d3df8a19396a6e40e7518059584a1a394a2b35e0a/wrapt-1.17.3-cp314-cp314-win32.whl", hash = "sha256:fbd3c8319de8e1dc79d346929cd71d523622da527cca14e0c1d257e31c2b8b10", size = 37178, upload-time = "2025-08-12T05:53:12.605Z" }, + { url = "https://files.pythonhosted.org/packages/db/c8/b71eeb192c440d67a5a0449aaee2310a1a1e8eca41676046f99ed2487e9f/wrapt-1.17.3-cp314-cp314-win_amd64.whl", hash = "sha256:e1a4120ae5705f673727d3253de3ed0e016f7cd78dc463db1b31e2463e1f3cf6", size = 39310, upload-time = "2025-08-12T05:53:11.106Z" }, + { url = "https://files.pythonhosted.org/packages/45/20/2cda20fd4865fa40f86f6c46ed37a2a8356a7a2fde0773269311f2af56c7/wrapt-1.17.3-cp314-cp314-win_arm64.whl", hash = "sha256:507553480670cab08a800b9463bdb881b2edeed77dc677b0a5915e6106e91a58", size = 37266, upload-time = "2025-08-12T05:52:56.531Z" }, + { url = "https://files.pythonhosted.org/packages/77/ed/dd5cf21aec36c80443c6f900449260b80e2a65cf963668eaef3b9accce36/wrapt-1.17.3-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:ed7c635ae45cfbc1a7371f708727bf74690daedc49b4dba310590ca0bd28aa8a", size = 56544, upload-time = "2025-08-12T05:51:51.109Z" }, + { url = "https://files.pythonhosted.org/packages/8d/96/450c651cc753877ad100c7949ab4d2e2ecc4d97157e00fa8f45df682456a/wrapt-1.17.3-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:249f88ed15503f6492a71f01442abddd73856a0032ae860de6d75ca62eed8067", size = 40283, upload-time = "2025-08-12T05:51:39.912Z" }, + { url = "https://files.pythonhosted.org/packages/d1/86/2fcad95994d9b572db57632acb6f900695a648c3e063f2cd344b3f5c5a37/wrapt-1.17.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:5a03a38adec8066d5a37bea22f2ba6bbf39fcdefbe2d91419ab864c3fb515454", size = 40366, upload-time = "2025-08-12T05:52:00.693Z" }, + { url = "https://files.pythonhosted.org/packages/64/0e/f4472f2fdde2d4617975144311f8800ef73677a159be7fe61fa50997d6c0/wrapt-1.17.3-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:5d4478d72eb61c36e5b446e375bbc49ed002430d17cdec3cecb36993398e1a9e", size = 108571, upload-time = "2025-08-12T05:52:44.521Z" }, + { url = "https://files.pythonhosted.org/packages/cc/01/9b85a99996b0a97c8a17484684f206cbb6ba73c1ce6890ac668bcf3838fb/wrapt-1.17.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:223db574bb38637e8230eb14b185565023ab624474df94d2af18f1cdb625216f", size = 113094, upload-time = "2025-08-12T05:52:22.618Z" }, + { url = "https://files.pythonhosted.org/packages/25/02/78926c1efddcc7b3aa0bc3d6b33a822f7d898059f7cd9ace8c8318e559ef/wrapt-1.17.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:e405adefb53a435f01efa7ccdec012c016b5a1d3f35459990afc39b6be4d5056", size = 110659, upload-time = "2025-08-12T05:52:24.057Z" }, + { url = "https://files.pythonhosted.org/packages/dc/ee/c414501ad518ac3e6fe184753632fe5e5ecacdcf0effc23f31c1e4f7bfcf/wrapt-1.17.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:88547535b787a6c9ce4086917b6e1d291aa8ed914fdd3a838b3539dc95c12804", size = 106946, upload-time = "2025-08-12T05:52:45.976Z" }, + { url = "https://files.pythonhosted.org/packages/be/44/a1bd64b723d13bb151d6cc91b986146a1952385e0392a78567e12149c7b4/wrapt-1.17.3-cp314-cp314t-win32.whl", hash = "sha256:41b1d2bc74c2cac6f9074df52b2efbef2b30bdfe5f40cb78f8ca22963bc62977", size = 38717, upload-time = "2025-08-12T05:53:15.214Z" }, + { url = "https://files.pythonhosted.org/packages/79/d9/7cfd5a312760ac4dd8bf0184a6ee9e43c33e47f3dadc303032ce012b8fa3/wrapt-1.17.3-cp314-cp314t-win_amd64.whl", hash = "sha256:73d496de46cd2cdbdbcce4ae4bcdb4afb6a11234a1df9c085249d55166b95116", size = 41334, upload-time = "2025-08-12T05:53:14.178Z" }, + { url = "https://files.pythonhosted.org/packages/46/78/10ad9781128ed2f99dbc474f43283b13fea8ba58723e98844367531c18e9/wrapt-1.17.3-cp314-cp314t-win_arm64.whl", hash = "sha256:f38e60678850c42461d4202739f9bf1e3a737c7ad283638251e79cc49effb6b6", size = 38471, upload-time = "2025-08-12T05:52:57.784Z" }, + { url = "https://files.pythonhosted.org/packages/1f/f6/a933bd70f98e9cf3e08167fc5cd7aaaca49147e48411c0bd5ae701bb2194/wrapt-1.17.3-py3-none-any.whl", hash = "sha256:7171ae35d2c33d326ac19dd8facb1e82e5fd04ef8c6c0e394d7af55a55051c22", size = 23591, upload-time = "2025-08-12T05:53:20.674Z" }, +] + +[[package]] +name = "zipp" +version = "3.23.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e3/02/0f2892c661036d50ede074e376733dca2ae7c6eb617489437771209d4180/zipp-3.23.0.tar.gz", hash = "sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166", size = 25547, upload-time = "2025-06-08T17:06:39.4Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2e/54/647ade08bf0db230bfea292f893923872fd20be6ac6f53b2b936ba839d75/zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e", size = 10276, upload-time = "2025-06-08T17:06:38.034Z" }, +]