Ai_Assistant/server/server.py

559 lines
16 KiB
Python
Raw Permalink Normal View History

2026-05-24 13:31:30 +02:00
# server.py
import asyncio
import json
import logging
import time
from typing import Optional, Set
from pathlib import Path
import os
import uvicorn
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from click_reactions import build_click_reaction
# BASE_DIR = Path(__file__).resolve().parent.parent
# UPLOADS_DIR = Path(os.getenv("UPLOADS_DIR", BASE_DIR / "uploads"))
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
clients: list[WebSocket] = []
app = FastAPI()
# app.mount("/uploads", StaticFiles(directory=str(UPLOADS_DIR)), name="uploads")
# app.mount("/audio", StaticFiles(directory=str(UPLOADS_DIR)), name="audio")
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # or replace with ["http://localhost:5173"] for tighter security
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class AnimationPayload(BaseModel):
animate_type: str # "start_vrma", "start_mixamo", or "auto" (auto-detect from extension)
animation_url: str
play_once: Optional[bool] = False
crop_start: Optional[float] = 0.0 # seconds to crop from start
crop_end: Optional[float] = 0.0 # seconds to crop from end
lock_position: Optional[bool] = False # If true, animation plays in place (no root motion)
track_position: Optional[bool] = True # If true, character stays at end position after animation
class CombinedPayload(BaseModel):
animation_url: str
audio_path: str
expression: str = "neutral"
delay: float = 0.0 # seconds
class SetStateRequest(BaseModel):
"""Set the VRM avatar's animation state"""
state: str # idle, listening, thinking, talking
# --- Track connections ---
active_connections: Set[WebSocket] = set()
status_connections: Set[WebSocket] = set()
# --- Pending user actions (clicks etc.) waiting to be merged into the next LLM prompt ---
_pending_user_actions: list[dict] = []
# --- Simple status page (optional) ---
html = """
<!DOCTYPE html>
<html>
<head><title>VRM Trigger Server</title></head>
<body>
<h1>VRM Trigger Server</h1>
<p>WebSocket clients: <span id="count">0</span></p>
<script>
const ws = new WebSocket(`ws://${location.host}/ws_status`);
ws.onmessage = e => {
const msg = JSON.parse(e.data);
if (msg.type === 'count_update') {
document.getElementById('count').textContent = msg.count;
}
};
</script>
</body>
</html>
"""
@app.get("/")
async def root():
return HTMLResponse(html)
# --- Models ---
class TalkRequest(BaseModel):
audio_path: str
expression: str = "neutral"
audio_text: str
audio_duraction: int
class MessageRequest(BaseModel):
message_text: str
class ClickInteractionRequest(BaseModel):
type: str
bone: str
region: str
# --- Notification logic ---
async def notify_clients(message: dict):
"""Broadcast JSON `message` to every active WS client."""
if not active_connections:
logger.info("No clients connected; skipping notify.")
return
data = json.dumps(message)
logger.info(f"Broadcasting to {len(active_connections)} client(s): {data}")
coros = [ws.send_text(data) for ws in list(active_connections)]
results = await asyncio.gather(*coros, return_exceptions=True)
for ws, res in zip(list(active_connections), results):
if isinstance(res, Exception):
logger.error(f"Failed to send to {ws.client}: {res}")
active_connections.discard(ws)
async def broadcast_status(count: int):
msg = json.dumps({"type": "count_update", "count": count})
coros = [ws.send_text(msg) for ws in list(status_connections)]
await asyncio.gather(*coros, return_exceptions=True)
# NEW: Function to send transcription results back to client
async def send_transcription_result(text: str):
"""Send transcription result back to clients for editing."""
message = {"type": "transcription_result", "text": text}
await notify_clients(message)
# --- WebSocket endpoints ---
@app.websocket("/ws")
async def ws_endpoint(ws: WebSocket):
await ws.accept()
active_connections.add(ws)
logger.info(f"Client connected: {ws.client} (total {len(active_connections)})")
await broadcast_status(len(active_connections))
try:
while True:
# Keep-alive or handle incoming if needed
await ws.receive_text()
except WebSocketDisconnect:
active_connections.discard(ws)
logger.info(f"Client disconnected: {ws.client} (total {len(active_connections)})")
await broadcast_status(len(active_connections))
except Exception as e:
logger.error(f"WS error: {e}")
active_connections.discard(ws)
await broadcast_status(len(active_connections))
@app.websocket("/ws_status")
async def ws_status(ws: WebSocket):
await ws.accept()
status_connections.add(ws)
# send initial count
await ws.send_text(json.dumps({"type": "count_update", "count": len(active_connections)}))
try:
while True:
msg = await ws.receive_text()
if msg == "ping":
await ws.send_text("pong")
except WebSocketDisconnect:
status_connections.discard(ws)
except Exception:
status_connections.discard(ws)
# --- HTTP trigger endpoint ---
@app.post("/talk")
async def talk(req: TalkRequest):
"""Receive audio_path & optional expression, broadcast to VRM clients."""
payload = {
"type": "start_animation",
"audio_path": req.audio_path,
"expression": req.expression,
"audio_text": req.audio_text,
"audio_duraction": req.audio_duraction
}
await notify_clients(payload)
return {"status": "sent", "payload": payload}
@app.post("/animate")
async def animate(payload: AnimationPayload):
# Auto-detect animation type from file extension if set to "auto"
anim_type = payload.animate_type
if anim_type == "auto":
url_lower = payload.animation_url.lower()
if url_lower.endswith(".vrma"):
anim_type = "start_vrma"
elif url_lower.endswith(".fbx"):
anim_type = "start_mixamo"
else:
# Default to mixamo for unknown extensions
anim_type = "start_mixamo"
logger.info(f"Auto-detected animation type: {anim_type} for {payload.animation_url}")
# forward these fields to clients
forwarded = {
"type": anim_type,
"animation_url": payload.animation_url,
"play_once": payload.play_once,
"crop_start": payload.crop_start,
"crop_end": payload.crop_end,
"lock_position": payload.lock_position,
"track_position": payload.track_position,
}
await notify_clients(forwarded)
return {"status": "sent", "payload": forwarded}
@app.post("/animate_and_talk")
async def animate_and_talk(payload: CombinedPayload):
payload = {
"type": "start_vrma_and_talk",
"animation_url": payload.animation_url,
"audio_path": payload.audio_path,
"expression": payload.expression,
"delay": payload.delay
}
for ws in clients:
await ws.send_json(payload)
return {"status": "combined sent"}
# ============ STATE CONTROL ============
@app.post("/set_state")
async def set_state(req: SetStateRequest):
"""
Set the VRM avatar's animation state.
This controls head microexpressions and animations.
Valid states:
- idle: Avatar looks around naturally with eye leading
- listening: Avatar nods and tilts head while listening
- thinking: Avatar looks away with pauses while thinking
- talking: Avatar nods frequently while talking
Example:
POST /set_state
{
"state": "idle"
}
POST /set_state
{
"state": "talking"
}
"""
valid_states = ["idle", "listening", "thinking", "talking"]
if req.state not in valid_states:
return {
"status": "error",
"message": f"Invalid state: {req.state}",
"valid_states": valid_states
}
payload = {
"type": "set_state",
"state": req.state
}
await notify_clients(payload)
return {
"status": "state_set",
"state": req.state
}
# Add these models to your existing server.py
from pydantic import BaseModel
from typing import Optional
# ============ MOVEMENT MODELS ============
class WalkToRequest(BaseModel):
x: float
y: float
z: float
speed: Optional[float] = 1.5 # Default walking speed
class TeleportRequest(BaseModel):
x: float
y: float
z: float
class SetSpeedRequest(BaseModel):
speed: float
class LoadAnimationRequest(BaseModel):
url: str
anim_type: str # 'walk' or 'idle'
class MovementStatusRequest(BaseModel):
pass # Just a trigger to get status
# ============ MOVEMENT ENDPOINTS ============
@app.post("/walk_to")
async def walk_to(req: WalkToRequest):
"""
Command VRM character to walk to specified coordinates.
Example:
POST /walk_to
{
"x": 2.0,
"y": 0.0,
"z": 3.0,
"speed": 1.5
}
"""
payload = {
"type": "walk_to",
"x": req.x,
"y": req.y,
"z": req.z,
"speed": req.speed
}
await notify_clients(payload)
return {
"status": "walking",
"target": {"x": req.x, "y": req.y, "z": req.z},
"speed": req.speed
}
@app.post("/stop_movement")
async def stop_movement():
"""
Stop the VRM character's current movement.
Character will return to idle animation.
"""
payload = {"type": "stop_movement"}
await notify_clients(payload)
return {"status": "stopped"}
@app.post("/teleport_to")
async def teleport_to(req: TeleportRequest):
"""
Instantly teleport VRM character to specified coordinates.
No walking animation, instant position change.
Example:
POST /teleport_to
{
"x": 5.0,
"y": 0.0,
"z": -2.0
}
"""
payload = {
"type": "teleport_to",
"x": req.x,
"y": req.y,
"z": req.z
}
await notify_clients(payload)
return {
"status": "teleported",
"position": {"x": req.x, "y": req.y, "z": req.z}
}
@app.post("/set_movement_speed")
async def set_movement_speed(req: SetSpeedRequest):
"""
Change the walking speed of the character.
Animation speed will adjust accordingly.
Example:
POST /set_movement_speed
{"speed": 2.5} # Faster
{"speed": 0.8} # Slower
"""
payload = {
"type": "set_speed",
"speed": req.speed
}
await notify_clients(payload)
return {"status": "speed_updated", "speed": req.speed}
@app.post("/load_movement_animation")
async def load_movement_animation(req: LoadAnimationRequest):
"""
Load a walk or idle animation for the movement system.
Example:
POST /load_movement_animation
{
"url": "/animations/walk.glb",
"anim_type": "walk"
}
{
"url": "/animations/idle.glb",
"anim_type": "idle"
}
"""
if req.anim_type == "walk":
payload = {
"type": "load_walk_animation",
"url": req.url
}
elif req.anim_type == "idle":
payload = {
"type": "load_idle_animation",
"url": req.url
}
else:
return {"status": "error", "message": "anim_type must be 'walk' or 'idle'"}
await notify_clients(payload)
return {
"status": "loading",
"anim_type": req.anim_type,
"url": req.url
}
# ============ CLICK INTERACTION ============
async def _broadcast_idle_after(payload: dict, delay: float):
"""Wait `delay` seconds, then broadcast the idle animation payload."""
try:
await asyncio.sleep(delay)
await notify_clients(payload)
logger.info(f"[click_interact] idle broadcast after {delay}s")
except asyncio.CancelledError:
logger.info("[click_interact] idle scheduler cancelled")
raise
except Exception as e:
logger.error(f"[click_interact] idle broadcast failed: {e}")
@app.post("/send_click_interaction")
async def send_click_interaction(req: ClickInteractionRequest):
"""
Handle a click/touch interaction from the client.
The client posts {type, bone, region}. We pick a feedback sound + reaction
animation based on region/bone, broadcast both to active WebSocket clients,
then schedule a return-to-idle broadcast after the per-animation delay.
"""
logger.info(
f"[click_interact] received type={req.type!r} region={req.region!r} bone={req.bone!r}"
)
reaction = build_click_reaction(req.region, req.bone)
await notify_clients(reaction["sound"])
logger.info(
f"[click_interact] sound broadcast: {reaction['sound']['audio_path']}"
)
await notify_clients(reaction["animation"])
logger.info(
f"[click_interact] animation broadcast: {reaction['animation']['animation_url']}"
)
asyncio.create_task(_broadcast_idle_after(reaction["idle"], reaction["idle_delay"]))
logger.info(f"[click_interact] idle scheduled in {reaction['idle_delay']}s")
# Buffer the action so the next LLM turn can mention it.
_pending_user_actions.append(
{"region": req.region, "bone": req.bone, "ts": time.time()}
)
logger.info(
f"[click_interact] queued user action; pending={len(_pending_user_actions)}"
)
return {
"status": "click_handled",
"region": req.region,
"bone": req.bone,
"animation": reaction["animation"]["animation_url"],
"idle_in": reaction["idle_delay"],
"pending_actions": len(_pending_user_actions),
}
@app.get("/pop_pending_actions")
async def pop_pending_actions():
"""
Drain and return any pending user actions (e.g. clicks) buffered since
the last call. main_chat_v9 calls this after recording to fold the
actions into the next LLM prompt as e.g. "[the user touched your bust]".
"""
global _pending_user_actions
actions = _pending_user_actions
_pending_user_actions = []
if actions:
logger.info(f"[pop_pending_actions] drained {len(actions)} action(s)")
return {"actions": actions, "count": len(actions)}
# ============ VR POSITION TRACKING ============
class VRPositionUpdate(BaseModel):
x: float
y: float
z: float
rx: Optional[float] = 0.0
ry: Optional[float] = 0.0
rz: Optional[float] = 0.0
timestamp: Optional[int] = None
# In-memory latest VR position
_vr_position = {
"x": 0.0, "y": 0.0, "z": 0.0,
"rx": 0.0, "ry": 0.0, "rz": 0.0,
"timestamp": 0,
}
@app.post("/vr/position")
async def update_vr_position(pos: VRPositionUpdate):
"""
Receive VR headset position update from the client.
Called periodically by vrPositionTracker.js.
"""
_vr_position.update({
"x": pos.x, "y": pos.y, "z": pos.z,
"rx": pos.rx, "ry": pos.ry, "rz": pos.rz,
"timestamp": pos.timestamp or 0,
})
return {"status": "ok"}
@app.get("/vr/position")
async def get_vr_position():
"""
Get the latest VR headset position.
Returns:
{x, y, z, rx, ry, rz, timestamp}
"""
return _vr_position
# --- Run with: python server.py ---
if __name__ == "__main__":
uvicorn.run("server:app", host="127.0.0.1", port=8001, reload=True)