api_asr_v0_impl.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. # -*- coding: utf-8 -*-
  2. import json
  3. from typing import List
  4. from digitalHuman.engine import EnginePool
  5. from digitalHuman.utils import config
  6. from digitalHuman.protocol import *
  7. from digitalHuman.server.models import *
  8. from digitalHuman.server.ws import *
  9. enginePool = EnginePool()
  10. def get_asr_list() -> List[EngineDesc]:
  11. engines = enginePool.listEngine(ENGINE_TYPE.ASR)
  12. return [enginePool.getEngine(ENGINE_TYPE.ASR, engine).desc() for engine in engines]
  13. def get_asr_default() -> EngineDesc:
  14. return enginePool.getEngine(ENGINE_TYPE.ASR, config.SERVER.ENGINES.ASR.DEFAULT).desc()
  15. def get_asr_param(name: str) -> List[ParamDesc]:
  16. engine = enginePool.getEngine(ENGINE_TYPE.ASR, name)
  17. return engine.parameters()
  18. async def asr_infer(user: UserDesc, items: ASREngineInput) -> TextMessage:
  19. if items.engine.lower() == "default":
  20. items.engine = config.SERVER.ENGINES.ASR.DEFAULT
  21. input = AudioMessage(data=items.data, sampleRate=items.sampleRate, sampleWidth=items.sampleWidth, type=items.type)
  22. engine = enginePool.getEngine(ENGINE_TYPE.ASR, items.engine)
  23. if engine.inferType != INFER_TYPE.NORMAL:
  24. raise Exception("ASR engine {} not support infer type {}".format(items.engine, engine.inferType))
  25. output: TextMessage = await engine.run(input=input, user=user, **items.config)
  26. return output
  27. async def asr_stream_infer(user: UserDesc, websocket: WebSocket):
  28. await websocket.accept()
  29. client_waitting = True
  30. while client_waitting:
  31. action, payload = await WebSocketHandler.recv_message(websocket)
  32. match action:
  33. case WS_RECV_ACTION_TYPE.PING:
  34. await WebSocketHandler.send_message(websocket, WS_SEND_ACTION_TYPE.PONG, b'')
  35. case WS_RECV_ACTION_TYPE.ENGINE_START:
  36. # 解析payload
  37. items = EngineInput.model_validate_json(payload)
  38. client_waitting = False
  39. case _:
  40. await WebSocketHandler.send_message(websocket, WS_SEND_ACTION_TYPE.ERROR, 'First action must be ENGINE_START | PING')
  41. return
  42. if items.engine.lower() == "default":
  43. items.engine = config.SERVER.ENGINES.ASR.DEFAULT
  44. engine = enginePool.getEngine(ENGINE_TYPE.ASR, items.engine)
  45. if engine.inferType != INFER_TYPE.STREAM:
  46. raise Exception("ASR engine {} not support infer type {}".format(items.engine, engine.inferType))
  47. await engine.run(websocket=websocket, user=user, **items.config)