Jelajahi Sumber

【功能新增】AI:增加 RedisVectorStore 向量库的接入

YunaiV 5 bulan lalu
induk
melakukan
588c9fe323

+ 10 - 4
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentServiceImpl.java

@@ -32,6 +32,7 @@ import org.springframework.stereotype.Service;
 
 import java.util.Collections;
 import java.util.List;
+import java.util.Map;
 import java.util.Objects;
 
 import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
@@ -48,9 +49,14 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGM
 @Slf4j
 public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService {
 
-    public static final String VECTOR_STORE_METADATA_KNOWLEDGE_ID = "knowledgeId";
-    public static final String VECTOR_STORE_METADATA_DOCUMENT_ID = "documentId";
-    public static final String VECTOR_STORE_METADATA_SEGMENT_ID = "segmentId";
+    private static final String VECTOR_STORE_METADATA_KNOWLEDGE_ID = "knowledgeId";
+    private static final String VECTOR_STORE_METADATA_DOCUMENT_ID = "documentId";
+    private static final String VECTOR_STORE_METADATA_SEGMENT_ID = "segmentId";
+
+    private static final Map<String, Class<?>> VECTOR_STORE_METADATA_TYPES = Map.of(
+            VECTOR_STORE_METADATA_KNOWLEDGE_ID, String.class,
+            VECTOR_STORE_METADATA_DOCUMENT_ID, String.class,
+            VECTOR_STORE_METADATA_SEGMENT_ID, String.class);
 
     @Resource
     private AiKnowledgeSegmentMapper segmentMapper;
@@ -257,7 +263,7 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
     }
 
     private VectorStore getVectorStoreById(AiKnowledgeDO knowledge) {
-        return modelService.getOrCreateVectorStore(knowledge.getEmbeddingModelId());
+        return modelService.getOrCreateVectorStore(knowledge.getEmbeddingModelId(), VECTOR_STORE_METADATA_TYPES);
     }
 
     private VectorStore getVectorStoreById(Long knowledgeId) {

+ 3 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiModelService.java

@@ -13,6 +13,7 @@ import org.springframework.ai.vectorstore.VectorStore;
 
 import javax.annotation.Nullable;
 import java.util.List;
+import java.util.Map;
 
 /**
  * AI 模型 Service 接口
@@ -125,8 +126,9 @@ public interface AiModelService {
      * 获得 VectorStore 对象
      *
      * @param id 编号
+     * @param metadataFields 元数据的定义
      * @return VectorStore 对象
      */
-    VectorStore getOrCreateVectorStore(Long id);
+    VectorStore getOrCreateVectorStore(Long id, Map<String, Class<?>> metadataFields);
 
 }

+ 5 - 3
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiModelServiceImpl.java

@@ -22,6 +22,7 @@ import org.springframework.stereotype.Service;
 import org.springframework.validation.annotation.Validated;
 
 import java.util.List;
+import java.util.Map;
 
 import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
 import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
@@ -151,7 +152,7 @@ public class AiModelServiceImpl implements AiModelService {
     }
 
     @Override
-    public VectorStore getOrCreateVectorStore(Long id) {
+    public VectorStore getOrCreateVectorStore(Long id, Map<String, Class<?>> metadataFields) {
         // 获取模型 + 密钥
         AiModelDO model = validateModel(id);
         AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId());
@@ -162,8 +163,9 @@ public class AiModelServiceImpl implements AiModelService {
                 platform, apiKey.getApiKey(), apiKey.getUrl(), model.getModel());
 
         // 创建或获取 VectorStore 对象
-//        return modelFactory.getOrCreateVectorStore(SimpleVectorStore.class, embeddingModel);
-        return modelFactory.getOrCreateVectorStore(QdrantVectorStore.class, embeddingModel);
+//        return modelFactory.getOrCreateVectorStore(SimpleVectorStore.class, embeddingModel, metadataFields);
+        return modelFactory.getOrCreateVectorStore(QdrantVectorStore.class, embeddingModel, metadataFields);
+//        return modelFactory.getOrCreateVectorStore(RedisVectorStore.class, embeddingModel, metadataFields);
     }
 
 }

+ 3 - 28
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java

@@ -12,12 +12,12 @@ import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
 import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.ai.autoconfigure.vectorstore.qdrant.QdrantVectorStoreProperties;
+import org.springframework.ai.autoconfigure.vectorstore.redis.RedisVectorStoreProperties;
 import org.springframework.ai.openai.OpenAiChatModel;
 import org.springframework.ai.openai.OpenAiChatOptions;
 import org.springframework.ai.openai.api.OpenAiApi;
 import org.springframework.ai.tokenizer.JTokkitTokenCountEstimator;
 import org.springframework.ai.tokenizer.TokenCountEstimator;
