Kaynağa Gözat

【代码优化】AI:适配 Spring AI 1.0.6 对 Ollama 的逻辑

YunaiV 5 ay önce
ebeveyn
işleme
d05a7bd59a

+ 12 - 11
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java

@@ -20,6 +20,7 @@ import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
 import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
 import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
 import cn.iocoder.yudao.module.infra.api.file.FileApi;
+import com.alibaba.cloud.ai.dashscope.image.DashScopeImageOptions;
 import jakarta.annotation.Resource;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.ai.image.ImageModel;
@@ -133,14 +134,14 @@ public class AiImageServiceImpl implements AiImageService {
         } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.STABLE_DIFFUSION.getPlatform())) {
             // https://platform.stability.ai/docs/api-reference#tag/SDXL-and-SD1.6/operation/textToImage
             // https://platform.stability.ai/docs/api-reference#tag/Text-to-Image/operation/textToImage
-            return StabilityAiImageOptions.builder().withModel(draw.getModel())
-                    .withHeight(draw.getHeight()).withWidth(draw.getWidth())
-                    .withSeed(Long.valueOf(draw.getOptions().get("seed")))
-                    .withCfgScale(Float.valueOf(draw.getOptions().get("scale")))
-                    .withSteps(Integer.valueOf(draw.getOptions().get("steps")))
-                    .withSampler(String.valueOf(draw.getOptions().get("sampler")))
-                    .withStylePreset(String.valueOf(draw.getOptions().get("stylePreset")))
-                    .withClipGuidancePreset(String.valueOf(draw.getOptions().get("clipGuidancePreset")))
+            return StabilityAiImageOptions.builder().model(draw.getModel())
+                    .height(draw.getHeight()).width(draw.getWidth())
+                    .seed(Long.valueOf(draw.getOptions().get("seed")))
+                    .cfgScale(Float.valueOf(draw.getOptions().get("scale")))
+                    .steps(Integer.valueOf(draw.getOptions().get("steps")))
+                    .sampler(String.valueOf(draw.getOptions().get("sampler")))
+                    .stylePreset(String.valueOf(draw.getOptions().get("stylePreset")))
+                    .clipGuidancePreset(String.valueOf(draw.getOptions().get("clipGuidancePreset")))
                     .build();
         } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.TONG_YI.getPlatform())) {
             return DashScopeImageOptions.builder()
@@ -149,12 +150,12 @@ public class AiImageServiceImpl implements AiImageService {
                     .build();
         } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.YI_YAN.getPlatform())) {
             return QianFanImageOptions.builder()
-                    .withModel(draw.getModel()).withN(1)
-                    .withHeight(draw.getHeight()).withWidth(draw.getWidth())
+                    .model(draw.getModel()).N(1)
+                    .height(draw.getHeight()).width(draw.getWidth())
                     .build();
         } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.ZHI_PU.getPlatform())) {
             return ZhiPuAiImageOptions.builder()
-                    .withModel(draw.getModel())
+                    .model(draw.getModel())
                     .build();
         }
         throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform());

+ 0 - 1
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java

@@ -35,7 +35,6 @@ public class YudaoAiAutoConfiguration {
         return new AiModelFactoryImpl();
     }
 
-
     // ========== 各种 AI Client 创建 ==========
 
     @Bean

+ 52 - 36
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/LlamaChatModelTests.java

@@ -1,6 +1,20 @@
 package cn.iocoder.yudao.framework.ai.chat;
 
+import org.junit.jupiter.api.Disabled;
+import org.junit.jupiter.api.Test;
+import org.springframework.ai.chat.messages.Message;
+import org.springframework.ai.chat.messages.SystemMessage;
+import org.springframework.ai.chat.messages.UserMessage;
+import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.ai.chat.prompt.Prompt;
 import org.springframework.ai.ollama.OllamaChatModel;
+import org.springframework.ai.ollama.api.OllamaApi;
+import org.springframework.ai.ollama.api.OllamaModel;
+import org.springframework.ai.ollama.api.OllamaOptions;
+import reactor.core.publisher.Flux;
+
+import java.util.ArrayList;
+import java.util.List;
 
 /**
  * {@link OllamaChatModel} 集成测试
@@ -9,41 +23,43 @@ import org.springframework.ai.ollama.OllamaChatModel;
  */
 public class LlamaChatModelTests {
 
-//    private final OllamaApi ollamaApi = new OllamaApi(
-//            "http://127.0.0.1:11434");
-//    private final OllamaChatModel chatModel = new OllamaChatModel(ollamaApi,
-//            OllamaOptions.create().withModel(OllamaModel.LLAMA3.getModelName()));
-//
-//    @Test
-//    @Disabled
-//    public void testCall() {
-//        // 准备参数
-//        List<Message> messages = new ArrayList<>();
-//        messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
-//        messages.add(new UserMessage("1 + 1 = ?"));
-//
-//        // 调用
-//        ChatResponse response = chatModel.call(new Prompt(messages));
-//        // 打印结果
-//        System.out.println(response);
-//        System.out.println(response.getResult().getOutput());
-//    }
-//
-//    @Test
-//    @Disabled
-//    public void testStream() {
-//        // 准备参数
-//        List<Message> messages = new ArrayList<>();
-//        messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
-//        messages.add(new UserMessage("1 + 1 = ?"));
-//
-//        // 调用
-//        Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
-//        // 打印结果
-//        flux.doOnNext(response -> {
-////            System.out.println(response);
-//            System.out.println(response.getResult().getOutput());
-//        }).then().block();
-//    }
+    private final OllamaChatModel chatModel = OllamaChatModel.builder()
+            .ollamaApi(new OllamaApi("http://127.0.0.1:11434")) // Ollama 服务地址
+            .defaultOptions(OllamaOptions.builder()
+                    .model(OllamaModel.LLAMA3.getName()) // 模型
+                    .build())
+            .build();
+
+    @Test
+    @Disabled
+    public void testCall() {
+        // 准备参数
+        List<Message> messages = new ArrayList<>();
+        messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
+        messages.add(new UserMessage("1 + 1 = ?"));
+
+        // 调用
+        ChatResponse response = chatModel.call(new Prompt(messages));
+        // 打印结果
+        System.out.println(response);
+        System.out.println(response.getResult().getOutput());
+    }
+
+    @Test
+    @Disabled
+    public void testStream() {
+        // 准备参数
+        List<Message> messages = new ArrayList<>();
+        messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
+        messages.add(new UserMessage("1 + 1 = ?"));
+
+        // 调用
+        Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
+        // 打印结果
+        flux.doOnNext(response -> {
+//            System.out.println(response);
+            System.out.println(response.getResult().getOutput());
+        }).then().block();
+    }
 
 }