llmFactory.py 773 B

12345678910111213141516171819202122232425
  1. # -*- coding: utf-8 -*-
  2. from ..builder import LLMEngines
  3. from ..engineBase import BaseEngine
  4. from typing import List
  5. from yacs.config import CfgNode as CN
  6. from digitalHuman.protocol import ENGINE_TYPE
  7. from digitalHuman.utils import logger
  8. __all__ = ["LLMFactory"]
  9. class LLMFactory():
  10. """
  11. Large Language Model Factory
  12. """
  13. @staticmethod
  14. def create(config: CN) -> BaseEngine:
  15. if config.NAME in LLMEngines.list():
  16. logger.info(f"[LLMFactory] Create engine: {config.NAME}")
  17. return LLMEngines.get(config.NAME)(config, ENGINE_TYPE.LLM)
  18. else:
  19. raise RuntimeError(f"[LLMFactory] Please check config, support LLM: {LLMEngines.list()}")
  20. @staticmethod
  21. def list() -> List:
  22. return LLMEngines.list()