-import org.springframework.ai.transformer.splitter.TokenTextSplitter;
 import org.springframework.boot.autoconfigure.AutoConfiguration;
 import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
 import org.springframework.boot.context.properties.EnableConfigurationProperties;
@@ -31,7 +31,8 @@ import org.springframework.context.annotation.Lazy;
  */
 @AutoConfiguration
 @EnableConfigurationProperties({YudaoAiProperties.class,
-        QdrantVectorStoreProperties.class // 解析 Qdrant 配置
+        QdrantVectorStoreProperties.class, // 解析 Qdrant 配置
+        RedisVectorStoreProperties.class, // 解析 Redis 配置
 })
 @Slf4j
 public class YudaoAiAutoConfiguration {
@@ -200,32 +201,6 @@ public class YudaoAiAutoConfiguration {
 //        return new TransformersEmbeddingModel(MetadataMode.EMBED);
 //    }
 
-    /**
-     * TODO @xin 默认版本先不弄,目前都先取对应的 EmbeddingModel
-     */
-//    @Bean
-//    @Lazy // TODO 芋艿:临时注释,避免无法启动
-//    public RedisVectorStore vectorStore(TransformersEmbeddingModel embeddingModel, RedisVectorStoreProperties properties,
-//                                        RedisProperties redisProperties) {
-//        var config = RedisVectorStore.RedisVectorStoreConfig.builder()
-//                .withIndexName(properties.getIndex())
-//                .withPrefix(properties.getPrefix())
-//                .withMetadataFields(new RedisVectorStore.MetadataField("knowledgeId", Schema.FieldType.NUMERIC))
-//                .build();
-//
-//        RedisVectorStore redisVectorStore = new RedisVectorStore(config, embeddingModel,
-//                new JedisPooled(redisProperties.getHost(), redisProperties.getPort()),
-//                properties.isInitializeSchema());
-//        redisVectorStore.afterPropertiesSet();
-//        return redisVectorStore;
-//    }
-    @Bean
-    @Lazy // TODO 芋艿:临时注释,避免无法启动
-    public TokenTextSplitter tokenTextSplitter() {
-        //TODO  @xin 配置提取
-        return new TokenTextSplitter(500, 100, 5, 10000, true);
-    }
-
     @Bean
     @Lazy // TODO 芋艿:临时注释,避免无法启动
     public TokenCountEstimator tokenCountEstimator() {

+ 7 - 2
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactory.java

@@ -8,6 +8,8 @@ import org.springframework.ai.embedding.EmbeddingModel;
 import org.springframework.ai.image.ImageModel;
 import org.springframework.ai.vectorstore.VectorStore;
 
+import java.util.Map;
+
 /**
  * AI Model 模型工厂的接口类
  *
@@ -96,13 +98,16 @@ public interface AiModelFactory {
 
     /**
      * 基于指定配置,获得 VectorStore 对象
-     * <p>
+     *
      * 如果不存在,则进行创建
      *
      * @param type           向量存储类型
      * @param embeddingModel 向量模型
+     * @param metadataFields 元数据字段
      * @return VectorStore 对象
      */
-    VectorStore getOrCreateVectorStore(Class<? extends VectorStore> type, EmbeddingModel embeddingModel);
+    VectorStore getOrCreateVectorStore(Class<? extends VectorStore> type,
+                                       EmbeddingModel embeddingModel,
+                                       Map<String, Class<?>> metadataFields);
 
 }

+ 70 - 30
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java

@@ -19,6 +19,7 @@ import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
 import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowChatModel;
 import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
 import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
+import cn.iocoder.yudao.framework.common.util.spring.SpringUtils;
 import com.alibaba.cloud.ai.autoconfigure.dashscope.DashScopeAutoConfiguration;
 import com.alibaba.cloud.ai.dashscope.api.DashScopeApi;
 import com.alibaba.cloud.ai.dashscope.api.DashScopeImageApi;
@@ -39,12 +40,13 @@ import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
 import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration;
 import org.springframework.ai.autoconfigure.vectorstore.qdrant.QdrantVectorStoreAutoConfiguration;
 import org.springframework.ai.autoconfigure.vectorstore.qdrant.QdrantVectorStoreProperties;
+import org.springframework.ai.autoconfigure.vectorstore.redis.RedisVectorStoreAutoConfiguration;
+import org.springframework.ai.autoconfigure.vectorstore.redis.RedisVectorStoreProperties;
 import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration;
 import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties;
 import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
 import org.springframework.ai.chat.model.ChatModel;
 import org.springframework.ai.document.MetadataMode;
-import org.springframework.ai.embedding.BatchingStrategy;
 import org.springframework.ai.embedding.EmbeddingModel;
 import org.springframework.ai.image.ImageModel;
 import org.springframework.ai.ollama.OllamaChatModel;
@@ -67,20 +69,26 @@ import org.springframework.ai.vectorstore.VectorStore;
 import org.springframework.ai.vectorstore.observation.DefaultVectorStoreObservationConvention;
 import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
 import org.springframework.ai.vectorstore.qdrant.QdrantVectorStore;
+import org.springframework.ai.vectorstore.redis.RedisVectorStore;
 import org.springframework.ai.zhipuai.ZhiPuAiChatModel;
 import org.springframework.ai.zhipuai.ZhiPuAiImageModel;
 import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
 import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi;
 import org.springframework.beans.BeansException;
 import org.springframework.beans.factory.ObjectProvider;
+import org.springframework.boot.autoconfigure.data.redis.RedisProperties;
 import org.springframework.web.client.RestClient;
+import redis.clients.jedis.JedisPooled;
 
 import java.io.File;
 import java.time.Duration;
 import java.util.List;
+import java.util.Map;
 import java.util.Timer;
 import java.util.TimerTask;
 
+import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
+
 /**
  * AI Model 模型工厂的实现类
  *
@@ -225,7 +233,9 @@ public class AiModelFactoryImpl implements AiModelFactory {
     }
 
     @Override
-    public VectorStore getOrCreateVectorStore(Class<? extends VectorStore> type, EmbeddingModel embeddingModel) {
+    public VectorStore getOrCreateVectorStore(Class<? extends VectorStore> type,
+            EmbeddingModel embeddingModel,
+            Map<String, Class<?>> metadataFields) {
         String cacheKey = buildClientCacheKey(VectorStore.class, embeddingModel, type);
         return Singleton.get(cacheKey, (Func0<VectorStore>) () -> {
             if (type == SimpleVectorStore.class) {
@@ -234,23 +244,10 @@ public class AiModelFactoryImpl implements AiModelFactory {
             if (type == QdrantVectorStore.class) {
                 return buildQdrantVectorStore(embeddingModel);
             }
+            if (type == RedisVectorStore.class) {
+                return buildRedisVectorStore(embeddingModel, metadataFields);
+            }
             throw new IllegalArgumentException(StrUtil.format("未知类型({})", type));
-            // TODO @芋艿:先临时使用 store
-            // TODO @芋艿:@xin:后续看看,是不是切到阿里云之类的
-            // String prefix = StrUtil.format("{}#{}:", platform.getPlatform(), apiKey);
-            // var config = RedisVectorStore.RedisVectorStoreConfig.builder()
-            // .withIndexName(cacheKey)
-            // .withPrefix(prefix)
-            // .withMetadataFields(new RedisVectorStore.MetadataField("knowledgeId",
-            // Schema.FieldType.NUMERIC))
-            // .build();
-            // RedisProperties redisProperties = SpringUtils.getBean(RedisProperties.class);
-            // RedisVectorStore redisVectorStore = new RedisVectorStore(config,
-            // embeddingModel,
-            // new JedisPooled(redisProperties.getHost(), redisProperties.getPort()),
-            // true);
-            // redisVectorStore.afterPropertiesSet();
-            // return redisVectorStore;
         });
     }
 
@@ -469,21 +466,65 @@ public class AiModelFactoryImpl implements AiModelFactory {
         return vectorStore;
     }
 
+    /**
+     * 参考 {@link QdrantVectorStoreAutoConfiguration} 的 vectorStore 方法
+     */
+    @SneakyThrows
     private QdrantVectorStore buildQdrantVectorStore(EmbeddingModel embeddingModel) {
         QdrantVectorStoreAutoConfiguration configuration = new QdrantVectorStoreAutoConfiguration();
-        QdrantVectorStoreProperties vectorStoreProperties = SpringUtil.getBean(QdrantVectorStoreProperties.class);
+        QdrantVectorStoreProperties properties = SpringUtil.getBean(QdrantVectorStoreProperties.class);
         // 参考 QdrantVectorStoreAutoConfiguration 实现,创建 QdrantClient 对象
         QdrantGrpcClient.Builder grpcClientBuilder = QdrantGrpcClient.newBuilder(
-                vectorStoreProperties.getHost(), vectorStoreProperties.getPort(), vectorStoreProperties.isUseTls());
-        if (StrUtil.isNotEmpty(vectorStoreProperties.getApiKey())) {
-            grpcClientBuilder.withApiKey(vectorStoreProperties.getApiKey());
+                properties.getHost(), properties.getPort(), properties.isUseTls());
+        if (StrUtil.isNotEmpty(properties.getApiKey())) {
+            grpcClientBuilder.withApiKey(properties.getApiKey());
         }
         QdrantClient qdrantClient = new QdrantClient(grpcClientBuilder.build());
-        // 参考 QdrantVectorStoreAutoConfiguration 实现,实现 batchingStrategy
-        BatchingStrategy batchingStrategy = ReflectUtil.invoke(configuration, "batchingStrategy");
-
         // 创建 QdrantVectorStore 对象
-        ObjectProvider<ObservationRegistry> observationRegistry = new ObjectProvider<>() {
+        QdrantVectorStore vectorStore = configuration.vectorStore(embeddingModel, properties, qdrantClient,
+                getObservationRegistry(), getCustomObservationConvention(),
+                ReflectUtil.invoke(configuration, "batchingStrategy"));
+        // 初始化索引
+        vectorStore.afterPropertiesSet();
+        return vectorStore;
+    }
+
+    /**
+     * 参考 {@link RedisVectorStoreAutoConfiguration} 的 vectorStore 方法
+     */
+    private RedisVectorStore buildRedisVectorStore(EmbeddingModel embeddingModel,
+                                                   Map<String, Class<?>> metadataFields) {
+        // 创建 JedisPooled 对象
+        RedisProperties redisProperties = SpringUtils.getBean(RedisProperties.class);
+        JedisPooled jedisPooled = new JedisPooled(redisProperties.getHost(), redisProperties.getPort());
+        // 创建 RedisVectorStoreProperties 对象
+        RedisVectorStoreAutoConfiguration configuration = new RedisVectorStoreAutoConfiguration();
+        RedisVectorStoreProperties properties = SpringUtil.getBean(RedisVectorStoreProperties.class);
+        RedisVectorStore redisVectorStore = RedisVectorStore.builder(jedisPooled, embeddingModel)
+                .indexName(properties.getIndex()).prefix(properties.getPrefix())
+                .initializeSchema(properties.isInitializeSchema())
+                .metadataFields(convertList(metadataFields.entrySet(), entry -> {
+                    String fieldName = entry.getKey();
+                    Class<?> fieldType = entry.getValue();
+                    if (Number.class.isAssignableFrom(fieldType)) {
+                        return RedisVectorStore.MetadataField.numeric(fieldName);
+                    }
+                    if (Boolean.class.isAssignableFrom(fieldType)) {
+                        return RedisVectorStore.MetadataField.tag(fieldName);
+                    }
+                    return RedisVectorStore.MetadataField.text(fieldName);
+                }))
+                .observationRegistry(getObservationRegistry().getObject())
+                .customObservationConvention(getCustomObservationConvention().getObject())
+                .batchingStrategy(ReflectUtil.invoke(configuration, "batchingStrategy"))
+                .build();
+        // 初始化索引
+        redisVectorStore.afterPropertiesSet();
+        return redisVectorStore;
+    }
+
+    private static ObjectProvider<ObservationRegistry> getObservationRegistry() {
+        return new ObjectProvider<>() {
 
             @Override
             public ObservationRegistry getObject() throws BeansException {
@@ -491,16 +532,15 @@ public class AiModelFactoryImpl implements AiModelFactory {
             }
 
         };
-        ObjectProvider <VectorStoreObservationConvention> customObservationConvention = new ObjectProvider<>() {
+    }
 
+    private static ObjectProvider<VectorStoreObservationConvention> getCustomObservationConvention() {
+        return new ObjectProvider<>() {
             @Override
             public VectorStoreObservationConvention getObject() throws BeansException {
                 return new DefaultVectorStoreObservationConvention();
             }
-
         };
-        return configuration.vectorStore(embeddingModel, vectorStoreProperties, qdrantClient,
-                observationRegistry, customObservationConvention, batchingStrategy);
     }
 
 }

+ 5 - 3
yudao-server/src/main/resources/application.yaml

@@ -149,10 +149,12 @@ spring:
   ai:
     vectorstore: # 向量存储
       redis:
-        index: default-index
-        prefix: "default:"
+        initialize-schema: true
+        index: knowledge_index # Redis 中向量索引的名称:用于存储和检索向量数据的索引标识符,所有相关的向量搜索操作都会基于这个索引进行
+        prefix: "knowledge_segment:" # Redis 中存储向量数据的键名前缀:这个前缀会添加到每个存储在 Redis 中的向量数据键名前,每个 document 都是一个 hash 结构
       qdrant:
-        collection-name: knowledge_segment
+        initialize-schema: true
+        collection-name: knowledge_segment # Qdrant 中向量集合的名称:用于存储向量数据的集合标识符,所有相关的向量操作都会在这个集合中进行
         host: 127.0.0.1
         port: 6334
         use-tls: false