Răsfoiți Sursa

【代码重构】AI:“聊天模型”重构为“模型”,支持 type 模型类型

YunaiV 5 luni în urmă
părinte
comite
89d079349c
36 a modificat fișierele cu 617 adăugiri și 554 ștergeri
  1. 4 10
      yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/ErrorCodeConstants.java
  2. 2 2
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/conversation/AiChatConversationRespVO.java
  3. 3 6
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/AiImageDrawReqVO.java
  4. 3 4
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/AiApiKeyController.java
  5. 0 84
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/AiChatModelController.java
  6. 89 0
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/AiModelController.java
  7. 2 2
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatRole/AiChatRoleRespVO.java
  8. 3 3
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/model/AiModelPageReqVO.java
  9. 6 3
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/model/AiModelRespVO.java
  10. 14 5
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/model/AiModelSaveReqVO.java
  11. 3 3
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatConversationDO.java
  12. 3 3
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java
  13. 9 4
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/image/AiImageDO.java
  14. 3 3
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeDO.java
  15. 7 1
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java
  16. 1 1
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiChatRoleDO.java
  17. 12 6
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiModelDO.java
  18. 12 5
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/write/AiWriteDO.java
  19. 47 0
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/model/AiChatMapper.java
  20. 0 43
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/model/AiChatModelMapper.java
  21. 10 9
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatConversationServiceImpl.java
  22. 10 13
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java
  23. 42 30
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java
  24. 11 7
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentServiceImpl.java
  25. 5 5
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeServiceImpl.java
  26. 13 15
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java
  27. 4 47
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java
  28. 6 68
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java
  29. 0 92
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelService.java
  30. 91 37
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelServiceImpl.java
  31. 131 0
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiModelService.java
  32. 4 4
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/music/AiMusicServiceImpl.java
  33. 15 17
      yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java
  34. 41 0
      yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiModelTypeEnum.java
  35. 11 1
      yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiPlatformEnum.java
  36. 0 21
      yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java

+ 4 - 10
yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/ErrorCodeConstants.java

