llm_api_v0.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # -*- coding: utf-8 -*-
  2. from fastapi import APIRouter
  3. from fastapi.responses import JSONResponse, StreamingResponse
  4. from digitalHuman.protocol import TextMessage
  5. from digitalHuman.engine import EnginePool
  6. from digitalHuman.server.reponse import Response, streamInteralError
  7. from digitalHuman.server.header import HeaderInfo
  8. from digitalHuman.server.models import *
  9. from digitalHuman.server.core.api_llm_v0_impl import *
  10. router = APIRouter(prefix="/llm/v0")
  11. enginePool = EnginePool()
  12. # ========================= 获取asr支持列表 ===========================
  13. @router.get("/engine", response_model=EngineListResp, summary="Get LLM Engine List")
  14. def api_get_llm_list():
  15. """
  16. 获取asr支持引擎列表
  17. """
  18. response = Response()
  19. try:
  20. response.data = get_llm_list()
  21. except Exception as e:
  22. response.data = []
  23. response.error(str(e))
  24. return JSONResponse(content=response.validate(EngineListResp), status_code=200)
  25. # ========================= 获取asr默认引擎 ===========================
  26. @router.get("/engine/default", response_model=EngineDefaultResp, summary="Get Default LLM Engine")
  27. def api_get_asr_default():
  28. """
  29. 获取默认asr引擎
  30. """
  31. response = Response()
  32. try:
  33. response.data = get_llm_default()
  34. except Exception as e:
  35. response.data = ""
  36. response.error(str(e))
  37. return JSONResponse(content=response.validate(EngineDefaultResp), status_code=200)
  38. # ========================= 获取asr引擎参数列表 ===========================
  39. @router.get("/engine/{engine}", response_model=EngineParam, summary="Get LLM Engine param")
  40. def api_get_asr_param(engine: str):
  41. """
  42. 获取asr引擎配置参数列表
  43. """
  44. response = Response()
  45. try:
  46. response.data = get_llm_param(engine)
  47. except Exception as e:
  48. response.data = []
  49. response.error(str(e))
  50. return JSONResponse(content=response.validate(EngineParam), status_code=200)
  51. # ========================= 执行asr引擎 ===========================
  52. @router.post("/engine", response_model=ASREngineOutput, summary="LLM Inference")
  53. async def api_agent_infer(item: LLMEngineInput, header: HeaderInfo):
  54. if item.engine.lower() == "default":
  55. item.engine = config.SERVER.LLM.DEFAULT
  56. response = Response()
  57. try:
  58. input = TextMessage(data=item.data)
  59. return StreamingResponse(enginePool.getEngine(ENGINE_TYPE.LLM, item.engine).run(input=input, user=header, **item.config), media_type="text/event-stream")
  60. except Exception as e:
  61. response.error(str(e))
  62. return StreamingResponse(streamInteralError("Interal Error"), media_type="text/event-stream")