enginePool.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. # -*- coding: utf-8 -*-
  2. from threading import RLock
  3. from typing import List
  4. from collections import defaultdict
  5. from yacs.config import CfgNode as CN
  6. from digitalHuman.utils import logger
  7. from digitalHuman.protocol import ENGINE_TYPE
  8. from .engineBase import BaseEngine
  9. from .asr import ASRFactory
  10. from .tts import TTSFactory
  11. from .llm import LLMFactory
  12. __all__ = ["EnginePool"]
  13. class EnginePool():
  14. singleLock = RLock()
  15. _init = False
  16. def __init__(self):
  17. if not self._init:
  18. self._pool = defaultdict(dict)
  19. self._init = True
  20. # Single Instance
  21. def __new__(cls, *args, **kwargs):
  22. with EnginePool.singleLock:
  23. if not hasattr(cls, '_instance'):
  24. EnginePool._instance = super().__new__(cls)
  25. return EnginePool._instance
  26. def __del__(self):
  27. self._pool.clear()
  28. self._init = False
  29. def setup(self, config: CN):
  30. # asr
  31. for asrCfg in config.ASR.SUPPORT_LIST:
  32. self._pool[ENGINE_TYPE.ASR][asrCfg.NAME] = ASRFactory.create(asrCfg)
  33. logger.info(f"[EnginePool] ASR Engine {asrCfg.NAME} is created.")
  34. logger.info(f"[EnginePool] ASR Engine default is {config.ASR.DEFAULT}.")
  35. # tts
  36. for ttsCfg in config.TTS.SUPPORT_LIST:
  37. self._pool[ENGINE_TYPE.TTS][ttsCfg.NAME] = TTSFactory.create(ttsCfg)
  38. logger.info(f"[EnginePool] TTS Engine {ttsCfg.NAME} is created.")
  39. logger.info(f"[EnginePool] TTS Engine default is {config.TTS.DEFAULT}.")
  40. # llm
  41. for llmCfg in config.LLM.SUPPORT_LIST:
  42. self._pool[ENGINE_TYPE.LLM][llmCfg.NAME] = LLMFactory.create(llmCfg)
  43. logger.info(f"[EnginePool] LLM Engine {llmCfg.NAME} is created.")
  44. logger.info(f"[EnginePool] LLM Engine default is {config.LLM.DEFAULT}.")
  45. def listEngine(self, engineType: ENGINE_TYPE) -> List[str]:
  46. if engineType not in self._pool: return []
  47. return self._pool[engineType].keys()
  48. def getEngine(self, engineType: ENGINE_TYPE, engineName: str) -> BaseEngine:
  49. if engineType not in self._pool:
  50. raise KeyError(f"[EnginePool] No such engine type: {engineType}")
  51. if engineName not in self._pool[engineType]:
  52. raise KeyError(f"[EnginePool] No such engine: {engineName}")
  53. return self._pool[engineType][engineName]