| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232 |
- # -*- coding: utf-8 -*-
- import os
- import json
- import asyncio
- from http import HTTPStatus
- from fastapi import WebSocket, WebSocketDisconnect
- from dashscope.audio.asr import RecognitionCallback, Recognition
- from digitalHuman.utils import logger
- from digitalHuman.engine.builder import ASREngines
- from digitalHuman.protocol import *
- from digitalHuman.engine.engineBase import StreamBaseEngine
- __all__ = ["DashscopeStreamingASR"]
- class ASRCallback(RecognitionCallback):
- """ASR 回调处理类"""
- def __init__(self, websocket: WebSocket):
- self.websocket = websocket
- self.partial_text = ""
- self.final_text = ""
-
- def on_open(self):
- logger.debug("[DashscopeStreamingASR] Connection opened")
-
- def on_close(self):
- logger.debug("[DashscopeStreamingASR] Connection closed")
-
- def on_event(self, result):
- """处理识别事件"""
- try:
- if result.status_code == HTTPStatus.OK:
- sentence = result.get_sentence()
- if sentence:
- text = sentence.get('text', '')
- # 判断是否为最终结果
- if sentence.get('end_time'):
- # 最终结果
- self.final_text = text
- logger.debug(f"[DashscopeStreamingASR] Final: {text}")
- else:
- # 部分结果
- self.partial_text = text
- logger.debug(f"[DashscopeStreamingASR] Partial: {text}")
- else:
- logger.error(f"[DashscopeStreamingASR] Error: {result.message}")
- except Exception as e:
- logger.error(f"[DashscopeStreamingASR] Callback error: {e}")
-
- def on_error(self, error):
- logger.error(f"[DashscopeStreamingASR] Error: {error}")
-
- async def get_partial_result(self):
- """获取部分识别结果"""
- if self.partial_text:
- text = self.partial_text
- return text
- return ""
-
- async def get_final_result(self):
- """获取最终识别结果"""
- if self.final_text:
- text = self.final_text
- self.final_text = ""
- self.partial_text = ""
- return text
- return ""
- @ASREngines.register("dashscopeStreamingASR")
- class DashscopeStreamingASR(StreamBaseEngine):
- def setup(self):
- """初始化配置"""
- try:
- import dashscope
- # 从配置或环境变量获取 API Key
- api_key = self.cfg.get('CUSTOM', {}).get('api_key') or os.getenv('DASHSCOPE_API_KEY')
- if api_key:
- dashscope.api_key = api_key
- logger.info("[DashscopeStreamingASR] API Key configured successfully")
- else:
- logger.warning("[DashscopeStreamingASR] No API Key found, please set DASHSCOPE_API_KEY environment variable or configure in yaml")
- except ImportError:
- logger.error("[DashscopeStreamingASR] Please install dashscope: pip install dashscope")
- raise
- except Exception as e:
- logger.error(f"[DashscopeStreamingASR] Setup error: {e}")
- raise
- async def _task_send(self, adhWebsocket: WebSocket, asr_callback: ASRCallback):
- """
- 发送识别结果到前端
- """
- try:
- last_partial = ""
- while True:
- await asyncio.sleep(0.1) # 100ms 检查一次
-
- # 检查是否有最终结果
- final_text = await asr_callback.get_final_result()
- if final_text:
- await WebSocketHandler.send_message(
- adhWebsocket,
- WS_SEND_ACTION_TYPE.ENGINE_FINAL_OUTPUT,
- final_text
- )
- last_partial = ""
- continue
-
- # 检查是否有部分结果
- partial_text = await asr_callback.get_partial_result()
- if partial_text and partial_text != last_partial:
- await WebSocketHandler.send_message(
- adhWebsocket,
- WS_SEND_ACTION_TYPE.ENGINE_PARTIAL_OUTPUT,
- partial_text
- )
- last_partial = partial_text
-
- except WebSocketDisconnect:
- logger.debug("[DashscopeStreamingASR] adhWebsocket closed, task_send exit")
- except Exception as e:
- logger.error(f"[DashscopeStreamingASR] task_send error: {e}")
- await WebSocketHandler.send_message(adhWebsocket, WS_SEND_ACTION_TYPE.ERROR, str(e))
- async def _task_recv(self, adhWebsocket: WebSocket, recognition: Recognition):
- """
- 接收前端音频数据并发送到识别服务
- """
- try:
- await WebSocketHandler.send_message(adhWebsocket, WS_SEND_ACTION_TYPE.ENGINE_STARTED)
-
- while True:
- action, payload = await WebSocketHandler.recv_message(adhWebsocket)
-
- match action:
- case WS_RECV_ACTION_TYPE.PING:
- await WebSocketHandler.send_message(adhWebsocket, WS_SEND_ACTION_TYPE.PONG, b"")
-
- case WS_RECV_ACTION_TYPE.ENGINE_START:
- raise RuntimeError("[DashscopeStreamingASR] Engine has been started")
-
- case WS_RECV_ACTION_TYPE.ENGINE_PARTIAL_INPUT:
- # 发送音频数据到识别服务
- await asyncio.get_event_loop().run_in_executor(
- None, recognition.send_audio_frame, payload
- )
-
- case WS_RECV_ACTION_TYPE.ENGINE_FINAL_INPUT:
- # 发送最后的音频数据
- await asyncio.get_event_loop().run_in_executor(
- None, recognition.send_audio_frame, payload
- )
-
- case WS_RECV_ACTION_TYPE.ENGINE_STOP:
- # 停止识别
- await asyncio.get_event_loop().run_in_executor(
- None, recognition.stop
- )
- await WebSocketHandler.send_message(adhWebsocket, WS_SEND_ACTION_TYPE.ENGINE_STOPPED)
- return
-
- case _:
- raise RuntimeError(f"[DashscopeStreamingASR] Unknown action: {action}")
-
- except WebSocketDisconnect:
- logger.debug("[DashscopeStreamingASR] adhWebsocket closed, task_recv exit")
- except Exception as e:
- logger.error(f"[DashscopeStreamingASR] task_recv error: {e}")
- await WebSocketHandler.send_message(adhWebsocket, WS_SEND_ACTION_TYPE.ERROR, str(e))
- async def run(self, websocket: WebSocket, **kwargs) -> None:
- """运行流式识别"""
- # 参数校验
- paramters = self.checkParameter(**kwargs)
- model = paramters.get("model", "fun-asr-realtime")
- sample_rate = paramters.get("sample_rate", 16000)
- format_type = paramters.get("format", "pcm")
- language_hints = paramters.get("language_hints", ["zh", "en"])
-
- await WebSocketHandler.send_message(websocket, WS_SEND_ACTION_TYPE.ENGINE_INITIALZING)
-
- try:
- # 创建回调对象
- asr_callback = ASRCallback(websocket)
-
- # 创建识别对象
- # 注意:language_hints 只支持 paraformer-realtime-v2 和 paraformer-v2 模型
- if model in ['paraformer-realtime-v2', 'paraformer-v2']:
- recognition = Recognition(
- model=model,
- format=format_type,
- sample_rate=sample_rate,
- language_hints=language_hints,
- callback=asr_callback
- )
- else:
- # fun-asr-realtime 等模型不支持 language_hints
- recognition = Recognition(
- model=model,
- format=format_type,
- sample_rate=sample_rate,
- callback=asr_callback
- )
-
- # 启动识别
- await asyncio.get_event_loop().run_in_executor(
- None, recognition.start
- )
-
- # 创建发送和接收任务
- task_recv = asyncio.create_task(self._task_recv(websocket, recognition))
- task_send = asyncio.create_task(self._task_send(websocket, asr_callback))
-
- # 等待任务完成
- await asyncio.gather(task_recv, task_send)
-
- except Exception as e:
- logger.error(f"[DashscopeStreamingASR] Run error: {e}")
- await WebSocketHandler.send_message(websocket, WS_SEND_ACTION_TYPE.ERROR, str(e))
- finally:
- # 清理资源
- try:
- if recognition:
- await asyncio.get_event_loop().run_in_executor(
- None, recognition.stop
- )
- except:
- pass
|