Ai_Assistant/server/main_chat_v9.py
2026-05-24 13:31:30 +02:00

759 lines
27 KiB
Python

#!/usr/bin/env python3
"""
main_chat_v9.py - voice loop with streaming LLM and optional MCP tool calls.
Two triggers run independently:
1. Main loop: blocks on speech via record_on_speech, then runs an LLM turn.
2. Background click dispatcher: polls /pop_pending_actions and runs an LLM
turn whenever the user touches the avatar.
Both triggers funnel through run_llm_turn() which is serialized by a single
lock so chat history stays consistent. Audio chunks from both triggers go to
the same playback worker, which serializes them.
"""
import json
import os
import shutil
import time
import uuid
import math
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from queue import Queue
from threading import Event, Lock, Thread
from typing import Any, Dict, List, Optional
import requests
import yaml
from dotenv import load_dotenv
from openai import OpenAI
from process.asr_func.asr_transcribe_groq import record_on_speech, transcribe_audio_groq
from process.tts_func.sovits_ping import get_wav_duration
from process.tts_func.elevenlabs_ping import elevenlabs_gen
from process.tts_func.tts_preprocess import clean_llm_output
from process.vrm_func.vrm_ping import vrm_animate, vrm_talk
from process.vrm_func.vrm_states_ping import set_vrm_state
# Server URL for fetching pending user actions (clicks).
SERVER_BASE_URL = "http://localhost:8001"
load_dotenv()
# ==================== Paths & Config ====================
PROJECT_ROOT = Path(__file__).resolve().parent.parent
CONFIG_PATH = PROJECT_ROOT / "character_config.yaml"
MCP_CONFIG_PATH = Path(
os.getenv("MCP_CONFIG_PATH", str(Path.home() / "MCP_functions" / "mcp_config.json"))
)
if not CONFIG_PATH.exists():
raise FileNotFoundError(f"[config] character_config.yaml not found at {CONFIG_PATH}")
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
char_config = yaml.safe_load(f)
HISTORY_FILE = char_config["history_file"]
MODEL = char_config.get("model", "model")
BASE_SYSTEM_PROMPT = char_config["presets"]["default"]["system_prompt"]
# Appended only when MCP is active so the model knows to interleave speech and tool calls.
TOOL_USE_RULES = """
You also have access to a tool that lets you perform physical actions (e.g. wave, walk, dropkick, backflip, kiss, etc.).
Call the tool only when performing actions. You can both speak (by outputting text) and perform physical actions (by calling the tool).
IMPORTANT: You MUST interleave your speech and actions naturally in order. For example:
- To say something, then do an action, then say something else, output text first, then call the tool, then output more text - all in the correct sequence.
- Do NOT batch all tool calls together. Alternate between text and tool calls as the situation requires.
- Do NOT use the tool for speaking - just write your words as normal text.
"""
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
raise EnvironmentError("[config] Please set OPENAI_API_KEY in your environment")
client = OpenAI(api_key=OPENAI_API_KEY)
# ==================== emotional translator =================
def get_vrm_expression(text: str) -> str:
text_lower = text.lower()
if any(word in text_lower for word in ['*giggle', '*laugh', '*smile', '*smirk', '*cheer', '[laughs]']):
return "happy"
elif any(word in text_lower for word in ['*angry', '*glare', '*mad', '*frustrat']):
return "angry"
elif any(word in text_lower for word in ['*sad', '*cry', '*sigh', '*sorrow', '[sigh]']):
return "sad"
elif any(word in text_lower for word in ['*gasp', '*shock', '*surpris', '*wide', '[gasps]']):
return "surprised"
elif any(word in text_lower for word in ['*neutral', '*stare']):
return "neutral"
else:
return "relaxed"
# ==================== MCP Optional Layer ====================
def load_mcp_config() -> Optional[dict]:
if not MCP_CONFIG_PATH.exists():
return None
try:
with open(MCP_CONFIG_PATH, "r", encoding="utf-8") as f:
return json.load(f)
except (json.JSONDecodeError, OSError) as e:
print(f"[mcp] failed to load config at {MCP_CONFIG_PATH}: {e}")
return None
def is_mcp_available() -> bool:
return load_mcp_config() is not None
def _mcp_jsonrpc(method: str, params: dict) -> Optional[Any]:
cfg = load_mcp_config()
if not cfg:
return None
try:
endpoint = f"{cfg['url'].rstrip('/')}/mcp/"
resp = requests.post(
endpoint,
headers={
"Authorization": f"Bearer {cfg['token']}",
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
},
json={"jsonrpc": "2.0", "id": 1, "method": method, "params": params},
timeout=10,
stream=True,
)
resp.raise_for_status()
if resp.headers.get("content-type", "").startswith("text/event-stream"):
for line in resp.iter_lines():
if not line:
continue
decoded = line.decode("utf-8")
if not decoded.startswith("data: "):
continue
try:
event_data = json.loads(decoded[6:])
except json.JSONDecodeError:
continue
if "error" in event_data:
print(f"[mcp] error from {method}: {event_data['error']}")
return None
if event_data.get("jsonrpc") == "2.0" and "result" in event_data:
return event_data["result"]
return None
body = resp.json()
if "error" in body:
print(f"[mcp] error from {method}: {body['error']}")
return None
return body.get("result", body)
except Exception as e:
print(f"[mcp] {method} failed: {e}")
return None
def get_all_tools_metadata() -> Optional[List[Dict[str, Any]]]:
result = _mcp_jsonrpc("tools/list", {})
if not result:
return None
return result.get("tools", result) if isinstance(result, dict) else result
def get_metadata_field(function_name: str, field_path: str) -> Optional[Any]:
tools = get_all_tools_metadata()
if not tools:
return None
tool = next((t for t in tools if t.get("name") == function_name), None)
if not tool:
return None
current = tool.get("_meta", {})
for part in field_path.split("."):
if isinstance(current, dict) and part in current:
current = current[part]
else:
return None
return current
def call_mcp_tool(tool_name: str, arguments: dict) -> Optional[Any]:
return _mcp_jsonrpc("tools/call", {"name": tool_name, "arguments": arguments})
# ==================== Playback Worker ====================
# Fire the next chunk's vrm_talk this many seconds before the previous chunk's
# audio is supposed to end, to mask client-side load + play() latency.
CLIENT_LATENCY_LEAD = 0.5
MIN_SLEEP_BETWEEN_CHUNKS = 0.05
class ItemType(Enum):
TEXT = "text"
FUNCTION_CALL = "function_call"
@dataclass
class PlaybackItem:
item_type: ItemType
content: str
arguments: Optional[Dict] = None
audio_path: Optional[Path] = None
expression: str = "relaxed"
duration: float = 0.0
needs_sync: bool = False
class OrderedPlaybackWorker:
"""Sequentially process text + function items so audio never overlaps."""
def __init__(self):
self.queue: Queue = Queue()
self.thread = Thread(target=self._run, daemon=True)
self._running = False
self._talking = False
self.queue_finished_event = Event()
self.queue_finished_event.set()
def start(self):
if not self._running:
self._running = True
self.thread.start()
print("[playback] worker started")
def enqueue(self, item: PlaybackItem):
self.queue_finished_event.clear()
self.queue.put(item)
def wait_until_finished(self, timeout=None) -> bool:
return self.queue_finished_event.wait(timeout)
def stop(self):
self.queue.put(None)
self.thread.join()
def _run(self):
while True:
item = self.queue.get()
if item is None:
break
if item.item_type == ItemType.TEXT:
self._process_text_item(item)
elif item.item_type == ItemType.FUNCTION_CALL:
self._process_function_item(item)
if self.queue.empty():
# Animation calls intentionally left commented (option c).
self._talking = False
self.queue_finished_event.set()
print("[playback] queue drained")
def _process_text_item(self, item: PlaybackItem):
try:
preview = item.content.replace("\n", " ")
print(f"[playback] talk ({item.duration:.2f}s): {preview!r}")
safe_duration = math.ceil(item.duration)
vrm_talk(str(item.audio_path), item.expression, item.content, safe_duration)
except Exception as e:
print(f"[playback] vrm_talk failed: {e}")
# Fire the next vrm_talk slightly before the current audio actually
# ends on the client - the client has its own load + play() latency.
sleep_for = max(0.05, safe_duration - 0.5)
time.sleep(sleep_for)
def _process_function_item(self, item: PlaybackItem):
print(f"[playback] function: {item.content}({item.arguments}) needs_sync={item.needs_sync}")
if item.needs_sync:
tool_args = dict(item.arguments or {})
tool_args["manual_call"] = True
result = call_mcp_tool(item.content, tool_args)
print(f"[playback] function {item.content} result: {result}")
else:
print(f"[playback] function {item.content} already executed during streaming")
# ==================== History ====================
def _initial_system_message() -> dict:
sys_text = BASE_SYSTEM_PROMPT
if is_mcp_available():
sys_text = sys_text + TOOL_USE_RULES
return {"role": "system", "content": [{"type": "input_text", "text": sys_text}]}
def load_history() -> List[dict]:
if os.path.exists(HISTORY_FILE):
with open(HISTORY_FILE, "r", encoding="utf-8") as f:
return json.load(f)
return [_initial_system_message()]
def save_history(history: List[dict]):
with open(HISTORY_FILE, "w", encoding="utf-8") as f:
json.dump(history, f, indent=2)
# ==================== Streaming ====================
MIN_CHUNK_LEN = 60
MAX_CHUNK_LEN = 350
CHUNK_PUNCT = (".", "?", "!", "")
def _do_single_stream(stream_kwargs):
"""
Live stream: yield text chunks as they arrive, yield function items the
moment their MCP call completes.
Yields:
("text", str), ("function", (name, args, needs_sync)),
and a terminal ("final", (had_tool_calls, final_response)).
"""
text_buffers: Dict[int, str] = {}
mcp_names: Dict[int, str] = {}
mcp_args_buf: Dict[int, str] = {}
saw_tool_call = False
print(
f"[stream] start model={stream_kwargs.get('model')} "
f"continuation={'previous_response_id' in stream_kwargs}"
)
with client.responses.stream(**stream_kwargs) as stream:
for event in stream:
event_type = getattr(event, "type", "unknown")
output_index = getattr(event, "output_index", None)
if event_type == "response.output_item.added":
item = getattr(event, "item", None)
if item is not None and getattr(item, "type", None) == "mcp_call":
name = getattr(item, "name", None)
if output_index is not None and name:
mcp_names[output_index] = name
mcp_args_buf.setdefault(output_index, "")
print(f"[stream] mcp_call started idx={output_index} name={name}")
elif event_type == "response.output_text.delta":
buf = text_buffers.get(output_index, "") + event.delta
if buf.endswith(CHUNK_PUNCT) and len(buf) >= MIN_CHUNK_LEN:
yield "text", buf.strip()
text_buffers[output_index] = ""
elif len(buf) >= MAX_CHUNK_LEN:
yield "text", buf.strip()
text_buffers[output_index] = ""
else:
text_buffers[output_index] = buf
elif event_type == "response.output_text.done":
buf = text_buffers.get(output_index, "")
if buf.strip():
yield "text", buf.strip()
text_buffers[output_index] = ""
print(f"[stream] text done idx={output_index}")
elif event_type == "response.mcp_call_arguments.delta":
if output_index is not None:
mcp_args_buf[output_index] = (
mcp_args_buf.get(output_index, "") + getattr(event, "delta", "")
)
elif event_type == "response.mcp_call.completed":
saw_tool_call = True
if output_index is None:
continue
tool_name = mcp_names.get(output_index)
args_str = mcp_args_buf.get(output_index, "")
try:
tool_args = json.loads(args_str) if args_str else {}
except json.JSONDecodeError:
tool_args = {}
if tool_name:
tool_type = get_metadata_field(tool_name, "tool_type")
needs_sync = tool_type == "needs_sync"
print(
f"[stream] mcp_call done idx={output_index} "
f"name={tool_name} args={tool_args} needs_sync={needs_sync}"
)
yield "function", (tool_name, tool_args, needs_sync)
else:
print(f"[stream] mcp_call done idx={output_index} but tool name missing")
final_response = stream.get_final_response()
unresolved = [idx for idx in mcp_args_buf if idx not in mcp_names]
if unresolved:
print(f"[stream] resolving {len(unresolved)} mcp item(s) via final_response")
for output_item in final_response.output:
if getattr(output_item, "type", None) != "mcp_call":
continue
tool_name = getattr(output_item, "name", None)
if not tool_name:
continue
try:
tool_args = json.loads(getattr(output_item, "arguments", "{}"))
except json.JSONDecodeError:
tool_args = {}
tool_type = get_metadata_field(tool_name, "tool_type")
yield "function", (tool_name, tool_args, tool_type == "needs_sync")
yield "final", (saw_tool_call, final_response)
def stream_with_functions(messages):
"""Multi-turn streaming with previous_response_id continuation."""
mcp_config = load_mcp_config()
base_kwargs = {
"model": MODEL,
"input": messages,
"temperature": 1,
"top_p": 1,
"max_output_tokens": 2048,
}
if mcp_config:
base_kwargs["tools"] = [
{
"type": "mcp",
"server_label": mcp_config["server_name"],
"server_url": f"{mcp_config['url']}/mcp",
"require_approval": "never",
}
]
print(f"[stream] MCP active: {mcp_config['server_name']}")
else:
print("[stream] MCP not configured - running plain streaming chat")
max_continuations = 10
continuation = 0
previous_response_id = None
all_final_responses: list = []
while continuation <= max_continuations:
if continuation == 0:
stream_kwargs = base_kwargs.copy()
else:
stream_kwargs = {
"model": MODEL,
"previous_response_id": previous_response_id,
"input": [],
"temperature": 1,
"top_p": 1,
"max_output_tokens": 2048,
}
if mcp_config:
stream_kwargs["tools"] = base_kwargs["tools"]
had_tool_calls = False
final_response = None
for kind, payload in _do_single_stream(stream_kwargs):
if kind == "text":
yield payload, "text"
elif kind == "function":
yield payload, "function"
elif kind == "final":
had_tool_calls, final_response = payload
if final_response is not None:
all_final_responses.append(final_response)
previous_response_id = final_response.id
if not had_tool_calls:
break
print(f"[stream] continuing for more output (turn {continuation + 1})")
continuation += 1
if continuation > max_continuations:
print(f"[stream] WARNING: hit max continuations ({max_continuations})")
yield all_final_responses, "final_responses"
# ==================== Helpers ====================
def ensure_dirs():
Path("client/audio").mkdir(parents=True, exist_ok=True)
Path("audio").mkdir(parents=True, exist_ok=True)
def _safe_set_state(state: str):
try:
set_vrm_state(state)
except Exception as e:
print(f"[main] set_vrm_state({state}) failed: {e}")
def _action_to_text(action: dict) -> str:
"""Render a click/touch action as a bracketed prompt for the LLM."""
region = (action.get("region") or "").strip()
bone = (action.get("bone") or "").strip()
if region and region.lower() != "body":
label = region.replace("_", " ")
elif bone:
label = bone.replace("_", " ")
else:
label = "body"
return f"[the user touched your {label}]"
def fetch_pending_user_actions() -> List[dict]:
"""Drain any clicks/touches buffered server-side since the last call."""
try:
resp = requests.get(f"{SERVER_BASE_URL}/pop_pending_actions", timeout=2)
resp.raise_for_status()
return resp.json().get("actions", [])
except Exception as e:
print(f"[main] fetch_pending_user_actions failed: {e}")
return []
def _persist_assistant_turn(messages: List[dict], all_final_responses: list, fallback_text: str, mcp_on: bool):
if all_final_responses and mcp_on:
for resp in all_final_responses:
if not hasattr(resp, "output"):
continue
for resp_item in resp.output:
if resp_item.type == "mcp_list_tools":
tool_list = []
for tool in resp_item.tools:
tool_info = {
"name": tool.name,
"description": tool.description,
"input_schema": tool.input_schema,
}
if hasattr(tool, "annotations"):
tool_info["annotations"] = tool.annotations
tool_list.append(tool_info)
messages.append(
{
"type": "mcp_list_tools",
"server_label": resp_item.server_label,
"tools": tool_list,
}
)
elif resp_item.type == "mcp_call":
messages.append(
{
"type": "mcp_call",
"name": resp_item.name,
"arguments": resp_item.arguments,
"server_label": resp_item.server_label,
"output": resp_item.output,
}
)
elif resp_item.type == "message":
content_list = []
for c in resp_item.content:
if c.type == "output_text":
content_list.append({"type": "output_text", "text": c.text})
messages.append({"role": resp_item.role, "content": content_list})
else:
messages.append(
{
"role": "assistant",
"content": [{"type": "output_text", "text": fallback_text}],
}
)
# ==================== LLM Turn ====================
# Single lock so concurrent triggers (speech vs click) don't clobber history.
_llm_turn_lock = Lock()
def run_llm_turn(user_text: str, playback: OrderedPlaybackWorker, mcp_on: bool):
"""
Run one full LLM turn: append user_text to history, stream the response,
enqueue chunks/functions to the playback worker, persist history.
Serialized via _llm_turn_lock so concurrent triggers don't race.
"""
with _llm_turn_lock:
print(f"\n[turn] >>> {user_text!r}")
_safe_set_state("talking")
messages = load_history()
messages.append(
{
"role": "user",
"content": [{"type": "input_text", "text": user_text}],
}
)
full_assistant_text = ""
all_final_responses: list = []
for item, item_type in stream_with_functions(messages):
if item_type == "text":
text_chunk: str = item
full_assistant_text += text_chunk + " "
tts_text = clean_llm_output(text_chunk)
# --- DEIN LEBENSRETTENDER FIX: Stumme Aktionen überspringen ---
# if not tts_text.strip():
# print(f"[Skip] Chunk enthielt nur stumme Aktionen: {text_chunk}")
# continue
#
# # Falls du deine get_vrm_expression Funktion wieder eingebaut hast,
# # tausche "relaxed" gegen: expression = get_vrm_expression(text_chunk)
# expression = "relaxed"
# uid = uuid.uuid4().hex
# filename = f"output_{uid}.wav"
# client_out = Path("client") / "audio" / filename
# public_out = Path("audio") / filename
# client_out.parent.mkdir(parents=True, exist_ok=True)
if not tts_text.strip():
print(f"[Skip] Chunk enthielt nur stumme Aktionen: {text_chunk}")
continue
# --- HIER IST DIE ÄNDERUNG ---
# Riku liest den Text und wählt die passende Mimik
expression = get_vrm_expression(text_chunk)
uid = uuid.uuid4().hex
filename = f"output_{uid}.wav"
client_out = Path("client") / "audio" / filename
public_out = Path("audio") / filename
client_out.parent.mkdir(parents=True, exist_ok=True)
# --- NEUES ROUTING: ElevenLabs statt SoVITS ---
try:
elevenlabs_gen(tts_text, output_wav_pth=str(client_out))
except TypeError:
elevenlabs_gen(tts_text, str(client_out))
shutil.copy2(client_out, public_out)
try:
# Dateigröße in Bytes abfragen
file_size = os.path.getsize(public_out)
# 128 kbps = 16.000 Bytes pro Sekunde
duration = file_size / 16000.0
except Exception as e:
print(f"Fehler bei der Längenberechnung: {e}")
duration = 3.0
playback.enqueue(
PlaybackItem(
item_type=ItemType.TEXT,
content=text_chunk,
audio_path=public_out,
expression=expression,
duration=duration,
)
)
print(f"[turn] enqueued text chunk ({duration:.2f}s)")
elif item_type == "function":
tool_name, tool_args, needs_sync = item
playback.enqueue(
PlaybackItem(
item_type=ItemType.FUNCTION_CALL,
content=tool_name,
arguments=tool_args,
needs_sync=needs_sync,
)
)
print(f"[turn] enqueued function {tool_name} needs_sync={needs_sync}")
elif item_type == "final_responses":
all_final_responses = item
final_text = full_assistant_text.strip()
print(f"[turn] full response: {final_text!r}")
_persist_assistant_turn(messages, all_final_responses, final_text, mcp_on)
save_history(messages)
print("[turn] history saved")
# ==================== Click Dispatcher (Background) ====================
def click_dispatcher(playback: OrderedPlaybackWorker, mcp_on: bool, poll_interval: float = 0.5):
"""
Background thread: poll /pop_pending_actions and run an LLM turn the
moment any clicks land. Independent from the speech recording loop.
"""
print(f"[click_dispatcher] started (poll every {poll_interval}s)")
while True:
try:
actions = fetch_pending_user_actions()
if actions:
user_text = " ".join(_action_to_text(a) for a in actions)
print(f"[click_dispatcher] firing turn for {len(actions)} action(s): {user_text}")
run_llm_turn(user_text, playback, mcp_on)
except Exception as e:
print(f"[click_dispatcher] error: {e}")
time.sleep(poll_interval)
# ==================== Main Loop ====================
def main_loop():
ensure_dirs()
mcp_on = is_mcp_available()
print(f"\n========= Starting main_chat_v9 (MCP={'ON' if mcp_on else 'OFF'}) =========\n")
playback = OrderedPlaybackWorker()
playback.start()
Thread(target=click_dispatcher, args=(playback, mcp_on), daemon=True).start()
while True:
try:
print("\n[main] waiting for playback queue to finish...")
playback.wait_until_finished()
print("[main] queue finished - ready for input")
_safe_set_state("idle")
print("[main] recording - speak when ready")
conversation_recording = Path("audio") / "conversation.wav"
conversation_recording.parent.mkdir(parents=True, exist_ok=True)
recording_path = str(conversation_recording)
record_on_speech(
output_file=recording_path,
samplerate=44100,
channels=1,
silence_threshold=0.22,
silence_duration=2,
device="default",
)
_safe_set_state("thinking")
user_spoken_text = transcribe_audio_groq(aud_path=recording_path)
print(f"[main] transcribed: {user_spoken_text!r}")
if not user_spoken_text or not user_spoken_text.strip():
print("[main] empty transcription, skipping turn")
continue
run_llm_turn(user_spoken_text, playback, mcp_on)
time.sleep(0.1)
except KeyboardInterrupt:
print("\n[main] interrupted, stopping playback")
playback.stop()
break
except Exception as e:
print(f"[main] error in main loop: {e}")
import traceback
traceback.print_exc()
time.sleep(1)
if __name__ == "__main__":
main_loop()