|
@@ -12,13 +12,17 @@ import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
|
|
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
|
|
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
|
|
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
|
|
-import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*;
|
|
|
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
|
|
|
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageReqVO;
|
|
|
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePublicPageReqVO;
|
|
|
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdateReqVO;
|
|
|
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO;
|
|
|
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO;
|
|
|
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
|
|
|
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
|
|
|
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
|
|
|
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
|
|
|
-import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
|
|
+import cn.iocoder.yudao.module.ai.service.model.AiModelService;
|
|
|
import cn.iocoder.yudao.module.infra.api.file.FileApi;
|
|
|
import com.alibaba.cloud.ai.dashscope.image.DashScopeImageOptions;
|
|
|
import jakarta.annotation.Resource;
|
|
@@ -55,13 +59,13 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
|
|
|
public class AiImageServiceImpl implements AiImageService {
|
|
|
|
|
|
@Resource
|
|
|
- private AiImageMapper imageMapper;
|
|
|
+ private AiModelService modelService;
|
|
|
|
|
|
@Resource
|
|
|
- private FileApi fileApi;
|
|
|
+ private AiImageMapper imageMapper;
|
|
|
|
|
|
@Resource
|
|
|
- private AiApiKeyService apiKeyService;
|
|
|
+ private FileApi fileApi;
|
|
|
|
|
|
@Override
|
|
|
public PageResult<AiImageDO> getImagePageMy(Long userId, AiImagePageReqVO pageReqVO) {
|
|
@@ -88,23 +92,31 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
|
|
|
@Override
|
|
|
public Long drawImage(Long userId, AiImageDrawReqVO drawReqVO) {
|
|
|
- // 1. 保存数据库
|
|
|
- AiImageDO image = BeanUtils.toBean(drawReqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false)
|
|
|
- .setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
|
|
|
+ // 1. 校验模型
|
|
|
+ AiModelDO model = modelService.validateModel(drawReqVO.getModelId());
|
|
|
+
|
|
|
+ // 2. 保存数据库
|
|
|
+ AiImageDO image = BeanUtils.toBean(drawReqVO, AiImageDO.class).setUserId(userId)
|
|
|
+ .setPlatform(model.getPlatform()).setModelId(model.getId()).setModel(model.getModel())
|
|
|
+ .setPublicStatus(false).setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
|
|
|
imageMapper.insert(image);
|
|
|
- // 2. 异步绘制,后续前端通过返回的 id 进行轮询结果
|
|
|
- getSelf().executeDrawImage(image, drawReqVO);
|
|
|
+
|
|
|
+ // 3. 异步绘制,后续前端通过返回的 id 进行轮询结果
|
|
|
+ getSelf().executeDrawImage(image, drawReqVO, model);
|
|
|
return image.getId();
|
|
|
}
|
|
|
|
|
|
@Async
|
|
|
- public void executeDrawImage(AiImageDO image, AiImageDrawReqVO req) {
|
|
|
+ public void executeDrawImage(AiImageDO image, AiImageDrawReqVO reqVO, AiModelDO model) {
|
|
|
try {
|
|
|
// 1.1 构建请求
|
|
|
- ImageOptions request = buildImageOptions(req);
|
|
|
+ ImageOptions request = buildImageOptions(reqVO, model);
|
|
|
// 1.2 执行请求
|
|
|
- ImageModel imageModel = apiKeyService.getImageModel(AiPlatformEnum.validatePlatform(req.getPlatform()));
|
|
|
- ImageResponse response = imageModel.call(new ImagePrompt(req.getPrompt(), request));
|
|
|
+ ImageModel imageModel = modelService.getImageModel(model.getId());
|
|
|
+ ImageResponse response = imageModel.call(new ImagePrompt(reqVO.getPrompt(), request));
|
|
|
+ if (response.getResult() == null) {
|
|
|
+ throw new IllegalArgumentException("生成结果为空");
|
|
|
+ }
|
|
|
|
|
|
// 2. 上传到文件服务
|
|
|
String b64Json = response.getResult().getOutput().getB64Json();
|
|
@@ -116,25 +128,25 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
imageMapper.updateById(new AiImageDO().setId(image.getId()).setStatus(AiImageStatusEnum.SUCCESS.getStatus())
|
|
|
.setPicUrl(filePath).setFinishTime(LocalDateTime.now()));
|
|
|
} catch (Exception ex) {
|
|
|
- log.error("[doDall][image({}) 生成异常]", image, ex);
|
|
|
+ log.error("[executeDrawImage][image({}) 生成异常]", image, ex);
|
|
|
imageMapper.updateById(new AiImageDO().setId(image.getId())
|
|
|
.setStatus(AiImageStatusEnum.FAIL.getStatus())
|
|
|
.setErrorMessage(ex.getMessage()).setFinishTime(LocalDateTime.now()));
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- private static ImageOptions buildImageOptions(AiImageDrawReqVO draw) {
|
|
|
- if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.OPENAI.getPlatform())) {
|
|
|
+ private static ImageOptions buildImageOptions(AiImageDrawReqVO draw, AiModelDO model) {
|
|
|
+ if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.OPENAI.getPlatform())) {
|
|
|
// https://platform.openai.com/docs/api-reference/images/create
|
|
|
- return OpenAiImageOptions.builder().withModel(draw.getModel())
|
|
|
+ return OpenAiImageOptions.builder().withModel(model.getModel())
|
|
|
.withHeight(draw.getHeight()).withWidth(draw.getWidth())
|
|
|
.withStyle(MapUtil.getStr(draw.getOptions(), "style")) // 风格
|
|
|
.withResponseFormat("b64_json")
|
|
|
.build();
|
|
|
- } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.STABLE_DIFFUSION.getPlatform())) {
|
|
|
+ } else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.STABLE_DIFFUSION.getPlatform())) {
|
|
|
// https://platform.stability.ai/docs/api-reference#tag/SDXL-and-SD1.6/operation/textToImage
|
|
|
// https://platform.stability.ai/docs/api-reference#tag/Text-to-Image/operation/textToImage
|
|
|
- return StabilityAiImageOptions.builder().model(draw.getModel())
|
|
|
+ return StabilityAiImageOptions.builder().model(model.getModel())
|
|
|
.height(draw.getHeight()).width(draw.getWidth())
|
|
|
.seed(Long.valueOf(draw.getOptions().get("seed")))
|
|
|
.cfgScale(Float.valueOf(draw.getOptions().get("scale")))
|
|
@@ -143,22 +155,22 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
.stylePreset(String.valueOf(draw.getOptions().get("stylePreset")))
|
|
|
.clipGuidancePreset(String.valueOf(draw.getOptions().get("clipGuidancePreset")))
|
|
|
.build();
|
|
|
- } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.TONG_YI.getPlatform())) {
|
|
|
+ } else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.TONG_YI.getPlatform())) {
|
|
|
return DashScopeImageOptions.builder()
|
|
|
- .withModel(draw.getModel()).withN(1)
|
|
|
+ .withModel(model.getModel()).withN(1)
|
|
|
.withHeight(draw.getHeight()).withWidth(draw.getWidth())
|
|
|
.build();
|
|
|
- } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.YI_YAN.getPlatform())) {
|
|
|
+ } else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.YI_YAN.getPlatform())) {
|
|
|
return QianFanImageOptions.builder()
|
|
|
- .model(draw.getModel()).N(1)
|
|
|
+ .model(model.getModel()).N(1)
|
|
|
.height(draw.getHeight()).width(draw.getWidth())
|
|
|
.build();
|
|
|
- } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.ZHI_PU.getPlatform())) {
|
|
|
+ } else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.ZHI_PU.getPlatform())) {
|
|
|
return ZhiPuAiImageOptions.builder()
|
|
|
- .model(draw.getModel())
|
|
|
+ .model(model.getModel())
|
|
|
.build();
|
|
|
}
|
|
|
- throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform());
|
|
|
+ throw new IllegalArgumentException("不支持的 AI 平台:" + model.getPlatform());
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -206,7 +218,7 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
@Override
|
|
|
@Transactional(rollbackFor = Exception.class)
|
|
|
public Long midjourneyImagine(Long userId, AiMidjourneyImagineReqVO reqVO) {
|
|
|
- MidjourneyApi midjourneyApi = apiKeyService.getMidjourneyApi();
|
|
|
+ MidjourneyApi midjourneyApi = modelService.getMidjourneyApi();
|
|
|
// 1. 保存数据库
|
|
|
AiImageDO image = BeanUtils.toBean(reqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false)
|
|
|
.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus())
|
|
@@ -237,7 +249,7 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
|
|
|
@Override
|
|
|
public Integer midjourneySync() {
|
|
|
- MidjourneyApi midjourneyApi = apiKeyService.getMidjourneyApi();
|
|
|
+ MidjourneyApi midjourneyApi = modelService.getMidjourneyApi();
|
|
|
// 1.1 获取 Midjourney 平台,状态在 “进行中” 的 image
|
|
|
List<AiImageDO> imageList = imageMapper.selectListByStatusAndPlatform(
|
|
|
AiImageStatusEnum.IN_PROGRESS.getStatus(), AiPlatformEnum.MIDJOURNEY.getPlatform());
|
|
@@ -308,7 +320,7 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
|
|
|
@Override
|
|
|
public Long midjourneyAction(Long userId, AiMidjourneyActionReqVO reqVO) {
|
|
|
- MidjourneyApi midjourneyApi = apiKeyService.getMidjourneyApi();
|
|
|
+ MidjourneyApi midjourneyApi = modelService.getMidjourneyApi();
|
|
|
// 1.1 检查 image
|
|
|
AiImageDO image = validateImageExists(reqVO.getId());
|
|
|
if (ObjUtil.notEqual(userId, image.getUserId())) {
|