Browse Source

【功能新增】AI:聊天记录返回时,增加 segments

YunaiV 5 tháng trước cách đây
mục cha
commit
b709af11a1

+ 46 - 5
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.java

@@ -12,9 +12,13 @@ import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessage
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
 import cn.iocoder.yudao.module.ai.service.chat.AiChatConversationService;
 import cn.iocoder.yudao.module.ai.service.chat.AiChatMessageService;
+import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeDocumentService;
+import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService;
 import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
 import io.swagger.v3.oas.annotations.Operation;
 import io.swagger.v3.oas.annotations.Parameter;
@@ -32,7 +36,7 @@ import java.util.List;
 import java.util.Map;
 
 import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
-import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertSet;
+import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.*;
 import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils.getLoginUserId;
 
 @Tag(name = "管理后台 - 聊天消息")
@@ -47,6 +51,10 @@ public class AiChatMessageController {
     private AiChatConversationService chatConversationService;
     @Resource
     private AiChatRoleService chatRoleService;
+    @Resource
+    private AiKnowledgeSegmentService knowledgeSegmentService;
+    @Resource
+    private AiKnowledgeDocumentService knowledgeDocumentService;
 
     @Operation(summary = "发送消息(段式)", description = "一次性返回,响应较慢")
     @PostMapping("/send")
@@ -56,7 +64,8 @@ public class AiChatMessageController {
 
     @Operation(summary = "发送消息(流式)", description = "流式返回,响应较快")
     @PostMapping(value = "/send-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
-    public Flux<CommonResult<AiChatMessageSendRespVO>> sendChatMessageStream(@Valid @RequestBody AiChatMessageSendReqVO sendReqVO) {
+    public Flux<CommonResult<AiChatMessageSendRespVO>> sendChatMessageStream(
+            @Valid @RequestBody AiChatMessageSendReqVO sendReqVO) {
         return chatMessageService.sendChatMessageStream(sendReqVO, getLoginUserId());
     }
 
@@ -69,8 +78,38 @@ public class AiChatMessageController {
         if (conversation == null || ObjUtil.notEqual(conversation.getUserId(), getLoginUserId())) {
             return success(Collections.emptyList());
         }
+        // 1. 获取消息列表
         List<AiChatMessageDO> messageList = chatMessageService.getChatMessageListByConversationId(conversationId);
-        return success(BeanUtils.toBean(messageList, AiChatMessageRespVO.class));
+        if (CollUtil.isEmpty(messageList)) {
+            return success(Collections.emptyList());
+        }
+
+        // 2. 拼接数据,主要是知识库段落信息
+        Map<Long, AiKnowledgeSegmentDO> segmentMap = knowledgeSegmentService.getKnowledgeSegmentMap(convertListByFlatMap(messageList,
+                message -> CollUtil.isEmpty(message.getSegmentIds()) ? null : message.getSegmentIds().stream()));
+        Map<Long, AiKnowledgeDocumentDO> documentMap = knowledgeDocumentService.getKnowledgeDocumentMap(
+                convertList(segmentMap.values(), AiKnowledgeSegmentDO::getDocumentId));
+        List<AiChatMessageRespVO> messageVOList = BeanUtils.toBean(messageList, AiChatMessageRespVO.class);
+        for (int i = 0; i < messageList.size(); i++) {
+            AiChatMessageDO message = messageList.get(i);
+            if (CollUtil.isEmpty(message.getSegmentIds())) {
+                continue;
+            }
+            // 设置知识库段落信息
+            messageVOList.get(i).setSegments(convertList(message.getSegmentIds(), segmentId -> {
+                AiKnowledgeSegmentDO segment = segmentMap.get(segmentId);
+                if (segment == null) {
+                    return null;
+                }
+                AiKnowledgeDocumentDO document = documentMap.get(segment.getDocumentId());
+                if (document == null) {
+                    return null;
+                }
+                return new AiChatMessageRespVO.KnowledgeSegment().setId(segment.getId()).setContent(segment.getContent())
+                        .setDocumentId(segment.getDocumentId()).setDocumentName(document.getName());
+            }));
+        }
+        return success(messageVOList);
     }
 
     @Operation(summary = "删除消息")
@@ -84,7 +123,8 @@ public class AiChatMessageController {
     @Operation(summary = "删除指定对话的消息")
     @DeleteMapping("/delete-by-conversation-id")
     @Parameter(name = "conversationId", required = true, description = "对话编号", example = "1024")
-    public CommonResult<Boolean> deleteChatMessageByConversationId(@RequestParam("conversationId") Long conversationId) {
+    public CommonResult<Boolean> deleteChatMessageByConversationId(
+            @RequestParam("conversationId") Long conversationId) {
         chatMessageService.deleteChatMessageByConversationId(conversationId, getLoginUserId());
         return success(true);
     }
@@ -103,7 +143,8 @@ public class AiChatMessageController {
         Map<Long, AiChatRoleDO> roleMap = chatRoleService.getChatRoleMap(
                 convertSet(pageResult.getList(), AiChatMessageDO::getRoleId));
         return success(BeanUtils.toBean(pageResult, AiChatMessageRespVO.class,
-                respVO -> MapUtils.findAndThen(roleMap, respVO.getRoleId(), role -> respVO.setRoleName(role.getName()))));
+                respVO -> MapUtils.findAndThen(roleMap, respVO.getRoleId(),
+                        role -> respVO.setRoleName(role.getName()))));
     }
 
     @Operation(summary = "管理员删除消息")

+ 25 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageRespVO.java

@@ -4,6 +4,7 @@ import io.swagger.v3.oas.annotations.media.Schema;
 import lombok.Data;
 
 import java.time.LocalDateTime;
+import java.util.List;
 
 @Schema(description = "管理后台 - AI 聊天消息 Response VO")
 @Data
@@ -39,6 +40,12 @@ public class AiChatMessageRespVO {
     @Schema(description = "是否携带上下文", requiredMode = Schema.RequiredMode.REQUIRED, example = "true")
     private Boolean useContext;
 
+    @Schema(description = "知识库段落编号数组", example = "[1,2,3]")
+    private List<Long> segmentIds;
+
+    @Schema(description = "知识库段落数组")
+    private List<KnowledgeSegment> segments;
+
     @Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED, example = "2024-05-12 12:51")
     private LocalDateTime createTime;
 
@@ -47,4 +54,22 @@ public class AiChatMessageRespVO {
     @Schema(description = "角色名字", example = "小黄")
     private String roleName;
 
+    @Schema(description = "知识库段落", example = "Java 开发手册")
+    @Data
+    public static class KnowledgeSegment {
+
+        @Schema(description = "段落编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
+        private Long id;
+
+        @Schema(description = "切片内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "Java 开发手册")
+        private String content;
+
+        @Schema(description = "文档编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "24790")
+        private Long documentId;
+
+        @Schema(description = "文档名称", requiredMode = Schema.RequiredMode.REQUIRED, example = "产品使用手册")
+        private String documentName;
+
+    }
+
 }

+ 0 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/knowledge/AiKnowledgeController.java

@@ -23,7 +23,6 @@ import java.util.List;
 import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
 import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
 
-// TODO @芋艿:增加权限标识
 @Tag(name = "管理后台 - AI 知识库")
 @RestController
 @RequestMapping("/ai/knowledge")

+ 9 - 9
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java

@@ -20,7 +20,7 @@ import java.util.List;
  * @since 2024/4/14 17:35
  * @since 2024/4/14 17:35
  */
-@TableName("ai_chat_message")
+@TableName(value = "ai_chat_message", autoResultMap = true)
 @KeySequence("ai_chat_conversation_seq") // 用于 Oracle、PostgreSQL、Kingbase、DB2、H2 数据库的主键自增。如果是 MySQL 等数据库,可不写。
 @Data
 @EqualsAndHashCode(callSuper = true)
@@ -71,14 +71,6 @@ public class AiChatMessageDO extends BaseDO {
      */
     private Long roleId;
 
-    /**
-     * 知识库段落编号数组
-     *
-     * 关联 {@link AiKnowledgeSegmentDO#getId()} 字段
-     */
-    @TableField(typeHandler = LongListTypeHandler.class)
-    private List<Long> segmentIds;
-
     /**
      * 模型标志
      *
@@ -102,4 +94,12 @@ public class AiChatMessageDO extends BaseDO {
      */
     private Boolean useContext;
 
+    /**
+     * 知识库段落编号数组
+     *
+     * 关联 {@link AiKnowledgeSegmentDO#getId()} 字段
+     */
+    @TableField(typeHandler = LongListTypeHandler.class)
+    private List<Long> segmentIds;
+
 }

+ 1 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java

@@ -212,7 +212,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         // 1.4 知识库,通过 UserMessage 实现
         if (CollUtil.isNotEmpty(knowledgeSegments)) {
             String reference = knowledgeSegments.stream()
-                    .map(segment -> "<Reference>\n" + segment.getContent() + "</Reference>")
+                    .map(segment -> "<Reference>" + segment.getContent() + "</Reference>")
                     .collect(Collectors.joining("\n\n"));
             chatMessages.add(new UserMessage(String.format(KNOWLEDGE_USER_MESSAGE_TEMPLATE, reference)));
         }

+ 3 - 2
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeDocumentServiceImpl.java

@@ -26,6 +26,7 @@ import org.springframework.transaction.annotation.Transactional;
 
 import java.util.ArrayList;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.List;
 
 import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
@@ -205,9 +206,9 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic
     @Override
     public List<AiKnowledgeDocumentDO> getKnowledgeDocumentList(Collection<Long> ids) {
         if (CollUtil.isEmpty(ids)) {
-            return new ArrayList<>();
+            return Collections.emptyList();
         }
-        return knowledgeDocumentMapper.selectByIds(ids);
+        return knowledgeDocumentMapper.selectBatchIds(ids);
     }
 
 }

+ 22 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentService.java

@@ -10,7 +10,11 @@ import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchR
 import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchRespBO;
 import org.springframework.scheduling.annotation.Async;
 
+import java.util.Collection;
 import java.util.List;
+import java.util.Map;
+
+import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertMap;
 
 /**
  * AI 知识库段落 Service 接口
@@ -27,6 +31,24 @@ public interface AiKnowledgeSegmentService {
      */
     AiKnowledgeSegmentDO getKnowledgeSegment(Long id);
 
+    /**
+     * 获取知识库段落列表
+     *
+     * @param ids 段落编号列表
+     * @return 段落列表
+     */
+    List<AiKnowledgeSegmentDO> getKnowledgeSegmentList(Collection<Long> ids);
+
+    /**
+     * 获取知识库段落 Map
+     *
+     * @param ids 段落编号列表
+     * @return 段落 Map
+     */
+    default Map<Long, AiKnowledgeSegmentDO> getKnowledgeSegmentMap(Collection<Long> ids) {
+        return convertMap(getKnowledgeSegmentList(ids), AiKnowledgeSegmentDO::getId);
+    }
+
     /**
      * 获取段落分页
      *

+ 10 - 5
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentServiceImpl.java

@@ -30,10 +30,7 @@ import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder;
 import org.springframework.context.annotation.Lazy;
 import org.springframework.stereotype.Service;
 
-import java.util.Collections;
-import java.util.List;
-import java.util.Map;
-import java.util.Objects;
+import java.util.*;
 
 import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
 import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
@@ -322,7 +319,15 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
 
     @Override
     public AiKnowledgeSegmentDO getKnowledgeSegment(Long id) {
-        return validateKnowledgeSegmentExists(id);
+        return segmentMapper.selectById(id);
+    }
+
+    @Override
+    public List<AiKnowledgeSegmentDO> getKnowledgeSegmentList(Collection<Long> ids) {
+        if (CollUtil.isEmpty(ids)) {
+            return Collections.emptyList();
+        }
+        return segmentMapper.selectBatchIds(ids);
     }
 
 }