Browse Source

修改AI大模型内容返回方式,添加html文件进行测试实现流式输出效果

zhaojinyu 1 month ago
parent
commit
5cc0ee5356

+ 111 - 109
base-modules/service-ai/service-ai-biz/src/main/java/com/usky/ai/controller/web/AiChatController.java

@@ -5,6 +5,7 @@ import com.alibaba.dashscope.aigc.generation.GenerationParam;
 import com.alibaba.dashscope.aigc.generation.GenerationResult;
 import com.alibaba.dashscope.common.Message;
 import com.alibaba.dashscope.common.Role;
+import com.alibaba.dashscope.exception.ApiException;
 import com.alibaba.dashscope.exception.InputRequiredException;
 import com.alibaba.dashscope.exception.NoApiKeyException;
 import com.fasterxml.jackson.databind.JsonNode;
@@ -58,34 +59,33 @@ public class AiChatController {
         // 如果没有传入 sessionId,则创建一个新的会话ID
         if (sessionId == null || sessionId.isEmpty()) {
             sessionId = java.util.UUID.randomUUID().toString();
+        }
 
+        // 解析 JSON 并提取 "content" 字段的值
+        String questionText;
+        try {
+            JsonNode jsonNode = objectMapper.readTree(content);
+            questionText = jsonNode.get("content").asText(); // 提取 "content" 字段的值
+        } catch (IOException e) {
+            log.error("Error parsing JSON content", e);
+            return ResponseEntity.badRequest().body(outputStream -> {
+                outputStream.write("data: Invalid JSON format\n\n".getBytes(StandardCharsets.UTF_8));
+                outputStream.flush();
+            });
+        }
 
-            // 解析 JSON 并提取 "content" 字段的值
-            String questionText;
-            try {
-                JsonNode jsonNode = objectMapper.readTree(content);
-                questionText = jsonNode.get("content").asText(); // 提取 "content" 字段的值
-            } catch (IOException e) {
-                log.error("Error parsing JSON content", e);
-                return ResponseEntity.badRequest().body(outputStream -> {
-                    outputStream.write("Invalid JSON format".getBytes(StandardCharsets.UTF_8));
-                    outputStream.flush();
-                });
-            }
-
-            // 检查是否已经存在相同的 sessionId
-            boolean exists = aiSessionMapper.existsBySessionId(sessionId);
-
-            if (!exists) {
-                // 创建新的 AiSession 实体并存入数据库
-                AiSession aiSession = new AiSession();
-                aiSession.setSessionId(sessionId);
-                aiSession.setUserId(userId);
-                aiSession.setUserName(userName);
-                aiSession.setQuestion(questionText);
-                aiSession.setAskTime(LocalDateTime.now());
-                aiSessionMapper.save(aiSession);
-            }
+        // 检查是否已经存在相同的 sessionId
+        boolean exists = aiSessionMapper.existsBySessionId(sessionId);
+
+        if (!exists) {
+            // 创建新的 AiSession 实体并存入数据库
+            AiSession aiSession = new AiSession();
+            aiSession.setSessionId(sessionId);
+            aiSession.setUserId(userId);
+            aiSession.setUserName(userName);
+            aiSession.setQuestion(questionText);
+            aiSession.setAskTime(LocalDateTime.now());
+            aiSessionMapper.save(aiSession);
         }
 
         // 获取当前用户的对话历史
@@ -99,18 +99,13 @@ public class AiChatController {
                         .build())
                 .collect(Collectors.toList());
 
-        // 解析 JSON 并提取 "content" 字段的值
-        String questionText;
-        try {
-            JsonNode jsonNode = objectMapper.readTree(content);
-            questionText = jsonNode.get("content").asText(); // 提取 "content" 字段的值
-        } catch (IOException e) {
-            log.error("Error parsing JSON content", e);
-            return ResponseEntity.badRequest().body(outputStream -> {
-                outputStream.write("Invalid JSON format".getBytes(StandardCharsets.UTF_8));
-                outputStream.flush();
-            });
-        }
+        // 插入角色定义(在对话历史的开头)
+        Message roleDefinition = Message.builder()
+                .role(Role.SYSTEM.getValue()) // 使用系统角色
+//                .content("你是一个名为'永天小天Ai'的智能助手,擅长幽默和简洁的回答。") // 定义角色的行为和风格
+                .content("你是一个名为'小天Ai'的智能助手,擅长解答编程和技术问题。回答时请保持专业、清晰且简洁。") // 定义角色的行为和风格
+                .build();
+        messages.add(0, roleDefinition); // 将角色定义插入到对话历史的开头
 
         // 添加用户的新消息
         Message userMessage = Message.builder()
@@ -124,41 +119,46 @@ public class AiChatController {
                 .model("qwen-turbo")
                 .messages(messages)
                 .resultFormat(GenerationParam.ResultFormat.MESSAGE)
-                .topP(0.8)
                 .apiKey(apiKey)
-                .enableSearch(true)
+                .incrementalOutput(true) // 开启增量输出[^1^]
                 .build();
 
         String finalSessionId = sessionId;
+        StringBuilder completeAnswer = new StringBuilder(); // 用于收集完整的回答内容
         return ResponseEntity.ok()
                 .contentType(MediaType.TEXT_EVENT_STREAM)
                 .body(outputStream -> {
                     try {
-                        GenerationResult generationResult = generation.call(param);
-
-                        // 获取回答内容
-                        String answer = generationResult.getOutput().getChoices().get(0).getMessage().getContent();
-
-                        // 将回答内容写入输出流
-                        outputStream.write(answer.getBytes(StandardCharsets.UTF_8));
-                        outputStream.flush();
-
-                        // 创建实体并保存到数据库
+                        // 调用流式接口
+                        generation.streamCall(param).blockingForEach(chunk -> {
+                            // 获取每次生成的内容
+                            String partialAnswer = chunk.getOutput().getChoices().get(0).getMessage().getContent();
+                            // 将部分内容写入输出流
+                            outputStream.write(("data: " + partialAnswer + "\n\n").getBytes(StandardCharsets.UTF_8));
+                            outputStream.flush();
+                            // 累加到完整回答内容中
+                            completeAnswer.append(partialAnswer);
+                        });
+
+                        // 流式接口调用完成后,将完整回答存入数据库
                         AiQuestion aiQuestion = new AiQuestion();
                         aiQuestion.setModel("qwen-turbo");
                         aiQuestion.setSessionId(finalSessionId);
                         aiQuestion.setUserId(userId);
                         aiQuestion.setUserName(userName);
                         aiQuestion.setQuestion(questionText); // 存入提取的文本
-                        aiQuestion.setAnswer(answer);
+                        aiQuestion.setAnswer(completeAnswer.toString());
                         aiQuestion.setAskTime(LocalDateTime.now());
                         aiQuestionMapper.save(aiQuestion);
-                    } catch (IOException | NoApiKeyException | InputRequiredException e) {
+
+                    } catch (ApiException e) {
                         log.error("Error processing request", e);
-                        outputStream.write("Error processing request".getBytes(StandardCharsets.UTF_8));
+                        outputStream.write(("data: Error processing request\n\n").getBytes(StandardCharsets.UTF_8));
                         outputStream.flush();
-                    } finally {
-                        outputStream.close();
+                    } catch (NoApiKeyException e) {
+                        throw new RuntimeException(e);
+                    } catch (InputRequiredException e) {
+                        throw new RuntimeException(e);
                     }
                 });
     }
@@ -173,33 +173,33 @@ public class AiChatController {
         // 如果没有传入 sessionId,则创建一个新的会话ID
         if (sessionId == null || sessionId.isEmpty()) {
             sessionId = java.util.UUID.randomUUID().toString();
+        }
 
-            // 解析 JSON 并提取 "content" 字段的值
-            String questionText;
-            try {
-                JsonNode jsonNode = objectMapper.readTree(content);
-                questionText = jsonNode.get("content").asText(); // 提取 "content" 字段的值
-            } catch (IOException e) {
-                log.error("Error parsing JSON content", e);
-                return ResponseEntity.badRequest().body(outputStream -> {
-                    outputStream.write("Invalid JSON format".getBytes(StandardCharsets.UTF_8));
-                    outputStream.flush();
-                });
-            }
-
-            // 检查是否已经存在相同的 sessionId
-            boolean exists = aiSessionMapper.existsBySessionId(sessionId);
-
-            if (!exists) {
-                // 创建新的 AiSession 实体并存入数据库
-                AiSession aiSession = new AiSession();
-                aiSession.setSessionId(sessionId);
-                aiSession.setUserId(userId);
-                aiSession.setUserName(userName);
-                aiSession.setQuestion(questionText);
-                aiSession.setAskTime(LocalDateTime.now());
-                aiSessionMapper.save(aiSession);
-            }
+        // 解析 JSON 并提取 "content" 字段的值
+        String questionText;
+        try {
+            JsonNode jsonNode = objectMapper.readTree(content);
+            questionText = jsonNode.get("content").asText(); // 提取 "content" 字段的值
+        } catch (IOException e) {
+            log.error("Error parsing JSON content", e);
+            return ResponseEntity.badRequest().body(outputStream -> {
+                outputStream.write("data: Invalid JSON format\n\n".getBytes(StandardCharsets.UTF_8));
+                outputStream.flush();
+            });
+        }
+
+        // 检查是否已经存在相同的 sessionId
+        boolean exists = aiSessionMapper.existsBySessionId(sessionId);
+
+        if (!exists) {
+            // 创建新的 AiSession 实体并存入数据库
+            AiSession aiSession = new AiSession();
+            aiSession.setSessionId(sessionId);
+            aiSession.setUserId(userId);
+            aiSession.setUserName(userName);
+            aiSession.setQuestion(questionText);
+            aiSession.setAskTime(LocalDateTime.now());
+            aiSessionMapper.save(aiSession);
         }
 
         // 获取当前用户的对话历史
@@ -213,18 +213,13 @@ public class AiChatController {
                         .build())
                 .collect(Collectors.toList());
 
-        // 解析 JSON 并提取 "content" 字段的值
-        String questionText;
-        try {
-            JsonNode jsonNode = objectMapper.readTree(content);
-            questionText = jsonNode.get("content").asText(); // 提取 "content" 字段的值
-        } catch (IOException e) {
-            log.error("Error parsing JSON content", e);
-            return ResponseEntity.badRequest().body(outputStream -> {
-                outputStream.write("Invalid JSON format".getBytes(StandardCharsets.UTF_8));
-                outputStream.flush();
-            });
-        }
+        // 插入角色定义(在对话历史的开头)
+        Message roleDefinition = Message.builder()
+                .role(Role.SYSTEM.getValue()) // 使用系统角色
+//                .content("你是一个名为'永天小天Ai'的智能助手,擅长幽默和简洁的回答。") // 定义角色的行为和风格
+                .content("你是一个名为'小天Ai'的智能助手,擅长解答编程和技术问题。回答时请保持专业、清晰且简洁。") // 定义角色的行为和风格
+                .build();
+        messages.add(0, roleDefinition); // 将角色定义插入到对话历史的开头
 
         // 添加用户的新消息
         Message userMessage = Message.builder()
@@ -239,38 +234,45 @@ public class AiChatController {
                 .messages(messages)
                 .resultFormat(GenerationParam.ResultFormat.MESSAGE)
                 .apiKey(apiKey)
+                .incrementalOutput(true) // 开启增量输出[^1^]
                 .build();
 
         String finalSessionId = sessionId;
+        StringBuilder completeAnswer = new StringBuilder(); // 用于收集完整的回答内容
         return ResponseEntity.ok()
                 .contentType(MediaType.TEXT_EVENT_STREAM)
                 .body(outputStream -> {
                     try {
-                        GenerationResult generationResult = generation.call(param);
-
-                        // 获取回答内容
-                        String answer = generationResult.getOutput().getChoices().get(0).getMessage().getContent();
-
-                        // 将回答内容写入输出流
-                        outputStream.write(answer.getBytes(StandardCharsets.UTF_8));
-                        outputStream.flush();
-
-                        // 创建实体并保存到数据库
+                        // 调用流式接口
+                        generation.streamCall(param).blockingForEach(chunk -> {
+                            // 获取每次生成的内容
+                            String partialAnswer = chunk.getOutput().getChoices().get(0).getMessage().getContent();
+                            // 将部分内容写入输出流
+                            outputStream.write(("data: " + partialAnswer + "\n\n").getBytes(StandardCharsets.UTF_8));
+                            outputStream.flush();
+                            // 累加到完整回答内容中
+                            completeAnswer.append(partialAnswer);
+                        });
+
+                        // 流式接口调用完成后,将完整回答存入数据库
                         AiQuestion aiQuestion = new AiQuestion();
-                        aiQuestion.setModel("deepseek-v3");
+                        aiQuestion.setModel("qwen-turbo");
                         aiQuestion.setSessionId(finalSessionId);
                         aiQuestion.setUserId(userId);
                         aiQuestion.setUserName(userName);
                         aiQuestion.setQuestion(questionText); // 存入提取的文本
-                        aiQuestion.setAnswer(answer);
+                        aiQuestion.setAnswer(completeAnswer.toString());
                         aiQuestion.setAskTime(LocalDateTime.now());
                         aiQuestionMapper.save(aiQuestion);
-                    } catch (IOException | NoApiKeyException | InputRequiredException e) {
+
+                    } catch (ApiException e) {
                         log.error("Error processing request", e);
-                        outputStream.write("Error processing request".getBytes(StandardCharsets.UTF_8));
+                        outputStream.write(("data: Error processing request\n\n").getBytes(StandardCharsets.UTF_8));
                         outputStream.flush();
-                    } finally {
-                        outputStream.close();
+                    } catch (NoApiKeyException e) {
+                        throw new RuntimeException(e);
+                    } catch (InputRequiredException e) {
+                        throw new RuntimeException(e);
                     }
                 });
     }

+ 19 - 0
base-modules/service-ai/service-ai-biz/src/main/java/com/usky/ai/service/config/MyGlobalCorsConfig.java

@@ -0,0 +1,19 @@
+package com.usky.ai.service.config;
+
+import org.springframework.context.annotation.Configuration;
+import org.springframework.web.servlet.config.annotation.CorsRegistry;
+import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
+
+@Configuration
+public class MyGlobalCorsConfig implements WebMvcConfigurer {
+
+    @Override
+    public void addCorsMappings(CorsRegistry registry) {
+        registry.addMapping("/**") // 对所有的路径允许跨域请求
+                .allowedOrigins("*") // 允许来自任何源的请求
+                .allowedMethods("GET", "POST", "PUT", "DELETE") // 允许的请求方法
+                .allowedHeaders("*") // 允许的请求头
+                .allowCredentials(false) // 是否允许证书(cookies),根据需要设置
+                .maxAge(3600); // 预检请求的缓存时间(秒)
+    }
+}

+ 14 - 0
base-modules/service-ai/service-ai-biz/src/main/java/com/usky/ai/service/config/WebConfig.java

@@ -0,0 +1,14 @@
+package com.usky.ai.service.config;
+
+import org.springframework.context.annotation.Configuration;
+import org.springframework.web.servlet.config.annotation.AsyncSupportConfigurer;
+import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
+
+@Configuration
+public class WebConfig implements WebMvcConfigurer {
+    @Override
+    public void configureAsyncSupport(AsyncSupportConfigurer configurer) {
+        // 设置默认的异步请求超时时间为300秒(单位:毫秒)
+        configurer.setDefaultTimeout(300000);
+    }
+}

+ 1 - 1
base-modules/service-ai/service-ai-biz/src/main/resources/bootstrap.yml

@@ -1,6 +1,6 @@
 # Tomcat
 server:
-  port: 6666
+  port: 8080
 # Spring
 spring: 
   application:

+ 132 - 0
base-modules/service-ai/service-ai-biz/src/main/resources/static/dpsk.html

@@ -0,0 +1,132 @@
+<!DOCTYPE html>
+<html lang="en">
+<head>
+    <meta charset="UTF-8">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
+    <title>小天-AI</title>
+    <style>
+        body {
+            font-family: 'Roboto', Arial, sans-serif;
+            margin: 0;
+            padding: 0;
+            background-color: #f4f4f9;
+            display: flex;
+            justify-content: center;
+            align-items: center;
+            height: 100vh;
+        }
+        .container {
+            width: 500px;
+            background-color: #fff;
+            box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
+            border-radius: 8px;
+            overflow: hidden;
+        }
+        h1 {
+            background-color: #007bff;
+            color: #fff;
+            text-align: center;
+            padding: 15px;
+            margin: 0;
+            font-size: 1.5em;
+        }
+        form {
+            padding: 20px;
+        }
+        label {
+            display: block;
+            margin-bottom: 10px;
+            font-weight: bold;
+        }
+        textarea {
+            width: 100%;
+            height: 100px;
+            border: 1px solid #ccc;
+            border-radius: 4px;
+            padding: 10px;
+            resize: none;
+        }
+        button {
+            width: 100%;
+            padding: 10px;
+            background-color: #007bff;
+            color: #fff;
+            border: none;
+            border-radius: 4px;
+            cursor: pointer;
+            transition: background-color 0.3s ease;
+        }
+        button:hover {
+            background-color: #0056b3;
+        }
+        #response {
+            margin-top: 20px;
+            padding: 10px;
+            background-color: #f9f9f9;
+            border-top: 1px solid #ccc;
+            max-height: 200px;
+            overflow-y: auto;
+            font-family: monospace;
+            white-space: pre-wrap;
+        }
+    </style>
+</head>
+<body>
+<div class="container">
+    <h1>小天-AI</h1>
+    <form id="chatForm">
+        <label for="content">你的问题:</label>
+        <textarea id="content" name="content" placeholder="请输入你的问题在这里..."></textarea>
+        <button type="submit">发送</button>
+    </form>
+    <div id="response"></div>
+</div>
+
+<script>
+    document.getElementById('chatForm').addEventListener('submit', function(event) {
+        event.preventDefault();
+
+        const content = document.getElementById('content').value;
+
+        const requestBody = JSON.stringify({ content: content });
+
+        const token = "eyJhbGciOiJIUzUxMiJ9.eyIiOjEwMDMsInVzZXJfaWQiOjIxMywidXNlcl9rZXkiOiJlYzUxODMzNjdmYTk0ODgwOGQwZjEwODEyOWVmNjgwOSIsInVzZXJuYW1lIjoi6LW16YeR6ZuoIn0.zWulXcesI1TRcDmiAHuQ9P2WHDE2l7mDmuunx13TmVl6E5Yvs8nZvu1ddtINdw0lrnnR3Q5lZaRH3mJJTaDhig";
+
+        fetch('/ai/aliDeepSeek', {
+            method: 'POST',
+            headers: {
+                'Content-Type': 'application/json',
+                'Authorization': `Bearer ${token}`
+            },
+            body: requestBody
+        })
+            .then(response => {
+                if (!response.ok) {
+                    throw new Error('Network response was not ok');
+                }
+                return response.text();
+            })
+            .then(data => {
+                document.getElementById('response').innerText = '';
+
+                const lines = data.split('\n');
+                const fullResponse = lines.map(line => line.replace('data: ', '')).join('');
+
+                let index = 0;
+                const responseElement = document.getElementById('response');
+                const interval = setInterval(() => {
+                    if (index < fullResponse.length) {
+                        responseElement.innerText += fullResponse[index];
+                        index++;
+                    } else {
+                        clearInterval(interval);
+                    }
+                }, 50);
+            })
+            .catch(error => {
+                document.getElementById('response').innerText = 'Error: ' + error.message;
+            });
+    });
+</script>
+</body>
+</html>

+ 132 - 0
base-modules/service-ai/service-ai-biz/src/main/resources/static/tyqw.html

@@ -0,0 +1,132 @@
+<!DOCTYPE html>
+<html lang="en">
+<head>
+    <meta charset="UTF-8">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
+    <title>小天-AI</title>
+    <style>
+        body {
+            font-family: 'Roboto', Arial, sans-serif;
+            margin: 0;
+            padding: 0;
+            background-color: #f4f4f9;
+            display: flex;
+            justify-content: center;
+            align-items: center;
+            height: 100vh;
+        }
+        .container {
+            width: 500px;
+            background-color: #fff;
+            box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
+            border-radius: 8px;
+            overflow: hidden;
+        }
+        h1 {
+            background-color: #007bff;
+            color: #fff;
+            text-align: center;
+            padding: 15px;
+            margin: 0;
+            font-size: 1.5em;
+        }
+        form {
+            padding: 20px;
+        }
+        label {
+            display: block;
+            margin-bottom: 10px;
+            font-weight: bold;
+        }
+        textarea {
+            width: 100%;
+            height: 100px;
+            border: 1px solid #ccc;
+            border-radius: 4px;
+            padding: 10px;
+            resize: none;
+        }
+        button {
+            width: 100%;
+            padding: 10px;
+            background-color: #007bff;
+            color: #fff;
+            border: none;
+            border-radius: 4px;
+            cursor: pointer;
+            transition: background-color 0.3s ease;
+        }
+        button:hover {
+            background-color: #0056b3;
+        }
+        #response {
+            margin-top: 20px;
+            padding: 10px;
+            background-color: #f9f9f9;
+            border-top: 1px solid #ccc;
+            max-height: 200px;
+            overflow-y: auto;
+            font-family: monospace;
+            white-space: pre-wrap;
+        }
+    </style>
+</head>
+<body>
+<div class="container">
+    <h1>小天-AI</h1>
+    <form id="chatForm">
+        <label for="content">你的问题:</label>
+        <textarea id="content" name="content" placeholder="请输入你的问题在这里..."></textarea>
+        <button type="submit">发送</button>
+    </form>
+    <div id="response"></div>
+</div>
+
+<script>
+    document.getElementById('chatForm').addEventListener('submit', function(event) {
+        event.preventDefault();
+
+        const content = document.getElementById('content').value;
+
+        const requestBody = JSON.stringify({ content: content });
+
+        const token = "eyJhbGciOiJIUzUxMiJ9.eyIiOjEwMDMsInVzZXJfaWQiOjIxMywidXNlcl9rZXkiOiJlYzUxODMzNjdmYTk0ODgwOGQwZjEwODEyOWVmNjgwOSIsInVzZXJuYW1lIjoi6LW16YeR6ZuoIn0.zWulXcesI1TRcDmiAHuQ9P2WHDE2l7mDmuunx13TmVl6E5Yvs8nZvu1ddtINdw0lrnnR3Q5lZaRH3mJJTaDhig";
+
+        fetch('/ai/aliTyqw', {
+            method: 'POST',
+            headers: {
+                'Content-Type': 'application/json',
+                'Authorization': `Bearer ${token}`
+            },
+            body: requestBody
+        })
+            .then(response => {
+                if (!response.ok) {
+                    throw new Error('Network response was not ok');
+                }
+                return response.text();
+            })
+            .then(data => {
+                document.getElementById('response').innerText = '';
+
+                const lines = data.split('\n');
+                const fullResponse = lines.map(line => line.replace('data: ', '')).join('');
+
+                let index = 0;
+                const responseElement = document.getElementById('response');
+                const interval = setInterval(() => {
+                    if (index < fullResponse.length) {
+                        responseElement.innerText += fullResponse[index];
+                        index++;
+                    } else {
+                        clearInterval(interval);
+                    }
+                }, 50);
+            })
+            .catch(error => {
+                document.getElementById('response').innerText = 'Error: ' + error.message;
+            });
+    });
+</script>
+</body>
+</html>