Browse Source

【新增】AI:聊天接入知识库

xiaoxin 10 months ago
parent
commit
0700c3f15e

+ 6 - 1
yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java

@@ -34,7 +34,12 @@ public enum AiChatRoleEnum {
              ### 支付宝
              ### 微信
             除此之外不要任何解释性语句。
-            """);
+            """),
+
+    AI_KNOWLEDGE_ROLE("知识库助手", """
+                给你提供一些数据参考:{info},请回答我的问题。
+                请你跟进数据参考与工具返回结果回复用户的请求。
+                """);
 
     /**
      * 角色名

+ 3 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/conversation/AiChatConversationCreateMyReqVO.java

@@ -10,4 +10,7 @@ public class AiChatConversationCreateMyReqVO {
     @Schema(description = "聊天角色编号", example = "666")
     private Long roleId;
 
+    @Schema(description = "知识库编号", example = "1204")
+    private Long knowledgeId;
+
 }

+ 3 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/conversation/AiChatConversationUpdateMyReqVO.java

@@ -21,6 +21,9 @@ public class AiChatConversationUpdateMyReqVO {
     @Schema(description = "模型编号", example = "1")
     private Long modelId;
 
+    @Schema(description = "知识库编号", example = "1")
+    private Long knowledgeId;
+
     @Schema(description = "角色设定", example = "一个快乐的程序员")
     private String systemMessage;
 

+ 8 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatConversationDO.java

@@ -1,6 +1,7 @@
 package cn.iocoder.yudao.module.ai.dal.dataobject.chat;
 
 import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
 import com.baomidou.mybatisplus.annotation.KeySequence;
@@ -64,6 +65,13 @@ public class AiChatConversationDO extends BaseDO {
      */
     private Long roleId;
 
