Ai_Assistant/_Backup/server/server.py

261 lines
7.9 KiB
Python
Raw Normal View History

2026-05-24 13:31:30 +02:00
# server.py
import asyncio
import json
import logging
from typing import Optional, Set
from pathlib import Path
import os
import uvicorn
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
# BASE_DIR = Path(__file__).resolve().parent.parent
# UPLOADS_DIR = Path(os.getenv("UPLOADS_DIR", BASE_DIR / "uploads"))
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
clients: list[WebSocket] = []
app = FastAPI()
# app.mount("/uploads", StaticFiles(directory=str(UPLOADS_DIR)), name="uploads")
# app.mount("/audio", StaticFiles(directory=str(UPLOADS_DIR)), name="audio")
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # or replace with ["http://localhost:5173"] for tighter security
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class AnimationPayload(BaseModel):
animate_type: str # "start_vrma", "start_mixamo", or "auto" (auto-detect from extension)
animation_url: str
play_once: Optional[bool] = False
crop_start: Optional[float] = 0.0 # seconds to crop from start
crop_end: Optional[float] = 0.0 # seconds to crop from end
lock_position: Optional[bool] = False # If true, animation plays in place (no root motion)
track_position: Optional[bool] = True # If true, character stays at end position after animation
class CombinedPayload(BaseModel):
animation_url: str
audio_path: str
expression: str = "neutral"
delay: float = 0.0 # seconds
class SetStateRequest(BaseModel):
"""Set the VRM avatar's animation state"""
state: str # idle, listening, thinking, talking
# --- Track connections ---
active_connections: Set[WebSocket] = set()
status_connections: Set[WebSocket] = set()
# --- Simple status page (optional) ---
html = """
<!DOCTYPE html>
<html>
<head><title>VRM Trigger Server</title></head>
<body>
<h1>VRM Trigger Server</h1>
<p>WebSocket clients: <span id="count">0</span></p>
<script>
const ws = new WebSocket(`ws://${location.host}/ws_status`);
ws.onmessage = e => {
const msg = JSON.parse(e.data);
if (msg.type === 'count_update') {
document.getElementById('count').textContent = msg.count;
}
};
</script>
</body>
</html>
"""
@app.get("/")
async def root():
return HTMLResponse(html)
# --- Models ---
class TalkRequest(BaseModel):
audio_path: str
expression: str = "neutral"
audio_text: str
audio_duraction: int
# --- Notification logic ---
async def notify_clients(message: dict):
"""Broadcast JSON `message` to every active WS client."""
if not active_connections:
logger.info("No clients connected; skipping notify.")
return
data = json.dumps(message)
logger.info(f"Broadcasting to {len(active_connections)} client(s): {data}")
coros = [ws.send_text(data) for ws in list(active_connections)]
results = await asyncio.gather(*coros, return_exceptions=True)
for ws, res in zip(list(active_connections), results):
if isinstance(res, Exception):
logger.error(f"Failed to send to {ws.client}: {res}")
active_connections.discard(ws)
async def broadcast_status(count: int):
msg = json.dumps({"type": "count_update", "count": count})
coros = [ws.send_text(msg) for ws in list(status_connections)]
await asyncio.gather(*coros, return_exceptions=True)
# --- WebSocket endpoints ---
@app.websocket("/ws")
async def ws_endpoint(ws: WebSocket):
await ws.accept()
active_connections.add(ws)
logger.info(f"Client connected: {ws.client} (total {len(active_connections)})")
await broadcast_status(len(active_connections))
try:
while True:
# Keep-alive or handle incoming if needed
await ws.receive_text()
except WebSocketDisconnect:
active_connections.discard(ws)
logger.info(f"Client disconnected: {ws.client} (total {len(active_connections)})")
await broadcast_status(len(active_connections))
except Exception as e:
logger.error(f"WS error: {e}")
active_connections.discard(ws)
await broadcast_status(len(active_connections))
@app.websocket("/ws_status")
async def ws_status(ws: WebSocket):
await ws.accept()
status_connections.add(ws)
# send initial count
await ws.send_text(json.dumps({"type": "count_update", "count": len(active_connections)}))
try:
while True:
msg = await ws.receive_text()
if msg == "ping":
await ws.send_text("pong")
except WebSocketDisconnect:
status_connections.discard(ws)
except Exception:
status_connections.discard(ws)
# --- HTTP trigger endpoint ---
@app.post("/talk")
async def talk(req: TalkRequest):
"""Receive audio_path & optional expression, broadcast to VRM clients."""
payload = {
"type": "start_animation",
"audio_path": req.audio_path,
"expression": req.expression,
"audio_text": req.audio_text,
"audio_duraction": req.audio_duraction
}
await notify_clients(payload)
return {"status": "sent", "payload": payload}
@app.post("/animate")
async def animate(payload: AnimationPayload):
# Auto-detect animation type from file extension if set to "auto"
anim_type = payload.animate_type
if anim_type == "auto":
url_lower = payload.animation_url.lower()
if url_lower.endswith(".vrma"):
anim_type = "start_vrma"
elif url_lower.endswith(".fbx"):
anim_type = "start_mixamo"
else:
# Default to mixamo for unknown extensions
anim_type = "start_mixamo"
logger.info(f"Auto-detected animation type: {anim_type} for {payload.animation_url}")
# forward these fields to clients
forwarded = {
"type": anim_type,
"animation_url": payload.animation_url,
"play_once": payload.play_once,
"crop_start": payload.crop_start,
"crop_end": payload.crop_end,
"lock_position": payload.lock_position,
"track_position": payload.track_position,
}
await notify_clients(forwarded)
return {"status": "sent", "payload": forwarded}
@app.post("/animate_and_talk")
async def animate_and_talk(payload: CombinedPayload):
payload = {
"type": "start_vrma_and_talk",
"animation_url": payload.animation_url,
"audio_path": payload.audio_path,
"expression": payload.expression,
"delay": payload.delay
}
for ws in clients:
await ws.send_json(payload)
return {"status": "combined sent"}
# ============ STATE CONTROL ============
@app.post("/set_state")
async def set_state(req: SetStateRequest):
"""
Set the VRM avatar's animation state.
This controls head microexpressions and animations.
Valid states:
- idle: Avatar looks around naturally with eye leading
- listening: Avatar nods and tilts head while listening
- thinking: Avatar looks away with pauses while thinking
- talking: Avatar nods frequently while talking
Example:
POST /set_state
{
"state": "idle"
}
POST /set_state
{
"state": "talking"
}
"""
valid_states = ["idle", "listening", "thinking", "talking"]
if req.state not in valid_states:
return {
"status": "error",
"message": f"Invalid state: {req.state}",
"valid_states": valid_states
}
payload = {
"type": "set_state",
"state": req.state
}
await notify_clients(payload)
return {
"status": "state_set",
"state": req.state
}
# --- Run with: python server.py ---
if __name__ == "__main__":
uvicorn.run("server:app", host="127.0.0.1", port=8001, reload=True)