|  | @@ -1,5 +1,6 @@
 | 
	
		
			
				|  |  |  package cn.iocoder.yudao.module.ai.service.impl;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +import cn.hutool.core.util.StrUtil;
 | 
	
		
			
				|  |  |  import cn.iocoder.yudao.framework.ai.exception.AiException;
 | 
	
		
			
				|  |  |  import cn.iocoder.yudao.framework.ai.image.ImageGeneration;
 | 
	
		
			
				|  |  |  import cn.iocoder.yudao.framework.ai.image.ImagePrompt;
 | 
	
	
		
			
				|  | @@ -9,18 +10,20 @@ import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageOptions;
 | 
	
		
			
				|  |  |  import cn.iocoder.yudao.framework.ai.imageopenai.enums.OpenAiImageModelEnum;
 | 
	
		
			
				|  |  |  import cn.iocoder.yudao.framework.ai.imageopenai.enums.OpenAiImageStyleEnum;
 | 
	
		
			
				|  |  |  import cn.iocoder.yudao.framework.ai.midjourney.api.MidjourneyInteractionsApi;
 | 
	
		
			
				|  |  | +import cn.iocoder.yudao.framework.ai.midjourney.api.req.ReRollReq;
 | 
	
		
			
				|  |  |  import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyWebSocketStarter;
 | 
	
		
			
				|  |  |  import cn.iocoder.yudao.framework.ai.midjourney.webSocket.WssNotify;
 | 
	
		
			
				|  |  |  import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
 | 
	
		
			
				|  |  | +import cn.iocoder.yudao.framework.common.pojo.PageResult;
 | 
	
		
			
				|  |  | +import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
 | 
	
		
			
				|  |  | +import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
 | 
	
		
			
				|  |  |  import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
 | 
	
		
			
				|  |  |  import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
 | 
	
		
			
				|  |  | -import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReqVO;
 | 
	
		
			
				|  |  | -import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingRespVO;
 | 
	
		
			
				|  |  | -import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReqVO;
 | 
	
		
			
				|  |  | +import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*;
 | 
	
		
			
				|  |  |  import cn.iocoder.yudao.module.ai.convert.AiImageConvert;
 | 
	
		
			
				|  |  |  import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
 | 
	
		
			
				|  |  |  import cn.iocoder.yudao.module.ai.dal.mysql.AiImageMapper;
 | 
	
		
			
				|  |  | -import cn.iocoder.yudao.module.ai.enums.AiChatDrawingStatusEnum;
 | 
	
		
			
				|  |  | +import cn.iocoder.yudao.module.ai.enums.AiImageDrawingStatusEnum;
 | 
	
		
			
				|  |  |  import cn.iocoder.yudao.module.ai.service.AiImageService;
 | 
	
		
			
				|  |  |  import jakarta.annotation.PostConstruct;
 | 
	
		
			
				|  |  |  import lombok.AllArgsConstructor;
 | 
	
	
		
			
				|  | @@ -28,6 +31,9 @@ import lombok.extern.slf4j.Slf4j;
 | 
	
		
			
				|  |  |  import org.springframework.stereotype.Service;
 | 
	
		
			
				|  |  |  import org.springframework.transaction.annotation.Transactional;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +import java.util.Collections;
 | 
	
		
			
				|  |  | +import java.util.List;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  /**
 | 
	
		
			
				|  |  |   * ai 作图
 | 
	
		
			
				|  |  |   *
 | 
	
	
		
			
				|  | @@ -61,6 +67,23 @@ public class AiImageServiceImpl implements AiImageService {
 | 
	
		
			
				|  |  |          });
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    @Override
 | 
	
		
			
				|  |  | +    public PageResult<AiImageListRespVO> list(AiImageListReqVO req) {
 | 
	
		
			
				|  |  | +        // 获取登录用户
 | 
	
		
			
				|  |  | +        Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
 | 
	
		
			
				|  |  | +        // 查询当前用户下所有的绘画记录
 | 
	
		
			
				|  |  | +        PageResult<AiImageDO> pageResult = aiImageMapper.selectPage(req,
 | 
	
		
			
				|  |  | +                new LambdaQueryWrapperX<AiImageDO>()
 | 
	
		
			
				|  |  | +                        .eq(AiImageDO::getUserId, loginUserId)
 | 
	
		
			
				|  |  | +                        .orderByDesc(AiImageDO::getId)
 | 
	
		
			
				|  |  | +        );
 | 
	
		
			
				|  |  | +        // 转换 PageResult<AiImageListRespVO> 返回
 | 
	
		
			
				|  |  | +        PageResult<AiImageListRespVO> result = new PageResult<>();
 | 
	
		
			
				|  |  | +        result.setTotal(pageResult.getTotal());
 | 
	
		
			
				|  |  | +        result.setList(AiImageConvert.INSTANCE.convertAiImageListRespVO(pageResult.getList()));
 | 
	
		
			
				|  |  | +        return result;
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      @Override
 | 
	
		
			
				|  |  |      public AiImageDallDrawingRespVO dallDrawing(AiImageDallDrawingReqVO req) {
 | 
	
		
			
				|  |  |          // 获取 model
 | 
	
	
		
			
				|  | @@ -79,7 +102,8 @@ public class AiImageServiceImpl implements AiImageService {
 | 
	
		
			
				|  |  |              ImageGeneration imageGeneration = imageResponse.getResult();
 | 
	
		
			
				|  |  |              // 保存数据库
 | 
	
		
			
				|  |  |              doSave(req.getPrompt(), req.getSize(), req.getModal(),
 | 
	
		
			
				|  |  | -                    imageGeneration.getOutput().getUrl(), AiChatDrawingStatusEnum.COMPLETE, null);
 | 
	
		
			
				|  |  | +                    imageGeneration.getOutput().getUrl(), AiImageDrawingStatusEnum.COMPLETE, null,
 | 
	
		
			
				|  |  | +                    null, null, null);
 | 
	
		
			
				|  |  |              // 返回 flex
 | 
	
		
			
				|  |  |              respVO.setUrl(imageGeneration.getOutput().getUrl());
 | 
	
		
			
				|  |  |              respVO.setBase64(imageGeneration.getOutput().getB64Json());
 | 
	
	
		
			
				|  | @@ -87,7 +111,8 @@ public class AiImageServiceImpl implements AiImageService {
 | 
	
		
			
				|  |  |          } catch (AiException aiException) {
 | 
	
		
			
				|  |  |              // 保存数据库
 | 
	
		
			
				|  |  |              doSave(req.getPrompt(), req.getSize(), req.getModal(),
 | 
	
		
			
				|  |  | -                    null, AiChatDrawingStatusEnum.FAIL, aiException.getMessage());
 | 
	
		
			
				|  |  | +                    null, AiImageDrawingStatusEnum.FAIL, aiException.getMessage(),
 | 
	
		
			
				|  |  | +                    null, null, null);
 | 
	
		
			
				|  |  |              // 发送错误信息
 | 
	
		
			
				|  |  |              respVO.setErrorMessage(aiException.getMessage());
 | 
	
		
			
				|  |  |              return respVO;
 | 
	
	
		
			
				|  | @@ -99,7 +124,8 @@ public class AiImageServiceImpl implements AiImageService {
 | 
	
		
			
				|  |  |      public void midjourney(AiImageMidjourneyReqVO req) {
 | 
	
		
			
				|  |  |          // 保存数据库
 | 
	
		
			
				|  |  |          AiImageDO aiImageDO = doSave(req.getPrompt(), null, "midjoureny",
 | 
	
		
			
				|  |  | -                null, AiChatDrawingStatusEnum.SUBMIT, null);
 | 
	
		
			
				|  |  | +                null, AiImageDrawingStatusEnum.SUBMIT, null,
 | 
	
		
			
				|  |  | +                null, null, null);
 | 
	
		
			
				|  |  |          // 提交 midjourney 任务
 | 
	
		
			
				|  |  |          Boolean imagine = midjourneyInteractionsApi.imagine(aiImageDO.getId(), req.getPrompt());
 | 
	
		
			
				|  |  |          if (!imagine) {
 | 
	
	
		
			
				|  | @@ -107,23 +133,71 @@ public class AiImageServiceImpl implements AiImageService {
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -//    private static void sendSseEmitter(Utf8SseEmitter sseEmitter, Object object) {
 | 
	
		
			
				|  |  | -//        try {
 | 
	
		
			
				|  |  | -//            sseEmitter.send(object, MediaType.APPLICATION_JSON);
 | 
	
		
			
				|  |  | -//        } catch (IOException e) {
 | 
	
		
			
				|  |  | -//            throw new RuntimeException(e);
 | 
	
		
			
				|  |  | -//        } finally {
 | 
	
		
			
				|  |  | -//            // 发送 complete
 | 
	
		
			
				|  |  | -//            sseEmitter.complete();
 | 
	
		
			
				|  |  | -//        }
 | 
	
		
			
				|  |  | -//    }
 | 
	
		
			
				|  |  | +    @Transactional(rollbackFor = Exception.class)
 | 
	
		
			
				|  |  | +    @Override
 | 
	
		
			
				|  |  | +    public void midjourneyOperate(AiImageMidjourneyOperateReqVO req) {
 | 
	
		
			
				|  |  | +        // 校验是否存在
 | 
	
		
			
				|  |  | +        AiImageDO aiImageDO = validateExists(req);
 | 
	
		
			
				|  |  | +        // 获取 midjourneyOperations
 | 
	
		
			
				|  |  | +        List<AiImageMidjourneyOperationsVO> midjourneyOperations = getMidjourneyOperations(aiImageDO);
 | 
	
		
			
				|  |  | +        // 校验 OperateId 是否存在
 | 
	
		
			
				|  |  | +        AiImageMidjourneyOperationsVO midjourneyOperationsVO = validateMidjourneyOperationsExists(midjourneyOperations, req.getOperateId());
 | 
	
		
			
				|  |  | +        // 校验 messageId
 | 
	
		
			
				|  |  | +        validateMessageId(aiImageDO.getMjMessageId(), req.getMessageId());
 | 
	
		
			
				|  |  | +        // 获取 mjOperationName
 | 
	
		
			
				|  |  | +        String mjOperationName = midjourneyOperationsVO.getLabel();
 | 
	
		
			
				|  |  | +        // 保存一个 image 任务记录
 | 
	
		
			
				|  |  | +        doSave(aiImageDO.getPrompt(), aiImageDO.getSize(), aiImageDO.getModal(),
 | 
	
		
			
				|  |  | +                null, AiImageDrawingStatusEnum.SUBMIT, null,
 | 
	
		
			
				|  |  | +                req.getMessageId(), req.getOperateId(), mjOperationName);
 | 
	
		
			
				|  |  | +        // 提交操作
 | 
	
		
			
				|  |  | +        midjourneyInteractionsApi.reRoll(
 | 
	
		
			
				|  |  | +                new ReRollReq()
 | 
	
		
			
				|  |  | +                        .setCustomId(req.getOperateId())
 | 
	
		
			
				|  |  | +                        .setMessageId(req.getMessageId())
 | 
	
		
			
				|  |  | +        );
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    private void validateMessageId(String mjMessageId, String messageId) {
 | 
	
		
			
				|  |  | +        if (!mjMessageId.equals(messageId)) {
 | 
	
		
			
				|  |  | +            throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_MESSAGE_ID_INCORRECT);
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    private AiImageMidjourneyOperationsVO validateMidjourneyOperationsExists(List<AiImageMidjourneyOperationsVO> midjourneyOperations, String operateId) {
 | 
	
		
			
				|  |  | +        for (AiImageMidjourneyOperationsVO midjourneyOperation : midjourneyOperations) {
 | 
	
		
			
				|  |  | +            if (midjourneyOperation.getCustom_id().equals(operateId)) {
 | 
	
		
			
				|  |  | +                return midjourneyOperation;
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +        throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_OPERATION_NOT_EXISTS);
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    private List<AiImageMidjourneyOperationsVO> getMidjourneyOperations(AiImageDO aiImageDO) {
 | 
	
		
			
				|  |  | +        if (StrUtil.isBlank(aiImageDO.getMjOperations())) {
 | 
	
		
			
				|  |  | +            return Collections.emptyList();
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +        return JsonUtils.parseArray(aiImageDO.getMjOperations(), AiImageMidjourneyOperationsVO.class);
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    private AiImageDO validateExists(AiImageMidjourneyOperateReqVO req) {
 | 
	
		
			
				|  |  | +        AiImageDO aiImageDO = aiImageMapper.selectById(req.getId());
 | 
	
		
			
				|  |  | +        if (aiImageDO == null) {
 | 
	
		
			
				|  |  | +            throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_IMAGINE_FAIL);
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +        return aiImageDO;
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      private AiImageDO doSave(String prompt,
 | 
	
		
			
				|  |  | -                        String size,
 | 
	
		
			
				|  |  | -                        String model,
 | 
	
		
			
				|  |  | -                        String imageUrl,
 | 
	
		
			
				|  |  | -                        AiChatDrawingStatusEnum drawingStatusEnum,
 | 
	
		
			
				|  |  | -                        String drawingError) {
 | 
	
		
			
				|  |  | +                             String size,
 | 
	
		
			
				|  |  | +                             String model,
 | 
	
		
			
				|  |  | +                             String drawingImageUrl,
 | 
	
		
			
				|  |  | +                             AiImageDrawingStatusEnum drawingStatusEnum,
 | 
	
		
			
				|  |  | +                             String drawingErrorMessage,
 | 
	
		
			
				|  |  | +                             String mjMessageId,
 | 
	
		
			
				|  |  | +                             String mjOperationId,
 | 
	
		
			
				|  |  | +                             String mjOperationName) {
 | 
	
		
			
				|  |  |          // 保存数据库
 | 
	
		
			
				|  |  |          Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
 | 
	
		
			
				|  |  |          AiImageDO aiImageDO = new AiImageDO();
 | 
	
	
		
			
				|  | @@ -132,9 +206,15 @@ public class AiImageServiceImpl implements AiImageService {
 | 
	
		
			
				|  |  |          aiImageDO.setSize(size);
 | 
	
		
			
				|  |  |          aiImageDO.setModal(model);
 | 
	
		
			
				|  |  |          aiImageDO.setUserId(loginUserId);
 | 
	
		
			
				|  |  | -        aiImageDO.setDrawingImageUrl(imageUrl);
 | 
	
		
			
				|  |  | +        // TODO @芋艿 如何上传到自己服务器
 | 
	
		
			
				|  |  | +        aiImageDO.setImageUrl(null);
 | 
	
		
			
				|  |  |          aiImageDO.setDrawingStatus(drawingStatusEnum.getStatus());
 | 
	
		
			
				|  |  | -        aiImageDO.setDrawingError(drawingError);
 | 
	
		
			
				|  |  | +        aiImageDO.setDrawingImageUrl(drawingImageUrl);
 | 
	
		
			
				|  |  | +        aiImageDO.setDrawingErrorMessage(drawingErrorMessage);
 | 
	
		
			
				|  |  | +        //
 | 
	
		
			
				|  |  | +        aiImageDO.setMjMessageId(mjMessageId);
 | 
	
		
			
				|  |  | +        aiImageDO.setMjOperationId(mjOperationId);
 | 
	
		
			
				|  |  | +        aiImageDO.setMjOperationName(mjOperationName);
 | 
	
		
			
				|  |  |          aiImageMapper.insert(aiImageDO);
 | 
	
		
			
				|  |  |          return aiImageDO;
 | 
	
		
			
				|  |  |      }
 |