Browse Source

修改历史会话查询返回内容格式

zhaojinyu 4 weeks ago
parent
commit
2ef4bc5b0d

+ 9 - 0
base-modules/service-ai/service-ai-biz/src/main/java/com/usky/ai/controller/web/AiChatController.java

@@ -47,6 +47,9 @@ public class AiChatController {
     @Value("${aliDpsk.model}")
     private String dpskModel;
 
+    @Value("${ai.historyLimit}")
+    private int Limit;
+
     @Resource
     private Generation generation;
 
@@ -96,6 +99,12 @@ public class AiChatController {
         // 获取当前用户的对话历史
         List<AiQuestion> conversationHistory = aiQuestionMapper.findByUserIdAndSessionId(sessionId, userId);
 
+        // 只保留最近的几轮对话
+        int historyLimit = Limit;
+        if (conversationHistory.size() > historyLimit) {
+            conversationHistory = conversationHistory.subList(conversationHistory.size() - historyLimit, conversationHistory.size());
+        }
+
         // 构建对话历史消息
         List<Message> messages = conversationHistory.stream()
                 .map(q -> Message.builder()

+ 38 - 3
base-modules/service-ai/service-ai-biz/src/main/java/com/usky/ai/controller/web/AiSessionController.java

@@ -1,6 +1,9 @@
 package com.usky.ai.controller.web;
 
+import com.usky.ai.mapper.AiQuestionMapper;
 import com.usky.ai.mapper.AiSessionMapper;
+import com.usky.ai.service.AiQuestion;
+import com.usky.ai.service.AiQuestionItem;
 import com.usky.ai.service.AiSession;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.beans.factory.annotation.Autowired;
@@ -9,6 +12,7 @@ import org.springframework.web.bind.annotation.PathVariable;
 import org.springframework.web.bind.annotation.RequestMapping;
 import org.springframework.web.bind.annotation.RestController;
 
+import java.util.ArrayList;
 import java.util.List;
 
 @Slf4j
@@ -19,13 +23,44 @@ public class AiSessionController {
     @Autowired
     private AiSessionMapper aiSessionMapper;
 
+    @Autowired
+    private AiQuestionMapper aiQuestionMapper;
+
     @GetMapping("/all")
     public List<AiSession> getAllSessions() {
         return aiSessionMapper.findAll();
     }
 
-    @GetMapping("/{sessionId}")
-    public List<AiSession> getSessionsBySessionId(@PathVariable String sessionId) {
-        return aiSessionMapper.findBySessionId(sessionId);
+    @GetMapping("/{userId}")
+    public List<AiSession> getSessionsByUserId(@PathVariable Long userId) {
+        List<AiSession> sessions = aiSessionMapper.findByUserId(userId);
+
+        for (AiSession session : sessions) {
+            List<AiQuestion> questions = aiQuestionMapper.findQuestionsBySessionId(session.getSessionId());
+            List<AiQuestionItem> itemList = new ArrayList<>();
+
+            for (AiQuestion question : questions) {
+                AiQuestionItem userItem = new AiQuestionItem("user", question.getQuestion());
+                userItem.setId(question.getId());
+                userItem.setSessionId(question.getSessionId());
+                userItem.setUserId(question.getUserId());
+                userItem.setUserName(question.getUserName());
+                userItem.setAskTime(question.getAskTime());
+
+                AiQuestionItem assistantItem = new AiQuestionItem("assistant", question.getAnswer());
+                assistantItem.setId(question.getId());
+                assistantItem.setSessionId(question.getSessionId());
+                assistantItem.setUserId(question.getUserId());
+                assistantItem.setUserName(question.getUserName());
+                assistantItem.setAskTime(question.getAskTime());
+
+                itemList.add(userItem);
+                itemList.add(assistantItem);
+            }
+
+            session.setItemList(itemList);
+        }
+
+        return sessions;
     }
 }

+ 4 - 0
base-modules/service-ai/service-ai-biz/src/main/java/com/usky/ai/mapper/AiQuestionMapper.java

@@ -30,4 +30,8 @@ public interface AiQuestionMapper {
     //根据 sessionId查询数据
     @Select("SELECT * FROM ai_questions WHERE session_id = #{sessionId} ORDER BY ask_time ASC")
     List<AiQuestion> findBySessionId(String sessionId);
+
+    // 根据 sessionId 查询 ai_questions 表中的数据
+    @Select("SELECT * FROM ai_questions WHERE session_id = #{sessionId} ORDER BY ask_time ASC")
+    List<AiQuestion> findQuestionsBySessionId(String sessionId);
 }

+ 4 - 4
base-modules/service-ai/service-ai-biz/src/main/java/com/usky/ai/mapper/AiSessionMapper.java

@@ -1,5 +1,6 @@
 package com.usky.ai.mapper;
 
+import com.usky.ai.service.AiQuestion;
 import com.usky.ai.service.AiSession;
 import org.apache.ibatis.annotations.Insert;
 import org.apache.ibatis.annotations.Mapper;
@@ -17,10 +18,6 @@ public interface AiSessionMapper {
     @Select("SELECT * FROM ai_sessions ORDER BY ask_time ASC")
     List<AiSession> findAll();
 
-    // 根据 session_id 查询数据
-    @Select("SELECT * FROM ai_sessions WHERE session_id = #{sessionId} ORDER BY ask_time ASC")
-    List<AiSession> findBySessionId(String sessionId);
-
     //根据user_id 查询数据
     @Select("SELECT * FROM ai_sessions WHERE user_id = #{userId} ORDER BY ask_time ASC")
     List<AiSession> findByUserId(Long userId);
@@ -28,4 +25,7 @@ public interface AiSessionMapper {
     // 检查是否存在指定的 session_id
     @Select("SELECT COUNT(*) FROM ai_sessions WHERE session_id = #{sessionId}")
     boolean existsBySessionId(String sessionId);
+
+    @Select("SELECT * FROM ai_questions WHERE session_id = #{sessionId} ORDER BY ask_time ASC")
+    List<AiQuestion> findQuestionsBySessionId(String sessionId);
 }

+ 46 - 0
base-modules/service-ai/service-ai-biz/src/main/java/com/usky/ai/service/AiQuestionItem.java

@@ -0,0 +1,46 @@
+package com.usky.ai.service;
+
+import java.time.LocalDateTime;
+
+public class AiQuestionItem {
+    private Long id;
+    private String sessionId;
+    private Long userId;
+    private String userName;
+    private String role;
+    private String content;
+    private LocalDateTime askTime;
+
+    public AiQuestionItem(String role, String content) {
+        this.role = role;
+        this.content = content;
+    }
+
+    public Long getId() { return id; }
+
+    public void setId(Long id) { this.id = id; }
+
+    public String getSessionId() { return sessionId; }
+
+    public void setSessionId(String sessionId) { this.sessionId = sessionId; }
+
+    public Long getUserId() { return userId; }
+
+    public void setUserId(Long userId) { this.userId = userId; }
+
+    public String getUserName() { return userName; }
+
+    public void setUserName(String userName) { this.userName = userName; }
+
+    public String getRole() { return role; }
+
+    public void setRole(String role) { this.role = role; }
+
+    public String getContent() { return content; }
+
+    public void setContent(String content) { this.content = content; }
+
+    public LocalDateTime getAskTime() { return askTime; }
+
+    public void setAskTime(LocalDateTime askTime) { this.askTime = askTime; }
+}

+ 6 - 0
base-modules/service-ai/service-ai-biz/src/main/java/com/usky/ai/service/AiSession.java

@@ -1,6 +1,7 @@
 package com.usky.ai.service;
 
 import java.time.LocalDateTime;
+import java.util.List;
 
 public class AiSession {
 
@@ -10,6 +11,7 @@ public class AiSession {
     private String userName; // 添加用户名字段
     private String question;
     private LocalDateTime askTime;
+    private List<AiQuestionItem> itemList;
 
     public Long getId() { return id; }
 
@@ -34,4 +36,8 @@ public class AiSession {
     public LocalDateTime getAskTime() { return askTime; }
 
     public void setAskTime(LocalDateTime askTime) { this.askTime = askTime; }
+
+    public List<AiQuestionItem> getItemList() { return itemList; }
+
+    public void setItemList(List<AiQuestionItem> itemList) { this.itemList = itemList; }
 }

+ 1 - 1
base-modules/service-ai/service-ai-biz/src/main/resources/bootstrap.yml

@@ -1,6 +1,6 @@
 # Tomcat
 server:
-  port: 8080
+  port: 9893
 # Spring
 spring: 
   application: