Browse Source

流式输出并将提取content字段的值存入数据库

unknown 1 month ago
parent
commit
e34bd3f2d3

+ 111 - 50
base-modules/service-ai/service-ai-biz/src/main/java/com/usky/ai/controller/web/AiQuestionController.java

@@ -7,15 +7,22 @@ import com.alibaba.dashscope.common.Message;
 import com.alibaba.dashscope.common.Role;
 import com.alibaba.dashscope.exception.InputRequiredException;
 import com.alibaba.dashscope.exception.NoApiKeyException;
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
 import com.usky.ai.mapper.AiQuestionMapper;
 import com.usky.ai.service.AiQuestion;
 import com.usky.common.security.utils.SecurityUtils;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.beans.factory.annotation.Value;
+import org.springframework.http.MediaType;
+import org.springframework.http.ResponseEntity;
 import org.springframework.web.bind.annotation.*;
+import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody;
 
 import javax.annotation.Resource;
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
 import java.time.LocalDateTime;
 import java.util.List;
 import java.util.stream.Collectors;
@@ -34,6 +41,8 @@ public class AiQuestionController {
     @Autowired
     private AiQuestionMapper aiQuestionMapper;
 
+    private final ObjectMapper objectMapper = new ObjectMapper();
+
     // 查询所有数据
     @GetMapping("/all")
     public List<AiQuestion> getAllQuestions() {
@@ -59,11 +68,10 @@ public class AiQuestionController {
         return aiQuestionMapper.findByUserIdAndSessionId(sessionId, userId);
     }
 
-    //阿里百炼通义千问大模型
-    @PostMapping(value = "/aliTyqw")
-    public String send1(@RequestBody String content, @RequestParam(required = false) String sessionId) throws NoApiKeyException, InputRequiredException {
-
-        //获取当前登录用户的信息
+    // 阿里百炼通义千问大模型
+    @PostMapping(value = "/aliTyqw", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
+    public ResponseEntity<StreamingResponseBody> send1(@RequestBody String content, @RequestParam(required = false) String sessionId) throws NoApiKeyException, InputRequiredException {
+        // 获取当前登录用户的信息
         Long userId = SecurityUtils.getUserId();
         String userName = SecurityUtils.getLoginUser().getSysUser().getNickName();
 
@@ -73,7 +81,6 @@ public class AiQuestionController {
         }
 
         // 获取当前用户的对话历史
-        log.info("会话ID: {}, 用户ID: {}", sessionId, userId);
         List<AiQuestion> conversationHistory = aiQuestionMapper.findByUserIdAndSessionId(sessionId, userId);
 
         // 构建对话历史消息
@@ -84,14 +91,27 @@ public class AiQuestionController {
                         .build())
                 .collect(Collectors.toList());
 
-        // 用户与模型的对话历史
+        // 解析 JSON 并提取 "content" 字段的值
+        String questionText;
+        try {
+            JsonNode jsonNode = objectMapper.readTree(content);
+            questionText = jsonNode.get("content").asText(); // 提取 "content" 字段的值
+        } catch (IOException e) {
+            log.error("Error parsing JSON content", e);
+            return ResponseEntity.badRequest().body(outputStream -> {
+                outputStream.write("Invalid JSON format".getBytes(StandardCharsets.UTF_8));
+                outputStream.flush();
+            });
+        }
+
         // 添加用户的新消息
         Message userMessage = Message.builder()
                 .role(Role.USER.getValue())
-                .content(content)
+                .content(questionText) // 使用提取的文本
                 .build();
         messages.add(userMessage);
 
+        // 构建模型调用参数
         GenerationParam param = GenerationParam.builder()
                 .model("qwen-turbo")
                 .messages(messages)
@@ -101,30 +121,44 @@ public class AiQuestionController {
                 .enableSearch(true)
                 .build();
 
-        GenerationResult generationResult = generation.call(param);
-
-        // 获取回答内容
-        String answer = generationResult.getOutput().getChoices().get(0).getMessage().getContent();
-
-        // 创建实体并保存到数据库
-        AiQuestion aiQuestion = new AiQuestion();
-        aiQuestion.setModel("qwen-turbo");
-        aiQuestion.setSessionId(sessionId);
-        aiQuestion.setUserId(userId);
-        aiQuestion.setUserName(userName);
-        aiQuestion.setQuestion(content);
-        aiQuestion.setAnswer(answer);
-        aiQuestion.setAskTime(LocalDateTime.now());
-        aiQuestionMapper.save(aiQuestion);
-
-        return answer;
+        String finalSessionId = sessionId;
+        return ResponseEntity.ok()
+                .contentType(MediaType.TEXT_EVENT_STREAM)
+                .body(outputStream -> {
+                    try {
+                        GenerationResult generationResult = generation.call(param);
+
+                        // 获取回答内容
+                        String answer = generationResult.getOutput().getChoices().get(0).getMessage().getContent();
+
+                        // 将回答内容写入输出流
+                        outputStream.write(answer.getBytes(StandardCharsets.UTF_8));
+                        outputStream.flush();
+
+                        // 创建实体并保存到数据库
+                        AiQuestion aiQuestion = new AiQuestion();
+                        aiQuestion.setModel("qwen-turbo");
+                        aiQuestion.setSessionId(finalSessionId);
+                        aiQuestion.setUserId(userId);
+                        aiQuestion.setUserName(userName);
+                        aiQuestion.setQuestion(questionText); // 存入提取的文本
+                        aiQuestion.setAnswer(answer);
+                        aiQuestion.setAskTime(LocalDateTime.now());
+                        aiQuestionMapper.save(aiQuestion);
+                    } catch (IOException | NoApiKeyException | InputRequiredException e) {
+                        log.error("Error processing request", e);
+                        outputStream.write("Error processing request".getBytes(StandardCharsets.UTF_8));
+                        outputStream.flush();
+                    } finally {
+                        outputStream.close();
+                    }
+                });
     }
 
-    //阿里百炼DeepSeek大模型
-    @PostMapping(value = "/aliDeepSeek")
-    public String send2(@RequestBody String content, @RequestParam(required = false) String sessionId) throws NoApiKeyException, InputRequiredException {
-
-        //获取当前登录用户的信息
+    // 阿里百炼DeepSeek大模型
+    @PostMapping(value = "/aliDeepSeek", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
+    public ResponseEntity<StreamingResponseBody> send2(@RequestBody String content, @RequestParam(required = false) String sessionId) throws NoApiKeyException, InputRequiredException {
+        // 获取当前登录用户的信息
         Long userId = SecurityUtils.getUserId();
         String userName = SecurityUtils.getLoginUser().getSysUser().getNickName();
 
@@ -134,7 +168,6 @@ public class AiQuestionController {
         }
 
         // 获取当前用户的对话历史
-        log.info("会话ID: {}, 用户ID: {}", sessionId, userId);
         List<AiQuestion> conversationHistory = aiQuestionMapper.findByUserIdAndSessionId(sessionId, userId);
 
         // 构建对话历史消息
@@ -145,14 +178,27 @@ public class AiQuestionController {
                         .build())
                 .collect(Collectors.toList());
 
-        // 用户与模型的对话历史
+        // 解析 JSON 并提取 "content" 字段的值
+        String questionText;
+        try {
+            JsonNode jsonNode = objectMapper.readTree(content);
+            questionText = jsonNode.get("content").asText(); // 提取 "content" 字段的值
+        } catch (IOException e) {
+            log.error("Error parsing JSON content", e);
+            return ResponseEntity.badRequest().body(outputStream -> {
+                outputStream.write("Invalid JSON format".getBytes(StandardCharsets.UTF_8));
+                outputStream.flush();
+            });
+        }
+
         // 添加用户的新消息
         Message userMessage = Message.builder()
                 .role(Role.USER.getValue())
-                .content(content)
+                .content(questionText) // 使用提取的文本
                 .build();
         messages.add(userMessage);
 
+        // 构建模型调用参数
         GenerationParam param = GenerationParam.builder()
                 .model("deepseek-v3")
                 .messages(messages)
@@ -160,22 +206,37 @@ public class AiQuestionController {
                 .apiKey(apiKey)
                 .build();
 
-        GenerationResult generationResult = generation.call(param);
-
-        // 获取回答内容
-        String answer = generationResult.getOutput().getChoices().get(0).getMessage().getContent();
-
-        // 创建实体并保存到数据库
-        AiQuestion aiQuestion = new AiQuestion();
-        aiQuestion.setModel("deepseek-v3");
-        aiQuestion.setSessionId(sessionId);
-        aiQuestion.setUserId(userId);
-        aiQuestion.setUserName(userName);
-        aiQuestion.setQuestion(content);
-        aiQuestion.setAnswer(answer);
-        aiQuestion.setAskTime(LocalDateTime.now());
-        aiQuestionMapper.save(aiQuestion);
-
-        return answer;
+        String finalSessionId = sessionId;
+        return ResponseEntity.ok()
+                .contentType(MediaType.TEXT_EVENT_STREAM)
+                .body(outputStream -> {
+                    try {
+                        GenerationResult generationResult = generation.call(param);
+
+                        // 获取回答内容
+                        String answer = generationResult.getOutput().getChoices().get(0).getMessage().getContent();
+
+                        // 将回答内容写入输出流
+                        outputStream.write(answer.getBytes(StandardCharsets.UTF_8));
+                        outputStream.flush();
+
+                        // 创建实体并保存到数据库
+                        AiQuestion aiQuestion = new AiQuestion();
+                        aiQuestion.setModel("deepseek-v3");
+                        aiQuestion.setSessionId(finalSessionId);
+                        aiQuestion.setUserId(userId);
+                        aiQuestion.setUserName(userName);
+                        aiQuestion.setQuestion(questionText); // 存入提取的文本
+                        aiQuestion.setAnswer(answer);
+                        aiQuestion.setAskTime(LocalDateTime.now());
+                        aiQuestionMapper.save(aiQuestion);
+                    } catch (IOException | NoApiKeyException | InputRequiredException e) {
+                        log.error("Error processing request", e);
+                        outputStream.write("Error processing request".getBytes(StandardCharsets.UTF_8));
+                        outputStream.flush();
+                    } finally {
+                        outputStream.close();
+                    }
+                });
     }
 }