diff --git a/.agents/claude-code-cli.ts b/.agents/claude-code-cli.ts index 2de48ff5c5..075d9f23e4 100644 --- a/.agents/claude-code-cli.ts +++ b/.agents/claude-code-cli.ts @@ -10,7 +10,7 @@ const baseDefinition = createCliAgent({ startCommand: 'claude --dangerously-skip-permissions', permissionNote: 'Always use `--dangerously-skip-permissions` when testing to avoid permission prompts that would block automated tests.', - model: 'anthropic/claude-opus-4.6', + model: 'anthropic/claude-opus-4.7', }) // Constants must be inside handleSteps since it gets serialized via .toString() diff --git a/.agents/codebuff-local-cli.ts b/.agents/codebuff-local-cli.ts index 1fdf975c62..8cb367a08a 100644 --- a/.agents/codebuff-local-cli.ts +++ b/.agents/codebuff-local-cli.ts @@ -10,7 +10,7 @@ const baseDefinition = createCliAgent({ startCommand: 'bun --cwd=cli run dev', permissionNote: 'No permission flags needed for Codebuff local dev server.', - model: 'anthropic/claude-opus-4.6', + model: 'anthropic/claude-opus-4.7', skipPrepPhase: true, cliSpecificDocs: `## Codebuff CLI Specific Guidance diff --git a/.agents/codex-cli.ts b/.agents/codex-cli.ts index 9914e3d7c7..e7b18473a8 100644 --- a/.agents/codex-cli.ts +++ b/.agents/codex-cli.ts @@ -81,7 +81,7 @@ const baseDefinition = createCliAgent({ startCommand: 'codex -a never -s danger-full-access', permissionNote: 'Always use `-a never -s danger-full-access` when testing to avoid approval prompts that would block automated tests.', - model: 'anthropic/claude-opus-4.6', + model: 'anthropic/claude-opus-4.7', extraInputParams: { reviewType: { type: 'string', diff --git a/.agents/gemini-cli.ts b/.agents/gemini-cli.ts index 38186add48..d5eb7f45e2 100644 --- a/.agents/gemini-cli.ts +++ b/.agents/gemini-cli.ts @@ -10,7 +10,7 @@ const baseDefinition = createCliAgent({ startCommand: 'gemini --yolo', permissionNote: 'Always use `--yolo` (or `--approval-mode yolo`) when testing to auto-approve all tool actions and avoid prompts that would block automated tests.', - model: 'anthropic/claude-opus-4.6', + model: 'anthropic/claude-opus-4.7', cliSpecificDocs: `## Gemini CLI Commands Gemini CLI uses slash commands for navigation: diff --git a/.agents/package.json b/.agents/package.json index e6dd6fc4e7..053d1e6c66 100644 --- a/.agents/package.json +++ b/.agents/package.json @@ -5,7 +5,6 @@ "type": "module", "scripts": { "typecheck": "bun x tsc --noEmit -p tsconfig.json", - "test": "bun test __tests__", - "test:e2e": "bun test e2e" + "test": "bun test __tests__" } } diff --git a/.agents/sessions/03-02-1407-chatgpt-oauth-direct/LESSONS.md b/.agents/sessions/03-02-1407-chatgpt-oauth-direct/LESSONS.md new file mode 100644 index 0000000000..0dbb6fd5b9 --- /dev/null +++ b/.agents/sessions/03-02-1407-chatgpt-oauth-direct/LESSONS.md @@ -0,0 +1,42 @@ +# LESSONS — ChatGPT OAuth Direct Routing + +Session: `.agents/sessions/03-02-14:07-chatgpt-oauth-direct/` + +## What went well +- Building this feature behind a strict feature flag (`CHATGPT_OAUTH_ENABLED=false`) reduced rollout risk while allowing full end-to-end wiring. +- Reusing the Claude OAuth architectural pattern (credentials helpers, refresh mutex, routing split) accelerated implementation without coupling the two providers. +- Splitting policy logic into `classifyChatGptOAuthStreamError` made fallback/auth/fail-fast behavior easier to test and reason about. +- Adding focused CLI tests for `/connect:chatgpt` gating and utility sanitization caught regression risk early. + +## Current confidence / known gaps +- Runtime ChatGPT stream policy is **partially tested**: `classifyChatGptOAuthStreamError` is covered, but we do not yet have full behavioral tests for `promptAiSdkStream` recursion branches (actual fallback recursion and post-partial-output behavior). +- CLI routing coverage is strongest for **feature-flag OFF** paths; flag-ON auth-code routing should get explicit dedicated tests in a future pass. + +## What was tricky +- The repo had unrelated local drift during implementation; explicit scope cleanup (`git checkout -- `) was necessary to avoid accidental cross-feature commits. +- CLI module mocking is path-sensitive. Test modules under `cli/src/commands/__tests__` must mock sibling modules with correct relative paths (e.g. `../../state/chat-store`), or mocks silently fail. +- Over-mocking analytics can break transitive imports (`setAnalyticsErrorLogger` export expectations). A safe pattern is spreading real analytics exports and overriding only `trackEvent`. + +## Unexpected behaviors / gotchas +- A staged unrelated file can survive despite working-tree revert; both staged and worktree states must be checked before final handoff. +- “Looks correct” tests can still miss runtime branches if they only validate helper classification, not route wiring; reviewer loops were useful to force coverage on practical paths. +- For OAuth tooling/scripts, sanitize error text aggressively. Returning status-only errors avoids accidental token payload leakage. + +## Useful patterns discovered +- Keep direct-provider routing stream-only initially; explicitly forcing non-streaming/structured calls to backend avoided broad compatibility risk. +- Use deterministic model allowlist + normalization mapping in constants to avoid relying on provider-side parsing/errors for unsupported models. +- Treat temporary protocol validation scripts as first-class validation artifacts: they are valuable for real-account smoke checks without coupling to full CLI runtime. + +## Temporary script disposition +- `scripts/chatgpt-oauth-validate.ts` is currently kept as a **dev utility** for manual protocol revalidation while the feature remains experimental/off by default. +- Removal criteria: if protocol endpoints are either officially documented or the CLI flow gets stable automated integration coverage, this script can be retired. + +## Repeatable security verification +- For redaction checks, run targeted searches against changed code/log handling paths for sensitive markers before handoff, e.g. `access_token`, `refresh_token`, and `Authorization: Bearer`. +- Keep surfaced token exchange errors status-only and avoid echoing raw provider response bodies. + +## Follow-up improvements worth considering +- Add deeper runtime-behavior tests for `promptAiSdkStream` recursive fallback branches (not just policy classifier). +- Add explicit CLI test for flag-ON connect flow path once flag toggling is test-harness friendly. +- If feature graduates from experimental, add richer direct-path observability while preserving strict token redaction. +- Add periodic protocol drift checks (authorize/token/callback PKCE assumptions) before enabling the feature flag in production defaults. diff --git a/.agents/sessions/03-02-1407-chatgpt-oauth-direct/PLAN.md b/.agents/sessions/03-02-1407-chatgpt-oauth-direct/PLAN.md new file mode 100644 index 0000000000..9684c95329 --- /dev/null +++ b/.agents/sessions/03-02-1407-chatgpt-oauth-direct/PLAN.md @@ -0,0 +1,104 @@ +# PLAN — ChatGPT Subscription OAuth Direct Routing + +## Implementation Steps +1. **Add shared ChatGPT OAuth constants** + - Create `common/src/constants/chatgpt-oauth.ts` with: + - feature flag (`CHATGPT_OAUTH_ENABLED=false`) + - endpoints/client id/redirect URI/env var + - model allowlist + normalization helpers + - Export through `common/src/constants/index.ts`. + +2. **Build core OAuth utility + temporary protocol validation script (early gate)** + - Create `cli/src/utils/chatgpt-oauth.ts` with PKCE URL generation, browser-open helper, pasted code/URL parsing, token exchange helper. + - Create `scripts/chatgpt-oauth-validate.ts` to test OAuth URL generation + paste parsing + token exchange interaction. + - **Run this script before full integration** as go/no-go checkpoint for endpoint assumptions. + +3. **Add SDK env + credential support** + - Extend `sdk/src/env.ts` with `getChatGptOAuthTokenFromEnv()`. + - Extend `sdk/src/credentials.ts` with `chatgptOAuth` schema and helpers: + - get/save/clear + - valid-check + refresh mutex + - get-valid-with-refresh + - Preserve all non-target credentials in read/write operations. + +4. **Add CLI connect flow UI and command routing** + - Create `cli/src/components/chatgpt-connect-banner.tsx` with state machine + `handleChatGptAuthCode`. + - Update input modes (`connect:chatgpt`) and banner registry. + - Add `/connect:chatgpt` command + alias handling and slash command entry (feature-gated). + - Extend router to process pasted auth code in `connect:chatgpt` mode. + - Verify command visibility: hidden when flag OFF, present when flag ON. + +5. **Implement direct routing primitives in model-provider (decomposed)** + - 5.1 Add ChatGPT direct eligibility checks (feature flag + creds + model scope + skip flag + rate-limit cache state). + - 5.2 Add model normalization + prevalidation helpers (OpenRouter-style -> provider-native). + - 5.3 Add strict payload sanitization helper for direct requests. + - 5.4 Add ChatGPT OAuth direct model construction using OpenAI-compatible transport. + - 5.5 Add ChatGPT rate-limit cache helpers (parallel to Claude cache pattern). + - Keep Claude OAuth path unchanged. + +6. **Update stream execution + fallback/error policy** + - Extend `sdk/src/impl/llm.ts` to: + - recognize ChatGPT direct route usage + - emit ChatGPT OAuth analytics + - fallback only on rate-limit errors + - fail with reconnect guidance on auth errors + - fail fast for all other direct errors + - skip cost accounting for successful ChatGPT direct requests + - avoid fallback once output has already streamed + +7. **Wire startup refresh + CLI status surfacing** + - Update `cli/src/init/init-app.ts` for background ChatGPT OAuth credential refresh when enabled. + - Update `cli/src/chat.tsx`, `cli/src/components/bottom-status-line.tsx`, and `cli/src/components/usage-banner.tsx` to surface ChatGPT connection/active status. + +8. **Add analytics constants + SDK exports** + - Extend `common/src/constants/analytics-events.ts` with ChatGPT OAuth request/rate-limit/auth-error events. + - Ensure SDK exports newly needed helper(s) in `sdk/src/index.ts`. + +9. **Add/adjust tests (explicit matrix)** + - SDK credentials tests: + - env precedence + - persisted read/write/clear + - refresh success/failure + mutex + - Model-provider tests: + - rate-limit cache lifecycle + - allowlist prevalidation + unsupported-model error + - normalization behavior for mapped/unknown variants + - LLM routing/fallback tests (targeted): + - 429 fallback + - 401/403 no-fallback + reconnect path + - timeout/5xx fail-fast + - no fallback after content emitted + - CLI tests/wiring checks: + - command/mode visibility by feature flag + - connect mode routing and handler call. + - Non-streaming/structured guard check: + - confirm backend-only behavior unchanged. + +10. **Validation and cleanup decision for temporary script** + - Run targeted tests/typechecks for touched packages. + - Run OAuth validation script in manual mode (with your account interaction if needed). + - Decide and apply final disposition of temporary script: + - keep as dev utility, or + - remove before finalization. + +11. **Security/redaction verification** + - Validate no token values are logged in direct feature code paths. + - Grep/check for accidental logging of authorization headers, token payload fields, or raw callback query params. + +## Dependencies / Ordering +- Step 1 must be first. +- Step 2 must run before deep integration (early protocol validation gate). +- Step 3 precedes Steps 5–7. +- Step 4 can run in parallel with Step 3 after constants/util setup. +- Step 5 must precede Step 6. +- Step 8 can be implemented alongside Steps 5–6 but must complete before final validation. +- Step 9 follows core implementation completion. +- Steps 10–11 are final validation/cleanup/security passes. + +## Risk Areas +1. **Unofficial OAuth contract drift** — endpoint/field incompatibility can break token exchange. +2. **Direct payload compatibility** — strict sanitization must retain required OpenAI fields. +3. **Error classification correctness** — misclassification can violate requested fallback policy. +4. **Model normalization accuracy** — wrong mapping yields avoidable provider failures. +5. **Token redaction** — avoid leakage in logs, errors, or analytics payloads. +6. **Streaming boundary behavior** — fallback must not happen after partial output is emitted. diff --git a/.agents/sessions/03-02-1407-chatgpt-oauth-direct/SPEC.md b/.agents/sessions/03-02-1407-chatgpt-oauth-direct/SPEC.md new file mode 100644 index 0000000000..d56a415caf --- /dev/null +++ b/.agents/sessions/03-02-1407-chatgpt-oauth-direct/SPEC.md @@ -0,0 +1,155 @@ +# SPEC — ChatGPT Subscription OAuth Direct Routing + +## Overview +Implement an **experimental, default-disabled** ChatGPT subscription OAuth feature that allows the local CLI to route eligible OpenAI-model **streaming** requests directly to OpenAI instead of Codebuff backend routing, mirroring the prior Claude OAuth architecture pattern. + +## Protocol Assumptions (Explicit) +Because this is unofficial/experimental, this implementation proceeds under the following explicit assumptions: + +1. OAuth authorize endpoint: `https://auth.openai.com/oauth/authorize` +2. OAuth token endpoint: `https://auth.openai.com/oauth/token` +3. Public client id is configurable constant, defaulting to Codex-compatible value from ecosystem references. +4. PKCE (`S256`) is required. +5. Redirect URI is pinned to: `http://localhost:1455/auth/callback` +6. User can paste either: + - raw authorization code, or + - full callback URL containing code/state query params. +7. Token response includes at least `access_token`, optional `refresh_token`, and expiry info (`expires_in` or equivalent). +8. Refresh uses standard `grant_type=refresh_token`. + +If any assumption fails at runtime, the feature fails with explicit guidance and remains safely fallbackable only where policy allows. + +## Requirements +1. Add ChatGPT OAuth feature set, default disabled behind `CHATGPT_OAUTH_ENABLED = false`. +2. Add a new CLI command and mode: `/connect:chatgpt` with dedicated banner flow. +3. Implement browser-based PKCE code-paste flow (no device-code flow in this iteration). +4. Keep user-facing warning minimal (per user preference), while leaving code comments clearly marking experimental nature. +5. Store ChatGPT OAuth credentials in local credentials JSON alongside existing credentials. +6. Support env-var token override (power-user/automation use), but env var **must not bypass feature flag**. +7. Add refresh-token support with concurrency guard (mutex) for persisted credentials. +8. Direct routing scope is **streaming only** (`promptAiSdkStream` path); non-streaming and structured stay backend-routed. +9. Add model allowlist for direct routing; include optimistic aliases: + - `openai/gpt-5.3` + - `openai/gpt-5.3-codex` + - `openai/gpt-5.2` + - `openai/gpt-5.2-codex` + - plus selected nearby GPT/Codex IDs already present in repo config. +10. Provide deterministic model normalization for direct requests (OpenRouter-style -> provider-native): + - Example: `openai/gpt-5.3-codex` -> `gpt-5.3-codex` + - Mapping table lives in constants and is used for prevalidation. +11. Unsupported model handling must be deterministic and prevalidated: + - if model is not in allowlist/mapping for direct route, fail with explicit unsupported-model error (no fallback). +12. Fallback policy: + - Rate-limit/overload classification: auto-fallback to Codebuff backend. + - Auth errors (401/403): fail explicitly with reconnect guidance (no fallback). + - All other direct errors: fail fast (no fallback), per user decision. +13. Successful direct ChatGPT OAuth requests do **not** consume Codebuff credits. +14. Add lightweight ChatGPT connection status surfacing in CLI (usage banner and/or bottom status line), without quota API dependency. +15. Preserve existing Claude OAuth behavior unchanged. +16. Add temporary OAuth validation script that tests auth URL generation + token exchange manually before/alongside full wiring. +17. Add/update tests for credential parsing/storage/refresh, model gating, routing/fallback classification, and CLI command/mode wiring. +18. Never log OAuth tokens in analytics or error logs. + +## Direct Request Transformation Rules +Before sending direct streaming requests to OpenAI, enforce strict sanitization: + +1. Rewrite `model` from `openai/*` format to provider-native mapped id. +2. Remove provider-specific/non-OpenAI fields (e.g., codebuff metadata/provider routing payloads). +3. Preserve fields known to be valid for OpenAI-compatible chat completions. +4. Do not inject Codex-specific required prefix by default in v1 (user preference), but structure code so optional future injection is easy. + +## Error Classification Table +| Class | Detection | Behavior | +|---|---|---| +| Rate limit | HTTP 429 or message/body contains rate-limit indicators | Fallback to backend (if no output emitted yet) | +| Auth | HTTP 401/403 or auth-token-invalid indicators | Fail with reconnect guidance; no fallback | +| Unsupported model | Local allowlist/mapping precheck failure | Fail explicit unsupported-model error; no fallback | +| Other | Network timeout, 5xx, malformed payload, unknown 4xx | Fail fast; no fallback | + +## Routing Scope +1. Direct routing applies only to `promptAiSdkStream` eligible requests. +2. `promptAiSdk` and `promptAiSdkStructured` remain backend-only for this iteration. +3. Backend routing remains unchanged for all non-eligible models and when feature disabled/disconnected. + +## Credentials & Precedence Rules +1. Credentials file schema extends with `chatgptOAuth` object. +2. Precedence: env token override > persisted OAuth credentials > none. +3. Env token produces synthetic non-refreshing credentials object. +4. Persisted credentials refresh when expired/near-expiry (5-minute buffer). +5. On refresh failure for persisted credentials, clear only `chatgptOAuth` entry (preserve other credentials). + +## Feature Gating Matrix +1. `CHATGPT_OAUTH_ENABLED = false` + - hide `/connect:chatgpt` command and banner UX + - disable direct routing even if env token exists +2. `CHATGPT_OAUTH_ENABLED = true` and credentials available + - enable command/UI + - enable direct routing for eligible models + +## Logging/Redaction Requirements +1. Never log raw access tokens, refresh tokens, authorization headers, or token response payloads. +2. If callback URL is logged for debugging, redact query values for `code`, `access_token`, `refresh_token`, and similar sensitive keys. +3. Analytics properties must not include token-bearing strings. + +## Technical Approach +1. Create `common/src/constants/chatgpt-oauth.ts`: + - feature flag, endpoints, client id, redirect URI, env var name, model allowlist/mapping helpers. +2. Export new constants via `common/src/constants/index.ts` so legacy `old-constants` re-export path includes them. +3. Extend `sdk/src/env.ts` with ChatGPT OAuth env-token helper. +4. Extend `sdk/src/credentials.ts` with ChatGPT OAuth schema+helpers mirroring Claude pattern. +5. Create `cli/src/utils/chatgpt-oauth.ts` for PKCE start/open/exchange/disconnect/status. +6. Create `cli/src/components/chatgpt-connect-banner.tsx` and auth-code handler. +7. Wire CLI command/input mode/slash menu/router/banner registry for `connect:chatgpt`. +8. Extend model provider (`sdk/src/impl/model-provider.ts`): + - add ChatGPT direct route decision path for `openai/*` allowlisted models + - add rate-limit cache helpers for ChatGPT path + - build direct OpenAI-compatible language model with OAuth bearer auth + - enforce strict body sanitization + model normalization in the direct path. +9. Extend stream error handling (`sdk/src/impl/llm.ts`) for ChatGPT direct path with required fallback/fail rules and analytics. +10. Extend app init (`cli/src/init/init-app.ts`) for background ChatGPT credential refresh when enabled. +11. Add analytics events for ChatGPT OAuth request/rate-limit/auth-error. +12. Update usage/status UI text to include ChatGPT connection state. +13. Add temporary validation script (e.g., `scripts/chatgpt-oauth-validate.ts`) to exercise OAuth setup interactively. + +## Acceptance Criteria +1. With feature disabled, `/connect:chatgpt` is unavailable and no direct routing occurs. +2. With feature enabled, user can run `/connect:chatgpt`, complete browser flow, paste code/URL, and connect. +3. Eligible streaming requests on allowlisted `openai/*` models use direct OAuth path. +4. Direct request payloads are sanitized and model ids normalized before transmission. +5. Rate-limited direct requests fallback to backend automatically. +6. Auth failures produce reconnect guidance and do not fallback. +7. Unsupported models fail immediately with explicit unsupported-model message. +8. Successful direct requests skip Codebuff credit accounting path. +9. Existing Claude OAuth flow remains behaviorally unchanged. +10. New/updated tests pass for touched behavior. +11. Temporary validation script can run and guide manual OAuth exchange checks. + +## Files to Create/Modify +- Create: `common/src/constants/chatgpt-oauth.ts` +- Create: `cli/src/utils/chatgpt-oauth.ts` +- Create: `cli/src/components/chatgpt-connect-banner.tsx` +- Create: `scripts/chatgpt-oauth-validate.ts` (temporary validation utility) +- Modify: `common/src/constants/index.ts` +- Modify: `common/src/constants/analytics-events.ts` +- Modify: `sdk/src/env.ts` +- Modify: `sdk/src/credentials.ts` +- Modify: `sdk/src/impl/model-provider.ts` +- Modify: `sdk/src/impl/llm.ts` +- Modify: `sdk/src/index.ts` +- Modify: `cli/src/utils/input-modes.ts` +- Modify: `cli/src/components/input-mode-banner.tsx` +- Modify: `cli/src/data/slash-commands.ts` +- Modify: `cli/src/commands/command-registry.ts` +- Modify: `cli/src/commands/router.ts` +- Modify: `cli/src/chat.tsx` +- Modify: `cli/src/components/usage-banner.tsx` +- Modify: `cli/src/components/bottom-status-line.tsx` +- Modify: `cli/src/init/init-app.ts` +- Modify tests in SDK/CLI for new behavior. + +## Out of Scope +1. Device-code auth flow. +2. Legal/policy guarantees around undocumented endpoints. +3. Full quota/usage API integration for ChatGPT subscription plans. +4. Local callback server daemon beyond paste-based flow. +5. Enabling feature by default. diff --git a/.agents/skills/meta/SKILL.md b/.agents/skills/meta/SKILL.md index a66b88dafb..8b05efdddf 100644 --- a/.agents/skills/meta/SKILL.md +++ b/.agents/skills/meta/SKILL.md @@ -10,3 +10,9 @@ description: Broad project-level implementation and validation heuristics - From monorepo root, run workspace scripts as `bun run --cwd + + +
Top navigation should disappear
+
+
+

Important Answer

+

The web researcher should see this useful paragraph.

+

React 19 useActionState returns state, a form action, and pending state.

+
+
+
Footer boilerplate should disappear
+ + + `) + + expect('errorMessage' in result).toBe(false) + if ('errorMessage' in result) return + + expect(result.title).toBe('Research Source') + expect(result.description).toBe('A concise source description.') + expect(result.text).toContain('Important Answer') + expect(result.text).toContain('useActionState returns state') + expect(result.text).not.toContain('.unused-') + expect(result.text).not.toContain('Top navigation') + }) + + it('prefers article content over a larger page main area', async () => { + const result = await successValue(` + + Repository Page + +
+
+

Folders and files

+ ${Array.from( + { length: 40 }, + (_, index) => `file-${index}.ts`, + ).join('')} +
+
+

Project README

+

This is the source content the researcher needs.

+
+
+ + + `) + + expect('errorMessage' in result).toBe(false) + if ('errorMessage' in result) return + + expect(result.text).toContain('Project README') + expect(result.text).toContain('source content') + expect(result.text).not.toContain('Folders and files') + expect(result.text).not.toContain('file-39.ts') + }) + + it('does not add spaces between syntax-highlighted code tokens', async () => { + const result = await successValue(` +
+
const answer=42;
+
+ `) + + expect('errorMessage' in result).toBe(false) + if ('errorMessage' in result) return + + expect(result.text).toContain('const answer=42;') + }) + + it('leaves invalid numeric HTML entities unchanged', async () => { + const result = await successValue( + '

Bad entity: �

', + ) + + expect('errorMessage' in result).toBe(false) + if ('errorMessage' in result) return + + expect(result.text).toContain('Bad entity: �') + }) + + it('rejects non-http URLs', async () => { + const result = await readUrl({ + url: 'file:///etc/passwd', + fetch: async () => { + throw new Error('fetch should not be called') + }, + }) + + expect(result[0].value).toEqual({ + url: 'file:///etc/passwd', + errorMessage: 'Only http:// and https:// URLs are supported', + }) + }) + + it('rejects non-http URLs at the tool schema boundary', () => { + expect(() => + clientToolCallSchema.parse({ + toolName: 'read_url', + input: { url: 'file:///etc/passwd' }, + }), + ).toThrow() + }) + + it('truncates extracted text to max_chars', async () => { + const result = await readUrl({ + url: 'https://example.com/long', + max_chars: 1_000, + fetch: async () => + new Response(`

${'word '.repeat(1_000)}

