| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 |
- # -*- coding: utf-8 -*-
- from typing import List, Dict
- from yacs.config import CfgNode as CN
- from abc import ABC, abstractmethod
- from digitalHuman.protocol import BaseMessage, ParamDesc, EngineDesc, ENGINE_TYPE, INFER_TYPE
- __all__ = ["BaseRunner"]
- class BaseRunner(ABC):
- def __init__(self, config: CN, type: ENGINE_TYPE):
- self.cfg = config
- self._engineType = type
- self.setup()
-
- def __del__(self):
- self.release()
-
- @property
- def name(self) -> str:
- return self.cfg.NAME
-
- @property
- def type(self) -> ENGINE_TYPE:
- return self._engineType
-
- @property
- def inferType(self) -> INFER_TYPE:
- if "infer_type" not in self.meta(): return INFER_TYPE.NORMAL
- if self.meta()['infer_type'] == 'stream':
- return INFER_TYPE.STREAM
- elif self.meta()['infer_type'] == 'normal':
- return INFER_TYPE.NORMAL
- else:
- raise RuntimeError(f"Invalid infer type: {self.meta()['infer_type']}")
-
- def desc(self) -> EngineDesc:
- return EngineDesc(
- name=self.name,
- type=self.type,
- infer_type=self.inferType,
- desc=self.cfg.DESC if "DESC" in self.cfg else "",
- meta=self.meta()
- )
-
- def meta(self) -> Dict:
- if "META" not in self.cfg: return {}
- return self.cfg.META
-
- def custom(self) -> Dict:
- if "CUSTOM" not in self.cfg: return {}
- return self.cfg.CUSTOM
- def parameters(self) -> List[ParamDesc]:
- if "PARAMETERS" not in self.cfg: return []
- params = []
- for param in self.cfg.PARAMETERS:
- params.append(ParamDesc.model_validate(param))
- return params
-
- def checkParameter(self, **kwargs) -> Dict:
- paramters = {}
- for paramter in self.parameters():
- if paramter.name not in kwargs:
- if not paramter.required:
- paramters[paramter.name] = paramter.default
- continue
- raise RuntimeError(f"Missing parameter: {paramter.name}")
- paramters[paramter.name] = kwargs[paramter.name]
- # 额外参数填充
- for k, v in kwargs.items():
- if k not in paramters:
- paramters[k] = v
- return paramters
-
- def setup(self):
- pass
- def release(self):
- pass
- @abstractmethod
- async def run(self, input: BaseMessage, **kwargs):
- raise NotImplementedError
|