asr_api_v0.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. # -*- coding: utf-8 -*-
  2. import json
  3. from fastapi import APIRouter, UploadFile, Form
  4. from fastapi.responses import JSONResponse
  5. from digitalHuman.server.reponse import Response
  6. from digitalHuman.server.header import HeaderInfo
  7. from digitalHuman.server.models import *
  8. from digitalHuman.server.core.api_asr_v0_impl import *
  9. router = APIRouter(prefix="/asr/v0")
  10. enginePool = EnginePool()
  11. # ========================= 获取asr支持列表 ===========================
  12. @router.get("/engine", response_model=EngineListResp, summary="Get ASR Engine List")
  13. def api_get_asr_list():
  14. """
  15. 获取asr支持引擎列表
  16. """
  17. response = Response()
  18. try:
  19. response.data = get_asr_list()
  20. except Exception as e:
  21. response.data = []
  22. response.error(str(e))
  23. return JSONResponse(content=response.validate(EngineListResp), status_code=200)
  24. # ========================= 获取asr默认引擎 ===========================
  25. @router.get("/engine/default", response_model=EngineDefaultResp, summary="Get Default ASR Engine")
  26. def api_get_asr_default():
  27. """
  28. 获取默认asr引擎
  29. """
  30. response = Response()
  31. try:
  32. response.data = get_asr_default()
  33. except Exception as e:
  34. response.data = ""
  35. response.error(str(e))
  36. return JSONResponse(content=response.validate(EngineDefaultResp), status_code=200)
  37. # ========================= 获取asr引擎参数列表 ===========================
  38. @router.get("/engine/{engine}", response_model=EngineParam, summary="Get ASR Engine param")
  39. def api_get_asr_param(engine: str):
  40. """
  41. 获取asr引擎配置参数列表
  42. """
  43. response = Response()
  44. try:
  45. response.data = get_asr_param(engine)
  46. except Exception as e:
  47. response.data = []
  48. response.error(str(e))
  49. return JSONResponse(content=response.validate(EngineParam), status_code=200)
  50. # ========================= 执行asr引擎 ===========================
  51. # wav 二进制
  52. @router.post("/engine", response_model=ASREngineOutput, summary="Speech To Text Inference (wav binary)")
  53. async def api_asr_infer(header: HeaderInfo, items: ASREngineInput):
  54. """
  55. 执行asr引擎
  56. """
  57. response = Response()
  58. try:
  59. output: TextMessage = await asr_infer(header, items)
  60. response.data = output.data
  61. except Exception as e:
  62. response.data = ""
  63. response.error(str(e))
  64. return JSONResponse(content=response.validate(ASREngineOutput), status_code=200)
  65. # mp3 文件
  66. @router.post("/engine/file", response_model=ASREngineOutput, summary="Speech To Text Inference (mp3 file)")
  67. async def api_asr_infer_file(
  68. header: HeaderInfo,
  69. file: UploadFile,
  70. engine: str = Form(...),
  71. type: AUDIO_TYPE = Form(...),
  72. config: str = Form(...),
  73. sampleRate: int = Form(...),
  74. sampleWidth: int = Form(...)
  75. ):
  76. """
  77. 执行asr引擎
  78. """
  79. response = Response()
  80. try:
  81. fileData = await file.read()
  82. items = ASREngineInput(
  83. engine=engine,
  84. type=type,
  85. config=json.loads(config),
  86. sampleRate=sampleRate,
  87. sampleWidth=sampleWidth,
  88. data=fileData
  89. )
  90. output: TextMessage = await asr_infer(header, items)
  91. response.data = output.data
  92. except Exception as e:
  93. response.data = ""
  94. response.error(str(e))
  95. return JSONResponse(content=response.validate(ASREngineOutput), status_code=200)
  96. # 流式
  97. @router.websocket("/engine/stream")
  98. async def api_asr_infer_stream(header: HeaderInfo, websocket: WebSocket):
  99. """
  100. 流式asr引擎
  101. """
  102. await asr_stream_infer(header, websocket)