Browse Source

【功能新增】AI:增加 AI 对话,与 tool 的打通

YunaiV 5 months ago
parent
commit
1e8845ce6d

+ 17 - 6
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/AiToolController.java

@@ -1,5 +1,6 @@
 package cn.iocoder.yudao.module.ai.controller.admin.model;
 
+import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
 import cn.iocoder.yudao.framework.common.pojo.CommonResult;
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
 import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
@@ -17,8 +18,10 @@ import org.springframework.security.access.prepost.PreAuthorize;
 import org.springframework.validation.annotation.Validated;
 import org.springframework.web.bind.annotation.*;
 
-import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
+import java.util.List;
 
+import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
+import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
 
 @Tag(name = "管理后台 - AI 工具")
 @RestController
@@ -30,14 +33,14 @@ public class AiToolController {
     private AiToolService toolService;
 
     @PostMapping("/create")
-    @Operation(summary = "创建AI 工具")
+    @Operation(summary = "创建工具")
     @PreAuthorize("@ss.hasPermission('ai:tool:create')")
     public CommonResult<Long> createTool(@Valid @RequestBody AiToolSaveReqVO createReqVO) {
         return success(toolService.createTool(createReqVO));
     }
 
     @PutMapping("/update")
-    @Operation(summary = "更新AI 工具")
+    @Operation(summary = "更新工具")
     @PreAuthorize("@ss.hasPermission('ai:tool:update')")
     public CommonResult<Boolean> updateTool(@Valid @RequestBody AiToolSaveReqVO updateReqVO) {
         toolService.updateTool(updateReqVO);
@@ -45,7 +48,7 @@ public class AiToolController {
     }
 
     @DeleteMapping("/delete")
-    @Operation(summary = "删除AI 工具")
+    @Operation(summary = "删除工具")
     @Parameter(name = "id", description = "编号", required = true)
     @PreAuthorize("@ss.hasPermission('ai:tool:delete')")
     public CommonResult<Boolean> deleteTool(@RequestParam("id") Long id) {
@@ -54,7 +57,7 @@ public class AiToolController {
     }
 
     @GetMapping("/get")
-    @Operation(summary = "获得AI 工具")
+    @Operation(summary = "获得工具")
     @Parameter(name = "id", description = "编号", required = true, example = "1024")
     @PreAuthorize("@ss.hasPermission('ai:tool:query')")
     public CommonResult<AiToolRespVO> getTool(@RequestParam("id") Long id) {
@@ -63,11 +66,19 @@ public class AiToolController {
     }
 
     @GetMapping("/page")
-    @Operation(summary = "获得AI 工具分页")
+    @Operation(summary = "获得工具分页")
     @PreAuthorize("@ss.hasPermission('ai:tool:query')")
     public CommonResult<PageResult<AiToolRespVO>> getToolPage(@Valid AiToolPageReqVO pageReqVO) {
         PageResult<AiToolDO> pageResult = toolService.getToolPage(pageReqVO);
         return success(BeanUtils.toBean(pageResult, AiToolRespVO.class));
     }
 
+    @GetMapping("/simple-list")
+    @Operation(summary = "获得工具列表")
+    public CommonResult<List<AiToolRespVO>> getToolSimpleList() {
+        List<AiToolDO> list = toolService.getToolListByStatus(CommonStatusEnum.ENABLE.getStatus());
+        return success(convertList(list, tool -> new AiToolRespVO()
+                .setId(tool.getId()).setName(tool.getName())));
+    }
+
 }

+ 3 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatRole/AiChatRoleRespVO.java

@@ -49,6 +49,9 @@ public class AiChatRoleRespVO implements VO {
     @Schema(description = "引用的知识库编号列表", example = "1,2,3")
     private List<Long> knowledgeIds;
 
+    @Schema(description = "引用的工具编号列表", example = "1,2,3")
+    private List<Long> toolIds;
+
     @Schema(description = "是否公开", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
     private Boolean publicStatus;
 

+ 3 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatRole/AiChatRoleSaveMyReqVO.java

@@ -34,4 +34,7 @@ public class AiChatRoleSaveMyReqVO {
     @Schema(description = "引用的知识库编号列表", example = "1,2,3")
     private List<Long> knowledgeIds;
 
+    @Schema(description = "引用的工具编号列表", example = "1,2,3")
+    private List<Long> toolIds;
+
 }

+ 3 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatRole/AiChatRoleSaveReqVO.java

@@ -47,6 +47,9 @@ public class AiChatRoleSaveReqVO {
     @Schema(description = "引用的知识库编号列表", example = "1,2,3")
     private List<Long> knowledgeIds;
 
+    @Schema(description = "引用的工具编号列表", example = "1,2,3")
+    private List<Long> toolIds;
+
     @Schema(description = "是否公开", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
     @NotNull(message = "是否公开不能为空")
     private Boolean publicStatus;

+ 7 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiChatRoleDO.java

@@ -74,6 +74,13 @@ public class AiChatRoleDO extends BaseDO {
      */
     @TableField(typeHandler = LongListTypeHandler.class)
     private List<Long> knowledgeIds;
+    /**
+     * 引用的工具编号列表
+     *
+     * 关联 {@link AiToolDO#getId()} 字段
+     */
+    @TableField(typeHandler = LongListTypeHandler.class)
+    private List<Long> toolIds;
 
     /**
      * 是否公开

+ 8 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/model/AiToolMapper.java

@@ -7,6 +7,8 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.tool.AiToolPageReqVO
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiToolDO;
 import org.apache.ibatis.annotations.Mapper;
 
+import java.util.List;
+
 /**
  * AI 工具 Mapper
  *
@@ -24,4 +26,10 @@ public interface AiToolMapper extends BaseMapperX<AiToolDO> {
                 .orderByDesc(AiToolDO::getId));
     }
 
+    default List<AiToolDO> selectListByStatus(Integer status) {
+        return selectList(new LambdaQueryWrapperX<AiToolDO>()
+                .eq(AiToolDO::getStatus, status)
+                .orderByDesc(AiToolDO::getId));
+    }
+
 }

+ 35 - 19
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java

@@ -7,7 +7,6 @@ import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
 import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
 import cn.iocoder.yudao.framework.common.pojo.CommonResult;
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
-import cn.iocoder.yudao.framework.common.util.collection.SetUtils;
 import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
 import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO;
@@ -19,6 +18,7 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiToolDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper;
 import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
 import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeDocumentService;
@@ -27,6 +27,7 @@ import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchR
 import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchRespBO;
 import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
 import cn.iocoder.yudao.module.ai.service.model.AiModelService;
+import cn.iocoder.yudao.module.ai.service.model.AiToolService;
 import jakarta.annotation.Resource;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.ai.chat.messages.Message;
@@ -50,6 +51,7 @@ import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionU
 import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error;
 import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
 import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
+import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertSet;
 import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_CONVERSATION_NOT_EXISTS;
 import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_MESSAGE_NOT_EXIST;
 
@@ -82,6 +84,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
     private AiKnowledgeSegmentService knowledgeSegmentService;
     @Resource
     private AiKnowledgeDocumentService knowledgeDocumentService;
+    @Resource
+    private AiToolService toolService;
 
     @Transactional(rollbackFor = Exception.class)
     public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) {
@@ -118,11 +122,13 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         String newContent = chatResponse.getResult().getOutput().getText();
         chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(newContent));
         // 3.4 响应结果
-        List<AiChatMessageRespVO.KnowledgeSegment> segments = BeanUtils.toBean(knowledgeSegments, AiChatMessageRespVO.KnowledgeSegment.class,
+        List<AiChatMessageRespVO.KnowledgeSegment> segments = BeanUtils.toBean(knowledgeSegments,
+                AiChatMessageRespVO.KnowledgeSegment.class,
                 segment -> {
-            AiKnowledgeDocumentDO document = knowledgeDocumentService.getKnowledgeDocument(segment.getDocumentId());
-            segment.setDocumentName(document != null ? document.getName() : null);
-        });
+                    AiKnowledgeDocumentDO document = knowledgeDocumentService
+                            .getKnowledgeDocument(segment.getDocumentId());
+                    segment.setDocumentName(document != null ? document.getName() : null);
+                });
         return new AiChatMessageSendRespVO()
                 .setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
                 .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class)
@@ -130,7 +136,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
     }
 
     @Override
-    public Flux<CommonResult<AiChatMessageSendRespVO>> sendChatMessageStream(AiChatMessageSendReqVO sendReqVO, Long userId) {
+    public Flux<CommonResult<AiChatMessageSendRespVO>> sendChatMessageStream(AiChatMessageSendReqVO sendReqVO,
+            Long userId) {
         // 1.1 校验对话存在
         AiChatConversationDO conversation = chatConversationService
                 .validateChatConversationExists(sendReqVO.getConversationId());
@@ -143,7 +150,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         StreamingChatModel chatModel = modalService.getChatModel(model.getId());
 
         // 2. 知识库找回
-        List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments = recallKnowledgeSegment(sendReqVO.getContent(), conversation);
+        List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments = recallKnowledgeSegment(sendReqVO.getContent(),
+                conversation);
 
         // 3. 插入 user 发送消息
         AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
@@ -167,7 +175,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
             if (StrUtil.isEmpty(contentBuffer)) {
                 segments = BeanUtils.toBean(knowledgeSegments, AiChatMessageRespVO.KnowledgeSegment.class,
                         segment -> TenantUtils.executeIgnore(() -> {
-                            AiKnowledgeDocumentDO document = knowledgeDocumentService.getKnowledgeDocument(segment.getDocumentId());
+                            AiKnowledgeDocumentDO document = knowledgeDocumentService
+                                    .getKnowledgeDocument(segment.getDocumentId());
                             segment.setDocumentName(document != null ? document.getName() : null);
                         }));
             }
@@ -192,7 +201,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
     }
 
     private List<AiKnowledgeSegmentSearchRespBO> recallKnowledgeSegment(String content,
-                                                                        AiChatConversationDO conversation) {
+            AiChatConversationDO conversation) {
         // 1. 查询聊天角色
         if (conversation == null || conversation.getRoleId() == null) {
             return Collections.emptyList();
@@ -212,8 +221,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
     }
 
     private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,
-                               List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments,
-                               AiModelDO model, AiChatMessageSendReqVO sendReqVO) {
+            List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments,
+            AiModelDO model, AiChatMessageSendReqVO sendReqVO) {
         List<Message> chatMessages = new ArrayList<>();
         // 1.1 System Context 角色设定
         if (StrUtil.isNotBlank(conversation.getSystemMessage())) {
@@ -236,11 +245,18 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
             chatMessages.add(new UserMessage(String.format(KNOWLEDGE_USER_MESSAGE_TEMPLATE, reference)));
         }
 
-        // 2. 构建 ChatOptions 对象
+        // 2.1 查询 tool 工具
+        Set<String> toolNames = null;
+        if (conversation.getRoleId() != null) {
+            AiChatRoleDO chatRole = chatRoleService.getChatRole(conversation.getRoleId());
+            if (chatRole != null && CollUtil.isNotEmpty(chatRole.getToolIds())) {
+                toolNames = convertSet(toolService.getToolList(chatRole.getToolIds()), AiToolDO::getName);
+            }
+        }
+        // 2.2 构建 ChatOptions 对象
         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
         ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(),
-                conversation.getTemperature(), conversation.getMaxTokens(),
-                SetUtils.asSet("directory_list", "weather_query"));
+                conversation.getTemperature(), conversation.getMaxTokens(), toolNames);
         return new Prompt(chatMessages, chatOptions);
     }
 
@@ -255,8 +271,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
      * @return 消息上下文
      */
     private List<AiChatMessageDO> filterContextMessages(List<AiChatMessageDO> messages,
-                                                        AiChatConversationDO conversation,
-                                                        AiChatMessageSendReqVO sendReqVO) {
+            AiChatConversationDO conversation,
+            AiChatMessageSendReqVO sendReqVO) {
         if (conversation.getMaxContexts() == null || ObjUtil.notEqual(sendReqVO.getUseContext(), Boolean.TRUE)) {
             return Collections.emptyList();
         }
@@ -285,9 +301,9 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
     }
 
     private AiChatMessageDO createChatMessage(Long conversationId, Long replyId,
-                                              AiModelDO model, Long userId, Long roleId,
-                                              MessageType messageType, String content, Boolean useContext,
-                                              List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments) {
+            AiModelDO model, Long userId, Long roleId,
+            MessageType messageType, String content, Boolean useContext,
+            List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments) {
         AiChatMessageDO message = new AiChatMessageDO().setConversationId(conversationId).setReplyId(replyId)
                 .setModel(model.getModel()).setModelId(model.getId()).setUserId(userId).setRoleId(roleId)
                 .setType(messageType.getValue()).setContent(content).setUseContext(useContext)

+ 23 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatRoleServiceImpl.java

@@ -39,11 +39,15 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
 
     @Resource
     private AiKnowledgeService knowledgeService;
+    @Resource
+    private AiToolService toolService;
 
     @Override
     public Long createChatRole(AiChatRoleSaveReqVO createReqVO) {
         // 校验文档
         validateDocuments(createReqVO.getKnowledgeIds());
+        // 校验工具
+        validateTools(createReqVO.getToolIds());
 
         // 保存角色
         AiChatRoleDO chatRole = BeanUtils.toBean(createReqVO, AiChatRoleDO.class);
@@ -55,6 +59,8 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
     public Long createChatRoleMy(AiChatRoleSaveMyReqVO createReqVO, Long userId) {
         // 校验文档
         validateDocuments(createReqVO.getKnowledgeIds());
+        // 校验工具
+        validateTools(createReqVO.getToolIds());
 
         // 保存角色
         AiChatRoleDO chatRole = BeanUtils.toBean(createReqVO, AiChatRoleDO.class).setUserId(userId)
@@ -69,6 +75,8 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
         validateChatRoleExists(updateReqVO.getId());
         // 校验文档
         validateDocuments(updateReqVO.getKnowledgeIds());
+        // 校验工具
+        validateTools(updateReqVO.getToolIds());
 
         // 更新角色
         AiChatRoleDO updateObj = BeanUtils.toBean(updateReqVO, AiChatRoleDO.class);
@@ -84,6 +92,8 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
         }
         // 校验文档
         validateDocuments(updateReqVO.getKnowledgeIds());
+        // 校验工具
+        validateTools(updateReqVO.getToolIds());
 
         // 更新
         AiChatRoleDO updateObj = BeanUtils.toBean(updateReqVO, AiChatRoleDO.class);
@@ -103,6 +113,19 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
         knowledgeIds.forEach(knowledgeService::validateKnowledgeExists);
     }
 
+    /**
+     * 校验工具是否存在
+     *
+     * @param toolIds 工具编号列表
+     */
+    private void validateTools(List<Long> toolIds) {
+        if (CollUtil.isEmpty(toolIds)) {
+            return;
+        }
+        // 遍历校验每个工具是否存在
+        toolIds.forEach(toolService::validateToolExists);
+    }
+
     @Override
     public void deleteChatRole(Long id) {
         // 校验存在

+ 33 - 7
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiToolService.java

@@ -6,6 +6,9 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.tool.AiToolSaveReqVO
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiToolDO;
 import jakarta.validation.Valid;
 
+import java.util.Collection;
+import java.util.List;
+
 /**
  * AI 工具 Service 接口
  *
@@ -14,7 +17,7 @@ import jakarta.validation.Valid;
 public interface AiToolService {
 
     /**
-     * 创建AI 工具
+     * 创建工具
      *
      * @param createReqVO 创建信息
      * @return 编号
@@ -22,33 +25,56 @@ public interface AiToolService {
     Long createTool(@Valid AiToolSaveReqVO createReqVO);
 
     /**
-     * 更新AI 工具
+     * 更新工具
      *
      * @param updateReqVO 更新信息
      */
     void updateTool(@Valid AiToolSaveReqVO updateReqVO);
 
     /**
-     * 删除AI 工具
+     * 删除工具
      *
      * @param id 编号
      */
     void deleteTool(Long id);
 
     /**
-     * 获得AI 工具
+     * 校验工具是否存在
      *
      * @param id 编号
-     * @return AI 工具
+     */
+    void validateToolExists(Long id);
+
+    /**
+     * 获得工具
+     *
+     * @param id 编号
+     * @return 工具
      */
     AiToolDO getTool(Long id);
 
     /**
-     * 获得AI 工具分页
+     * 获得工具列表
+     *
+     * @param ids 编号列表
+     * @return 工具列表
+     */
+    List<AiToolDO> getToolList(Collection<Long> ids);
+
+    /**
+     * 获得工具分页
      *
      * @param pageReqVO 分页查询
-     * @return AI 工具分页
+     * @return 工具分页
      */
     PageResult<AiToolDO> getToolPage(AiToolPageReqVO pageReqVO);
 
+    /**
+     * 获得工具列表
+     *
+     * @param status 状态
+     * @return 工具列表
+     */
+    List<AiToolDO> getToolListByStatus(Integer status);
+
 }

+ 15 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiToolServiceImpl.java

@@ -12,6 +12,9 @@ import org.springframework.beans.factory.NoSuchBeanDefinitionException;
 import org.springframework.stereotype.Service;
 import org.springframework.validation.annotation.Validated;
 
+import java.util.Collection;
+import java.util.List;
+
 import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
 import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.TOOL_NAME_NOT_EXISTS;
 import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.TOOL_NOT_EXISTS;
@@ -59,7 +62,8 @@ public class AiToolServiceImpl implements AiToolService {
         toolMapper.deleteById(id);
     }
 
-    private void validateToolExists(Long id) {
+    @Override
+    public void validateToolExists(Long id) {
         if (toolMapper.selectById(id) == null) {
             throw exception(TOOL_NOT_EXISTS);
         }
@@ -78,9 +82,19 @@ public class AiToolServiceImpl implements AiToolService {
         return toolMapper.selectById(id);
     }
 
+    @Override
+    public List<AiToolDO> getToolList(Collection<Long> ids) {
+        return toolMapper.selectBatchIds(ids);
+    }
+
     @Override
     public PageResult<AiToolDO> getToolPage(AiToolPageReqVO pageReqVO) {
         return toolMapper.selectPage(pageReqVO);
     }
 
+    @Override
+    public List<AiToolDO> getToolListByStatus(Integer status) {
+        return toolMapper.selectListByStatus(status);
+    }
+
 }