Selaa lähdekoodia

【代码重构】AI:使用 OpenAiApi 接入 deepseek

YunaiV 5 kuukautta sitten
vanhempi
sitoutus
f582c9cfa3

+ 10 - 131
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/deepseek/DeepSeekChatModel.java

@@ -1,166 +1,45 @@
 package cn.iocoder.yudao.framework.ai.core.model.deepseek;
 
-import cn.hutool.core.collection.ListUtil;
-import cn.hutool.core.lang.Assert;
+import lombok.RequiredArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
-import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
 import org.springframework.ai.chat.model.ChatModel;
 import org.springframework.ai.chat.model.ChatResponse;
-import org.springframework.ai.chat.model.Generation;
 import org.springframework.ai.chat.prompt.ChatOptions;
 import org.springframework.ai.chat.prompt.Prompt;
-import org.springframework.ai.model.ModelOptionsUtils;
-import org.springframework.ai.openai.OpenAiChatOptions;
-import org.springframework.ai.openai.api.OpenAiApi;
-import org.springframework.ai.openai.metadata.OpenAiChatResponseMetadata;
-import org.springframework.ai.retry.RetryUtils;
-import org.springframework.http.ResponseEntity;
-import org.springframework.retry.support.RetryTemplate;
+import org.springframework.ai.openai.OpenAiChatModel;
 import reactor.core.publisher.Flux;
 
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-
-import static cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions.MODEL_DEFAULT;
-
 /**
  * DeepSeek {@link ChatModel} 实现类
  *
  * @author fansili
  */
 @Slf4j