+    /**
+     * 知识库编号
+     * <p>
+     * 关联 {@link AiKnowledgeDO#getId()}
+     */
+    private Long knowledgeId;
+
     /**
      * 模型编号
      *

+ 15 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatConversationServiceImpl.java

@@ -13,6 +13,7 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatConversationMapper;
+import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeService;
 import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
 import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
 import jakarta.annotation.Resource;
@@ -22,6 +23,7 @@ import org.springframework.validation.annotation.Validated;
 
 import java.time.LocalDateTime;
 import java.util.List;
+import java.util.Objects;
 
 import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
 import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
@@ -45,6 +47,8 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
     private AiChatModelService chatModalService;
     @Resource
     private AiChatRoleService chatRoleService;
+    @Resource
+    private AiKnowledgeService knowledgeService;
 
     @Override
     public Long createChatConversationMy(AiChatConversationCreateMyReqVO createReqVO, Long userId) {
@@ -56,9 +60,14 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
         Assert.notNull(model, "必须找到默认模型");
         validateChatModel(model);
 
+        // 1.3 校验知识库
+        if (Objects.nonNull(createReqVO.getKnowledgeId())) {
+            knowledgeService.validateKnowledgeExists(createReqVO.getKnowledgeId());
+        }
+
         // 2. 创建 AiChatConversationDO 聊天对话
         AiChatConversationDO conversation = new AiChatConversationDO().setUserId(userId).setPinned(false)
-                .setModelId(model.getId()).setModel(model.getModel())
+                .setModelId(model.getId()).setModel(model.getModel()).setKnowledgeId(createReqVO.getKnowledgeId())
                 .setTemperature(model.getTemperature()).setMaxTokens(model.getMaxTokens()).setMaxContexts(model.getMaxContexts());
         if (role != null) {
             conversation.setTitle(role.getName()).setRoleId(role.getId()).setSystemMessage(role.getSystemMessage());
@@ -82,6 +91,11 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
             model = chatModalService.validateChatModel(updateReqVO.getModelId());
         }
 
+        // 1.3 校验知识库是否存在
+        if (updateReqVO.getKnowledgeId() != null) {
+            knowledgeService.validateKnowledgeExists(updateReqVO.getKnowledgeId());
+        }
+
         // 2. 更新对话信息
         AiChatConversationDO updateObj = BeanUtils.toBean(updateReqVO, AiChatConversationDO.class);
         if (Boolean.TRUE.equals(updateReqVO.getPinned())) {

+ 31 - 8
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java

@@ -12,21 +12,29 @@ import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO;
+import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentSearchReqVO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper;
+import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum;
 import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
+import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService;
 import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
 import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
 import jakarta.annotation.Resource;
 import lombok.extern.slf4j.Slf4j;
-import org.springframework.ai.chat.messages.*;
+import org.springframework.ai.chat.messages.Message;
+import org.springframework.ai.chat.messages.MessageType;
+import org.springframework.ai.chat.messages.SystemMessage;
+import org.springframework.ai.chat.messages.UserMessage;
 import org.springframework.ai.chat.model.ChatModel;
 import org.springframework.ai.chat.model.ChatResponse;
 import org.springframework.ai.chat.model.StreamingChatModel;
 import org.springframework.ai.chat.prompt.ChatOptions;
 import org.springframework.ai.chat.prompt.Prompt;
+import org.springframework.ai.chat.prompt.PromptTemplate;
 import org.springframework.stereotype.Service;
 import org.springframework.transaction.annotation.Transactional;
 import reactor.core.publisher.Flux;
@@ -59,6 +67,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
     private AiChatModelService chatModalService;
     @Resource
     private AiApiKeyService apiKeyService;
+    @Resource
+    private AiKnowledgeSegmentService knowledgeSegmentService;
 
     @Transactional(rollbackFor = Exception.class)
     public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) {
@@ -141,14 +151,27 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
                                AiChatModelDO model, AiChatMessageSendReqVO sendReqVO) {
         // 1. 构建 Prompt Message 列表
         List<Message> chatMessages = new ArrayList<>();
-        // 1.1 system context 角色设定
+
+        // 1.1 知识库召回
+        if (Objects.nonNull(conversation.getKnowledgeId())) {
+            List<AiKnowledgeSegmentDO> segmentList = knowledgeSegmentService.similaritySearch(new AiKnowledgeSegmentSearchReqVO().setKnowledgeId(conversation.getKnowledgeId()).setContent(sendReqVO.getContent()));
+            if (CollUtil.isNotEmpty(segmentList)) {
+                PromptTemplate promptTemplate = new PromptTemplate(AiChatRoleEnum.AI_KNOWLEDGE_ROLE.getSystemMessage());
+                StringBuilder infoBuilder = StrUtil.builder();
+                segmentList.forEach(segment -> infoBuilder.append(System.lineSeparator()).append(segment.getContent()));
+                Message message = promptTemplate.createMessage(Map.of("info", infoBuilder.toString()));
+                chatMessages.add(message);
+            }
+        }
+
+        // 1.2 system context 角色设定
         if (StrUtil.isNotBlank(conversation.getSystemMessage())) {
             chatMessages.add(new SystemMessage(conversation.getSystemMessage()));
         }
-        // 1.2 history message 历史消息
+        // 1.3 history message 历史消息
         List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO);
         contextMessages.forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent())));
-        // 1.3 user message 新发送消息
+        // 1.4 user message 新发送消息
         chatMessages.add(new UserMessage(sendReqVO.getContent()));
 
         // 2. 构建 ChatOptions 对象
@@ -160,12 +183,12 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
 
     /**
      * 从历史消息中,获得倒序的 n 组消息作为消息上下文
-     *
+     * <p>
      * n 组:指的是 user + assistant 形成一组
      *
-     * @param messages 消息列表
+     * @param messages     消息列表
      * @param conversation 对话
-     * @param sendReqVO 发送请求
+     * @param sendReqVO    发送请求
      * @return 消息上下文
      */
     private List<AiChatMessageDO> filterContextMessages(List<AiChatMessageDO> messages,
@@ -182,7 +205,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
             }
             AiChatMessageDO userMessage = CollUtil.get(messages, i - 1);
             if (userMessage == null || ObjUtil.notEqual(assistantMessage.getReplyId(), userMessage.getId())
-                || StrUtil.isEmpty(assistantMessage.getContent())) {
+                    || StrUtil.isEmpty(assistantMessage.getContent())) {
                 continue;
             }
             // 由于后续要 reverse 反转,所以先添加 assistantMessage