@@ -12,31 +12,25 @@ public interface ErrorCodeConstants {
     // ========== API 密钥 1-040-000-000 ==========
     ErrorCode API_KEY_NOT_EXISTS = new ErrorCode(1_040_000_000, "API 密钥不存在");
     ErrorCode API_KEY_DISABLE = new ErrorCode(1_040_000_001, "API 密钥已禁用!");
-    ErrorCode API_KEY_MIDJOURNEY_NOT_FOUND = new ErrorCode(1_040_000_900, "Midjourney 模型不存在");
-    ErrorCode API_KEY_SUNO_NOT_FOUND = new ErrorCode(1_040_000_901, "Suno 模型不存在");
-    ErrorCode API_KEY_IMAGE_NODE_FOUND = new ErrorCode(1_040_000_902, "平台({}) 图片模型未配置");
 
-    // ========== API 聊天模型 1-040-001-000 ==========
-    ErrorCode CHAT_MODEL_NOT_EXISTS = new ErrorCode(1_040_001_000, "模型不存在!");
-    ErrorCode CHAT_MODEL_DISABLE = new ErrorCode(1_040_001_001, "模型({})已禁用!");
-    ErrorCode CHAT_MODEL_DEFAULT_NOT_EXISTS = new ErrorCode(1_040_001_002, "操作失败,找不到默认聊天模型");
+    // ========== API 模型 1-040-001-000 ==========
+    ErrorCode MODEL_NOT_EXISTS = new ErrorCode(1_040_001_000, "模型不存在!");
+    ErrorCode MODEL_DISABLE = new ErrorCode(1_040_001_001, "模型({})已禁用!");
+    ErrorCode MODEL_DEFAULT_NOT_EXISTS = new ErrorCode(1_040_001_002, "操作失败,找不到默认模型");
 
     // ========== API 聊天角色 1-040-002-000 ==========
     ErrorCode CHAT_ROLE_NOT_EXISTS = new ErrorCode(1_040_002_000, "聊天角色不存在");
     ErrorCode CHAT_ROLE_DISABLE = new ErrorCode(1_040_001_001, "聊天角色({})已禁用!");
 
     // ========== API 聊天会话 1-040-003-000 ==========
-
     ErrorCode CHAT_CONVERSATION_NOT_EXISTS = new ErrorCode(1_040_003_000, "对话不存在!");
     ErrorCode CHAT_CONVERSATION_MODEL_ERROR = new ErrorCode(1_040_003_001, "操作失败,该聊天模型的配置不完整");
 
     // ========== API 聊天消息 1-040-004-000 ==========
-
     ErrorCode CHAT_MESSAGE_NOT_EXIST = new ErrorCode(1_040_004_000, "消息不存在!");
     ErrorCode CHAT_STREAM_ERROR = new ErrorCode(1_040_004_001, "对话生成异常!");
 
     // ========== API 绘画 1-040-005-000 ==========
-
     ErrorCode IMAGE_NOT_EXISTS = new ErrorCode(1_022_005_000, "图片不存在!");
     ErrorCode IMAGE_MIDJOURNEY_SUBMIT_FAIL = new ErrorCode(1_022_005_001, "Midjourney 提交失败!原因:{}");
     ErrorCode IMAGE_CUSTOM_ID_NOT_EXISTS = new ErrorCode(1_022_005_002, "Midjourney 按钮 customId 不存在! {}");

+ 2 - 2
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/conversation/AiChatConversationRespVO.java

@@ -1,6 +1,6 @@
 package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation;
 
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
 import com.fhs.core.trans.anno.Trans;
 import com.fhs.core.trans.constant.TransType;
@@ -31,7 +31,7 @@ public class AiChatConversationRespVO implements VO {
     private Long roleId;
 
     @Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
-    @Trans(type = TransType.SIMPLE, target = AiChatModelDO.class, fields = "name", ref = "modelName")
+    @Trans(type = TransType.SIMPLE, target = AiModelDO.class, fields = "name", ref = "modelName")
     private Long modelId;
 
     @Schema(description = "模型标志", requiredMode = Schema.RequiredMode.REQUIRED, example = "ERNIE-Bot-turbo-0922")

+ 3 - 6
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/AiImageDrawReqVO.java

@@ -14,18 +14,15 @@ import java.util.Map;
 @Data
 public class AiImageDrawReqVO {
 
-    @Schema(description = "模型平台", requiredMode = Schema.RequiredMode.REQUIRED, example = "OpenAI")
-    private String platform; // 参见 AiPlatformEnum 枚举
+    @Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024")
+    @NotNull(message = "模型编号不能为空")
+    private Long modelId;
 
     @Schema(description = "提示词", requiredMode = Schema.RequiredMode.REQUIRED, example = "画一个长城")
     @NotEmpty(message = "提示词不能为空")
     @Size(max = 1200, message = "提示词最大 1200")
     private String prompt;
 
-    @Schema(description = "模型", requiredMode = Schema.RequiredMode.REQUIRED, example = "stable-diffusion-v1-6")
-    @NotEmpty(message = "模型不能为空")
-    private String model;
-
     /**
      * 1. dall-e-2 模型:256x256、512x512、1024x1024
      * 2. dall-e-3 模型:1024x1024, 1792x1024, 或 1024x1792

+ 3 - 4
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/AiApiKeyController.java

@@ -6,9 +6,8 @@ import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
 import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyRespVO;
 import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO;
-import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelRespVO;
+import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelRespVO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
 import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
 import io.swagger.v3.oas.annotations.Operation;
 import io.swagger.v3.oas.annotations.Parameter;
@@ -76,9 +75,9 @@ public class AiApiKeyController {
 
     @GetMapping("/simple-list")
     @Operation(summary = "获得 API 密钥分页列表")
-    public CommonResult<List<AiChatModelRespVO>> getApiKeySimpleList() {
+    public CommonResult<List<AiModelRespVO>> getApiKeySimpleList() {
         List<AiApiKeyDO> list = apiKeyService.getApiKeyList();
-        return success(convertList(list, key -> new AiChatModelRespVO().setId(key.getId()).setName(key.getName())));
+        return success(convertList(list, key -> new AiModelRespVO().setId(key.getId()).setName(key.getName())));
     }
 
 }

+ 0 - 84
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/AiChatModelController.java

@@ -1,84 +0,0 @@
-package cn.iocoder.yudao.module.ai.controller.admin.model;
-
-import cn.iocoder.yudao.framework.common.pojo.CommonResult;
-import cn.iocoder.yudao.framework.common.pojo.PageResult;
-import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
-import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelPageReqVO;
-import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelRespVO;
-import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelSaveReqVO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
-import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
-import io.swagger.v3.oas.annotations.Operation;
-import io.swagger.v3.oas.annotations.Parameter;
-import io.swagger.v3.oas.annotations.tags.Tag;
-import jakarta.annotation.Resource;
-import jakarta.validation.Valid;
-import org.springframework.security.access.prepost.PreAuthorize;
-import org.springframework.validation.annotation.Validated;
-import org.springframework.web.bind.annotation.*;
-
-import java.util.List;
-
-import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
-import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
-
-@Tag(name = "管理后台 - AI 聊天模型")
-@RestController
-@RequestMapping("/ai/chat-model")
-@Validated
-public class AiChatModelController {
-
-    @Resource
-    private AiChatModelService chatModelService;
-
-    @PostMapping("/create")
-    @Operation(summary = "创建聊天模型")
-    @PreAuthorize("@ss.hasPermission('ai:chat-model:create')")
-    public CommonResult<Long> createChatModel(@Valid @RequestBody AiChatModelSaveReqVO createReqVO) {
-        return success(chatModelService.createChatModel(createReqVO));
-    }
-
-    @PutMapping("/update")
-    @Operation(summary = "更新聊天模型")
-    @PreAuthorize("@ss.hasPermission('ai:chat-model:update')")
-    public CommonResult<Boolean> updateChatModel(@Valid @RequestBody AiChatModelSaveReqVO updateReqVO) {
-        chatModelService.updateChatModel(updateReqVO);
-        return success(true);
-    }
-
-    @DeleteMapping("/delete")
-    @Operation(summary = "删除聊天模型")
-    @Parameter(name = "id", description = "编号", required = true)
-    @PreAuthorize("@ss.hasPermission('ai:chat-model:delete')")
-    public CommonResult<Boolean> deleteChatModel(@RequestParam("id") Long id) {
-        chatModelService.deleteChatModel(id);
-        return success(true);
-    }
-
-    @GetMapping("/get")
-    @Operation(summary = "获得聊天模型")
-    @Parameter(name = "id", description = "编号", required = true, example = "1024")
-    @PreAuthorize("@ss.hasPermission('ai:chat-model:query')")
-    public CommonResult<AiChatModelRespVO> getChatModel(@RequestParam("id") Long id) {
-        AiChatModelDO chatModel = chatModelService.getChatModel(id);
-        return success(BeanUtils.toBean(chatModel, AiChatModelRespVO.class));
-    }
-
-    @GetMapping("/page")
-    @Operation(summary = "获得聊天模型分页")
-    @PreAuthorize("@ss.hasPermission('ai:chat-model:query')")
-    public CommonResult<PageResult<AiChatModelRespVO>> getChatModelPage(@Valid AiChatModelPageReqVO pageReqVO) {
-        PageResult<AiChatModelDO> pageResult = chatModelService.getChatModelPage(pageReqVO);
-        return success(BeanUtils.toBean(pageResult, AiChatModelRespVO.class));
-    }
-
-    @GetMapping("/simple-list")
-    @Operation(summary = "获得聊天模型列表")
-    @Parameter(name = "status", description = "状态", required = true, example = "1")
-    public CommonResult<List<AiChatModelRespVO>> getChatModelSimpleList(@RequestParam("status") Integer status) {
-        List<AiChatModelDO> list = chatModelService.getChatModelListByStatus(status);
-        return success(convertList(list, model -> new AiChatModelRespVO().setId(model.getId())
-                .setName(model.getName()).setModel(model.getModel())));
-    }
-
-}

+ 89 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/AiModelController.java

@@ -0,0 +1,89 @@
+package cn.iocoder.yudao.module.ai.controller.admin.model;
+
+import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
+import cn.iocoder.yudao.framework.common.pojo.CommonResult;
+import cn.iocoder.yudao.framework.common.pojo.PageResult;
+import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
+import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelPageReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelRespVO;
+import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelSaveReqVO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
+import cn.iocoder.yudao.module.ai.service.model.AiModelService;
+import io.swagger.v3.oas.annotations.Operation;
+import io.swagger.v3.oas.annotations.Parameter;
+import io.swagger.v3.oas.annotations.tags.Tag;
+import jakarta.annotation.Resource;
+import jakarta.validation.Valid;
+import org.springframework.security.access.prepost.PreAuthorize;
+import org.springframework.validation.annotation.Validated;
+import org.springframework.web.bind.annotation.*;
+
+import java.util.List;
+
+import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
+import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
+
+@Tag(name = "管理后台 - AI 模型")
+@RestController
+@RequestMapping("/ai/model")
+@Validated
+public class AiModelController {
+
+    @Resource
+    private AiModelService modelService;
+
+    @PostMapping("/create")
+    @Operation(summary = "创建模型")
+    @PreAuthorize("@ss.hasPermission('ai:model:create')")
+    public CommonResult<Long> createModel(@Valid @RequestBody AiModelSaveReqVO createReqVO) {
+        return success(modelService.createModel(createReqVO));
+    }
+
+    @PutMapping("/update")
+    @Operation(summary = "更新模型")
+    @PreAuthorize("@ss.hasPermission('ai:model:update')")
+    public CommonResult<Boolean> updateModel(@Valid @RequestBody AiModelSaveReqVO updateReqVO) {
+        modelService.updateModel(updateReqVO);
+        return success(true);
+    }
+
+    @DeleteMapping("/delete")
+    @Operation(summary = "删除模型")
+    @Parameter(name = "id", description = "编号", required = true)
+    @PreAuthorize("@ss.hasPermission('ai:model:delete')")
+    public CommonResult<Boolean> deleteModel(@RequestParam("id") Long id) {
+        modelService.deleteModel(id);
+        return success(true);
+    }
+
+    @GetMapping("/get")
+    @Operation(summary = "获得模型")
+    @Parameter(name = "id", description = "编号", required = true, example = "1024")
+    @PreAuthorize("@ss.hasPermission('ai:model:query')")
+    public CommonResult<AiModelRespVO> getModel(@RequestParam("id") Long id) {
+        AiModelDO model = modelService.getModel(id);
+        return success(BeanUtils.toBean(model, AiModelRespVO.class));
+    }
+
+    @GetMapping("/page")
+    @Operation(summary = "获得模型分页")
+    @PreAuthorize("@ss.hasPermission('ai:model:query')")
+    public CommonResult<PageResult<AiModelRespVO>> getModelPage(@Valid AiModelPageReqVO pageReqVO) {
+        PageResult<AiModelDO> pageResult = modelService.getModelPage(pageReqVO);
+        return success(BeanUtils.toBean(pageResult, AiModelRespVO.class));
+    }
+
+    @GetMapping("/simple-list")
+    @Operation(summary = "获得模型列表")
+    @Parameter(name = "type", description = "类型", required = true, example = "1")
+    @Parameter(name = "platform", description = "平台", example = "midjourney")
+    public CommonResult<List<AiModelRespVO>> getModelSimpleList(
+            @RequestParam("type") Integer type,
+            @RequestParam(value = "platform", required = false) String platform) {
+        List<AiModelDO> list = modelService.getModelListByStatusAndType(
+                CommonStatusEnum.ENABLE.getStatus(), type, platform);
+        return success(convertList(list, model -> new AiModelRespVO().setId(model.getId())
+                .setName(model.getName()).setModel(model.getModel()).setPlatform(model.getPlatform())));
+    }
+
+}

+ 2 - 2
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatRole/AiChatRoleRespVO.java

@@ -1,6 +1,6 @@
 package cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole;
 
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
 import com.fhs.core.trans.anno.Trans;
 import com.fhs.core.trans.constant.TransType;
 import com.fhs.core.trans.vo.VO;
@@ -20,7 +20,7 @@ public class AiChatRoleRespVO implements VO {
     private Long userId;
 
     @Schema(description = "模型编号", example = "17640")
-    @Trans(type = TransType.SIMPLE, target = AiChatModelDO.class, fields = {"name", "model"}, refs = {"modelName", "model"})
+    @Trans(type = TransType.SIMPLE, target = AiModelDO.class, fields = {"name", "model"}, refs = {"modelName", "model"})
     private Long modelId;
     @Schema(description = "模型名字", example = "张三")
     private String modelName;

+ 3 - 3
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatModel/AiChatModelPageReqVO.java → yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/model/AiModelPageReqVO.java

@@ -1,12 +1,12 @@
-package cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel;
+package cn.iocoder.yudao.module.ai.controller.admin.model.vo.model;
 
 import lombok.*;
 import io.swagger.v3.oas.annotations.media.Schema;
 import cn.iocoder.yudao.framework.common.pojo.PageParam;
 
-@Schema(description = "管理后台 - API 聊天模型分页 Request VO")
+@Schema(description = "管理后台 - API 模型分页 Request VO")
 @Data
-public class AiChatModelPageReqVO extends PageParam {
+public class AiModelPageReqVO extends PageParam {
 
     @Schema(description = "模型名字", example = "张三")
     private String name;

+ 6 - 3
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatModel/AiChatModelRespVO.java → yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/model/AiModelRespVO.java

@@ -1,13 +1,13 @@
-package cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel;
+package cn.iocoder.yudao.module.ai.controller.admin.model.vo.model;
 
 import io.swagger.v3.oas.annotations.media.Schema;
 import lombok.Data;
 
 import java.time.LocalDateTime;
 
-@Schema(description = "管理后台 - AI 聊天模型 Response VO")
+@Schema(description = "管理后台 - AI 模型 Response VO")
 @Data
-public class AiChatModelRespVO {
+public class AiModelRespVO {
 
     @Schema(description = "编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "2630")
     private Long id;
@@ -24,6 +24,9 @@ public class AiChatModelRespVO {
     @Schema(description = "模型平台", example = "OpenAI")
     private String platform;
 
+    @Schema(description = "模型类型", example = "1")
+    private Integer type;
+
     @Schema(description = "排序", example = "1")
     private Integer sort;
 

+ 14 - 5
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/chatModel/AiChatModelSaveReqVO.java → yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/model/vo/model/AiModelSaveReqVO.java

@@ -1,14 +1,17 @@
-package cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel;
+package cn.iocoder.yudao.module.ai.controller.admin.model.vo.model;
 
+import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum;
+import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
 import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
 import cn.iocoder.yudao.framework.common.validation.InEnum;
 import io.swagger.v3.oas.annotations.media.Schema;
-import lombok.*;
-import jakarta.validation.constraints.*;
+import jakarta.validation.constraints.NotEmpty;
+import jakarta.validation.constraints.NotNull;
+import lombok.Data;
 
-@Schema(description = "管理后台 - API 聊天模型新增/修改 Request VO")
+@Schema(description = "管理后台 - API 模型新增/修改 Request VO")
 @Data
-public class AiChatModelSaveReqVO {
+public class AiModelSaveReqVO {
 
     @Schema(description = "编号", example = "2630")
     private Long id;
@@ -27,8 +30,14 @@ public class AiChatModelSaveReqVO {
 
     @Schema(description = "模型平台", requiredMode = Schema.RequiredMode.REQUIRED, example = "OpenAI")
     @NotEmpty(message = "模型平台不能为空")
+    @InEnum(AiPlatformEnum.class)
     private String platform;
 
+    @Schema(description = "模型类型", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
+    @NotNull(message = "模型类型不能为空")
+    @InEnum(AiModelTypeEnum.class)
+    private Integer type;
+
     @Schema(description = "排序", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
     @NotNull(message = "排序不能为空")
     private Integer sort;

+ 3 - 3
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatConversationDO.java

@@ -2,7 +2,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.chat;
 
 import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
 import com.baomidou.mybatisplus.annotation.KeySequence;
 import com.baomidou.mybatisplus.annotation.TableId;
@@ -76,13 +76,13 @@ public class AiChatConversationDO extends BaseDO {
     /**
      * 模型编号
      *
-     * 关联 {@link AiChatModelDO#getId()} 字段
+     * 关联 {@link AiModelDO#getId()} 字段
      */
     private Long modelId;
     /**
      * 模型标志
      *
-     * 冗余 {@link AiChatModelDO#getModel()} 字段
+     * 冗余 {@link AiModelDO#getModel()} 字段
      */
     private String model;
 

+ 3 - 3
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java

@@ -2,7 +2,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.chat;
 
 import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
 import com.baomidou.mybatisplus.annotation.KeySequence;
 import com.baomidou.mybatisplus.annotation.TableField;
@@ -83,13 +83,13 @@ public class AiChatMessageDO extends BaseDO {
     /**
      * 模型标志
      *
-     * 冗余 {@link AiChatModelDO#getModel()}
+     * 冗余 {@link AiModelDO#getModel()}
      */
     private String model;
     /**
      * 模型编号
      *
-     * 关联 {@link AiChatModelDO#getId()} 字段
+     * 关联 {@link AiModelDO#getId()} 字段
      */
     private Long modelId;
 

+ 9 - 4
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/image/AiImageDO.java

@@ -2,7 +2,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.image;
 
 import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
 import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
 import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
 import cn.iocoder.yudao.module.system.api.user.dto.AdminUserRespDTO;
 import com.baomidou.mybatisplus.annotation.KeySequence;
@@ -52,11 +52,16 @@ public class AiImageDO extends BaseDO {
      * 枚举 {@link cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum}
      */
     private String platform;
-    // TODO @芋艿:modelId?
     /**
-     * 模型
+     * 模型编号
      *
-     * 冗余 {@link AiChatModelDO#getModel()}
+     * 关联 {@link AiModelDO#getId()}
+     */
+    private Long modelId;
+    /**
+     * 模型标识
+     *
+     * 冗余 {@link AiModelDO#getModel()}
      */
     private String model;
 

+ 3 - 3
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeDO.java

@@ -2,7 +2,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.knowledge;
 
 import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
 import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
 import com.baomidou.mybatisplus.annotation.KeySequence;
 import com.baomidou.mybatisplus.annotation.TableId;
 import com.baomidou.mybatisplus.annotation.TableName;
@@ -35,13 +35,13 @@ public class AiKnowledgeDO extends BaseDO {
     /**
      * 向量模型编号
      *
-     * 关联 {@link AiChatModelDO#getId()}
+     * 关联 {@link AiModelDO#getId()}
      */
     private Long embeddingModelId;
     /**
      * 模型标识
      *
-     * 冗余 {@link AiChatModelDO#getModel()}
+     * 冗余 {@link AiModelDO#getModel()}
      */
     private String embeddingModel;
 

+ 7 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java

@@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.mindmap;
 
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
 import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
 import com.baomidou.mybatisplus.annotation.KeySequence;
 import com.baomidou.mybatisplus.annotation.TableId;
 import com.baomidou.mybatisplus.annotation.TableName;
@@ -36,7 +37,12 @@ public class AiMindMapDO extends BaseDO {
      * 枚举 {@link AiPlatformEnum}
      */
     private String platform;
-    // TODO @芋艿:modelId?
+    /**
+     * 模型编号
+     *
+     * 关联 {@link AiModelDO#getId()}
+     */
+    private Long modelId;
     /**
      * 模型
      */

+ 1 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiChatRoleDO.java

@@ -58,7 +58,7 @@ public class AiChatRoleDO extends BaseDO {
     /**
      * 模型编号
      *
-     * 关联 {@link AiChatModelDO#getId()} 字段
+     * 关联 {@link AiModelDO#getId()} 字段
      */
     private Long modelId;
 

+ 12 - 6
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiChatModelDO.java → yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiModelDO.java

@@ -1,5 +1,6 @@
 package cn.iocoder.yudao.module.ai.dal.dataobject.model;
 
+import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum;
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
 import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
 import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
@@ -8,23 +9,22 @@ import com.baomidou.mybatisplus.annotation.TableId;
 import com.baomidou.mybatisplus.annotation.TableName;
 import lombok.*;
 
-// TODO @芋艿,需要改造,增加 type
 /**
- * AI 聊天模型 DO
+ * AI 模型 DO
  *
- * 默认聊天模型:{@link #status} 为开启,并且 {@link #sort} 排序第一
+ * 默认模型:{@link #status} 为开启,并且 {@link #sort} 排序第一
  *
  * @author fansili
  * @since 2024/4/24 19:39
  */
-@TableName("ai_chat_model")
-@KeySequence("ai_chat_model_seq") // 用于 Oracle、PostgreSQL、Kingbase、DB2、H2 数据库的主键自增。如果是 MySQL 等数据库,可不写。
+@TableName("ai_model")
+@KeySequence("ai_model_seq") // 用于 Oracle、PostgreSQL、Kingbase、DB2、H2 数据库的主键自增。如果是 MySQL 等数据库,可不写。
 @Data
 @EqualsAndHashCode(callSuper = true)
 @Builder
 @NoArgsConstructor
 @AllArgsConstructor
-public class AiChatModelDO extends BaseDO {
+public class AiModelDO extends BaseDO {
 
     /**
      * 编号
@@ -51,6 +51,12 @@ public class AiChatModelDO extends BaseDO {
      * 枚举 {@link AiPlatformEnum}
      */
     private String platform;
+    /**
+     * 类型
+     *
+     * 枚举 {@link AiModelTypeEnum}
+     */
+    private Integer type;
 
     /**
      * 排序值

+ 12 - 5
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/write/AiWriteDO.java

@@ -2,6 +2,8 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.write;
 
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
 import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
+import cn.iocoder.yudao.module.ai.enums.DictTypeConstants;
 import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
 import com.baomidou.mybatisplus.annotation.KeySequence;
 import com.baomidou.mybatisplus.annotation.TableId;
@@ -44,7 +46,12 @@ public class AiWriteDO extends BaseDO {
      * 枚举 {@link AiPlatformEnum}
      */
     private String platform;
-    // TODO @芋艿:modelId?
+    /**
+     * 模型编号
+     *
+     * 关联 {@link AiModelDO#getId()}
+     */
+    private Long modelId;
     /**
      * 模型
      */
@@ -67,25 +74,25 @@ public class AiWriteDO extends BaseDO {
     /**
      * 长度提示词
      *
-     * 字典:{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_LENGTH}
+     * 字典:{@link DictTypeConstants#AI_WRITE_LENGTH}
      */
     private Integer length;
     /**
      * 格式提示词
      *
-     * 字典:{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_FORMAT}
+     * 字典:{@link DictTypeConstants#AI_WRITE_FORMAT}
      */
     private Integer format;
     /**
      * 语气提示词
      *
-     * 字典:{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_TONE}
+     * 字典:{@link DictTypeConstants#AI_WRITE_TONE}
      */
     private Integer tone;
     /**
      * 语言提示词
      *
-     * 字典:{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_LANGUAGE}
+     * 字典:{@link DictTypeConstants#AI_WRITE_LANGUAGE}
      */
     private Integer language;
 

+ 47 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/model/AiChatMapper.java

@@ -0,0 +1,47 @@
+package cn.iocoder.yudao.module.ai.dal.mysql.model;
+
+import cn.iocoder.yudao.framework.common.pojo.PageResult;
+import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
+import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
+import cn.iocoder.yudao.framework.mybatis.core.query.QueryWrapperX;
+import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelPageReqVO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
+import org.apache.ibatis.annotations.Mapper;
+
+import javax.annotation.Nullable;
+import java.util.List;
+
+/**
+ * API 模型 Mapper
+ *
+ * @author fansili
+ */
+@Mapper
+public interface AiChatMapper extends BaseMapperX<AiModelDO> {
+
+    default AiModelDO selectFirstByStatus(Integer type, Integer status) {
+        return selectOne(new QueryWrapperX<AiModelDO>()
+                .eq("type", type)
+                .eq("status", status)
+                .limitN(1)
+                .orderByAsc("sort"));
+    }
+
+    default PageResult<AiModelDO> selectPage(AiModelPageReqVO reqVO) {
+        return selectPage(reqVO, new LambdaQueryWrapperX<AiModelDO>()
+                .likeIfPresent(AiModelDO::getName, reqVO.getName())
+                .eqIfPresent(AiModelDO::getModel, reqVO.getModel())
+                .eqIfPresent(AiModelDO::getPlatform, reqVO.getPlatform())
+                .orderByAsc(AiModelDO::getSort));
+    }
+
+    default List<AiModelDO> selectListByStatusAndType(Integer status, Integer type,
+                                                      @Nullable String platform) {
+        return selectList(new LambdaQueryWrapperX<AiModelDO>()
+                .eq(AiModelDO::getStatus, status)
+                .eq(AiModelDO::getType, type)
+                .eqIfPresent(AiModelDO::getPlatform, platform)
+                .orderByAsc(AiModelDO::getSort));
+    }
+
+}

+ 0 - 43
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/model/AiChatModelMapper.java

@@ -1,43 +0,0 @@
-package cn.iocoder.yudao.module.ai.dal.mysql.model;
-
-import cn.iocoder.yudao.framework.common.pojo.PageResult;
-import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
-import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
-import cn.iocoder.yudao.framework.mybatis.core.query.QueryWrapperX;
-import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelPageReqVO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
-import org.apache.ibatis.annotations.Mapper;
-
-import java.util.Collection;
-import java.util.List;
-
-/**
- * API 聊天模型 Mapper
- *
- * @author fansili
- */
-@Mapper
-public interface AiChatModelMapper extends BaseMapperX<AiChatModelDO> {
-
-    default AiChatModelDO selectFirstByStatus(Integer status) {
-        return selectOne(new QueryWrapperX<AiChatModelDO>()
-                .eq("status", status)
-                .limitN(1)
-                .orderByAsc("sort"));
-    }
-
-    default PageResult<AiChatModelDO> selectPage(AiChatModelPageReqVO reqVO) {
-        return selectPage(reqVO, new LambdaQueryWrapperX<AiChatModelDO>()
-                .likeIfPresent(AiChatModelDO::getName, reqVO.getName())
-                .eqIfPresent(AiChatModelDO::getModel, reqVO.getModel())
-                .eqIfPresent(AiChatModelDO::getPlatform, reqVO.getPlatform())
-                .orderByAsc(AiChatModelDO::getSort));
-    }
-
-    default List<AiChatModelDO> selectList(Integer status) {
-        return selectList(new LambdaQueryWrapperX<AiChatModelDO>()
-                .eq(AiChatModelDO::getStatus, status)
-                .orderByAsc(AiChatModelDO::getSort));
-    }
-
-}

+ 10 - 9
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatConversationServiceImpl.java

@@ -4,17 +4,18 @@ import cn.hutool.core.collection.CollUtil;
 import cn.hutool.core.lang.Assert;
 import cn.hutool.core.util.ObjUtil;
 import cn.hutool.core.util.ObjectUtil;
+import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum;
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
 import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationCreateMyReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationPageReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationUpdateMyReqVO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatConversationMapper;
 import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeService;
-import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
+import cn.iocoder.yudao.module.ai.service.model.AiModelService;
 import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
 import jakarta.annotation.Resource;
 import lombok.extern.slf4j.Slf4j;
@@ -44,7 +45,7 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
     private AiChatConversationMapper chatConversationMapper;
 
     @Resource
-    private AiChatModelService chatModalService;
+    private AiModelService chatModalService;
     @Resource
     private AiChatRoleService chatRoleService;
     @Resource
@@ -54,9 +55,9 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
     public Long createChatConversationMy(AiChatConversationCreateMyReqVO createReqVO, Long userId) {
         // 1.1 获得 AiChatRoleDO 聊天角色
         AiChatRoleDO role = createReqVO.getRoleId() != null ? chatRoleService.validateChatRole(createReqVO.getRoleId()) : null;
-        // 1.2 获得 AiChatModelDO 聊天模型
-        AiChatModelDO model = role != null && role.getModelId() != null ? chatModalService.validateChatModel(role.getModelId())
-                : chatModalService.getRequiredDefaultChatModel();
+        // 1.2 获得 AiModelDO 聊天模型
+        AiModelDO model = role != null && role.getModelId() != null ? chatModalService.validateModel(role.getModelId())
+                : chatModalService.getRequiredDefaultModel(AiModelTypeEnum.CHAT.getType());
         Assert.notNull(model, "必须找到默认模型");
         validateChatModel(model);
 
@@ -86,9 +87,9 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
             throw exception(CHAT_CONVERSATION_NOT_EXISTS);
         }
         // 1.2 校验模型是否存在(修改模型的情况)
-        AiChatModelDO model = null;
+        AiModelDO model = null;
         if (updateReqVO.getModelId() != null) {
-            model = chatModalService.validateChatModel(updateReqVO.getModelId());
+            model = chatModalService.validateModel(updateReqVO.getModelId());
         }
 
         // 1.3 校验知识库是否存在
@@ -139,7 +140,7 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
         chatConversationMapper.deleteById(id);
     }
 
-    private void validateChatModel(AiChatModelDO model) {
+    private void validateChatModel(AiModelDO model) {
         if (ObjectUtil.isAllNotEmpty(model.getTemperature(), model.getMaxTokens(), model.getMaxContexts())) {
             return;
         }

+ 10 - 13
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java

@@ -15,13 +15,12 @@ import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessage
 import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper;
 import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum;
 import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
 import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService;
-import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
-import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
+import cn.iocoder.yudao.module.ai.service.model.AiModelService;
 import jakarta.annotation.Resource;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.ai.chat.messages.Message;
@@ -63,9 +62,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
     @Resource
     private AiChatConversationService chatConversationService;
     @Resource
-    private AiChatModelService chatModalService;
-    @Resource
-    private AiApiKeyService apiKeyService;
+    private AiModelService modalService;
     @Resource
     private AiKnowledgeSegmentService knowledgeSegmentService;
 
@@ -78,8 +75,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         }
         List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId());
         // 1.2 校验模型
-        AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId());
-        ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
+        AiModelDO model = modalService.validateModel(conversation.getModelId());
+        ChatModel chatModel = modalService.getChatModel(model.getKeyId());
 
         // 2. 插入 user 发送消息
         AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
@@ -112,8 +109,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         }
         List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId());
         // 1.2 校验模型
-        AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId());
-        StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
+        AiModelDO model = modalService.validateModel(conversation.getModelId());
+        StreamingChatModel chatModel = modalService.getChatModel(model.getKeyId());
 
         // 2. 插入 user 发送消息
         AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
@@ -161,8 +158,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         return null;
     }
 
-    private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,List<AiKnowledgeSegmentDO> segmentList,
-                               AiChatModelDO model, AiChatMessageSendReqVO sendReqVO) {
+    private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages, List<AiKnowledgeSegmentDO> segmentList,
+                               AiModelDO model, AiChatMessageSendReqVO sendReqVO) {
         // 1. 构建 Prompt Message 列表
         List<Message> chatMessages = new ArrayList<>();
 
@@ -232,7 +229,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
     }
 
     private AiChatMessageDO createChatMessage(Long conversationId, Long replyId,
-                                              AiChatModelDO model, Long userId, Long roleId,
+                                              AiModelDO model, Long userId, Long roleId,
                                               MessageType messageType, String content, Boolean useContext) {
         AiChatMessageDO message = new AiChatMessageDO().setConversationId(conversationId).setReplyId(replyId)
                 .setModel(model.getModel()).setModelId(model.getId()).setUserId(userId).setRoleId(roleId)

+ 42 - 30
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java

@@ -12,13 +12,17 @@ import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
 import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
 import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
-import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePublicPageReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdateReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
 import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
-import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
+import cn.iocoder.yudao.module.ai.service.model.AiModelService;
 import cn.iocoder.yudao.module.infra.api.file.FileApi;
 import com.alibaba.cloud.ai.dashscope.image.DashScopeImageOptions;
 import jakarta.annotation.Resource;
@@ -55,13 +59,13 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
 public class AiImageServiceImpl implements AiImageService {
 
     @Resource
-    private AiImageMapper imageMapper;
+    private AiModelService modelService;
 
     @Resource
-    private FileApi fileApi;
+    private AiImageMapper imageMapper;
 
     @Resource
-    private AiApiKeyService apiKeyService;
+    private FileApi fileApi;
 
     @Override
     public PageResult<AiImageDO> getImagePageMy(Long userId, AiImagePageReqVO pageReqVO) {
@@ -88,23 +92,31 @@ public class AiImageServiceImpl implements AiImageService {
 
     @Override
     public Long drawImage(Long userId, AiImageDrawReqVO drawReqVO) {
-        // 1. 保存数据库
-        AiImageDO image = BeanUtils.toBean(drawReqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false)
-                .setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
+        // 1. 校验模型
+        AiModelDO model = modelService.validateModel(drawReqVO.getModelId());
+
+        // 2. 保存数据库
+        AiImageDO image = BeanUtils.toBean(drawReqVO, AiImageDO.class).setUserId(userId)
+                .setPlatform(model.getPlatform()).setModelId(model.getId()).setModel(model.getModel())
+                .setPublicStatus(false).setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
         imageMapper.insert(image);
-        // 2. 异步绘制,后续前端通过返回的 id 进行轮询结果
-        getSelf().executeDrawImage(image, drawReqVO);
+
+        // 3. 异步绘制,后续前端通过返回的 id 进行轮询结果
+        getSelf().executeDrawImage(image, drawReqVO, model);
         return image.getId();
     }
 
     @Async
-    public void executeDrawImage(AiImageDO image, AiImageDrawReqVO req) {
+    public void executeDrawImage(AiImageDO image, AiImageDrawReqVO reqVO, AiModelDO model) {
         try {
             // 1.1 构建请求
-            ImageOptions request = buildImageOptions(req);
+            ImageOptions request = buildImageOptions(reqVO, model);
             // 1.2 执行请求
-            ImageModel imageModel = apiKeyService.getImageModel(AiPlatformEnum.validatePlatform(req.getPlatform()));
-            ImageResponse response = imageModel.call(new ImagePrompt(req.getPrompt(), request));
+            ImageModel imageModel = modelService.getImageModel(model.getId());
+            ImageResponse response = imageModel.call(new ImagePrompt(reqVO.getPrompt(), request));
+            if (response.getResult() == null) {
+                throw new IllegalArgumentException("生成结果为空");
+            }
 
             // 2. 上传到文件服务
             String b64Json = response.getResult().getOutput().getB64Json();
@@ -116,25 +128,25 @@ public class AiImageServiceImpl implements AiImageService {
             imageMapper.updateById(new AiImageDO().setId(image.getId()).setStatus(AiImageStatusEnum.SUCCESS.getStatus())
                     .setPicUrl(filePath).setFinishTime(LocalDateTime.now()));
         } catch (Exception ex) {
-            log.error("[doDall][image({}) 生成异常]", image, ex);
+            log.error("[executeDrawImage][image({}) 生成异常]", image, ex);
             imageMapper.updateById(new AiImageDO().setId(image.getId())
                     .setStatus(AiImageStatusEnum.FAIL.getStatus())
                     .setErrorMessage(ex.getMessage()).setFinishTime(LocalDateTime.now()));
         }
     }
 
-    private static ImageOptions buildImageOptions(AiImageDrawReqVO draw) {
-        if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.OPENAI.getPlatform())) {
+    private static ImageOptions buildImageOptions(AiImageDrawReqVO draw, AiModelDO model) {
+        if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.OPENAI.getPlatform())) {
             // https://platform.openai.com/docs/api-reference/images/create
-            return OpenAiImageOptions.builder().withModel(draw.getModel())
+            return OpenAiImageOptions.builder().withModel(model.getModel())
                     .withHeight(draw.getHeight()).withWidth(draw.getWidth())
                     .withStyle(MapUtil.getStr(draw.getOptions(), "style")) // 风格
                     .withResponseFormat("b64_json")
                     .build();
-        } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.STABLE_DIFFUSION.getPlatform())) {
+        } else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.STABLE_DIFFUSION.getPlatform())) {
             // https://platform.stability.ai/docs/api-reference#tag/SDXL-and-SD1.6/operation/textToImage
             // https://platform.stability.ai/docs/api-reference#tag/Text-to-Image/operation/textToImage
-            return StabilityAiImageOptions.builder().model(draw.getModel())
+            return StabilityAiImageOptions.builder().model(model.getModel())
                     .height(draw.getHeight()).width(draw.getWidth())
                     .seed(Long.valueOf(draw.getOptions().get("seed")))
                     .cfgScale(Float.valueOf(draw.getOptions().get("scale")))
@@ -143,22 +155,22 @@ public class AiImageServiceImpl implements AiImageService {
                     .stylePreset(String.valueOf(draw.getOptions().get("stylePreset")))
                     .clipGuidancePreset(String.valueOf(draw.getOptions().get("clipGuidancePreset")))
                     .build();
-        } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.TONG_YI.getPlatform())) {
+        } else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.TONG_YI.getPlatform())) {
             return DashScopeImageOptions.builder()
-                    .withModel(draw.getModel()).withN(1)
+                    .withModel(model.getModel()).withN(1)
                     .withHeight(draw.getHeight()).withWidth(draw.getWidth())
                     .build();
-        } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.YI_YAN.getPlatform())) {
+        } else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.YI_YAN.getPlatform())) {
             return QianFanImageOptions.builder()
-                    .model(draw.getModel()).N(1)
+                    .model(model.getModel()).N(1)
                     .height(draw.getHeight()).width(draw.getWidth())
                     .build();
-        } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.ZHI_PU.getPlatform())) {
+        } else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.ZHI_PU.getPlatform())) {
             return ZhiPuAiImageOptions.builder()
-                    .model(draw.getModel())
+                    .model(model.getModel())
                     .build();
         }
