api_tts_v0_impl.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334
  1. # -*- coding: utf-8 -*-
  2. from typing import List, Dict
  3. from digitalHuman.engine import EnginePool, BaseTTSEngine
  4. from digitalHuman.utils import config
  5. from digitalHuman.protocol import ParamDesc, EngineDesc, ENGINE_TYPE, UserDesc, AudioMessage, TextMessage, VoiceDesc
  6. from digitalHuman.server.models import TTSEngineInput
  7. enginePool = EnginePool()
  8. def get_tts_list() -> List[EngineDesc]:
  9. engines = enginePool.listEngine(ENGINE_TYPE.TTS)
  10. return [enginePool.getEngine(ENGINE_TYPE.TTS, engine).desc() for engine in engines]
  11. def get_tts_default() -> EngineDesc:
  12. return enginePool.getEngine(ENGINE_TYPE.TTS, config.SERVER.ENGINES.TTS.DEFAULT).desc()
  13. async def get_tts_voice(name: str, **kwargs) -> List[VoiceDesc]:
  14. engine: BaseTTSEngine = enginePool.getEngine(ENGINE_TYPE.TTS, name)
  15. voices = await engine.voices(**kwargs)
  16. return voices
  17. def get_tts_param(name: str) -> List[ParamDesc]:
  18. engine = enginePool.getEngine(ENGINE_TYPE.TTS, name)
  19. return engine.parameters()
  20. async def tts_infer(user: UserDesc, item: TTSEngineInput) -> AudioMessage:
  21. if item.engine.lower() == "default":
  22. item.engine = config.SERVER.ENGINES.TTS.DEFAULT
  23. input = TextMessage(data=item.data)
  24. engine = enginePool.getEngine(ENGINE_TYPE.TTS, item.engine)
  25. output: AudioMessage = await engine.run(input=input, user=user, **item.config)
  26. return output