Browse Source

【功能新增】AI:新增 QianFanEmbeddingModel 向量模型

YunaiV 5 months ago
parent
commit
20eb3013f2

+ 26 - 4
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java

@@ -64,6 +64,8 @@ import org.springframework.ai.openai.api.OpenAiApi;
 import org.springframework.ai.openai.api.OpenAiImageApi;
 import org.springframework.ai.openai.api.common.OpenAiApiConstants;
 import org.springframework.ai.qianfan.QianFanChatModel;
+import org.springframework.ai.qianfan.QianFanEmbeddingModel;
+import org.springframework.ai.qianfan.QianFanEmbeddingOptions;
 import org.springframework.ai.qianfan.QianFanImageModel;
 import org.springframework.ai.qianfan.api.QianFanApi;
 import org.springframework.ai.qianfan.api.QianFanImageApi;
@@ -78,6 +80,7 @@ import org.springframework.ai.vectorstore.qdrant.QdrantVectorStore;
 import org.springframework.ai.vectorstore.redis.RedisVectorStore;
 import org.springframework.ai.zhipuai.ZhiPuAiChatModel;
 import org.springframework.ai.zhipuai.ZhiPuAiEmbeddingModel;
+import org.springframework.ai.zhipuai.ZhiPuAiEmbeddingOptions;
 import org.springframework.ai.zhipuai.ZhiPuAiImageModel;
 import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
 import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi;
@@ -230,11 +233,16 @@ public class AiModelFactoryImpl implements AiModelFactory {
             switch (platform) {
                 case TONG_YI:
                     return buildTongYiEmbeddingModel(apiKey, model);
-                case OLLAMA:
-                    return buildOllamaEmbeddingModel(url, model);
-                // TODO @芋艿:yiyan
+                case YI_YAN:
+                    return buildYiYanEmbeddingModel(apiKey, model);
                 case ZHI_PU:
                     return buildZhiPuEmbeddingModel(apiKey, url, model);
+//                case OPENAI:
+//                    return buildOpenAiChatModel(apiKey, url);
+//                case AZURE_OPENAI:
+//                    return buildAzureOpenAiChatModel(apiKey, url);
+                case OLLAMA:
+                    return buildOllamaEmbeddingModel(url, model);
                 default:
                     throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
             }
@@ -443,7 +451,21 @@ public class AiModelFactoryImpl implements AiModelFactory {
     private ZhiPuAiEmbeddingModel buildZhiPuEmbeddingModel(String apiKey, String url, String model) {
         url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);
         ZhiPuAiApi zhiPuAiApi = new ZhiPuAiApi(url, apiKey);
-        return new ZhiPuAiEmbeddingModel(zhiPuAiApi);
+        ZhiPuAiEmbeddingOptions zhiPuAiEmbeddingOptions = ZhiPuAiEmbeddingOptions.builder().model(model).build();
+        return new ZhiPuAiEmbeddingModel(zhiPuAiApi, MetadataMode.EMBED, zhiPuAiEmbeddingOptions);
+    }
+
+    /**
+     * 可参考 {@link QianFanAutoConfiguration} 的 qianFanEmbeddingModel 方法
+     */
+    private QianFanEmbeddingModel buildYiYanEmbeddingModel(String key, String model) {
+        List<String> keys = StrUtil.split(key, '|');
+        Assert.equals(keys.size(), 2, "YiYanChatClient 的密钥需要 (appKey|secretKey) 格式");
+        String appKey = keys.get(0);
+        String secretKey = keys.get(1);
+        QianFanApi qianFanApi = new QianFanApi(appKey, secretKey);
+        QianFanEmbeddingOptions qianFanEmbeddingOptions = QianFanEmbeddingOptions.builder().model(model).build();
+        return new QianFanEmbeddingModel(qianFanApi, MetadataMode.EMBED, qianFanEmbeddingOptions);
     }
 
     private OllamaEmbeddingModel buildOllamaEmbeddingModel(String url, String model) {