Browse Source

【代码优化】AI:硅基流动的图片生成

YunaiV 4 months ago
parent
commit
ef5e56d560

+ 2 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java

@@ -207,6 +207,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
                 return SpringUtil.getBean(QianFanImageModel.class);
             case ZHI_PU:
                 return SpringUtil.getBean(ZhiPuAiImageModel.class);
+            case SILICON_FLOW:
+                return SpringUtil.getBean(SiliconFlowImageModel.class);
             case OPENAI:
                 return SpringUtil.getBean(OpenAiImageModel.class);
             case STABLE_DIFFUSION:

+ 2 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/siliconflow/SiliconFlowApiConstants.java

@@ -27,6 +27,8 @@ public final class SiliconFlowApiConstants {
 
 	public static final String MODEL_DEFAULT = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B";
 
+    public static final String DEFAULT_IMAGE_MODEL = "Kwai-Kolors/Kolors";
+
 	public static final String PROVIDER_NAME = "Siiconflow";
 
 }

+ 29 - 4
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/siliconflow/SiliconFlowImageModel.java

@@ -31,6 +31,7 @@ import org.springframework.ai.openai.api.OpenAiImageApi;
 import org.springframework.ai.openai.metadata.OpenAiImageGenerationMetadata;
 import org.springframework.ai.retry.RetryUtils;
 import org.springframework.http.ResponseEntity;
+import org.springframework.lang.Nullable;
 import org.springframework.retry.support.RetryTemplate;
 import org.springframework.util.Assert;
 
@@ -82,7 +83,8 @@ public class SiliconFlowImageModel implements ImageModel {
 
 	@Override
 	public ImageResponse call(ImagePrompt imagePrompt) {
-		SiliconFlowImageApi.SiliconflowImageRequest imageRequest = createRequest(imagePrompt);
+        SiliconFlowImageOptions requestImageOptions = mergeOptions(imagePrompt.getOptions(), this.defaultOptions);
+        SiliconFlowImageApi.SiliconflowImageRequest imageRequest = createRequest(imagePrompt, requestImageOptions);
 
 		var observationContext = ImageModelObservationContext.builder()
 			.imagePrompt(imagePrompt)
@@ -105,13 +107,14 @@ public class SiliconFlowImageModel implements ImageModel {
 			});
 	}
 
-	private SiliconFlowImageApi.SiliconflowImageRequest createRequest(ImagePrompt imagePrompt) {
+	private SiliconFlowImageApi.SiliconflowImageRequest createRequest(ImagePrompt imagePrompt,
+                                                                      SiliconFlowImageOptions requestImageOptions) {
 		String instructions = imagePrompt.getInstructions().get(0).getText();
 
 		SiliconFlowImageApi.SiliconflowImageRequest imageRequest = new SiliconFlowImageApi.SiliconflowImageRequest(instructions,
-				imagePrompt.getOptions().getModel());
+                SiliconFlowApiConstants.DEFAULT_IMAGE_MODEL);
 
-		return ModelOptionsUtils.merge(imagePrompt.getOptions(), imageRequest, SiliconFlowImageApi.SiliconflowImageRequest.class);
+		return ModelOptionsUtils.merge(requestImageOptions, imageRequest, SiliconFlowImageApi.SiliconflowImageRequest.class);
 	}
 
 	private ImageResponse convertResponse(ResponseEntity<OpenAiImageApi.OpenAiImageResponse> imageResponseEntity,
@@ -131,4 +134,26 @@ public class SiliconFlowImageModel implements ImageModel {
 		ImageResponseMetadata openAiImageResponseMetadata = new ImageResponseMetadata(imageApiResponse.created());
 		return new ImageResponse(imageGenerationList, openAiImageResponseMetadata);
 	}
+
+    private SiliconFlowImageOptions mergeOptions(@Nullable ImageOptions runtimeOptions, SiliconFlowImageOptions defaultOptions) {
+        var runtimeOptionsForProvider = ModelOptionsUtils.copyToTarget(runtimeOptions, ImageOptions.class,
+                SiliconFlowImageOptions.class);
+
+        if (runtimeOptionsForProvider == null) {
+            return defaultOptions;
+        }
+
+        return SiliconFlowImageOptions.builder()
+                // Handle portable image options
+                .model(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getModel(), defaultOptions.getModel()))
+                .batchSize(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getN(), defaultOptions.getN()))
+                .width(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getWidth(), defaultOptions.getWidth()))
+                .height(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getHeight(), defaultOptions.getHeight()))
+                // Handle OpenAI specific image options
+                .negativePrompt(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getNegativePrompt(), defaultOptions.getNegativePrompt()))
+                .numInferenceSteps(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getNumInferenceSteps(), defaultOptions.getNumInferenceSteps()))
+                .guidanceScale(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getGuidanceScale(), defaultOptions.getGuidanceScale()))
+                .seed(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getSeed(), defaultOptions.getSeed()))
+                .build();
+    }
 }

+ 2 - 2
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/siliconflow/SiliconFlowImageOptions.java

@@ -89,12 +89,12 @@ public class SiliconFlowImageOptions implements ImageOptions {
 
     @Override
     public Integer getN() {
-        return null;
+        return batchSize;
     }
 
     @Override
     public String getResponseFormat() {
-        return null;
+        return "url";
     }
 
     @Override