Kaynağa Gözat

【功能新增】AI:百川模型的接入

YunaiV 4 ay önce
ebeveyn
işleme
cd4813f7dd

+ 28 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java

@@ -4,6 +4,7 @@ import cn.hutool.core.util.StrUtil;
 import cn.hutool.extra.spring.SpringUtil;
 import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
 import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactoryImpl;
+import cn.iocoder.yudao.framework.ai.core.model.baichuan.BaiChuanChatModel;
 import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
 import cn.iocoder.yudao.framework.ai.core.model.doubao.DouBaoChatModel;
 import cn.iocoder.yudao.framework.ai.core.model.hunyuan.HunYuanChatModel;
@@ -193,6 +194,33 @@ public class YudaoAiAutoConfiguration {
         return new XingHuoChatModel(openAiChatModel);
     }
 
+    @Bean
+    @ConditionalOnProperty(value = "yudao.ai.baichuan.enable", havingValue = "true")
+    public BaiChuanChatModel baiChuanChatClient(YudaoAiProperties yudaoAiProperties) {
+        YudaoAiProperties.BaiChuanProperties properties = yudaoAiProperties.getBaichuan();
+        return buildBaiChuanChatClient(properties);
+    }
+
+    public BaiChuanChatModel buildBaiChuanChatClient(YudaoAiProperties.BaiChuanProperties properties) {
+        if (StrUtil.isEmpty(properties.getModel())) {
+            properties.setModel(BaiChuanChatModel.MODEL_DEFAULT);
+        }
+        OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
+                .openAiApi(OpenAiApi.builder()
+                        .baseUrl(BaiChuanChatModel.BASE_URL)
+                        .apiKey(properties.getApiKey())
+                        .build())
+                .defaultOptions(OpenAiChatOptions.builder()
+                        .model(properties.getModel())
+                        .temperature(properties.getTemperature())
+                        .maxTokens(properties.getMaxTokens())
+                        .topP(properties.getTopP())
+                        .build())
+                .toolCallingManager(getToolCallingManager())
+                .build();
+        return new BaiChuanChatModel(openAiChatModel);
+    }
+
     @Bean
     @ConditionalOnProperty(value = "yudao.ai.midjourney.enable", havingValue = "true")
     public MidjourneyApi midjourneyApi(YudaoAiProperties yudaoAiProperties) {

+ 19 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiProperties.java

@@ -43,6 +43,12 @@ public class YudaoAiProperties {
     @SuppressWarnings("SpellCheckingInspection")
     private XingHuoProperties xinghuo;
 
+    /**
+     * 百川
+     */
+    @SuppressWarnings("SpellCheckingInspection")
+    private BaiChuanProperties baichuan;
+
     /**
      * Midjourney 绘图
      */
@@ -122,6 +128,19 @@ public class YudaoAiProperties {
 
     }
 
+    @Data
+    public static class  BaiChuanProperties {
+
+        private String enable;
+        private String apiKey;
+
+        private String model;
+        private Double temperature;
+        private Integer maxTokens;
+        private Double topP;
+
+    }
+
     @Data
     public static class MidjourneyProperties {
 

+ 1 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiPlatformEnum.java

@@ -27,6 +27,7 @@ public enum AiPlatformEnum implements ArrayValuable<String> {
     SILICON_FLOW("SiliconFlow", "硅基流动"), // 硅基流动
     MINI_MAX("MiniMax", "MiniMax"), // 稀宇科技
     MOONSHOT("Moonshot", "月之暗灭"), // KIMI
+    BAI_CHUAN("BaiChuan", "百川智能"), // 百川智能
 
     // ========== 国外平台 ==========
 

+ 14 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java

@@ -11,6 +11,7 @@ import cn.hutool.extra.spring.SpringUtil;
 import cn.iocoder.yudao.framework.ai.config.YudaoAiAutoConfiguration;
 import cn.iocoder.yudao.framework.ai.config.YudaoAiProperties;
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
+import cn.iocoder.yudao.framework.ai.core.model.baichuan.BaiChuanChatModel;
 import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
 import cn.iocoder.yudao.framework.ai.core.model.doubao.DouBaoChatModel;
 import cn.iocoder.yudao.framework.ai.core.model.hunyuan.HunYuanChatModel;
@@ -150,6 +151,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
                     return buildMoonshotChatModel(apiKey, url);
                 case XING_HUO:
                     return buildXingHuoChatModel(apiKey);
+                case BAI_CHUAN:
+                    return buildBaiChuanChatModel(apiKey);
                 case OPENAI:
                     return buildOpenAiChatModel(apiKey, url);
                 case AZURE_OPENAI:
@@ -186,6 +189,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
                 return SpringUtil.getBean(MoonshotChatModel.class);
             case XING_HUO:
                 return SpringUtil.getBean(XingHuoChatModel.class);
+            case BAI_CHUAN:
+                return SpringUtil.getBean(AzureOpenAiChatModel.class);
             case OPENAI:
                 return SpringUtil.getBean(OpenAiChatModel.class);
             case AZURE_OPENAI:
@@ -441,6 +446,15 @@ public class AiModelFactoryImpl implements AiModelFactory {
         return new YudaoAiAutoConfiguration().buildXingHuoChatClient(properties);
     }
 
+    /**
+     * 可参考 {@link YudaoAiAutoConfiguration#baiChuanChatClient(YudaoAiProperties)}
+     */
+    private BaiChuanChatModel buildBaiChuanChatModel(String apiKey) {
+        YudaoAiProperties.BaiChuanProperties properties = new YudaoAiProperties.BaiChuanProperties()
+                .setApiKey(apiKey);
+        return new YudaoAiAutoConfiguration().buildBaiChuanChatClient(properties);
+    }
+
     /**
      * 可参考 {@link OpenAiAutoConfiguration} 的 openAiChatModel 方法
      */

+ 45 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/baichuan/BaiChuanChatModel.java

@@ -0,0 +1,45 @@
+package cn.iocoder.yudao.framework.ai.core.model.baichuan;
+
+import lombok.RequiredArgsConstructor;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.ai.chat.model.ChatModel;
+import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.ai.chat.prompt.ChatOptions;
+import org.springframework.ai.chat.prompt.Prompt;
+import org.springframework.ai.openai.OpenAiChatModel;
+import reactor.core.publisher.Flux;
+
+/**
+ * 百川 {@link ChatModel} 实现类
+ *
+ * @author 芋道源码
+ */
+@Slf4j
+@RequiredArgsConstructor
+public class BaiChuanChatModel implements ChatModel {
+
+    public static final String BASE_URL = "https://api.baichuan-ai.com";
+
+    public static final String MODEL_DEFAULT = "Baichuan4-Turbo";
+
+    /**
+     * 兼容 OpenAI 接口,进行复用
+     */
+    private final OpenAiChatModel openAiChatModel;
+
+    @Override
+    public ChatResponse call(Prompt prompt) {
+        return openAiChatModel.call(prompt);
+    }
+
+    @Override
+    public Flux<ChatResponse> stream(Prompt prompt) {
+        return openAiChatModel.stream(prompt);
+    }
+
+    @Override
+    public ChatOptions getDefaultOptions() {
+        return openAiChatModel.getDefaultOptions();
+    }
+
+}

+ 1 - 1
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/siliconflow/SiliconFlowImageModel.java

@@ -149,7 +149,7 @@ public class SiliconFlowImageModel implements ImageModel {
                 .batchSize(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getN(), defaultOptions.getN()))
                 .width(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getWidth(), defaultOptions.getWidth()))
                 .height(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getHeight(), defaultOptions.getHeight()))
-                // Handle OpenAI specific image options
+                // Handle SiliconFlow specific image options
                 .negativePrompt(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getNegativePrompt(), defaultOptions.getNegativePrompt()))
                 .numInferenceSteps(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getNumInferenceSteps(), defaultOptions.getNumInferenceSteps()))
                 .guidanceScale(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getGuidanceScale(), defaultOptions.getGuidanceScale()))

