api_agent_v0_impl.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. # -*- coding: utf-8 -*-
  2. from typing import List, Dict
  3. from digitalHuman.agent import AgentPool
  4. from digitalHuman.utils import config
  5. from digitalHuman.protocol import *
  6. from digitalHuman.server.models import AgentEngineInput
  7. agentPool = AgentPool()
  8. def get_agent_list() -> List[EngineDesc]:
  9. agents = agentPool.list()
  10. return [agentPool.get(agent).desc() for agent in agents]
  11. def get_agent_default() -> EngineDesc:
  12. return agentPool.get(config.SERVER.AGENTS.DEFAULT).desc()
  13. def get_agent_param(name: str) -> List[ParamDesc]:
  14. engine = agentPool.get(name)
  15. return engine.parameters()
  16. async def create_agent_conversation(name: str, param: Dict) -> str:
  17. engine = agentPool.get(name)
  18. id = await engine.createConversation(**param)
  19. return id
  20. def agent_infer_stream(user: UserDesc, items: AgentEngineInput):
  21. # 检查是否是按钮触发的对话(包含 [BUTTON_TRIGGERED] 标记)
  22. # 如果是按钮触发,添加 persona 前缀;否则直接使用用户输入
  23. BUTTON_MARKER = "[BUTTON_TRIGGERED]"
  24. if items.data.startswith(BUTTON_MARKER):
  25. # 移除标记,添加 persona 前缀
  26. user_message = items.data[len(BUTTON_MARKER):]
  27. persona_prefix = "你现在是永天科技展厅的智能客服,请介绍永天科技的产品和解决方案:\n"
  28. user_input = persona_prefix + user_message
  29. else:
  30. # 普通对话,不添加 persona 前缀
  31. user_input = items.data
  32. input = TextMessage(data=user_input)
  33. streamContent = agentPool.get(items.engine).run(input=input, user=user, streaming=True, conversation_id=items.conversation_id, **items.config)
  34. return streamContent