Эх сурвалжийг харах

【功能新增】AI:聊天记录返回时,增加 segments

YunaiV 5 сар өмнө
parent
commit
32e1ef4da8

+ 7 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendRespVO.java

@@ -4,6 +4,7 @@ import io.swagger.v3.oas.annotations.media.Schema;
 import lombok.Data;
 
 import java.time.LocalDateTime;
+import java.util.List;
 
 @Schema(description = "管理后台 - AI 聊天消息发送 Response VO")
 @Data
@@ -28,6 +29,12 @@ public class AiChatMessageSendRespVO {
         @Schema(description = "聊天内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "你好,你好啊")
         private String content;
 
+        @Schema(description = "知识库段落编号数组", example = "[1,2,3]")
+        private List<Long> segmentIds;
+
+        @Schema(description = "知识库段落数组")
+        private List<AiChatMessageRespVO.KnowledgeSegment> segments;
+
         @Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED)
         private LocalDateTime createTime;
 

+ 34 - 16
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java

@@ -10,14 +10,17 @@ import cn.iocoder.yudao.framework.common.pojo.PageResult;
 import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
 import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper;
 import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
+import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeDocumentService;
 import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService;
 import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchReqBO;
 import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchRespBO;
@@ -76,6 +79,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
     private AiModelService modalService;
     @Resource
     private AiKnowledgeSegmentService knowledgeSegmentService;
+    @Resource
+    private AiKnowledgeDocumentService knowledgeDocumentService;
 
     @Transactional(rollbackFor = Exception.class)
     public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) {
@@ -108,18 +113,23 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO);
         ChatResponse chatResponse = chatModel.call(prompt);
 
-        // 3.3 段式返回
+        // 3.3 更新响应内容
         String newContent = chatResponse.getResult().getOutput().getText();
         chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(newContent));
+        // 3.4 响应结果
+        List<AiChatMessageRespVO.KnowledgeSegment> segments = BeanUtils.toBean(knowledgeSegments, AiChatMessageRespVO.KnowledgeSegment.class,
+                segment -> {
+            AiKnowledgeDocumentDO document = knowledgeDocumentService.getKnowledgeDocument(segment.getDocumentId());
+            segment.setDocumentName(document != null ? document.getName() : null);
+        });
         return new AiChatMessageSendRespVO()
                 .setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
                 .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class)
-                        .setContent(newContent));
+                        .setContent(newContent).setSegments(segments));
     }
 
     @Override
-    public Flux<CommonResult<AiChatMessageSendRespVO>> sendChatMessageStream(AiChatMessageSendReqVO sendReqVO,
-            Long userId) {
+    public Flux<CommonResult<AiChatMessageSendRespVO>> sendChatMessageStream(AiChatMessageSendReqVO sendReqVO, Long userId) {
         // 1.1 校验对话存在
         AiChatConversationDO conversation = chatConversationService
                 .validateChatConversationExists(sendReqVO.getConversationId());
@@ -132,8 +142,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         StreamingChatModel chatModel = modalService.getChatModel(model.getId());
 
         // 2. 知识库找回
-        List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments = recallKnowledgeSegment(sendReqVO.getContent(),
-                conversation);
+        List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments = recallKnowledgeSegment(sendReqVO.getContent(), conversation);
 
         // 3. 插入 user 发送消息
         AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
@@ -152,14 +161,23 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         // 4.3 流式返回
         StringBuffer contentBuffer = new StringBuffer();
         return streamResponse.map(chunk -> {
+            // 处理知识库的返回,只有首次才有
+            List<AiChatMessageRespVO.KnowledgeSegment> segments = null;
+            if (StrUtil.isEmpty(contentBuffer)) {
+                segments = BeanUtils.toBean(knowledgeSegments, AiChatMessageRespVO.KnowledgeSegment.class,
+                        segment -> TenantUtils.executeIgnore(() -> {
+                            AiKnowledgeDocumentDO document = knowledgeDocumentService.getKnowledgeDocument(segment.getDocumentId());
+                            segment.setDocumentName(document != null ? document.getName() : null);
+                        }));
+            }
+            // 响应结果
             String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getText() : null;
             newContent = StrUtil.nullToDefault(newContent, ""); // 避免 null 的 情况
             contentBuffer.append(newContent);
-            // 响应结果
             return success(new AiChatMessageSendRespVO()
                     .setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
                     .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class)
-                            .setContent(newContent)));
+                            .setContent(newContent).setSegments(segments)));
         }).doOnComplete(() -> {
             // 忽略租户,因为 Flux 异步无法透传租户
             TenantUtils.executeIgnore(() -> chatMessageMapper.updateById(
@@ -173,7 +191,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
     }
 
     private List<AiKnowledgeSegmentSearchRespBO> recallKnowledgeSegment(String content,
-            AiChatConversationDO conversation) {
+                                                                        AiChatConversationDO conversation) {
         // 1. 查询聊天角色
         if (conversation == null || conversation.getRoleId() == null) {
             return Collections.emptyList();
@@ -193,8 +211,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
     }
 
     private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,
-            List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments,
-            AiModelDO model, AiChatMessageSendReqVO sendReqVO) {
+                               List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments,
+                               AiModelDO model, AiChatMessageSendReqVO sendReqVO) {
         List<Message> chatMessages = new ArrayList<>();
         // 1.1 System Context 角色设定
         if (StrUtil.isNotBlank(conversation.getSystemMessage())) {
@@ -235,8 +253,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
      * @return 消息上下文
      */
     private List<AiChatMessageDO> filterContextMessages(List<AiChatMessageDO> messages,
-            AiChatConversationDO conversation,
-            AiChatMessageSendReqVO sendReqVO) {
+                                                        AiChatConversationDO conversation,
+                                                        AiChatMessageSendReqVO sendReqVO) {
         if (conversation.getMaxContexts() == null || ObjUtil.notEqual(sendReqVO.getUseContext(), Boolean.TRUE)) {
             return Collections.emptyList();
         }
@@ -265,9 +283,9 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
     }
 
     private AiChatMessageDO createChatMessage(Long conversationId, Long replyId,
-            AiModelDO model, Long userId, Long roleId,
-            MessageType messageType, String content, Boolean useContext,
-            List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments) {
+                                              AiModelDO model, Long userId, Long roleId,
+                                              MessageType messageType, String content, Boolean useContext,
+                                              List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments) {
         AiChatMessageDO message = new AiChatMessageDO().setConversationId(conversationId).setReplyId(replyId)
                 .setModel(model.getModel()).setModelId(model.getId()).setUserId(userId).setRoleId(roleId)
                 .setType(messageType.getValue()).setContent(content).setUseContext(useContext)