+@RequiredArgsConstructor
 public class DeepSeekChatModel implements ChatModel {
 
-    private static final String BASE_URL = "https://api.deepseek.com";
+    public static final String BASE_URL = "https://api.deepseek.com";
 
-    private final DeepSeekChatOptions defaultOptions;
-    private final RetryTemplate retryTemplate;
+    public static final String MODEL_DEFAULT = "deepseek-chat";
 
     /**
-     * DeepSeek 兼容 OpenAI 的 HTTP 接口,所以复用它的实现,简化接入成本
-     *
-     * 不过要注意,DeepSeek 没有完全兼容,所以不能使用 {@link org.springframework.ai.openai.OpenAiChatModel} 调用,但是实现会参考它
+     * 兼容 OpenAI 接口,进行复用
      */
-    private final OpenAiApi openAiApi;
-
-    public DeepSeekChatModel(String apiKey) {
-        this(apiKey, DeepSeekChatOptions.builder().model(MODEL_DEFAULT).temperature(0.7F).build());
-    }
-
-    public DeepSeekChatModel(String apiKey, DeepSeekChatOptions options) {
-       this(apiKey, options, RetryUtils.DEFAULT_RETRY_TEMPLATE);
-    }
-
-    public DeepSeekChatModel(String apiKey, DeepSeekChatOptions options, RetryTemplate retryTemplate) {
-        Assert.notEmpty(apiKey, "apiKey 不能为空");
-        Assert.notNull(options, "options 不能为空");
-        Assert.notNull(retryTemplate, "retryTemplate 不能为空");
-        this.openAiApi = new OpenAiApi(BASE_URL, apiKey);
-        this.defaultOptions = options;
-        this.retryTemplate = retryTemplate;
-    }
+    private final OpenAiChatModel openAiChatModel;
 
     @Override
     public ChatResponse call(Prompt prompt) {
-        OpenAiApi.ChatCompletionRequest request = createRequest(prompt, false);
-        return this.retryTemplate.execute(ctx -> {
-            // 1.1 发起调用
-            ResponseEntity<OpenAiApi.ChatCompletion> completionEntity = openAiApi.chatCompletionEntity(request);
-            // 1.2 校验结果
-            OpenAiApi.ChatCompletion chatCompletion = completionEntity.getBody();
-            if (chatCompletion == null) {
-                log.warn("No chat completion returned for prompt: {}", prompt);
-                return new ChatResponse(ListUtil.of());
-            }
-            List<OpenAiApi.ChatCompletion.Choice> choices = chatCompletion.choices();
-            if (choices == null) {
-                log.warn("No choices returned for prompt: {}", prompt);
-                return new ChatResponse(ListUtil.of());
-            }
-
-            // 2. 转换 ChatResponse 返回
-            List<Generation> generations = choices.stream().map(choice -> {
-                Generation generation = new Generation(choice.message().content(), toMap(chatCompletion.id(), choice));
-                if (choice.finishReason() != null) {
-                    generation.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null));
-                }
-                return generation;
-            }).toList();
-            return new ChatResponse(generations,
-                    OpenAiChatResponseMetadata.from(completionEntity.getBody()));
-        });
-    }
-
-    private Map<String, Object> toMap(String id, OpenAiApi.ChatCompletion.Choice choice) {
-        Map<String, Object> map = new HashMap<>();
-        OpenAiApi.ChatCompletionMessage message = choice.message();
-        if (message.role() != null) {
-            map.put("role", message.role().name());
-        }
-        if (choice.finishReason() != null) {
-            map.put("finishReason", choice.finishReason().name());
-        }
-        map.put("id", id);
-        return map;
+        return openAiChatModel.call(prompt);
     }
 
     @Override
     public Flux<ChatResponse> stream(Prompt prompt) {
-        OpenAiApi.ChatCompletionRequest request = createRequest(prompt, true);
-        return this.retryTemplate.execute(ctx -> {
-            // 1. 发起调用
-            Flux<OpenAiApi.ChatCompletionChunk> response = this.openAiApi.chatCompletionStream(request);
-            return response.map(chatCompletion -> {
-                String id = chatCompletion.id();
-                // 2. 转换 ChatResponse 返回
-                List<Generation> generations = chatCompletion.choices().stream().map(choice -> {
-                    String finish = (choice.finishReason() != null ? choice.finishReason().name() : "");
-                    String role = (choice.delta().role() != null ? choice.delta().role().name() : "");
-                    if (choice.finishReason() == OpenAiApi.ChatCompletionFinishReason.STOP) {
-                        // 兜底处理 DeepSeek 返回 STOP 时,role 为空的情况
-                        role = OpenAiApi.ChatCompletionMessage.Role.ASSISTANT.name();
-                    }
-                    Generation generation = new Generation(choice.delta().content(),
-                            Map.of("id", id, "role", role, "finishReason", finish));
-                    if (choice.finishReason() != null) {
-                        generation = generation.withGenerationMetadata(
-                                ChatGenerationMetadata.from(choice.finishReason().name(), null));
-                    }
-                    return generation;
-                }).toList();
-                return new ChatResponse(generations);
-            });
-        });
-    }
-
-    OpenAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
-        // 1. 构建 ChatCompletionMessage 对象
-        List<OpenAiApi.ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map(m ->
-                new OpenAiApi.ChatCompletionMessage(m.getContent(), OpenAiApi.ChatCompletionMessage.Role.valueOf(m.getMessageType().name()))).toList();
-        OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(chatCompletionMessages, stream);
-
-        // 2.1 补充 prompt 内置的 options
-        if (prompt.getOptions() != null) {
-            if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
-                OpenAiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
-                        ChatOptions.class, OpenAiChatOptions.class);
-                request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, OpenAiApi.ChatCompletionRequest.class);
-            } else {
-                throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
-                        + prompt.getOptions().getClass().getSimpleName());
-            }
-        }
-        // 2.2 补充默认 options
-        if (this.defaultOptions != null) {
-            request = ModelOptionsUtils.merge(request, this.defaultOptions, OpenAiApi.ChatCompletionRequest.class);
-        }
-        return request;
+        return openAiChatModel.stream(prompt);
     }
 
     @Override
     public ChatOptions getDefaultOptions() {
-         return DeepSeekChatOptions.fromOptions(defaultOptions);
+        return openAiChatModel.getDefaultOptions();
     }
 
 }

+ 0 - 55
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/deepseek/DeepSeekChatOptions.java

