aliNLSTTS.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. import asyncio
  2. import random
  3. import threading
  4. from io import BytesIO
  5. from typing import Optional # Added for type hinting
  6. from digitalHuman.protocol import *
  7. from digitalHuman.utils import logger
  8. import nls # Alibaba NLS SDK, when need to be installed
  9. from ..builder import TTSEngines
  10. from ..engineBase import BaseEngine
  11. from yacs.config import CfgNode as CN
  12. __all__ = ["AliNLSTTS"]
  13. VOICE_LIST = [
  14. VoiceDesc(name="zhifeng_emo", gender=GENDER_TYPE.MALE),
  15. VoiceDesc(name="zhibing_emo", gender=GENDER_TYPE.MALE),
  16. VoiceDesc(name="zhitian_emo", gender=GENDER_TYPE.FEMALE),
  17. VoiceDesc(name="zhibei_emo", gender=GENDER_TYPE.FEMALE),
  18. VoiceDesc(name="zhiyan_emo", gender=GENDER_TYPE.FEMALE),
  19. VoiceDesc(name="zhimi_emo", gender=GENDER_TYPE.FEMALE),
  20. VoiceDesc(name="zhimiao_emo", gender=GENDER_TYPE.FEMALE),
  21. ]
  22. @TTSEngines.register("AliNLSTTS")
  23. class AliNLSTTS(BaseEngine):
  24. EMOTION_LIST = ['angry', 'fear', 'happy', 'hate', 'neutral', 'sad', 'surprise']
  25. def generate_remotion_ssml_text(self, text: str) -> str:
  26. return f'<speak><emotion category="{random.choice(self.EMOTION_LIST)}" intensity="1.0">{text}</emotion></speak>'
  27. async def voices(self) -> List[VoiceDesc]:
  28. return VOICE_LIST
  29. class NlsWorker:
  30. def __init__(
  31. self,
  32. text: str,
  33. config: CN,
  34. voice: str,
  35. token: str,
  36. api_key: str,
  37. ):
  38. self._text = text
  39. self._config = config
  40. self._voice = voice
  41. self._token = token
  42. self._api_key = api_key
  43. self._audio_buffer = BytesIO()
  44. self._completion_event = threading.Event()
  45. self._error_occurred = False
  46. self._error_message = ""
  47. # Configure NLS SDK debugging based on environment or config
  48. # nls.enableTrace(True) # Enable for debugging if needed
  49. def on_error(self, message, *args):
  50. logger.error(f"[{self._config.NAME}] On error: {message}, args: {args}")
  51. self._error_message = str(message)
  52. self._error_occurred = True
  53. self._completion_event.set() # Signal completion even on error
  54. def on_close(self, *args):
  55. logger.debug(f"[{self._config.NAME}] On close: args: {args}")
  56. self._completion_event.set() # Ensure completion is signaled
  57. def on_data(self, data, *args):
  58. if data:
  59. self._audio_buffer.write(data)
  60. def on_completed(self, message, *args):
  61. logger.debug(f"[{self._config.NAME}] On completed: {message}")
  62. self._completion_event.set()
  63. def synthesize(self) -> Optional[bytes]:
  64. tts = nls.NlsSpeechSynthesizer(
  65. url=self._config.URL,
  66. appkey=self._api_key,
  67. token=self._token,
  68. on_data=self.on_data,
  69. on_completed=self.on_completed,
  70. on_error=self.on_error,
  71. on_close=self.on_close,
  72. callback_args=[]
  73. )
  74. logger.debug(f"[{self._config.NAME}] Starting TTS synthesis for text: {self._text[:50]}...")
  75. # The NLS SDK's start method expects parameters like voice, format, sample_rate.
  76. # Make sure these are correctly passed from the config.
  77. # The text input here is expected to be SSML.
  78. logger.info(f"{self._text=}")
  79. tts.start(
  80. self._text,
  81. voice=self._voice,
  82. aformat=self._config.FORMAT.lower(), # SDK expects 'pcm', 'mp3', 'wav'
  83. sample_rate=self._config.SAMPLE_RATE
  84. )
  85. self._completion_event.wait() # Wait for callbacks to complete
  86. if self._error_occurred:
  87. logger.error(f"[{self._config.NAME}] Synthesis failed: {self._error_message}")
  88. return None
  89. self._audio_buffer.seek(0)
  90. return self._audio_buffer.getvalue()
  91. async def run(self, input: TextMessage, **kwargs) -> Optional[AudioMessage]:
  92. logger.info(f"[{self.cfg.NAME}] Received text for TTS: {input.data[:50]}...")
  93. # 参数校验
  94. paramters = self.checkParameter(**kwargs)
  95. voice = paramters["voice"]
  96. token = paramters["token"]
  97. api_key = paramters["api_key"]
  98. if not input.data:
  99. logger.warning(f"[{self.cfg.NAME}] Received empty text for TTS.")
  100. return None
  101. worker = self.NlsWorker(
  102. text=self.generate_remotion_ssml_text(input.data),
  103. config=self.cfg,
  104. voice=voice,
  105. token=token,
  106. api_key=api_key
  107. )
  108. # change to async function
  109. loop = asyncio.get_event_loop()
  110. audio_content = await loop.run_in_executor(None, worker.synthesize)
  111. config_audio_out_format = self.cfg.FORMAT.lower()
  112. if audio_content:
  113. if config_audio_out_format == "mp3":
  114. audio_format = AUDIO_TYPE.MP3
  115. elif config_audio_out_format == "wav":
  116. audio_format = AUDIO_TYPE.WAV
  117. else:
  118. raise ValueError(f"Unsupported {config_audio_out_format} for ALI NLS tts")
  119. logger.info(f"[{self.cfg.NAME}] TTS synthesis successful. Audio size: {len(audio_content)} bytes")
  120. return AudioMessage(
  121. data=audio_content,
  122. format=audio_format,
  123. sampleRate=self.cfg.SAMPLE_RATE,
  124. sampleWidth=0, # This might need adjustment based on format
  125. desc="Alibaba NLS TTS"
  126. )
  127. else:
  128. logger.error(f"[{self.cfg.NAME}] TTS synthesis failed to produce audio content.")
  129. return None