+ 1 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java

@@ -50,6 +50,7 @@ public class AiUtils {
             case HUN_YUAN: // 复用 OpenAI 客户端
             case XING_HUO: // 复用 OpenAI 客户端
             case SILICON_FLOW: // 复用 OpenAI 客户端
+            case BAI_CHUAN: // 复用 OpenAI 客户端
                 return OpenAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
                         .toolNames(toolNames).build();
             case AZURE_OPENAI:

+ 68 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/BaiChuanChatModelTests.java

@@ -0,0 +1,68 @@
+package cn.iocoder.yudao.framework.ai.chat;
+
+import cn.iocoder.yudao.framework.ai.core.model.baichuan.BaiChuanChatModel;
+import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
+import org.junit.jupiter.api.Disabled;
+import org.junit.jupiter.api.Test;
+import org.springframework.ai.chat.messages.Message;
+import org.springframework.ai.chat.messages.SystemMessage;
+import org.springframework.ai.chat.messages.UserMessage;
+import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.ai.chat.prompt.Prompt;
+import org.springframework.ai.openai.OpenAiChatModel;
+import org.springframework.ai.openai.OpenAiChatOptions;
+import org.springframework.ai.openai.api.OpenAiApi;
+import reactor.core.publisher.Flux;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * {@link BaiChuanChatModel} 集成测试
+ *
+ * @author 芋道源码
+ */
+public class BaiChuanChatModelTests {
+
+    private final OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
+            .openAiApi(OpenAiApi.builder()
+                    .baseUrl(BaiChuanChatModel.BASE_URL)
+                    .apiKey("sk-61b6766a94c70786ed02673f5e16af3c") // apiKey
+                    .build())
+            .defaultOptions(OpenAiChatOptions.builder()
+                    .model("Baichuan4-Turbo") // 模型(https://platform.baichuan-ai.com/docs/api)
+                    .temperature(0.7)
+                    .build())
+            .build();
+
+    private final DeepSeekChatModel chatModel = new DeepSeekChatModel(openAiChatModel);
+
+    @Test
+    @Disabled
+    public void testCall() {
+        // 准备参数
+        List<Message> messages = new ArrayList<>();
+        messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
+        messages.add(new UserMessage("1 + 1 = ?"));
+
+        // 调用
+        ChatResponse response = chatModel.call(new Prompt(messages));
+        // 打印结果
+        System.out.println(response);
+    }
+
+    @Test
+    @Disabled
+    public void testStream() {
+        // 准备参数
+        List<Message> messages = new ArrayList<>();
+        messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
+        messages.add(new UserMessage("1 + 1 = ?"));
+
+        // 调用
+        Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
+        // 打印结果
+        flux.doOnNext(System.out::println).then().block();
+    }
+
+}