runner.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # -*- coding: utf-8 -*-
  2. from typing import List, Dict
  3. from yacs.config import CfgNode as CN
  4. from abc import ABC, abstractmethod
  5. from digitalHuman.protocol import BaseMessage, ParamDesc, EngineDesc, ENGINE_TYPE, INFER_TYPE
  6. __all__ = ["BaseRunner"]
  7. class BaseRunner(ABC):
  8. def __init__(self, config: CN, type: ENGINE_TYPE):
  9. self.cfg = config
  10. self._engineType = type
  11. self.setup()
  12. def __del__(self):
  13. self.release()
  14. @property
  15. def name(self) -> str:
  16. return self.cfg.NAME
  17. @property
  18. def type(self) -> ENGINE_TYPE:
  19. return self._engineType
  20. @property
  21. def inferType(self) -> INFER_TYPE:
  22. if "infer_type" not in self.meta(): return INFER_TYPE.NORMAL
  23. if self.meta()['infer_type'] == 'stream':
  24. return INFER_TYPE.STREAM
  25. elif self.meta()['infer_type'] == 'normal':
  26. return INFER_TYPE.NORMAL
  27. else:
  28. raise RuntimeError(f"Invalid infer type: {self.meta()['infer_type']}")
  29. def desc(self) -> EngineDesc:
  30. return EngineDesc(
  31. name=self.name,
  32. type=self.type,
  33. infer_type=self.inferType,
  34. desc=self.cfg.DESC if "DESC" in self.cfg else "",
  35. meta=self.meta()
  36. )
  37. def meta(self) -> Dict:
  38. if "META" not in self.cfg: return {}
  39. return self.cfg.META
  40. def custom(self) -> Dict:
  41. if "CUSTOM" not in self.cfg: return {}
  42. return self.cfg.CUSTOM
  43. def parameters(self) -> List[ParamDesc]:
  44. if "PARAMETERS" not in self.cfg: return []
  45. params = []
  46. for param in self.cfg.PARAMETERS:
  47. params.append(ParamDesc.model_validate(param))
  48. return params
  49. def checkParameter(self, **kwargs) -> Dict:
  50. paramters = {}
  51. for paramter in self.parameters():
  52. if paramter.name not in kwargs:
  53. if not paramter.required:
  54. paramters[paramter.name] = paramter.default
  55. continue
  56. raise RuntimeError(f"Missing parameter: {paramter.name}")
  57. paramters[paramter.name] = kwargs[paramter.name]
  58. # 额外参数填充
  59. for k, v in kwargs.items():
  60. if k not in paramters:
  61. paramters[k] = v
  62. return paramters
  63. def setup(self):
  64. pass
  65. def release(self):
  66. pass
  67. @abstractmethod
  68. async def run(self, input: BaseMessage, **kwargs):
  69. raise NotImplementedError