`, { + status: 200, + headers: { 'content-type': 'text/html' }, + }), + }) + const value = result[0].value + + expect('errorMessage' in value).toBe(false) + if ('errorMessage' in value) return + + expect(value.truncated).toBe(true) + expect(value.text.length).toBeLessThanOrEqual(1_030) + expect(value.text).toContain('[Content truncated]') + }) + + it('returns pretty-printed JSON for JSON responses', async () => { + const result = await successValue('{"name":"Codebuff","answer":42}', { + contentType: 'application/json', + }) + + expect('errorMessage' in result).toBe(false) + if ('errorMessage' in result) return + + expect(result.text).toContain('"name": "Codebuff"') + expect(result.text).toContain('"answer": 42') + }) + + it('supports vendor JSON content types', async () => { + const result = await successValue('{"type":"metadata"}', { + contentType: 'application/ld+json', + }) + + expect('errorMessage' in result).toBe(false) + if ('errorMessage' in result) return + + expect(result.text).toContain('"type": "metadata"') + }) + + it('extracts markdown frontmatter into metadata and omits it from text', async () => { + const result = await successValue( + [ + '---', + 'title: "Readable Docs"', + "description: 'A useful docs page'", + '---', + '# First Heading', + 'Body with · entity.', + ].join('\n'), + { + contentType: 'text/markdown; charset=utf-8', + }, + ) + + expect('errorMessage' in result).toBe(false) + if ('errorMessage' in result) return + + expect(result.title).toBe('Readable Docs') + expect(result.description).toBe('A useful docs page') + expect(result.text.startsWith('# First Heading')).toBe(true) + expect(result.text).toContain('Body with * entity.') + expect(result.text).not.toContain('title:') + }) + + it('supports CRLF markdown frontmatter', async () => { + const result = await successValue( + '---\r\ntitle: CRLF Docs\r\n---\r\n# Body', + { + contentType: 'text/markdown; charset=utf-8', + }, + ) + + expect('errorMessage' in result).toBe(false) + if ('errorMessage' in result) return + + expect(result.title).toBe('CRLF Docs') + expect(result.text).toBe('# Body') + }) +}) diff --git a/sdk/src/__tests__/researcher-web.integration.test.ts b/sdk/src/__tests__/researcher-web.integration.test.ts new file mode 100644 index 0000000000..a5e981654a --- /dev/null +++ b/sdk/src/__tests__/researcher-web.integration.test.ts @@ -0,0 +1,202 @@ +import { existsSync, readFileSync } from 'fs' +import { homedir } from 'os' +import path from 'path' + +import { describe, expect, it } from 'bun:test' + +import { CodebuffClient } from '../client' +import { loadLocalAgents } from '../agents/load-agents' + +import type { AgentOutput } from '@codebuff/common/types/session-state' +import type { PrintModeEvent } from '@codebuff/common/types/print-mode' + +const DEFAULT_TIMEOUT_MS = 120_000 +const EXPECTED_KEYWORD = 'useActionState' + +function loadEnvValue(name: string): string | undefined { + if (process.env[name] && process.env[name] !== 'test') { + return process.env[name] + } + + for (const envPath of [ + path.join(homedir(), 'codebuff', '.env.local'), + path.join(process.cwd(), '.env.local'), + ]) { + if (!existsSync(envPath)) continue + + const contents = readFileSync(envPath, 'utf8') + const match = contents.match(new RegExp(`^${name}=(.*)$`, 'm')) + const value = match?.[1]?.trim().replace(/^['"]|['"]$/g, '') + if (value && value !== 'test') return value + } + + return undefined +} + +function extractOutputText(output: AgentOutput): string { + if (output.type === 'error') return output.message + if (output.type === 'structuredOutput') { + return JSON.stringify(output.value ?? {}) + } + + const assistantText = output.value.flatMap((message) => { + if ((message as { role?: unknown }).role !== 'assistant') return [] + + const content = (message as { content?: unknown }).content + if (typeof content === 'string') return [content] + if (!Array.isArray(content)) return [] + + return content.flatMap((part) => { + if ( + part && + typeof part === 'object' && + 'type' in part && + part.type === 'text' && + 'text' in part + ) { + return [String(part.text)] + } + return [] + }) + }) + + return assistantText.join('\n') +} + +function summarizeToolTrace(events: PrintModeEvent[]): { + readUrlCount: number + lines: string[] +} { + const lines: string[] = [] + let readUrlCount = 0 + + for (const event of events) { + if (event.type === 'tool_call') { + if (event.toolName === 'web_search') { + lines.push(`tool_call web_search query=${event.input.query}`) + } else if (event.toolName === 'read_url') { + readUrlCount += 1 + lines.push(`tool_call read_url url=${event.input.url}`) + } else { + lines.push(`tool_call ${event.toolName}`) + } + continue + } + + if (event.type !== 'tool_result') continue + + const output = event.output[0] + const value = output?.type === 'json' ? output.value : undefined + if (!value || typeof value !== 'object') { + lines.push(`tool_result ${event.toolName} empty`) + continue + } + + if (event.toolName === 'read_url') { + const result = value as { + url?: string + finalUrl?: string + status?: number + title?: string + text?: string + truncated?: boolean + errorMessage?: string + } + if (result.errorMessage) { + lines.push(`tool_result read_url error=${result.errorMessage}`) + } else { + lines.push( + [ + 'tool_result read_url', + `status=${result.status}`, + `finalUrl=${result.finalUrl}`, + `title=${JSON.stringify(result.title ?? '')}`, + `textChars=${result.text?.length ?? 0}`, + `truncated=${result.truncated ?? false}`, + ].join(' '), + ) + } + } else if (event.toolName === 'web_search') { + const result = value as { result?: string; errorMessage?: string } + lines.push( + result.errorMessage + ? `tool_result web_search error=${result.errorMessage}` + : `tool_result web_search chars=${result.result?.length ?? 0}`, + ) + } + } + + return { readUrlCount, lines } +} + +describe('researcher-web SDK integration', () => { + it( + `runs researcher-web through the SDK and answers with ${EXPECTED_KEYWORD}`, + async () => { + const apiKey = loadEnvValue('CODEBUFF_API_KEY') + if (!apiKey) { + console.log( + 'Skipping researcher-web SDK integration test: set CODEBUFF_API_KEY to run.', + ) + return + } + + const agentsPath = path.resolve( + import.meta.dir, + '../../../agents/researcher', + ) + const loadedAgents = await loadLocalAgents({ agentsPath }) + const researcherWeb = loadedAgents['researcher-web'] + expect(researcherWeb).toBeDefined() + + const events: PrintModeEvent[] = [] + const client = new CodebuffClient({ + apiKey, + cwd: process.cwd(), + }) + + const result = await client.run({ + agent: 'researcher-web', + agentDefinitions: [researcherWeb], + maxAgentSteps: 8, + handleEvent: (event) => { + events.push(event) + }, + prompt: [ + 'Use web search to answer this React docs question.', + 'After searching, fetch the most relevant React docs page with read_url before answering.', + 'In React 19, which hook returns state, a form action, and an isPending value for form actions?', + 'Answer with the exact hook name and one short sentence.', + ].join(' '), + }) + + const outputText = extractOutputText(result.output) + const trace = summarizeToolTrace(events) + console.log( + [ + 'researcher-web SDK trace:', + ...trace.lines.map((line) => ` ${line}`), + `read_url fetch count: ${trace.readUrlCount}`, + ].join('\n'), + ) + console.log('researcher-web SDK output:', outputText) + + expect(result.output.type).not.toBe('error') + expect(outputText).toContain(EXPECTED_KEYWORD) + expect(events.some((event) => event.type === 'tool_call')).toBe(true) + expect( + events.some( + (event) => + event.type === 'tool_call' && event.toolName === 'web_search', + ), + ).toBe(true) + expect( + events.some( + (event) => + event.type === 'tool_call' && event.toolName === 'read_url', + ), + ).toBe(true) + }, + DEFAULT_TIMEOUT_MS, + ) +}) diff --git a/sdk/src/__tests__/run-cancellation.test.ts b/sdk/src/__tests__/run-cancellation.test.ts index 9ebfbb8614..ae45c19f76 100644 --- a/sdk/src/__tests__/run-cancellation.test.ts +++ b/sdk/src/__tests__/run-cancellation.test.ts @@ -1,10 +1,10 @@ - import * as mainPromptModule from '@codebuff/agent-runtime/main-prompt' import { withSystemTags } from '@codebuff/agent-runtime/util/messages' import { getInitialSessionState } from '@codebuff/common/types/session-state' import { getStubProjectFileContext } from '@codebuff/common/util/file' import { assistantMessage, userMessage } from '@codebuff/common/util/messages' import { afterEach, describe, expect, it, mock, spyOn } from 'bun:test' +import { RetryError } from 'ai' // Type for tool call content blocks in message history interface ToolCallContentBlock { @@ -27,9 +27,9 @@ describe('Run Cancellation Handling', () => { id: 'user-123', email: 'test@example.com', discord_id: null, - referral_code: null, stripe_customer_id: null, banned: false, + created_at: new Date('2024-01-01T00:00:00Z'), }) spyOn(databaseModule, 'fetchAgentFromDatabase').mockResolvedValue(null) spyOn(databaseModule, 'startAgentRun').mockResolvedValue('run-1') @@ -37,9 +37,11 @@ describe('Run Cancellation Handling', () => { spyOn(databaseModule, 'addAgentStep').mockResolvedValue('step-1') // Server session state already includes the user's message (as the server would normally do) - const serverSessionState = getInitialSessionState(getStubProjectFileContext()) + const serverSessionState = getInitialSessionState( + getStubProjectFileContext(), + ) serverSessionState.mainAgentState.messageHistory.push( - userMessage('Please fix the bug'), // Server added this + userMessage('Please fix the bug'), // Server added this assistantMessage('I will help you with that.'), ) @@ -82,10 +84,10 @@ describe('Run Cancellation Handling', () => { const messageHistory = result.sessionState!.mainAgentState.messageHistory const userMessages = messageHistory.filter((m) => m.role === 'user') - + // Should have exactly 1 user message, not 2 expect(userMessages.length).toBe(1) - + // Total messages should be 2 (user + assistant), not 3 expect(messageHistory.length).toBe(2) }) @@ -95,9 +97,9 @@ describe('Run Cancellation Handling', () => { id: 'user-123', email: 'test@example.com', discord_id: null, - referral_code: null, stripe_customer_id: null, banned: false, + created_at: new Date('2024-01-01T00:00:00Z'), }) spyOn(databaseModule, 'fetchAgentFromDatabase').mockResolvedValue(null) spyOn(databaseModule, 'startAgentRun').mockResolvedValue('run-1') @@ -107,9 +109,11 @@ describe('Run Cancellation Handling', () => { const abortController = new AbortController() // Server session state already includes the user's message (server processed it) - const serverSessionState = getInitialSessionState(getStubProjectFileContext()) + const serverSessionState = getInitialSessionState( + getStubProjectFileContext(), + ) serverSessionState.mainAgentState.messageHistory.push( - userMessage('Please fix the bug'), // Server added the user's message + userMessage('Please fix the bug'), // Server added the user's message assistantMessage('I will help you with that.'), ) @@ -131,7 +135,11 @@ describe('Run Cancellation Handling', () => { // Simulate agent runtime adding interruption message on abort serverSessionState.mainAgentState.messageHistory.push( - userMessage(withSystemTags("User interrupted the response. The assistant's previous work has been preserved.")) + userMessage( + withSystemTags( + "User interrupted the response. The assistant's previous work has been preserved.", + ), + ), ) // Server still responds with its session state @@ -169,29 +177,237 @@ describe('Run Cancellation Handling', () => { // The user's message should NOT be duplicated const messageHistory = result.sessionState!.mainAgentState.messageHistory - + // Count user messages (excluding system interruption messages) const userPromptMessages = messageHistory.filter( - (m) => m.role === 'user' && - m.content.some((c: any) => c.type === 'text' && c.text.includes('fix the bug')) + (m) => + m.role === 'user' && + m.content.some( + (c: any) => c.type === 'text' && c.text.includes('fix the bug'), + ), ) - + // Should have exactly 1 user message with the prompt, not 2 expect(userPromptMessages.length).toBe(1) - + // Total messages should be: 1 user + 1 assistant (original) + 1 interruption = 3 // The server state already has the content; pendingAgentResponse is not duplicated. expect(messageHistory.length).toBe(3) }) + it('extracts error code and message from AI SDK responseBody on 403', async () => { + spyOn(databaseModule, 'getUserInfoFromApiKey').mockResolvedValue({ + id: 'user-123', + email: 'test@example.com', + discord_id: null, + stripe_customer_id: null, + banned: false, + created_at: new Date('2024-01-01T00:00:00Z'), + }) + spyOn(databaseModule, 'fetchAgentFromDatabase').mockResolvedValue(null) + spyOn(databaseModule, 'startAgentRun').mockResolvedValue('run-1') + spyOn(databaseModule, 'finishAgentRun').mockResolvedValue(undefined) + spyOn(databaseModule, 'addAgentStep').mockResolvedValue('step-1') + + // Simulate AI SDK's AI_APICallError with responseBody (what the server returns for free_mode_unavailable) + const apiError = new Error('Forbidden') as Error & { + statusCode: number + responseBody: string + } + apiError.statusCode = 403 + apiError.responseBody = JSON.stringify({ + error: 'free_mode_unavailable', + message: 'Free mode is not available in your country.', + countryCode: 'US', + countryBlockReason: 'anonymous_network', + ipPrivacySignals: ['vpn', 'hosting'], + }) + + spyOn(mainPromptModule, 'callMainPrompt').mockRejectedValue(apiError) + + const client = new CodebuffClient({ + apiKey: 'test-key', + }) + + const result = await client.run({ + agent: 'base2', + prompt: 'hello', + }) + + expect(result.output.type).toBe('error') + const output = result.output as { + type: 'error' + message: string + statusCode?: number + error?: string + countryCode?: string + countryBlockReason?: string + ipPrivacySignals?: string[] + } + // Should use the message from the response body, not the generic "Forbidden" + expect(output.message).toBe('Free mode is not available in your country.') + expect(output.statusCode).toBe(403) + // Should propagate the error code so isFreeModeUnavailableError can match + expect(output.error).toBe('free_mode_unavailable') + expect(output.countryCode).toBe('US') + expect(output.countryBlockReason).toBe('anonymous_network') + expect(output.ipPrivacySignals).toEqual(['vpn', 'hosting']) + }) + + it('extracts error code and message from nested AI SDK retry errors', async () => { + spyOn(databaseModule, 'getUserInfoFromApiKey').mockResolvedValue({ + id: 'user-123', + email: 'test@example.com', + discord_id: null, + stripe_customer_id: null, + banned: false, + created_at: new Date('2024-01-01T00:00:00Z'), + }) + spyOn(databaseModule, 'fetchAgentFromDatabase').mockResolvedValue(null) + spyOn(databaseModule, 'startAgentRun').mockResolvedValue('run-1') + spyOn(databaseModule, 'finishAgentRun').mockResolvedValue(undefined) + spyOn(databaseModule, 'addAgentStep').mockResolvedValue('step-1') + + const apiError = new Error('Conflict') as Error & { + statusCode: number + responseBody: string + } + apiError.statusCode = 409 + apiError.responseBody = JSON.stringify({ + error: 'session_model_mismatch', + message: + 'This session is bound to deepseek; restart freebuff to switch models.', + }) + + spyOn(mainPromptModule, 'callMainPrompt').mockRejectedValue( + new RetryError({ + message: 'Failed after 4 attempts. Last error: Conflict', + reason: 'maxRetriesExceeded', + errors: [apiError], + }), + ) + + const client = new CodebuffClient({ + apiKey: 'test-key', + }) + + const result = await client.run({ + agent: 'base2', + prompt: 'hello', + }) + + const output = result.output as { + type: 'error' + message: string + statusCode?: number + error?: string + } + expect(output.message).toBe( + 'This session is bound to deepseek; restart freebuff to switch models.', + ) + expect(output.statusCode).toBe(409) + expect(output.error).toBe('session_model_mismatch') + }) + + it('extracts error code from responseBody for account_suspended 403', async () => { + spyOn(databaseModule, 'getUserInfoFromApiKey').mockResolvedValue({ + id: 'user-123', + email: 'test@example.com', + discord_id: null, + stripe_customer_id: null, + banned: false, + created_at: new Date('2024-01-01T00:00:00Z'), + }) + spyOn(databaseModule, 'fetchAgentFromDatabase').mockResolvedValue(null) + spyOn(databaseModule, 'startAgentRun').mockResolvedValue('run-1') + spyOn(databaseModule, 'finishAgentRun').mockResolvedValue(undefined) + spyOn(databaseModule, 'addAgentStep').mockResolvedValue('step-1') + + const apiError = new Error('Forbidden') as Error & { + statusCode: number + responseBody: string + } + apiError.statusCode = 403 + apiError.responseBody = JSON.stringify({ + error: 'account_suspended', + message: 'Your account has been suspended due to billing issues.', + }) + + spyOn(mainPromptModule, 'callMainPrompt').mockRejectedValue(apiError) + + const client = new CodebuffClient({ + apiKey: 'test-key', + }) + + const result = await client.run({ + agent: 'base2', + prompt: 'hello', + }) + + const output = result.output as { + type: 'error' + message: string + statusCode?: number + error?: string + } + expect(output.message).toBe( + 'Your account has been suspended due to billing issues.', + ) + expect(output.statusCode).toBe(403) + expect(output.error).toBe('account_suspended') + }) + + it('falls back to error.message when responseBody is not valid JSON', async () => { + spyOn(databaseModule, 'getUserInfoFromApiKey').mockResolvedValue({ + id: 'user-123', + email: 'test@example.com', + discord_id: null, + stripe_customer_id: null, + banned: false, + created_at: new Date('2024-01-01T00:00:00Z'), + }) + spyOn(databaseModule, 'fetchAgentFromDatabase').mockResolvedValue(null) + spyOn(databaseModule, 'startAgentRun').mockResolvedValue('run-1') + spyOn(databaseModule, 'finishAgentRun').mockResolvedValue(undefined) + spyOn(databaseModule, 'addAgentStep').mockResolvedValue('step-1') + + const apiError = new Error('Forbidden') as Error & { + statusCode: number + responseBody: string + } + apiError.statusCode = 403 + apiError.responseBody = 'not valid json' + + spyOn(mainPromptModule, 'callMainPrompt').mockRejectedValue(apiError) + + const client = new CodebuffClient({ + apiKey: 'test-key', + }) + + const result = await client.run({ + agent: 'base2', + prompt: 'hello', + }) + + const output = result.output as { + type: 'error' + message: string + statusCode?: number + error?: string + } + expect(output.message).toBe('Forbidden') + expect(output.statusCode).toBe(403) + expect(output.error).toBeUndefined() + }) + it('preserves user message when callMainPrompt throws an error', async () => { spyOn(databaseModule, 'getUserInfoFromApiKey').mockResolvedValue({ id: 'user-123', email: 'test@example.com', discord_id: null, - referral_code: null, stripe_customer_id: null, banned: false, + created_at: new Date('2024-01-01T00:00:00Z'), }) spyOn(databaseModule, 'fetchAgentFromDatabase').mockResolvedValue(null) spyOn(databaseModule, 'startAgentRun').mockResolvedValue('run-1') @@ -214,7 +430,9 @@ describe('Run Cancellation Handling', () => { // Should return an error output expect(result.output.type).toBe('error') - expect((result.output as { type: 'error'; message: string }).message).toBe('Network connection failed') + expect((result.output as { type: 'error'; message: string }).message).toBe( + 'Network connection failed', + ) // The user's message should be preserved in the session state expect(result.sessionState).toBeDefined() @@ -230,7 +448,9 @@ describe('Run Cancellation Handling', () => { expect(userPromptMessage).toBeDefined() // Verify the message content contains the original prompt - const textContent = userPromptMessage!.content.find((c: any) => c.type === 'text') as { type: 'text'; text: string } | undefined + const textContent = userPromptMessage!.content.find( + (c: any) => c.type === 'text', + ) as { type: 'text'; text: string } | undefined expect(textContent).toBeDefined() expect(textContent!.text).toContain('Please fix the bug in my code') }) @@ -240,9 +460,9 @@ describe('Run Cancellation Handling', () => { id: 'user-123', email: 'test@example.com', discord_id: null, - referral_code: null, stripe_customer_id: null, banned: false, + created_at: new Date('2024-01-01T00:00:00Z'), }) spyOn(databaseModule, 'fetchAgentFromDatabase').mockResolvedValue(null) spyOn(databaseModule, 'startAgentRun').mockResolvedValue('run-1') @@ -250,11 +470,14 @@ describe('Run Cancellation Handling', () => { spyOn(databaseModule, 'addAgentStep').mockResolvedValue('step-1') const abortController = new AbortController() - const serverSessionState = getInitialSessionState(getStubProjectFileContext()) + const serverSessionState = getInitialSessionState( + getStubProjectFileContext(), + ) serverSessionState.mainAgentState.messageHistory.push( userMessage('User prompt'), ) - const originalHistoryLength = serverSessionState.mainAgentState.messageHistory.length + const originalHistoryLength = + serverSessionState.mainAgentState.messageHistory.length spyOn(mainPromptModule, 'callMainPrompt').mockImplementation( async (params: Parameters[0]) => { @@ -265,7 +488,11 @@ describe('Run Cancellation Handling', () => { // Simulate agent runtime adding interruption message on abort serverSessionState.mainAgentState.messageHistory.push( - userMessage(withSystemTags("User interrupted the response. The assistant's previous work has been preserved.")) + userMessage( + withSystemTags( + "User interrupted the response. The assistant's previous work has been preserved.", + ), + ), ) await sendAction({ @@ -308,7 +535,9 @@ describe('Run Cancellation Handling', () => { // The last message should be the interruption (user role), not an empty assistant message const lastMessage = messageHistory[messageHistory.length - 1] expect(lastMessage.role).toBe('user') - expect((lastMessage.content[0] as { type: 'text'; text: string }).text).toContain('User interrupted') + expect( + (lastMessage.content[0] as { type: 'text'; text: string }).text, + ).toContain('User interrupted') // Verify there's no empty assistant message before the interruption const secondToLastMessage = messageHistory[messageHistory.length - 2] @@ -321,9 +550,9 @@ describe('Run Cancellation Handling', () => { id: 'user-123', email: 'test@example.com', discord_id: null, - referral_code: null, stripe_customer_id: null, banned: false, + created_at: new Date('2024-01-01T00:00:00Z'), }) spyOn(databaseModule, 'fetchAgentFromDatabase').mockResolvedValue(null) spyOn(databaseModule, 'startAgentRun').mockResolvedValue('run-1') @@ -391,9 +620,9 @@ describe('Run Cancellation Handling', () => { id: 'user-123', email: 'test@example.com', discord_id: null, - referral_code: null, stripe_customer_id: null, banned: false, + created_at: new Date('2024-01-01T00:00:00Z'), }) spyOn(databaseModule, 'fetchAgentFromDatabase').mockResolvedValue(null) spyOn(databaseModule, 'startAgentRun').mockResolvedValue('run-1') @@ -403,7 +632,9 @@ describe('Run Cancellation Handling', () => { const abortController = new AbortController() // Create a session state with some existing message history to verify it's preserved - const serverSessionState = getInitialSessionState(getStubProjectFileContext()) + const serverSessionState = getInitialSessionState( + getStubProjectFileContext(), + ) serverSessionState.mainAgentState.messageHistory.push( userMessage('User prompt'), assistantMessage('I will help you with that.'), @@ -426,10 +657,13 @@ describe('Run Cancellation Handling', () => { role: 'tool', toolCallId: 'tool-1', toolName: 'read_files', - content: [{ type: 'json', value: [{ path: 'file.ts', content: 'const x = 1;' }] }], + content: [ + { type: 'json', value: [{ path: 'file.ts', content: 'const x = 1;' }] }, + ], }) - const originalHistoryLength = serverSessionState.mainAgentState.messageHistory.length + const originalHistoryLength = + serverSessionState.mainAgentState.messageHistory.length spyOn(mainPromptModule, 'callMainPrompt').mockImplementation( async (params: Parameters[0]) => { @@ -449,7 +683,11 @@ describe('Run Cancellation Handling', () => { // Simulate agent runtime adding interruption message on abort serverSessionState.mainAgentState.messageHistory.push( - userMessage(withSystemTags("User interrupted the response. The assistant's previous work has been preserved.")) + userMessage( + withSystemTags( + "User interrupted the response. The assistant's previous work has been preserved.", + ), + ), ) // Server still sends the prompt-response with the full session state @@ -500,7 +738,9 @@ describe('Run Cancellation Handling', () => { const toolCallMessage = messageHistory.find( (m) => m.role === 'assistant' && - m.content.some((c: any) => c.type === 'tool-call' && c.toolCallId === 'tool-1'), + m.content.some( + (c: any) => c.type === 'tool-call' && c.toolCallId === 'tool-1', + ), ) expect(toolCallMessage).toBeDefined() @@ -519,9 +759,9 @@ describe('Run Cancellation Handling', () => { id: 'user-123', email: 'test@example.com', discord_id: null, - referral_code: null, stripe_customer_id: null, banned: false, + created_at: new Date('2024-01-01T00:00:00Z'), }) spyOn(databaseModule, 'fetchAgentFromDatabase').mockResolvedValue(null) spyOn(databaseModule, 'startAgentRun').mockResolvedValue('run-1') @@ -529,7 +769,9 @@ describe('Run Cancellation Handling', () => { spyOn(databaseModule, 'addAgentStep').mockResolvedValue('step-1') const abortController = new AbortController() - const serverSessionState = getInitialSessionState(getStubProjectFileContext()) + const serverSessionState = getInitialSessionState( + getStubProjectFileContext(), + ) spyOn(mainPromptModule, 'callMainPrompt').mockImplementation( async (params: Parameters[0]) => { @@ -540,7 +782,11 @@ describe('Run Cancellation Handling', () => { // Simulate agent runtime adding interruption message on abort serverSessionState.mainAgentState.messageHistory.push( - userMessage(withSystemTags("User interrupted the response. The assistant's previous work has been preserved.")) + userMessage( + withSystemTags( + "User interrupted the response. The assistant's previous work has been preserved.", + ), + ), ) await sendAction({ @@ -582,7 +828,9 @@ describe('Run Cancellation Handling', () => { expect(lastMessage.role).toBe('user') expect(Array.isArray(lastMessage.content)).toBe(true) - const textContent = lastMessage.content.find((c: any) => c.type === 'text') as { type: 'text'; text: string } | undefined + const textContent = lastMessage.content.find( + (c: any) => c.type === 'text', + ) as { type: 'text'; text: string } | undefined expect(textContent).toBeDefined() // The text should be wrapped in tags @@ -602,9 +850,9 @@ describe('Run Cancellation Handling', () => { id: 'user-123', email: 'test@example.com', discord_id: null, - referral_code: null, stripe_customer_id: null, banned: false, + created_at: new Date('2024-01-01T00:00:00Z'), }) const abortController = new AbortController() @@ -630,21 +878,24 @@ describe('Run Cancellation Handling', () => { id: 'user-123', email: 'test@example.com', discord_id: null, - referral_code: null, stripe_customer_id: null, banned: false, + created_at: new Date('2024-01-01T00:00:00Z'), }) spyOn(databaseModule, 'fetchAgentFromDatabase').mockResolvedValue(null) spyOn(databaseModule, 'startAgentRun').mockResolvedValue('run-1') spyOn(databaseModule, 'finishAgentRun').mockResolvedValue(undefined) spyOn(databaseModule, 'addAgentStep').mockResolvedValue('step-1') - const serverSessionState = getInitialSessionState(getStubProjectFileContext()) + const serverSessionState = getInitialSessionState( + getStubProjectFileContext(), + ) serverSessionState.mainAgentState.messageHistory.push( userMessage('User prompt'), assistantMessage('Done!'), ) - const originalHistoryLength = serverSessionState.mainAgentState.messageHistory.length + const originalHistoryLength = + serverSessionState.mainAgentState.messageHistory.length spyOn(mainPromptModule, 'callMainPrompt').mockImplementation( async (params: Parameters[0]) => { @@ -691,14 +942,211 @@ describe('Run Cancellation Handling', () => { expect(lastMessage.role).toBe('assistant') }) + it('preserves message history across cancelled run and subsequent run', async () => { + spyOn(databaseModule, 'getUserInfoFromApiKey').mockResolvedValue({ + id: 'user-123', + email: 'test@example.com', + discord_id: null, + stripe_customer_id: null, + banned: false, + created_at: new Date('2024-01-01T00:00:00Z'), + }) + spyOn(databaseModule, 'fetchAgentFromDatabase').mockResolvedValue(null) + spyOn(databaseModule, 'startAgentRun').mockResolvedValue('run-1') + spyOn(databaseModule, 'finishAgentRun').mockResolvedValue(undefined) + spyOn(databaseModule, 'addAgentStep').mockResolvedValue('step-1') + + const abortController = new AbortController() + + // First run: server processes the user message and does some work, then user cancels + const firstRunServerState = getInitialSessionState( + getStubProjectFileContext(), + ) + firstRunServerState.mainAgentState.messageHistory.push( + userMessage('Fix the bug in auth.ts'), + assistantMessage('I will analyze the authentication module.'), + ) + + spyOn(mainPromptModule, 'callMainPrompt').mockImplementation( + async (params: Parameters[0]) => { + const { sendAction, promptId } = params + + // Stream some content + await sendAction({ + action: { + type: 'response-chunk', + userInputId: promptId, + chunk: 'Analyzing auth.ts...', + }, + }) + + // User cancels mid-stream + abortController.abort() + + // Agent runtime adds interruption message on abort + firstRunServerState.mainAgentState.messageHistory.push( + userMessage( + withSystemTags( + "User interrupted the response. The assistant's previous work has been preserved.", + ), + ), + ) + + // Server still sends the prompt-response with its session state + await sendAction({ + action: { + type: 'prompt-response', + promptId, + sessionState: firstRunServerState, + output: { + type: 'lastMessage', + value: [], + }, + }, + }) + + return { + sessionState: firstRunServerState, + output: { + type: 'lastMessage' as const, + value: [], + }, + } + }, + ) + + const client = new CodebuffClient({ + apiKey: 'test-key', + }) + + // Run 1: cancelled mid-stream + const firstRunResult = await client.run({ + agent: 'base2', + prompt: 'Fix the bug in auth.ts', + signal: abortController.signal, + }) + + // Verify the first run preserved the user message and work + expect(firstRunResult.sessionState).toBeDefined() + const firstHistory = + firstRunResult.sessionState!.mainAgentState.messageHistory + expect(firstHistory.length).toBe(3) // user + assistant + interruption + + const firstUserMsg = firstHistory.find( + (m) => + m.role === 'user' && + m.content.some( + (c: any) => c.type === 'text' && c.text.includes('Fix the bug'), + ), + ) + expect(firstUserMsg).toBeDefined() + + // Now set up mock for the second run + mock.restore() + spyOn(databaseModule, 'getUserInfoFromApiKey').mockResolvedValue({ + id: 'user-123', + email: 'test@example.com', + discord_id: null, + stripe_customer_id: null, + banned: false, + created_at: new Date('2024-01-01T00:00:00Z'), + }) + spyOn(databaseModule, 'fetchAgentFromDatabase').mockResolvedValue(null) + spyOn(databaseModule, 'startAgentRun').mockResolvedValue('run-2') + spyOn(databaseModule, 'finishAgentRun').mockResolvedValue(undefined) + spyOn(databaseModule, 'addAgentStep').mockResolvedValue('step-2') + + // Second run: server receives the previous state and adds the new user message + const secondRunServerState = JSON.parse( + JSON.stringify(firstRunResult.sessionState!), + ) as typeof firstRunServerState + secondRunServerState.mainAgentState.messageHistory.push( + userMessage('Now also fix the login page'), + assistantMessage('I will fix both issues.'), + ) + + spyOn(mainPromptModule, 'callMainPrompt').mockImplementation( + async (params: Parameters[0]) => { + const { sendAction, promptId } = params + + await sendAction({ + action: { + type: 'prompt-response', + promptId, + sessionState: secondRunServerState, + output: { + type: 'lastMessage', + value: [], + }, + }, + }) + + return { + sessionState: secondRunServerState, + output: { + type: 'lastMessage' as const, + value: [], + }, + } + }, + ) + + // Run 2: uses previousRun from the cancelled first run + const secondRunResult = await client.run({ + agent: 'base2', + prompt: 'Now also fix the login page', + previousRun: firstRunResult, + }) + + // Verify the second run's session state includes history from BOTH runs + expect(secondRunResult.sessionState).toBeDefined() + const secondHistory = + secondRunResult.sessionState!.mainAgentState.messageHistory + + // Should have: first user msg + first assistant msg + interruption + second user msg + second assistant msg + expect(secondHistory.length).toBe(5) + + // The first user message should be present + const firstUserMsgInSecond = secondHistory.find( + (m) => + m.role === 'user' && + m.content.some( + (c: any) => c.type === 'text' && c.text.includes('Fix the bug'), + ), + ) + expect(firstUserMsgInSecond).toBeDefined() + + // The second user message should also be present + const secondUserMsg = secondHistory.find( + (m) => + m.role === 'user' && + m.content.some( + (c: any) => + c.type === 'text' && c.text.includes('fix the login page'), + ), + ) + expect(secondUserMsg).toBeDefined() + + // The first assistant message should be preserved + const firstAssistantMsg = secondHistory.find( + (m) => + m.role === 'assistant' && + m.content.some( + (c: any) => + c.type === 'text' && c.text.includes('authentication module'), + ), + ) + expect(firstAssistantMsg).toBeDefined() + }) + it('preserves session state even when abort happens mid-stream', async () => { spyOn(databaseModule, 'getUserInfoFromApiKey').mockResolvedValue({ id: 'user-123', email: 'test@example.com', discord_id: null, - referral_code: null, stripe_customer_id: null, banned: false, + created_at: new Date('2024-01-01T00:00:00Z'), }) spyOn(databaseModule, 'fetchAgentFromDatabase').mockResolvedValue(null) spyOn(databaseModule, 'startAgentRun').mockResolvedValue('run-1') @@ -706,7 +1154,9 @@ describe('Run Cancellation Handling', () => { spyOn(databaseModule, 'addAgentStep').mockResolvedValue('step-1') const abortController = new AbortController() - const serverSessionState = getInitialSessionState(getStubProjectFileContext()) + const serverSessionState = getInitialSessionState( + getStubProjectFileContext(), + ) // Simulate multiple tool calls and results (more complex work done) serverSessionState.mainAgentState.messageHistory.push( @@ -727,7 +1177,12 @@ describe('Run Cancellation Handling', () => { role: 'tool', toolCallId: 'read-1', toolName: 'read_files', - content: [{ type: 'json', value: [{ path: 'src/bug.ts', content: 'buggy code' }] }], + content: [ + { + type: 'json', + value: [{ path: 'src/bug.ts', content: 'buggy code' }], + }, + ], }, { role: 'assistant', @@ -745,7 +1200,12 @@ describe('Run Cancellation Handling', () => { role: 'tool', toolCallId: 'write-1', toolName: 'write_file', - content: [{ type: 'json', value: { file: 'src/bug.ts', message: 'File written' } }], + content: [ + { + type: 'json', + value: { file: 'src/bug.ts', message: 'File written' }, + }, + ], }, ) @@ -771,7 +1231,11 @@ describe('Run Cancellation Handling', () => { // Simulate agent runtime adding interruption message on abort serverSessionState.mainAgentState.messageHistory.push( - userMessage(withSystemTags("User interrupted the response. The assistant's previous work has been preserved.")) + userMessage( + withSystemTags( + "User interrupted the response. The assistant's previous work has been preserved.", + ), + ), ) // Server still returns the full session state @@ -829,6 +1293,8 @@ describe('Run Cancellation Handling', () => { // Verify interruption message was added at the end const lastMessage = messageHistory[messageHistory.length - 1] expect(lastMessage.role).toBe('user') - expect((lastMessage.content[0] as { type: 'text'; text: string }).text).toContain('User interrupted the response') + expect( + (lastMessage.content[0] as { type: 'text'; text: string }).text, + ).toContain('User interrupted the response') }) }) diff --git a/sdk/src/__tests__/run-error-preserves-history.test.ts b/sdk/src/__tests__/run-error-preserves-history.test.ts new file mode 100644 index 0000000000..4af0229de9 --- /dev/null +++ b/sdk/src/__tests__/run-error-preserves-history.test.ts @@ -0,0 +1,314 @@ +import * as mainPromptModule from '@codebuff/agent-runtime/main-prompt' +import { getInitialSessionState } from '@codebuff/common/types/session-state' +import { getStubProjectFileContext } from '@codebuff/common/util/file' +import { assistantMessage, userMessage } from '@codebuff/common/util/messages' +import { afterEach, describe, expect, it, mock, spyOn } from 'bun:test' + +import { CodebuffClient } from '../client' +import * as databaseModule from '../impl/database' + +interface ToolCallContentBlock { + type: 'tool-call' + toolCallId: string + toolName: string + input: Record +} + +const setupDatabaseMocks = () => { + spyOn(databaseModule, 'getUserInfoFromApiKey').mockResolvedValue({ + id: 'user-123', + email: 'test@example.com', + discord_id: null, + stripe_customer_id: null, + banned: false, + created_at: new Date('2024-01-01T00:00:00Z'), + }) + spyOn(databaseModule, 'fetchAgentFromDatabase').mockResolvedValue(null) + spyOn(databaseModule, 'startAgentRun').mockResolvedValue('run-1') + spyOn(databaseModule, 'finishAgentRun').mockResolvedValue(undefined) + spyOn(databaseModule, 'addAgentStep').mockResolvedValue('step-1') +} + +describe('Error preserves in-progress message history', () => { + afterEach(() => { + mock.restore() + }) + + it('preserves in-progress assistant work on error (simulated via shared state mutation)', async () => { + setupDatabaseMocks() + + // Simulate the agent runtime: + // 1. Mutates the shared session state with the user message and partial work + // 2. Then throws due to a downstream timeout/service error + spyOn(mainPromptModule, 'callMainPrompt').mockImplementation( + async (params: Parameters[0]) => { + const mainAgentState = params.action.sessionState.mainAgentState + + // Match the real runtime's behavior: replace messageHistory with a new + // array that includes the user prompt as its first entry. The SDK + // detects runtime progress via reference inequality, so we must + // reassign the array rather than pushing into it. + mainAgentState.messageHistory = [ + ...mainAgentState.messageHistory, + { + role: 'user', + content: [{ type: 'text', text: 'Fix the bug in auth.ts' }], + tags: ['USER_PROMPT'], + }, + { + role: 'assistant', + content: [ + { type: 'text', text: 'Let me read the auth file first.' }, + { + type: 'tool-call', + toolCallId: 'read-1', + toolName: 'read_files', + input: { paths: ['auth.ts'] }, + } as ToolCallContentBlock, + ], + }, + { + role: 'tool', + toolCallId: 'read-1', + toolName: 'read_files', + content: [ + { + type: 'json', + value: [{ path: 'auth.ts', content: 'const auth = ...' }], + }, + ], + }, + { + role: 'assistant', + content: [ + { type: 'text', text: 'Found the issue, writing the fix now.' }, + { + type: 'tool-call', + toolCallId: 'write-1', + toolName: 'write_file', + input: { path: 'auth.ts', content: 'const auth = fixed' }, + } as ToolCallContentBlock, + ], + }, + { + role: 'tool', + toolCallId: 'write-1', + toolName: 'write_file', + content: [{ type: 'json', value: { file: 'auth.ts', message: 'File written' } }], + }, + ] + + // Now simulate a server timeout on the next LLM call + const timeoutError = new Error('Service Unavailable') as Error & { + statusCode: number + responseBody: string + } + timeoutError.statusCode = 503 + timeoutError.responseBody = JSON.stringify({ + message: 'Request timeout after 30s', + }) + throw timeoutError + }, + ) + + const client = new CodebuffClient({ apiKey: 'test-key' }) + const result = await client.run({ + agent: 'base2', + prompt: 'Fix the bug in auth.ts', + }) + + // Error output with correct status code + expect(result.output.type).toBe('error') + const errorOutput = result.output as { + type: 'error' + message: string + statusCode?: number + } + expect(errorOutput.statusCode).toBe(503) + + const history = result.sessionState!.mainAgentState.messageHistory + + // The user's prompt should appear exactly once + const userPromptMessages = history.filter( + (m) => + m.role === 'user' && + (m.content as Array<{ type: string; text?: string }>).some( + (c) => c.type === 'text' && c.text?.includes('Fix the bug'), + ), + ) + expect(userPromptMessages.length).toBe(1) + + // Assistant text messages from both steps should be preserved + const firstAssistantText = history.find( + (m) => + m.role === 'assistant' && + (m.content as Array<{ type: string; text?: string }>).some( + (c) => c.type === 'text' && c.text?.includes('read the auth file'), + ), + ) + expect(firstAssistantText).toBeDefined() + + const secondAssistantText = history.find( + (m) => + m.role === 'assistant' && + (m.content as Array<{ type: string; text?: string }>).some( + (c) => c.type === 'text' && c.text?.includes('writing the fix'), + ), + ) + expect(secondAssistantText).toBeDefined() + + // Both tool calls and both tool results should be preserved + const readToolCall = history.find( + (m) => + m.role === 'assistant' && + (m.content as Array<{ type: string; toolCallId?: string }>).some( + (c) => c.type === 'tool-call' && c.toolCallId === 'read-1', + ), + ) + expect(readToolCall).toBeDefined() + + const writeToolCall = history.find( + (m) => + m.role === 'assistant' && + (m.content as Array<{ type: string; toolCallId?: string }>).some( + (c) => c.type === 'tool-call' && c.toolCallId === 'write-1', + ), + ) + expect(writeToolCall).toBeDefined() + + const readToolResult = history.find( + (m) => m.role === 'tool' && m.toolCallId === 'read-1', + ) + expect(readToolResult).toBeDefined() + + const writeToolResult = history.find( + (m) => m.role === 'tool' && m.toolCallId === 'write-1', + ) + expect(writeToolResult).toBeDefined() + }) + + it('a subsequent run after error includes the preserved in-progress history', async () => { + setupDatabaseMocks() + + // Run 1: agent does some work then hits an error + spyOn(mainPromptModule, 'callMainPrompt').mockImplementation( + async (params: Parameters[0]) => { + const mainAgentState = params.action.sessionState.mainAgentState + + mainAgentState.messageHistory = [ + ...mainAgentState.messageHistory, + { + role: 'user', + content: [{ type: 'text', text: 'Investigate the login bug' }], + tags: ['USER_PROMPT'], + }, + assistantMessage('I found the problem in auth.ts on line 42.'), + { + role: 'assistant', + content: [ + { + type: 'tool-call', + toolCallId: 'read-login', + toolName: 'read_files', + input: { paths: ['login.ts'] }, + } as ToolCallContentBlock, + ], + }, + { + role: 'tool', + toolCallId: 'read-login', + toolName: 'read_files', + content: [{ type: 'json', value: [{ path: 'login.ts', content: 'login code' }] }], + }, + ] + + const error = new Error('Service Unavailable') as Error & { + statusCode: number + } + error.statusCode = 503 + throw error + }, + ) + + const client = new CodebuffClient({ apiKey: 'test-key' }) + const firstResult = await client.run({ + agent: 'base2', + prompt: 'Investigate the login bug', + }) + + expect(firstResult.output.type).toBe('error') + + // Run 2: use the failed run as previousRun + mock.restore() + setupDatabaseMocks() + + let historyReceivedByRuntime: unknown[] | undefined + spyOn(mainPromptModule, 'callMainPrompt').mockImplementation( + async (params: Parameters[0]) => { + const { sendAction, promptId } = params + historyReceivedByRuntime = [ + ...params.action.sessionState.mainAgentState.messageHistory, + ] + + const responseSessionState = getInitialSessionState( + getStubProjectFileContext(), + ) + responseSessionState.mainAgentState.messageHistory = [ + ...params.action.sessionState.mainAgentState.messageHistory, + userMessage('Now try again'), + assistantMessage('Continuing with the fix.'), + ] + + await sendAction({ + action: { + type: 'prompt-response', + promptId, + sessionState: responseSessionState, + output: { type: 'lastMessage', value: [] }, + }, + }) + + return { + sessionState: responseSessionState, + output: { type: 'lastMessage' as const, value: [] }, + } + }, + ) + + const secondResult = await client.run({ + agent: 'base2', + prompt: 'Now try again', + previousRun: firstResult, + }) + + // The runtime should have received history containing the work from the first run + expect(historyReceivedByRuntime).toBeDefined() + const receivedReadCall = historyReceivedByRuntime!.find( + (m) => + (m as { role: string }).role === 'assistant' && + ((m as { content: Array<{ type: string; toolCallId?: string }> }) + .content ?? []).some( + (c) => c.type === 'tool-call' && c.toolCallId === 'read-login', + ), + ) + expect(receivedReadCall).toBeDefined() + + const receivedToolResult = historyReceivedByRuntime!.find( + (m) => + (m as { role: string }).role === 'tool' && + (m as { toolCallId: string }).toolCallId === 'read-login', + ) + expect(receivedToolResult).toBeDefined() + + // Final result should preserve history + const finalHistory = secondResult.sessionState!.mainAgentState.messageHistory + const finalReadCall = finalHistory.find( + (m) => + m.role === 'assistant' && + (m.content as Array<{ type: string; toolCallId?: string }>).some( + (c) => c.type === 'tool-call' && c.toolCallId === 'read-login', + ), + ) + expect(finalReadCall).toBeDefined() + }) +}) diff --git a/sdk/src/__tests__/run-file-filter.test.ts b/sdk/src/__tests__/run-file-filter.test.ts index 78ccdbf37d..5d1be280a2 100644 --- a/sdk/src/__tests__/run-file-filter.test.ts +++ b/sdk/src/__tests__/run-file-filter.test.ts @@ -1,4 +1,3 @@ - import * as mainPromptModule from '@codebuff/agent-runtime/main-prompt' import { FILE_READ_STATUS } from '@codebuff/common/old-constants' import * as projectFileTree from '@codebuff/common/project-file-tree' @@ -71,9 +70,9 @@ describe('CodebuffClientOptions fileFilter', () => { id: 'user-123', email: 'test@example.com', discord_id: null, - referral_code: null, stripe_customer_id: null, banned: false, + created_at: new Date('2024-01-01T00:00:00Z'), }) spyOn(databaseModule, 'fetchAgentFromDatabase').mockResolvedValue(null) spyOn(databaseModule, 'startAgentRun').mockResolvedValue('run-1') @@ -91,9 +90,7 @@ describe('CodebuffClientOptions fileFilter', () => { let requestedFiles: Record = {} spyOn(mainPromptModule, 'callMainPrompt').mockImplementation( - async ( - params: Parameters[0], - ) => { + async (params: Parameters[0]) => { const { sendAction, promptId, requestFiles } = params const sessionState = getInitialSessionState(getStubProjectFileContext()) @@ -157,9 +154,9 @@ describe('CodebuffClientOptions fileFilter', () => { id: 'user-123', email: 'test@example.com', discord_id: null, - referral_code: null, stripe_customer_id: null, banned: false, + created_at: new Date('2024-01-01T00:00:00Z'), }) spyOn(databaseModule, 'fetchAgentFromDatabase').mockResolvedValue(null) spyOn(databaseModule, 'startAgentRun').mockResolvedValue('run-1') @@ -177,9 +174,7 @@ describe('CodebuffClientOptions fileFilter', () => { let requestedFiles: Record = {} spyOn(mainPromptModule, 'callMainPrompt').mockImplementation( - async ( - params: Parameters[0], - ) => { + async (params: Parameters[0]) => { const { sendAction, promptId, requestFiles } = params const sessionState = getInitialSessionState(getStubProjectFileContext()) @@ -240,9 +235,9 @@ describe('CodebuffClientOptions fileFilter', () => { id: 'user-123', email: 'test@example.com', discord_id: null, - referral_code: null, stripe_customer_id: null, banned: false, + created_at: new Date('2024-01-01T00:00:00Z'), }) spyOn(databaseModule, 'fetchAgentFromDatabase').mockResolvedValue(null) spyOn(databaseModule, 'startAgentRun').mockResolvedValue('run-1') @@ -259,9 +254,7 @@ describe('CodebuffClientOptions fileFilter', () => { let optionalFileResult: string | null = null spyOn(mainPromptModule, 'callMainPrompt').mockImplementation( - async ( - params: Parameters[0], - ) => { + async (params: Parameters[0]) => { const { sendAction, promptId, requestOptionalFile } = params const sessionState = getInitialSessionState(getStubProjectFileContext()) @@ -319,14 +312,83 @@ describe('CodebuffClientOptions fileFilter', () => { expect(optionalFileResult).toBeNull() }) + it('should tolerate absolute requestOptionalFile paths inside cwd', async () => { + spyOn(databaseModule, 'getUserInfoFromApiKey').mockResolvedValue({ + id: 'user-123', + email: 'test@example.com', + discord_id: null, + stripe_customer_id: null, + banned: false, + created_at: new Date('2024-01-01T00:00:00Z'), + }) + spyOn(databaseModule, 'fetchAgentFromDatabase').mockResolvedValue(null) + spyOn(databaseModule, 'startAgentRun').mockResolvedValue('run-1') + spyOn(databaseModule, 'finishAgentRun').mockResolvedValue(undefined) + spyOn(databaseModule, 'addAgentStep').mockResolvedValue('step-1') + spyOn(projectFileTree, 'isFileIgnored').mockResolvedValue(false) + + const mockFs = createMockFs({ + files: { + '/project/src/index.ts': { content: 'normal file content' }, + }, + }) + + const optionalFileResult: { current: string | null } = { current: null } + + spyOn(mainPromptModule, 'callMainPrompt').mockImplementation( + async (params: Parameters[0]) => { + const { sendAction, promptId, requestOptionalFile } = params + const sessionState = getInitialSessionState(getStubProjectFileContext()) + + optionalFileResult.current = await requestOptionalFile({ + filePath: '/project/src/index.ts', + }) + + await sendAction({ + action: { + type: 'prompt-response', + promptId, + sessionState, + output: { + type: 'lastMessage', + value: [], + }, + }, + }) + + return { + sessionState, + output: { + type: 'lastMessage' as const, + value: [], + }, + } + }, + ) + + const client = new CodebuffClient({ + apiKey: 'test-key', + cwd: '/project', + fsSource: mockFs, + }) + + const result = await client.run({ + agent: 'base2', + prompt: 'read optional file', + }) + + expect(result.output.type).toBe('lastMessage') + expect(optionalFileResult.current).toBe('normal file content') + }) + it('should allow all files when no fileFilter is provided', async () => { spyOn(databaseModule, 'getUserInfoFromApiKey').mockResolvedValue({ id: 'user-123', email: 'test@example.com', discord_id: null, - referral_code: null, stripe_customer_id: null, banned: false, + created_at: new Date('2024-01-01T00:00:00Z'), }) spyOn(databaseModule, 'fetchAgentFromDatabase').mockResolvedValue(null) spyOn(databaseModule, 'startAgentRun').mockResolvedValue('run-1') @@ -343,9 +405,7 @@ describe('CodebuffClientOptions fileFilter', () => { let requestedFiles: Record = {} spyOn(mainPromptModule, 'callMainPrompt').mockImplementation( - async ( - params: Parameters[0], - ) => { + async (params: Parameters[0]) => { const { sendAction, promptId, requestFiles } = params const sessionState = getInitialSessionState(getStubProjectFileContext()) @@ -396,9 +456,9 @@ describe('CodebuffClientOptions fileFilter', () => { id: 'user-123', email: 'test@example.com', discord_id: null, - referral_code: null, stripe_customer_id: null, banned: false, + created_at: new Date('2024-01-01T00:00:00Z'), }) spyOn(databaseModule, 'fetchAgentFromDatabase').mockResolvedValue(null) spyOn(databaseModule, 'startAgentRun').mockResolvedValue('run-1') @@ -417,9 +477,7 @@ describe('CodebuffClientOptions fileFilter', () => { }) spyOn(mainPromptModule, 'callMainPrompt').mockImplementation( - async ( - params: Parameters[0], - ) => { + async (params: Parameters[0]) => { const { sendAction, promptId, requestFiles } = params const sessionState = getInitialSessionState(getStubProjectFileContext()) diff --git a/sdk/src/__tests__/run-handle-event.test.ts b/sdk/src/__tests__/run-handle-event.test.ts index d8f4df3408..d3fc76b3ec 100644 --- a/sdk/src/__tests__/run-handle-event.test.ts +++ b/sdk/src/__tests__/run-handle-event.test.ts @@ -20,9 +20,9 @@ describe('CodebuffClient handleEvent / handleStreamChunk', () => { id: 'user-123', email: 'test@example.com', discord_id: null, - referral_code: null, stripe_customer_id: null, banned: false, + created_at: new Date('2024-01-01T00:00:00Z'), }) spyOn(databaseModule, 'fetchAgentFromDatabase').mockResolvedValue(null) spyOn(databaseModule, 'startAgentRun').mockResolvedValue('run-1') diff --git a/sdk/src/__tests__/run-mcp-tool-filter.test.ts b/sdk/src/__tests__/run-mcp-tool-filter.test.ts new file mode 100644 index 0000000000..40960c4c82 --- /dev/null +++ b/sdk/src/__tests__/run-mcp-tool-filter.test.ts @@ -0,0 +1,124 @@ +import * as mainPromptModule from '@codebuff/agent-runtime/main-prompt' +import { getInitialSessionState } from '@codebuff/common/types/session-state' +import { getStubProjectFileContext } from '@codebuff/common/util/file' +import { afterEach, describe, expect, it, mock, spyOn } from 'bun:test' + +import { CodebuffClient } from '../client' +import * as mcpClientModule from '@codebuff/common/mcp/client' +import * as databaseModule from '../impl/database' + +import type { AgentDefinition } from '@codebuff/common/templates/initial-agents-dir/types/agent-definition' +import type { MCPConfig } from '@codebuff/common/types/mcp' + +const browserMcpConfig: MCPConfig = { + type: 'stdio', + command: 'npx', + args: ['-y', 'fake-mcp-server'], + env: {}, +} + +const TEST_AGENT: AgentDefinition = { + id: 'mcp-filter-agent', + displayName: 'MCP Filter Agent', + model: 'openai/gpt-5-mini', + reasoningOptions: { effort: 'minimal' }, + mcpServers: { + browser: browserMcpConfig, + }, + toolNames: ['browser/browser_navigate', 'browser/browser_snapshot'], + systemPrompt: 'Test MCP filtering.', +} + +describe('MCP tool filtering', () => { + afterEach(() => { + mock.restore() + }) + + it('returns only allowlisted MCP tools when an agent restricts toolNames', async () => { + spyOn(databaseModule, 'getUserInfoFromApiKey').mockResolvedValue({ + id: 'user-123', + email: 'test@example.com', + discord_id: null, + stripe_customer_id: null, + banned: false, + created_at: new Date('2024-01-01T00:00:00Z'), + }) + spyOn(databaseModule, 'fetchAgentFromDatabase').mockResolvedValue(null) + spyOn(databaseModule, 'startAgentRun').mockResolvedValue('run-1') + spyOn(databaseModule, 'finishAgentRun').mockResolvedValue(undefined) + spyOn(databaseModule, 'addAgentStep').mockResolvedValue('step-1') + + spyOn(mcpClientModule, 'getMCPClient').mockResolvedValue('mcp-client-id') + spyOn(mcpClientModule, 'listMCPTools').mockResolvedValue({ + tools: [ + { + name: 'browser_navigate', + description: 'Navigate to a page', + inputSchema: { type: 'object', properties: {} }, + }, + { + name: 'browser_snapshot', + description: 'Capture snapshot', + inputSchema: { type: 'object', properties: {} }, + }, + { + name: 'browser_click', + description: 'Click an element', + inputSchema: { type: 'object', properties: {} }, + }, + ], + } as Awaited>) + + let filteredTools: Array<{ name: string }> = [] + + spyOn(mainPromptModule, 'callMainPrompt').mockImplementation( + async (params: Parameters[0]) => { + const { sendAction, promptId, requestMcpToolData } = params + const sessionState = getInitialSessionState(getStubProjectFileContext()) + + filteredTools = await requestMcpToolData({ + mcpConfig: browserMcpConfig, + toolNames: TEST_AGENT.toolNames! + .filter((toolName) => toolName.startsWith('browser/')) + .map((toolName) => toolName.slice('browser/'.length)), + }) + + await sendAction({ + action: { + type: 'prompt-response', + promptId, + sessionState, + output: { + type: 'lastMessage', + value: [], + }, + }, + }) + + return { + sessionState, + output: { + type: 'lastMessage' as const, + value: [], + }, + } + }, + ) + + const client = new CodebuffClient({ + apiKey: 'test-key', + agentDefinitions: [TEST_AGENT], + }) + + const result = await client.run({ + agent: TEST_AGENT.id, + prompt: 'List MCP tools', + }) + + expect(result.output.type).toBe('lastMessage') + expect(filteredTools.map((tool: { name: string }) => tool.name)).toEqual([ + 'browser_navigate', + 'browser_snapshot', + ]) + }) +}) diff --git a/sdk/src/credentials.ts b/sdk/src/credentials.ts index 0bbdfb553f..4d21e717b5 100644 --- a/sdk/src/credentials.ts +++ b/sdk/src/credentials.ts @@ -2,20 +2,20 @@ import fs from 'fs' import path from 'node:path' import os from 'os' -import { CLAUDE_OAUTH_CLIENT_ID } from '@codebuff/common/constants/claude-oauth' +import { + CHATGPT_OAUTH_CLIENT_ID, + CHATGPT_OAUTH_TOKEN_URL, +} from '@codebuff/common/constants/chatgpt-oauth' import { env } from '@codebuff/common/env' import { userSchema } from '@codebuff/common/util/credentials' import { z } from 'zod/v4' -import { getClaudeOAuthTokenFromEnv } from './env' +import { getChatGptOAuthTokenFromEnv } from './env' import type { ClientEnv } from '@codebuff/common/types/contracts/env' import type { User } from '@codebuff/common/util/credentials' -/** - * Schema for Claude OAuth credentials. - */ -const claudeOAuthSchema = z.object({ +const chatGptOAuthSchema = z.object({ accessToken: z.string(), refreshToken: z.string(), expiresAt: z.number(), @@ -24,11 +24,11 @@ const claudeOAuthSchema = z.object({ /** * Unified schema for the credentials file. - * Contains both Codebuff user credentials and Claude OAuth credentials. + * Contains both Codebuff user credentials and ChatGPT OAuth credentials. */ const credentialsFileSchema = z.object({ default: userSchema.optional(), - claudeOAuth: claudeOAuthSchema.optional(), + chatgptOAuth: chatGptOAuthSchema.optional(), }) const ensureDirectoryExistsSync = (dir: string) => { @@ -83,9 +83,9 @@ export const getUserCredentials = (clientEnv: ClientEnv = env): User | null => { } /** - * Claude OAuth credentials stored in the credentials file. + * ChatGPT OAuth credentials stored in the credentials file. */ -export interface ClaudeOAuthCredentials { +export interface ChatGptOAuthCredentials { accessToken: string refreshToken: string expiresAt: number // Unix timestamp in milliseconds @@ -93,50 +93,42 @@ export interface ClaudeOAuthCredentials { } /** - * Get Claude OAuth credentials from file or environment variable. + * Get ChatGPT OAuth credentials from environment variable or stored file. * Environment variable takes precedence. - * @returns OAuth credentials or null if not found */ -export const getClaudeOAuthCredentials = ( +export const getChatGptOAuthCredentials = ( clientEnv: ClientEnv = env, -): ClaudeOAuthCredentials | null => { - // Check environment variable first - const envToken = getClaudeOAuthTokenFromEnv() +): ChatGptOAuthCredentials | null => { + // 1. Environment variable takes highest precedence + const envToken = getChatGptOAuthTokenFromEnv() if (envToken) { - // Return a synthetic credentials object for env var tokens - // These tokens are assumed to be valid and non-expiring for simplicity return { accessToken: envToken, refreshToken: '', - expiresAt: Date.now() + 365 * 24 * 60 * 60 * 1000, // 1 year from now + expiresAt: Date.now() + 365 * 24 * 60 * 60 * 1000, connectedAt: Date.now(), } } + // 2. Codebuff's own stored credentials const credentialsPath = getCredentialsPath(clientEnv) - if (!fs.existsSync(credentialsPath)) { - return null - } - - try { - const credentialsFile = fs.readFileSync(credentialsPath, 'utf8') - const parsed = credentialsFileSchema.safeParse(JSON.parse(credentialsFile)) - if (!parsed.success || !parsed.data.claudeOAuth) { - return null + if (fs.existsSync(credentialsPath)) { + try { + const credentialsFile = fs.readFileSync(credentialsPath, 'utf8') + const parsed = credentialsFileSchema.safeParse(JSON.parse(credentialsFile)) + if (parsed.success && parsed.data.chatgptOAuth) { + return parsed.data.chatgptOAuth + } + } catch { + // Fall through } - return parsed.data.claudeOAuth - } catch (error) { - console.error('Error reading Claude OAuth credentials', error) - return null } + + return null } -/** - * Save Claude OAuth credentials to the credentials file. - * Preserves existing user credentials. - */ -export const saveClaudeOAuthCredentials = ( - credentials: ClaudeOAuthCredentials, +export const saveChatGptOAuthCredentials = ( + credentials: ChatGptOAuthCredentials, clientEnv: ClientEnv = env, ): void => { const configDir = getConfigDir(clientEnv) @@ -155,17 +147,13 @@ export const saveClaudeOAuthCredentials = ( const updatedData = { ...existingData, - claudeOAuth: credentials, + chatgptOAuth: credentials, } fs.writeFileSync(credentialsPath, JSON.stringify(updatedData, null, 2)) } -/** - * Clear Claude OAuth credentials from the credentials file. - * Preserves other credentials. - */ -export const clearClaudeOAuthCredentials = ( +export const clearChatGptOAuthCredentials = ( clientEnv: ClientEnv = env, ): void => { const credentialsPath = getCredentialsPath(clientEnv) @@ -175,126 +163,107 @@ export const clearClaudeOAuthCredentials = ( try { const existingData = JSON.parse(fs.readFileSync(credentialsPath, 'utf8')) - delete existingData.claudeOAuth + delete existingData.chatgptOAuth fs.writeFileSync(credentialsPath, JSON.stringify(existingData, null, 2)) } catch { // Ignore errors } } -/** - * Check if Claude OAuth credentials are valid (not expired). - * Returns true if credentials exist and haven't expired. - */ -export const isClaudeOAuthValid = (clientEnv: ClientEnv = env): boolean => { - const credentials = getClaudeOAuthCredentials(clientEnv) +export const isChatGptOAuthValid = (clientEnv: ClientEnv = env): boolean => { + const credentials = getChatGptOAuthCredentials(clientEnv) if (!credentials) { return false } - // Add 5 minute buffer before expiry const bufferMs = 5 * 60 * 1000 return credentials.expiresAt > Date.now() + bufferMs } -// Mutex to prevent concurrent refresh attempts -let refreshPromise: Promise | null = null +let chatGptRefreshPromise: Promise | null = null -/** - * Refresh the Claude OAuth access token using the refresh token. - * Returns the new credentials if successful, null if refresh fails. - * Uses a mutex to prevent concurrent refresh attempts. - */ -export const refreshClaudeOAuthToken = async ( +export const refreshChatGptOAuthToken = async ( clientEnv: ClientEnv = env, -): Promise => { - // If a refresh is already in progress, wait for it - if (refreshPromise) { - return refreshPromise +): Promise => { + if (chatGptRefreshPromise) { + return chatGptRefreshPromise } - const credentials = getClaudeOAuthCredentials(clientEnv) + const credentials = getChatGptOAuthCredentials(clientEnv) if (!credentials?.refreshToken) { return null } - // Start the refresh and store the promise - refreshPromise = (async () => { + chatGptRefreshPromise = (async () => { try { - const response = await fetch( - 'https://console.anthropic.com/v1/oauth/token', - { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ - grant_type: 'refresh_token', - refresh_token: credentials.refreshToken, - client_id: CLAUDE_OAUTH_CLIENT_ID, - }), + const response = await fetch(CHATGPT_OAUTH_TOKEN_URL, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', }, - ) + body: JSON.stringify({ + grant_type: 'refresh_token', + refresh_token: credentials.refreshToken, + client_id: CHATGPT_OAUTH_CLIENT_ID, + }), + }) if (!response.ok) { - // Refresh failed, clear credentials - clearClaudeOAuthCredentials(clientEnv) + console.debug(`ChatGPT OAuth token refresh failed (status ${response.status})`) return null } const data = await response.json() - const newCredentials: ClaudeOAuthCredentials = { + if ( + typeof data?.access_token !== 'string' || + data.access_token.trim().length === 0 + ) { + console.debug('ChatGPT OAuth token refresh returned empty access token') + return null + } + + const expiresIn = + typeof data.expires_in === 'number' ? data.expires_in * 1000 : 3600 * 1000 + + const newCredentials: ChatGptOAuthCredentials = { accessToken: data.access_token, refreshToken: data.refresh_token ?? credentials.refreshToken, - expiresAt: Date.now() + data.expires_in * 1000, + expiresAt: Date.now() + expiresIn, connectedAt: credentials.connectedAt, } - // Save updated credentials - saveClaudeOAuthCredentials(newCredentials, clientEnv) + saveChatGptOAuthCredentials(newCredentials, clientEnv) return newCredentials - } catch { - // Refresh failed, clear credentials - clearClaudeOAuthCredentials(clientEnv) + } catch (error) { + console.debug('ChatGPT OAuth token refresh failed:', error instanceof Error ? error.message : String(error)) return null } finally { - // Clear the mutex after completion - refreshPromise = null + chatGptRefreshPromise = null } })() - return refreshPromise + return chatGptRefreshPromise } -/** - * Get valid Claude OAuth credentials, refreshing if necessary. - * This is the main function to use when you need credentials for an API call. - * - * - Returns credentials immediately if valid (>5 min until expiry) - * - Attempts refresh if token is expired or near-expiry - * - Returns null if no credentials or refresh fails - */ -export const getValidClaudeOAuthCredentials = async ( +export const getValidChatGptOAuthCredentials = async ( clientEnv: ClientEnv = env, -): Promise => { - const credentials = getClaudeOAuthCredentials(clientEnv) +): Promise => { + const credentials = getChatGptOAuthCredentials(clientEnv) if (!credentials) { return null } - // Check if token is from environment variable (synthetic credentials, no refresh needed) + const bufferMs = 5 * 60 * 1000 + + // No refresh token (e.g. env var override) — return only if still valid if (!credentials.refreshToken) { - // Environment variable tokens are assumed valid - return credentials + return credentials.expiresAt > Date.now() + bufferMs ? credentials : null } - // Check if token is valid with 5 minute buffer - const bufferMs = 5 * 60 * 1000 if (credentials.expiresAt > Date.now() + bufferMs) { return credentials } - // Token is expired or expiring soon, try to refresh - return refreshClaudeOAuthToken(clientEnv) + return refreshChatGptOAuthToken(clientEnv) } diff --git a/sdk/src/env.ts b/sdk/src/env.ts index 325059acdf..033e3f245d 100644 --- a/sdk/src/env.ts +++ b/sdk/src/env.ts @@ -6,7 +6,7 @@ */ import { BYOK_OPENROUTER_ENV_VAR } from '@codebuff/common/constants/byok' -import { CLAUDE_OAUTH_TOKEN_ENV_VAR } from '@codebuff/common/constants/claude-oauth' +import { CHATGPT_OAUTH_TOKEN_ENV_VAR } from '@codebuff/common/constants/chatgpt-oauth' import { API_KEY_ENV_VAR } from '@codebuff/common/constants/paths' import { getBaseEnv } from '@codebuff/common/env-process' @@ -43,9 +43,8 @@ export const getByokOpenrouterApiKeyFromEnv = (): string | undefined => { } /** - * Get Claude OAuth token from environment variable. - * This allows users to provide their Claude Pro/Max OAuth token for direct Anthropic API access. + * Get ChatGPT OAuth token from environment variable. */ -export const getClaudeOAuthTokenFromEnv = (): string | undefined => { - return process.env[CLAUDE_OAUTH_TOKEN_ENV_VAR] +export const getChatGptOAuthTokenFromEnv = (): string | undefined => { + return process.env[CHATGPT_OAUTH_TOKEN_ENV_VAR] } diff --git a/sdk/src/impl/__tests__/llm-chatgpt-oauth-policy.test.ts b/sdk/src/impl/__tests__/llm-chatgpt-oauth-policy.test.ts new file mode 100644 index 0000000000..825853803e --- /dev/null +++ b/sdk/src/impl/__tests__/llm-chatgpt-oauth-policy.test.ts @@ -0,0 +1,67 @@ +import { describe, expect, test } from 'bun:test' + +import { classifyChatGptOAuthStreamError } from '../llm' + +describe('classifyChatGptOAuthStreamError', () => { + test('returns ignore when ChatGPT OAuth is not active', () => { + const result = classifyChatGptOAuthStreamError({ + isChatGptOAuth: false, + hasYieldedContent: false, + error: { statusCode: 429 }, + }) + expect(result).toBe('ignore') + }) + + test('returns fallback-rate-limit for 429 before content is yielded', () => { + const result = classifyChatGptOAuthStreamError({ + isChatGptOAuth: true, + hasYieldedContent: false, + error: { statusCode: 429 }, + }) + expect(result).toBe('fallback-rate-limit') + }) + + test('returns fail-auth-reconnect for 401/403 before content is yielded', () => { + const unauthorized = classifyChatGptOAuthStreamError({ + isChatGptOAuth: true, + hasYieldedContent: false, + error: { statusCode: 401 }, + }) + const forbidden = classifyChatGptOAuthStreamError({ + isChatGptOAuth: true, + hasYieldedContent: false, + error: { statusCode: 403 }, + }) + + expect(unauthorized).toBe('fail-auth-reconnect') + expect(forbidden).toBe('fail-auth-reconnect') + }) + + test('returns fail-fast for non-rate-limit non-auth errors', () => { + const result = classifyChatGptOAuthStreamError({ + isChatGptOAuth: true, + hasYieldedContent: false, + error: { statusCode: 500 }, + }) + expect(result).toBe('fail-fast') + }) + + test('returns ignore after partial output has been yielded', () => { + const result = classifyChatGptOAuthStreamError({ + isChatGptOAuth: true, + hasYieldedContent: true, + error: { statusCode: 429 }, + }) + expect(result).toBe('ignore') + }) + + test('returns ignore when skip flag is set', () => { + const result = classifyChatGptOAuthStreamError({ + isChatGptOAuth: true, + skipChatGptOAuth: true, + hasYieldedContent: false, + error: { statusCode: 429 }, + }) + expect(result).toBe('ignore') + }) +}) diff --git a/sdk/src/impl/__tests__/model-provider-free-mode.test.ts b/sdk/src/impl/__tests__/model-provider-free-mode.test.ts new file mode 100644 index 0000000000..2471da37b0 --- /dev/null +++ b/sdk/src/impl/__tests__/model-provider-free-mode.test.ts @@ -0,0 +1,98 @@ +import { describe, expect, test, beforeEach, afterEach, mock } from 'bun:test' +import { + clearMockedModules, + mockModule, +} from '@codebuff/common/testing/mock-modules' + +describe('getModelForRequest free-mode guards', () => { + const mockGetValidChatGptOAuthCredentials = mock(() => + Promise.resolve(null), + ) + + beforeEach(async () => { + // Mock CHATGPT_OAUTH_ENABLED to true so the ChatGPT OAuth path is entered. + // Uses mockModule helper since this is an absolute package specifier. + await mockModule('@codebuff/common/constants/chatgpt-oauth', () => ({ + CHATGPT_OAUTH_ENABLED: true, + })) + + // Mock credentials directly with Bun's mock.module — the helper resolves + // relative paths from common/src/testing/, not from this test file. + mock.module('../../credentials', () => ({ + getValidChatGptOAuthCredentials: mockGetValidChatGptOAuthCredentials, + })) + + mockGetValidChatGptOAuthCredentials.mockReset() + mockGetValidChatGptOAuthCredentials.mockResolvedValue(null) + }) + + afterEach(() => { + mock.restore() + clearMockedModules() + }) + + async function importFresh() { + const mod = await import('../model-provider') + // Ensure clean rate-limit state + mod.resetChatGptOAuthRateLimit() + return mod + } + + test('throws when ChatGPT OAuth is rate-limited in free mode', async () => { + const { getModelForRequest, markChatGptOAuthRateLimited } = + await importFresh() + + markChatGptOAuthRateLimited() + + await expect( + getModelForRequest({ + apiKey: 'test-key', + model: 'openai/gpt-5.3', + costMode: 'free', + }), + ).rejects.toThrow('ChatGPT rate limit reached') + }) + + test('throws when ChatGPT OAuth credentials are unavailable in free mode', async () => { + const { getModelForRequest } = await importFresh() + + mockGetValidChatGptOAuthCredentials.mockResolvedValue(null) + + await expect( + getModelForRequest({ + apiKey: 'test-key', + model: 'openai/gpt-5.3', + costMode: 'free', + }), + ).rejects.toThrow('ChatGPT OAuth credentials unavailable') + }) + + test('falls through to backend when rate-limited in non-free mode', async () => { + const { getModelForRequest, markChatGptOAuthRateLimited } = + await importFresh() + + markChatGptOAuthRateLimited() + + const result = await getModelForRequest({ + apiKey: 'test-key', + model: 'openai/gpt-5.3', + costMode: 'default', + }) + + expect(result.isChatGptOAuth).toBe(false) + }) + + test('falls through to backend when credentials unavailable in non-free mode', async () => { + const { getModelForRequest } = await importFresh() + + mockGetValidChatGptOAuthCredentials.mockResolvedValue(null) + + const result = await getModelForRequest({ + apiKey: 'test-key', + model: 'openai/gpt-5.3', + costMode: 'default', + }) + + expect(result.isChatGptOAuth).toBe(false) + }) +}) diff --git a/sdk/src/impl/__tests__/provider-options-metadata.test.ts b/sdk/src/impl/__tests__/provider-options-metadata.test.ts new file mode 100644 index 0000000000..908ce5446f --- /dev/null +++ b/sdk/src/impl/__tests__/provider-options-metadata.test.ts @@ -0,0 +1,72 @@ +import { describe, expect, it } from 'bun:test' + +import { getProviderOptions } from '../llm' + +describe('getProviderOptions — codebuff_metadata', () => { + const baseParams = { + model: 'openrouter/anthropic/claude-sonnet-4-5', + runId: 'run-1', + clientSessionId: 'session-1', + } + + it('includes run_id and client_id in codebuff_metadata', () => { + const opts = getProviderOptions(baseParams) + const meta = (opts.codebuff as any).codebuff_metadata + expect(meta).toMatchObject({ + run_id: 'run-1', + client_id: 'session-1', + }) + }) + + it('merges extraCodebuffMetadata into codebuff_metadata', () => { + const opts = getProviderOptions({ + ...baseParams, + extraCodebuffMetadata: { freebuff_instance_id: 'abc-123' }, + }) + const meta = (opts.codebuff as any).codebuff_metadata + expect(meta).toMatchObject({ + run_id: 'run-1', + client_id: 'session-1', + freebuff_instance_id: 'abc-123', + }) + }) + + it('omits extra keys when extraCodebuffMetadata is undefined', () => { + const opts = getProviderOptions(baseParams) + const meta = (opts.codebuff as any).codebuff_metadata + expect(Object.keys(meta)).toEqual( + expect.arrayContaining(['run_id', 'client_id']), + ) + expect(meta.freebuff_instance_id).toBeUndefined() + }) + + it('cost_mode passes through alongside extra metadata', () => { + const opts = getProviderOptions({ + ...baseParams, + costMode: 'free', + extraCodebuffMetadata: { freebuff_instance_id: 'uuid-xyz' }, + }) + const meta = (opts.codebuff as any).codebuff_metadata + expect(meta).toMatchObject({ + cost_mode: 'free', + freebuff_instance_id: 'uuid-xyz', + }) + }) + + it('extraCodebuffMetadata does not overwrite reserved keys', () => { + const opts = getProviderOptions({ + ...baseParams, + costMode: 'free', + extraCodebuffMetadata: { + // These are intentionally the same keys the function already sets — + // make sure a misuse doesn't let callers override server-trusted + // identifiers. The spread currently puts caller keys last, which + // means it WOULD override. If that's ever intentional, change this + // test; for now, lock it down. + run_id: 'evil-override', + }, + }) + const meta = (opts.codebuff as any).codebuff_metadata + expect(meta.run_id).toBe('run-1') + }) +}) diff --git a/sdk/src/impl/agent-runtime.ts b/sdk/src/impl/agent-runtime.ts index 9c8503d128..17858d8196 100644 --- a/sdk/src/impl/agent-runtime.ts +++ b/sdk/src/impl/agent-runtime.ts @@ -1,6 +1,7 @@ -import { trackEvent } from '@codebuff/common/analytics' +import { trackEvent as trackCommonEvent } from '@codebuff/common/analytics' import { env as clientEnvDefault } from '@codebuff/common/env' import { getCiEnv } from '@codebuff/common/env-ci' +import { shouldTrackAnalyticsEvent } from '@codebuff/common/util/analytics-sampling' import { success } from '@codebuff/common/util/error' import { @@ -19,6 +20,7 @@ import type { import type { DatabaseAgentCache } from '@codebuff/common/types/contracts/database' import type { ClientEnv } from '@codebuff/common/types/contracts/env' import type { Logger } from '@codebuff/common/types/contracts/logger' +import type { TrackEventFn } from '@codebuff/common/types/contracts/analytics' const databaseAgentCache: DatabaseAgentCache = new Map() @@ -51,6 +53,21 @@ export function getAgentRuntimeImpl( sendSubagentChunk, } = params + const trackSdkRuntimeEvent: TrackEventFn = (eventParams) => { + if ( + clientEnv.NEXT_PUBLIC_CB_ENVIRONMENT === 'prod' && + !shouldTrackAnalyticsEvent({ + event: eventParams.event, + distinctId: eventParams.userId, + properties: eventParams.properties, + }) + ) { + return + } + + trackCommonEvent(eventParams) + } + return { // Environment clientEnv, @@ -78,7 +95,7 @@ export function getAgentRuntimeImpl( databaseAgentCache, // Analytics - trackEvent, + trackEvent: trackSdkRuntimeEvent, // Other logger: logger ?? noopLogger, @@ -102,4 +119,4 @@ const noopLogger: Logger = { info: () => {}, warn: () => {}, error: () => {}, -} \ No newline at end of file +} diff --git a/sdk/src/impl/chatgpt-backend-fetch.ts b/sdk/src/impl/chatgpt-backend-fetch.ts new file mode 100644 index 0000000000..3a645dbf67 --- /dev/null +++ b/sdk/src/impl/chatgpt-backend-fetch.ts @@ -0,0 +1,516 @@ +/** + * Custom fetch for routing ChatGPT OAuth requests through the ChatGPT backend API. + * + * The AI SDK's OpenAICompatibleChatLanguageModel speaks Chat Completions format, + * but ChatGPT OAuth tokens only work with the ChatGPT backend (chatgpt.com/backend-api) + * which uses the Responses API format. + * + * This module transforms: + * - Request: Chat Completions body → Responses API body + * - Response: Responses API SSE → Chat Completions SSE + */ + +import type { FetchFunction } from '@ai-sdk/provider-utils' + +type FetchLike = (input: RequestInfo | URL, init?: RequestInit) => Promise + +// ============================================================================ +// JWT / Account ID +// ============================================================================ + +function base64UrlDecode(str: string): string { + let base64 = str.replace(/-/g, '+').replace(/_/g, '/') + const pad = base64.length % 4 + if (pad === 2) base64 += '==' + else if (pad === 3) base64 += '=' + return Buffer.from(base64, 'base64').toString('utf-8') +} + +export function extractChatGptAccountId(accessToken: string): string | null { + try { + const parts = accessToken.split('.') + if (parts.length !== 3) return null + const payload = JSON.parse(base64UrlDecode(parts[1])) + const auth = payload?.['https://api.openai.com/auth'] + return typeof auth?.chatgpt_account_id === 'string' + ? auth.chatgpt_account_id + : null + } catch { + return null + } +} + +// ============================================================================ +// Request Transform: Chat Completions → Responses API +// ============================================================================ + +interface ChatCompletionsToolCall { + id: string + type: string + function: { name: string; arguments: string } +} + +interface ChatCompletionsMessage { + role: string + content?: unknown + tool_calls?: ChatCompletionsToolCall[] + tool_call_id?: string +} + +interface ChatCompletionsTool { + type: string + function?: { + name: string + description?: string + parameters?: unknown + strict?: boolean + } +} + +function convertUserContentParts(content: unknown): unknown { + if (typeof content === 'string') return content + if (!Array.isArray(content)) return String(content ?? '') + return content.map((part: Record) => { + if (part.type === 'text') { + return { type: 'input_text', text: part.text } + } + if (part.type === 'image_url') { + const imageUrl = part.image_url as Record | undefined + return { + type: 'input_image', + image_url: imageUrl?.url ?? imageUrl, + } + } + return part + }) +} + +function convertMessages( + messages: ChatCompletionsMessage[], +): unknown[] { + const input: unknown[] = [] + + for (const msg of messages) { + switch (msg.role) { + case 'system': { + // System messages are extracted to top-level `instructions` field; + // if any slip through, convert to developer role + if (msg.content) { + input.push({ type: 'message', role: 'developer', content: msg.content }) + } + break + } + + case 'user': { + const content = convertUserContentParts(msg.content) + if (content) { + input.push({ type: 'message', role: 'user', content }) + } + break + } + + case 'assistant': { + if (msg.content) { + input.push({ type: 'message', role: 'assistant', content: msg.content }) + } + if (msg.tool_calls) { + for (const tc of msg.tool_calls) { + input.push({ + type: 'function_call', + call_id: tc.id, + name: tc.function.name, + arguments: tc.function.arguments, + }) + } + } + break + } + + case 'tool': { + input.push({ + type: 'function_call_output', + call_id: msg.tool_call_id ?? 'unknown', + output: + typeof msg.content === 'string' + ? msg.content + : JSON.stringify(msg.content), + }) + break + } + } + } + + return input +} + +function convertTools(tools: ChatCompletionsTool[]): unknown[] { + return tools.map((tool) => { + if (tool.type === 'function' && tool.function) { + return { + type: 'function', + name: tool.function.name, + description: tool.function.description, + parameters: tool.function.parameters, + ...(tool.function.strict !== undefined && { + strict: tool.function.strict, + }), + } + } + return tool + }) +} + +function transformRequestBody( + body: Record, +): Record { + const messages = (body.messages ?? []) as ChatCompletionsMessage[] + const tools = body.tools as ChatCompletionsTool[] | undefined + + // Extract system messages into the top-level `instructions` field + // (required by the ChatGPT backend API) + const systemMessages = messages.filter((m) => m.role === 'system') + const nonSystemMessages = messages.filter((m) => m.role !== 'system') + const instructions = systemMessages + .map((m) => (typeof m.content === 'string' ? m.content : JSON.stringify(m.content))) + .join('\n\n') + + const transformed: Record = { + model: body.model, + instructions: instructions || 'You are a helpful assistant.', + input: convertMessages(nonSystemMessages), + stream: true, + store: false, + include: ['reasoning.encrypted_content'], + } + + if (tools?.length) { + transformed.tools = convertTools(tools) + } + if (body.tool_choice != null) { + transformed.tool_choice = body.tool_choice + } + + // The ChatGPT backend does not support: max_output_tokens, max_tokens, + // temperature, top_p, stop, frequency_penalty, presence_penalty, logprobs, + // n, stream_options — omit them all. + + const reasoningEffort = body.reasoning_effort as string | undefined + transformed.reasoning = { + effort: reasoningEffort || 'high', + summary: 'auto', + } + + transformed.text = { verbosity: 'medium' } + + return transformed +} + +// ============================================================================ +// Response Transform: Responses API SSE → Chat Completions SSE +// ============================================================================ + +function createSseTransformStream(): TransformStream { + const encoder = new TextEncoder() + const decoder = new TextDecoder() + + let buffer = '' + let responseId: string | null = null + let responseModel: string | null = null + let nextToolCallIndex = 0 + const outputIndexToToolIndex = new Map() + let emittedRole = false + + function emit( + controller: TransformStreamDefaultController, + chunk: Record, + ) { + controller.enqueue(encoder.encode(`data: ${JSON.stringify(chunk)}\n\n`)) + } + + function processEvent( + controller: TransformStreamDefaultController, + data: Record, + ) { + const type = data.type as string | undefined + if (!type) return + + switch (type) { + case 'response.created': { + const resp = data.response as Record | undefined + responseId = (resp?.id as string) ?? null + responseModel = (resp?.model as string) ?? null + if (!emittedRole) { + emit(controller, { + id: responseId, + model: responseModel, + choices: [ + { index: 0, delta: { role: 'assistant' }, finish_reason: null }, + ], + }) + emittedRole = true + } + break + } + + case 'response.output_text.delta': { + emit(controller, { + id: responseId, + choices: [ + { + index: 0, + delta: { content: data.delta as string }, + finish_reason: null, + }, + ], + }) + break + } + + case 'response.reasoning_summary_text.delta': { + emit(controller, { + id: responseId, + choices: [ + { + index: 0, + delta: { reasoning_content: data.delta as string }, + finish_reason: null, + }, + ], + }) + break + } + + case 'response.output_item.added': { + const item = data.item as Record | undefined + if (item?.type === 'function_call') { + const tcIndex = nextToolCallIndex++ + const outputIdx = (data.output_index as number) ?? 0 + outputIndexToToolIndex.set(outputIdx, tcIndex) + emit(controller, { + id: responseId, + choices: [ + { + index: 0, + delta: { + tool_calls: [ + { + index: tcIndex, + id: (item.call_id as string) ?? (item.id as string), + function: { + name: item.name as string, + arguments: '', + }, + }, + ], + }, + finish_reason: null, + }, + ], + }) + } + break + } + + case 'response.function_call_arguments.delta': { + const outputIdx = (data.output_index as number) ?? 0 + const tcIdx = outputIndexToToolIndex.get(outputIdx) ?? 0 + emit(controller, { + id: responseId, + choices: [ + { + index: 0, + delta: { + tool_calls: [ + { + index: tcIdx, + function: { arguments: data.delta as string }, + }, + ], + }, + finish_reason: null, + }, + ], + }) + break + } + + case 'response.completed': + case 'response.done': { + const resp = data.response as Record | undefined + const usage = resp?.usage as Record | undefined + const status = resp?.status as string | undefined + + let finishReason = 'stop' + if (status === 'incomplete') { + finishReason = 'length' + } else if (nextToolCallIndex > 0) { + finishReason = 'tool_calls' + } + + const chunk: Record = { + id: responseId, + choices: [ + { index: 0, delta: {}, finish_reason: finishReason }, + ], + } + + if (usage) { + const outputDetails = usage.output_tokens_details as + | Record + | undefined + chunk.usage = { + prompt_tokens: usage.input_tokens, + completion_tokens: usage.output_tokens, + total_tokens: usage.total_tokens, + ...(outputDetails?.reasoning_tokens != null && { + completion_tokens_details: { + reasoning_tokens: outputDetails.reasoning_tokens, + }, + }), + } + } + + emit(controller, chunk) + controller.enqueue(encoder.encode('data: [DONE]\n\n')) + break + } + + case 'response.failed': { + const resp = data.response as Record | undefined + const errorObj = (resp?.error ?? data.error) as + | Record + | undefined + emit(controller, { + error: { + message: + (errorObj?.message as string) ?? + 'ChatGPT backend request failed', + type: (errorObj?.type as string) ?? 'server_error', + }, + }) + controller.enqueue(encoder.encode('data: [DONE]\n\n')) + break + } + + case 'error': { + const errorObj = (data.error ?? data) as Record + emit(controller, { + error: { + message: + (errorObj.message as string) ?? + 'Unknown error from ChatGPT backend', + type: (errorObj.type as string) ?? 'server_error', + }, + }) + break + } + + // Skip all other events silently (content_part.added, output_text.done, etc.) + } + } + + return new TransformStream({ + transform(chunk, controller) { + buffer += decoder.decode(chunk, { stream: true }) + + const lines = buffer.split('\n') + buffer = lines.pop() ?? '' + + for (const line of lines) { + if (!line.startsWith('data: ')) continue + + const jsonStr = line.slice(6).trim() + if (!jsonStr || jsonStr === '[DONE]') { + continue + } + + try { + const parsed = JSON.parse(jsonStr) as Record + processEvent(controller, parsed) + } catch { + // Skip unparseable lines + } + } + }, + + flush(controller) { + if (buffer.trim().startsWith('data: ')) { + const jsonStr = buffer.trim().slice(6).trim() + if (jsonStr && jsonStr !== '[DONE]') { + try { + const parsed = JSON.parse(jsonStr) as Record + processEvent(controller, parsed) + } catch { + // skip + } + } + } + }, + }) +} + +function transformResponseStream( + inputStream: ReadableStream, +): ReadableStream { + const transform = createSseTransformStream() + inputStream.pipeTo(transform.writable).catch(() => {}) + return transform.readable +} + +// ============================================================================ +// Custom Fetch +// ============================================================================ + +export function createChatGptBackendFetch(): FetchFunction { + const fetchFn: FetchLike = async ( + input: RequestInfo | URL, + init?: RequestInit, + ): Promise => { + let transformedInit = init + + if (init?.body && typeof init.body === 'string') { + try { + const body = JSON.parse(init.body) as Record + const transformedBody = transformRequestBody(body) + transformedInit = { ...init, body: JSON.stringify(transformedBody) } + } catch { + // If body can't be parsed, pass through unchanged + } + } + + const response = await globalThis.fetch(input, transformedInit) + + if (!response.ok) { + // Map 404 usage-limit errors to 429 (same as opencode plugin) + if (response.status === 404) { + try { + const text = await response.clone().text() + if (/usage_limit|rate_limit/i.test(text)) { + return new Response(text, { + status: 429, + statusText: 'Too Many Requests', + headers: response.headers, + }) + } + } catch { + // Fall through to return original response + } + } + return response + } + + if (!response.body) return response + + const transformedStream = transformResponseStream(response.body) + + return new Response(transformedStream, { + status: response.status, + statusText: response.statusText, + headers: new Headers({ + 'content-type': 'text/event-stream; charset=utf-8', + }), + }) + } + + return fetchFn as FetchFunction +} diff --git a/sdk/src/impl/llm.ts b/sdk/src/impl/llm.ts index 37ed3a13b8..60bb678bb1 100644 --- a/sdk/src/impl/llm.ts +++ b/sdk/src/impl/llm.ts @@ -1,4 +1,5 @@ import { AnalyticsEvent } from '@codebuff/common/constants/analytics-events' +import { isFreeMode } from '@codebuff/common/constants/free-agents' import { models, PROFIT_MARGIN } from '@codebuff/common/old-constants' import { buildArray } from '@codebuff/common/util/array' import { normalizeProviderRequestBodyForCacheDebug } from '@codebuff/common/util/cache-debug' @@ -17,8 +18,11 @@ import { TypeValidationError, } from 'ai' -import { getModelForRequest, markClaudeOAuthRateLimited, fetchClaudeOAuthResetTime } from './model-provider' -import { getValidClaudeOAuthCredentials } from '../credentials' +import { + getModelForRequest, + markChatGptOAuthRateLimited, +} from './model-provider' +import { refreshChatGptOAuthToken } from '../credentials' import { getErrorStatusCode } from '../error-utils' import type { ModelRequestParams } from './model-provider' @@ -56,7 +60,7 @@ function calculateUsedCredits(params: { costDollars: number }): number { return Math.round(costDollars * (1 + PROFIT_MARGIN) * 100) } -function getProviderOptions(params: { +export function getProviderOptions(params: { model: string runId: string clientSessionId: string @@ -65,6 +69,7 @@ function getProviderOptions(params: { n?: number costMode?: string cacheDebugCorrelation?: string + extraCodebuffMetadata?: Record }): { codebuff: JSONObject } { const { model, @@ -75,6 +80,7 @@ function getProviderOptions(params: { n, costMode, cacheDebugCorrelation, + extraCodebuffMetadata, } = params let providerConfig: Record @@ -99,6 +105,9 @@ function getProviderOptions(params: { ...providerOptions?.codebuff, // All values here get appended to the request body codebuff_metadata: { + // Caller-supplied keys go first so they can't override reserved + // identifiers like run_id/client_id/cost_mode that the server trusts. + ...(extraCodebuffMetadata ?? {}), run_id: runId, client_id: clientSessionId, ...(n && { n }), @@ -122,9 +131,9 @@ type OpenRouterUsageAccounting = { } /** - * Check if an error is a Claude OAuth rate limit error that should trigger fallback. + * Check if an error is an OAuth rate limit error that should trigger fallback. */ -function isClaudeOAuthRateLimitError(error: unknown): boolean { +function isOAuthRateLimitError(error: unknown): boolean { if (!error || typeof error !== 'object') return false // Check status code (handles both 'status' from AI SDK and 'statusCode' from our errors) @@ -141,10 +150,9 @@ function isClaudeOAuthRateLimitError(error: unknown): boolean { if (message.includes('rate_limit') || message.includes('rate limit')) return true - if (message.includes('overloaded')) return true if ( responseBody.includes('rate_limit') || - responseBody.includes('overloaded') + responseBody.includes('rate limit') ) return true @@ -152,10 +160,10 @@ function isClaudeOAuthRateLimitError(error: unknown): boolean { } /** - * Check if an error is a Claude OAuth authentication error (expired/invalid token). + * Check if an error is an OAuth authentication error (expired/invalid token). * This indicates we should try refreshing the token. */ -function isClaudeOAuthAuthError(error: unknown): boolean { +function isOAuthAuthError(error: unknown): boolean { if (!error || typeof error !== 'object') return false // Check status code (handles both 'status' from AI SDK and 'statusCode' from our errors) @@ -240,12 +248,46 @@ function emitCacheDebugUsage(params: { }) } +export type ChatGptOAuthStreamErrorPolicy = + | 'fallback-rate-limit' + | 'fail-auth-reconnect' + | 'fail-fast' + | 'ignore' + +export function classifyChatGptOAuthStreamError(params: { + isChatGptOAuth: boolean + skipChatGptOAuth?: boolean + hasYieldedContent: boolean + error: unknown +}): ChatGptOAuthStreamErrorPolicy { + const { isChatGptOAuth, skipChatGptOAuth, hasYieldedContent, error } = params + + if (!isChatGptOAuth || skipChatGptOAuth || hasYieldedContent) { + return 'ignore' + } + + if (isOAuthRateLimitError(error)) { + return 'fallback-rate-limit' + } + + if (isOAuthAuthError(error)) { + return 'fail-auth-reconnect' + } + + return 'fail-fast' +} + export async function* promptAiSdkStream( params: ParamsOf & { - skipClaudeOAuth?: boolean - onClaudeOAuthStatusChange?: (isActive: boolean) => void + skipChatGptOAuth?: boolean + chatGptOAuthRetried?: boolean }, ): ReturnType { + const { + providerOptions: originalProviderOptions, + ...streamParams + } = params + const { logger, trackEvent, userId, userInputId, model: requestedModel } = params const agentChunkMetadata = params.agentId != null ? { agentId: params.agentId } : undefined @@ -264,14 +306,15 @@ export async function* promptAiSdkStream( const modelParams: ModelRequestParams = { apiKey: params.apiKey, model: params.model, - skipClaudeOAuth: params.skipClaudeOAuth, + skipChatGptOAuth: params.skipChatGptOAuth, + costMode: params.costMode, } - const { model: aiSDKModel, isClaudeOAuth } = await getModelForRequest(modelParams) + const { model: aiSDKModel, isChatGptOAuth } = + await getModelForRequest(modelParams) - // Track and notify about Claude OAuth usage - if (isClaudeOAuth) { + if (isChatGptOAuth) { trackEvent({ - event: AnalyticsEvent.CLAUDE_OAUTH_REQUEST, + event: AnalyticsEvent.CHATGPT_OAUTH_REQUEST, userId: userId ?? '', properties: { model: requestedModel, @@ -279,24 +322,24 @@ export async function* promptAiSdkStream( }, logger, }) - if (params.onClaudeOAuthStatusChange) { - params.onClaudeOAuthStatusChange(true) - } } const response = streamText({ - ...params, + ...streamParams, prompt: undefined, model: aiSDKModel, messages: convertCbToModelMessages(params), - // When using Claude OAuth, disable retries so we can immediately fall back to Codebuff - // backend on rate limit errors instead of retrying 4 times first - ...(isClaudeOAuth && { maxRetries: 0 }), - providerOptions: getProviderOptions({ - ...params, - agentProviderOptions: params.agentProviderOptions, - cacheDebugCorrelation: params.cacheDebugCorrelation, - }), + ...(isChatGptOAuth && { maxRetries: 0 }), + // For ChatGPT OAuth direct, don't send codebuff metadata/provider options to OpenAI + ...(isChatGptOAuth + ? {} + : { + providerOptions: getProviderOptions({ + ...params, + providerOptions: originalProviderOptions, + agentProviderOptions: params.agentProviderOptions, + }), + }), // Handle tool call errors gracefully by passing them through to our validation layer // instead of throwing (which would halt the agent). The only special case is when // the tool name matches a spawnable agent - transform those to spawn_agents calls. @@ -465,20 +508,22 @@ export async function* promptAiSdkStream( continue } - // Check if this is a Claude OAuth rate limit error - only fall back if no content yielded yet - if ( - isClaudeOAuth && - !params.skipClaudeOAuth && - !hasYieldedContent && - isClaudeOAuthRateLimitError(chunkValue.error) - ) { - logger.info( + const chatGptErrorPolicy = classifyChatGptOAuthStreamError({ + isChatGptOAuth, + skipChatGptOAuth: params.skipChatGptOAuth, + hasYieldedContent, + error: chunkValue.error, + }) + + if (chatGptErrorPolicy === 'fallback-rate-limit') { + const rateLimitErrorDetails = chunkValue.error instanceof Error ? chunkValue.error.message : String(chunkValue.error) + logger.warn( { error: getErrorObject(chunkValue.error) }, - 'Claude OAuth rate limited during stream, falling back to Codebuff backend', + 'ChatGPT OAuth rate limited during stream', ) - // Track the rate limit event + trackEvent({ - event: AnalyticsEvent.CLAUDE_OAUTH_RATE_LIMITED, + event: AnalyticsEvent.CHATGPT_OAUTH_RATE_LIMITED, userId: userId ?? '', properties: { model: requestedModel, @@ -486,38 +531,31 @@ export async function* promptAiSdkStream( }, logger, }) - // Try to get the actual reset time from the quota API, fall back to default cooldown - const credentials = await getValidClaudeOAuthCredentials() - const resetTime = credentials?.accessToken - ? await fetchClaudeOAuthResetTime(credentials.accessToken) - : null - // Mark as rate-limited so subsequent requests skip Claude OAuth - markClaudeOAuthRateLimited(resetTime ?? undefined) - if (params.onClaudeOAuthStatusChange) { - params.onClaudeOAuthStatusChange(false) + + markChatGptOAuthRateLimited() + + // In free mode, don't fall back to Codebuff backend — fail instead + if (isFreeMode(params.costMode)) { + throw new Error( + `ChatGPT rate limit reached. Please wait a few minutes and try again. (${rateLimitErrorDetails})`, + ) } - // Retry with Codebuff backend + const fallbackResult = yield* promptAiSdkStream({ ...params, - skipClaudeOAuth: true, + skipChatGptOAuth: true, }) return fallbackResult } - // Check if this is a Claude OAuth authentication error (expired token) - only fall back if no content yielded yet - if ( - isClaudeOAuth && - !params.skipClaudeOAuth && - !hasYieldedContent && - isClaudeOAuthAuthError(chunkValue.error) - ) { + if (chatGptErrorPolicy === 'fail-auth-reconnect') { logger.info( { error: getErrorObject(chunkValue.error) }, - 'Claude OAuth auth error during stream, falling back to Codebuff backend', + 'ChatGPT OAuth auth error during stream, attempting token refresh', ) - // Track the auth error event + trackEvent({ - event: AnalyticsEvent.CLAUDE_OAUTH_AUTH_ERROR, + event: AnalyticsEvent.CHATGPT_OAUTH_AUTH_ERROR, userId: userId ?? '', properties: { model: requestedModel, @@ -525,13 +563,33 @@ export async function* promptAiSdkStream( }, logger, }) - if (params.onClaudeOAuthStatusChange) { - params.onClaudeOAuthStatusChange(false) + + // Try refreshing the token and retrying once before failing/falling back + if (!params.chatGptOAuthRetried) { + const refreshed = await refreshChatGptOAuthToken() + if (refreshed) { + logger.info({ model: requestedModel }, 'ChatGPT OAuth token refreshed, retrying request') + const retryResult = yield* promptAiSdkStream({ + ...params, + chatGptOAuthRetried: true, + }) + return retryResult + } + logger.warn({ model: requestedModel }, 'ChatGPT OAuth token refresh failed, unable to recover') + } + + // Refresh failed or already retried + // In free mode, don't fall back to Codebuff backend — fail instead + if (isFreeMode(params.costMode)) { + throw new Error( + 'ChatGPT OAuth authentication failed. Please reconnect with /connect:chatgpt and try again.', + ) } - // Retry with Codebuff backend (skipClaudeOAuth will bypass the failed OAuth) + + // Fall back to Codebuff backend const fallbackResult = yield* promptAiSdkStream({ ...params, - skipClaudeOAuth: true, + skipChatGptOAuth: true, }) return fallbackResult } @@ -549,21 +607,20 @@ export async function* promptAiSdkStream( throw chunkValue.error } if (chunkValue.type === 'reasoning-delta') { - for (const provider of ['openrouter', 'codebuff'] as const) { - if ( + const reasoningExcluded = (['openrouter', 'codebuff'] as const).some( + (p) => ( - params.providerOptions?.[provider] as - | OpenRouterProviderOptions - | undefined - )?.reasoning?.exclude - ) { - continue + params.providerOptions?.[p] as + | OpenRouterProviderOptions + | undefined + )?.reasoning?.exclude, + ) + if (!reasoningExcluded) { + yield { + type: 'reasoning', + text: chunkValue.text, } } - yield { - type: 'reasoning', - text: chunkValue.text, - } } if (chunkValue.type === 'text-delta') { if (!params.stopSequences) { @@ -617,8 +674,8 @@ export async function* promptAiSdkStream( usage: usageResult, }) - // Skip cost tracking for Claude OAuth (user is on their own subscription) - if (!isClaudeOAuth) { + // Skip cost tracking for ChatGPT OAuth (user is on their own subscription) + if (!isChatGptOAuth) { const providerMetadataResult = await response.providerMetadata const providerMetadata = providerMetadataResult ?? {} @@ -664,7 +721,7 @@ export async function promptAiSdk( const modelParams: ModelRequestParams = { apiKey: params.apiKey, model: params.model, - skipClaudeOAuth: true, // Always use Codebuff backend for non-streaming + skipChatGptOAuth: true, // Always use Codebuff backend for non-streaming } const { model: aiSDKModel } = await getModelForRequest(modelParams) @@ -731,7 +788,7 @@ export async function promptAiSdkStructured( const modelParams: ModelRequestParams = { apiKey: params.apiKey, model: params.model, - skipClaudeOAuth: true, // Always use Codebuff backend for non-streaming + skipChatGptOAuth: true, // Always use Codebuff backend for non-streaming } const { model: aiSDKModel } = await getModelForRequest(modelParams) diff --git a/sdk/src/impl/model-provider.ts b/sdk/src/impl/model-provider.ts index 797d13daf3..83e016c611 100644 --- a/sdk/src/impl/model-provider.ts +++ b/sdk/src/impl/model-provider.ts @@ -2,132 +2,76 @@ * Model provider abstraction for routing requests to the appropriate LLM provider. * * This module handles: - * - Claude OAuth: Direct requests to Anthropic API using user's OAuth token + * - ChatGPT OAuth: Direct requests to OpenAI API using user's OAuth token * - Default: Requests through Codebuff backend (which routes to OpenRouter) */ import path from 'path' -import { createAnthropic } from '@ai-sdk/anthropic' import { BYOK_OPENROUTER_HEADER } from '@codebuff/common/constants/byok' +import { isFreeMode } from '@codebuff/common/constants/free-agents' import { - CLAUDE_CODE_SYSTEM_PROMPT_PREFIX, - CLAUDE_OAUTH_BETA_HEADERS, - CLAUDE_OAUTH_ENABLED, - isClaudeModel, - toAnthropicModelId, -} from '@codebuff/common/constants/claude-oauth' + CHATGPT_BACKEND_BASE_URL, + CHATGPT_OAUTH_ENABLED, + isChatGptOAuthModelAllowed, + isOpenAIProviderModel, + toOpenAIModelId, +} from '@codebuff/common/constants/chatgpt-oauth' import { OpenAICompatibleChatLanguageModel, VERSION, } from '@codebuff/internal/openai-compatible/index' import { WEBSITE_URL } from '../constants' -import { getValidClaudeOAuthCredentials } from '../credentials' +import { + getValidChatGptOAuthCredentials, +} from '../credentials' import { getByokOpenrouterApiKeyFromEnv } from '../env' +import { + createChatGptBackendFetch, + extractChatGptAccountId, +} from './chatgpt-backend-fetch' import type { LanguageModel } from 'ai' // ============================================================================ -// Claude OAuth Rate Limit Cache +// ChatGPT OAuth Rate Limit Cache // ============================================================================ -/** Timestamp (ms) when Claude OAuth rate limit expires, or null if not rate-limited */ -let claudeOAuthRateLimitedUntil: number | null = null +/** Timestamp (ms) when ChatGPT OAuth rate limit expires, or null if not rate-limited */ +let chatGptOAuthRateLimitedUntil: number | null = null /** - * Mark Claude OAuth as rate-limited. Subsequent requests will skip Claude OAuth + * Mark ChatGPT OAuth as rate-limited. Subsequent requests will skip direct ChatGPT OAuth * and use Codebuff backend until the reset time. - * @param resetAt - When the rate limit resets. If not provided, guesses 5 minutes from now. */ -export function markClaudeOAuthRateLimited(resetAt?: Date): void { +export function markChatGptOAuthRateLimited(resetAt?: Date): void { const fiveMinutesFromNow = Date.now() + 5 * 60 * 1000 - claudeOAuthRateLimitedUntil = resetAt ? resetAt.getTime() : fiveMinutesFromNow + chatGptOAuthRateLimitedUntil = resetAt + ? resetAt.getTime() + : fiveMinutesFromNow } /** - * Check if Claude OAuth is currently rate-limited. - * Returns true if rate-limited and reset time hasn't passed. + * Check if ChatGPT OAuth is currently rate-limited. */ -export function isClaudeOAuthRateLimited(): boolean { - if (claudeOAuthRateLimitedUntil === null) { +export function isChatGptOAuthRateLimited(): boolean { + if (chatGptOAuthRateLimitedUntil === null) { return false } - if (Date.now() >= claudeOAuthRateLimitedUntil) { - // Rate limit expired, clear the cache - claudeOAuthRateLimitedUntil = null + if (Date.now() >= chatGptOAuthRateLimitedUntil) { + chatGptOAuthRateLimitedUntil = null return false } return true } /** - * Reset the Claude OAuth rate limit cache. - * Call this when user reconnects their Claude subscription. - */ -export function resetClaudeOAuthRateLimit(): void { - claudeOAuthRateLimitedUntil = null -} - -// ============================================================================ -// Claude OAuth Quota Fetching -// ============================================================================ - -interface ClaudeQuotaWindow { - utilization: number - resets_at: string | null -} - -interface ClaudeQuotaResponse { - five_hour: ClaudeQuotaWindow | null - seven_day: ClaudeQuotaWindow | null - seven_day_oauth_apps: ClaudeQuotaWindow | null - seven_day_opus: ClaudeQuotaWindow | null -} - -/** - * Fetch the rate limit reset time from Anthropic's quota API. - * Returns the earliest reset time (whichever limit is more restrictive). - * Returns null if fetch fails or no reset time is available. + * Reset the ChatGPT OAuth rate-limit cache. + * Call this when user reconnects their ChatGPT subscription. */ -export async function fetchClaudeOAuthResetTime(accessToken: string): Promise { - try { - const response = await fetch('https://api.anthropic.com/api/oauth/usage', { - method: 'GET', - headers: { - Authorization: `Bearer ${accessToken}`, - Accept: 'application/json', - 'Content-Type': 'application/json', - 'anthropic-version': '2023-06-01', - 'anthropic-beta': 'oauth-2025-04-20,claude-code-20250219', - }, - }) - - if (!response.ok) { - return null - } - - const responseBody = await response.json() - const data = responseBody as ClaudeQuotaResponse - - // Parse reset times - const fiveHour = data.five_hour - const sevenDay = data.seven_day - - const fiveHourRemaining = fiveHour ? Math.max(0, 100 - fiveHour.utilization) : 100 - const sevenDayRemaining = sevenDay ? Math.max(0, 100 - sevenDay.utilization) : 100 - - // Return the reset time for whichever limit is more restrictive (lower remaining) - if (fiveHourRemaining <= sevenDayRemaining && fiveHour?.resets_at) { - return new Date(fiveHour.resets_at) - } else if (sevenDay?.resets_at) { - return new Date(sevenDay.resets_at) - } - - return null - } catch { - return null - } +export function resetChatGptOAuthRateLimit(): void { + chatGptOAuthRateLimitedUntil = null } /** @@ -138,8 +82,10 @@ export interface ModelRequestParams { apiKey: string /** Model ID (OpenRouter format, e.g., "anthropic/claude-sonnet-4") */ model: string - /** If true, skip Claude OAuth and use Codebuff backend (for fallback after rate limit) */ - skipClaudeOAuth?: boolean + /** If true, skip ChatGPT OAuth and use Codebuff backend (for fallback after rate limit) */ + skipChatGptOAuth?: boolean + /** Cost mode (e.g. 'free') — affects fallback behavior for OAuth routes */ + costMode?: string } /** @@ -148,8 +94,8 @@ export interface ModelRequestParams { export interface ModelResult { /** The language model to use for requests */ model: LanguageModel - /** Whether this model uses Claude OAuth direct (affects cost tracking) */ - isClaudeOAuth: boolean + /** Whether this model uses ChatGPT OAuth direct (affects cost tracking) */ + isChatGptOAuth: boolean } // Usage accounting type for OpenRouter/Codebuff backend responses @@ -163,26 +109,45 @@ type OpenRouterUsageAccounting = { /** * Get the appropriate model for a request. * - * If Claude OAuth credentials are available and the model is a Claude model, - * returns an Anthropic direct model. Otherwise, returns the Codebuff backend model. + * If ChatGPT OAuth credentials are available and the model is an OpenAI model, + * returns an OpenAI direct model. Otherwise, returns the Codebuff backend model. * * This function is async because it may need to refresh the OAuth token. */ export async function getModelForRequest(params: ModelRequestParams): Promise { - const { apiKey, model, skipClaudeOAuth } = params + const { apiKey, model, skipChatGptOAuth, costMode } = params + + // Check if we should use ChatGPT OAuth direct + // Only attempt for allowlisted models; non-allowlisted models silently fall through to backend. + if ( + CHATGPT_OAUTH_ENABLED && + !skipChatGptOAuth && + isOpenAIProviderModel(model) && + isChatGptOAuthModelAllowed(model) + ) { + // In free mode, rate-limited ChatGPT OAuth must not silently fall through to + // the Codebuff backend — freebuff should only use the direct OpenAI route or fail. + if (isChatGptOAuthRateLimited()) { + if (isFreeMode(costMode)) { + throw new Error( + 'ChatGPT rate limit reached. Please wait a few minutes and try again.', + ) + } + } else { + const chatGptOAuthCredentials = await getValidChatGptOAuthCredentials() + + if (chatGptOAuthCredentials) { + return { + model: createOpenAIOAuthModel(model, chatGptOAuthCredentials.accessToken), + isChatGptOAuth: true, + } + } - // Check if we should use Claude OAuth direct - // Skip if feature disabled, explicitly requested, if rate-limited, or if not a Claude model - if (CLAUDE_OAUTH_ENABLED && !skipClaudeOAuth && !isClaudeOAuthRateLimited() && isClaudeModel(model)) { - // Get valid credentials (will refresh if needed) - const claudeOAuthCredentials = await getValidClaudeOAuthCredentials() - if (claudeOAuthCredentials) { - return { - model: createAnthropicOAuthModel( - model, - claudeOAuthCredentials.accessToken, - ), - isClaudeOAuth: true, + // In free mode, if credentials are unavailable, don't fall through to backend. + if (isFreeMode(costMode)) { + throw new Error( + 'ChatGPT OAuth credentials unavailable. Please reconnect with /connect:chatgpt.', + ) } } } @@ -190,107 +155,34 @@ export async function getModelForRequest(params: ModelRequestParams): Promise => { - const headers = new Headers(init?.headers) - - // Remove the x-api-key header that the SDK adds - headers.delete('x-api-key') - - // Add Bearer token authentication (for OAuth) - headers.set('Authorization', `Bearer ${oauthToken}`) - - // Add required beta headers for OAuth (same as opencode) - // These beta headers are required to access Claude 4+ models with OAuth - const existingBeta = headers.get('anthropic-beta') ?? '' - const betaList = existingBeta - .split(',') - .map((b) => b.trim()) - .filter(Boolean) - const mergedBetas = [ - ...new Set([...CLAUDE_OAUTH_BETA_HEADERS, ...betaList]), - ].join(',') - headers.set('anthropic-beta', mergedBetas) +function createOpenAIOAuthModel(model: string, oauthToken: string): LanguageModel { + const openAIModelId = toOpenAIModelId(model) + const accountId = extractChatGptAccountId(oauthToken) - // Transform the request body to use the correct system prompt format for Claude OAuth - // Anthropic requires the system prompt to be split into two separate blocks: - // 1. First block: Claude Code identifier (required for OAuth access) - // 2. Second block: The actual system prompt (if any) - let modifiedInit = init - if (init?.body && typeof init.body === 'string') { - try { - const body = JSON.parse(init.body) - // Always inject the Claude Code identifier for OAuth requests - // Extract existing system prompt if present - const existingSystem = body.system - ? Array.isArray(body.system) - ? body.system - .map( - (s: { text?: string; content?: string }) => - s.text ?? s.content ?? '', - ) - .join('\n\n') - : typeof body.system === 'string' - ? body.system - : '' - : '' - - // Build the system array with Claude Code identifier first - body.system = [ - { - type: 'text', - text: CLAUDE_CODE_SYSTEM_PROMPT_PREFIX, - }, - // Only add second block if there's actual content - ...(existingSystem - ? [ - { - type: 'text', - text: existingSystem, - }, - ] - : []), - ] - modifiedInit = { ...init, body: JSON.stringify(body) } - } catch { - // If parsing fails, continue with original body - } - } - - return globalThis.fetch(input, { - ...modifiedInit, - headers, - }) - } - - // Pass empty apiKey like opencode does - this prevents the SDK from adding x-api-key header - // The custom fetch will add the Bearer token instead - const anthropic = createAnthropic({ - apiKey: '', - fetch: customFetch as unknown as typeof globalThis.fetch, + return new OpenAICompatibleChatLanguageModel(openAIModelId, { + provider: 'openai', + url: () => `${CHATGPT_BACKEND_BASE_URL}/codex/responses`, + headers: () => ({ + Authorization: `Bearer ${oauthToken}`, + 'Content-Type': 'application/json', + 'OpenAI-Beta': 'responses=experimental', + originator: 'codex_cli_rs', + accept: 'text/event-stream', + 'user-agent': `ai-sdk/openai-compatible/${VERSION}/codebuff-chatgpt-oauth`, + ...(accountId ? { 'chatgpt-account-id': accountId } : {}), + }), + fetch: createChatGptBackendFetch(), + supportsStructuredOutputs: true, + includeUsage: undefined, }) - - // Cast to LanguageModel since the AI SDK types may be slightly different versions - // Using unknown as intermediate to handle V2 vs V3 differences - return anthropic(anthropicModelId) as unknown as LanguageModel } /** diff --git a/sdk/src/index.ts b/sdk/src/index.ts index bcd41e6af3..4b04f03af4 100644 --- a/sdk/src/index.ts +++ b/sdk/src/index.ts @@ -82,7 +82,11 @@ export { export type { CodebuffFileSystem } from '@codebuff/common/types/filesystem' // Tree-sitter / code-map exports -export { getFileTokenScores, setWasmDir } from '@codebuff/code-map' +export { + getFileTokenScores, + setWasmDir, + setTreeSitterWasmPath, +} from '@codebuff/code-map' export type { FileTokenData, TokenCallerMap } from '@codebuff/code-map' export { runTerminalCommand } from './tools/run-terminal-command' @@ -91,4 +95,6 @@ export { promptAiSdkStream, promptAiSdkStructured, } from './impl/llm' -export { resetClaudeOAuthRateLimit } from './impl/model-provider' +export { + resetChatGptOAuthRateLimit, +} from './impl/model-provider' diff --git a/sdk/src/run-state.ts b/sdk/src/run-state.ts index 7752c26fd2..7fcc35a42b 100644 --- a/sdk/src/run-state.ts +++ b/sdk/src/run-state.ts @@ -2,6 +2,7 @@ import * as os from 'os' import path from 'path' import { getFileTokenScores } from '@codebuff/code-map/parse' +import { getSystemInfo } from '@codebuff/common/util/system-info' import { KNOWLEDGE_FILE_NAMES_LOWERCASE, isKnowledgeFile, @@ -52,9 +53,7 @@ export function selectHighestPriorityKnowledgeFile( ): string | undefined { // Loop through priorities and find the first match directly for (const priorityName of KNOWLEDGE_FILE_NAMES_LOWERCASE) { - const match = candidates.find((f) => - f.toLowerCase().endsWith(priorityName), - ) + const match = candidates.find((f) => f.toLowerCase().endsWith(priorityName)) if (match) return match } return undefined @@ -63,6 +62,7 @@ export function selectHighestPriorityKnowledgeFile( export type RunState = { sessionState?: SessionState output: AgentOutput + traceSessionId: string } export type InitialSessionStateOptions = { @@ -135,26 +135,27 @@ function processCustomToolDefinitions( /** * Computes project file indexes (file tree and token scores) */ -async function computeProjectIndex( - cwd: string, - projectFiles: Record, -): Promise<{ +type ProjectIndexInput = { + cwd: string + fileTree: FileTreeNode[] + filePaths: string[] + readFile?: (filePath: string) => string | null | Promise +} + +const MAX_DISCOVERED_PROJECT_READ_BYTES = 1_000_000 + +async function computeProjectIndex(params: ProjectIndexInput): Promise<{ fileTree: FileTreeNode[] fileTokenScores: Record tokenCallers: Record }> { - const filePaths = Object.keys(projectFiles).sort() - const fileTree = buildFileTree(filePaths) + const { cwd, fileTree, filePaths, readFile } = params let fileTokenScores = {} let tokenCallers = {} if (filePaths.length > 0) { try { - const tokenData = await getFileTokenScores( - cwd, - filePaths, - (filePath: string) => projectFiles[filePath] || null, - ) + const tokenData = await getFileTokenScores(cwd, filePaths, readFile) fileTokenScores = tokenData.tokenScores tokenCallers = tokenData.tokenCallers } catch (error) { @@ -166,6 +167,68 @@ async function computeProjectIndex( return { fileTree, fileTokenScores, tokenCallers } } +function getProjectIndexInput(params: { + cwd: string + fs?: CodebuffFileSystem + logger?: Logger + projectFiles?: Record + discoveredProject?: { fileTree: FileTreeNode[]; filePaths: string[] } +}): ProjectIndexInput | undefined { + const { cwd, fs, logger, projectFiles, discoveredProject } = params + + if (projectFiles) { + const filePaths = Object.keys(projectFiles).sort() + return { + cwd, + fileTree: buildFileTree(filePaths), + filePaths, + readFile: (filePath: string) => projectFiles[filePath] || null, + } + } + + if (discoveredProject) { + if (!fs || !logger) return undefined + + return { + cwd, + fileTree: discoveredProject.fileTree, + filePaths: discoveredProject.filePaths.sort(), + readFile: createDiscoveredProjectReader({ cwd, fs, logger }), + } + } + + return undefined +} + +function createDiscoveredProjectReader(params: { + cwd: string + fs: CodebuffFileSystem + logger: Logger +}): (filePath: string) => Promise { + const { cwd, fs, logger } = params + + return async (filePath: string) => { + const fullPath = path.join(cwd, filePath) + try { + const stats = await fs.stat(fullPath) + if (getFileSize(stats) > MAX_DISCOVERED_PROJECT_READ_BYTES) { + return null + } + return await fs.readFile(fullPath, 'utf8') + } catch (error) { + logger.debug?.( + { filePath, error: getErrorObject(error) }, + 'Failed to read discovered project file for symbol scoring', + ) + return null + } + } +} + +function getFileSize(stats: Awaited>) { + return typeof stats.size === 'number' ? stats.size : 0 +} + /** * Helper to convert ChildProcess to Promise with stdout/stderr */ @@ -260,43 +323,20 @@ async function getGitChanges(params: { } /** - * Discovers project files using .gitignore patterns when projectFiles is undefined + * Discovers project paths using .gitignore patterns when projectFiles is undefined. + * This intentionally does not read every file into memory; large repositories can + * contain generated or binary files that are expensive to retain before parsing. */ -async function discoverProjectFiles(params: { +async function discoverProjectPaths(params: { cwd: string fs: CodebuffFileSystem - logger: Logger -}): Promise> { - const { cwd, fs, logger } = params +}): Promise<{ fileTree: FileTreeNode[]; filePaths: string[] }> { + const { cwd, fs } = params const fileTree = await getProjectFileTree({ projectRoot: cwd, fs }) const filePaths = getAllFilePaths(fileTree) - let error - - // Create projectFiles with empty content - the token scorer will read from disk - const projectFilePromises = Object.fromEntries( - filePaths.map((filePath) => [ - filePath, - fs.readFile(path.join(cwd, filePath), 'utf8').catch((err) => { - error = err - return '[ERROR_READING_FILE]' - }), - ]), - ) - if (error) { - logger.warn( - { error: getErrorObject(error) }, - 'Failed to discover some project files', - ) - } - const projectFilesResolved: Record = {} - for (const [filePath, contentPromise] of Object.entries( - projectFilePromises, - )) { - projectFilesResolved[filePath] = await contentPromise - } - return projectFilesResolved + return { fileTree, filePaths } } /** @@ -321,7 +361,10 @@ export async function loadUserKnowledgeFiles(params: { try { entries = await fs.readdir(homeDir) } catch (error) { - logger.debug?.({ homeDir, error: getErrorObject(error) }, 'Failed to read home directory') + logger.debug?.( + { homeDir, error: getErrorObject(error) }, + 'Failed to read home directory', + ) return userKnowledgeFiles } @@ -350,7 +393,10 @@ export async function loadUserKnowledgeFiles(params: { // Only use the first file found (highest priority) break } catch (error) { - logger.debug?.({ filePath, error: getErrorObject(error) }, 'Failed to read user knowledge file') + logger.debug?.( + { filePath, error: getErrorObject(error) }, + 'Failed to read user knowledge file', + ) } } } @@ -406,6 +452,32 @@ function deriveKnowledgeFiles( return knowledgeFiles } +async function loadKnowledgeFilesFromPaths(params: { + cwd: string + filePaths: string[] + fs: CodebuffFileSystem + logger: Logger +}): Promise> { + const { cwd, filePaths, fs, logger } = params + const selectedFilePaths = selectKnowledgeFilePaths(filePaths) + + const knowledgeFiles: Record = {} + for (const filePath of selectedFilePaths) { + try { + knowledgeFiles[filePath] = await fs.readFile( + path.join(cwd, filePath), + 'utf8', + ) + } catch (error) { + logger.debug?.( + { filePath, error: getErrorObject(error) }, + 'Failed to read project knowledge file', + ) + } + } + return knowledgeFiles +} + export async function initialSessionState( params: InitialSessionStateOptions, ): Promise { @@ -442,12 +514,27 @@ export async function initialSessionState( } } + let discoveredProject: + | { fileTree: FileTreeNode[]; filePaths: string[] } + | undefined + // Auto-discover project files if not provided and cwd is available if (projectFiles === undefined && cwd) { - projectFiles = await discoverProjectFiles({ cwd, fs, logger }) + discoveredProject = await discoverProjectPaths({ cwd, fs }) } if (knowledgeFiles === undefined) { - knowledgeFiles = projectFiles ? deriveKnowledgeFiles(projectFiles) : {} + if (projectFiles) { + knowledgeFiles = deriveKnowledgeFiles(projectFiles) + } else if (cwd && discoveredProject) { + knowledgeFiles = await loadKnowledgeFilesFromPaths({ + cwd, + filePaths: discoveredProject.filePaths, + fs, + logger, + }) + } else { + knowledgeFiles = {} + } } let processedAgentTemplates: Record = {} @@ -460,13 +547,15 @@ export async function initialSessionState( customToolDefinitions, ) - // Generate file tree and token scores from projectFiles if available let fileTree: FileTreeNode[] = [] let fileTokenScores: Record = {} let tokenCallers: Record = {} - if (cwd && projectFiles) { - const result = await computeProjectIndex(cwd, projectFiles) + const projectIndex = cwd + ? getProjectIndexInput({ cwd, fs, logger, projectFiles, discoveredProject }) + : undefined + if (projectIndex) { + const result = await computeProjectIndex(projectIndex) fileTree = result.fileTree fileTokenScores = result.fileTokenScores tokenCallers = result.tokenCallers @@ -490,7 +579,11 @@ export async function initialSessionState( } // Load skills from project and home directories - const skills = await loadSkills({ cwd: cwd ?? process.cwd(), skillsPath: skillsDir, verbose: false }) + const skills = await loadSkills({ + cwd: cwd ?? process.cwd(), + skillsPath: skillsDir, + verbose: false, + }) const initialState = getInitialSessionState({ projectRoot: cwd ?? process.cwd(), @@ -506,14 +599,7 @@ export async function initialSessionState( gitChanges, changesSinceLastChat: {}, shellConfigFiles: {}, - systemInfo: { - platform: process.platform, - shell: 'bash', - nodeVersion: process.version, - arch: process.arch, - homedir: os.homedir(), - cpus: os.cpus().length ?? 1, - }, + systemInfo: getSystemInfo(), }) if (maxAgentSteps) { @@ -545,6 +631,7 @@ export async function generateInitialRunState({ fs: CodebuffFileSystem }): Promise { return { + traceSessionId: crypto.randomUUID(), sessionState: await initialSessionState({ cwd, skillsDir, @@ -624,11 +711,17 @@ export async function applyOverridesToSessionState( // Apply projectFiles override (recomputes file tree and token scores) if (overrides.projectFiles !== undefined) { if (cwd) { - const { fileTree, fileTokenScores, tokenCallers } = - await computeProjectIndex(cwd, overrides.projectFiles) - sessionState.fileContext.fileTree = fileTree - sessionState.fileContext.fileTokenScores = fileTokenScores - sessionState.fileContext.tokenCallers = tokenCallers + const projectIndex = getProjectIndexInput({ + cwd, + projectFiles: overrides.projectFiles, + }) + if (projectIndex) { + const { fileTree, fileTokenScores, tokenCallers } = + await computeProjectIndex(projectIndex) + sessionState.fileContext.fileTree = fileTree + sessionState.fileContext.fileTokenScores = fileTokenScores + sessionState.fileContext.tokenCallers = tokenCallers + } } else { // If projectFiles are provided but no cwd, reset file context fields sessionState.fileContext.fileTree = [] diff --git a/sdk/src/run.ts b/sdk/src/run.ts index 4db516a479..4014e85449 100644 --- a/sdk/src/run.ts +++ b/sdk/src/run.ts @@ -15,6 +15,7 @@ import { import { toolNames } from '@codebuff/common/tools/constants' import { clientToolCallSchema } from '@codebuff/common/tools/list' import { AgentOutputSchema } from '@codebuff/common/types/session-state' +import { extractApiErrorDetails } from '@codebuff/common/util/error' import { cloneDeep } from 'lodash' import { getErrorStatusCode } from './error-utils' @@ -26,7 +27,9 @@ import { applyPatchTool } from './tools/apply-patch' import { codeSearch } from './tools/code-search' import { glob } from './tools/glob' import { listDirectory } from './tools/list-directory' +import { getProjectPathLookupKeys } from './tools/path-utils' import { getFiles } from './tools/read-files' +import { readUrl } from './tools/read-url' import { runTerminalCommand } from './tools/run-terminal-command' import type { CustomToolDefinition } from './custom-tool' @@ -146,6 +149,10 @@ export type RunOptions = { extraToolResults?: ToolMessage[] signal?: AbortSignal costMode?: string + /** Extra key/values merged into each LLM request's `codebuff_metadata`. + * Used by hosts (e.g. the CLI) to forward client-scoped identifiers like + * `freebuff_instance_id` that server-side gates read from the request body. */ + extraCodebuffMetadata?: Record } const createAbortError = (signal?: AbortSignal) => { @@ -171,6 +178,8 @@ export async function run(options: RunExecutionOptions): Promise { const abortError = createAbortError(signal) return { sessionState: options.previousRun?.sessionState, + traceSessionId: + options.previousRun?.traceSessionId ?? crypto.randomUUID(), output: { type: 'error', message: abortError.message, @@ -212,6 +221,7 @@ async function runOnce({ extraToolResults, signal, costMode, + extraCodebuffMetadata, }: RunExecutionOptions): Promise { const fsSourceValue = typeof fsSource === 'function' ? fsSource() : fsSource const fs = await fsSourceValue @@ -262,6 +272,7 @@ async function runOnce({ logger, }) } + const traceSessionId = previousRun?.traceSessionId ?? crypto.randomUUID() let resolve: (value: RunReturnType) => any = () => {} let _reject: (error: any) => any = () => {} @@ -276,16 +287,27 @@ async function runOnce({ } } + // The agent runtime mutates sessionState.mainAgentState as it progresses, + // replacing messageHistory with a new array once it adds the user prompt. + // Comparing array identity detects progress more robustly than length: + // context pruning could shrink history below its starting length without + // meaning the runtime never ran. + const initialMessageHistory = sessionState.mainAgentState.messageHistory + /** Calculates the current session state if cancelled. * - * This is used when callMainPrompt throws an error (the server never processed the request). - * We need to add the user's message here since the server didn't get a chance to add it. + * This is used when callMainPrompt throws an error. If the agent runtime made + * any progress (replaced the shared messageHistory), those messages are + * preserved. Otherwise the user's message is added so it isn't lost. */ function getCancelledSessionState(message: string): SessionState { + const runtimeMadeProgress = + sessionState.mainAgentState.messageHistory !== initialMessageHistory + const state = cloneDeep(sessionState) - // Add the user's message since the server never processed it - if (prompt || preparedContent) { + // Only add the user's message if the runtime didn't get a chance to add it. + if (!runtimeMadeProgress && (prompt || preparedContent)) { state.mainAgentState.messageHistory.push({ role: 'user' as const, content: buildUserMessageContent(prompt, params, preparedContent), @@ -304,6 +326,7 @@ async function runOnce({ message = message ?? 'Run cancelled by user.' return { sessionState: getCancelledSessionState(message), + traceSessionId, output: { type: 'error', message, @@ -393,7 +416,7 @@ async function runOnce({ filteredTools.push(tool) continue } - if (tool.name in toolNames) { + if (toolNames.includes(tool.name)) { filteredTools.push(tool) continue } @@ -417,7 +440,11 @@ async function runOnce({ cwd, fs, }) - return toOptionalFile(files[filePath] ?? null) + const lookupKeys = cwd + ? getProjectPathLookupKeys(cwd, filePath) + : [filePath] + const fileKey = lookupKeys.find((key) => key in files) + return toOptionalFile(fileKey === undefined ? null : files[fileKey]!) }, sendAction: ({ action }) => { if (action.type === 'action-error') { @@ -438,6 +465,7 @@ async function runOnce({ resolve, onError, initialSessionState: sessionState, + traceSessionId, }) return } @@ -447,6 +475,7 @@ async function runOnce({ resolve, onError, initialSessionState: sessionState, + traceSessionId, }) return } @@ -508,17 +537,38 @@ async function runOnce({ repoId: undefined, clientSessionId: promptId, userId, + extraCodebuffMetadata: { + ...(extraCodebuffMetadata ?? {}), + trace_session_id: traceSessionId, + }, signal: signal ?? new AbortController().signal, }).catch((error) => { - const errorMessage = + let errorMessage = error instanceof Error ? error.message : String(error ?? '') - const statusCode = getErrorStatusCode(error) + const apiErrorDetails = extractApiErrorDetails(error) + const statusCode = apiErrorDetails.statusCode ?? getErrorStatusCode(error) + const { + countryBlockReason, + countryCode, + errorCode, + ipPrivacySignals, + message: parsedMessage, + } = apiErrorDetails + if (parsedMessage) { + errorMessage = parsedMessage + } + resolve({ sessionState: getCancelledSessionState(errorMessage), + traceSessionId, output: { type: 'error', message: errorMessage, ...(statusCode !== undefined && { statusCode }), + ...(errorCode !== undefined && { error: errorCode }), + ...(countryCode !== undefined && { countryCode }), + ...(countryBlockReason !== undefined && { countryBlockReason }), + ...(ipPrivacySignals !== undefined && { ipPrivacySignals }), }, }) }) @@ -655,6 +705,8 @@ async function handleToolCall({ cwd: path.resolve(resolvedCwd, input.cwd ?? '.'), env, } as Parameters[0]) + } else if (toolName === 'read_url') { + result = await readUrl(input as Parameters[0]) } else if (toolName === 'code_search') { result = await codeSearch({ projectPath: requireCwd(cwd, 'code_search'), @@ -786,11 +838,13 @@ async function handlePromptResponse({ resolve, onError, initialSessionState, + traceSessionId, }: { action: ServerAction<'prompt-response'> | ServerAction<'prompt-error'> resolve: (value: RunReturnType) => any onError: (error: { message: string }) => void initialSessionState: SessionState + traceSessionId: string }) { if (action.type === 'prompt-error') { onError({ message: action.message }) @@ -798,6 +852,7 @@ async function handlePromptResponse({ const statusCode = extractStatusCodeFromMessage(action.message) resolve({ sessionState: initialSessionState, + traceSessionId, output: { type: 'error', message: action.message, @@ -817,6 +872,7 @@ async function handlePromptResponse({ onError({ message }) resolve({ sessionState: initialSessionState, + traceSessionId, output: { type: 'error', message, @@ -828,6 +884,7 @@ async function handlePromptResponse({ const state: RunState = { sessionState, + traceSessionId, output: output ?? { type: 'error', message: 'No output from agent', @@ -841,6 +898,7 @@ async function handlePromptResponse({ }) resolve({ sessionState: initialSessionState, + traceSessionId, output: { type: 'error', message: 'Internal error: prompt response type not handled', diff --git a/sdk/src/tools/change-file.ts b/sdk/src/tools/change-file.ts index da372e7dbc..dbcb55effd 100644 --- a/sdk/src/tools/change-file.ts +++ b/sdk/src/tools/change-file.ts @@ -4,9 +4,11 @@ import { fileExists } from '@codebuff/common/util/file' import { applyPatch } from 'diff' import z from 'zod/v4' +import { resolveFilePathWithinProject } from './path-utils' import type { CodebuffToolOutput } from '@codebuff/common/tools/list' import type { CodebuffFileSystem } from '@codebuff/common/types/filesystem' +import type { ResolvedProjectPath } from './path-utils' const FileChangeSchema = z.object({ type: z.enum(['patch', 'file']), @@ -14,20 +16,12 @@ const FileChangeSchema = z.object({ content: z.string(), }) -function containsUpwardTraversal(dirPath: string): boolean { - const normalized = path.normalize(dirPath) - return normalized.includes('..') -} +type FileChange = z.infer -/** - * Checks if a path contains path traversal sequences that would escape the root. - * Uses proper path normalization to prevent traversal attacks. - */ -function containsPathTraversal(filePath: string): boolean { - const normalized = path.normalize(filePath) - // Check for absolute paths or paths starting with .. that escape root - return path.isAbsolute(normalized) || normalized.startsWith('..') -} +type ApplyChangeResult = + | { status: 'created' | 'modified'; file: string } + | { status: 'patchFailed'; file: string; patch: string } + | { status: 'invalid'; file: string } export async function changeFile(params: { parameters: unknown @@ -36,114 +30,78 @@ export async function changeFile(params: { }): Promise> { const { parameters, cwd, fs } = params - if (containsUpwardTraversal(cwd)) { - throw new Error('cwd contains invalid path traversal') - } const fileChange = FileChangeSchema.parse(parameters) - if (containsPathTraversal(fileChange.path)) { - throw new Error('file path contains invalid path traversal') - } - const lines = fileChange.content.split('\n') - - const { created, modified, invalid, patchFailed } = await applyChanges({ - projectRoot: cwd, - changes: [fileChange], - fs, - }) - - const results: CodebuffToolOutput<'str_replace'>[0]['value'][] = [] - - for (const file of created) { - results.push({ - file, - message: 'Created new file', - unifiedDiff: lines.join('\n'), - }) + const resolvedPath = resolveFilePathWithinProject(cwd, fileChange.path) + if (!resolvedPath) { + throw new Error('file path is outside the project directory') } - for (const file of modified) { - results.push({ - file, - message: 'Updated file', - unifiedDiff: lines.join('\n'), - }) - } + const result = await applyChange({ change: fileChange, resolvedPath, fs }) - for (const file of patchFailed) { - results.push({ - file, - errorMessage: `Failed to apply patch.`, - patch: lines.join('\n'), - }) - } + return [{ type: 'json', value: formatApplyChangeResult(result, fileChange) }] +} - for (const file of invalid) { - results.push({ - file, - errorMessage: - 'Failed to write to file: file path caused an error or file could not be written', - }) +function formatApplyChangeResult( + result: ApplyChangeResult, + fileChange: FileChange, +): CodebuffToolOutput<'str_replace'>[0]['value'] { + if (result.status === 'created' || result.status === 'modified') { + return { + file: result.file, + message: + fileChange.type === 'patch' + ? 'String replace applied successfully.' + : result.status === 'created' + ? 'Created file successfully.' + : 'Overwrote file successfully.', + } } - if (results.length !== 1) { - throw new Error( - `Internal error: Unexpected result length while modifying files: ${ - results.length - }`, - ) + if (result.status === 'patchFailed') { + return { + file: result.file, + errorMessage: `Failed to apply patch.`, + patch: result.patch, + } } - return [{ type: 'json', value: results[0] }] + return { + file: result.file, + errorMessage: + 'Failed to write to file: file path caused an error or file could not be written', + } } -async function applyChanges(params: { - projectRoot: string - changes: { - type: 'patch' | 'file' - path: string - content: string - }[] +async function applyChange(params: { + change: FileChange + resolvedPath: ResolvedProjectPath fs: CodebuffFileSystem -}) { - const { projectRoot, changes, fs } = params - - const created: string[] = [] - const modified: string[] = [] - const patchFailed: string[] = [] - const invalid: string[] = [] - - for (const change of changes) { - const { path: filePath, content, type } = change - try { - const fullPath = path.join(projectRoot, filePath) - const exists = await fileExists({ filePath: fullPath, fs }) - if (!exists) { - const dirPath = path.dirname(fullPath) - await fs.mkdir(dirPath, { recursive: true }) - } - - if (type === 'file') { - await fs.writeFile(fullPath, content) - } else { - const oldContent = await fs.readFile(fullPath, 'utf-8') - const newContent = applyPatch(oldContent, content) - if (newContent === false) { - patchFailed.push(filePath) - continue - } - await fs.writeFile(fullPath, newContent) - } +}): Promise { + const { change, resolvedPath, fs } = params + const { content, type } = change + const { fullPath, relativePath } = resolvedPath + + try { + const exists = await fileExists({ filePath: fullPath, fs }) + if (!exists) { + const dirPath = path.dirname(fullPath) + await fs.mkdir(dirPath, { recursive: true }) + } - if (exists) { - modified.push(filePath) - } else { - created.push(filePath) + if (type === 'file') { + await fs.writeFile(fullPath, content) + } else { + const oldContent = await fs.readFile(fullPath, 'utf-8') + const newContent = applyPatch(oldContent, content) + if (newContent === false) { + return { status: 'patchFailed', file: relativePath, patch: content } } - } catch (error) { - console.error(`Failed to apply patch to ${filePath}:`, error, content) - invalid.push(filePath) + await fs.writeFile(fullPath, newContent) } - } - return { created, modified, invalid, patchFailed } + return { status: exists ? 'modified' : 'created', file: relativePath } + } catch (error) { + console.error(`Failed to apply patch to ${relativePath}:`, error, content) + return { status: 'invalid', file: relativePath } + } } diff --git a/sdk/src/tools/code-search.ts b/sdk/src/tools/code-search.ts index 6bd656b6a4..2fa0286d5c 100644 --- a/sdk/src/tools/code-search.ts +++ b/sdk/src/tools/code-search.ts @@ -98,7 +98,10 @@ export function codeSearch({ const rgPath = getBundledRgPath(import.meta.url) if (logger) { - logger.info({ rgPath, args, searchCwd }, 'code-search: Spawning ripgrep process') + logger.info( + { rgPath, args, searchCwd }, + 'code-search: Spawning ripgrep process', + ) } const childProcess = spawn(rgPath, args, { cwd: searchCwd, @@ -111,6 +114,7 @@ export function codeSearch({ const fileGroups = new Map() // Track match count per file separately from total lines const fileMatchCounts = new Map() + const filesLimitedByMaxResults = new Set() let matchesGlobal = 0 let estimatedOutputLen = 0 let killedForLimit = false @@ -140,7 +144,7 @@ export function codeSearch({ const hardKill = () => { try { childProcess.kill('SIGTERM') - } catch { } + } catch {} // Store timeout reference so it can be cleared if process closes normally killTimeoutId = setTimeout(() => { try { @@ -148,12 +152,22 @@ export function codeSearch({ } catch { try { childProcess.kill() - } catch { } + } catch {} } killTimeoutId = null }, 1000) } + const formatCollectedOutput = (rawOutput: string) => + formatCodeSearchOutput(rawOutput, { + matchCount: matchesGlobal, + }) + + const truncateOutput = (output: string, maxLength: number) => + output.length > maxLength + ? output.substring(0, maxLength) + '\n\n[Output truncated]' + : output + const timeoutId = setTimeout(() => { if (isResolved) return hardKill() @@ -165,10 +179,10 @@ export function codeSearch({ } const partialOutput = collectedLines.join('\n') - const truncatedStdout = - partialOutput.length > 1000 - ? partialOutput.substring(0, 1000) + '\n\n[Output truncated]' - : partialOutput + const truncatedStdout = truncateOutput( + formatCollectedOutput(partialOutput), + 1000, + ) const truncatedStderr = stderrBuf.length > 1000 ? stderrBuf.substring(0, 1000) + '\n\n[Error output truncated]' @@ -228,6 +242,9 @@ export function codeSearch({ // For matches: only if we haven't hit the per-file limit // For context: always include (they don't count toward limit) const shouldInclude = !isMatch || fileMatchCount < maxResults + if (isMatch && !shouldInclude) { + filesLimitedByMaxResults.add(filePath) + } if (shouldInclude) { // Add the line to output @@ -253,13 +270,10 @@ export function codeSearch({ limitedLines.push(...lines) } const rawOutput = limitedLines.join('\n') - const formattedOutput = formatCodeSearchOutput(rawOutput) - - const finalOutput = - formattedOutput.length > maxOutputStringLength - ? formattedOutput.substring(0, maxOutputStringLength) + - '\n\n[Output truncated]' - : formattedOutput + const finalOutput = truncateOutput( + formatCollectedOutput(rawOutput), + maxOutputStringLength, + ) const limitReason = matchesGlobal >= globalMaxResults @@ -324,6 +338,13 @@ export function codeSearch({ !isMatch || (fileMatchCount < maxResults && matchesGlobal < globalMaxResults) + if ( + isMatch && + fileMatchCount >= maxResults && + matchesGlobal < globalMaxResults + ) { + filesLimitedByMaxResults.add(filePath) + } if (shouldInclude) { fileLines.push(formattedLine) @@ -335,10 +356,10 @@ export function codeSearch({ } } } - } catch { } + } catch {} } } - } catch { } + } catch {} // Build final output from collected matches const limitedLines: string[] = [] @@ -346,9 +367,7 @@ export function codeSearch({ for (const [filename, fileLines] of fileGroups) { limitedLines.push(...fileLines) - // Note if file was truncated (based on match count, not total lines) - const fileMatchCount = fileMatchCounts.get(filename) ?? 0 - if (fileMatchCount >= maxResults) { + if (filesLimitedByMaxResults.has(filename)) { truncatedFiles.push( `${filename}: limited to ${maxResults} results per file`, ) @@ -374,20 +393,17 @@ export function codeSearch({ rawOutput += `\n\n[${truncationMessages.join('\n\n')}]` } - const formattedOutput = formatCodeSearchOutput(rawOutput) - // Truncate output to prevent memory issues - const truncatedStdout = - formattedOutput.length > maxOutputStringLength - ? formattedOutput.substring(0, maxOutputStringLength) + - '\n\n[Output truncated]' - : formattedOutput + const truncatedStdout = truncateOutput( + formatCollectedOutput(rawOutput), + maxOutputStringLength, + ) const truncatedStderr = stderrBuf ? stderrBuf + - (stderrBuf.length >= Math.floor(maxOutputStringLength / 5) - ? '\n\n[Error output truncated]' - : '') + (stderrBuf.length >= Math.floor(maxOutputStringLength / 5) + ? '\n\n[Error output truncated]' + : '') : '' settle({ diff --git a/sdk/src/tools/path-utils.ts b/sdk/src/tools/path-utils.ts new file mode 100644 index 0000000000..92fe8a1325 --- /dev/null +++ b/sdk/src/tools/path-utils.ts @@ -0,0 +1,41 @@ +import path from 'path' + +export type ResolvedProjectPath = { + fullPath: string + relativePath: string +} + +function escapesProject(relativePath: string): boolean { + return ( + relativePath === '..' || + relativePath.startsWith(`..${path.sep}`) || + path.isAbsolute(relativePath) + ) +} + +export function resolveFilePathWithinProject( + projectRoot: string, + filePath: string, +): ResolvedProjectPath | null { + const resolvedRoot = path.resolve(projectRoot) + const fullPath = path.isAbsolute(filePath) + ? path.resolve(filePath) + : path.resolve(resolvedRoot, filePath) + const relativePath = path.relative(resolvedRoot, fullPath) + + if (relativePath === '' || escapesProject(relativePath)) { + return null + } + + return { fullPath, relativePath } +} + +export function getProjectPathLookupKeys( + projectRoot: string, + filePath: string, +): string[] { + const resolvedPath = resolveFilePathWithinProject(projectRoot, filePath) + const keys = resolvedPath ? [resolvedPath.relativePath, filePath] : [filePath] + + return [...new Set(keys)] +} diff --git a/sdk/src/tools/read-files.ts b/sdk/src/tools/read-files.ts index e2d68b95fe..a6462f1a24 100644 --- a/sdk/src/tools/read-files.ts +++ b/sdk/src/tools/read-files.ts @@ -1,8 +1,8 @@ -import path, { isAbsolute } from 'path' - import { FILE_READ_STATUS } from '@codebuff/common/old-constants' import { isFileIgnored } from '@codebuff/common/project-file-tree' +import { resolveFilePathWithinProject } from './path-utils' + import type { CodebuffFileSystem } from '@codebuff/common/types/filesystem' export type FileFilterResult = { @@ -28,22 +28,22 @@ export async function getFiles(params: { const hasCustomFilter = fileFilter !== undefined const result: Record = {} - const MAX_FILE_SIZE = 1024 * 1024 // 1MB in bytes + const MAX_FILE_BYTES = 10 * 1024 * 1024 // 10MB - skip reading entirely + const MAX_CHARS = 100_000 // 100k characters threshold + const numFmt = new Intl.NumberFormat('en-US') + const fmtNum = (n: number) => numFmt.format(n) for (const filePath of filePaths) { if (!filePath) { continue } - // Convert absolute paths within project to relative paths - const relativePath = filePath.startsWith(cwd) - ? path.relative(cwd, filePath) - : filePath - const fullPath = path.join(cwd, relativePath) - if (isAbsolute(relativePath) || !fullPath.startsWith(cwd)) { - result[relativePath] = FILE_READ_STATUS.OUTSIDE_PROJECT + const resolvedPath = resolveFilePathWithinProject(cwd, filePath) + if (!resolvedPath) { + result[filePath] = FILE_READ_STATUS.OUTSIDE_PROJECT continue } + const { relativePath, fullPath } = resolvedPath // Apply file filter if provided const filterResult = fileFilter?.(relativePath) @@ -68,13 +68,27 @@ export async function getFiles(params: { } try { + // Safety check: skip reading files over 10MB to avoid OOM const stats = await fs.stat(fullPath) - if (stats.size > MAX_FILE_SIZE) { + if (stats.size > MAX_FILE_BYTES) { result[relativePath] = FILE_READ_STATUS.TOO_LARGE + - ` [${(stats.size / (1024 * 1024)).toFixed(2)}MB]` + ` [${(stats.size / (1024 * 1024)).toFixed(1)}MB exceeds 10MB limit. Use code_search or glob to find specific content.]` + continue + } + + const content = await fs.readFile(fullPath, 'utf8') + + if (content.length > MAX_CHARS) { + const truncated = content.slice(0, MAX_CHARS) + result[relativePath] = + truncated + + '\n\n[FILE_TOO_LARGE: This file is ' + + fmtNum(content.length) + + ' chars, exceeding the ' + + fmtNum(MAX_CHARS) + + ' char limit. The content above has been truncated. Use other tools to read other sections of the file.]' } else { - const content = await fs.readFile(fullPath, 'utf8') // Prepend TEMPLATE marker for example files result[relativePath] = isExampleFile ? FILE_READ_STATUS.TEMPLATE + '\n' + content diff --git a/sdk/src/tools/read-url.ts b/sdk/src/tools/read-url.ts new file mode 100644 index 0000000000..9bd5c89f86 --- /dev/null +++ b/sdk/src/tools/read-url.ts @@ -0,0 +1,413 @@ +import type { CodebuffToolOutput } from '../../../common/src/tools/list' + +const DEFAULT_MAX_CHARS = 20_000 +const MAX_RESPONSE_BYTES = 2_000_000 +const FETCH_TIMEOUT_MS = 20_000 +const USER_AGENT = + 'Mozilla/5.0 (compatible; CodebuffResearchBot/1.0; +https://codebuff.com)' + +type ReadUrlOutput = CodebuffToolOutput<'read_url'> +type FetchLike = ( + input: string | URL | Request, + init?: RequestInit, +) => Promise + +function errorResult( + url: string | undefined, + errorMessage: string, +): ReadUrlOutput { + return [{ type: 'json', value: { ...(url ? { url } : {}), errorMessage } }] +} + +function isAllowedUrl(url: URL): boolean { + return url.protocol === 'http:' || url.protocol === 'https:' +} + +function getHeader(headers: Headers, name: string): string | undefined { + return headers.get(name) ?? undefined +} + +async function readResponseBody( + response: Response, + maxBytes: number, +): Promise { + const contentLength = getHeader(response.headers, 'content-length') + if (contentLength && Number(contentLength) > maxBytes) { + throw new Error(`Response is too large (${contentLength} bytes)`) + } + + if (!response.body) { + const buffer = await response.arrayBuffer() + if (buffer.byteLength > maxBytes) { + throw new Error(`Response is too large (${buffer.byteLength} bytes)`) + } + return new TextDecoder().decode(buffer) + } + + const reader = response.body.getReader() + const chunks: Uint8Array[] = [] + let totalBytes = 0 + + while (true) { + const { done, value } = await reader.read() + if (done) break + if (!value) continue + + totalBytes += value.byteLength + if (totalBytes > maxBytes) { + await reader.cancel() + throw new Error(`Response exceeded ${maxBytes} bytes`) + } + chunks.push(value) + } + + const body = new Uint8Array(totalBytes) + let offset = 0 + for (const chunk of chunks) { + body.set(chunk, offset) + offset += chunk.byteLength + } + + return new TextDecoder().decode(body) +} + +function decodeHtmlEntities(text: string): string { + const namedEntities: Record = { + amp: '&', + apos: "'", + copy: '(c)', + hellip: '...', + gt: '>', + lt: '<', + mdash: '-', + middot: '*', + nbsp: ' ', + ndash: '-', + quot: '"', + rsquo: "'", + } + + return text.replace(/&(#x?[0-9a-fA-F]+|[a-zA-Z]+);/g, (entity, body) => { + if (body[0] === '#') { + const isHex = body[1]?.toLowerCase() === 'x' + const value = Number.parseInt(body.slice(isHex ? 2 : 1), isHex ? 16 : 10) + return Number.isFinite(value) && value >= 0 && value <= 0x10ffff + ? String.fromCodePoint(value) + : entity + } + return namedEntities[body] ?? entity + }) +} + +function normalizeText(text: string): string { + return text + .replace(/\r/g, '') + .replace(/[ \t\f\v]+/g, ' ') + .replace(/ *\n */g, '\n') + .replace(/\n{3,}/g, '\n\n') + .split('\n') + .map((line) => line.trim()) + .filter(Boolean) + .join('\n') + .trim() +} + +function extractFirstMatch(html: string, pattern: RegExp): string | undefined { + const match = html.match(pattern) + if (!match?.[1]) return undefined + return normalizeText(decodeHtmlEntities(stripTags(match[1]))) +} + +function stripTags(html: string): string { + return html.replace(/<[^>]*>/g, ' ') +} + +function removeElement(html: string, tagName: string): string { + return html.replace( + new RegExp(`<${tagName}\\b[^>]*>[\\s\\S]*?<\\/${tagName}>`, 'gi'), + '\n', + ) +} + +function extractElementContents(html: string, tagName: string): string[] { + const matches = html.matchAll( + new RegExp(`<${tagName}\\b[^>]*>([\\s\\S]*?)<\\/${tagName}>`, 'gi'), + ) + return Array.from(matches, (match) => match[1]).filter(Boolean) +} + +function selectReadableHtml(html: string): string { + const articleCandidates = extractElementContents(html, 'article') + if (articleCandidates.length > 0) { + return articleCandidates.reduce((best, candidate) => + stripTags(candidate).length > stripTags(best).length ? candidate : best, + ) + } + + const mainCandidates = extractElementContents(html, 'main') + if (mainCandidates.length > 0) { + return mainCandidates.reduce((best, candidate) => + stripTags(candidate).length > stripTags(best).length ? candidate : best, + ) + } + + return html +} + +function extractMetaContent(html: string, name: string): string | undefined { + const escapedName = name.replace(/[.*+?^${}()|[\]\\]/g, '\\$&') + const patterns = [ + new RegExp( + `]*(?:name|property)=["']${escapedName}["'])(?=[^>]*content=["']([^"']*)["'])[^>]*>`, + 'i', + ), + new RegExp( + `]*content=["']([^"']*)["'])(?=[^>]*(?:name|property)=["']${escapedName}["'])[^>]*>`, + 'i', + ), + ] + + for (const pattern of patterns) { + const match = html.match(pattern) + if (match?.[1]) return normalizeText(decodeHtmlEntities(match[1])) + } + return undefined +} + +function extractHtml(html: string): { + title?: string + description?: string + text: string +} { + const title = extractFirstMatch(html, /]*>([\s\S]*?)<\/title>/i) + const description = + extractMetaContent(html, 'description') ?? + extractMetaContent(html, 'og:description') + + let readable = html + .replace(//g, '\n') + .replace(/]*>/gi, '\n') + + for (const tagName of [ + 'script', + 'style', + 'svg', + 'canvas', + 'iframe', + 'noscript', + 'nav', + 'header', + 'footer', + 'form', + 'button', + 'select', + ]) { + readable = removeElement(readable, tagName) + } + + readable = selectReadableHtml(readable) + + readable = readable + .replace(//gi, '\n') + .replace( + /<\/(p|div|section|article|main|aside|li|tr|td|th|h[1-6]|blockquote|pre)>/gi, + '\n', + ) + .replace(/<(li|tr|h[1-6])\b[^>]*>/gi, '\n') + .replace(/<[^>]*>/g, '') + + const text = normalizeText(decodeHtmlEntities(readable)) + return { title, description, text } +} + +function extractMarkdownFrontmatter(body: string): { + title?: string + description?: string + text: string +} { + const match = body.match(/^---\s*\r?\n([\s\S]*?)\r?\n---\s*\r?\n?/) + if (!match) { + return { text: normalizeText(decodeHtmlEntities(body)) } + } + + const frontmatter = match[1] + const getValue = (key: 'title' | 'description') => { + const valueMatch = frontmatter.match( + new RegExp(`^${key}:\\s*(?:"([^"]*)"|'([^']*)'|(.+))\\s*$`, 'm'), + ) + return normalizeText( + decodeHtmlEntities( + valueMatch?.[1] ?? valueMatch?.[2] ?? valueMatch?.[3] ?? '', + ), + ) + } + + return { + title: getValue('title') || undefined, + description: getValue('description') || undefined, + text: normalizeText(decodeHtmlEntities(body.slice(match[0].length))), + } +} + +function isJsonContentType(contentType: string): boolean { + return ( + contentType.includes('application/json') || contentType.includes('+json') + ) +} + +function isMarkdownContentType(contentType: string): boolean { + return contentType.includes('text/markdown') +} + +function isSupportedContentType(contentType: string): boolean { + return /^(text\/|application\/(json|[^;\s/]+\+json|xhtml\+xml|xml|rss\+xml|atom\+xml)\b)/i.test( + contentType, + ) +} + +function extractTextByContentType( + contentType: string, + body: string, +): { + title?: string + description?: string + text: string +} { + const lowerContentType = contentType.toLowerCase() + + if ( + lowerContentType.includes('text/html') || + lowerContentType.includes('application/xhtml') + ) { + return extractHtml(body) + } + + if (isJsonContentType(lowerContentType)) { + try { + return { text: JSON.stringify(JSON.parse(body), null, 2) } + } catch { + return { text: normalizeText(body) } + } + } + + if (isMarkdownContentType(lowerContentType)) { + return extractMarkdownFrontmatter(body) + } + + if ( + lowerContentType.startsWith('text/') || + lowerContentType.includes('application/xml') || + lowerContentType.includes('application/rss+xml') || + lowerContentType.includes('application/atom+xml') + ) { + return { text: normalizeText(body) } + } + + return { text: normalizeText(body) } +} + +function truncateText( + text: string, + maxChars: number, +): { + text: string + truncated: boolean +} { + if (text.length <= maxChars) { + return { text, truncated: false } + } + return { + text: `${text.slice(0, maxChars).trimEnd()}\n\n[Content truncated]`, + truncated: true, + } +} + +export async function readUrl({ + url, + max_chars = DEFAULT_MAX_CHARS, + fetch: fetchImpl = globalThis.fetch, +}: { + url: string + max_chars?: number + fetch?: FetchLike +}): Promise { + let parsedUrl: URL + try { + parsedUrl = new URL(url) + } catch { + return errorResult(url, 'Invalid URL') + } + + if (!isAllowedUrl(parsedUrl)) { + return errorResult(url, 'Only http:// and https:// URLs are supported') + } + + const controller = new AbortController() + const timeout = setTimeout(() => controller.abort(), FETCH_TIMEOUT_MS) + + try { + const response = await fetchImpl(parsedUrl.toString(), { + redirect: 'follow', + signal: controller.signal, + headers: { + accept: + 'text/html,application/xhtml+xml,application/json,text/plain;q=0.9,*/*;q=0.8', + 'accept-language': 'en-US,en;q=0.9', + 'user-agent': USER_AGENT, + }, + }) + + if (!response.ok) { + return errorResult( + url, + `Failed to fetch URL: ${response.status} ${response.statusText}`, + ) + } + + const contentType = getHeader(response.headers, 'content-type') ?? '' + if (contentType && !isSupportedContentType(contentType)) { + return errorResult( + url, + `Unsupported content type: ${contentType || 'unknown'}`, + ) + } + + const body = await readResponseBody(response, MAX_RESPONSE_BYTES) + const extracted = extractTextByContentType(contentType, body) + const truncated = truncateText(extracted.text, max_chars) + + if (!truncated.text) { + return errorResult(url, 'No readable text found at URL') + } + + return [ + { + type: 'json', + value: { + url, + finalUrl: response.url || parsedUrl.toString(), + status: response.status, + ...(contentType ? { contentType } : {}), + ...(extracted.title ? { title: extracted.title } : {}), + ...(extracted.description + ? { description: extracted.description } + : {}), + text: truncated.text, + truncated: truncated.truncated, + }, + }, + ] + } catch (error) { + const isAbort = error instanceof Error && error.name === 'AbortError' + return errorResult( + url, + isAbort + ? `Timed out after ${FETCH_TIMEOUT_MS} ms` + : error instanceof Error + ? error.message + : 'Unknown error', + ) + } finally { + clearTimeout(timeout) + } +} diff --git a/sdk/test/setup-env.ts b/sdk/test/setup-env.ts index 45b4fa8148..381bb09691 100644 --- a/sdk/test/setup-env.ts +++ b/sdk/test/setup-env.ts @@ -18,7 +18,7 @@ const testDefaults: Record = { const serverDefaults: Record = { OPEN_ROUTER_API_KEY: 'test', OPENAI_API_KEY: 'test', - LINKUP_API_KEY: 'test', + SERPER_API_KEY: 'test', PORT: '4242', DATABASE_URL: 'postgres://user:pass@localhost:5432/db', CODEBUFF_GITHUB_ID: 'test-id', diff --git a/test/setup-scm-loader.ts b/test/setup-scm-loader.ts new file mode 100644 index 0000000000..6acafba756 --- /dev/null +++ b/test/setup-scm-loader.ts @@ -0,0 +1,15 @@ +import { plugin } from 'bun' +import { readFile } from 'fs/promises' + +plugin({ + name: 'scm-text-loader', + setup(build) { + build.onLoad({ filter: /\.scm$/ }, async (args) => { + const text = await readFile(args.path, 'utf8') + return { + exports: { default: text }, + loader: 'object', + } + }) + }, +}) diff --git a/web/instrumentation.ts b/web/instrumentation.ts index 6ce22befe4..422a11c9e0 100644 --- a/web/instrumentation.ts +++ b/web/instrumentation.ts @@ -10,7 +10,7 @@ import { logger } from '@/util/logger' -export function register() { +export async function register() { // Handle unhandled promise rejections (async errors that aren't caught) process.on( 'unhandledRejection', @@ -45,4 +45,14 @@ export function register() { }) logger.info({}, '[Instrumentation] Global error handlers registered') + + // DB-touching admission module uses `postgres`, which imports Node built-ins + // like `crypto`. Gate on NEXT_RUNTIME so the edge bundle doesn't try to + // resolve them. + if (process.env.NEXT_RUNTIME === 'nodejs') { + const { startFreeSessionAdmission } = await import( + '@/server/free-session/admission' + ) + startFreeSessionAdmission() + } } diff --git a/web/jest.config.cjs b/web/jest.config.cjs index ccbf30ee18..5736284c2d 100644 --- a/web/jest.config.cjs +++ b/web/jest.config.cjs @@ -13,8 +13,8 @@ const config = { '^@codebuff/internal/env$': '/../packages/internal/src/env.ts', '^@codebuff/internal/xml-parser$': '/src/test-stubs/xml-parser.ts', '^bun:test$': '/src/test-stubs/bun-test.ts', - '^react$': '/node_modules/react', - '^react-dom$': '/node_modules/react-dom', + '^react$': '/../node_modules/react', + '^react-dom$': '/../node_modules/react-dom', }, // Bun-specific tests that use top-level await or bun:test features testPathIgnorePatterns: [ diff --git a/web/jest.setup.js b/web/jest.setup.js index c44951a680..9f6d201bbb 100644 --- a/web/jest.setup.js +++ b/web/jest.setup.js @@ -1 +1,25 @@ import '@testing-library/jest-dom' +import { TextDecoder, TextEncoder } from 'node:util' +import { ReadableStream, WritableStream, TransformStream } from 'node:stream/web' + +// JSDOM lacks Node's Web API globals — undici (loaded transitively via +// `next/server` and `openai`) needs these at module-load time. +if (typeof globalThis.TextEncoder === 'undefined') { + globalThis.TextEncoder = TextEncoder +} +if (typeof globalThis.TextDecoder === 'undefined') { + globalThis.TextDecoder = TextDecoder +} +if (typeof globalThis.ReadableStream === 'undefined') { + globalThis.ReadableStream = ReadableStream + globalThis.WritableStream = WritableStream + globalThis.TransformStream = TransformStream +} +if (typeof globalThis.Request === 'undefined') { + const undici = require('undici') + globalThis.Request = undici.Request + globalThis.Response = undici.Response + globalThis.Headers = undici.Headers + globalThis.fetch = undici.fetch + globalThis.FormData = undici.FormData +} diff --git a/web/knowledge.md b/web/knowledge.md index f1316ec790..63dff2da40 100644 --- a/web/knowledge.md +++ b/web/knowledge.md @@ -92,22 +92,6 @@ Key files: - Store user_id as property for internal reference - Track events with consistent naming: `category.event_name` -## Referral System - -### Workflow - -1. Users get unique referral codes upon account creation -2. Share referral links: `${env.NEXT_PUBLIC_CODEBUFF_APP_URL}/redeem?referral_code=${referralCode}` -3. New users redeem codes during signup/onboarding -4. Both referrer and referred user receive `CREDITS_REFERRAL_BONUS` credits -5. Referrals tracked in database with limits - -### Key Components - -- `web/src/app/referrals/page.tsx`: Main referrals UI -- `web/src/app/api/referrals/route.ts`: API operations -- `web/src/app/onboard/page.tsx`: Referral code processing - ## Verifying Changes After changes, run type checking: diff --git a/web/next.config.mjs b/web/next.config.mjs index fce0f5658b..2927cf1816 100644 --- a/web/next.config.mjs +++ b/web/next.config.mjs @@ -36,6 +36,7 @@ const nextConfig = { 'encoding', 'perf_hooks', 'async_hooks', + 'geoip-lite', ) // Externalize code-map package to avoid bundling tree-sitter WASM files diff --git a/web/package.json b/web/package.json index 9b92c03529..830cbbdc36 100644 --- a/web/package.json +++ b/web/package.json @@ -35,7 +35,7 @@ }, "sideEffects": false, "engines": { - "bun": "^1.3.5" + "bun": "1.3.11" }, "dependencies": { "@codebuff/billing": "workspace:*", @@ -70,9 +70,10 @@ "discord.js": "^14.18.0", "dotenv": "^16.4.7", "framer-motion": "^11.13.3", + "geoip-lite": "^2.0.0", "lucide-react": "^0.487.0", "mermaid": "^11.8.1", - "next": "15.5.11", + "next": "15.5.16", "next-auth": "^4.24.11", "next-contentlayer2": "^0.5.8", "next-themes": "^0.4.6", @@ -97,6 +98,7 @@ "@tailwindcss/typography": "^0.5.15", "@testing-library/jest-dom": "^6.8.0", "@testing-library/react": "^16.3.0", + "@types/geoip-lite": "^1.4.4", "@types/jest": "^29.5.14", "@types/node": "^22.14.0", "@types/pg": "^8.11.11", diff --git a/web/src/__tests__/e2e/redirects.spec.ts b/web/src/__tests__/e2e/redirects.spec.ts index 7f119f5990..a2c2065d50 100644 --- a/web/src/__tests__/e2e/redirects.spec.ts +++ b/web/src/__tests__/e2e/redirects.spec.ts @@ -71,80 +71,5 @@ if (isBun) { }) }) - test.describe('Sponsee (affiliate link) redirect', () => { - test('shows error page for unknown sponsee', async ({ page }) => { - await page.goto('/unknown-sponsee-name-12345') - - // Should show the error message for unknown sponsee - await expect( - page.getByText("that link doesn't look right", { exact: false }), - ).toBeVisible() - await expect( - page.getByText('unknown-sponsee-name-12345', { exact: false }), - ).toBeVisible() - }) - - test('error page includes support email link', async ({ page }) => { - await page.goto('/nonexistent-referrer') - - // Should have a link to support email - const supportLink = page.locator('a[href^="mailto:"]') - await expect(supportLink).toBeVisible() - }) - - // Note: Testing the happy path (successful redirect with query param preservation) - // requires a valid sponsee in the database. This test documents the expected behavior - // and can be run against a seeded test database. - test.describe('with seeded database', { tag: '@seeded-db' }, () => { - test.skip( - () => !process.env.E2E_TEST_SPONSEE, - 'Requires E2E_TEST_SPONSEE env var with a valid sponsee handle', - ) - - test('preserves query parameters when redirecting to referral page', async ({ - request, - }) => { - const sponsee = process.env.E2E_TEST_SPONSEE! - const response = await request.get( - `/${sponsee}?utm_source=twitter&utm_campaign=test&custom=value`, - { - maxRedirects: 0, - }, - ) - - // Should redirect to /referrals/ - expect(response.status()).toBe(307) - const location = response.headers()['location'] - expect(location).toMatch(/^\/referrals\//) - - // Query params should be preserved - expect(location).toContain('utm_source=twitter') - expect(location).toContain('utm_campaign=test') - expect(location).toContain('custom=value') - - // Referrer param should be added - expect(location).toContain(`referrer=${sponsee}`) - }) - - test('referrer param overrides existing referrer in query', async ({ - request, - }) => { - const sponsee = process.env.E2E_TEST_SPONSEE! - const response = await request.get( - `/${sponsee}?referrer=should-be-overridden`, - { - maxRedirects: 0, - }, - ) - - expect(response.status()).toBe(307) - const location = response.headers()['location'] - - // The referrer should be the sponsee name, not the original value - expect(location).toContain(`referrer=${sponsee}`) - expect(location).not.toContain('should-be-overridden') - }) - }) - }) }) } diff --git a/web/src/__tests__/playwright-runner.e2e.ts b/web/src/__tests__/playwright-runner.e2e.ts index 28686d50bd..a107424668 100644 --- a/web/src/__tests__/playwright-runner.e2e.ts +++ b/web/src/__tests__/playwright-runner.e2e.ts @@ -22,7 +22,7 @@ describe('playwright e2e suite', () => { env.NEXT_PUBLIC_WEB_PORT ||= '3000' env.OPEN_ROUTER_API_KEY ||= 'test' env.OPENAI_API_KEY ||= 'test' - env.LINKUP_API_KEY ||= 'test' + env.SERPER_API_KEY ||= 'test' env.PORT = env.NEXT_PUBLIC_WEB_PORT env.DATABASE_URL = getE2EDatabaseUrl() env.CODEBUFF_GITHUB_ID ||= 'test-id' diff --git a/web/src/app/[sponsee]/page.tsx b/web/src/app/[sponsee]/page.tsx index 2c74d14e5a..e09eb7c00b 100644 --- a/web/src/app/[sponsee]/page.tsx +++ b/web/src/app/[sponsee]/page.tsx @@ -69,7 +69,6 @@ export default async function SponseePage({ ) } - // Build query string preserving all incoming params and adding/overriding referrer const queryParams = new URLSearchParams() for (const [key, value] of Object.entries(resolvedSearchParams)) { if (value !== undefined) { diff --git a/web/src/app/affiliates/actions.ts b/web/src/app/affiliates/actions.ts deleted file mode 100644 index d27c3d84b1..0000000000 --- a/web/src/app/affiliates/actions.ts +++ /dev/null @@ -1,135 +0,0 @@ -'use server' - -import { AFFILIATE_USER_REFFERAL_LIMIT } from '@codebuff/common/old-constants' -import db from '@codebuff/internal/db' -import * as schema from '@codebuff/internal/db/schema' -import { eq, and, ne } from 'drizzle-orm' -import { revalidatePath } from 'next/cache' -import { getServerSession } from 'next-auth' -import { z } from 'zod/v4' - -import { authOptions } from '@/app/api/auth/[...nextauth]/auth-options' - -const RESERVED_HANDLES = [ - 'api', - 'docs', - 'hackathon', - 'login', - 'onboard', - 'payment-change', - 'payment-success', - 'pricing', - 'privacy-policy', - 'referrals', - 'subscription', - 'terms-of-service', - 'usage', - 'affiliates', - 'discord', - 'ingest', - 'admin', - 'auth', - 'user', - 'profile', - 'settings', - 'support', - 'help', - 'contact', - 'root', - 'codebuff', - 'manicode', - 'status', - 'healthz', -].map((h) => h.toLowerCase()) - -const HandleSchema = z - .string() - .min(3, 'Handle must be at least 3 characters long.') - .max(20, 'Handle cannot be longer than 20 characters.') - .regex( - /^[a-zA-Z0-9_]+$/, - 'Handle can only contain letters, numbers, and underscores.', - ) - .transform((str) => str.toLowerCase()) - .refine((handle) => !RESERVED_HANDLES.includes(handle), { - message: 'This handle is reserved and cannot be used.', - }) - -export interface SetHandleFormState { - message: string - success: boolean - fieldErrors?: { - handle?: string[] - } -} - -export async function setAffiliateHandleAction( - prevState: SetHandleFormState, - formData: FormData, -): Promise { - const session = await getServerSession(authOptions) - - if (!session?.user?.id) { - return { success: false, message: 'Authentication required.' } - } - - const userId = session.user.id - const handleResult = HandleSchema.safeParse(formData.get('handle')) - - if (!handleResult.success) { - const formErrors = handleResult.error.flatten().formErrors - const message = - formErrors.find((err) => err.includes('reserved')) || - formErrors[0] || - 'Invalid handle format.' - return { - success: false, - message: message, - fieldErrors: { handle: formErrors }, - } - } - - const desiredHandle = handleResult.data - - try { - const currentUser = await db.query.user.findFirst({ - where: eq(schema.user.id, userId), - columns: { handle: true }, - }) - - if (currentUser?.handle) { - return { success: false, message: 'You already have a handle set.' } - } - - const existingUser = await db.query.user.findFirst({ - where: and( - eq(schema.user.handle, desiredHandle), - ne(schema.user.id, userId), - ), - columns: { id: true }, - }) - - if (existingUser) { - return { - success: false, - message: `Handle "${desiredHandle}" is already taken. Please choose another.`, - fieldErrors: { handle: ['This handle is already taken.'] }, - } - } - - await db - .update(schema.user) - .set({ - handle: desiredHandle, - referral_limit: AFFILIATE_USER_REFFERAL_LIMIT, - }) - .where(eq(schema.user.id, userId)) - - revalidatePath('/affiliates') - - return { success: true, message: 'Handle set successfully!' } - } catch (error) { - console.error('Error setting affiliate handle:', error) - return { success: false, message: 'An unexpected error occurred.' } - } -} diff --git a/web/src/app/affiliates/affiliates-client.tsx b/web/src/app/affiliates/affiliates-client.tsx deleted file mode 100644 index 4eff1907ec..0000000000 --- a/web/src/app/affiliates/affiliates-client.tsx +++ /dev/null @@ -1,265 +0,0 @@ -'use client' - -import { env } from '@codebuff/common/env' -import { - CREDITS_REFERRAL_BONUS, - AFFILIATE_USER_REFFERAL_LIMIT, -} from '@codebuff/common/old-constants' -import Link from 'next/link' -import { useSession } from 'next-auth/react' -import React, { useEffect, useState, useCallback, useActionState } from 'react' - -import { setAffiliateHandleAction } from './actions' - -import type { SetHandleFormState } from './actions' - -import CardWithBeams from '@/components/card-with-beams' -import { SignInCardFooter } from '@/components/sign-in/sign-in-card-footer' -import { Button } from '@/components/ui/button' -import { - Card, - CardContent, - CardDescription, - CardHeader, - CardTitle, -} from '@/components/ui/card' -import { Input } from '@/components/ui/input' -import { Label } from '@/components/ui/label' -import { Skeleton } from '@/components/ui/skeleton' -import { useToast } from '@/components/ui/use-toast' - -function SubmitButton({ pending }: { pending: boolean }) { - return ( - - ) -} - -function SetHandleForm({ - onHandleSetSuccess, -}: { - onHandleSetSuccess: () => void -}) { - const { toast } = useToast() - const initialState: SetHandleFormState = { - message: '', - success: false, - fieldErrors: {}, - } - const [state, formAction, isPending] = useActionState( - setAffiliateHandleAction, - initialState, - ) - - useEffect(() => { - if (state.message) { - toast({ - title: state.success ? 'Success!' : 'Error', - description: state.message, - variant: state.success ? 'default' : 'destructive', - }) - if (state.success) { - onHandleSetSuccess() - } - } - }, [state, toast, onHandleSetSuccess]) - - return ( -
-
- -

- This will be part of your referral link (e.g., - codebuff.com/your_unique_handle). -

-

- 3-20 chars. letters, numbers, underscores only. -

- - - {state.fieldErrors?.handle && ( -

- {state.fieldErrors.handle.join(', ')} -

- )} - {!state.success && state.message && !state.fieldErrors?.handle && ( -

{state.message}

- )} -
- - - ) -} - -export default function AffiliatesClient() { - const { status: sessionStatus } = useSession() - const [userProfile, setUserProfile] = useState< - { handle: string | null; referralCode: string | null } | undefined - >(undefined) - const [fetchError, setFetchError] = useState(null) - - const fetchUserProfile = useCallback(() => { - setFetchError(null) - fetch('/api/user/profile') - .then(async (res) => { - if (!res.ok) { - const errorData = await res.json().catch(() => ({})) - throw new Error( - errorData.error || `HTTP error! status: ${res.status}`, - ) - } - return res.json() - }) - .then((data) => { - setUserProfile({ - handle: data.handle ?? null, - referralCode: data.referral_code ?? null, - }) - }) - .catch((error) => { - console.error('Failed to fetch user profile:', error) - setFetchError(error.message || 'Failed to load profile data.') - setUserProfile({ handle: null, referralCode: null }) - }) - }, []) - - useEffect(() => { - if (sessionStatus === 'authenticated') { - fetchUserProfile() - } else if (sessionStatus === 'unauthenticated') { - setUserProfile({ handle: null, referralCode: null }) - } - }, [sessionStatus, fetchUserProfile]) - - if (sessionStatus === 'loading' || userProfile === undefined) { - return ( -
-
- - - - - - - - - - - -
-
- ) - } - - if (sessionStatus === 'unauthenticated') { - return ( - -

- Want to partner with Codebuff and earn rewards? Log in first! -

- - - } - /> - ) - } - - if (fetchError) { - return ( -
-
-

Error loading affiliate information: {fetchError}

-

Please try refreshing the page or contact support.

-
-
- ) - } - - const userHandle = userProfile?.handle - const _referralCode = userProfile?.referralCode - - return ( -
-
- - - - Codebuff Affiliate Program - - - Share Codebuff and earn credits! - - - - {userHandle === null && ( -
-

- Become an Affiliate -

-

- Generate your unique referral link, that grants you{' '} - {AFFILIATE_USER_REFFERAL_LIMIT.toLocaleString()} referrals for - your friends, colleagues, and followers. When they sign up - using your link, you'll both earn an extra{' '} - {CREDITS_REFERRAL_BONUS} credits! -

- - -
- )} - - {userHandle && ( -
-

- Your Affiliate Handle -

-

- Your affiliate handle is set to:{' '} - - {userHandle} - - . You can now refer up to{' '} - {AFFILIATE_USER_REFFERAL_LIMIT.toLocaleString()} new users! -

-

- Your referral link is:{' '} - {`${env.NEXT_PUBLIC_CODEBUFF_APP_URL}/${userHandle}`} -

-
- )} - -

- Questions? Contact us at{' '} - - {env.NEXT_PUBLIC_SUPPORT_EMAIL} - - . -

-
-
-
-
- ) -} diff --git a/web/src/app/affiliates/page.tsx b/web/src/app/affiliates/page.tsx deleted file mode 100644 index f51ea2de8b..0000000000 --- a/web/src/app/affiliates/page.tsx +++ /dev/null @@ -1,130 +0,0 @@ -import { env } from '@codebuff/common/env' - -import AffiliatesClient from './affiliates-client' - -import type { Metadata } from 'next' - - -export async function generateMetadata(): Promise { - const canonicalUrl = `${env.NEXT_PUBLIC_CODEBUFF_APP_URL}/affiliates` - - const title = 'Affiliate Program – Earn Credits by Referring | Codebuff' - const description = - 'Join the Codebuff Affiliate Program. Share your unique referral link and earn credits when friends sign up. Both you and your referrals get bonus credits!' - - return { - title, - description, - alternates: { - canonical: canonicalUrl, - }, - openGraph: { - title, - description, - url: canonicalUrl, - type: 'website', - siteName: 'Codebuff', - images: '/opengraph-image.png', - }, - twitter: { - card: 'summary_large_image', - title, - description, - images: '/opengraph-image.png', - }, - keywords: [ - 'affiliate program', - 'referral program', - 'earn credits', - 'Codebuff affiliate', - 'Codebuff referral', - 'AI coding assistant affiliate', - ], - } -} - -// WebPage JSON-LD schema describing the affiliate program -function WebPageJsonLd() { - const jsonLd = { - '@context': 'https://schema.org', - '@type': 'WebPage', - name: 'Codebuff Affiliate Program', - description: - 'Join the Codebuff Affiliate Program. Share your unique referral link and earn credits when friends sign up.', - url: `${env.NEXT_PUBLIC_CODEBUFF_APP_URL}/affiliates`, - mainEntity: { - '@type': 'Service', - name: 'Codebuff Affiliate Program', - description: - 'Referral program that rewards users with bonus credits for inviting new users to Codebuff.', - provider: { - '@type': 'Organization', - name: 'Codebuff', - url: env.NEXT_PUBLIC_CODEBUFF_APP_URL, - }, - serviceType: 'Affiliate/Referral Program', - areaServed: 'Worldwide', - offers: { - '@type': 'Offer', - price: '0', - priceCurrency: 'USD', - description: - 'Free to join. Earn bonus credits for both referrer and referee.', - }, - }, - isPartOf: { - '@type': 'WebSite', - name: 'Codebuff', - url: env.NEXT_PUBLIC_CODEBUFF_APP_URL, - }, - } - - return ( -