|
@@ -5,6 +5,7 @@ import cn.hutool.core.lang.Assert;
|
|
|
import cn.hutool.core.lang.Singleton;
|
|
|
import cn.hutool.core.lang.func.Func0;
|
|
|
import cn.hutool.core.util.ArrayUtil;
|
|
|
+import cn.hutool.core.util.ReflectUtil;
|
|
|
import cn.hutool.core.util.RuntimeUtil;
|
|
|
import cn.hutool.core.util.StrUtil;
|
|
|
import cn.hutool.extra.spring.SpringUtil;
|
|
@@ -26,6 +27,9 @@ import com.alibaba.cloud.ai.dashscope.embedding.DashScopeEmbeddingModel;
|
|
|
import com.alibaba.cloud.ai.dashscope.embedding.DashScopeEmbeddingOptions;
|
|
|
import com.alibaba.cloud.ai.dashscope.image.DashScopeImageModel;
|
|
|
import com.azure.ai.openai.OpenAIClientBuilder;
|
|
|
+import io.micrometer.observation.ObservationRegistry;
|
|
|
+import io.qdrant.client.QdrantClient;
|
|
|
+import io.qdrant.client.QdrantGrpcClient;
|
|
|
import lombok.SneakyThrows;
|
|
|
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration;
|
|
|
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties;
|
|
@@ -33,11 +37,14 @@ import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiConnectionPr
|
|
|
import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration;
|
|
|
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
|
|
|
import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration;
|
|
|
+import org.springframework.ai.autoconfigure.vectorstore.qdrant.QdrantVectorStoreAutoConfiguration;
|
|
|
+import org.springframework.ai.autoconfigure.vectorstore.qdrant.QdrantVectorStoreProperties;
|
|
|
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration;
|
|
|
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties;
|
|
|
import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
|
|
|
import org.springframework.ai.chat.model.ChatModel;
|
|
|
import org.springframework.ai.document.MetadataMode;
|
|
|
+import org.springframework.ai.embedding.BatchingStrategy;
|
|
|
import org.springframework.ai.embedding.EmbeddingModel;
|
|
|
import org.springframework.ai.image.ImageModel;
|
|
|
import org.springframework.ai.ollama.OllamaChatModel;
|
|
@@ -57,10 +64,15 @@ import org.springframework.ai.stabilityai.StabilityAiImageModel;
|
|
|
import org.springframework.ai.stabilityai.api.StabilityAiApi;
|
|
|
import org.springframework.ai.vectorstore.SimpleVectorStore;
|
|
|
import org.springframework.ai.vectorstore.VectorStore;
|
|
|
+import org.springframework.ai.vectorstore.observation.DefaultVectorStoreObservationConvention;
|
|
|
+import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
|
|
|
+import org.springframework.ai.vectorstore.qdrant.QdrantVectorStore;
|
|
|
import org.springframework.ai.zhipuai.ZhiPuAiChatModel;
|
|
|
import org.springframework.ai.zhipuai.ZhiPuAiImageModel;
|
|
|
import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
|
|
|
import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi;
|
|
|
+import org.springframework.beans.BeansException;
|
|
|
+import org.springframework.beans.factory.ObjectProvider;
|
|
|
import org.springframework.web.client.RestClient;
|
|
|
|
|
|
import java.io.File;
|
|
@@ -214,13 +226,14 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
|
|
|
|
|
@Override
|
|
|
public VectorStore getOrCreateVectorStore(Class<? extends VectorStore> type, EmbeddingModel embeddingModel) {
|
|
|
- // String cacheKey = buildClientCacheKey(VectorStore.class, platform, apiKey,
|
|
|
- // url);
|
|
|
String cacheKey = buildClientCacheKey(VectorStore.class, embeddingModel, type);
|
|
|
return Singleton.get(cacheKey, (Func0<VectorStore>) () -> {
|
|
|
if (type == SimpleVectorStore.class) {
|
|
|
return buildSimpleVectorStore(embeddingModel);
|
|
|
}
|
|
|
+ if (type == QdrantVectorStore.class) {
|
|
|
+ return buildQdrantVectorStore(embeddingModel);
|
|
|
+ }
|
|
|
throw new IllegalArgumentException(StrUtil.format("未知类型({})", type));
|
|
|
// TODO @芋艿:先临时使用 store
|
|
|
// TODO @芋艿:@xin:后续看看,是不是切到阿里云之类的
|
|
@@ -456,4 +469,38 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
|
|
return vectorStore;
|
|
|
}
|
|
|
|
|
|
+ private QdrantVectorStore buildQdrantVectorStore(EmbeddingModel embeddingModel) {
|
|
|
+ QdrantVectorStoreAutoConfiguration configuration = new QdrantVectorStoreAutoConfiguration();
|
|
|
+ QdrantVectorStoreProperties vectorStoreProperties = SpringUtil.getBean(QdrantVectorStoreProperties.class);
|
|
|
+ // 参考 QdrantVectorStoreAutoConfiguration 实现,创建 QdrantClient 对象
|
|
|
+ QdrantGrpcClient.Builder grpcClientBuilder = QdrantGrpcClient.newBuilder(
|
|
|
+ vectorStoreProperties.getHost(), vectorStoreProperties.getPort(), vectorStoreProperties.isUseTls());
|
|
|
+ if (StrUtil.isNotEmpty(vectorStoreProperties.getApiKey())) {
|
|
|
+ grpcClientBuilder.withApiKey(vectorStoreProperties.getApiKey());
|
|
|
+ }
|
|
|
+ QdrantClient qdrantClient = new QdrantClient(grpcClientBuilder.build());
|
|
|
+ // 参考 QdrantVectorStoreAutoConfiguration 实现,实现 batchingStrategy
|
|
|
+ BatchingStrategy batchingStrategy = ReflectUtil.invoke(configuration, "batchingStrategy");
|
|
|
+
|
|
|
+ // 创建 QdrantVectorStore 对象
|
|
|
+ ObjectProvider<ObservationRegistry> observationRegistry = new ObjectProvider<>() {
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public ObservationRegistry getObject() throws BeansException {
|
|
|
+ return SpringUtil.getBean(ObservationRegistry.class);
|
|
|
+ }
|
|
|
+
|
|
|
+ };
|
|
|
+ ObjectProvider <VectorStoreObservationConvention> customObservationConvention = new ObjectProvider<>() {
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public VectorStoreObservationConvention getObject() throws BeansException {
|
|
|
+ return new DefaultVectorStoreObservationConvention();
|
|
|
+ }
|
|
|
+
|
|
|
+ };
|
|
|
+ return configuration.vectorStore(embeddingModel, vectorStoreProperties, qdrantClient,
|
|
|
+ observationRegistry, customObservationConvention, batchingStrategy);
|
|
|
+ }
|
|
|
+
|
|
|
}
|