-        throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform());
+        throw new IllegalArgumentException("不支持的 AI 平台:" + model.getPlatform());
     }
 
     @Override
@@ -206,7 +218,7 @@ public class AiImageServiceImpl implements AiImageService {
     @Override
     @Transactional(rollbackFor = Exception.class)
     public Long midjourneyImagine(Long userId, AiMidjourneyImagineReqVO reqVO) {
-        MidjourneyApi midjourneyApi = apiKeyService.getMidjourneyApi();
+        MidjourneyApi midjourneyApi = modelService.getMidjourneyApi();
         // 1. 保存数据库
         AiImageDO image = BeanUtils.toBean(reqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false)
                 .setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus())
@@ -237,7 +249,7 @@ public class AiImageServiceImpl implements AiImageService {
 
     @Override
     public Integer midjourneySync() {
-        MidjourneyApi midjourneyApi = apiKeyService.getMidjourneyApi();
+        MidjourneyApi midjourneyApi = modelService.getMidjourneyApi();
         // 1.1 获取 Midjourney 平台,状态在 “进行中” 的 image
         List<AiImageDO> imageList = imageMapper.selectListByStatusAndPlatform(
                 AiImageStatusEnum.IN_PROGRESS.getStatus(), AiPlatformEnum.MIDJOURNEY.getPlatform());
@@ -308,7 +320,7 @@ public class AiImageServiceImpl implements AiImageService {
 
     @Override
     public Long midjourneyAction(Long userId, AiMidjourneyActionReqVO reqVO) {
-        MidjourneyApi midjourneyApi = apiKeyService.getMidjourneyApi();
+        MidjourneyApi midjourneyApi = modelService.getMidjourneyApi();
         // 1.1 检查 image
         AiImageDO image = validateImageExists(reqVO.getId());
         if (ObjUtil.notEqual(userId, image.getUserId())) {

+ 11 - 7
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentServiceImpl.java

@@ -7,14 +7,17 @@ import cn.hutool.core.util.StrUtil;
 import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
 import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
-import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.*;
+import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentPageReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentProcessRespVO;
+import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentSaveReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateStatusReqVO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper;
-import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
 import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchReqBO;
 import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchRespBO;
+import cn.iocoder.yudao.module.ai.service.model.AiModelService;
 import jakarta.annotation.Resource;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.ai.document.Document;
@@ -33,8 +36,8 @@ import java.util.Objects;
 
 import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
 import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
-import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGMENT_NOT_EXISTS;
 import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGMENT_CONTENT_TOO_LONG;
+import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGMENT_NOT_EXISTS;
 
 /**
  * AI 知识库分片 Service 实现类
@@ -58,7 +61,7 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
     @Lazy // 延迟加载,避免循环依赖
     private AiKnowledgeDocumentService knowledgeDocumentService;
     @Resource
-    private AiApiKeyService apiKeyService;
+    private AiModelService modelService;
 
     @Resource
     private TokenCountEstimator tokenCountEstimator;
@@ -180,7 +183,7 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
         AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(reqBO.getKnowledgeId());
 
         // 2.1 向量检索
-        VectorStore vectorStore = apiKeyService.getOrCreateVectorStoreByModelId(knowledge.getEmbeddingModelId());
+        VectorStore vectorStore = getVectorStoreById(knowledge);
         List<Document> documents = vectorStore.similaritySearch(SearchRequest.builder()
                 .query(reqBO.getContent())
                 .topK(ObjUtil.defaultIfNull(reqBO.getTopK(), knowledge.getTopK()))
@@ -251,11 +254,12 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
     }
 
     private VectorStore getVectorStoreById(AiKnowledgeDO knowledge) {
-        return apiKeyService.getOrCreateVectorStoreByModelId(knowledge.getEmbeddingModelId());
+        return modelService.getOrCreateVectorStore(knowledge.getEmbeddingModelId());
     }
 
     private VectorStore getVectorStoreById(Long knowledgeId) {
-        return getVectorStoreById(knowledgeService.validateKnowledgeExists(knowledgeId));
+        AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(knowledgeId);
+        return getVectorStoreById(knowledge);
     }
 
     private static List<Document> splitContentByToken(String content, Integer segmentMaxTokens) {

+ 5 - 5
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeServiceImpl.java

@@ -5,9 +5,9 @@ import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
 import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgePageReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeSaveReqVO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeMapper;
-import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
+import cn.iocoder.yudao.module.ai.service.model.AiModelService;
 import jakarta.annotation.Resource;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.stereotype.Service;
@@ -28,12 +28,12 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
     private AiKnowledgeMapper knowledgeMapper;
 
     @Resource
-    private AiChatModelService chatModelService;
+    private AiModelService chatModelService;
 
     @Override
     public Long createKnowledge(AiKnowledgeSaveReqVO createReqVO) {
         // 1. 校验模型配置
-        AiChatModelDO model = chatModelService.validateChatModel(createReqVO.getEmbeddingModelId());
+        AiModelDO model = chatModelService.validateModel(createReqVO.getEmbeddingModelId());
 
         // 2. 插入知识库
         AiKnowledgeDO knowledge = BeanUtils.toBean(createReqVO, AiKnowledgeDO.class)
@@ -47,7 +47,7 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
         // 1.1 校验知识库存在
         validateKnowledgeExists(updateReqVO.getId());
         // 1.2 校验模型配置
-        AiChatModelDO model = chatModelService.validateChatModel(updateReqVO.getEmbeddingModelId());
+        AiModelDO model = chatModelService.validateModel(updateReqVO.getEmbeddingModelId());
 
         // 2. 更新知识库
         AiKnowledgeDO updateObj = BeanUtils.toBean(updateReqVO, AiKnowledgeDO.class)

+ 13 - 15
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java

@@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.ai.service.mindmap;
 import cn.hutool.core.collection.CollUtil;
 import cn.hutool.core.lang.Assert;
 import cn.hutool.core.util.StrUtil;
+import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum;
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
 import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
 import cn.iocoder.yudao.framework.common.pojo.CommonResult;
@@ -12,14 +13,13 @@ import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
 import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapGenerateReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapPageReqVO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.mindmap.AiMindMapMapper;
 import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum;
 import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
-import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
-import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
 import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
+import cn.iocoder.yudao.module.ai.service.model.AiModelService;
 import jakarta.annotation.Resource;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.ai.chat.messages.Message;
@@ -50,9 +50,7 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.MIND_MAP_NOT_E
 public class AiMindMapServiceImpl implements AiMindMapService {
 
     @Resource
-    private AiApiKeyService apiKeyService;
-    @Resource
-    private AiChatModelService chatModalService;
+    private AiModelService modalService;
     @Resource
     private AiChatRoleService chatRoleService;
 
@@ -65,17 +63,17 @@ public class AiMindMapServiceImpl implements AiMindMapService {
         AiChatRoleDO role = CollUtil.getFirst(
                 chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName()));
         // 1.1 获取导图执行模型
-        AiChatModelDO model = getModel(role);
+        AiModelDO model = getModel(role);
         // 1.2 获取角色设定消息
         String systemMessage = role != null && StrUtil.isNotBlank(role.getSystemMessage())
                 ? role.getSystemMessage() : AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage();
         // 1.3 校验平台
         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
-        ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
+        ChatModel chatModel = modalService.getChatModel(model.getId());
 
         // 2. 插入思维导图信息
-        AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class,
-                mindMap -> mindMap.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
+        AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class, mindMap -> mindMap.setUserId(userId)
+                .setPlatform(platform.getPlatform()).setModelId(model.getId()).setModel(model.getModel()));
         mindMapMapper.insert(mindMapDO);
 
         // 3.1 构建 Prompt,并进行调用
@@ -103,7 +101,7 @@ public class AiMindMapServiceImpl implements AiMindMapService {
 
     }
 
-    private Prompt buildPrompt(AiMindMapGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) {
+    private Prompt buildPrompt(AiMindMapGenerateReqVO generateReqVO, AiModelDO model, String systemMessage) {
         // 1. 构建 message 列表
         List<Message> chatMessages = buildMessages(generateReqVO, systemMessage);
         // 2. 构建 options 对象
@@ -123,13 +121,13 @@ public class AiMindMapServiceImpl implements AiMindMapService {
         return chatMessages;
     }
 
-    private AiChatModelDO getModel(AiChatRoleDO role) {
-        AiChatModelDO model = null;
+    private AiModelDO getModel(AiChatRoleDO role) {
+        AiModelDO model = null;
         if (role != null && role.getModelId() != null) {
-            model = chatModalService.getChatModel(role.getModelId());
+            model = modalService.getModel(role.getModelId());
         }
         if (model == null) {
-            model = chatModalService.getRequiredDefaultChatModel();
+            model = modalService.getRequiredDefaultModel(AiModelTypeEnum.CHAT.getType());
         }
         Assert.notNull(model, "[AI] 获取不到模型");
         return model;

+ 4 - 47
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java

@@ -1,16 +1,10 @@
 package cn.iocoder.yudao.module.ai.service.model;
 
-import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
-import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
-import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
 import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
 import jakarta.validation.Valid;
-import org.springframework.ai.chat.model.ChatModel;
-import org.springframework.ai.image.ImageModel;
-import org.springframework.ai.vectorstore.VectorStore;
 
 import java.util.List;
 
@@ -74,50 +68,13 @@ public interface AiApiKeyService {
      */
     List<AiApiKeyDO> getApiKeyList();
 
-    // ========== 与 spring-ai 集成 ==========
-
-    /**
-     * 获得 ChatModel 对象
-     *
-     * @param id 编号
-     * @return ChatModel 对象
-     */
-    ChatModel getChatModel(Long id);
-
     /**
-     * 获得 ImageModel 对象
-     *
-     * TODO 可优化点:目前默认获取 platform 对应的第一个开启的配置用于绘画;后续可以支持配置选择
+     * 获得默认的 API 密钥
      *
      * @param platform 平台
-     * @return ImageModel 对象
-     */
-    ImageModel getImageModel(AiPlatformEnum platform);
-
-    /**
-     * 获得 MidjourneyApi 对象
-     *
-     * TODO 可优化点:目前默认获取 Midjourney 对应的第一个开启的配置用于绘画;后续可以支持配置选择
-     *
-     * @return MidjourneyApi 对象
-     */
-    MidjourneyApi getMidjourneyApi();
-
-    /**
-     * 获得 SunoApi 对象
-     *
-     * TODO 可优化点:目前默认获取 Suno 对应的第一个开启的配置用于音乐;后续可以支持配置选择
-     *
-     * @return SunoApi 对象
-     */
-    SunoApi getSunoApi();
-
-    /**
-     * 获得 VectorStore 对象
-     *
-     * @param modelId 编号
-     * @return VectorStore 对象
+     * @param status 状态
+     * @return API 密钥
      */
-    VectorStore getOrCreateVectorStoreByModelId(Long modelId);
+    AiApiKeyDO getRequiredDefaultApiKey(String platform, Integer status);
 
 }

+ 6 - 68
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java

@@ -1,31 +1,21 @@
 package cn.iocoder.yudao.module.ai.service.model;
 
-import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
-import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
-import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
-import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
 import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
 import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
 import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper;
 import jakarta.annotation.Resource;
-import org.springframework.ai.chat.model.ChatModel;
-import org.springframework.ai.embedding.EmbeddingModel;
-import org.springframework.ai.image.ImageModel;
-import org.springframework.ai.vectorstore.SimpleVectorStore;
-import org.springframework.ai.vectorstore.VectorStore;
-import org.springframework.context.annotation.Lazy;
 import org.springframework.stereotype.Service;
 import org.springframework.validation.annotation.Validated;
 
 import java.util.List;
 
 import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
-import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
+import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.API_KEY_DISABLE;
+import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.API_KEY_NOT_EXISTS;
 
 /**
  * AI API 密钥 Service 实现类
@@ -39,14 +29,6 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
     @Resource
     private AiApiKeyMapper apiKeyMapper;
 
-    // TODO @芋艿:后续要不要改?
-    @Resource
-    @Lazy // 延迟加载,解决渲染依赖
-    private AiChatModelService chatModelService;
-
-    @Resource
-    private AiModelFactory modelFactory;
-
     @Override
     public Long createApiKey(AiApiKeySaveReqVO createReqVO) {
         // 插入
@@ -105,57 +87,13 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
         return apiKeyMapper.selectList();
     }
 
-    // ========== 与 spring-ai 集成 ==========
-
-    @Override
-    public ChatModel getChatModel(Long id) {
-        AiApiKeyDO apiKey = validateApiKey(id);
-        AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
-        return modelFactory.getOrCreateChatModel(platform, apiKey.getApiKey(), apiKey.getUrl());
-    }
-
-    @Override
-    public ImageModel getImageModel(AiPlatformEnum platform) {
-        AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
-        if (apiKey == null) {
-            throw exception(API_KEY_IMAGE_NODE_FOUND, platform.getName());
-        }
-        return modelFactory.getOrCreateImageModel(platform, apiKey.getApiKey(), apiKey.getUrl());
-    }
-
     @Override
-    public MidjourneyApi getMidjourneyApi() {
-        AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(
-                AiPlatformEnum.MIDJOURNEY.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
+    public AiApiKeyDO getRequiredDefaultApiKey(String platform, Integer status) {
+        AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform, status);
         if (apiKey == null) {
-            throw exception(API_KEY_MIDJOURNEY_NOT_FOUND);
-        }
-        return modelFactory.getOrCreateMidjourneyApi(apiKey.getApiKey(), apiKey.getUrl());
-    }
-
-    @Override
-    public SunoApi getSunoApi() {
-        AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(
-                AiPlatformEnum.SUNO.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
-        if (apiKey == null) {
-            throw exception(API_KEY_SUNO_NOT_FOUND);
+            throw exception(API_KEY_NOT_EXISTS);
         }
-        return modelFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl());
-    }
-
-    @Override
-    public VectorStore getOrCreateVectorStoreByModelId(Long modelId) {
-        // 获取模型 + 密钥
-        AiChatModelDO chatModel = chatModelService.validateChatModel(modelId);
-        AiApiKeyDO apiKey = validateApiKey(chatModel.getKeyId());
-        AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
-
-        // 创建或获取 EmbeddingModel 对象
-        EmbeddingModel embeddingModel = modelFactory.getOrCreateEmbeddingModel(platform, apiKey.getApiKey(),
-                apiKey.getUrl(), chatModel.getModel());
-
-        // 创建或获取 VectorStore 对象
-        return modelFactory.getOrCreateVectorStore(SimpleVectorStore.class, embeddingModel);
+        return apiKey;
     }
 
 }

+ 0 - 92
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelService.java

@@ -1,92 +0,0 @@
-package cn.iocoder.yudao.module.ai.service.model;
-
-import cn.iocoder.yudao.framework.common.pojo.PageResult;
-import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelPageReqVO;
-import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelSaveReqVO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
-import jakarta.validation.Valid;
-
-import java.util.Collection;
-import java.util.List;
-
-import java.util.Set;
-
-/**
- * AI 聊天模型 Service 接口
- *
- * @author fansili
- * @since 2024/4/24 19:42
- */
-public interface AiChatModelService {
-
-    /**
-     * 创建聊天模型
-     *
-     * @param createReqVO 创建信息
-     * @return 编号
-     */
-    Long createChatModel(@Valid AiChatModelSaveReqVO createReqVO);
-
-    /**
-     * 更新聊天模型
-     *
-     * @param updateReqVO 更新信息
-     */
-    void updateChatModel(@Valid AiChatModelSaveReqVO updateReqVO);
-
-    /**
-     * 删除聊天模型
-     *
-     * @param id 编号
-     */
-    void deleteChatModel(Long id);
-
-    /**
-     * 获得聊天模型
-     *
-     * @param id 编号
-     * @return 聊天模型
-     */
-    AiChatModelDO getChatModel(Long id);
-
-    /**
-     * 获得默认的聊天模型
-     *
-     * 如果获取不到,则抛出 {@link cn.iocoder.yudao.framework.common.exception.ServiceException} 业务异常
-     *
-     * @return 聊天模型
-     */
-    AiChatModelDO getRequiredDefaultChatModel();
-
-    /**
-     * 获得聊天模型分页
-     *
-     * @param pageReqVO 分页查询
-     * @return 聊天模型分页
-     */
-    PageResult<AiChatModelDO> getChatModelPage(AiChatModelPageReqVO pageReqVO);
-
-    /**
-     * 校验聊天模型
-     *
-     * @param id 编号
-     * @return 聊天模型
-     */
-    AiChatModelDO validateChatModel(Long id);
-
-    /**
-     * 获得聊天模型列表
-     *
-     * @param status 状态
-     * @return 聊天模型列表
-     */
-    List<AiChatModelDO> getChatModelListByStatus(Integer status);
-
-    /**
-     * 获得聊天模型列表
-     *
-     * @param ids 编号数组
-     * @return 模型列表
-     */
-    List<AiChatModelDO> getChatModelList(Collection<Long> ids);
-}

+ 91 - 37
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelServiceImpl.java

@@ -1,114 +1,168 @@
 package cn.iocoder.yudao.module.ai.service.model;
 
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
+import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
+import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
+import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
 import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
 import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
-import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelPageReqVO;
-import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelSaveReqVO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
-import cn.iocoder.yudao.module.ai.dal.mysql.model.AiChatModelMapper;
+import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelPageReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelSaveReqVO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
+import cn.iocoder.yudao.module.ai.dal.mysql.model.AiChatMapper;
 import jakarta.annotation.Resource;
+import org.springframework.ai.chat.model.ChatModel;
+import org.springframework.ai.embedding.EmbeddingModel;
+import org.springframework.ai.image.ImageModel;
+import org.springframework.ai.vectorstore.SimpleVectorStore;
+import org.springframework.ai.vectorstore.VectorStore;
 import org.springframework.stereotype.Service;
 import org.springframework.validation.annotation.Validated;
 
-import java.util.Collection;
 import java.util.List;
 
 import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
 import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
 
 /**
- * AI 聊天模型 Service 实现类
+ * AI 模型 Service 实现类
  *
  * @author fansili
  */
 @Service
 @Validated
-public class AiChatModelServiceImpl implements AiChatModelService {
+public class AiChatModelServiceImpl implements AiModelService {
 
     @Resource
     private AiApiKeyService apiKeyService;
 
     @Resource
-    private AiChatModelMapper chatModelMapper;
+    private AiChatMapper modelMapper;
+
+    @Resource
+    private AiModelFactory modelFactory;
 
     @Override
-    public Long createChatModel(AiChatModelSaveReqVO createReqVO) {
+    public Long createModel(AiModelSaveReqVO createReqVO) {
         // 1. 校验
         AiPlatformEnum.validatePlatform(createReqVO.getPlatform());
         apiKeyService.validateApiKey(createReqVO.getKeyId());
 
         // 2. 插入
-        AiChatModelDO chatModel = BeanUtils.toBean(createReqVO, AiChatModelDO.class);
-        chatModelMapper.insert(chatModel);
-        return chatModel.getId();
+        AiModelDO model = BeanUtils.toBean(createReqVO, AiModelDO.class);
+        modelMapper.insert(model);
+        return model.getId();
     }
 
     @Override
-    public void updateChatModel(AiChatModelSaveReqVO updateReqVO) {
+    public void updateModel(AiModelSaveReqVO updateReqVO) {
         // 1. 校验
-        validateChatModelExists(updateReqVO.getId());
+        validateModelExists(updateReqVO.getId());
         AiPlatformEnum.validatePlatform(updateReqVO.getPlatform());
         apiKeyService.validateApiKey(updateReqVO.getKeyId());
 
         // 2. 更新
-        AiChatModelDO updateObj = BeanUtils.toBean(updateReqVO, AiChatModelDO.class);
-        chatModelMapper.updateById(updateObj);
+        AiModelDO updateObj = BeanUtils.toBean(updateReqVO, AiModelDO.class);
+        modelMapper.updateById(updateObj);
     }
 
     @Override
-    public void deleteChatModel(Long id) {
+    public void deleteModel(Long id) {
         // 校验存在
-        validateChatModelExists(id);
+        validateModelExists(id);
         // 删除
-        chatModelMapper.deleteById(id);
+        modelMapper.deleteById(id);
     }
 
-    private AiChatModelDO validateChatModelExists(Long id) {
-        AiChatModelDO model = chatModelMapper.selectById(id);
-        if (chatModelMapper.selectById(id) == null) {
-            throw exception(CHAT_MODEL_NOT_EXISTS);
+    private AiModelDO validateModelExists(Long id) {
+        AiModelDO model = modelMapper.selectById(id);
+        if (modelMapper.selectById(id) == null) {
+            throw exception(MODEL_NOT_EXISTS);
         }
         return model;
     }
 
     @Override
-    public AiChatModelDO getChatModel(Long id) {
-        return chatModelMapper.selectById(id);
+    public AiModelDO getModel(Long id) {
+        return modelMapper.selectById(id);
     }
 
     @Override
-    public AiChatModelDO getRequiredDefaultChatModel() {
-        AiChatModelDO model = chatModelMapper.selectFirstByStatus(CommonStatusEnum.ENABLE.getStatus());
+    public AiModelDO getRequiredDefaultModel(Integer type) {
+        AiModelDO model = modelMapper.selectFirstByStatus(type, CommonStatusEnum.ENABLE.getStatus());
         if (model == null) {
-            throw exception(CHAT_MODEL_DEFAULT_NOT_EXISTS);
+            throw exception(MODEL_DEFAULT_NOT_EXISTS);
         }
         return model;
     }
 
     @Override
-    public PageResult<AiChatModelDO> getChatModelPage(AiChatModelPageReqVO pageReqVO) {
-        return chatModelMapper.selectPage(pageReqVO);
+    public PageResult<AiModelDO> getModelPage(AiModelPageReqVO pageReqVO) {
+        return modelMapper.selectPage(pageReqVO);
     }
 
     @Override
-    public AiChatModelDO validateChatModel(Long id) {
-        AiChatModelDO model = validateChatModelExists(id);
+    public AiModelDO validateModel(Long id) {
+        AiModelDO model = validateModelExists(id);
         if (CommonStatusEnum.isDisable(model.getStatus())) {
-            throw exception(CHAT_MODEL_DISABLE);
+            throw exception(MODEL_DISABLE);
         }
         return model;
     }
 
     @Override
-    public List<AiChatModelDO> getChatModelListByStatus(Integer status) {
-        return chatModelMapper.selectList(status);
+    public List<AiModelDO> getModelListByStatusAndType(Integer status, Integer type,
+                                                       String platform) {
+        return modelMapper.selectListByStatusAndType(status, type, platform);
+    }
+
+    // ========== 与 Spring AI 集成 ==========
+
+    @Override
+    public ChatModel getChatModel(Long id) {
+        AiModelDO model = validateModel(id);
+        AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId());
+        AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
+        return modelFactory.getOrCreateChatModel(platform, apiKey.getApiKey(), apiKey.getUrl());
+    }
+
+    @Override
+    public ImageModel getImageModel(Long id) {
+        AiModelDO model = validateModel(id);
+        AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId());
+        AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
+        return modelFactory.getOrCreateImageModel(platform, apiKey.getApiKey(), apiKey.getUrl());
+    }
+
+    @Override
+    public MidjourneyApi getMidjourneyApi() {
+        AiApiKeyDO apiKey = apiKeyService.getRequiredDefaultApiKey(
+                AiPlatformEnum.MIDJOURNEY.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
+        return modelFactory.getOrCreateMidjourneyApi(apiKey.getApiKey(), apiKey.getUrl());
+    }
+
+    @Override
+    public SunoApi getSunoApi() {
+        AiApiKeyDO apiKey = apiKeyService.getRequiredDefaultApiKey(
+                AiPlatformEnum.SUNO.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
+        return modelFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl());
     }
 
     @Override
-    public List<AiChatModelDO> getChatModelList(Collection<Long> ids) {
-        return chatModelMapper.selectBatchIds(ids);
+    public VectorStore getOrCreateVectorStore(Long id) {
+        // 获取模型 + 密钥
+        AiModelDO model = validateModel(id);
+        AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId());
+        AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
+
+        // 创建或获取 EmbeddingModel 对象
+        EmbeddingModel embeddingModel = modelFactory.getOrCreateEmbeddingModel(
+                platform, apiKey.getApiKey(), apiKey.getUrl(), model.getModel());
+
+        // 创建或获取 VectorStore 对象
+        return modelFactory.getOrCreateVectorStore(SimpleVectorStore.class, embeddingModel);
     }
 
 }

+ 131 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiModelService.java

@@ -0,0 +1,131 @@
+package cn.iocoder.yudao.module.ai.service.model;
+
+import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
+import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
+import cn.iocoder.yudao.framework.common.pojo.PageResult;
+import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelPageReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelSaveReqVO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
+import jakarta.validation.Valid;
+import org.springframework.ai.chat.model.ChatModel;
+import org.springframework.ai.image.ImageModel;
+import org.springframework.ai.vectorstore.VectorStore;
+
+import javax.annotation.Nullable;
+import java.util.List;
+
+/**
+ * AI 模型 Service 接口
+ *
+ * @author fansili
+ * @since 2024/4/24 19:42
+ */
+public interface AiModelService {
+
+    /**
+     * 创建模型
+     *
+     * @param createReqVO 创建信息
+     * @return 编号
+     */
+    Long createModel(@Valid AiModelSaveReqVO createReqVO);
+
+    /**
+     * 更新模型
+     *
+     * @param updateReqVO 更新信息
+     */
+    void updateModel(@Valid AiModelSaveReqVO updateReqVO);
+
+    /**
+     * 删除模型
+     *
+     * @param id 编号
+     */
+    void deleteModel(Long id);
+
+    /**
+     * 获得模型
+     *
+     * @param id 编号
+     * @return 模型
+     */
+    AiModelDO getModel(Long id);
+
+    /**
+     * 获得默认的模型
+     *
+     * 如果获取不到,则抛出 {@link cn.iocoder.yudao.framework.common.exception.ServiceException} 业务异常
+     *
+     * @return 模型
+     */
+    AiModelDO getRequiredDefaultModel(Integer type);
+
+    /**
+     * 获得模型分页
+     *
+     * @param pageReqVO 分页查询
+     * @return 模型分页
+     */
+    PageResult<AiModelDO> getModelPage(AiModelPageReqVO pageReqVO);
+
+    /**
+     * 校验模型是否可使用
+     *
+     * @param id 编号
+     * @return 模型
+     */
+    AiModelDO validateModel(Long id);
+
+    /**
+     * 获得模型列表
+     *
+     * @param status 状态
+     * @param type 类型
+     * @param platform 平台,允许空
+     * @return 模型列表
+     */
+    List<AiModelDO> getModelListByStatusAndType(Integer status, Integer type,
+                                                @Nullable String platform);
+
+    // ========== 与 Spring AI 集成 ==========
+
+    /**
+     * 获得 ChatModel 对象
+     *
+     * @param id 编号
+     * @return ChatModel 对象
+     */
+    ChatModel getChatModel(Long id);
+
+    /**
+     * 获得 ImageModel 对象
+     *
+     * @param id 编号
+     * @return ImageModel 对象
+     */
+    ImageModel getImageModel(Long id);
+
+    /**
+     * 获得 MidjourneyApi 对象
+     *
+     * @return MidjourneyApi 对象
+     */
+    MidjourneyApi getMidjourneyApi();
+
+    /**
+     * 获得 SunoApi 对象
+     *
+     * @return SunoApi 对象
+     */
+    SunoApi getSunoApi();
+
+    /**
+     * 获得 VectorStore 对象
+     *
+     * @param id 编号
+     * @return VectorStore 对象
+     */
+    VectorStore getOrCreateVectorStore(Long id);
+
+}

+ 4 - 4
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/music/AiMusicServiceImpl.java

@@ -16,7 +16,7 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.music.AiMusicDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.music.AiMusicMapper;
 import cn.iocoder.yudao.module.ai.enums.music.AiMusicGenerateModeEnum;
 import cn.iocoder.yudao.module.ai.enums.music.AiMusicStatusEnum;
-import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
+import cn.iocoder.yudao.module.ai.service.model.AiModelService;
 import cn.iocoder.yudao.module.infra.api.file.FileApi;
 import jakarta.annotation.Resource;
 import lombok.extern.slf4j.Slf4j;
@@ -41,7 +41,7 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.MUSIC_NOT_EXIS
 public class AiMusicServiceImpl implements AiMusicService {
 
     @Resource
-    private AiApiKeyService apiKeyService;
+    private AiModelService modelService;
 
     @Resource
     private AiMusicMapper musicMapper;
@@ -53,7 +53,7 @@ public class AiMusicServiceImpl implements AiMusicService {
     @Transactional(rollbackFor = Exception.class)
     public List<Long> generateMusic(Long userId, AiSunoGenerateReqVO reqVO) {
         // 1. 调用 Suno 生成音乐
-        SunoApi sunoApi = apiKeyService.getSunoApi();
+        SunoApi sunoApi = modelService.getSunoApi();
         List<SunoApi.MusicData> musicDataList;
         if (Objects.equals(AiMusicGenerateModeEnum.DESCRIPTION.getMode(), reqVO.getGenerateMode())) {
             // 1.1 描述模式
@@ -88,7 +88,7 @@ public class AiMusicServiceImpl implements AiMusicService {
         log.info("[syncMusic][Suno 开始同步, 共 ({}) 个任务]", streamingTask.size());
 
         // GET 请求,为避免参数过长,分批次处理
-        SunoApi sunoApi = apiKeyService.getSunoApi();
+        SunoApi sunoApi = modelService.getSunoApi();
         CollUtil.split(streamingTask, 36).forEach(chunkList -> {
             Map<String, Long> taskIdMap = convertMap(chunkList, AiMusicDO::getTaskId, AiMusicDO::getId);
             List<SunoApi.MusicData> musicTaskList = sunoApi.getMusicList(new ArrayList<>(taskIdMap.keySet()));

+ 15 - 17
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java

@@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.ai.service.write;
 import cn.hutool.core.collection.CollUtil;
 import cn.hutool.core.lang.Assert;
 import cn.hutool.core.util.StrUtil;
+import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum;
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
 import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
 import cn.iocoder.yudao.framework.common.pojo.CommonResult;
@@ -11,17 +12,16 @@ import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
 import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
 import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWritePageReqVO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.write.AiWriteMapper;
 import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum;
 import cn.iocoder.yudao.module.ai.enums.DictTypeConstants;
 import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
 import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
-import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
-import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
 import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
+import cn.iocoder.yudao.module.ai.service.model.AiModelService;
 import cn.iocoder.yudao.module.system.api.dict.DictDataApi;
 import jakarta.annotation.Resource;
 import lombok.extern.slf4j.Slf4j;
@@ -54,17 +54,15 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.WRITE_NOT_EXIS
 public class AiWriteServiceImpl implements AiWriteService {
 
     @Resource
-    private AiApiKeyService apiKeyService;
-    @Resource
-    private AiChatModelService chatModalService;
+    private AiModelService chatModalService;
     @Resource
     private AiChatRoleService chatRoleService;
 
     @Resource
-    private DictDataApi dictDataApi;
+    private AiWriteMapper writeMapper;
 
     @Resource
-    private AiWriteMapper writeMapper;
+    private DictDataApi dictDataApi;
 
     @Override
     public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
@@ -72,17 +70,17 @@ public class AiWriteServiceImpl implements AiWriteService {
         AiChatRoleDO writeRole = CollUtil.getFirst(
                 chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName()));
         // 1.1 获取写作执行模型
-        AiChatModelDO model = getModel(writeRole);
+        AiModelDO model = getModel(writeRole);
         // 1.2 获取角色设定消息
         String systemMessage = Objects.nonNull(writeRole) && StrUtil.isNotBlank(writeRole.getSystemMessage())
                 ? writeRole.getSystemMessage() : AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage();
         // 1.3 校验平台
         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
-        StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
+        StreamingChatModel chatModel = chatModalService.getChatModel(model.getKeyId());
 
         // 2. 插入写作信息
-        AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class,
-                write -> write.setUserId(userId).setPlatform(platform.getPlatform()).setModel(model.getModel()));
+        AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class, write -> write.setUserId(userId)
+                        .setPlatform(platform.getPlatform()).setModelId(model.getId()).setModel(model.getModel()));
         writeMapper.insert(writeDO);
 
         // 3.1 构建 Prompt,并进行调用
@@ -109,19 +107,19 @@ public class AiWriteServiceImpl implements AiWriteService {
         }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR)));
     }
 
-    private AiChatModelDO getModel(AiChatRoleDO writeRole) {
-        AiChatModelDO model = null;
+    private AiModelDO getModel(AiChatRoleDO writeRole) {
+        AiModelDO model = null;
         if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) {
-            model = chatModalService.getChatModel(writeRole.getModelId());
+            model = chatModalService.getModel(writeRole.getModelId());
         }
         if (model == null) {
-            model = chatModalService.getRequiredDefaultChatModel();
+            model = chatModalService.getRequiredDefaultModel(AiModelTypeEnum.CHAT.getType());
         }
         Assert.notNull(model, "[AI] 获取不到模型");
         return model;
     }
 
-    private Prompt buildPrompt(AiWriteGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) {
+    private Prompt buildPrompt(AiWriteGenerateReqVO generateReqVO, AiModelDO model, String systemMessage) {
         // 1. 构建 message 列表
         List<Message> chatMessages = buildMessages(generateReqVO, systemMessage);
         // 2. 构建 options 对象

+ 41 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiModelTypeEnum.java

@@ -0,0 +1,41 @@
+package cn.iocoder.yudao.framework.ai.core.enums;
+
+import cn.iocoder.yudao.framework.common.core.ArrayValuable;
+import lombok.Getter;
+import lombok.RequiredArgsConstructor;
+
+import java.util.Arrays;
+
+/**
+ * AI 模型类型的枚举
+ *
+ * @author 芋道源码
+ */
+@Getter
+@RequiredArgsConstructor
+public enum AiModelTypeEnum implements ArrayValuable<Integer> {
+
+    CHAT(1, "对话"),
+    IMAGE(2, "图片"),
+    VOICE(3, "语音"),
+    VIDEO(4, "视频"),
+    EMBEDDING(5, "向量"),
+    RERANK(6, "重排序");
+
+    /**
+     * 类型
+     */
+    private final Integer type;
+    /**
+     * 类型名
+     */
+    private final String name;
+
+    public static final Integer[] ARRAYS = Arrays.stream(values()).map(AiModelTypeEnum::getType).toArray(Integer[]::new);
+
+    @Override
+    public Integer[] array() {
+        return ARRAYS;
+    }
+
+}

+ 11 - 1
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiPlatformEnum.java

@@ -1,8 +1,11 @@
 package cn.iocoder.yudao.framework.ai.core.enums;
 
+import cn.iocoder.yudao.framework.common.core.ArrayValuable;
 import lombok.AllArgsConstructor;
 import lombok.Getter;
 
+import java.util.Arrays;
+
 /**
  * AI 模型平台
  *
@@ -10,7 +13,7 @@ import lombok.Getter;
  */
 @Getter
 @AllArgsConstructor
-public enum AiPlatformEnum {
+public enum AiPlatformEnum implements ArrayValuable<String> {
 
     // ========== 国内平台 ==========
 
@@ -44,6 +47,8 @@ public enum AiPlatformEnum {
      */
     private final String name;
 
+    public static final String[] ARRAYS = Arrays.stream(values()).map(AiPlatformEnum::getPlatform).toArray(String[]::new);
+
     public static AiPlatformEnum validatePlatform(String platform) {
         for (AiPlatformEnum platformEnum : AiPlatformEnum.values()) {
             if (platformEnum.getPlatform().equals(platform)) {
@@ -53,4 +58,9 @@ public enum AiPlatformEnum {
         throw new IllegalArgumentException("非法平台: " + platform);
     }
 
+    @Override
+    public String[] array() {
+        return ARRAYS;
+    }
+
 }

+ 0 - 21
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java

@@ -456,25 +456,4 @@ public class AiModelFactoryImpl implements AiModelFactory {
         return vectorStore;
     }
 
-    /**
-     * 创建向量存储文件
-     *
-     * @param embeddingModel 嵌入模型
-     * @return 向量存储文件
-     */
-    private File createVectorStoreFile(EmbeddingModel embeddingModel) {
-        // 获取简单类名
-        String simpleClassName = embeddingModel.getClass().getSimpleName();
-        // 获取用户主目录
-        String userHome = FileUtil.getUserHomePath();
-        // 创建vector_store目录
-        File vectorStoreDir = new File(userHome, "vector_store");
-        if (!vectorStoreDir.exists()) {
-            vectorStoreDir.mkdirs();
-        }
-
-        // 创建文件
-        return new File(vectorStoreDir, "simple_" + simpleClassName + ".json");
-    }
-
 }