@@ -1,55 +0,0 @@
-package cn.iocoder.yudao.framework.ai.core.model.deepseek;
-
-import lombok.AllArgsConstructor;
-import lombok.Builder;
-import lombok.Data;
-import lombok.NoArgsConstructor;
-import org.springframework.ai.chat.prompt.ChatOptions;
-
-/**
- * DeepSeek {@link ChatOptions} 实现类
- *
- * 参考文档:<a href="https://platform.deepseek.com/api-docs/zh-cn/">快速开始</a>
- *
- * @author fansili
- */
-@Data
-@NoArgsConstructor
-@AllArgsConstructor
-@Builder
-public class DeepSeekChatOptions implements ChatOptions {
-
-    public static final String MODEL_DEFAULT = "deepseek-chat";
-
-    /**
-     * 模型
-     */
-    private String model;
-    /**
-     * 温度
-     */
-    private Float temperature;
-    /**
-     * 最大 Token
-     */
-    private Integer maxTokens;
-    /**
-     * topP
-     */
-    private Float topP;
-
-    @Override
-    public Integer getTopK() {
-        return null;
-    }
-
-    public static DeepSeekChatOptions fromOptions(DeepSeekChatOptions fromOptions) {
-        return DeepSeekChatOptions.builder()
-                .model(fromOptions.getModel())
-                .temperature(fromOptions.getTemperature())
-                .maxTokens(fromOptions.getMaxTokens())
-                .topP(fromOptions.getTopP())
-                .build();
-    }
-
-}

+ 0 - 55
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/xinghuo/XingHuoChatOptions.java

@@ -1,55 +0,0 @@
-package cn.iocoder.yudao.framework.ai.core.model.xinghuo;
-
-import lombok.AllArgsConstructor;
-import lombok.Builder;
-import lombok.Data;
-import lombok.NoArgsConstructor;
-import org.springframework.ai.chat.prompt.ChatOptions;
-
-/**
- * 讯飞星火 {@link ChatOptions} 实现类
- *
- * 参考文档:<a href="https://www.xfyun.cn/doc/spark/HTTP%E8%B0%83%E7%94%A8%E6%96%87%E6%A1%A3.html">HTTP 调用</a>
- *
- * @author fansili
- */
-@Data
-@NoArgsConstructor
-@AllArgsConstructor
-@Builder
-public class XingHuoChatOptions implements ChatOptions {
-
-    public static final String MODEL_DEFAULT = "generalv3.5";
-
-    /**
-     * 模型
-     */
-    private String model;
-    /**
-     * 温度
-     */
-    private Float temperature;
-    /**
-     * 最大 Token
-     */
-    private Integer maxTokens;
-    /**
-     * K 个候选
-     */
-    private Integer topK;
-
-    @Override
-    public Float getTopP() {
-        return null;
-    }
-
-    public static XingHuoChatOptions fromOptions(XingHuoChatOptions fromOptions) {
-        return XingHuoChatOptions.builder()
-                .model(fromOptions.getModel())
-                .temperature(fromOptions.getTemperature())
-                .maxTokens(fromOptions.getMaxTokens())
-                .topK(fromOptions.getTopK())
-                .build();
-    }
-
-}

+ 15 - 1
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/DeepSeekChatModelTests.java

@@ -8,6 +8,9 @@ import org.springframework.ai.chat.messages.SystemMessage;
 import org.springframework.ai.chat.messages.UserMessage;
 import org.springframework.ai.chat.model.ChatResponse;
 import org.springframework.ai.chat.prompt.Prompt;
+import org.springframework.ai.openai.OpenAiChatModel;
+import org.springframework.ai.openai.OpenAiChatOptions;
+import org.springframework.ai.openai.api.OpenAiApi;
 import reactor.core.publisher.Flux;
 
 import java.util.ArrayList;
@@ -20,7 +23,18 @@ import java.util.List;
  */
 public class DeepSeekChatModelTests {
 
-    private final DeepSeekChatModel chatModel = new DeepSeekChatModel("sk-e94db327cc7d457d99a8de8810fc6b12");
+    private static final OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
+            .openAiApi(OpenAiApi.builder()
+                    .baseUrl(DeepSeekChatModel.BASE_URL)
+                    .apiKey("sk-e52047409b144d97b791a6a46a2d") // apiKey
+                    .build())
+            .defaultOptions(OpenAiChatOptions.builder()
+                    .model("deepseek-chat") // 模型
+                    .temperature(0.7)
+                    .build())
+            .build();
+
+    private final DeepSeekChatModel chatModel = new DeepSeekChatModel(openAiChatModel);
 
     @Test
     @Disabled