瀏覽代碼

【功能新增】AI:新增 AzureOpenAiEmbeddingModel、OpenAiEmbeddingModel 向量模型

YunaiV 5 月之前
父節點
當前提交
b9e8495712

+ 36 - 4
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java

@@ -35,6 +35,7 @@ import lombok.SneakyThrows;
 import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration;
 import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties;
 import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiConnectionProperties;
+import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiEmbeddingProperties;
 import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration;
 import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
 import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration;
@@ -49,6 +50,7 @@ import org.springframework.ai.autoconfigure.vectorstore.redis.RedisVectorStorePr
 import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration;
 import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties;
 import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
+import org.springframework.ai.azure.openai.AzureOpenAiEmbeddingModel;
 import org.springframework.ai.chat.model.ChatModel;
 import org.springframework.ai.document.MetadataMode;
 import org.springframework.ai.embedding.BatchingStrategy;
@@ -59,6 +61,8 @@ import org.springframework.ai.ollama.OllamaEmbeddingModel;
 import org.springframework.ai.ollama.api.OllamaApi;
 import org.springframework.ai.ollama.api.OllamaOptions;
 import org.springframework.ai.openai.OpenAiChatModel;
+import org.springframework.ai.openai.OpenAiEmbeddingModel;
+import org.springframework.ai.openai.OpenAiEmbeddingOptions;
 import org.springframework.ai.openai.OpenAiImageModel;
 import org.springframework.ai.openai.api.OpenAiApi;
 import org.springframework.ai.openai.api.OpenAiImageApi;
@@ -227,6 +231,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
     }
 
     @Override
+    @SuppressWarnings("EnhancedSwitchMigration")
     public EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url, String model) {
         String cacheKey = buildClientCacheKey(EmbeddingModel.class, platform, apiKey, url, model);
         return Singleton.get(cacheKey, (Func0<EmbeddingModel>) () -> {
@@ -237,10 +242,10 @@ public class AiModelFactoryImpl implements AiModelFactory {
                     return buildYiYanEmbeddingModel(apiKey, model);
                 case ZHI_PU:
                     return buildZhiPuEmbeddingModel(apiKey, url, model);
-//                case OPENAI:
-//                    return buildOpenAiChatModel(apiKey, url);
-//                case AZURE_OPENAI:
-//                    return buildAzureOpenAiChatModel(apiKey, url);
+                case OPENAI:
+                    return buildOpenAiEmbeddingModel(apiKey, url, model);
+                case AZURE_OPENAI:
+                    return buildAzureOpenAiEmbeddingModel(apiKey, url, model);
                 case OLLAMA:
                     return buildOllamaEmbeddingModel(url, model);
                 default:
@@ -474,6 +479,33 @@ public class AiModelFactoryImpl implements AiModelFactory {
         return OllamaEmbeddingModel.builder().ollamaApi(ollamaApi).defaultOptions(ollamaOptions).build();
     }
 
+    /**
+     * 可参考 {@link OpenAiAutoConfiguration} 的 openAiEmbeddingModel 方法
+     */
+    private OpenAiEmbeddingModel buildOpenAiEmbeddingModel(String openAiToken, String url, String model) {
+        url = StrUtil.blankToDefault(url, OpenAiApiConstants.DEFAULT_BASE_URL);
+        OpenAiApi openAiApi = OpenAiApi.builder().baseUrl(url).apiKey(openAiToken).build();
+        OpenAiEmbeddingOptions openAiEmbeddingProperties = OpenAiEmbeddingOptions.builder().model(model).build();
+        return new OpenAiEmbeddingModel(openAiApi, MetadataMode.EMBED, openAiEmbeddingProperties);
+    }
+
+    // TODO @芋艿:手头暂时没密钥,使用建议再测试下
+    /**
+     * 可参考 {@link AzureOpenAiAutoConfiguration} 的 azureOpenAiEmbeddingModel 方法
+     */
+    private AzureOpenAiEmbeddingModel buildAzureOpenAiEmbeddingModel(String apiKey, String url, String model) {
+        AzureOpenAiAutoConfiguration azureOpenAiAutoConfiguration = new AzureOpenAiAutoConfiguration();
+        // 创建 OpenAIClient 对象
+        AzureOpenAiConnectionProperties connectionProperties = new AzureOpenAiConnectionProperties();
+        connectionProperties.setApiKey(apiKey);
+        connectionProperties.setEndpoint(url);
+        OpenAIClientBuilder openAIClient = azureOpenAiAutoConfiguration.openAIClientBuilder(connectionProperties, null);
+        // 获取 AzureOpenAiChatProperties 对象
+        AzureOpenAiEmbeddingProperties embeddingProperties = SpringUtil.getBean(AzureOpenAiEmbeddingProperties.class);
+        return azureOpenAiAutoConfiguration.azureOpenAiEmbeddingModel(openAIClient, embeddingProperties,
+                null, null);
+    }
+
     // ========== 各种创建 VectorStore 的方法 ==========
 
     /**