浏览代码

【代码重构】AI:“聊天模型”重构为“模型”,支持 type 模型类型

YunaiV 5 月之前
父节点
当前提交
1c9c9790cd

+ 3 - 3
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/midjourney/AiMidjourneyImagineReqVO.java

@@ -13,9 +13,9 @@ public class AiMidjourneyImagineReqVO {
     @NotEmpty(message = "提示词不能为空!")
     private String prompt;
 
-    @Schema(description = "模型", requiredMode = Schema.RequiredMode.REQUIRED, example = "midjourney")
-    @NotEmpty(message = "模型不能为空")
-    private String model; // 参考 MidjourneyApi.ModelEnum
+    @Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
+    @NotNull(message = "模型编号不能为空")
+    private Long modelId;
 
     @Schema(description = "图片宽度", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
     @NotNull(message = "图片宽度不能为空")

+ 25 - 20
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java

@@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.ai.service.image;
 import cn.hutool.core.bean.BeanUtil;
 import cn.hutool.core.codec.Base64;
 import cn.hutool.core.collection.CollUtil;
+import cn.hutool.core.lang.Assert;
 import cn.hutool.core.map.MapUtil;
 import cn.hutool.core.util.ObjUtil;
 import cn.hutool.core.util.StrUtil;
@@ -217,52 +218,56 @@ public class AiImageServiceImpl implements AiImageService {
 
     @Override
     @Transactional(rollbackFor = Exception.class)
-    public Long midjourneyImagine(Long userId, AiMidjourneyImagineReqVO reqVO) {
-        MidjourneyApi midjourneyApi = modelService.getMidjourneyApi();
-        // 1. 保存数据库
-        AiImageDO image = BeanUtils.toBean(reqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false)
+    public Long midjourneyImagine(Long userId, AiMidjourneyImagineReqVO drawReqVO) {
+        // 1. 校验模型
+        AiModelDO model = modelService.validateModel(drawReqVO.getModelId());
+        Assert.equals(model.getPlatform(), AiPlatformEnum.MIDJOURNEY.getPlatform(), "平台不匹配");
+        MidjourneyApi midjourneyApi = modelService.getMidjourneyApi(model.getId());
+
+        // 2. 保存数据库
+        AiImageDO image = BeanUtils.toBean(drawReqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false)
                 .setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus())
-                .setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform());
+                .setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform()).setModelId(model.getId()).setModel(model.getName());
         imageMapper.insert(image);
 
-        // 2. 调用 Midjourney Proxy 提交任务
-        List<String> base64Array = StrUtil.isBlank(reqVO.getReferImageUrl()) ? null :
-                Collections.singletonList("data:image/jpeg;base64,".concat(Base64.encode(HttpUtil.downloadBytes(reqVO.getReferImageUrl()))));
+        // 3. 调用 Midjourney Proxy 提交任务
+        List<String> base64Array = StrUtil.isBlank(drawReqVO.getReferImageUrl()) ? null :
+                Collections.singletonList("data:image/jpeg;base64,".concat(Base64.encode(HttpUtil.downloadBytes(drawReqVO.getReferImageUrl()))));
         MidjourneyApi.ImagineRequest imagineRequest = new MidjourneyApi.ImagineRequest(
-                base64Array, reqVO.getPrompt(),null,
-                MidjourneyApi.ImagineRequest.buildState(reqVO.getWidth(),
-                        reqVO.getHeight(), reqVO.getVersion(), reqVO.getModel()));
+                base64Array, drawReqVO.getPrompt(),null,
+                MidjourneyApi.ImagineRequest.buildState(drawReqVO.getWidth(),
+                        drawReqVO.getHeight(), drawReqVO.getVersion(), model.getModel()));
         MidjourneyApi.SubmitResponse imagineResponse = midjourneyApi.imagine(imagineRequest);
 
-        // 3. 情况一【失败】:抛出业务异常
+        // 4.1 情况一【失败】:抛出业务异常
         if (!MidjourneyApi.SubmitCodeEnum.SUCCESS_CODES.contains(imagineResponse.code())) {
             String description = imagineResponse.description().contains("quota_not_enough") ?
                     "账户余额不足" : imagineResponse.description();
             throw exception(IMAGE_MIDJOURNEY_SUBMIT_FAIL, description);
         }
 
-        // 4. 情况二【成功】:更新 taskId 和参数
+        // 4.2 情况二【成功】:更新 taskId 和参数
         imageMapper.updateById(new AiImageDO().setId(image.getId())
-                .setTaskId(imagineResponse.result()).setOptions(BeanUtil.beanToMap(reqVO)));
+                .setTaskId(imagineResponse.result()).setOptions(BeanUtil.beanToMap(drawReqVO)));
         return image.getId();
     }
 
     @Override
     public Integer midjourneySync() {
-        MidjourneyApi midjourneyApi = modelService.getMidjourneyApi();
         // 1.1 获取 Midjourney 平台,状态在 “进行中” 的 image
-        List<AiImageDO> imageList = imageMapper.selectListByStatusAndPlatform(
+        List<AiImageDO> images = imageMapper.selectListByStatusAndPlatform(
                 AiImageStatusEnum.IN_PROGRESS.getStatus(), AiPlatformEnum.MIDJOURNEY.getPlatform());
-        if (CollUtil.isEmpty(imageList)) {
+        if (CollUtil.isEmpty(images)) {
             return 0;
         }
         // 1.2 调用 Midjourney Proxy 获取任务进展
-        List<MidjourneyApi.Notify> taskList = midjourneyApi.getTaskList(convertSet(imageList, AiImageDO::getTaskId));
+        MidjourneyApi midjourneyApi = modelService.getMidjourneyApi(images.get(0).getModelId());
+        List<MidjourneyApi.Notify> taskList = midjourneyApi.getTaskList(convertSet(images, AiImageDO::getTaskId));
         Map<String, MidjourneyApi.Notify> taskMap = convertMap(taskList, MidjourneyApi.Notify::id);
 
         // 2. 逐个处理,更新进展
         int count = 0;
-        for (AiImageDO image : imageList) {
+        for (AiImageDO image : images) {
             MidjourneyApi.Notify notify = taskMap.get(image.getTaskId());
             if (notify == null) {
                 log.error("[midjourneySync][image({}) 查询不到进展]", image);
@@ -320,12 +325,12 @@ public class AiImageServiceImpl implements AiImageService {
 
     @Override
     public Long midjourneyAction(Long userId, AiMidjourneyActionReqVO reqVO) {
-        MidjourneyApi midjourneyApi = modelService.getMidjourneyApi();
         // 1.1 检查 image
         AiImageDO image = validateImageExists(reqVO.getId());
         if (ObjUtil.notEqual(userId, image.getUserId())) {
             throw exception(IMAGE_NOT_EXISTS);
         }
+        MidjourneyApi midjourneyApi = modelService.getMidjourneyApi(image.getModelId());
         // 1.2 检查 customId
         MidjourneyApi.Button button = CollUtil.findOne(image.getButtons(),
                 buttonX -> buttonX.customId().equals(reqVO.getCustomId()));

+ 3 - 3
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelServiceImpl.java

@@ -137,9 +137,9 @@ public class AiChatModelServiceImpl implements AiModelService {
     }
 
     @Override
-    public MidjourneyApi getMidjourneyApi() {
-        AiApiKeyDO apiKey = apiKeyService.getRequiredDefaultApiKey(
-                AiPlatformEnum.MIDJOURNEY.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
+    public MidjourneyApi getMidjourneyApi(Long id) {
+        AiModelDO model = validateModel(id);
+        AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId());
         return modelFactory.getOrCreateMidjourneyApi(apiKey.getApiKey(), apiKey.getUrl());
     }
 

+ 2 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiModelService.java

@@ -109,9 +109,10 @@ public interface AiModelService {
     /**
      * 获得 MidjourneyApi 对象
      *
+     * @param id 编号
      * @return MidjourneyApi 对象
      */
-    MidjourneyApi getMidjourneyApi();
+    MidjourneyApi getMidjourneyApi(Long id);
 
     /**
      * 获得 SunoApi 对象