# server.py import asyncio import json import logging import time from typing import Optional, Set from pathlib import Path import os import uvicorn from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.responses import HTMLResponse from pydantic import BaseModel from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from click_reactions import build_click_reaction # BASE_DIR = Path(__file__).resolve().parent.parent # UPLOADS_DIR = Path(os.getenv("UPLOADS_DIR", BASE_DIR / "uploads")) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) clients: list[WebSocket] = [] app = FastAPI() # app.mount("/uploads", StaticFiles(directory=str(UPLOADS_DIR)), name="uploads") # app.mount("/audio", StaticFiles(directory=str(UPLOADS_DIR)), name="audio") # Enable CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], # or replace with ["http://localhost:5173"] for tighter security allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class AnimationPayload(BaseModel): animate_type: str # "start_vrma", "start_mixamo", or "auto" (auto-detect from extension) animation_url: str play_once: Optional[bool] = False crop_start: Optional[float] = 0.0 # seconds to crop from start crop_end: Optional[float] = 0.0 # seconds to crop from end lock_position: Optional[bool] = False # If true, animation plays in place (no root motion) track_position: Optional[bool] = True # If true, character stays at end position after animation class CombinedPayload(BaseModel): animation_url: str audio_path: str expression: str = "neutral" delay: float = 0.0 # seconds class SetStateRequest(BaseModel): """Set the VRM avatar's animation state""" state: str # idle, listening, thinking, talking # --- Track connections --- active_connections: Set[WebSocket] = set() status_connections: Set[WebSocket] = set() # --- Pending user actions (clicks etc.) waiting to be merged into the next LLM prompt --- _pending_user_actions: list[dict] = [] # --- Simple status page (optional) --- html = """
WebSocket clients: 0
""" @app.get("/") async def root(): return HTMLResponse(html) # --- Models --- class TalkRequest(BaseModel): audio_path: str expression: str = "neutral" audio_text: str audio_duraction: int class MessageRequest(BaseModel): message_text: str class ClickInteractionRequest(BaseModel): type: str bone: str region: str # --- Notification logic --- async def notify_clients(message: dict): """Broadcast JSON `message` to every active WS client.""" if not active_connections: logger.info("No clients connected; skipping notify.") return data = json.dumps(message) logger.info(f"Broadcasting to {len(active_connections)} client(s): {data}") coros = [ws.send_text(data) for ws in list(active_connections)] results = await asyncio.gather(*coros, return_exceptions=True) for ws, res in zip(list(active_connections), results): if isinstance(res, Exception): logger.error(f"Failed to send to {ws.client}: {res}") active_connections.discard(ws) async def broadcast_status(count: int): msg = json.dumps({"type": "count_update", "count": count}) coros = [ws.send_text(msg) for ws in list(status_connections)] await asyncio.gather(*coros, return_exceptions=True) # NEW: Function to send transcription results back to client async def send_transcription_result(text: str): """Send transcription result back to clients for editing.""" message = {"type": "transcription_result", "text": text} await notify_clients(message) # --- WebSocket endpoints --- @app.websocket("/ws") async def ws_endpoint(ws: WebSocket): await ws.accept() active_connections.add(ws) logger.info(f"Client connected: {ws.client} (total {len(active_connections)})") await broadcast_status(len(active_connections)) try: while True: # Keep-alive or handle incoming if needed await ws.receive_text() except WebSocketDisconnect: active_connections.discard(ws) logger.info(f"Client disconnected: {ws.client} (total {len(active_connections)})") await broadcast_status(len(active_connections)) except Exception as e: logger.error(f"WS error: {e}") active_connections.discard(ws) await broadcast_status(len(active_connections)) @app.websocket("/ws_status") async def ws_status(ws: WebSocket): await ws.accept() status_connections.add(ws) # send initial count await ws.send_text(json.dumps({"type": "count_update", "count": len(active_connections)})) try: while True: msg = await ws.receive_text() if msg == "ping": await ws.send_text("pong") except WebSocketDisconnect: status_connections.discard(ws) except Exception: status_connections.discard(ws) # --- HTTP trigger endpoint --- @app.post("/talk") async def talk(req: TalkRequest): """Receive audio_path & optional expression, broadcast to VRM clients.""" payload = { "type": "start_animation", "audio_path": req.audio_path, "expression": req.expression, "audio_text": req.audio_text, "audio_duraction": req.audio_duraction } await notify_clients(payload) return {"status": "sent", "payload": payload} @app.post("/animate") async def animate(payload: AnimationPayload): # Auto-detect animation type from file extension if set to "auto" anim_type = payload.animate_type if anim_type == "auto": url_lower = payload.animation_url.lower() if url_lower.endswith(".vrma"): anim_type = "start_vrma" elif url_lower.endswith(".fbx"): anim_type = "start_mixamo" else: # Default to mixamo for unknown extensions anim_type = "start_mixamo" logger.info(f"Auto-detected animation type: {anim_type} for {payload.animation_url}") # forward these fields to clients forwarded = { "type": anim_type, "animation_url": payload.animation_url, "play_once": payload.play_once, "crop_start": payload.crop_start, "crop_end": payload.crop_end, "lock_position": payload.lock_position, "track_position": payload.track_position, } await notify_clients(forwarded) return {"status": "sent", "payload": forwarded} @app.post("/animate_and_talk") async def animate_and_talk(payload: CombinedPayload): payload = { "type": "start_vrma_and_talk", "animation_url": payload.animation_url, "audio_path": payload.audio_path, "expression": payload.expression, "delay": payload.delay } for ws in clients: await ws.send_json(payload) return {"status": "combined sent"} # ============ STATE CONTROL ============ @app.post("/set_state") async def set_state(req: SetStateRequest): """ Set the VRM avatar's animation state. This controls head microexpressions and animations. Valid states: - idle: Avatar looks around naturally with eye leading - listening: Avatar nods and tilts head while listening - thinking: Avatar looks away with pauses while thinking - talking: Avatar nods frequently while talking Example: POST /set_state { "state": "idle" } POST /set_state { "state": "talking" } """ valid_states = ["idle", "listening", "thinking", "talking"] if req.state not in valid_states: return { "status": "error", "message": f"Invalid state: {req.state}", "valid_states": valid_states } payload = { "type": "set_state", "state": req.state } await notify_clients(payload) return { "status": "state_set", "state": req.state } # Add these models to your existing server.py from pydantic import BaseModel from typing import Optional # ============ MOVEMENT MODELS ============ class WalkToRequest(BaseModel): x: float y: float z: float speed: Optional[float] = 1.5 # Default walking speed class TeleportRequest(BaseModel): x: float y: float z: float class SetSpeedRequest(BaseModel): speed: float class LoadAnimationRequest(BaseModel): url: str anim_type: str # 'walk' or 'idle' class MovementStatusRequest(BaseModel): pass # Just a trigger to get status # ============ MOVEMENT ENDPOINTS ============ @app.post("/walk_to") async def walk_to(req: WalkToRequest): """ Command VRM character to walk to specified coordinates. Example: POST /walk_to { "x": 2.0, "y": 0.0, "z": 3.0, "speed": 1.5 } """ payload = { "type": "walk_to", "x": req.x, "y": req.y, "z": req.z, "speed": req.speed } await notify_clients(payload) return { "status": "walking", "target": {"x": req.x, "y": req.y, "z": req.z}, "speed": req.speed } @app.post("/stop_movement") async def stop_movement(): """ Stop the VRM character's current movement. Character will return to idle animation. """ payload = {"type": "stop_movement"} await notify_clients(payload) return {"status": "stopped"} @app.post("/teleport_to") async def teleport_to(req: TeleportRequest): """ Instantly teleport VRM character to specified coordinates. No walking animation, instant position change. Example: POST /teleport_to { "x": 5.0, "y": 0.0, "z": -2.0 } """ payload = { "type": "teleport_to", "x": req.x, "y": req.y, "z": req.z } await notify_clients(payload) return { "status": "teleported", "position": {"x": req.x, "y": req.y, "z": req.z} } @app.post("/set_movement_speed") async def set_movement_speed(req: SetSpeedRequest): """ Change the walking speed of the character. Animation speed will adjust accordingly. Example: POST /set_movement_speed {"speed": 2.5} # Faster {"speed": 0.8} # Slower """ payload = { "type": "set_speed", "speed": req.speed } await notify_clients(payload) return {"status": "speed_updated", "speed": req.speed} @app.post("/load_movement_animation") async def load_movement_animation(req: LoadAnimationRequest): """ Load a walk or idle animation for the movement system. Example: POST /load_movement_animation { "url": "/animations/walk.glb", "anim_type": "walk" } { "url": "/animations/idle.glb", "anim_type": "idle" } """ if req.anim_type == "walk": payload = { "type": "load_walk_animation", "url": req.url } elif req.anim_type == "idle": payload = { "type": "load_idle_animation", "url": req.url } else: return {"status": "error", "message": "anim_type must be 'walk' or 'idle'"} await notify_clients(payload) return { "status": "loading", "anim_type": req.anim_type, "url": req.url } # ============ CLICK INTERACTION ============ async def _broadcast_idle_after(payload: dict, delay: float): """Wait `delay` seconds, then broadcast the idle animation payload.""" try: await asyncio.sleep(delay) await notify_clients(payload) logger.info(f"[click_interact] idle broadcast after {delay}s") except asyncio.CancelledError: logger.info("[click_interact] idle scheduler cancelled") raise except Exception as e: logger.error(f"[click_interact] idle broadcast failed: {e}") @app.post("/send_click_interaction") async def send_click_interaction(req: ClickInteractionRequest): """ Handle a click/touch interaction from the client. The client posts {type, bone, region}. We pick a feedback sound + reaction animation based on region/bone, broadcast both to active WebSocket clients, then schedule a return-to-idle broadcast after the per-animation delay. """ logger.info( f"[click_interact] received type={req.type!r} region={req.region!r} bone={req.bone!r}" ) reaction = build_click_reaction(req.region, req.bone) await notify_clients(reaction["sound"]) logger.info( f"[click_interact] sound broadcast: {reaction['sound']['audio_path']}" ) await notify_clients(reaction["animation"]) logger.info( f"[click_interact] animation broadcast: {reaction['animation']['animation_url']}" ) asyncio.create_task(_broadcast_idle_after(reaction["idle"], reaction["idle_delay"])) logger.info(f"[click_interact] idle scheduled in {reaction['idle_delay']}s") # Buffer the action so the next LLM turn can mention it. _pending_user_actions.append( {"region": req.region, "bone": req.bone, "ts": time.time()} ) logger.info( f"[click_interact] queued user action; pending={len(_pending_user_actions)}" ) return { "status": "click_handled", "region": req.region, "bone": req.bone, "animation": reaction["animation"]["animation_url"], "idle_in": reaction["idle_delay"], "pending_actions": len(_pending_user_actions), } @app.get("/pop_pending_actions") async def pop_pending_actions(): """ Drain and return any pending user actions (e.g. clicks) buffered since the last call. main_chat_v9 calls this after recording to fold the actions into the next LLM prompt as e.g. "[the user touched your bust]". """ global _pending_user_actions actions = _pending_user_actions _pending_user_actions = [] if actions: logger.info(f"[pop_pending_actions] drained {len(actions)} action(s)") return {"actions": actions, "count": len(actions)} # ============ VR POSITION TRACKING ============ class VRPositionUpdate(BaseModel): x: float y: float z: float rx: Optional[float] = 0.0 ry: Optional[float] = 0.0 rz: Optional[float] = 0.0 timestamp: Optional[int] = None # In-memory latest VR position _vr_position = { "x": 0.0, "y": 0.0, "z": 0.0, "rx": 0.0, "ry": 0.0, "rz": 0.0, "timestamp": 0, } @app.post("/vr/position") async def update_vr_position(pos: VRPositionUpdate): """ Receive VR headset position update from the client. Called periodically by vrPositionTracker.js. """ _vr_position.update({ "x": pos.x, "y": pos.y, "z": pos.z, "rx": pos.rx, "ry": pos.ry, "rz": pos.rz, "timestamp": pos.timestamp or 0, }) return {"status": "ok"} @app.get("/vr/position") async def get_vr_position(): """ Get the latest VR headset position. Returns: {x, y, z, rx, ry, rz, timestamp} """ return _vr_position # --- Run with: python server.py --- if __name__ == "__main__": uvicorn.run("server:app", host="127.0.0.1", port=8001, reload=True)