||
- # -*- coding: utf-8 -*-
- import struct
- import asyncio
- from enum import Enum
- from uuid import uuid4
- from typing import Optional, Union, List, Dict, Tuple
- from datetime import datetime
- from pydantic import BaseModel, Field
- from fastapi import WebSocket
- # ======================= 枚举类型 =======================
- class StrEnum(str, Enum):
- def __str__(self):
- return str(self.value)
- class IntEnum(int, Enum):
- def __str__(self):
- return str(self.value)
- class ENGINE_TYPE(StrEnum):
- ASR = "ASR"
- TTS = "TTS"
- LLM = "LLM"
- AGENT = "AGENT"
- class GENDER_TYPE(StrEnum):
- MALE = 'MALE'
- FEMALE = 'FEMALE'
-
- class EVENT_TYPE(StrEnum):
- CONVERSATION_ID = 'CONVERSATION_ID'
- MESSAGE_ID = 'MESSAGE_ID'
- TEXT = 'TEXT'
- THINK = 'THINK'
- TASK = 'TASK'
- DONE = 'DONE'
- ERROR = 'ERROR'
- class PARAM_TYPE(StrEnum):
- STRING = 'string'
- INT = 'int'
- FLOAT = 'float'
- BOOL = 'bool'
- LIST = 'list'
- class AUDIO_TYPE(StrEnum):
- MP3 = 'mp3'
- WAV = 'wav'
- class DATA_TYPE(StrEnum):
- TEXT = 'text'
- AUDIO_URL = 'audio_url'
- AUDIO_STREAM = 'audio_stream'
- class ROLE_TYPE(StrEnum):
- SYSTEM = 'system'
- USER = 'user'
- ASSISTANT = 'assistant'
- TOOL = 'tool'
- class INFER_TYPE(StrEnum):
- NORMAL = 'normal'
- STREAM = 'stream'
- class RESPONSE_CODE(IntEnum):
- OK = 0
- ERROR = -1
- # ========================== Message =============================
- class BaseMessage(BaseModel):
- """
- Base Protocol
- """
- # id: str = Field(default_factory=lambda: str(uuid4()))
- def __str__(self) -> str:
- return f'Message({self.model_dump()})'
- class AudioMessage(BaseMessage):
- data: Optional[Union[str, bytes]] = None
- dataType: DATA_TYPE = DATA_TYPE.AUDIO_STREAM # 数据类型:音频流、音频URL等
- type: AUDIO_TYPE = AUDIO_TYPE.WAV # 音频格式:WAV、MP3等
- sampleRate: int = 16000
- sampleWidth: int = 2
- class TextMessage(BaseMessage):
- data: Optional[str] = None
- dataType: DATA_TYPE = DATA_TYPE.TEXT # 数据类型
- class RoleMessage(BaseMessage):
- role: ROLE_TYPE
- content: str
- # ========================== server =============================
- class BaseResponse(BaseModel):
- code: RESPONSE_CODE
- message: str
- # ========================== voice =============================
- class VoiceDesc(BaseModel):
- name: str
- gender: GENDER_TYPE
- # ========================== param =============================
- class ParamDesc(BaseModel):
- name: str
- description: str
- type: PARAM_TYPE
- required: bool
- range: List[Union[str, int, float]] = []
- choices: List[Union[str, int, float]] = []
- default: Union[str, int, float, bool, List]
- # ========================== engine =============================
- class EngineDesc(BaseModel):
- name: str
- type: ENGINE_TYPE
- infer_type: INFER_TYPE
- desc: str = ""
- meta: Dict = {}
- class EngineConfig(BaseModel):
- name: str
- type: ENGINE_TYPE
- config: Dict
- # ========================== user =============================
- class UserDesc(BaseModel):
- user_id: str
- request_id: str
- cookie: str
-
- # ========================== func =============================
- def eventStreamResponse(event: EVENT_TYPE, data: str) -> str:
- message = "event: " + str(event) + "\ndata: " + data.replace("\n", "\\n") + "\n\n"
- return message
- def eventStreamText(data: str) -> str:
- return eventStreamResponse(EVENT_TYPE.TEXT, data)
- def eventStreamTask(task_id: str) -> str:
- return eventStreamResponse(EVENT_TYPE.TASK, task_id)
- def eventStreamThink(data: str) -> str:
- return eventStreamResponse(EVENT_TYPE.THINK, data)
- def eventStreamConversationId(conversation_id: str) -> str:
- return eventStreamResponse(EVENT_TYPE.CONVERSATION_ID, conversation_id)
- def eventStreamMessageId(message_id: str) -> str:
- return eventStreamResponse(EVENT_TYPE.MESSAGE_ID, message_id)
- def eventStreamDone() -> str:
- return f"event: {EVENT_TYPE.DONE}\ndata: Done\n\n"
- def eventStreamError(error: str):
- return eventStreamResponse(EVENT_TYPE.ERROR, error)
- def isEventStreamResponse(message: str) -> bool:
- return message.startswith("event:")
- # ========================== websocket =============================
- # 协议常量定义
- ACTION_HEADER_SIZE = 18 # action字段大小(18字节)
- # 协议格式: [Action(18字节)] + [Payload Size(4字节)] + [Payload(可变长度)]
- PROTOCOL_HEADER_FORMAT = ">18sI" # 大端序: 18字节action + 4字节无符号整数payload_size
- PROTOCOL_HEADER_SIZE = struct.calcsize(PROTOCOL_HEADER_FORMAT) # 22字节
- class WS_RECV_ACTION_TYPE(StrEnum):
- """客户端请求类型"""
- PING = "PING" # 心跳包
- ENGINE_START = "ENGINE_START" # 启动引擎
- ENGINE_PARTIAL_INPUT = "PARTIAL_INPUT" # 引擎输入
- ENGINE_FINAL_INPUT = "FINAL_INPUT" # 引擎输入
- ENGINE_STOP = "ENGINE_STOP" # 停止引擎
- class WS_SEND_ACTION_TYPE(StrEnum):
- """服务端响应类型"""
- PONG = "PONG" # 心跳响应
- ENGINE_INITIALZING = "ENGINE_INITIALZING" # 引擎初始化
- ENGINE_STARTED = "ENGINE_STARTED" # 引擎准备就绪
- ENGINE_PARTIAL_OUTPUT = "PARTIAL_OUTPUT" # 引擎输出
- ENGINE_FINAL_OUTPUT = "FINAL_OUTPUT" # 引擎输出
- ENGINE_STOPPED = "ENGINE_STOPPED" # 关闭引擎
- ERROR = "ERROR" # 错误响应
- def _format_action(action_name: str) -> bytes:
- """格式化action名称为18字节,右侧用空格填充"""
- if len(action_name) > ACTION_HEADER_SIZE:
- raise ValueError(
- f"Action name '{action_name}' exceeds {ACTION_HEADER_SIZE} bytes"
- )
- return action_name.ljust(ACTION_HEADER_SIZE).encode("utf-8")
- def struct_message(action: str, message: str | bytes) -> bytes:
- """构造发送消息"""
- if isinstance(message, str):
- message = message.encode("utf-8")
- action_bytes = _format_action(action)
- payload_size = len(message)
- # 打包协议头部: action(18字节) + payload_size(4字节)
- header = struct.pack(PROTOCOL_HEADER_FORMAT, action_bytes, payload_size)
- return header + message
- def parse_message(message: bytes) -> Tuple[str, bytes]:
- """解析接收到的消息"""
- if len(message) < PROTOCOL_HEADER_SIZE:
- raise ValueError(
- f"Message too short: {len(message)} bytes, expected at least {PROTOCOL_HEADER_SIZE}"
- )
- # 解析协议头部: action(18字节) + payload_size(4字节)
- action, payload_size = struct.unpack(
- PROTOCOL_HEADER_FORMAT, message[:PROTOCOL_HEADER_SIZE]
- )
- expected_total_size = PROTOCOL_HEADER_SIZE + payload_size
- if len(message) != expected_total_size:
- raise ValueError(
- f"Message size mismatch: got {len(message)} bytes, expected {expected_total_size}"
- )
- # 提取payload
- payload = message[PROTOCOL_HEADER_SIZE : PROTOCOL_HEADER_SIZE + payload_size] if payload_size > 0 else b""
- return (action.decode("utf-8").strip(), payload)
- class WebSocketHandler():
- """
- websocket处理类(协议控制)
- """
- @staticmethod
- async def connect(ws: WebSocket) -> None:
- """连接WebSocket"""
- await ws.accept()
- # logger.debug(f"WebSocket connected: {ws.client.host}")
-
- @staticmethod
- async def disconnect(ws: WebSocket):
- """断开WebSocket连接"""
- try:
- await ws.close()
- except (RuntimeError, AttributeError, Exception):
- # 忽略关闭时的错误,避免在事件循环关闭后尝试关闭连接
- # 这是 Windows 上 ProactorEventLoop 的已知问题
- # 当事件循环关闭后,WebSocket 连接的析构函数会尝试关闭连接,但此时事件循环已经关闭
- pass
- # logger.debug(f"WebSocket disconnected: {ws.client.host}")
- @staticmethod
- async def send_message(ws: WebSocket, action: str, message: str | bytes = b'') -> None:
- """发送WebSocket消息"""
- data = struct_message(action, message)
- await ws.send_bytes(data)
- # logger.debug(f"Sent action: {action}, payload size: {len(data) - PROTOCOL_HEADER_SIZE} bytes")
-
- @staticmethod
- async def recv_message(ws: WebSocket) -> Tuple[str, bytes]:
- """接收WebSocket消息"""
- message = await ws.receive_bytes()
- action, payload = parse_message(message)
- # logger.debug(f"Received action: {action.decode('utf-8').strip()}, payload size: {len(payload)} bytes")
- return action, payload
|