|
@@ -1,456 +0,0 @@
|
|
|
-/*
|
|
|
- * Copyright 2023 - 2024 the original author or authors.
|
|
|
- *
|
|
|
- * Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
- * you may not use this file except in compliance with the License.
|
|
|
- * You may obtain a copy of the License at
|
|
|
- *
|
|
|
- * https://www.apache.org/licenses/LICENSE-2.0
|
|
|
- *
|
|
|
- * Unless required by applicable law or agreed to in writing, software
|
|
|
- * distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
- * See the License for the specific language governing permissions and
|
|
|
- * limitations under the License.
|
|
|
- */
|
|
|
-package org.springframework.ai.vectorstore;
|
|
|
-
|
|
|
-import org.slf4j.Logger;
|
|
|
-import org.slf4j.LoggerFactory;
|
|
|
-import org.springframework.ai.document.Document;
|
|
|
-import org.springframework.ai.embedding.EmbeddingModel;
|
|
|
-import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
|
|
|
-import org.springframework.beans.factory.InitializingBean;
|
|
|
-import org.springframework.util.Assert;
|
|
|
-import org.springframework.util.CollectionUtils;
|
|
|
-import redis.clients.jedis.JedisPooled;
|
|
|
-import redis.clients.jedis.Pipeline;
|
|
|
-import redis.clients.jedis.json.Path2;
|
|
|
-import redis.clients.jedis.search.*;
|
|
|
-import redis.clients.jedis.search.Schema.FieldType;
|
|
|
-import redis.clients.jedis.search.schemafields.*;
|
|
|
-import redis.clients.jedis.search.schemafields.VectorField.VectorAlgorithm;
|
|
|
-
|
|
|
-import java.text.MessageFormat;
|
|
|
-import java.util.*;
|
|
|
-import java.util.function.Function;
|
|
|
-import java.util.function.Predicate;
|
|
|
-import java.util.stream.Collectors;
|
|
|
-
|
|
|
-/**
|
|
|
- * The RedisVectorStore is for managing and querying vector data in a Redis database. It
|
|
|
- * offers functionalities like adding, deleting, and performing similarity searches on
|
|
|
- * documents.
|
|
|
- *
|
|
|
- * The store utilizes RedisJSON and RedisSearch to handle JSON documents and to index and
|
|
|
- * search vector data. It supports various vector algorithms (e.g., FLAT, HSNW) for
|
|
|
- * efficient similarity searches. Additionally, it allows for custom metadata fields in
|
|
|
- * the documents to be stored alongside the vector and content data.
|
|
|
- *
|
|
|
- * This class requires a RedisVectorStoreConfig configuration object for initialization,
|
|
|
- * which includes settings like Redis URI, index name, field names, and vector algorithms.
|
|
|
- * It also requires an EmbeddingModel to convert documents into embeddings before storing
|
|
|
- * them.
|
|
|
- *
|
|
|
- * @author Julien Ruaux
|
|
|
- * @author Christian Tzolov
|
|
|
- * @author Eddú Meléndez
|
|
|
- * @see VectorStore
|
|
|
- * @see RedisVectorStoreConfig
|
|
|
- * @see EmbeddingModel
|
|
|
- */
|
|
|
-public class RedisVectorStore implements VectorStore, InitializingBean {
|
|
|
-
|
|
|
- public enum Algorithm {
|
|
|
-
|
|
|
- FLAT, HSNW
|
|
|
-
|
|
|
- }
|
|
|
-
|
|
|
- public record MetadataField(String name, FieldType fieldType) {
|
|
|
-
|
|
|
- public static MetadataField text(String name) {
|
|
|
- return new MetadataField(name, FieldType.TEXT);
|
|
|
- }
|
|
|
-
|
|
|
- public static MetadataField numeric(String name) {
|
|
|
- return new MetadataField(name, FieldType.NUMERIC);
|
|
|
- }
|
|
|
-
|
|
|
- public static MetadataField tag(String name) {
|
|
|
- return new MetadataField(name, FieldType.TAG);
|
|
|
- }
|
|
|
-
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * Configuration for the Redis vector store.
|
|
|
- */
|
|
|
- public static final class RedisVectorStoreConfig {
|
|
|
-
|
|
|
- private final String indexName;
|
|
|
-
|
|
|
- private final String prefix;
|
|
|
-
|
|
|
- private final String contentFieldName;
|
|
|
-
|
|
|
- private final String embeddingFieldName;
|
|
|
-
|
|
|
- private final Algorithm vectorAlgorithm;
|
|
|
-
|
|
|
- private final List<MetadataField> metadataFields;
|
|
|
-
|
|
|
- private RedisVectorStoreConfig() {
|
|
|
- this(builder());
|
|
|
- }
|
|
|
-
|
|
|
- private RedisVectorStoreConfig(Builder builder) {
|
|
|
- this.indexName = builder.indexName;
|
|
|
- this.prefix = builder.prefix;
|
|
|
- this.contentFieldName = builder.contentFieldName;
|
|
|
- this.embeddingFieldName = builder.embeddingFieldName;
|
|
|
- this.vectorAlgorithm = builder.vectorAlgorithm;
|
|
|
- this.metadataFields = builder.metadataFields;
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * Start building a new configuration.
|
|
|
- * @return The entry point for creating a new configuration.
|
|
|
- */
|
|
|
- public static Builder builder() {
|
|
|
-
|
|
|
- return new Builder();
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * {@return the default config}
|
|
|
- */
|
|
|
- public static RedisVectorStoreConfig defaultConfig() {
|
|
|
-
|
|
|
- return builder().build();
|
|
|
- }
|
|
|
-
|
|
|
- public static class Builder {
|
|
|
-
|
|
|
- private String indexName = DEFAULT_INDEX_NAME;
|
|
|
-
|
|
|
- private String prefix = DEFAULT_PREFIX;
|
|
|
-
|
|
|
- private String contentFieldName = DEFAULT_CONTENT_FIELD_NAME;
|
|
|
-
|
|
|
- private String embeddingFieldName = DEFAULT_EMBEDDING_FIELD_NAME;
|
|
|
-
|
|
|
- private Algorithm vectorAlgorithm = DEFAULT_VECTOR_ALGORITHM;
|
|
|
-
|
|
|
- private List<MetadataField> metadataFields = new ArrayList<>();
|
|
|
-
|
|
|
- private Builder() {
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * Configures the Redis index name to use.
|
|
|
- * @param name the index name to use
|
|
|
- * @return this builder
|
|
|
- */
|
|
|
- public Builder withIndexName(String name) {
|
|
|
- this.indexName = name;
|
|
|
- return this;
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * Configures the Redis key prefix to use (default: "embedding:").
|
|
|
- * @param prefix the prefix to use
|
|
|
- * @return this builder
|
|
|
- */
|
|
|
- public Builder withPrefix(String prefix) {
|
|
|
- this.prefix = prefix;
|
|
|
- return this;
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * Configures the Redis content field name to use.
|
|
|
- * @param name the content field name to use
|
|
|
- * @return this builder
|
|
|
- */
|
|
|
- public Builder withContentFieldName(String name) {
|
|
|
- this.contentFieldName = name;
|
|
|
- return this;
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * Configures the Redis embedding field name to use.
|
|
|
- * @param name the embedding field name to use
|
|
|
- * @return this builder
|
|
|
- */
|
|
|
- public Builder withEmbeddingFieldName(String name) {
|
|
|
- this.embeddingFieldName = name;
|
|
|
- return this;
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * Configures the Redis vector algorithmto use.
|
|
|
- * @param algorithm the vector algorithm to use
|
|
|
- * @return this builder
|
|
|
- */
|
|
|
- public Builder withVectorAlgorithm(Algorithm algorithm) {
|
|
|
- this.vectorAlgorithm = algorithm;
|
|
|
- return this;
|
|
|
- }
|
|
|
-
|
|
|
- public Builder withMetadataFields(MetadataField... fields) {
|
|
|
- return withMetadataFields(Arrays.asList(fields));
|
|
|
- }
|
|
|
-
|
|
|
- public Builder withMetadataFields(List<MetadataField> fields) {
|
|
|
- this.metadataFields = fields;
|
|
|
- return this;
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * {@return the immutable configuration}
|
|
|
- */
|
|
|
- public RedisVectorStoreConfig build() {
|
|
|
-
|
|
|
- return new RedisVectorStoreConfig(this);
|
|
|
- }
|
|
|
-
|
|
|
- }
|
|
|
-
|
|
|
- }
|
|
|
-
|
|
|
- private final boolean initializeSchema;
|
|
|
-
|
|
|
- public static final String DEFAULT_INDEX_NAME = "spring-ai-index";
|
|
|
-
|
|
|
- public static final String DEFAULT_CONTENT_FIELD_NAME = "content";
|
|
|
-
|
|
|
- public static final String DEFAULT_EMBEDDING_FIELD_NAME = "embedding";
|
|
|
-
|
|
|
- public static final String DEFAULT_PREFIX = "embedding:";
|
|
|
-
|
|
|
- public static final Algorithm DEFAULT_VECTOR_ALGORITHM = Algorithm.HSNW;
|
|
|
-
|
|
|
- private static final String QUERY_FORMAT = "%s=>[KNN %s @%s $%s AS %s]";
|
|
|
-
|
|
|
- private static final Path2 JSON_SET_PATH = Path2.of("$");
|
|
|
-
|
|
|
- private static final String JSON_PATH_PREFIX = "$.";
|
|
|
-
|
|
|
- private static final Logger logger = LoggerFactory.getLogger(RedisVectorStore.class);
|
|
|
-
|
|
|
- private static final Predicate<Object> RESPONSE_OK = Predicate.isEqual("OK");
|
|
|
-
|
|
|
- private static final Predicate<Object> RESPONSE_DEL_OK = Predicate.isEqual(1l);
|
|
|
-
|
|
|
- private static final String VECTOR_TYPE_FLOAT32 = "FLOAT32";
|
|
|
-
|
|
|
- private static final String EMBEDDING_PARAM_NAME = "BLOB";
|
|
|
-
|
|
|
- public static final String DISTANCE_FIELD_NAME = "vector_score";
|
|
|
-
|
|
|
- private static final String DEFAULT_DISTANCE_METRIC = "COSINE";
|
|
|
-
|
|
|
- private final JedisPooled jedis;
|
|
|
-
|
|
|
- private final EmbeddingModel embeddingModel;
|
|
|
-
|
|
|
- private final RedisVectorStoreConfig config;
|
|
|
-
|
|
|
- private FilterExpressionConverter filterExpressionConverter;
|
|
|
-
|
|
|
- public RedisVectorStore(RedisVectorStoreConfig config, EmbeddingModel embeddingModel, JedisPooled jedis,
|
|
|
- boolean initializeSchema) {
|
|
|
-
|
|
|
- Assert.notNull(config, "Config must not be null");
|
|
|
- Assert.notNull(embeddingModel, "Embedding model must not be null");
|
|
|
- this.initializeSchema = initializeSchema;
|
|
|
-
|
|
|
- this.jedis = jedis;
|
|
|
- this.embeddingModel = embeddingModel;
|
|
|
- this.config = config;
|
|
|
- this.filterExpressionConverter = new RedisFilterExpressionConverter(this.config.metadataFields);
|
|
|
- }
|
|
|
-
|
|
|
- public JedisPooled getJedis() {
|
|
|
- return this.jedis;
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public void add(List<Document> documents) {
|
|
|
- try (Pipeline pipeline = this.jedis.pipelined()) {
|
|
|
- for (Document document : documents) {
|
|
|
- var embedding = this.embeddingModel.embed(document);
|
|
|
- document.setEmbedding(embedding);
|
|
|
-
|
|
|
- var fields = new HashMap<String, Object>();
|
|
|
- fields.put(this.config.embeddingFieldName, embedding);
|
|
|
- fields.put(this.config.contentFieldName, document.getContent());
|
|
|
- fields.putAll(document.getMetadata());
|
|
|
- pipeline.jsonSetWithEscape(key(document.getId()), JSON_SET_PATH, fields);
|
|
|
- }
|
|
|
- List<Object> responses = pipeline.syncAndReturnAll();
|
|
|
- Optional<Object> errResponse = responses.stream().filter(Predicate.not(RESPONSE_OK)).findAny();
|
|
|
- if (errResponse.isPresent()) {
|
|
|
- String message = MessageFormat.format("Could not add document: {0}", errResponse.get());
|
|
|
- if (logger.isErrorEnabled()) {
|
|
|
- logger.error(message);
|
|
|
- }
|
|
|
- throw new RuntimeException(message);
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- private String key(String id) {
|
|
|
- return this.config.prefix + id;
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public Optional<Boolean> delete(List<String> idList) {
|
|
|
- try (Pipeline pipeline = this.jedis.pipelined()) {
|
|
|
- for (String id : idList) {
|
|
|
- pipeline.jsonDel(key(id));
|
|
|
- }
|
|
|
- List<Object> responses = pipeline.syncAndReturnAll();
|
|
|
- Optional<Object> errResponse = responses.stream().filter(Predicate.not(RESPONSE_DEL_OK)).findAny();
|
|
|
- if (errResponse.isPresent()) {
|
|
|
- if (logger.isErrorEnabled()) {
|
|
|
- logger.error("Could not delete document: {}", errResponse.get());
|
|
|
- }
|
|
|
- return Optional.of(false);
|
|
|
- }
|
|
|
- return Optional.of(true);
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public List<Document> similaritySearch(SearchRequest request) {
|
|
|
-
|
|
|
- Assert.isTrue(request.getTopK() > 0, "The number of documents to returned must be greater than zero");
|
|
|
- Assert.isTrue(request.getSimilarityThreshold() >= 0 && request.getSimilarityThreshold() <= 1,
|
|
|
- "The similarity score is bounded between 0 and 1; least to most similar respectively.");
|
|
|
-
|
|
|
- String filter = nativeExpressionFilter(request);
|
|
|
-
|
|
|
- String queryString = String.format(QUERY_FORMAT, filter, request.getTopK(), this.config.embeddingFieldName,
|
|
|
- EMBEDDING_PARAM_NAME, DISTANCE_FIELD_NAME);
|
|
|
-
|
|
|
- List<String> returnFields = new ArrayList<>();
|
|
|
- this.config.metadataFields.stream().map(MetadataField::name).forEach(returnFields::add);
|
|
|
- returnFields.add(this.config.embeddingFieldName);
|
|
|
- returnFields.add(this.config.contentFieldName);
|
|
|
- returnFields.add(DISTANCE_FIELD_NAME);
|
|
|
- var embedding = toFloatArray(this.embeddingModel.embed(request.getQuery()));
|
|
|
- Query query = new Query(queryString).addParam(EMBEDDING_PARAM_NAME, RediSearchUtil.toByteArray(embedding))
|
|
|
- .returnFields(returnFields.toArray(new String[0]))
|
|
|
- .setSortBy(DISTANCE_FIELD_NAME, true)
|
|
|
- .dialect(2);
|
|
|
-
|
|
|
- SearchResult result = this.jedis.ftSearch(this.config.indexName, query);
|
|
|
- return result.getDocuments()
|
|
|
- .stream()
|
|
|
- .filter(d -> similarityScore(d) >= request.getSimilarityThreshold())
|
|
|
- .map(this::toDocument)
|
|
|
- .toList();
|
|
|
- }
|
|
|
-
|
|
|
- private Document toDocument(redis.clients.jedis.search.Document doc) {
|
|
|
- var id = doc.getId().substring(this.config.prefix.length());
|
|
|
- var content = doc.hasProperty(this.config.contentFieldName) ? doc.getString(this.config.contentFieldName)
|
|
|
- : null;
|
|
|
- Map<String, Object> metadata = this.config.metadataFields.stream()
|
|
|
- .map(MetadataField::name)
|
|
|
- .filter(doc::hasProperty)
|
|
|
- .collect(Collectors.toMap(Function.identity(), doc::getString));
|
|
|
- metadata.put(DISTANCE_FIELD_NAME, 1 - similarityScore(doc));
|
|
|
- return new Document(id, content, metadata);
|
|
|
- }
|
|
|
-
|
|
|
- private float similarityScore(redis.clients.jedis.search.Document doc) {
|
|
|
- return (2 - Float.parseFloat(doc.getString(DISTANCE_FIELD_NAME))) / 2;
|
|
|
- }
|
|
|
-
|
|
|
- private String nativeExpressionFilter(SearchRequest request) {
|
|
|
- if (request.getFilterExpression() == null) {
|
|
|
- return "*";
|
|
|
- }
|
|
|
- return "(" + this.filterExpressionConverter.convertExpression(request.getFilterExpression()) + ")";
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public void afterPropertiesSet() {
|
|
|
-
|
|
|
- if (!this.initializeSchema) {
|
|
|
- return;
|
|
|
- }
|
|
|
-
|
|
|
- // If index already exists don't do anything
|
|
|
- if (this.jedis.ftList().contains(this.config.indexName)) {
|
|
|
- return;
|
|
|
- }
|
|
|
-
|
|
|
- String response = this.jedis.ftCreate(this.config.indexName,
|
|
|
- FTCreateParams.createParams().on(IndexDataType.JSON).addPrefix(this.config.prefix), schemaFields());
|
|
|
- if (!RESPONSE_OK.test(response)) {
|
|
|
- String message = MessageFormat.format("Could not create index: {0}", response);
|
|
|
- throw new RuntimeException(message);
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- private Iterable<SchemaField> schemaFields() {
|
|
|
- Map<String, Object> vectorAttrs = new HashMap<>();
|
|
|
- vectorAttrs.put("DIM", this.embeddingModel.dimensions());
|
|
|
- vectorAttrs.put("DISTANCE_METRIC", DEFAULT_DISTANCE_METRIC);
|
|
|
- vectorAttrs.put("TYPE", VECTOR_TYPE_FLOAT32);
|
|
|
- List<SchemaField> fields = new ArrayList<>();
|
|
|
- fields.add(TextField.of(jsonPath(this.config.contentFieldName)).as(this.config.contentFieldName).weight(1.0));
|
|
|
- fields.add(VectorField.builder()
|
|
|
- .fieldName(jsonPath(this.config.embeddingFieldName))
|
|
|
- .algorithm(vectorAlgorithm())
|
|
|
- .attributes(vectorAttrs)
|
|
|
- .as(this.config.embeddingFieldName)
|
|
|
- .build());
|
|
|
-
|
|
|
- if (!CollectionUtils.isEmpty(this.config.metadataFields)) {
|
|
|
- for (MetadataField field : this.config.metadataFields) {
|
|
|
- fields.add(schemaField(field));
|
|
|
- }
|
|
|
- }
|
|
|
- return fields;
|
|
|
- }
|
|
|
-
|
|
|
- private SchemaField schemaField(MetadataField field) {
|
|
|
- String fieldName = jsonPath(field.name);
|
|
|
- switch (field.fieldType) {
|
|
|
- case NUMERIC:
|
|
|
- return NumericField.of(fieldName).as(field.name);
|
|
|
- case TAG:
|
|
|
- return TagField.of(fieldName).as(field.name);
|
|
|
- case TEXT:
|
|
|
- return TextField.of(fieldName).as(field.name);
|
|
|
- default:
|
|
|
- throw new IllegalArgumentException(
|
|
|
- MessageFormat.format("Field {0} has unsupported type {1}", field.name, field.fieldType));
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- private VectorAlgorithm vectorAlgorithm() {
|
|
|
- if (config.vectorAlgorithm == Algorithm.HSNW) {
|
|
|
- return VectorAlgorithm.HNSW;
|
|
|
- }
|
|
|
- return VectorAlgorithm.FLAT;
|
|
|
- }
|
|
|
-
|
|
|
- private String jsonPath(String field) {
|
|
|
- return JSON_PATH_PREFIX + field;
|
|
|
- }
|
|
|
-
|
|
|
- private static float[] toFloatArray(List<Double> embeddingDouble) {
|
|
|
- float[] embeddingFloat = new float[embeddingDouble.size()];
|
|
|
- int i = 0;
|
|
|
- for (Double d : embeddingDouble) {
|
|
|
- embeddingFloat[i++] = d.floatValue();
|
|
|
- }
|
|
|
- return embeddingFloat;
|
|
|
- }
|
|
|
-
|
|
|
-}
|