瀏覽代碼

【功能新增】AI:聊天时,增加知识库的拼接

YunaiV 5 月之前
父節點
當前提交
cddaca5863

+ 1 - 5
yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java

@@ -35,11 +35,7 @@ public enum AiChatRoleEnum {
              ### 微信
             除此之外不要任何解释性语句。
             """),
-
-    AI_KNOWLEDGE_ROLE("知识库助手", """
-                给你提供一些数据参考:{info},请回答我的问题。
-                请你跟进数据参考与工具返回结果回复用户的请求。
-                """);
+    ;
 
     /**
      * 角色名

+ 1 - 10
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatConversationDO.java

@@ -1,9 +1,8 @@
 package cn.iocoder.yudao.module.ai.dal.dataobject.chat;
 
 import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
 import com.baomidou.mybatisplus.annotation.KeySequence;
 import com.baomidou.mybatisplus.annotation.TableId;
 import com.baomidou.mybatisplus.annotation.TableName;
@@ -65,14 +64,6 @@ public class AiChatConversationDO extends BaseDO {
      */
     private Long roleId;
 
-    // TODO @芋艿:可优化,绑定多个知识库。前提,spring ai 支持 RerankModel 的封装
-    /**
-     * 知识库编号
-     * <p>
-     * 关联 {@link AiKnowledgeDO#getId()}
-     */
-    private Long knowledgeId;
-
     /**
      * 模型编号
      *

+ 4 - 5
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java

@@ -1,14 +1,14 @@
 package cn.iocoder.yudao.module.ai.dal.dataobject.chat;
 
 import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
+import cn.iocoder.yudao.framework.mybatis.core.type.LongListTypeHandler;
 import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
 import com.baomidou.mybatisplus.annotation.KeySequence;
 import com.baomidou.mybatisplus.annotation.TableField;
 import com.baomidou.mybatisplus.annotation.TableId;
 import com.baomidou.mybatisplus.annotation.TableName;
-import com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler;
 import lombok.*;
 import org.springframework.ai.chat.messages.MessageType;
 
@@ -71,13 +71,12 @@ public class AiChatMessageDO extends BaseDO {
      */
     private Long roleId;
 
-
     /**
-     * 段落编号数组
+     * 知识库段落编号数组
      *
      * 关联 {@link AiKnowledgeSegmentDO#getId()} 字段
      */
-    @TableField(typeHandler = JacksonTypeHandler.class)
+    @TableField(typeHandler = LongListTypeHandler.class)
     private List<Long> segmentIds;
 
     /**

+ 1 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatConversationServiceImpl.java

@@ -68,7 +68,7 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
 
         // 2. 创建 AiChatConversationDO 聊天对话
         AiChatConversationDO conversation = new AiChatConversationDO().setUserId(userId).setPinned(false)
-                .setModelId(model.getId()).setModel(model.getModel()).setKnowledgeId(createReqVO.getKnowledgeId())
+                .setModelId(model.getId()).setModel(model.getModel())
                 .setTemperature(model.getTemperature()).setMaxTokens(model.getMaxTokens()).setMaxContexts(model.getMaxContexts());
         if (role != null) {
             conversation.setTitle(role.getName()).setRoleId(role.getId()).setSystemMessage(role.getSystemMessage());

+ 103 - 64
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java

@@ -14,12 +14,14 @@ import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessage
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper;
-import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum;
 import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
 import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService;
+import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchReqBO;
+import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchRespBO;
+import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
 import cn.iocoder.yudao.module.ai.service.model.AiModelService;
 import jakarta.annotation.Resource;
 import lombok.extern.slf4j.Slf4j;
@@ -32,13 +34,13 @@ import org.springframework.ai.chat.model.ChatResponse;
 import org.springframework.ai.chat.model.StreamingChatModel;
 import org.springframework.ai.chat.prompt.ChatOptions;
 import org.springframework.ai.chat.prompt.Prompt;
-import org.springframework.ai.chat.prompt.PromptTemplate;
 import org.springframework.stereotype.Service;
 import org.springframework.transaction.annotation.Transactional;
 import reactor.core.publisher.Flux;
 
 import java.time.LocalDateTime;
 import java.util.*;
+import java.util.stream.Collectors;
 
 import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
 import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error;
@@ -56,12 +58,21 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_MESSAGE_N
 @Slf4j
 public class AiChatMessageServiceImpl implements AiChatMessageService {
 
+    /**
+     * 知识库转 {@link UserMessage} 的内容模版
+     */
+    private static final String KNOWLEDGE_USER_MESSAGE_TEMPLATE = "使用 <Reference></Reference> 标记中的内容作为本次对话的参考:\n\n" +
+            "%s\n\n" + // 多个 <Reference></Reference> 的拼接
+            "回答要求:\n- 避免提及你是从 <Reference></Reference> 获取的知识。";
+
     @Resource
     private AiChatMessageMapper chatMessageMapper;
 
     @Resource
     private AiChatConversationService chatConversationService;
     @Resource
+    private AiChatRoleService chatRoleService;
+    @Resource
     private AiModelService modalService;
     @Resource
     private AiKnowledgeSegmentService knowledgeSegmentService;
@@ -69,118 +80,143 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
     @Transactional(rollbackFor = Exception.class)
     public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) {
         // 1.1 校验对话存在
-        AiChatConversationDO conversation = chatConversationService.validateChatConversationExists(sendReqVO.getConversationId());
+        AiChatConversationDO conversation = chatConversationService
+                .validateChatConversationExists(sendReqVO.getConversationId());
         if (ObjUtil.notEqual(conversation.getUserId(), userId)) {
             throw exception(CHAT_CONVERSATION_NOT_EXISTS);
         }
         List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId());
         // 1.2 校验模型
         AiModelDO model = modalService.validateModel(conversation.getModelId());
-        ChatModel chatModel = modalService.getChatModel(model.getKeyId());
+        ChatModel chatModel = modalService.getChatModel(model.getId());
+
+        // 2. 知识库找回
+        List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments = recallKnowledgeSegment(sendReqVO.getContent(),
+                conversation);
 
-        // 2. 插入 user 发送消息
+        // 3. 插入 user 发送消息
         AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
-                userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext());
+                userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext(),
+                null);
 
         // 3.1 插入 assistant 接收消息
         AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
-                userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
+                userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext(),
+                knowledgeSegments);
 
-        // 3.2 召回段落
-        List<AiKnowledgeSegmentDO> segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId());
-
-        // 3.3 创建 chat 需要的 Prompt
-        Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO);
+        // 3.2 创建 chat 需要的 Prompt
+        Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO);
         ChatResponse chatResponse = chatModel.call(prompt);
 
-        // 3.4 段式返回
+        // 3.3 段式返回
         String newContent = chatResponse.getResult().getOutput().getText();
-        chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setSegmentIds(convertList(segmentList, AiKnowledgeSegmentDO::getId)).setContent(newContent));
-        return new AiChatMessageSendRespVO().setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
-                .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class).setContent(newContent));
+        chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(newContent));
+        return new AiChatMessageSendRespVO()
+                .setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
+                .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class)
+                        .setContent(newContent));
     }
 
     @Override
-    public Flux<CommonResult<AiChatMessageSendRespVO>> sendChatMessageStream(AiChatMessageSendReqVO sendReqVO, Long userId) {
+    public Flux<CommonResult<AiChatMessageSendRespVO>> sendChatMessageStream(AiChatMessageSendReqVO sendReqVO,
+            Long userId) {
         // 1.1 校验对话存在
-        AiChatConversationDO conversation = chatConversationService.validateChatConversationExists(sendReqVO.getConversationId());
+        AiChatConversationDO conversation = chatConversationService
+                .validateChatConversationExists(sendReqVO.getConversationId());
         if (ObjUtil.notEqual(conversation.getUserId(), userId)) {
             throw exception(CHAT_CONVERSATION_NOT_EXISTS);
         }
         List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId());
         // 1.2 校验模型
         AiModelDO model = modalService.validateModel(conversation.getModelId());
-        StreamingChatModel chatModel = modalService.getChatModel(model.getKeyId());
+        StreamingChatModel chatModel = modalService.getChatModel(model.getId());
+
+        // 2. 知识库找回
+        List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments = recallKnowledgeSegment(sendReqVO.getContent(),
+                conversation);
 
-        // 2. 插入 user 发送消息
+        // 3. 插入 user 发送消息
         AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
-                userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext());
+                userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext(),
+                null);
 
-        // 3.1 插入 assistant 接收消息
+        // 4.1 插入 assistant 接收消息
         AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
-                userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
+                userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext(),
+                knowledgeSegments);
 
-        // 3.2 召回段落
-        List<AiKnowledgeSegmentDO> segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId());
-
-        // 3.3 构建 Prompt,并进行调用
-        Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO);
+        // 4.2 构建 Prompt,并进行调用
+        Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO);
         Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
 
-        // 3.4 流式返回
+        // 4.3 流式返回
         StringBuffer contentBuffer = new StringBuffer();
         return streamResponse.map(chunk -> {
             String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getText() : null;
             newContent = StrUtil.nullToDefault(newContent, ""); // 避免 null 的 情况
             contentBuffer.append(newContent);
             // 响应结果
-            return success(new AiChatMessageSendRespVO().setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
-                    .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class).setContent(newContent)));
+            return success(new AiChatMessageSendRespVO()
+                    .setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
+                    .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class)
+                            .setContent(newContent)));
         }).doOnComplete(() -> {
             // 忽略租户,因为 Flux 异步无法透传租户
-            TenantUtils.executeIgnore(() ->
-                    chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setSegmentIds(convertList(segmentList, AiKnowledgeSegmentDO::getId))
-                            .setContent(contentBuffer.toString())));
+            TenantUtils.executeIgnore(() -> chatMessageMapper.updateById(
+                    new AiChatMessageDO().setId(assistantMessage.getId()).setContent(contentBuffer.toString())));
         }).doOnError(throwable -> {
             log.error("[sendChatMessageStream][userId({}) sendReqVO({}) 发生异常]", userId, sendReqVO, throwable);
             // 忽略租户,因为 Flux 异步无法透传租户
-            TenantUtils.executeIgnore(() ->
-                    chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(throwable.getMessage())));
+            TenantUtils.executeIgnore(() -> chatMessageMapper.updateById(
+                    new AiChatMessageDO().setId(assistantMessage.getId()).setContent(throwable.getMessage())));
         }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR)));
     }
 
-    private List<AiKnowledgeSegmentDO> recallSegment(String content, Long knowledgeId) {
-        if (Objects.isNull(knowledgeId)) {
+    private List<AiKnowledgeSegmentSearchRespBO> recallKnowledgeSegment(String content,
+            AiChatConversationDO conversation) {
+        // 1. 查询聊天角色
+        if (conversation == null || conversation.getRoleId() == null) {
+            return Collections.emptyList();
+        }
+        AiChatRoleDO role = chatRoleService.getChatRole(conversation.getRoleId());
+        if (role == null || CollUtil.isEmpty(role.getKnowledgeIds())) {
             return Collections.emptyList();
         }
-//        return knowledgeSegmentService.similaritySearch(new AiKnowledgeSegmentSearchReqVO().setKnowledgeId(knowledgeId).setContent(content));
-        return null;
-    }
-
-    private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages, List<AiKnowledgeSegmentDO> segmentList,
-                               AiModelDO model, AiChatMessageSendReqVO sendReqVO) {
-        // 1. 构建 Prompt Message 列表
-        List<Message> chatMessages = new ArrayList<>();
 
-        // 1.1 召回内容消息构建
-        if (CollUtil.isNotEmpty(segmentList)) {
-            PromptTemplate promptTemplate = new PromptTemplate(AiChatRoleEnum.AI_KNOWLEDGE_ROLE.getSystemMessage());
-            StringBuilder infoBuilder = StrUtil.builder();
-            segmentList.forEach(segment -> infoBuilder.append(System.lineSeparator()).append(segment.getContent()));
-            Message message = promptTemplate.createMessage(Map.of("info", infoBuilder.toString()));
-            chatMessages.add(message);
+        // 2. 遍历找回
+        List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments = new ArrayList<>();
+        for (Long knowledgeId : role.getKnowledgeIds()) {
+            knowledgeSegments.addAll(knowledgeSegmentService.searchKnowledgeSegment(new AiKnowledgeSegmentSearchReqBO()
+                    .setKnowledgeId(knowledgeId).setContent(content)));
         }
+        return knowledgeSegments;
+    }
 
-        // 1.2 system context 角色设定
+    private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,
+            List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments,
+            AiModelDO model, AiChatMessageSendReqVO sendReqVO) {
+        List<Message> chatMessages = new ArrayList<>();
+        // 1.1 System Context 角色设定
         if (StrUtil.isNotBlank(conversation.getSystemMessage())) {
             chatMessages.add(new SystemMessage(conversation.getSystemMessage()));
         }
-        // 1.3 history message 历史消息
+
+        // 1.2 历史 history message 历史消息
         List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO);
-        contextMessages.forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent())));
-        // 1.4 user message 新发送消息
+        contextMessages
+                .forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent())));
+
+        // 1.3 当前 user message 新发送消息
         chatMessages.add(new UserMessage(sendReqVO.getContent()));
 
+        // 1.4 知识库,通过 UserMessage 实现
+        if (CollUtil.isNotEmpty(knowledgeSegments)) {
+            String reference = knowledgeSegments.stream()
+                    .map(segment -> "<Reference>\n" + segment.getContent() + "</Reference>")
+                    .collect(Collectors.joining("\n\n"));
+            chatMessages.add(new UserMessage(String.format(KNOWLEDGE_USER_MESSAGE_TEMPLATE, reference)));
+        }
+
         // 2. 构建 ChatOptions 对象
         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
         ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(),
@@ -199,8 +235,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
      * @return 消息上下文
      */
     private List<AiChatMessageDO> filterContextMessages(List<AiChatMessageDO> messages,
-                                                        AiChatConversationDO conversation,
-                                                        AiChatMessageSendReqVO sendReqVO) {
+            AiChatConversationDO conversation,
+            AiChatMessageSendReqVO sendReqVO) {
         if (conversation.getMaxContexts() == null || ObjUtil.notEqual(sendReqVO.getUseContext(), Boolean.TRUE)) {
             return Collections.emptyList();
         }
@@ -211,7 +247,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
                 continue;
             }
             AiChatMessageDO userMessage = CollUtil.get(messages, i - 1);
-            if (userMessage == null || ObjUtil.notEqual(assistantMessage.getReplyId(), userMessage.getId())
+            if (userMessage == null
+                    || ObjUtil.notEqual(assistantMessage.getReplyId(), userMessage.getId())
                     || StrUtil.isEmpty(assistantMessage.getContent())) {
                 continue;
             }
@@ -228,11 +265,13 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
     }
 
     private AiChatMessageDO createChatMessage(Long conversationId, Long replyId,
-                                              AiModelDO model, Long userId, Long roleId,
-                                              MessageType messageType, String content, Boolean useContext) {
+            AiModelDO model, Long userId, Long roleId,
+            MessageType messageType, String content, Boolean useContext,
+            List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments) {
         AiChatMessageDO message = new AiChatMessageDO().setConversationId(conversationId).setReplyId(replyId)
                 .setModel(model.getModel()).setModelId(model.getId()).setUserId(userId).setRoleId(roleId)
-                .setType(messageType.getValue()).setContent(content).setUseContext(useContext);
+                .setType(messageType.getValue()).setContent(content).setUseContext(useContext)
+                .setSegmentIds(convertList(knowledgeSegments, AiKnowledgeSegmentSearchRespBO::getId));
         message.setCreateTime(LocalDateTime.now());
         chatMessageMapper.insert(message);
         return message;