Prechádzať zdrojové kódy

【功能新增】AI:新增 function call 示例。会继续完善!

YunaiV 5 mesiacov pred
rodič
commit
25a0fe908a

+ 95 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/tool/ListDirTool.java

@@ -0,0 +1,95 @@
+package cn.iocoder.yudao.module.ai.service.tool;
+
+import cn.hutool.core.date.LocalDateTimeUtil;
+import cn.hutool.core.io.FileUtil;
+import cn.hutool.core.util.ArrayUtil;
+import cn.hutool.core.util.StrUtil;
+import lombok.AllArgsConstructor;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+import org.springframework.ai.tool.annotation.Tool;
+import org.springframework.ai.tool.annotation.ToolParam;
+import org.springframework.stereotype.Component;
+
+import java.io.File;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static cn.hutool.core.date.DatePattern.NORM_DATETIME_PATTERN;
+import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
+
+/**
+ * 目录内容列表工具:列出指定目录的内容
+ *
+ * @author 芋道源码
+ */
+@Component
+public class ListDirTool {
+
+    /**
+     * 列出指定目录的内容
+     *
+     * @param relativePath 要列出内容的目录路径,相对于工作区根目录
+     * @return 目录内容列表
+     */
+    @Tool(name = "listDir", description = "列出指定目录的内容")
+    public Response listDir(@ToolParam(description = "要列出内容的目录路径,相对于工作区根目录") String relativePath) {
+        // 校验目录存在
+        String path = StrUtil.blankToDefault(relativePath, ".");
+        Path dirPath = Paths.get(path);
+        if (!FileUtil.exist(dirPath.toString()) || !FileUtil.isDirectory(dirPath.toString())) {
+            return new Response(Collections.emptyList());
+        }
+        // 列出目录内容
+        File[] files = dirPath.toFile().listFiles();
+        if (ArrayUtil.isEmpty(files)) {
+            return new Response(Collections.emptyList());
+        }
+        return new Response(convertList(Arrays.asList(files), file -> new Response.File()
+                .setDirectory(file.isDirectory()).setName(file.getName())
+                .setSize(file.isFile() ? FileUtil.readableFileSize(file.length()) : null)
+                .setLastModified(
+                        LocalDateTimeUtil.format(LocalDateTimeUtil.of(file.lastModified()), NORM_DATETIME_PATTERN))));
+    }
+
+    @Data
+    @AllArgsConstructor
+    @NoArgsConstructor
+    public static class Response {
+
+        /**
+         * 目录内容列表
+         */
+        private List<File> files;
+
+        @Data
+        public static class File {
+
+            /**
+             * 是否为目录
+             */
+            private Boolean directory;
+
+            /**
+             * 名称
+             */
+            private String name;
+
+            /**
+             * 大小,仅对文件有效
+             */
+            private String size;
+
+            /**
+             * 最后修改时间
+             */
+            private String lastModified;
+
+        }
+
+    }
+
+}

+ 100 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/tool/ListDirToolB.java

