|
@@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.ai.service.image;
|
|
|
import cn.hutool.core.bean.BeanUtil;
|
|
|
import cn.hutool.core.codec.Base64;
|
|
|
import cn.hutool.core.collection.CollUtil;
|
|
|
+import cn.hutool.core.lang.Assert;
|
|
|
import cn.hutool.core.map.MapUtil;
|
|
|
import cn.hutool.core.util.ObjUtil;
|
|
|
import cn.hutool.core.util.StrUtil;
|
|
@@ -217,52 +218,56 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
|
|
|
@Override
|
|
|
@Transactional(rollbackFor = Exception.class)
|
|
|
- public Long midjourneyImagine(Long userId, AiMidjourneyImagineReqVO reqVO) {
|
|
|
- MidjourneyApi midjourneyApi = modelService.getMidjourneyApi();
|
|
|
- // 1. 保存数据库
|
|
|
- AiImageDO image = BeanUtils.toBean(reqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false)
|
|
|
+ public Long midjourneyImagine(Long userId, AiMidjourneyImagineReqVO drawReqVO) {
|
|
|
+ // 1. 校验模型
|
|
|
+ AiModelDO model = modelService.validateModel(drawReqVO.getModelId());
|
|
|
+ Assert.equals(model.getPlatform(), AiPlatformEnum.MIDJOURNEY.getPlatform(), "平台不匹配");
|
|
|
+ MidjourneyApi midjourneyApi = modelService.getMidjourneyApi(model.getId());
|
|
|
+
|
|
|
+ // 2. 保存数据库
|
|
|
+ AiImageDO image = BeanUtils.toBean(drawReqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false)
|
|
|
.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus())
|
|
|
- .setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform());
|
|
|
+ .setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform()).setModelId(model.getId()).setModel(model.getName());
|
|
|
imageMapper.insert(image);
|
|
|
|
|
|
- // 2. 调用 Midjourney Proxy 提交任务
|
|
|
- List<String> base64Array = StrUtil.isBlank(reqVO.getReferImageUrl()) ? null :
|
|
|
- Collections.singletonList("data:image/jpeg;base64,".concat(Base64.encode(HttpUtil.downloadBytes(reqVO.getReferImageUrl()))));
|
|
|
+ // 3. 调用 Midjourney Proxy 提交任务
|
|
|
+ List<String> base64Array = StrUtil.isBlank(drawReqVO.getReferImageUrl()) ? null :
|
|
|
+ Collections.singletonList("data:image/jpeg;base64,".concat(Base64.encode(HttpUtil.downloadBytes(drawReqVO.getReferImageUrl()))));
|
|
|
MidjourneyApi.ImagineRequest imagineRequest = new MidjourneyApi.ImagineRequest(
|
|
|
- base64Array, reqVO.getPrompt(),null,
|
|
|
- MidjourneyApi.ImagineRequest.buildState(reqVO.getWidth(),
|
|
|
- reqVO.getHeight(), reqVO.getVersion(), reqVO.getModel()));
|
|
|
+ base64Array, drawReqVO.getPrompt(),null,
|
|
|
+ MidjourneyApi.ImagineRequest.buildState(drawReqVO.getWidth(),
|
|
|
+ drawReqVO.getHeight(), drawReqVO.getVersion(), model.getModel()));
|
|
|
MidjourneyApi.SubmitResponse imagineResponse = midjourneyApi.imagine(imagineRequest);
|
|
|
|
|
|
- // 3. 情况一【失败】:抛出业务异常
|
|
|
+ // 4.1 情况一【失败】:抛出业务异常
|
|
|
if (!MidjourneyApi.SubmitCodeEnum.SUCCESS_CODES.contains(imagineResponse.code())) {
|
|
|
String description = imagineResponse.description().contains("quota_not_enough") ?
|
|
|
"账户余额不足" : imagineResponse.description();
|
|
|
throw exception(IMAGE_MIDJOURNEY_SUBMIT_FAIL, description);
|
|
|
}
|
|
|
|
|
|
- // 4. 情况二【成功】:更新 taskId 和参数
|
|
|
+ // 4.2 情况二【成功】:更新 taskId 和参数
|
|
|
imageMapper.updateById(new AiImageDO().setId(image.getId())
|
|
|
- .setTaskId(imagineResponse.result()).setOptions(BeanUtil.beanToMap(reqVO)));
|
|
|
+ .setTaskId(imagineResponse.result()).setOptions(BeanUtil.beanToMap(drawReqVO)));
|
|
|
return image.getId();
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public Integer midjourneySync() {
|
|
|
- MidjourneyApi midjourneyApi = modelService.getMidjourneyApi();
|
|
|
// 1.1 获取 Midjourney 平台,状态在 “进行中” 的 image
|
|
|
- List<AiImageDO> imageList = imageMapper.selectListByStatusAndPlatform(
|
|
|
+ List<AiImageDO> images = imageMapper.selectListByStatusAndPlatform(
|
|
|
AiImageStatusEnum.IN_PROGRESS.getStatus(), AiPlatformEnum.MIDJOURNEY.getPlatform());
|
|
|
- if (CollUtil.isEmpty(imageList)) {
|
|
|
+ if (CollUtil.isEmpty(images)) {
|
|
|
return 0;
|
|
|
}
|
|
|
// 1.2 调用 Midjourney Proxy 获取任务进展
|
|
|
- List<MidjourneyApi.Notify> taskList = midjourneyApi.getTaskList(convertSet(imageList, AiImageDO::getTaskId));
|
|
|
+ MidjourneyApi midjourneyApi = modelService.getMidjourneyApi(images.get(0).getModelId());
|
|
|
+ List<MidjourneyApi.Notify> taskList = midjourneyApi.getTaskList(convertSet(images, AiImageDO::getTaskId));
|
|
|
Map<String, MidjourneyApi.Notify> taskMap = convertMap(taskList, MidjourneyApi.Notify::id);
|
|
|
|
|
|
// 2. 逐个处理,更新进展
|
|
|
int count = 0;
|
|
|
- for (AiImageDO image : imageList) {
|
|
|
+ for (AiImageDO image : images) {
|
|
|
MidjourneyApi.Notify notify = taskMap.get(image.getTaskId());
|
|
|
if (notify == null) {
|
|
|
log.error("[midjourneySync][image({}) 查询不到进展]", image);
|
|
@@ -320,12 +325,12 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
|
|
|
@Override
|
|
|
public Long midjourneyAction(Long userId, AiMidjourneyActionReqVO reqVO) {
|
|
|
- MidjourneyApi midjourneyApi = modelService.getMidjourneyApi();
|
|
|
// 1.1 检查 image
|
|
|
AiImageDO image = validateImageExists(reqVO.getId());
|
|
|
if (ObjUtil.notEqual(userId, image.getUserId())) {
|
|
|
throw exception(IMAGE_NOT_EXISTS);
|
|
|
}
|
|
|
+ MidjourneyApi midjourneyApi = modelService.getMidjourneyApi(image.getModelId());
|
|
|
// 1.2 检查 customId
|
|
|
MidjourneyApi.Button button = CollUtil.findOne(image.getButtons(),
|
|
|
buttonX -> buttonX.customId().equals(reqVO.getCustomId()));
|