Browse Source

【代码重构】AI:“聊天模型”重构为“模型”,支持 type 模型类型

YunaiV 5 tháng trước cách đây
mục cha
commit
433e91da8e

+ 5 - 4
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatConversationServiceImpl.java

@@ -45,7 +45,7 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
     private AiChatConversationMapper chatConversationMapper;
 
     @Resource
-    private AiModelService chatModalService;
+    private AiModelService modalService;
     @Resource
     private AiChatRoleService chatRoleService;
     @Resource
@@ -56,8 +56,8 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
         // 1.1 获得 AiChatRoleDO 聊天角色
         AiChatRoleDO role = createReqVO.getRoleId() != null ? chatRoleService.validateChatRole(createReqVO.getRoleId()) : null;
         // 1.2 获得 AiModelDO 聊天模型
-        AiModelDO model = role != null && role.getModelId() != null ? chatModalService.validateModel(role.getModelId())
-                : chatModalService.getRequiredDefaultModel(AiModelTypeEnum.CHAT.getType());
+        AiModelDO model = role != null && role.getModelId() != null ? modalService.validateModel(role.getModelId())
+                : modalService.getRequiredDefaultModel(AiModelTypeEnum.CHAT.getType());
         Assert.notNull(model, "必须找到默认模型");
         validateChatModel(model);
 
@@ -89,7 +89,7 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
         // 1.2 校验模型是否存在(修改模型的情况)
         AiModelDO model = null;
         if (updateReqVO.getModelId() != null) {
-            model = chatModalService.validateModel(updateReqVO.getModelId());
+            model = modalService.validateModel(updateReqVO.getModelId());
         }
 
         // 1.3 校验知识库是否存在
@@ -144,6 +144,7 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
         if (ObjectUtil.isAllNotEmpty(model.getTemperature(), model.getMaxTokens(), model.getMaxContexts())) {
             return;
         }
+        Assert.equals(model.getType(), AiModelTypeEnum.CHAT.getType(), "模型类型不正确:" + model);
         throw exception(CHAT_CONVERSATION_MODEL_ERROR);
     }
 

+ 3 - 3
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeServiceImpl.java

@@ -28,12 +28,12 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
     private AiKnowledgeMapper knowledgeMapper;
 
     @Resource
-    private AiModelService chatModelService;
+    private AiModelService modelService;
 
     @Override
     public Long createKnowledge(AiKnowledgeSaveReqVO createReqVO) {
         // 1. 校验模型配置
-        AiModelDO model = chatModelService.validateModel(createReqVO.getEmbeddingModelId());
+        AiModelDO model = modelService.validateModel(createReqVO.getEmbeddingModelId());
 
         // 2. 插入知识库
         AiKnowledgeDO knowledge = BeanUtils.toBean(createReqVO, AiKnowledgeDO.class)
@@ -47,7 +47,7 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
         // 1.1 校验知识库存在
         validateKnowledgeExists(updateReqVO.getId());
         // 1.2 校验模型配置
-        AiModelDO model = chatModelService.validateModel(updateReqVO.getEmbeddingModelId());
+        AiModelDO model = modelService.validateModel(updateReqVO.getEmbeddingModelId());
 
         // 2. 更新知识库
         AiKnowledgeDO updateObj = BeanUtils.toBean(updateReqVO, AiKnowledgeDO.class)

+ 1 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelServiceImpl.java → yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiModelServiceImpl.java

@@ -33,7 +33,7 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
  */
 @Service
 @Validated
-public class AiChatModelServiceImpl implements AiModelService {
+public class AiModelServiceImpl implements AiModelService {
 
     @Resource
     private AiApiKeyService apiKeyService;

+ 4 - 4
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java

@@ -54,7 +54,7 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.WRITE_NOT_EXIS
 public class AiWriteServiceImpl implements AiWriteService {
 
     @Resource
-    private AiModelService chatModalService;
+    private AiModelService modalService;
     @Resource
     private AiChatRoleService chatRoleService;
 
@@ -76,7 +76,7 @@ public class AiWriteServiceImpl implements AiWriteService {
                 ? writeRole.getSystemMessage() : AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage();
         // 1.3 校验平台
         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
-        StreamingChatModel chatModel = chatModalService.getChatModel(model.getKeyId());
+        StreamingChatModel chatModel = modalService.getChatModel(model.getKeyId());
 
         // 2. 插入写作信息
         AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class, write -> write.setUserId(userId)
@@ -110,10 +110,10 @@ public class AiWriteServiceImpl implements AiWriteService {
     private AiModelDO getModel(AiChatRoleDO writeRole) {
         AiModelDO model = null;
         if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) {
-            model = chatModalService.getModel(writeRole.getModelId());
+            model = modalService.getModel(writeRole.getModelId());
         }
         if (model == null) {
-            model = chatModalService.getRequiredDefaultModel(AiModelTypeEnum.CHAT.getType());
+            model = modalService.getRequiredDefaultModel(AiModelTypeEnum.CHAT.getType());
         }
         Assert.notNull(model, "[AI] 获取不到模型");
         return model;

+ 0 - 4
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/OpenAiImageModelTests.java

@@ -5,13 +5,9 @@ import org.junit.jupiter.api.Test;
 import org.springframework.ai.image.ImageOptions;
 import org.springframework.ai.image.ImagePrompt;
 import org.springframework.ai.image.ImageResponse;
-import org.springframework.ai.openai.OpenAiChatModel;
-import org.springframework.ai.openai.OpenAiChatOptions;
 import org.springframework.ai.openai.OpenAiImageModel;
 import org.springframework.ai.openai.OpenAiImageOptions;
-import org.springframework.ai.openai.api.OpenAiApi;
 import org.springframework.ai.openai.api.OpenAiImageApi;
-import org.springframework.web.client.RestClient;
 
 /**
  * {@link OpenAiImageModel} 集成测试类