@@ -0,0 +1,100 @@
+package cn.iocoder.yudao.module.ai.service.tool;
+
+import cn.hutool.core.date.LocalDateTimeUtil;
+import cn.hutool.core.io.FileUtil;
+import cn.hutool.core.util.ArrayUtil;
+import cn.hutool.core.util.StrUtil;
+import com.fasterxml.jackson.annotation.JsonClassDescription;
+import com.fasterxml.jackson.annotation.JsonPropertyDescription;
+import lombok.AllArgsConstructor;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+import org.springframework.stereotype.Component;
+
+import java.io.File;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.function.Function;
+
+import static cn.hutool.core.date.DatePattern.NORM_DATETIME_PATTERN;
+import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
+
+/**
+ * 目录内容列表工具:列出指定目录的内容
+ *
+ * @author 芋道源码
+ */
+@Component("listDir")
+public class ListDirToolB implements Function<ListDirToolB.Request, ListDirToolB.Response> {
+
+    @Data
+    @JsonClassDescription("列出指定目录的内容")
+    public static class Request {
+
+        /**
+         * 要列出内容的目录路径
+         */
+        @JsonPropertyDescription("要列出内容的目录路径,例如说:/Users/yunai")
+        private String path;
+
+    }
+
+    @Data
+    @AllArgsConstructor
+    @NoArgsConstructor
+    public static class Response {
+
+        /**
+         * 目录内容列表
+         */
+        private List<File> files;
+
+        @Data
+        public static class File {
+
+            /**
+             * 是否为目录
+             */
+            private Boolean directory;
+
+            /**
+             * 名称
+             */
+            private String name;
+
+            /**
+             * 大小,仅对文件有效
+             */
+            private String size;
+
+            /**
+             * 最后修改时间
+             */
+            private String lastModified;
+
+        }
+
+    }
+
+    @Override
+    public Response apply(Request request) {
+        // 校验目录存在
+        String path = StrUtil.blankToDefault(request.getPath(), ".");
+        Path dirPath = Paths.get(path);
+        if (!FileUtil.exist(dirPath.toString()) || !FileUtil.isDirectory(dirPath.toString())) {
+            return new Response(Collections.emptyList());
+        }
+        // 列出目录内容
+        File[] files = dirPath.toFile().listFiles();
+        if (ArrayUtil.isEmpty(files)) {
+            return new Response(Collections.emptyList());
+        }
+        return new Response(convertList(Arrays.asList(files), file ->
+                new Response.File().setDirectory(file.isDirectory()).setName(file.getName())
+                        .setLastModified(LocalDateTimeUtil.format(LocalDateTimeUtil.of(file.lastModified()), NORM_DATETIME_PATTERN))));
+    }
+
+}

+ 7 - 6
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java

@@ -33,7 +33,7 @@ import org.springframework.context.annotation.Bean;
  * @author fansili
  */
 @AutoConfiguration
