|
@@ -1,9 +1,11 @@
|
|
|
package cn.iocoder.yudao.framework.ai.core.factory;
|
|
|
|
|
|
+import cn.hutool.core.io.FileUtil;
|
|
|
import cn.hutool.core.lang.Assert;
|
|
|
import cn.hutool.core.lang.Singleton;
|
|
|
import cn.hutool.core.lang.func.Func0;
|
|
|
import cn.hutool.core.util.ArrayUtil;
|
|
|
+import cn.hutool.core.util.RuntimeUtil;
|
|
|
import cn.hutool.core.util.StrUtil;
|
|
|
import cn.hutool.extra.spring.SpringUtil;
|
|
|
import cn.iocoder.yudao.framework.ai.config.YudaoAiAutoConfiguration;
|
|
@@ -24,6 +26,7 @@ import com.alibaba.cloud.ai.dashscope.embedding.DashScopeEmbeddingModel;
|
|
|
import com.alibaba.cloud.ai.dashscope.embedding.DashScopeEmbeddingOptions;
|
|
|
import com.alibaba.cloud.ai.dashscope.image.DashScopeImageModel;
|
|
|
import com.azure.ai.openai.OpenAIClientBuilder;
|
|
|
+import lombok.SneakyThrows;
|
|
|
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration;
|
|
|
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties;
|
|
|
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiConnectionProperties;
|
|
@@ -60,7 +63,11 @@ import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
|
|
|
import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi;
|
|
|
import org.springframework.web.client.RestClient;
|
|
|
|
|
|
+import java.io.File;
|
|
|
+import java.time.Duration;
|
|
|
import java.util.List;
|
|
|
+import java.util.Timer;
|
|
|
+import java.util.TimerTask;
|
|
|
|
|
|
/**
|
|
|
* AI Model 模型工厂的实现类
|
|
@@ -73,7 +80,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
|
|
public ChatModel getOrCreateChatModel(AiPlatformEnum platform, String apiKey, String url) {
|
|
|
String cacheKey = buildClientCacheKey(ChatModel.class, platform, apiKey, url);
|
|
|
return Singleton.get(cacheKey, (Func0<ChatModel>) () -> {
|
|
|
- //noinspection EnhancedSwitchMigration
|
|
|
+ // noinspection EnhancedSwitchMigration
|
|
|
switch (platform) {
|
|
|
case TONG_YI:
|
|
|
return buildTongYiChatModel(apiKey);
|
|
@@ -105,7 +112,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
|
|
|
|
|
@Override
|
|
|
public ChatModel getDefaultChatModel(AiPlatformEnum platform) {
|
|
|
- //noinspection EnhancedSwitchMigration
|
|
|
+ // noinspection EnhancedSwitchMigration
|
|
|
switch (platform) {
|
|
|
case TONG_YI:
|
|
|
return SpringUtil.getBean(DashScopeChatModel.class);
|
|
@@ -136,7 +143,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
|
|
|
|
|
@Override
|
|
|
public ImageModel getDefaultImageModel(AiPlatformEnum platform) {
|
|
|
- //noinspection EnhancedSwitchMigration
|
|
|
+ // noinspection EnhancedSwitchMigration
|
|
|
switch (platform) {
|
|
|
case TONG_YI:
|
|
|
return SpringUtil.getBean(DashScopeImageModel.class);
|
|
@@ -155,7 +162,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
|
|
|
|
|
@Override
|
|
|
public ImageModel getOrCreateImageModel(AiPlatformEnum platform, String apiKey, String url) {
|
|
|
- //noinspection EnhancedSwitchMigration
|
|
|
+ // noinspection EnhancedSwitchMigration
|
|
|
switch (platform) {
|
|
|
case TONG_YI:
|
|
|
return buildTongYiImagesModel(apiKey);
|
|
@@ -174,9 +181,11 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
|
|
|
|
|
@Override
|
|
|
public MidjourneyApi getOrCreateMidjourneyApi(String apiKey, String url) {
|
|
|
- String cacheKey = buildClientCacheKey(MidjourneyApi.class, AiPlatformEnum.MIDJOURNEY.getPlatform(), apiKey, url);
|
|
|
+ String cacheKey = buildClientCacheKey(MidjourneyApi.class, AiPlatformEnum.MIDJOURNEY.getPlatform(), apiKey,
|
|
|
+ url);
|
|
|
return Singleton.get(cacheKey, (Func0<MidjourneyApi>) () -> {
|
|
|
- YudaoAiProperties.MidjourneyProperties properties = SpringUtil.getBean(YudaoAiProperties.class).getMidjourney();
|
|
|
+ YudaoAiProperties.MidjourneyProperties properties = SpringUtil.getBean(YudaoAiProperties.class)
|
|
|
+ .getMidjourney();
|
|
|
return new MidjourneyApi(url, apiKey, properties.getNotifyUrl());
|
|
|
});
|
|
|
}
|
|
@@ -204,25 +213,31 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
- public VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel) {
|
|
|
-// String cacheKey = buildClientCacheKey(VectorStore.class, platform, apiKey, url);
|
|
|
- String cacheKey = buildClientCacheKey(VectorStore.class, embeddingModel);
|
|
|
+ public VectorStore getOrCreateVectorStore(Class<? extends VectorStore> type, EmbeddingModel embeddingModel) {
|
|
|
+ // String cacheKey = buildClientCacheKey(VectorStore.class, platform, apiKey,
|
|
|
+ // url);
|
|
|
+ String cacheKey = buildClientCacheKey(VectorStore.class, embeddingModel, type);
|
|
|
return Singleton.get(cacheKey, (Func0<VectorStore>) () -> {
|
|
|
+ if (type == SimpleVectorStore.class) {
|
|
|
+ return buildSimpleVectorStore(embeddingModel);
|
|
|
+ }
|
|
|
+ throw new IllegalArgumentException(StrUtil.format("未知类型({})", type));
|
|
|
// TODO @芋艿:先临时使用 store
|
|
|
- return SimpleVectorStore.builder(embeddingModel).build();
|
|
|
// TODO @芋艿:@xin:后续看看,是不是切到阿里云之类的
|
|
|
-// String prefix = StrUtil.format("{}#{}:", platform.getPlatform(), apiKey);
|
|
|
-// var config = RedisVectorStore.RedisVectorStoreConfig.builder()
|
|
|
-// .withIndexName(cacheKey)
|
|
|
-// .withPrefix(prefix)
|
|
|
-// .withMetadataFields(new RedisVectorStore.MetadataField("knowledgeId", Schema.FieldType.NUMERIC))
|
|
|
-// .build();
|
|
|
-// RedisProperties redisProperties = SpringUtils.getBean(RedisProperties.class);
|
|
|
-// RedisVectorStore redisVectorStore = new RedisVectorStore(config, embeddingModel,
|
|
|
-// new JedisPooled(redisProperties.getHost(), redisProperties.getPort()),
|
|
|
-// true);
|
|
|
-// redisVectorStore.afterPropertiesSet();
|
|
|
-// return redisVectorStore;
|
|
|
+ // String prefix = StrUtil.format("{}#{}:", platform.getPlatform(), apiKey);
|
|
|
+ // var config = RedisVectorStore.RedisVectorStoreConfig.builder()
|
|
|
+ // .withIndexName(cacheKey)
|
|
|
+ // .withPrefix(prefix)
|
|
|
+ // .withMetadataFields(new RedisVectorStore.MetadataField("knowledgeId",
|
|
|
+ // Schema.FieldType.NUMERIC))
|
|
|
+ // .build();
|
|
|
+ // RedisProperties redisProperties = SpringUtils.getBean(RedisProperties.class);
|
|
|
+ // RedisVectorStore redisVectorStore = new RedisVectorStore(config,
|
|
|
+ // embeddingModel,
|
|
|
+ // new JedisPooled(redisProperties.getHost(), redisProperties.getPort()),
|
|
|
+ // true);
|
|
|
+ // redisVectorStore.afterPropertiesSet();
|
|
|
+ // return redisVectorStore;
|
|
|
});
|
|
|
}
|
|
|
|
|
@@ -307,7 +322,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
|
|
*/
|
|
|
private ChatModel buildSiliconFlowChatModel(String apiKey) {
|
|
|
YudaoAiProperties.SiliconFlowProperties properties = new YudaoAiProperties.SiliconFlowProperties()
|
|
|
- .setApiKey(apiKey);
|
|
|
+ .setApiKey(apiKey);
|
|
|
return new YudaoAiAutoConfiguration().buildSiliconFlowChatClient(properties);
|
|
|
}
|
|
|
|
|
@@ -397,7 +412,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
|
|
*/
|
|
|
private DashScopeEmbeddingModel buildTongYiEmbeddingModel(String apiKey, String model) {
|
|
|
DashScopeApi dashScopeApi = new DashScopeApi(apiKey);
|
|
|
- DashScopeEmbeddingOptions dashScopeEmbeddingOptions = DashScopeEmbeddingOptions.builder().withModel(model).build();
|
|
|
+ DashScopeEmbeddingOptions dashScopeEmbeddingOptions = DashScopeEmbeddingOptions.builder().withModel(model)
|
|
|
+ .build();
|
|
|
return new DashScopeEmbeddingModel(dashScopeApi, MetadataMode.EMBED, dashScopeEmbeddingOptions);
|
|
|
}
|
|
|
|
|
@@ -407,4 +423,58 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
|
|
return OllamaEmbeddingModel.builder().ollamaApi(ollamaApi).defaultOptions(ollamaOptions).build();
|
|
|
}
|
|
|
|
|
|
+ // ========== 各种创建 VectorStore 的方法 ==========
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 注意:仅适合本地测试使用,生产建议还是使用 Qdrant、Milvus 等
|
|
|
+ */
|
|
|
+ @SneakyThrows
|
|
|
+ @SuppressWarnings("ResultOfMethodCallIgnored")
|
|
|
+ private SimpleVectorStore buildSimpleVectorStore(EmbeddingModel embeddingModel) {
|
|
|
+ SimpleVectorStore vectorStore = SimpleVectorStore.builder(embeddingModel).build();
|
|
|
+ // 启动加载
|
|
|
+ File file = new File(StrUtil.format("{}/vector_store/simple_{}.json",
|
|
|
+ FileUtil.getUserHomePath(), embeddingModel.getClass().getSimpleName()));
|
|
|
+ if (!file.exists()) {
|
|
|
+ FileUtil.mkParentDirs(file);
|
|
|
+ file.createNewFile();
|
|
|
+ } else if (file.length() > 0) {
|
|
|
+ vectorStore.load(file);
|
|
|
+ }
|
|
|
+ // 定时持久化,每分钟一次
|
|
|
+ Timer timer = new Timer("SimpleVectorStoreTimer-" + file.getAbsolutePath());
|
|
|
+ timer.scheduleAtFixedRate(new TimerTask() {
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void run() {
|
|
|
+ vectorStore.save(file);
|
|
|
+ }
|
|
|
+
|
|
|
+ }, Duration.ofMinutes(1).toMillis(), Duration.ofMinutes(1).toMillis());
|
|
|
+ // 关闭时,进行持久化
|
|
|
+ RuntimeUtil.addShutdownHook(() -> vectorStore.save(file));
|
|
|
+ return vectorStore;
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 创建向量存储文件
|
|
|
+ *
|
|
|
+ * @param embeddingModel 嵌入模型
|
|
|
+ * @return 向量存储文件
|
|
|
+ */
|
|
|
+ private File createVectorStoreFile(EmbeddingModel embeddingModel) {
|
|
|
+ // 获取简单类名
|
|
|
+ String simpleClassName = embeddingModel.getClass().getSimpleName();
|
|
|
+ // 获取用户主目录
|
|
|
+ String userHome = FileUtil.getUserHomePath();
|
|
|
+ // 创建vector_store目录
|
|
|
+ File vectorStoreDir = new File(userHome, "vector_store");
|
|
|
+ if (!vectorStoreDir.exists()) {
|
|
|
+ vectorStoreDir.mkdirs();
|
|
|
+ }
|
|
|
+
|
|
|
+ // 创建文件
|
|
|
+ return new File(vectorStoreDir, "simple_" + simpleClassName + ".json");
|
|
|
+ }
|
|
|
+
|
|
|
}
|