#!/usr/bin/env python3 import sys import logging # Configure logging to stderr (captured by journald) logging.basicConfig( level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler(sys.stderr)] ) logger = logging.getLogger(__name__) """ Captain Claude Mobile v2 - Backend API Architecture: SessionManager handles all screen communication WebSocket clients subscribe to session output via queues """ import os import asyncio import secrets import sqlite3 import threading import pty import select import fcntl import subprocess import re from datetime import datetime, timedelta from typing import Optional, Dict, Set from contextlib import asynccontextmanager from pathlib import Path from dataclasses import dataclass, field from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Depends from fastapi.middleware.cors import CORSMiddleware from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from pydantic import BaseModel import jwt # ============================================================================ # Configuration # ============================================================================ DATA_DIR = Path("/home/architect/captain-claude/apps/captain-mobile-v2/data") DB_PATH = DATA_DIR / "captain.db" CLAUDE_CMD = "/home/architect/.npm-global/bin/claude" WORKING_DIR = "/home/architect/captain-claude" JWT_SECRET_FILE = DATA_DIR / ".jwt_secret" JWT_ALGORITHM = "HS256" JWT_EXPIRY_DAYS = 7 def get_jwt_secret(): DATA_DIR.mkdir(parents=True, exist_ok=True) if os.environ.get("JWT_SECRET"): return os.environ.get("JWT_SECRET") if JWT_SECRET_FILE.exists(): return JWT_SECRET_FILE.read_text().strip() secret = secrets.token_hex(32) JWT_SECRET_FILE.write_text(secret) JWT_SECRET_FILE.chmod(0o600) return secret JWT_SECRET = get_jwt_secret() ADMIN_PASSWORD = os.environ.get("ADMIN_PASSWORD") if not ADMIN_PASSWORD: _pass_file = DATA_DIR / ".admin_password" DATA_DIR.mkdir(parents=True, exist_ok=True) if _pass_file.exists(): ADMIN_PASSWORD = _pass_file.read_text().strip() else: ADMIN_PASSWORD = "admin" _pass_file.write_text(ADMIN_PASSWORD) _pass_file.chmod(0o600) VALID_USERS = {"admin": ADMIN_PASSWORD} ALLOWED_ORIGINS = [ "http://localhost:3000", "http://localhost:8080", "http://127.0.0.1:3000", "http://127.0.0.1:8080", "https://captain.tzzrarchitect.me", "capacitor://localhost", "ionic://localhost" ] security = HTTPBearer(auto_error=False) # ============================================================================ # Pydantic Models # ============================================================================ class LoginRequest(BaseModel): username: str password: str class LoginResponse(BaseModel): token: str expires_at: str class ScreenSession(BaseModel): name: str # Short name (e.g., "captain") pid: str # Process ID (e.g., "400970") full_name: str # Full name for commands (e.g., "400970.captain") attached: bool class CreateSessionRequest(BaseModel): name: str # ============================================================================ # Database # ============================================================================ _local = threading.local() def get_db() -> sqlite3.Connection: if not hasattr(_local, 'conn'): _local.conn = sqlite3.connect(str(DB_PATH), check_same_thread=False) _local.conn.row_factory = sqlite3.Row return _local.conn def init_db(): DATA_DIR.mkdir(parents=True, exist_ok=True) conn = sqlite3.connect(str(DB_PATH)) conn.execute(''' CREATE TABLE IF NOT EXISTS conversations ( id TEXT PRIMARY KEY, user_id TEXT NOT NULL, session_name TEXT, title TEXT, created_at TEXT DEFAULT CURRENT_TIMESTAMP, updated_at TEXT DEFAULT CURRENT_TIMESTAMP ) ''') conn.execute(''' CREATE TABLE IF NOT EXISTS messages ( id TEXT PRIMARY KEY, conversation_id TEXT REFERENCES conversations(id) ON DELETE CASCADE, role TEXT NOT NULL, content TEXT NOT NULL, created_at TEXT DEFAULT CURRENT_TIMESTAMP ) ''') conn.commit() conn.close() print(f"Database initialized at {DB_PATH}") # ============================================================================ # Auth Helpers # ============================================================================ def create_token(username: str) -> tuple[str, datetime]: expires = datetime.utcnow() + timedelta(days=JWT_EXPIRY_DAYS) payload = { "sub": username, "exp": expires, "iat": datetime.utcnow() } token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM) return token, expires def verify_token(token: str) -> Optional[str]: try: payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM]) return payload.get("sub") except jwt.ExpiredSignatureError: return None except jwt.InvalidTokenError: return None async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)) -> str: if not credentials: raise HTTPException(status_code=401, detail="Not authenticated") username = verify_token(credentials.credentials) if not username: raise HTTPException(status_code=401, detail="Invalid or expired token") return username # ============================================================================ # Screen Session Utilities # ============================================================================ def list_screen_sessions() -> list[ScreenSession]: """List active screen sessions""" try: result = subprocess.run(["screen", "-ls"], capture_output=True, text=True) sessions = [] for line in result.stdout.split("\n"): if "\t" in line and ("Attached" in line or "Detached" in line): parts = line.strip().split("\t") if len(parts) >= 2: session_info = parts[0] # e.g., "400970.captain" pid_name = session_info.split(".") if len(pid_name) >= 2: sessions.append(ScreenSession( pid=pid_name[0], name=".".join(pid_name[1:]), full_name=session_info, # Use this for all commands attached="Attached" in line )) return sessions except Exception: return [] def get_full_session_name(session_name: str) -> Optional[str]: """Get full session name (PID.name) from short name""" result = subprocess.run(["screen", "-ls"], capture_output=True, text=True) for line in result.stdout.split("\n"): if f".{session_name}" in line and ("\t" in line): parts = line.strip().split("\t")[0] return parts return None def create_screen_session(name: str) -> ScreenSession: """Create a new screen session with Claude""" clean_name = re.sub(r'[^a-z0-9_-]', '', name.replace(" ", "-").lower()) if not clean_name or len(clean_name) < 2 or len(clean_name) > 50: raise ValueError("Invalid session name (2-50 chars, alphanumeric, hyphens, underscores)") result = subprocess.run(["screen", "-ls"], capture_output=True, text=True) if f".{clean_name}" in result.stdout: raise ValueError(f"Session '{clean_name}' already exists") subprocess.run( ["screen", "-dmS", clean_name, CLAUDE_CMD, "--dangerously-skip-permissions"], cwd=WORKING_DIR, check=True ) import time time.sleep(1) result = subprocess.run(["screen", "-ls"], capture_output=True, text=True) for line in result.stdout.split("\n"): if f".{clean_name}" in line: full_name = line.strip().split("\t")[0] # e.g., "400970.captain" parts = full_name.split(".") if len(parts) >= 2: return ScreenSession( pid=parts[0], name=clean_name, full_name=full_name, attached=False ) raise ValueError("Session created but could not find PID") def strip_ansi(text: str) -> str: """Remove ANSI escape codes from text""" ansi_pattern = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') return ansi_pattern.sub('', text) # ============================================================================ # Session Manager - Central coordinator for screen sessions # ============================================================================ @dataclass class ManagedSession: """Represents a managed connection to a screen session""" session_name: str full_name: str master_fd: int = -1 process: Optional[subprocess.Popen] = None subscribers: Set[asyncio.Queue] = field(default_factory=set) subscribers_lock: asyncio.Lock = field(default_factory=asyncio.Lock) # Protect subscribers set reader_task: Optional[asyncio.Task] = None is_running: bool = False output_buffer: str = "" last_content: str = "" # For polling-based change detection class SessionManager: """ Centralized manager for screen session connections. - Maintains ONE reader per session (not per client) - Clients subscribe to session output via queues - Input is sent via screen -X stuff (reliable) """ def __init__(self): self._sessions: Dict[str, ManagedSession] = {} self._lock = asyncio.Lock() async def subscribe(self, full_name: str) -> asyncio.Queue: """ Subscribe to a session's output using full_name (PID.name). Returns queue or raises ValueError. """ logger.debug(f"subscribe: called with full_name={full_name}") async with self._lock: # Validate full_name exists result = subprocess.run(["screen", "-ls"], capture_output=True, text=True) logger.debug(f"subscribe: screen -ls output contains '{full_name}': {full_name in result.stdout}") if full_name not in result.stdout: raise ValueError(f"Session '{full_name}' not found") # Get or create managed session (keyed by full_name) if full_name not in self._sessions: managed = ManagedSession( session_name=full_name.split(".", 1)[1] if "." in full_name else full_name, full_name=full_name ) self._sessions[full_name] = managed # Start the reader await self._start_reader(managed) managed = self._sessions[full_name] # Create subscriber queue with lock protection queue: asyncio.Queue = asyncio.Queue(maxsize=100) async with managed.subscribers_lock: managed.subscribers.add(queue) logger.info(f"subscribe: added queue, now {len(managed.subscribers)} subscribers for {full_name}") return queue async def unsubscribe(self, full_name: str, queue: asyncio.Queue): """Unsubscribe from a session's output""" async with self._lock: if full_name in self._sessions: managed = self._sessions[full_name] async with managed.subscribers_lock: managed.subscribers.discard(queue) remaining = len(managed.subscribers) logger.info(f"unsubscribe: removed queue, {remaining} subscribers remaining for {full_name}") # If no more subscribers, stop the reader if remaining == 0: await self._stop_reader(managed) del self._sessions[full_name] async def send_input(self, full_name: str, content: str) -> bool: """ Send input to a session. CRITICAL: ESC and Enter must be sent SEPARATELY with delays. """ try: # Use screen -X stuff with SEPARATE commands for ESC and Enter # This is the ONLY method that works reliably with Claude Code TUI logger.debug(f"send_input: full_name={full_name}, content={content[:50]}") # Step 1: Clear any pending input with Ctrl+C cmd1 = ["screen", "-S", full_name, "-p", "0", "-X", "stuff", "\x03"] subprocess.run(cmd1, capture_output=True, timeout=2) await asyncio.sleep(0.3) # Step 2: Send the message content cmd2 = ["screen", "-S", full_name, "-p", "0", "-X", "stuff", content] result2 = subprocess.run(cmd2, capture_output=True, timeout=2) if result2.returncode != 0: logger.error(f"send_input: failed to send content, rc={result2.returncode}") return False await asyncio.sleep(0.3) # Step 3: Send ESC separately (CRITICAL - must be separate!) cmd3 = ["screen", "-S", full_name, "-p", "0", "-X", "stuff", "\x1b"] result3 = subprocess.run(cmd3, capture_output=True, timeout=2) if result3.returncode != 0: logger.error(f"send_input: failed to send ESC, rc={result3.returncode}") return False await asyncio.sleep(0.2) # Step 4: Send Enter separately (CRITICAL - must be separate!) cmd4 = ["screen", "-S", full_name, "-p", "0", "-X", "stuff", "\r"] result4 = subprocess.run(cmd4, capture_output=True, timeout=2) if result4.returncode != 0: logger.error(f"send_input: failed to send Enter, rc={result4.returncode}") return False logger.info(f"send_input: message sent successfully to {full_name}") return True except Exception as e: logger.error(f"Error sending to {full_name}: {e}") return False async def _start_reader(self, managed: ManagedSession): """Start reading output from a screen session using PTY + hardcopy polling""" if managed.is_running: return try: # Create PTY for bidirectional communication master_fd, slave_fd = pty.openpty() env = os.environ.copy() env["TERM"] = "xterm-256color" # Attach to screen session for reading/writing process = subprocess.Popen( ["screen", "-x", managed.full_name], stdin=slave_fd, stdout=slave_fd, stderr=slave_fd, preexec_fn=os.setsid, env=env ) os.close(slave_fd) # Set non-blocking for reads flags = fcntl.fcntl(master_fd, fcntl.F_GETFL) fcntl.fcntl(master_fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) managed.master_fd = master_fd managed.process = process managed.is_running = True managed.last_content = "" logger.debug(f"_start_reader: PTY established, fd={master_fd}, screen PID={process.pid}") # Start reader task - uses hardcopy polling with UI filtering managed.reader_task = asyncio.create_task( self._reader_loop_polling(managed) ) logger.info(f"Started reader for session: {managed.session_name}") except Exception as e: logger.error(f"Error starting reader for {managed.session_name}: {e}") managed.is_running = False async def _stop_reader(self, managed: ManagedSession): """Stop reading from a session""" managed.is_running = False if managed.reader_task: managed.reader_task.cancel() try: await managed.reader_task except asyncio.CancelledError: pass managed.reader_task = None if managed.master_fd >= 0: try: os.close(managed.master_fd) except: pass managed.master_fd = -1 if managed.process: try: managed.process.terminate() managed.process.wait(timeout=2) except: try: managed.process.kill() except: pass managed.process = None logger.info(f"Stopped reader for session: {managed.session_name}") async def _reader_loop(self, managed: ManagedSession): """Read output from PTY and broadcast to subscribers""" logger.debug(f"_reader_loop: STARTED for {managed.full_name}") buffer = "" last_broadcast = asyncio.get_event_loop().time() while managed.is_running: try: await asyncio.sleep(0.05) if managed.master_fd < 0: break # Check for data r, _, _ = select.select([managed.master_fd], [], [], 0) if r: try: data = os.read(managed.master_fd, 4096) if data: text = data.decode("utf-8", errors="replace") text = strip_ansi(text) buffer += text logger.debug(f"_reader_loop: read {len(data)} bytes, buffer now {len(buffer)}") except OSError as e: logger.debug(f"_reader_loop: OSError reading: {e}") break # Broadcast buffer periodically now = asyncio.get_event_loop().time() if buffer and (now - last_broadcast > 0.1): await self._broadcast(managed, buffer) buffer = "" last_broadcast = now # Also broadcast if buffer is getting large if len(buffer) > 1000: await self._broadcast(managed, buffer) buffer = "" last_broadcast = asyncio.get_event_loop().time() except asyncio.CancelledError: break except Exception as e: logger.error(f"Reader error for {managed.session_name}: {e}") await asyncio.sleep(1) # Final flush if buffer: await self._broadcast(managed, buffer) async def _reader_loop_hybrid(self, managed: ManagedSession): """ Hybrid reader: uses PTY for real-time reads + hardcopy for UI state. PTY reads give us streaming output, hardcopy helps verify state. """ logger.debug(f"_reader_loop_hybrid: STARTED for {managed.full_name}") buffer = "" last_broadcast = asyncio.get_event_loop().time() tmp_file = f"/tmp/screen_hc_{managed.full_name.replace('.', '_')}.txt" while managed.is_running: try: await asyncio.sleep(0.05) # 50ms for responsiveness if managed.master_fd < 0: break # Try to read from PTY (non-blocking) r, _, _ = select.select([managed.master_fd], [], [], 0) if r: try: data = os.read(managed.master_fd, 8192) if data: text = data.decode("utf-8", errors="replace") text = strip_ansi(text) buffer += text logger.debug(f"_reader_loop_hybrid: PTY read {len(data)} bytes") except OSError as e: if e.errno not in (11, 35): # EAGAIN/EWOULDBLOCK logger.error(f"_reader_loop_hybrid: PTY read error: {e}") break # Broadcast buffer periodically now = asyncio.get_event_loop().time() if buffer and (now - last_broadcast > 0.15): await self._broadcast(managed, buffer) buffer = "" last_broadcast = now # Broadcast large buffers immediately if len(buffer) > 2000: await self._broadcast(managed, buffer) buffer = "" last_broadcast = asyncio.get_event_loop().time() except asyncio.CancelledError: break except Exception as e: logger.error(f"Hybrid reader error for {managed.session_name}: {e}") await asyncio.sleep(1) # Final flush if buffer: await self._broadcast(managed, buffer) # Cleanup temp file try: os.remove(tmp_file) except: pass logger.debug(f"_reader_loop_hybrid: STOPPED for {managed.full_name}") async def _broadcast(self, managed: ManagedSession, content: str): """Broadcast content to all subscribers""" # Get a snapshot of subscribers under lock async with managed.subscribers_lock: subscribers_snapshot = list(managed.subscribers) num_subscribers = len(subscribers_snapshot) logger.debug(f"_broadcast: {managed.full_name} -> {num_subscribers} subscribers, {len(content)} chars") if num_subscribers == 0: logger.warning(f"_broadcast: NO SUBSCRIBERS for {managed.full_name}") return dead_queues = [] for queue in subscribers_snapshot: try: queue.put_nowait(content) except asyncio.QueueFull: # Drop oldest message and try again try: queue.get_nowait() queue.put_nowait(content) except: pass except Exception: dead_queues.append(queue) # Clean up dead queues under lock if dead_queues: async with managed.subscribers_lock: for q in dead_queues: managed.subscribers.discard(q) async def _reader_loop_polling(self, managed: ManagedSession): """ Read output from screen session using hardcopy polling. Strategy: Track the last seen content and detect incremental changes. - hardcopy captures the visible screen (last N lines) - Claude output appears at the bottom and scrolls up - We detect new lines at the end by comparing with previous content """ logger.debug(f"_reader_loop_polling: STARTED for {managed.full_name}") # Create temp file for hardcopy output tmp_file = f"/tmp/screen_hc_{managed.full_name.replace('.', '_')}.txt" poll_interval = 0.25 # 250ms polling for responsiveness # Track last content for incremental detection last_lines: list = [] last_content_hash = "" # UI filter patterns - Claude Code interface elements UI_PATTERNS = [ # Permissions and controls 'bypass permissions', 'ctrl+', 'shift+tab', 'Esc to', 'to cycle', 'on (shift+tab', 'to interrupt', 'press ctrl', 'again to exit', 'What should Claude do', 'Interrupted', # Claude Code branding/header 'Claude Code', 'Claude Max', 'Opus 4', 'Sonnet', 'claude-', '~/captain-claude', '/home/architect', # Working directory display # Box-drawing characters '───', '═══', '│', '┌', '└', '├', '┐', '┘', '┤', '╭', '╮', '╯', '╰', '║', '╔', '╗', '╚', '╝', '╠', '╣', '╦', '╩', '╬', # Block characters '▐', '▛', '▜', '▌', '▝', '█', '▀', '▄', '▖', '▗', '▘', '▙', '▚', '▞', '▟', '░', '▒', '▓', '■', '□', # Spinners '◐', '◑', '◒', '◓', '⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏', # UI elements '❯', '●', '○', '◆', '◇', '✓', '✗', '→', '←', '↑', '↓', # Status indicators 'Baking', 'Thinking', 'Working', 'Loading', # Tool usage indicators '� Bash', '� Read', '� Edit', '� Write', '� Glob', '� Grep', ] def clean_line(line: str) -> str: """Remove broken unicode characters (replacement chars, emojis that don't render)""" # Remove replacement character and common broken emoji patterns cleaned = line.replace('�', '').replace('\ufffd', '') # Remove zero-width characters cleaned = ''.join(c for c in cleaned if ord(c) >= 32 or c in '\n\t') return cleaned.strip() def is_ui_line(line: str) -> bool: """Check if line is a Claude Code UI element""" stripped = line.strip() # Empty or very short if len(stripped) <= 2: return True # Lines starting with 'o ' are user prompts in Claude Code if stripped.startswith('o '): return True # Lines containing UI patterns for pattern in UI_PATTERNS: if pattern in line: return True # Lines that are mostly special characters (box drawing, etc) special_chars = set('─═│┌└├┐┘┤╭╮╯╰║╔╗╚╝╠╣╦╩╬▐▛▜▌▝█▀▄░▒▓●○◆◇✓✗→←↑↓❯ \t') if all(c in special_chars for c in stripped): return True return False def extract_claude_response(lines: list) -> str: """ Extract only Claude's actual response text. Claude responses typically start with a special marker or are plain text. """ result = [] for line in lines: cleaned = clean_line(line) if not cleaned: continue if is_ui_line(cleaned): continue # Skip lines that are just the user's input echoed back if cleaned.startswith('responde') or cleaned.startswith('di ') or cleaned.startswith('hola'): continue result.append(cleaned) return '\n'.join(result) # Track bash commands for progress indicator bash_count = [0] # Use list to allow modification in nested function def is_bash_line(line: str) -> bool: """Check if line is a Bash command or its output""" # Bash tool indicators bash_patterns = [ 'Bash(', 'bash(', '$ ', '# ', # Command invocations 'Exit code', 'exit code', # Exit status '+ ', '++ ', # Shell trace output ] for p in bash_patterns: if p in line: return True # Lines that look like terminal output (start with common command output) if line.startswith(('total ', 'drwx', '-rw', 'lrwx')): # ls output return True return False def filter_content_lines(raw_lines: list) -> list: """Filter out UI lines, convert Bash to progress indicator""" result = [] bash_in_this_batch = 0 for line in raw_lines: cleaned = clean_line(line) if not cleaned: continue if is_ui_line(cleaned): continue # Convert Bash lines to progress dots if is_bash_line(cleaned): bash_in_this_batch += 1 continue # Don't add bash lines directly result.append(cleaned) # If we had bash commands, add a progress indicator if bash_in_this_batch > 0: dots = ' .' * min(bash_in_this_batch, 10) # Max 10 dots result.insert(0, f"procesando{dots}") bash_count[0] += bash_in_this_batch return result def find_overlap(old_lines: list, new_lines: list) -> int: """ Find where old_lines ends in new_lines (overlap point). Returns index in new_lines where new content starts. If no overlap found, returns 0 (treat all as new). """ if not old_lines or not new_lines: return 0 # Look for the last few lines of old content in new content # This handles scrolling - old content moves up, new appears at bottom search_window = min(len(old_lines), 10) # Check last 10 lines for i in range(search_window, 0, -1): old_tail = old_lines[-i:] # Search for this tail in new_lines for j in range(len(new_lines) - i + 1): if new_lines[j:j+i] == old_tail: # Found overlap - new content starts after this new_start = j + i if new_start < len(new_lines): return new_start else: return len(new_lines) # No new content return 0 # No overlap found - all content is new while managed.is_running: try: await asyncio.sleep(poll_interval) # Get current screen content via hardcopy (-p 0 selects window 0) cmd = ["screen", "-S", managed.full_name, "-p", "0", "-X", "hardcopy", tmp_file] result = subprocess.run(cmd, capture_output=True, timeout=2) if result.returncode != 0: continue # Read the hardcopy file try: with open(tmp_file, 'r', errors='replace') as f: raw_content = f.read() except FileNotFoundError: continue # Process and filter lines raw_lines = raw_content.split('\n') current_lines = filter_content_lines(raw_lines) # Quick hash check - skip if nothing changed content_hash = hash(tuple(current_lines)) if content_hash == last_content_hash: continue last_content_hash = content_hash # First poll - don't send initial content (it's usually just UI) # We only want to send NEW responses after the user sends a message if not last_lines: logger.debug(f"_reader_loop_polling: initial poll, {len(current_lines)} lines tracked") last_lines = current_lines.copy() continue # Detect if this is a screen refresh (completely different content) # vs incremental output (new lines at the end) overlap_start = find_overlap(last_lines, current_lines) if overlap_start == 0 and last_lines: # Check if content is completely different (screen refresh) # Compare first few lines - if different, it's a refresh compare_len = min(3, len(last_lines), len(current_lines)) if compare_len > 0 and last_lines[:compare_len] != current_lines[:compare_len]: # Screen refresh - just update tracking, don't send everything # Only send if there's meaningful new content at the end logger.debug(f"_reader_loop_polling: screen refresh detected, skipping") last_lines = current_lines.copy() continue # Get new lines (incremental output) new_lines = current_lines[overlap_start:] if overlap_start < len(current_lines) else [] # Filter out any lines that are duplicates of very recent content # (handles edge case of partial overlap detection) if new_lines and last_lines: # Remove lines that appear in the last 5 lines of previous content recent_set = set(last_lines[-5:]) new_lines = [l for l in new_lines if l not in recent_set] # Broadcast new lines if new_lines: new_content = '\n'.join(new_lines) logger.debug(f"_reader_loop_polling: {len(new_lines)} new lines at end") await self._broadcast(managed, new_content) # Update tracking last_lines = current_lines.copy() except asyncio.CancelledError: break except Exception as e: logger.error(f"Polling error for {managed.session_name}: {e}") await asyncio.sleep(1) # Cleanup temp file try: os.remove(tmp_file) except: pass logger.debug(f"_reader_loop_polling: STOPPED for {managed.full_name}") async def cleanup(self): """Stop all sessions""" async with self._lock: for managed in list(self._sessions.values()): await self._stop_reader(managed) self._sessions.clear() # Global session manager session_manager = SessionManager() # ============================================================================ # App Setup # ============================================================================ @asynccontextmanager async def lifespan(app: FastAPI): init_db() yield await session_manager.cleanup() app = FastAPI( title="Captain Claude Mobile v2", description="Chat API with centralized session management", version="2.3.0", lifespan=lifespan ) app.add_middleware( CORSMiddleware, allow_origins=ALLOWED_ORIGINS, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ============================================================================ # REST Endpoints # ============================================================================ @app.get("/health") async def health(): return {"status": "ok", "version": "2.3.0"} @app.post("/auth/login", response_model=LoginResponse) async def login(request: LoginRequest): if request.username not in VALID_USERS: raise HTTPException(status_code=401, detail="Invalid credentials") if VALID_USERS[request.username] != request.password: raise HTTPException(status_code=401, detail="Invalid credentials") token, expires = create_token(request.username) return LoginResponse(token=token, expires_at=expires.isoformat()) @app.get("/sessions", response_model=list[ScreenSession]) async def get_sessions(user: str = Depends(get_current_user)): return list_screen_sessions() @app.post("/sessions", response_model=ScreenSession) async def create_session(request: CreateSessionRequest, user: str = Depends(get_current_user)): try: return create_screen_session(request.name) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # ============================================================================ # WebSocket Chat # ============================================================================ @app.websocket("/ws/chat") async def websocket_chat(websocket: WebSocket): """WebSocket endpoint for chat""" await websocket.accept() # Auth handshake try: auth_msg = await asyncio.wait_for(websocket.receive_json(), timeout=10.0) token = auth_msg.get("token", "") username = verify_token(token) if not username: await websocket.send_json({"type": "error", "message": "Invalid token"}) await websocket.close(code=4001) return except asyncio.TimeoutError: await websocket.send_json({"type": "error", "message": "Auth timeout"}) await websocket.close(code=4001) return except Exception as e: await websocket.send_json({"type": "error", "message": f"Auth error: {str(e)}"}) await websocket.close(code=4001) return # Send init - include full_name for unique identification sessions = list_screen_sessions() await websocket.send_json({ "type": "init", "user": username, "sessions": [{"name": s.name, "pid": s.pid, "full_name": s.full_name, "attached": s.attached} for s in sessions] }) # Current subscription state - use full_name as key current_full_name: Optional[str] = None current_queue: Optional[asyncio.Queue] = None output_task: Optional[asyncio.Task] = None async def output_forwarder(queue: asyncio.Queue): """Forward output from queue to websocket""" logger.debug("output_forwarder: STARTED") while True: try: content = await queue.get() logger.debug(f"output_forwarder: sending {len(content)} chars to websocket") await websocket.send_json({ "type": "output", "content": content }) logger.debug("output_forwarder: sent OK") except asyncio.CancelledError: logger.debug("output_forwarder: cancelled") break except Exception as e: logger.error(f"output_forwarder: error {e}") break try: while True: data = await websocket.receive_json() msg_type = data.get("type", "") if msg_type == "connect_session": # Accept full_name (PID.name) for unique session identification full_name = data.get("full_name") or data.get("session_name", "") # Validate format: either "PID.name" or just "name" if not re.match(r'^[0-9]+\.[a-z0-9_-]+$|^[a-z0-9_-]+$', full_name): await websocket.send_json({"type": "error", "message": "Invalid session name"}) continue # If short name provided, try to resolve it (backward compat) if "." not in full_name: resolved = get_full_session_name(full_name) if not resolved: await websocket.send_json({"type": "error", "message": f"Session '{full_name}' not found"}) continue full_name = resolved # Unsubscribe from previous session if current_full_name and current_queue: if output_task: output_task.cancel() try: await output_task except asyncio.CancelledError: pass await session_manager.unsubscribe(current_full_name, current_queue) # Subscribe to new session try: current_queue = await session_manager.subscribe(full_name) current_full_name = full_name # Start output forwarder output_task = asyncio.create_task(output_forwarder(current_queue)) await websocket.send_json({ "type": "session_connected", "session_name": full_name.split(".", 1)[1] if "." in full_name else full_name, "full_name": full_name }) except ValueError as e: await websocket.send_json({"type": "error", "message": str(e)}) current_full_name = None current_queue = None elif msg_type == "create_session": session_name = data.get("session_name", "") try: new_session = create_screen_session(session_name) await websocket.send_json({ "type": "session_created", "session": { "name": new_session.name, "pid": new_session.pid, "full_name": new_session.full_name, "attached": new_session.attached } }) except ValueError as e: await websocket.send_json({"type": "error", "message": str(e)}) elif msg_type == "message": content = data.get("content", "") logger.debug(f"message: current_full_name={current_full_name}, content={content[:30] if content else 'empty'}") if current_full_name and content: success = await session_manager.send_input(current_full_name, content) if not success: await websocket.send_json({ "type": "error", "message": "Failed to send message to session" }) elif msg_type == "list_sessions": sessions = list_screen_sessions() await websocket.send_json({ "type": "sessions_list", "sessions": [{"name": s.name, "pid": s.pid, "full_name": s.full_name, "attached": s.attached} for s in sessions] }) elif msg_type == "ping": await websocket.send_json({"type": "pong"}) except WebSocketDisconnect: pass except Exception as e: try: await websocket.send_json({"type": "error", "message": str(e)}) except: pass finally: # Cleanup if output_task: output_task.cancel() try: await output_task except asyncio.CancelledError: pass if current_full_name and current_queue: await session_manager.unsubscribe(current_full_name, current_queue) # ============================================================================ # Main # ============================================================================ if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=3030)