-@EnableConfigurationProperties({YudaoAiProperties.class,
+@EnableConfigurationProperties({ YudaoAiProperties.class,
         QdrantVectorStoreProperties.class, // 解析 Qdrant 配置
         RedisVectorStoreProperties.class, // 解析 Redis 配置
         MilvusVectorStoreProperties.class, MilvusServiceClientProperties.class // 解析 Milvus 配置
@@ -139,15 +139,16 @@ public class YudaoAiAutoConfiguration {
         }
         // 特殊:由于混元大模型不提供 deepseek,而是通过知识引擎,所以需要区分下 URL
         if (StrUtil.isEmpty(properties.getBaseUrl())) {
-            properties.setBaseUrl(StrUtil.startWithIgnoreCase(properties.getModel(), "deepseek") ?
-                    HunYuanChatModel.DEEP_SEEK_BASE_URL : HunYuanChatModel.BASE_URL);
+            properties.setBaseUrl(
+                    StrUtil.startWithIgnoreCase(properties.getModel(), "deepseek") ? HunYuanChatModel.DEEP_SEEK_BASE_URL
+                            : HunYuanChatModel.BASE_URL);
         }
         // 创建 OpenAiChatModel、HunYuanChatModel 对象
         OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
                 .openAiApi(OpenAiApi.builder()
-                      .baseUrl(properties.getBaseUrl())
-                      .apiKey(properties.getApiKey())
-                      .build())
+                        .baseUrl(properties.getBaseUrl())
+                        .apiKey(properties.getApiKey())
+                        .build())
                 .defaultOptions(OpenAiChatOptions.builder()
                         .model(properties.getModel())
                         .temperature(properties.getTemperature())

+ 7 - 2
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java

@@ -61,6 +61,7 @@ import org.springframework.ai.minimax.MiniMaxChatModel;
 import org.springframework.ai.minimax.MiniMaxEmbeddingModel;
 import org.springframework.ai.minimax.MiniMaxEmbeddingOptions;
 import org.springframework.ai.minimax.api.MiniMaxApi;
+import org.springframework.ai.model.tool.ToolCallingManager;
 import org.springframework.ai.moonshot.MoonshotChatModel;
 import org.springframework.ai.moonshot.api.MoonshotApi;
 import org.springframework.ai.ollama.OllamaChatModel;
@@ -431,7 +432,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
     private static OpenAiChatModel buildOpenAiChatModel(String openAiToken, String url) {
         url = StrUtil.blankToDefault(url, OpenAiApiConstants.DEFAULT_BASE_URL);
         OpenAiApi openAiApi = OpenAiApi.builder().baseUrl(url).apiKey(openAiToken).build();
-        return OpenAiChatModel.builder().openAiApi(openAiApi).build();
+        return OpenAiChatModel.builder().openAiApi(openAiApi).toolCallingManager(getToolCallingManager()).build();
     }
 
     // TODO @芋艿:手头暂时没密钥,使用建议再测试下
@@ -465,7 +466,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
      */
     private static OllamaChatModel buildOllamaChatModel(String url) {
         OllamaApi ollamaApi = new OllamaApi(url);
-        return OllamaChatModel.builder().ollamaApi(ollamaApi).build();
+        return OllamaChatModel.builder().ollamaApi(ollamaApi).toolCallingManager(getToolCallingManager()).build();
     }
 
     private StabilityAiImageModel buildStabilityAiImageModel(String apiKey, String url) {
@@ -699,4 +700,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
         return SpringUtil.getBean(BatchingStrategy.class);
     }
 
+    private static ToolCallingManager getToolCallingManager() {
+        return SpringUtil.getBean(ToolCallingManager.class);
+    }
+
 }

+ 12 - 2
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java

@@ -24,14 +24,18 @@ public class AiUtils {
         // noinspection EnhancedSwitchMigration
         switch (platform) {
             case TONG_YI:
+                // TODO functions
                 return DashScopeChatOptions.builder().withModel(model).withTemperature(temperature).withMaxToken(maxTokens).build();
             case YI_YAN:
                 return QianFanChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens).build();
             case ZHI_PU:
+                // TODO functions
                 return ZhiPuAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens).build();
             case MINI_MAX:
+                // TODO functions
                 return MiniMaxChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens).build();
             case MOONSHOT:
+                // TODO functions
                 return MoonshotChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens).build();
             case OPENAI:
             case DEEP_SEEK: // 复用 OpenAI 客户端
@@ -39,12 +43,18 @@ public class AiUtils {
             case HUN_YUAN: // 复用 OpenAI 客户端
             case XING_HUO: // 复用 OpenAI 客户端
             case SILICON_FLOW: // 复用 OpenAI 客户端
-                return OpenAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens).build();
+                return OpenAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
+//                        .toolNames() TODO
+                        .toolNames("listDir")
+                        .build();
             case AZURE_OPENAI:
                 // TODO 芋艿:貌似没 model 字段???!
+                // TODO 芋艿:.toolNames() TODO
                 return AzureOpenAiChatOptions.builder().deploymentName(model).temperature(temperature).maxTokens(maxTokens).build();
             case OLLAMA:
-                return OllamaOptions.builder().model(model).temperature(temperature).numPredict(maxTokens).build();
+                // TODO 芋艿:.toolNames() TODO
+                return OllamaOptions.builder().model(model).temperature(temperature).numPredict(maxTokens)
+                        .toolNames("listDir").build();
             default:
                 throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
         }

+ 5 - 2
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/package-info.java

@@ -4,7 +4,10 @@
  * models 包路径:
  *  1. xinghuo 包:【讯飞】星火,自己实现
  *  2. deepseek 包:【深度求索】DeepSeek,自己实现
- *  3. midjourney 包:Midjourney API,对接 https://github.com/novicezk/midjourney-proxy 实现
- *  4. suno 包:Suno API,对接 https://github.com/gcui-art/suno-api 实现
+ *  3. doubao 包:【字节豆包】DouBao,自己实现
+ *  4. hunyuan 包:【腾讯混元】HunYuan,自己实现
+ *  5. siliconflow 包:【硅基硅流】SiliconFlow,自己实现
+ *  6. midjourney 包:Midjourney API,对接 https://github.com/novicezk/midjourney-proxy 实现
+ *  7. suno 包:Suno API,对接 https://github.com/gcui-art/suno-api 实现
  */
 package cn.iocoder.yudao.framework.ai;