From 80eb5eb5906c84ec1a88fa507855b5a1f785de5d Mon Sep 17 00:00:00 2001 From: kennethcheng Date: Mon, 20 Apr 2026 02:57:36 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9E=E6=B5=81=E5=BC=8F=20?= =?UTF-8?q?RAG=20=E9=97=AE=E7=AD=94=E4=B8=8E=20PDF=20=E4=B8=8A=E4=BC=A0?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增接口: - POST /api/chat/stream - 流式 RAG 问答 (SSE) - POST /api/documents/upload - PDF 文件上传 - POST /api/documents/upload/stream - 带进度 PDF 上传 新增功能: - CorsConfig 跨域配置(支持 localhost:8081, 5173) - FileUploadResponse/FileUploadProgress DTO - PDF 文本提取与向量化存储 - MD5 文件去重机制 配置更新: - embedding 模型更新为 BAAI/bge-m3 - multipart max-file-size: 100MB - ChatRequest topK 默认值 3 → 10 --- README.md | 144 +++++----- src/main/java/com/demo/config/CorsConfig.java | 41 +++ .../com/demo/controller/ChatController.java | 27 ++ .../demo/controller/DocumentController.java | 80 +++++- src/main/java/com/demo/dto/ChatRequest.java | 2 +- .../java/com/demo/dto/FileUploadProgress.java | 19 ++ .../java/com/demo/dto/FileUploadResponse.java | 19 ++ .../java/com/demo/service/ChatService.java | 53 +++- .../com/demo/service/DocumentService.java | 260 ++++++++++++++++++ src/main/resources/application.yaml | 15 +- 10 files changed, 581 insertions(+), 79 deletions(-) create mode 100644 src/main/java/com/demo/config/CorsConfig.java create mode 100644 src/main/java/com/demo/dto/FileUploadProgress.java create mode 100644 src/main/java/com/demo/dto/FileUploadResponse.java diff --git a/README.md b/README.md index 0061d6b..2ae837d 100644 --- a/README.md +++ b/README.md @@ -14,10 +14,13 @@ - 📝 **Markdown 支持** - 完整渲染代码块、表格、列表等格式 - 👨‍💻 **代码高亮** - Highlight.js 自动语言检测与语法着色 - 📚 **RAG 智能问答** - 基于向量检索的上下文感知回答 +- 📤 **PDF 文件上传** - PDF 文本提取与向量化存储 +- 📊 **上传进度流** - 实时上传进度反馈 - 🔍 **语义检索** - Milvus 向量数据库相似度搜索 - 🔗 **引用溯源** - 回答附带文档来源引用 - 🗃️ **向量数据库** - Milvus 分布式向量数据库 -- 🔍 **Embedding 服务** - SiliconFlow BAAI/bge-large-zh-v1.5 +- 🔍 **Embedding 服务** - SiliconFlow BAAI/bge-m3 +- 🌐 **CORS 跨域** - 支持多端访问 - 🎨 **精美界面** - 深色主题响应式设计 ## 🚀 快速开始 @@ -62,8 +65,10 @@ mvn spring-boot:run | 基础框架 | Spring Boot | 3.2.0 | | AI 框架 | Spring AI | 1.0.0-M3 | | AI 模型 | OpenAI/Ollama | - | -| Embedding | SiliconFlow BAAI/bge-large-zh-v1.5 | - | +| Embedding | SiliconFlow BAAI/bge-m3 | - | | 向量数据库 | Milvus | 2.3.4 | +| PDF 处理 | Apache PDFBox | 2.0.29 | +| 文件上传 | Commons FileUpload | 1.5 | | 响应式编程 | Spring WebFlux | 3.2.0 | ### 前端 @@ -81,39 +86,45 @@ springAiDemo/ ├── src/main/java/com/demo/ │ ├── MyApplication.java # Spring Boot 启动入口 │ ├── config/ -│ │ └── RagConfig.java # RAG 配置类 +│ │ ├── CorsConfig.java # CORS 跨域配置 +│ │ └── RagConfig.java # RAG 配置类 │ ├── controller/ -│ │ ├── ChatController.java # AI 聊天 API -│ │ └── DocumentController.java # 文档导入 API +│ │ ├── ChatController.java # AI 聊天 API +│ │ └── DocumentController.java # 文档处理 API │ ├── dto/ -│ │ ├── ApiResponse.java # 通用响应封装 -│ │ ├── ChatRequest.java # 聊天请求 DTO -│ │ └── ChatResponse.java # 聊天响应 DTO +│ │ ├── ApiResponse.java # 统一响应封装 +│ │ ├── ChatRequest.java # 聊天请求 DTO +│ │ ├── ChatResponse.java # 聊天响应 DTO +│ │ ├── FileUploadProgress.java # 上传进度 DTO +│ │ └── FileUploadResponse.java # 上传响应 DTO │ └── service/ │ ├── ChatService.java # RAG 聊天服务 │ └── DocumentService.java # 文档处理服务 ├── data/ -│ └── doris_intro.md # RAG 示例文档 +│ └── doris_intro.md # RAG 示例文档 └── src/main/resources/ - ├── application.yaml # 应用配置 - └── static/ # 前端资源 + ├── application.yaml # 应用配置 + └── static/ # 前端资源 ``` ## 💬 API 文档 ### 聊天接口 -| 方法 | 端点 | 描述 | 参数 | +| 方法 | 端点 | 描述 | 参数/Body | |:---|:---|:---|:---| | GET | `/api/chat/test` | 健康检查 | - | | GET | `/api/chat/ai` | 流式 AI 对话 | `msg` | -| POST | `/api/chat` | RAG 智能问答 | `{ "question": "..." }` | +| POST | `/api/chat` | RAG 智能问答 | `{ "question": "...", "topK": 10 }` | +| POST | `/api/chat/stream` | 流式 RAG 问答 | `{ "question": "..." }` (SSE) | ### 文档接口 -| 方法 | 端点 | 描述 | -|:---|:---|:---| -| GET | `/api/documents/import` | 导入文档到向量库 | +| 方法 | 端点 | 描述 | 参数 | +|:---|:---|:---|:---| +| POST | `/api/documents/import` | 导入 RAG 文档 | - | +| POST | `/api/documents/upload` | 上传 PDF 文件 | `file` (multipart) | +| POST | `/api/documents/upload/stream` | 上传 PDF (带进度) | `file` (multipart) | ### 请求示例 @@ -124,14 +135,22 @@ curl -X POST http://localhost:8080/api/chat \ -d '{"question": "Apache Doris 是什么?"}' ``` -**流式对话:** +**流式 RAG 问答 (SSE):** ```bash -curl "http://localhost:8080/api/chat/ai/stream?msg=讲一个故事" +curl -X POST http://localhost:8080/api/chat/stream \ + -H "Content-Type: application/json" \ + -d '{"question": "Apache Doris 性能如何?"}' +``` + +**PDF 文件上传:** +```bash +curl -X POST http://localhost:8080/api/documents/upload \ + -F "file=@/path/to/document.pdf" ``` **导入 RAG 文档:** ```bash -curl http://localhost:8080/api/documents/import +curl -X POST http://localhost:8080/api/documents/import ``` ### 响应格式 @@ -149,6 +168,14 @@ curl http://localhost:8080/api/documents/import } ``` +**POST /api/documents/upload/stream 响应 (SSE):** +```json +{"percent":0,"status":"STARTING","message":"开始上传..."} +{"percent":30,"status":"EXTRACTING","message":"正在提取文本..."} +{"percent":60,"status":"PROCESSING","message":"正在向量化..."} +{"percent":100,"status":"COMPLETED","message":"上传完成"} +``` + ## 📚 RAG 文档问答 ### 工作原理 @@ -157,40 +184,30 @@ curl http://localhost:8080/api/documents/import 用户问题 → Embedding → 向量检索 → 构建上下文 → LLM 生成回答 → 返回答案+引用 ``` -1. **文档导入** - 将 `.md` / `.txt` 文档读取并切割成 chunks -2. **向量化** - 使用 BAAI/bge-large-zh-v1.5 生成 1024 维向量 -3. **存储检索** -存入 Milvus,向量相似度搜索 top-K 结果 -4. **生成回答** - 将检索结果作为上下文,LLM 生成答案 - ### 使用步骤 -1. 将文档放入 `data/` 目录 -2. 调用导入接口:`curl http://localhost:8080/api/documents/import` -3. 通过 POST `/api/chat` 提问 - -### 文档处理配置 - -```yaml -document: - data-path: data # 文档目录 - chunk-size: 400 # 分割块大小 (tokens) - min-chunk-size: 200 # 最小块大小 - max-num-chunk: 10000 # 最大块数量 -``` +1. **方式一**:将 `.md` / `.txt` 文档放入 `data/` 目录,调用导入接口 +2. **方式二**:直接上传 PDF 文件,系统自动提取文本并向量化 +3. 通过 POST `/api/chat` 或流式 `/api/chat/stream` 提问 ## 🛠️ 配置说明 -### AI 对话配置 +### CORS 跨域配置 + +```yaml +cors: + allowed-origins: http://localhost:8081,http://localhost:5173 + allowed-methods: GET,POST,PUT,DELETE,OPTIONS +``` + +### 文件上传配置 ```yaml spring: - ai: - openai: - base-url: http://localhost:11434 - chat: - options: - model: gpt-oss:120b-cloud - temperature: 0.7 + servlet: + multipart: + max-file-size: 100MB + max-request-size: 100MB ``` ### Embedding 配置 @@ -202,24 +219,10 @@ spring: embedding: api-key: your-siliconflow-api-key base-url: https://api.siliconflow.cn - model: BAAI/bge-large-zh-v1.5 + model: BAAI/bge-m3 dimensions: 1024 ``` -### 向量数据库配置 - -```yaml -spring: - ai: - vectorstore: - milvus: - client: - host: 192.168.50.103 - port: 19530 - databaseName: doris_docs - collectionName: vector_store -``` - ## 🎬 架构图 ``` @@ -227,7 +230,7 @@ spring: │ Client │ │ ┌─────────────────┐ ┌─────────────────────────────────┐│ │ │ Web UI │ │ POST /api/chat ││ -│ │ (流式/非流式) │────▶│ { question: "..." } ││ +│ │ (流式/非流式) │────▶│ POST /api/chat/stream (SSE) ││ │ └─────────────────┘ └─────────────────────────────────┘│ └─────────────────────────────────────────────────────────────┘ │ @@ -242,12 +245,6 @@ spring: ┌─────────────────────────┐ ┌─────────────────────────┐ │ VectorStore │ │ ChatClient │ │ (Milvus 语义检索) │ │ (Ollama LLM) │ -└─────────────────────────┘ └─────────────────────────┘ - │ │ - ▼ ▼ -┌─────────────────────────┐ ┌─────────────────────────┐ -│ DocumentService │ │ SiliconFlow API │ -│ (文档切割/向量化) │ │ (Embedding) │ └─────────────────────────┘ └─────────────────────────┘ ``` @@ -261,15 +258,22 @@ spring: ### v1.1.0 (2026-04-19) - ✨ 新增 RAG 文档问答功能 -- ✨ 新增 DocumentController 文档导入 API - ✨ 配置 SiliconFlow Embedding 服务 - ✨ 集成 Milvus 向量数据库 ### v1.2.0 (2026-04-19) - ✨ 新增 POST /api/chat RAG 智能问答接口 - ✨ 新增 ChatService RAG 核心服务 -- ✨ 新增 ChatRequest/ChatResponse/ApiResponse DTO -- ✨ 新增引用溯源功能,返回文档来源 +- ✨ 新增引用溯源功能 + +### v1.3.0 (2026-04-20) +- ✨ 新增 POST /api/chat/stream 流式 RAG 问答接口 (SSE) +- ✨ 新增 PDF 文件上传接口 (POST /api/documents/upload) +- ✨ 新增带进度流式 PDF 上传 (POST /api/documents/upload/stream) +- ✨ 新增 CorsConfig 跨域配置 +- ✨ embedding 模型更新为 BAAI/bge-m3 +- ✨ ChatRequest topK 默认值调整为 10 +- 🐛 修复文件重复上传问题 (MD5 去重) ## 📗 License diff --git a/src/main/java/com/demo/config/CorsConfig.java b/src/main/java/com/demo/config/CorsConfig.java new file mode 100644 index 0000000..5222711 --- /dev/null +++ b/src/main/java/com/demo/config/CorsConfig.java @@ -0,0 +1,41 @@ +package com.demo.config; + +import org.springframework.beans.factory.annotation.Value; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.cors.CorsConfiguration; +import org.springframework.web.cors.UrlBasedCorsConfigurationSource; +import org.springframework.web.filter.CorsFilter; + +import java.util.Arrays; + +@Configuration +public class CorsConfig { + + @Value("${cors.allowed-origins}") + private String allowedOrigins; + + @Value("${cors.allowed-methods}") + private String allowedMethods; + + @Value("${cors.allowed-headers}") + private String allowedHeaders; + + @Value("${cors.allow-credentials}") + private boolean allowCredentials; + + @Bean + public CorsFilter corsFilter() { + CorsConfiguration config = new CorsConfiguration(); + config.setAllowedOrigins(Arrays.asList(allowedOrigins.split(","))); + config.setAllowedMethods(Arrays.asList(allowedMethods.split(","))); + config.setAllowedHeaders(Arrays.asList(allowedHeaders.split(","))); + config.setAllowCredentials(allowCredentials); + config.setMaxAge(3600L); + + UrlBasedCorsConfigurationSource source = new UrlBasedCorsConfigurationSource(); + source.registerCorsConfiguration("/**", config); + + return new CorsFilter(source); + } +} diff --git a/src/main/java/com/demo/controller/ChatController.java b/src/main/java/com/demo/controller/ChatController.java index 72c49b3..a7c0978 100644 --- a/src/main/java/com/demo/controller/ChatController.java +++ b/src/main/java/com/demo/controller/ChatController.java @@ -4,8 +4,11 @@ import com.demo.dto.ApiResponse; import com.demo.dto.ChatRequest; import com.demo.dto.ChatResponse; import com.demo.service.ChatService; +import jakarta.validation.Valid; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.chat.client.ChatClient; +import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; import org.springframework.web.bind.annotation.*; import reactor.core.publisher.Flux; @@ -43,6 +46,30 @@ public class ChatController { return ApiResponse.success(response); } + /** + * 流式问答接口 - 直接返回纯文本流 + */ + @PostMapping(value = "/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) + public Flux> chatStream(@Valid @RequestBody ChatRequest request) { + log.info("接收到流式问答请求: {}", request.getQuestion()); + + return chatService.chatStream(request) + .map(chunk -> ServerSentEvent.builder() + .data(chunk) + .build()) + .doOnSubscribe(s -> log.info("开始流式传输")) + .doOnNext(chunk -> log.debug("发送文本块: {}", chunk)) + .doOnComplete(() -> log.info("流式传输完成")) + .doOnError(e -> log.error("流式传输错误", e)) + .onErrorResume(e -> { + log.error("流式问答处理失败", e); + // 错误信息也一并包装 + return Flux.just(ServerSentEvent.builder() + .data("\n\n[错误: " + e.getMessage() + "]") + .build()); + }); + } + /** * 流式的聊天接口,要注意如果中文有乱码,就是编码得问题,需要添加produces = "text/html;charset=UTF-8 * @param msg diff --git a/src/main/java/com/demo/controller/DocumentController.java b/src/main/java/com/demo/controller/DocumentController.java index ea883e9..bd9dd31 100644 --- a/src/main/java/com/demo/controller/DocumentController.java +++ b/src/main/java/com/demo/controller/DocumentController.java @@ -1,15 +1,21 @@ package com.demo.controller; +import com.demo.dto.ApiResponse; +import com.demo.dto.FileUploadProgress; +import com.demo.dto.FileUploadResponse; import com.demo.service.DocumentService; import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.web.bind.annotation.GetMapping; -import org.springframework.web.bind.annotation.RequestMapping; -import org.springframework.web.bind.annotation.RestController; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.web.bind.annotation.*; +import org.springframework.web.multipart.MultipartFile; +import reactor.core.publisher.Flux; import java.io.IOException; - +@Slf4j @RestController @RequestMapping("api/documents") @RequiredArgsConstructor @@ -18,7 +24,7 @@ public class DocumentController { @Autowired public final DocumentService documentService; - @GetMapping("import") + @PostMapping("import") public String importDocument() { try { documentService.importDocument(); @@ -27,5 +33,69 @@ public class DocumentController { } return "ok"; } + @PostMapping(value = "/upload", consumes = MediaType.MULTIPART_FORM_DATA_VALUE) + public ResponseEntity> uploadPdfFile( + @RequestParam("file") MultipartFile file) { + try { + log.info("收到PDF文件上传请求: {}", file.getOriginalFilename()); + + // 验证文件 + if (file.isEmpty()) { + log.error("上传文件为空"); + return ResponseEntity.badRequest() + .body(ApiResponse.error("上传文件不能为空")); + } + + // 检查文件格式 + String contentType = file.getContentType(); + String originalFilename = file.getOriginalFilename(); + + if (originalFilename == null || !originalFilename.toLowerCase().endsWith(".pdf")) { + log.error("文件格式错误: {}", originalFilename); + return ResponseEntity.badRequest() + .body(ApiResponse.error("只能上传PDF格式文件")); + } + + // 检查文件大小 (30MB) + if (file.getSize() > 30 * 1024 * 1024) { + log.error("文件大小超出限制: {} bytes", file.getSize()); + return ResponseEntity.badRequest() + .body(ApiResponse.error("文件大小不能超过30MB")); + } + + // 处理文件上传 + FileUploadResponse response = documentService.uploadPdfFile(file); + + if (response.isDuplicate()) { + log.info("文件已存在,跳过处理: {}", originalFilename); + return ResponseEntity.ok(ApiResponse.success("文件已存在,无需重复上传", response)); + } + + log.info("文件上传成功: {}, 生成了 {} 个文档片段", + originalFilename, response.getChunkCount()); + return ResponseEntity.ok(ApiResponse.success("文件上传成功", response)); + + } catch (Exception e) { + log.error("文件上传失败", e); + return ResponseEntity.internalServerError() + .body(ApiResponse.error("文件上传失败: " + e.getMessage())); + } + } + + @PostMapping(value = "/upload/stream", consumes = MediaType.MULTIPART_FORM_DATA_VALUE) + public Flux uploadPdfFileWithProgress( + @RequestParam("file") MultipartFile file) { + try { + log.info("收到PDF文件上传请求(流式): {}", file.getOriginalFilename()); + return documentService.uploadPdfFileWithProgress(file); + } catch (Exception e) { + log.error("文件上传失败", e); + return Flux.just(FileUploadProgress.builder() + .status("FAILED") + .message("文件上传失败: " + e.getMessage()) + .percentage(0) + .build()); + } + } } diff --git a/src/main/java/com/demo/dto/ChatRequest.java b/src/main/java/com/demo/dto/ChatRequest.java index d4a94ea..7b5a3cf 100644 --- a/src/main/java/com/demo/dto/ChatRequest.java +++ b/src/main/java/com/demo/dto/ChatRequest.java @@ -8,6 +8,6 @@ public class ChatRequest { @NotBlank(message = "请输入你的问题") private String question; - private Integer topK = 3; + private Integer topK = 10; } diff --git a/src/main/java/com/demo/dto/FileUploadProgress.java b/src/main/java/com/demo/dto/FileUploadProgress.java new file mode 100644 index 0000000..419f6c8 --- /dev/null +++ b/src/main/java/com/demo/dto/FileUploadProgress.java @@ -0,0 +1,19 @@ +package com.demo.dto; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@Builder +@NoArgsConstructor +@AllArgsConstructor +public class FileUploadProgress { + private String fileName; + private long bytesProcessed; + private long totalBytes; + private int percentage; + private String status; // PROCESSING, COMPLETED, FAILED + private String message; +} \ No newline at end of file diff --git a/src/main/java/com/demo/dto/FileUploadResponse.java b/src/main/java/com/demo/dto/FileUploadResponse.java new file mode 100644 index 0000000..3c282a9 --- /dev/null +++ b/src/main/java/com/demo/dto/FileUploadResponse.java @@ -0,0 +1,19 @@ +package com.demo.dto; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@Builder +@NoArgsConstructor +@AllArgsConstructor +public class FileUploadResponse { + private String fileName; + private String fileSize; + private String md5Hash; + private int chunkCount; + private String message; + private boolean duplicate; +} \ No newline at end of file diff --git a/src/main/java/com/demo/service/ChatService.java b/src/main/java/com/demo/service/ChatService.java index 7ff0c62..c008187 100644 --- a/src/main/java/com/demo/service/ChatService.java +++ b/src/main/java/com/demo/service/ChatService.java @@ -2,7 +2,6 @@ package com.demo.service; import com.demo.dto.ChatRequest; import com.demo.dto.ChatResponse; -import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.Message; @@ -14,7 +13,9 @@ import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.VectorStore; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; +import reactor.core.publisher.Flux; +import java.time.Duration; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -98,4 +99,54 @@ public class ChatService { .timestamp(System.currentTimeMillis()) .build(); } + + /** + * 处理用户问题(流式响应 - 直接返回原始文本) + */ + public Flux chatStream(ChatRequest request) { + log.info("收到流式问题: {}", request.getQuestion()); + + try { + // 1. 向量检索相关文档 + List relevantDocs = vectorStore.similaritySearch( + SearchRequest.query(request.getQuestion()) + .withTopK(request.getTopK()) + ); + + log.info("检索到 {} 个相关文档", relevantDocs.size()); + + // 2. 构建上下文 + String context = relevantDocs.isEmpty() ? + "暂无相关文档,请基于你的知识回答。" : + relevantDocs.stream() + .map(Document::getContent) + .collect(Collectors.joining("\n\n---\n\n")); + + // 3. 构建提示词 + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(SYSTEM_PROMPT); + Message systemMessage = systemPromptTemplate.createMessage(Map.of("context", context)); + UserMessage userMessage = new UserMessage(request.getQuestion()); + + Prompt prompt = new Prompt(List.of(systemMessage, userMessage)); + + // 4. 返回原始文本流 + log.info("开始流式生成答案"); + + return chatClient.prompt(prompt) + .stream() + .content() +// .doOnNext(chunk -> log.debug("发送文本块: {}", chunk)) + .doOnComplete(() -> log.info("流式传输完成")) + .doOnError(e -> log.error("流式生成失败", e)) + .onErrorResume(e -> { + log.error("流式响应错误", e); + return Flux.just("\n\n[错误: " + e.getMessage() + "]"); + }) + .timeout(Duration.ofSeconds(60)); + + } catch (Exception e) { + log.error("流式问答处理异常", e); + return Flux.just("[错误: " + e.getMessage() + "]"); + } + } } diff --git a/src/main/java/com/demo/service/DocumentService.java b/src/main/java/com/demo/service/DocumentService.java index e2850bb..d424fd6 100644 --- a/src/main/java/com/demo/service/DocumentService.java +++ b/src/main/java/com/demo/service/DocumentService.java @@ -1,20 +1,29 @@ package com.demo.service; +import com.demo.dto.FileUploadProgress; +import com.demo.dto.FileUploadResponse; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; +import org.apache.pdfbox.pdmodel.PDDocument; +import org.apache.pdfbox.text.PDFTextStripper; import org.springframework.ai.document.Document; import org.springframework.ai.reader.TextReader; import org.springframework.ai.transformer.splitter.TokenTextSplitter; +import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.VectorStore; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.core.io.FileSystemResource; import org.springframework.stereotype.Service; +import org.springframework.web.multipart.MultipartFile; +import reactor.core.publisher.Flux; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; import java.util.ArrayList; import java.util.List; import java.util.stream.Stream; @@ -79,4 +88,255 @@ public class DocumentService { vectorStore.add(allDocuments); } } + + /** + * 上传PDF文件 + */ + public FileUploadResponse uploadPdfFile(MultipartFile file) throws IOException, NoSuchAlgorithmException { + String originalFilename = file.getOriginalFilename(); + long fileSize = file.getSize(); + + log.info("开始处理PDF文件: {}, 大小: {} bytes", originalFilename, fileSize); + + // 计算文件MD5 + String md5Hash = calculateFileMD5(file); + log.info("文件MD5: {}", md5Hash); + + // 检查文件是否已存在 + if (isFileAlreadyUploaded(md5Hash)) { + log.info("文件已存在,跳过处理: {}", originalFilename); + return FileUploadResponse.builder() + .fileName(originalFilename) + .fileSize(formatFileSize(fileSize)) + .md5Hash(md5Hash) + .chunkCount(0) + .message("文件已存在,无需重复上传") + .duplicate(true) + .build(); + } + + // 提取PDF文本 + String pdfText = extractPdfText(file); + log.info("PDF文本提取完成,长度: {} 字符", pdfText.length()); + + // 创建文档对象 + Document document = new Document(pdfText); + document.getMetadata().put("source", originalFilename); + document.getMetadata().put("file_size", String.valueOf(fileSize)); + document.getMetadata().put("file_hash", md5Hash); + document.getMetadata().put("upload_time", String.valueOf(System.currentTimeMillis())); + + // 文本分割 + List splitDocuments = textSplitter.apply(List.of(document)); + log.info("文本分割完成,生成 {} 个文档片段", splitDocuments.size()); + + // 为每个分割的文档添加元数据 + splitDocuments.forEach(doc -> { + doc.getMetadata().put("source", originalFilename); + doc.getMetadata().put("file_size", String.valueOf(fileSize)); + doc.getMetadata().put("file_hash", md5Hash); + doc.getMetadata().put("upload_time", String.valueOf(System.currentTimeMillis())); + }); + + // 存储到向量数据库 + vectorStore.add(splitDocuments); + log.info("文档存储到向量数据库完成"); + + return FileUploadResponse.builder() + .fileName(originalFilename) + .fileSize(formatFileSize(fileSize)) + .md5Hash(md5Hash) + .chunkCount(splitDocuments.size()) + .message("文件上传成功") + .duplicate(false) + .build(); + } + + /** + * 带进度上传PDF文件 + */ + public Flux uploadPdfFileWithProgress(MultipartFile file) { + return Flux.create(sink -> { + try { + String originalFilename = file.getOriginalFilename(); + long fileSize = file.getSize(); + + log.info("开始流式处理PDF文件: {}", originalFilename); + + // 发送开始进度 + sink.next(FileUploadProgress.builder() + .fileName(originalFilename) + .bytesProcessed(0) + .totalBytes(fileSize) + .percentage(0) + .status("PROCESSING") + .message("开始处理文件...") + .build()); + + // 计算文件MD5 (20%) + sink.next(FileUploadProgress.builder() + .fileName(originalFilename) + .bytesProcessed(fileSize / 5) + .totalBytes(fileSize) + .percentage(20) + .status("PROCESSING") + .message("计算文件哈希值...") + .build()); + + String md5Hash = calculateFileMD5(file); + + // 检查文件是否已存在 + if (isFileAlreadyUploaded(md5Hash)) { + sink.next(FileUploadProgress.builder() + .fileName(originalFilename) + .bytesProcessed(fileSize) + .totalBytes(fileSize) + .percentage(100) + .status("COMPLETED") + .message("文件已存在,无需重复上传") + .build()); + sink.complete(); + return; + } + + // 提取PDF文本 (40%) + sink.next(FileUploadProgress.builder() + .fileName(originalFilename) + .bytesProcessed(fileSize * 2 / 5) + .totalBytes(fileSize) + .percentage(40) + .status("PROCESSING") + .message("提取PDF文本内容...") + .build()); + + String pdfText = extractPdfText(file); + + // 创建文档对象 + Document document = new Document(pdfText); + document.getMetadata().put("source", originalFilename); + document.getMetadata().put("file_size", String.valueOf(fileSize)); + document.getMetadata().put("file_hash", md5Hash); + document.getMetadata().put("upload_time", String.valueOf(System.currentTimeMillis())); + + // 文本分割 (60%) + sink.next(FileUploadProgress.builder() + .fileName(originalFilename) + .bytesProcessed(fileSize * 3 / 5) + .totalBytes(fileSize) + .percentage(60) + .status("PROCESSING") + .message("分割文本内容...") + .build()); + + List splitDocuments = textSplitter.apply(List.of(document)); + + // 为每个分割的文档添加元数据 + splitDocuments.forEach(doc -> { + doc.getMetadata().put("source", originalFilename); + doc.getMetadata().put("file_size", String.valueOf(fileSize)); + doc.getMetadata().put("file_hash", md5Hash); + doc.getMetadata().put("upload_time", String.valueOf(System.currentTimeMillis())); + }); + + // 存储到向量数据库 (80%) + sink.next(FileUploadProgress.builder() + .fileName(originalFilename) + .bytesProcessed(fileSize * 4 / 5) + .totalBytes(fileSize) + .percentage(80) + .status("PROCESSING") + .message("存储到向量数据库...") + .build()); + // 分批提交文档,减小embedding的压力 + List tmpDocuments = new ArrayList<>(); + for (int i = 0; i < splitDocuments.size(); i++) { + tmpDocuments.add(splitDocuments.get(i)); + if (i % 10 == 0) { + vectorStore.add(tmpDocuments); + tmpDocuments.clear(); + } + } + vectorStore.add(tmpDocuments); + + // 完成 (100%) + sink.next(FileUploadProgress.builder() + .fileName(originalFilename) + .bytesProcessed(fileSize) + .totalBytes(fileSize) + .percentage(100) + .status("COMPLETED") + .message("文件上传成功,生成了 " + splitDocuments.size() + " 个文档片段") + .build()); + + sink.complete(); + log.info("流式文件上传完成: {}, 生成了 {} 个文档片段", + originalFilename, splitDocuments.size()); + + } catch (Exception e) { + log.error("流式文件上传失败", e); + sink.next(FileUploadProgress.builder() + .status("FAILED") + .message("文件上传失败: " + e.getMessage()) + .percentage(0) + .build()); + sink.complete(); + } + }); + } + + /** + * 计算文件MD5哈希值 + */ + private String calculateFileMD5(MultipartFile file) throws IOException, NoSuchAlgorithmException { + MessageDigest md = MessageDigest.getInstance("MD5"); + byte[] fileBytes = file.getBytes(); + byte[] digest = md.digest(fileBytes); + + StringBuilder sb = new StringBuilder(); + for (byte b : digest) { + sb.append(String.format("%02x", b)); + } + return sb.toString(); + } + + /** + * 检查文件是否已上传 + */ + private boolean isFileAlreadyUploaded(String md5Hash) { + try { + // 使用向量存储的相似度搜索功能来检查是否存在相同MD5的文件 + // 这里简化处理,实际应用中可能需要更高效的查询方式 + List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(3)); + + // 检查已存在的文档中是否有相同MD5的文件 + return results.stream() + .anyMatch(doc -> md5Hash.equals(doc.getMetadata().get("file_hash"))); + } catch (Exception e) { + log.warn("检查文件重复性失败", e); + return false; + } + } + + /** + * 提取PDF文本内容 + */ + private String extractPdfText(MultipartFile file) throws IOException { + try (PDDocument document = PDDocument.load(file.getInputStream())) { + PDFTextStripper stripper = new PDFTextStripper(); + return stripper.getText(document); + } + } + + /** + * 格式化文件大小 + */ + private String formatFileSize(long size) { + if (size < 1024) { + return size + " B"; + } else if (size < 1024 * 1024) { + return String.format("%.1f KB", size / 1024.0); + } else { + return String.format("%.1f MB", size / (1024.0 * 1024.0)); + } + } } diff --git a/src/main/resources/application.yaml b/src/main/resources/application.yaml index 35f1bf9..fdfd487 100644 --- a/src/main/resources/application.yaml +++ b/src/main/resources/application.yaml @@ -2,6 +2,10 @@ server: port: 8080 spring: + servlet: + multipart: + max-file-size: 100MB + max-request-size: 100MB ai: openai: api-key: ollama @@ -19,7 +23,7 @@ spring: api-key: key base-url: https://api.siliconflow.cn options: - model: BAAI/bge-large-zh-v1.5 + model: BAAI/bge-m3 dimensions: 1024 enable: true vectorstore: @@ -39,4 +43,11 @@ document: data-path: data chunk-size: 400 min-chunk-size: 200 - max-num-chunk: 10000 \ No newline at end of file + max-num-chunk: 10000 + +# CORS 配置 +cors: + allowed-origins: ${CORS_ORIGINS:http://localhost:8081,http://localhost:5173} + allowed-methods: GET,POST,PUT,DELETE,OPTIONS + allowed-headers: "*" + allow-credentials: true \ No newline at end of file