funasrStreamingASR.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. import json
  2. import asyncio
  3. import time
  4. import websockets
  5. from fastapi import WebSocket, WebSocketDisconnect
  6. from digitalHuman.utils import logger
  7. from digitalHuman.engine.builder import ASREngines
  8. from digitalHuman.protocol import *
  9. from digitalHuman.engine.engineBase import StreamBaseEngine
  10. __all__ = ["FunasrStreamingAsr"]
  11. @ASREngines.register("funasrStreaming")
  12. class FunasrStreamingAsr(StreamBaseEngine):
  13. async def _reset_sentence(self, funasrWebsocket: websockets.ClientConnection):
  14. """重置说话识别, 防止连续识别添加标点符号"""
  15. message = json.dumps(
  16. {
  17. "is_speaking": False,
  18. }
  19. )
  20. await funasrWebsocket.send(message)
  21. message = json.dumps(
  22. {
  23. "is_speaking": True,
  24. }
  25. )
  26. await funasrWebsocket.send(message)
  27. async def _task_send(self, adhWebsocket: WebSocket, funasrWebsocket: websockets.ClientConnection):
  28. """
  29. funasr server -> adh server -> adh web
  30. """
  31. text_send = ""
  32. text_send_2pass_online = ""
  33. text_send_2pass_offline = ""
  34. wake_word = "小天小天"
  35. is_awake = False
  36. inactivity_deadline = time.monotonic() + 300 # 5分钟超时
  37. def process_text_for_wake(text: str) -> tuple[bool, str]:
  38. nonlocal is_awake
  39. if not is_awake:
  40. if wake_word in text:
  41. is_awake = True
  42. return True, text.replace(wake_word, "").strip()
  43. return False, ""
  44. return True, text.replace(wake_word, "").strip()
  45. try:
  46. while True:
  47. # 超时检查
  48. if time.monotonic() > inactivity_deadline:
  49. await WebSocketHandler.send_message(adhWebsocket, WS_SEND_ACTION_TYPE.ENGINE_STOPPED, "inactivity_timeout")
  50. return
  51. meg = await funasrWebsocket.recv()
  52. meg = json.loads(meg)
  53. wav_name = meg.get("wav_name", "demo")
  54. text = meg["text"]
  55. timestamp = ""
  56. offline_msg_done = meg.get("is_final", False)
  57. if "timestamp" in meg:
  58. timestamp = meg["timestamp"]
  59. if "mode" not in meg:
  60. continue
  61. if meg["mode"] == "online":
  62. text_send += text
  63. elif meg["mode"] == "offline":
  64. text_send += text
  65. offline_msg_done = True
  66. else:
  67. if meg["mode"] == "2pass-online":
  68. text_send_2pass_online += text
  69. text_send = text_send_2pass_offline + text_send_2pass_online
  70. else:
  71. offline_msg_done = True
  72. text_send_2pass_online = ""
  73. text_send = text_send_2pass_offline + text
  74. text_send_2pass_offline += text
  75. if offline_msg_done:
  76. awakened, cleaned = process_text_for_wake(text_send)
  77. if awakened and cleaned:
  78. await WebSocketHandler.send_message(adhWebsocket, WS_SEND_ACTION_TYPE.ENGINE_FINAL_OUTPUT, cleaned)
  79. inactivity_deadline = time.monotonic() + 300
  80. text_send = ""
  81. text_send_2pass_online = ""
  82. text_send_2pass_offline = ""
  83. await self._reset_sentence(funasrWebsocket)
  84. else:
  85. awakened, cleaned = process_text_for_wake(text_send)
  86. if awakened and cleaned:
  87. await WebSocketHandler.send_message(adhWebsocket, WS_SEND_ACTION_TYPE.ENGINE_PARTIAL_OUTPUT, cleaned)
  88. inactivity_deadline = time.monotonic() + 300
  89. except WebSocketDisconnect:
  90. logger.debug("adhWebsocket closed, task_send exit")
  91. except websockets.ConnectionClosed:
  92. logger.debug("funasrWebsocket closed, task_send exit")
  93. except Exception as e:
  94. logger.error(f"FunasrStreamingAsr task_send error: {e}")
  95. await WebSocketHandler.send_message(adhWebsocket, WS_SEND_ACTION_TYPE.ERROR, str(e))
  96. async def _task_recv(self, adhWebsocket: WebSocket, funasrWebsocket: websockets.ClientConnection, mode: str):
  97. """
  98. adh web -> adh server -> funasr server
  99. """
  100. try:
  101. message = json.dumps(
  102. {
  103. "mode": mode,
  104. "chunk_size": [5, 10, 5], # chunk_size: 60 * 10 ms. 左看300ms, 右看300ms
  105. "chunk_interval": 10,
  106. "encoder_chunk_look_back": 4,
  107. "decoder_chunk_look_back": 0,
  108. "wav_name": "adh",
  109. "is_speaking": True,
  110. "hotwords": "",
  111. "itn": True,
  112. }
  113. )
  114. await funasrWebsocket.send(message)
  115. await WebSocketHandler.send_message(adhWebsocket, WS_SEND_ACTION_TYPE.ENGINE_STARTED)
  116. while True:
  117. action, payload = await WebSocketHandler.recv_message(adhWebsocket)
  118. match action:
  119. case WS_RECV_ACTION_TYPE.PING:
  120. await WebSocketHandler.send_message(adhWebsocket, WS_SEND_ACTION_TYPE.PONG.value, b"")
  121. case WS_RECV_ACTION_TYPE.ENGINE_START:
  122. raise RuntimeError("FunasrStreamingAsr has benn started")
  123. case WS_RECV_ACTION_TYPE.ENGINE_PARTIAL_INPUT:
  124. await funasrWebsocket.send(payload)
  125. case WS_RECV_ACTION_TYPE.ENGINE_FINAL_INPUT:
  126. message = json.dumps(
  127. {
  128. "is_speaking": False
  129. }
  130. )
  131. await funasrWebsocket.send(message)
  132. await funasrWebsocket.send(payload)
  133. case WS_RECV_ACTION_TYPE.ENGINE_STOP:
  134. await funasrWebsocket.close()
  135. await WebSocketHandler.send_message(adhWebsocket, WS_SEND_ACTION_TYPE.ENGINE_STOPPED)
  136. return
  137. case _:
  138. raise RuntimeError(f"FunasrStreamingAsr task_recv error: {action} not found")
  139. except WebSocketDisconnect:
  140. logger.debug("funasrWebsocket closed, task_recv exit")
  141. except Exception as e:
  142. logger.error(f"FunasrStreamingAsr task_recv error: {e}")
  143. await WebSocketHandler.send_message(adhWebsocket, WS_SEND_ACTION_TYPE.ERROR, str(e))
  144. async def run(self, websocket: WebSocket, **kwargs) -> None:
  145. # 参数校验
  146. paramters = self.checkParameter(**kwargs)
  147. API_URL = paramters["api_url"]
  148. MODE = paramters["mode"]
  149. await WebSocketHandler.send_message(websocket, WS_SEND_ACTION_TYPE.ENGINE_INITIALZING)
  150. # 连接服务器
  151. try:
  152. async with websockets.connect(API_URL, subprotocols=["binary"], ping_interval=None) as funasrWebsocket:
  153. # adh web -> adh server -> funasr server
  154. task_recv = asyncio.create_task(self._task_recv(websocket, funasrWebsocket, MODE))
  155. # funasr server -> adh server -> adh web
  156. task_send = asyncio.create_task(self._task_send(websocket, funasrWebsocket))
  157. await asyncio.gather(task_recv, task_send)
  158. except Exception as e:
  159. logger.error(f"FunasrStreamingAsr run error: {e}")
  160. # 异常会被 async with 自动处理,这里只记录错误