gabriel / musehub public
mcp.py python
632 lines 23.3 KB
9396568e fix: HIGH security patch (H1–H8) — dev → main (#25) Gabriel Cardona <cgcardona@gmail.com> 2d ago
1 """MuseHub MCP — Full Streamable HTTP transport (MCP 2025-11-25).
2
3 Implements the complete Streamable HTTP transport spec:
4
5 POST /mcp — client → server messages (requests, notifications, responses).
6 - Returns ``application/json`` for most requests (resources, prompts, ping).
7 - Returns ``text/event-stream`` when a tool needs SSE (elicitation, progress).
8 - Returns 202 Accepted for notifications (no body).
9 - Issues ``Mcp-Session-Id`` header on successful ``initialize``.
10 - Validates ``Mcp-Session-Id`` on all subsequent requests.
11 - Validates ``MCP-Protocol-Version`` header on non-initialize requests.
12 - Validates ``Origin`` header to prevent DNS-rebinding attacks.
13
14 GET /mcp — server → client SSE push channel.
15 - Opens a persistent ``text/event-stream`` for server-initiated messages.
16 - Supports ``Last-Event-ID`` for reconnection replay.
17 - Injects heartbeat comments every 15 s to keep proxies alive.
18
19 DELETE /mcp — client-initiated session termination.
20 - Returns 200 on success, 404 if session unknown.
21
22 Auth model (unchanged from 2025-03-26):
23 - ``Authorization: Bearer <jwt>`` → authenticated; user_id from JWT ``sub``.
24 - No token → anonymous; read-only tools work, write tools return isError=true.
25 - Invalid/expired token → 401.
26
27 Security:
28 - Origin header validated on all POST/GET/DELETE requests.
29 - Allowed origins configured via ``MUSEHUB_ALLOWED_ORIGINS`` env var
30 (comma-separated). Defaults to localhost in dev mode.
31 """
32
33 import json
34 import logging
35 import os
36 from collections.abc import AsyncIterator
37 from urllib.parse import urlparse
38
39 from fastapi import APIRouter, Request, Response
40 from fastapi.responses import JSONResponse, StreamingResponse
41 from musehub.rate_limits import limiter, MCP_LIMIT
42
43 from musehub.contracts.json_types import JSONObject, JSONValue
44 from musehub.mcp.dispatcher import handle_batch, handle_request
45 from musehub.mcp.session import (
46 MCPSession,
47 SessionCapacityError,
48 create_session,
49 delete_session,
50 get_session,
51 push_to_session,
52 register_sse_queue,
53 resolve_elicitation,
54 )
55 from musehub.mcp.sse import SSE_CONTENT_TYPE, heartbeat_stream, sse_heartbeat
56
57 logger = logging.getLogger(__name__)
58
59 router = APIRouter(tags=["MCP"])
60
61 _PROTOCOL_VERSION = "2025-11-25"
62
63 # ── Origin validation ─────────────────────────────────────────────────────────
64
65 _ALLOWED_ORIGINS: frozenset[str] = frozenset(
66 o.strip()
67 for o in os.environ.get(
68 "MUSEHUB_ALLOWED_ORIGINS",
69 "http://localhost,http://127.0.0.1,https://musehub.app",
70 ).split(",")
71 if o.strip()
72 )
73
74 _ALWAYS_ALLOW_ORIGINS: frozenset[str] = frozenset({
75 "http://localhost",
76 "http://127.0.0.1",
77 })
78
79
80 def _validate_origin(request: Request) -> bool:
81 """Return True if the request origin is allowed.
82
83 Per the Streamable HTTP spec, servers MUST validate Origin to prevent
84 DNS-rebinding attacks. Requests without an Origin header are allowed
85 (e.g. curl, Postman, stdio bridge tools).
86 """
87 origin = request.headers.get("Origin")
88 if origin is None:
89 return True # Non-browser clients don't send Origin.
90
91 # Normalise: strip path component, keep scheme+host+port.
92 try:
93 parsed = urlparse(origin)
94 normalised = f"{parsed.scheme}://{parsed.netloc}"
95 except Exception:
96 return False
97
98 return normalised in _ALLOWED_ORIGINS or normalised in _ALWAYS_ALLOW_ORIGINS
99
100
101 # ── Auth helper ───────────────────────────────────────────────────────────────
102
103
104 class _AuthResult:
105 """Parsed authentication result from a Bearer JWT."""
106
107 __slots__ = ("user_id", "is_agent", "agent_name")
108
109 def __init__(
110 self,
111 user_id: str | None,
112 is_agent: bool = False,
113 agent_name: str | None = None,
114 ) -> None:
115 self.user_id = user_id
116 self.is_agent = is_agent
117 self.agent_name = agent_name
118
119
120 async def _extract_auth(request: Request) -> _AuthResult | Response:
121 """Return an :class:`_AuthResult` from JWT, anonymous result, or a 401 Response.
122
123 Returns a ``Response`` object (not an ``_AuthResult``) when the token is
124 invalid. Callers must check ``isinstance(result, Response)`` and return early.
125
126 Agent tokens (``token_type: "agent"`` JWT claim) are identified here and
127 propagated into the :class:`~musehub.mcp.context.ToolCallContext` so that
128 downstream services can apply higher rate limits and tag activity events
129 with the "agent" badge.
130 """
131 auth_header = request.headers.get("Authorization", "")
132 if not auth_header.startswith("Bearer "):
133 return _AuthResult(user_id=None)
134
135 token_str = auth_header[7:]
136 try:
137 from musehub.auth.tokens import validate_access_code
138 claims = validate_access_code(token_str)
139 is_agent = claims.get("token_type") == "agent"
140 return _AuthResult(
141 user_id=claims.get("sub"),
142 is_agent=is_agent,
143 agent_name=claims.get("agent_name") if is_agent else None,
144 )
145 except Exception:
146 return JSONResponse(
147 status_code=401,
148 content={"error": "Invalid or expired access token."},
149 headers={"WWW-Authenticate": "Bearer"},
150 )
151
152
153 # ── POST /mcp ─────────────────────────────────────────────────────────────────
154
155
156 @router.post(
157 "/mcp",
158 operation_id="mcpEndpoint",
159 summary="MCP Streamable HTTP — POST endpoint (2025-11-25)",
160 include_in_schema=True,
161 )
162 @limiter.limit(MCP_LIMIT)
163 async def mcp_post(request: Request) -> Response:
164 """MCP Streamable HTTP POST endpoint.
165
166 Handles all client→server JSON-RPC messages. Returns ``application/json``
167 for most requests, ``text/event-stream`` when the response requires SSE
168 (tool calls that use elicitation or progress streaming), and
169 ``202 Accepted`` for notifications.
170
171 The ``initialize`` method issues an ``Mcp-Session-Id`` response header.
172 All subsequent requests must include that header.
173 """
174 # ── Security: Origin validation ───────────────────────────────────────────
175 if not _validate_origin(request):
176 return JSONResponse(
177 status_code=403,
178 content={
179 "jsonrpc": "2.0",
180 "id": None,
181 "error": {"code": -32600, "message": "Forbidden: invalid Origin header"},
182 },
183 )
184
185 # ── Auth ──────────────────────────────────────────────────────────────────
186 auth_or_resp = await _extract_auth(request)
187 if isinstance(auth_or_resp, Response):
188 return auth_or_resp
189 auth: _AuthResult = auth_or_resp
190 user_id: str | None = auth.user_id
191
192 # ── Parse body ────────────────────────────────────────────────────────────
193 try:
194 body = await request.body()
195 raw = json.loads(body)
196 except (json.JSONDecodeError, ValueError) as exc:
197 return JSONResponse(
198 status_code=400,
199 content={
200 "jsonrpc": "2.0",
201 "id": None,
202 "error": {"code": -32700, "message": f"Parse error: {exc}"},
203 },
204 )
205
206 # ── Determine if this is initialize or a subsequent request ──────────────
207 # For batches, use the method of the first request.
208 def _first_method(r: object) -> str | None:
209 if isinstance(r, dict):
210 m = r.get("method")
211 return m if isinstance(m, str) else None
212 if isinstance(r, list) and r:
213 return r[0].get("method") if isinstance(r[0], dict) else None
214 return None
215
216 first_method = _first_method(raw)
217 is_initialize = first_method == "initialize"
218
219 # ── Session management ────────────────────────────────────────────────────
220 session: MCPSession | None = None
221 session_id_header = request.headers.get("Mcp-Session-Id")
222
223 if not is_initialize:
224 # Validate MCP-Protocol-Version on non-initialize requests.
225 proto_ver = request.headers.get("MCP-Protocol-Version")
226 if proto_ver and proto_ver not in ("2025-11-25", "2025-03-26"):
227 return JSONResponse(
228 status_code=400,
229 content={
230 "jsonrpc": "2.0",
231 "id": None,
232 "error": {
233 "code": -32600,
234 "message": f"Unsupported MCP-Protocol-Version: {proto_ver!r}",
235 },
236 },
237 )
238
239 if session_id_header:
240 session = get_session(session_id_header)
241 if session is None:
242 # Session expired or unknown → client must re-initialize.
243 return JSONResponse(
244 status_code=404,
245 content={
246 "jsonrpc": "2.0",
247 "id": None,
248 "error": {
249 "code": -32600,
250 "message": "Session not found. Send a new InitializeRequest without Mcp-Session-Id.",
251 },
252 },
253 )
254
255 # ── Check if this is an elicitation response (client→server) ─────────────
256 # When the client sends the result of an elicitation/create back, it's a
257 # JSON-RPC *response* (has "result" or "error" but no "method"). Route it
258 # to the session's pending Future resolver.
259 if session and isinstance(raw, dict) and "method" not in raw and "id" in raw:
260 req_id = raw.get("id")
261 if "result" in raw and req_id is not None:
262 resolved = resolve_elicitation(session, req_id, raw["result"])
263 if resolved:
264 return Response(status_code=202)
265 # Unknown response — ignore per spec.
266 return Response(status_code=202)
267
268 # ── Dispatch ──────────────────────────────────────────────────────────────
269 try:
270 if isinstance(raw, list):
271 responses = await handle_batch(
272 raw, user_id=user_id, session=session,
273 is_agent=auth.is_agent, agent_name=auth.agent_name,
274 )
275 if not responses:
276 return Response(status_code=202)
277 return JSONResponse(content=responses)
278
279 elif isinstance(raw, dict):
280 # Detect if tool needs SSE streaming (elicitation tools set this).
281 needs_sse = _method_needs_sse(raw) and session is not None
282
283 if needs_sse and session is not None:
284 return _make_sse_tool_response(
285 raw, user_id=user_id, session=session,
286 is_agent=auth.is_agent, agent_name=auth.agent_name,
287 )
288
289 resp = await handle_request(
290 raw, user_id=user_id, session=session,
291 is_agent=auth.is_agent, agent_name=auth.agent_name,
292 )
293 if resp is None:
294 return Response(status_code=202)
295
296 # Attach Mcp-Session-Id on initialize.
297 if is_initialize:
298 try:
299 new_session = _create_session_from_initialize(raw, user_id)
300 except SessionCapacityError as cap_exc:
301 logger.warning("MCP session capacity exceeded: %s", cap_exc)
302 return JSONResponse(
303 status_code=503,
304 content={
305 "jsonrpc": "2.0",
306 "id": None,
307 "error": {
308 "code": -32000,
309 "message": str(cap_exc),
310 },
311 },
312 headers={"Retry-After": "5"},
313 )
314 return JSONResponse(
315 content=resp,
316 headers={"Mcp-Session-Id": new_session.session_id},
317 )
318
319 return JSONResponse(content=resp)
320
321 else:
322 return JSONResponse(
323 status_code=400,
324 content={
325 "jsonrpc": "2.0",
326 "id": None,
327 "error": {
328 "code": -32600,
329 "message": "Request must be an object or array",
330 },
331 },
332 )
333
334 except Exception as exc:
335 logger.exception("Unhandled error in POST /mcp: %s", exc)
336 return JSONResponse(
337 status_code=500,
338 content={
339 "jsonrpc": "2.0",
340 "id": None,
341 "error": {"code": -32603, "message": f"Internal error: {exc}"},
342 },
343 )
344
345
346 # ── GET /mcp ──────────────────────────────────────────────────────────────────
347
348
349 @router.get(
350 "/mcp",
351 operation_id="mcpSseStream",
352 summary="MCP Streamable HTTP — GET SSE push channel (2025-11-25)",
353 include_in_schema=True,
354 )
355 async def mcp_get(request: Request) -> Response:
356 """Open a persistent SSE stream for server-initiated messages.
357
358 The client MUST include ``Accept: text/event-stream`` and a valid
359 ``Mcp-Session-Id``. Returns 405 if SSE is not accepted.
360
361 Supports ``Last-Event-ID`` header for reconnection and event replay.
362 Heartbeat comments are sent every 15 s to keep proxies alive.
363
364 Server-initiated messages delivered on this stream:
365 - ``notifications/progress`` — tool progress updates.
366 - ``elicitation/create`` — requests for user input.
367 - ``notifications/elicitation/complete`` — URL mode completion signals.
368 """
369 if not _validate_origin(request):
370 return Response(status_code=403)
371
372 accept = request.headers.get("Accept", "")
373 if "text/event-stream" not in accept:
374 return Response(status_code=405, content="SSE requires Accept: text/event-stream")
375
376 session_id = request.headers.get("Mcp-Session-Id")
377 if not session_id:
378 return JSONResponse(
379 status_code=400,
380 content={"error": "Mcp-Session-Id header required for GET /mcp SSE stream"},
381 )
382
383 session = get_session(session_id)
384 if session is None:
385 return Response(status_code=404, content="Session not found or expired")
386
387 last_event_id = request.headers.get("Last-Event-ID")
388
389 async def event_generator() -> AsyncIterator[str]:
390 async for event_text in heartbeat_stream(
391 register_sse_queue(session, last_event_id),
392 interval_seconds=15.0,
393 ):
394 yield event_text
395
396 return StreamingResponse(
397 event_generator(),
398 media_type=SSE_CONTENT_TYPE,
399 headers={
400 "Cache-Control": "no-cache",
401 "X-Accel-Buffering": "no", # Disable nginx buffering.
402 },
403 )
404
405
406 # ── DELETE /mcp ───────────────────────────────────────────────────────────────
407
408
409 @router.delete(
410 "/mcp",
411 operation_id="mcpDeleteSession",
412 summary="MCP Streamable HTTP — DELETE session (2025-11-25)",
413 include_in_schema=True,
414 )
415 async def mcp_delete(request: Request) -> Response:
416 """Client-initiated session termination.
417
418 Closes all open SSE streams for the session and cancels any pending
419 elicitation Futures. Returns 200 on success, 404 if unknown.
420 """
421 if not _validate_origin(request):
422 return Response(status_code=403)
423
424 session_id = request.headers.get("Mcp-Session-Id")
425 if not session_id:
426 return JSONResponse(
427 status_code=400,
428 content={"error": "Mcp-Session-Id header required"},
429 )
430
431 deleted = delete_session(session_id)
432 if not deleted:
433 return Response(status_code=404, content="Session not found")
434
435 logger.info("MCP session terminated by client: %.8s...", session_id)
436 return Response(status_code=200)
437
438
439 # ── GET /mcp/docs.json ────────────────────────────────────────────────────────
440
441
442 @router.get(
443 "/mcp/docs.json",
444 operation_id="mcpDocsJson",
445 summary="MCP capability manifest — machine-readable JSON",
446 include_in_schema=True,
447 )
448 async def mcp_docs_json() -> JSONResponse:
449 """Return a machine-readable JSON manifest of all MCP capabilities.
450
451 This endpoint is the programmatic complement to ``GET /mcp/docs``.
452 AI agents and tool integrators can fetch this to discover:
453 - The full tool catalogue (names, descriptions, input schemas)
454 - All static and templated resources (URIs, names, descriptions)
455 - All available prompts (names, descriptions, arguments)
456 - Server info and protocol version
457
458 No authentication required — this is intentionally public so agents
459 can bootstrap without prior credentials.
460 """
461 from musehub.mcp.tools import MCP_TOOLS
462 from musehub.mcp.resources import STATIC_RESOURCES, RESOURCE_TEMPLATES
463 from musehub.mcp.prompts import PROMPT_CATALOGUE
464
465 tools_out = [
466 {k: v for k, v in t.items() if k != "server_side"}
467 for t in MCP_TOOLS
468 ]
469 resources_out = [
470 {
471 "uri": r.get("uri"),
472 "name": r.get("name"),
473 "description": r.get("description"),
474 "mimeType": r.get("mimeType"),
475 }
476 for r in STATIC_RESOURCES
477 ]
478 templates_out = [
479 {
480 "uriTemplate": t.get("uriTemplate"),
481 "name": t.get("name"),
482 "description": t.get("description"),
483 "mimeType": t.get("mimeType"),
484 }
485 for t in RESOURCE_TEMPLATES
486 ]
487 prompts_out = [
488 {
489 "name": p["name"],
490 "description": p["description"],
491 "arguments": p.get("arguments", []),
492 }
493 for p in PROMPT_CATALOGUE
494 ]
495
496 return JSONResponse(
497 content={
498 "server": {
499 "name": "musehub-mcp",
500 "version": "1.1.0",
501 "protocolVersion": _PROTOCOL_VERSION,
502 "endpoint": "/mcp",
503 "docsUrl": "/mcp/docs",
504 },
505 "tools": tools_out,
506 "resources": resources_out,
507 "resourceTemplates": templates_out,
508 "prompts": prompts_out,
509 },
510 headers={"Cache-Control": "public, max-age=300"},
511 )
512
513
514 # ── GET /mcp/docs ─────────────────────────────────────────────────────────────
515
516
517 @router.get(
518 "/mcp/docs",
519 operation_id="mcpDocs",
520 summary="MCP reference — human-readable documentation page",
521 include_in_schema=True,
522 )
523 async def mcp_docs(request: Request) -> Response:
524 """Render a human-readable reference page for the MuseHub MCP server.
525
526 Lists all tools, resources, resource templates, and prompts with their
527 descriptions and input schemas. Also shows:
528 - Connection instructions (endpoint URL, auth model, protocol version)
529 - Agent onboarding quick-start guide
530 - Link to ``/mcp/docs.json`` for machine-readable access
531
532 No authentication required.
533 """
534 try:
535 from musehub.api.routes.musehub._templates import templates
536 from musehub.mcp.tools import MCP_TOOLS
537 from musehub.mcp.resources import STATIC_RESOURCES, RESOURCE_TEMPLATES
538 from musehub.mcp.prompts import PROMPT_CATALOGUE
539
540 ctx = {
541 "request": request,
542 "tools": MCP_TOOLS,
543 "static_resources": STATIC_RESOURCES,
544 "resource_templates": RESOURCE_TEMPLATES,
545 "prompts": PROMPT_CATALOGUE,
546 "protocol_version": _PROTOCOL_VERSION,
547 }
548 return templates.TemplateResponse(request, "musehub/pages/mcp_docs.html", ctx)
549 except Exception as exc:
550 logger.warning("MCP docs template missing, falling back to JSON redirect: %s", exc)
551 from fastapi.responses import RedirectResponse
552 return RedirectResponse(url="/mcp/docs.json")
553
554
555 # ── Helpers ───────────────────────────────────────────────────────────────────
556
557
558 def _create_session_from_initialize(
559 raw: JSONObject,
560 user_id: str | None,
561 ) -> MCPSession:
562 """Extract client capabilities from initialize params and create a session."""
563 params = raw.get("params") or {}
564 client_caps: JSONObject = {}
565 if isinstance(params, dict):
566 caps = params.get("capabilities")
567 if isinstance(caps, dict):
568 client_caps = caps
569 return create_session(user_id, client_capabilities=client_caps)
570
571
572 # Tools that may use elicitation or progress streaming (SSE required).
573 _SSE_TOOL_NAMES: frozenset[str] = frozenset({
574 "musehub_create_with_preferences",
575 "musehub_review_pr_interactive",
576 "musehub_connect_streaming_platform",
577 "musehub_connect_daw_cloud",
578 "musehub_create_release_interactive",
579 })
580
581
582 def _method_needs_sse(raw: JSONObject) -> bool:
583 """Return True if this request should be streamed as SSE."""
584 if raw.get("method") != "tools/call":
585 return False
586 params = raw.get("params")
587 if not isinstance(params, dict):
588 return False
589 name = params.get("name")
590 return name in _SSE_TOOL_NAMES
591
592
593 def _make_sse_tool_response(
594 raw: JSONObject,
595 *,
596 user_id: str | None,
597 session: MCPSession,
598 is_agent: bool = False,
599 agent_name: str | None = None,
600 ) -> StreamingResponse:
601 """Return a StreamingResponse that runs the tool and streams results via SSE."""
602 from musehub.mcp.sse import sse_response, sse_notification
603
604 raw_id = raw.get("id")
605 req_id: str | int | None = raw_id if isinstance(raw_id, (str, int)) else None
606
607 async def generator() -> AsyncIterator[str]:
608 try:
609 result = await handle_request(
610 raw, user_id=user_id, session=session,
611 is_agent=is_agent, agent_name=agent_name,
612 )
613 if result is not None:
614 yield sse_response(req_id, result)
615 except Exception as exc:
616 logger.exception("SSE tool call error: %s", exc)
617 error_payload: dict[str, JSONValue] = {
618 "jsonrpc": "2.0",
619 "id": req_id,
620 "error": {"code": -32603, "message": str(exc)},
621 }
622 from musehub.mcp.sse import sse_event
623 yield sse_event(error_payload)
624
625 return StreamingResponse(
626 generator(),
627 media_type=SSE_CONTENT_TYPE,
628 headers={
629 "Cache-Control": "no-cache",
630 "X-Accel-Buffering": "no",
631 },
632 )