feat: 新增流式 RAG 问答与 PDF 上传功能
新增接口: - POST /api/chat/stream - 流式 RAG 问答 (SSE) - POST /api/documents/upload - PDF 文件上传 - POST /api/documents/upload/stream - 带进度 PDF 上传 新增功能: - CorsConfig 跨域配置(支持 localhost:8081, 5173) - FileUploadResponse/FileUploadProgress DTO - PDF 文本提取与向量化存储 - MD5 文件去重机制 配置更新: - embedding 模型更新为 BAAI/bge-m3 - multipart max-file-size: 100MB - ChatRequest topK 默认值 3 → 10
This commit is contained in:
parent
710fe14d7f
commit
80eb5eb590
144
README.md
144
README.md
@ -14,10 +14,13 @@
|
||||
- 📝 **Markdown 支持** - 完整渲染代码块、表格、列表等格式
|
||||
- 👨💻 **代码高亮** - Highlight.js 自动语言检测与语法着色
|
||||
- 📚 **RAG 智能问答** - 基于向量检索的上下文感知回答
|
||||
- 📤 **PDF 文件上传** - PDF 文本提取与向量化存储
|
||||
- 📊 **上传进度流** - 实时上传进度反馈
|
||||
- 🔍 **语义检索** - Milvus 向量数据库相似度搜索
|
||||
- 🔗 **引用溯源** - 回答附带文档来源引用
|
||||
- 🗃️ **向量数据库** - Milvus 分布式向量数据库
|
||||
- 🔍 **Embedding 服务** - SiliconFlow BAAI/bge-large-zh-v1.5
|
||||
- 🔍 **Embedding 服务** - SiliconFlow BAAI/bge-m3
|
||||
- 🌐 **CORS 跨域** - 支持多端访问
|
||||
- 🎨 **精美界面** - 深色主题响应式设计
|
||||
|
||||
## 🚀 快速开始
|
||||
@ -62,8 +65,10 @@ mvn spring-boot:run
|
||||
| 基础框架 | Spring Boot | 3.2.0 |
|
||||
| AI 框架 | Spring AI | 1.0.0-M3 |
|
||||
| AI 模型 | OpenAI/Ollama | - |
|
||||
| Embedding | SiliconFlow BAAI/bge-large-zh-v1.5 | - |
|
||||
| Embedding | SiliconFlow BAAI/bge-m3 | - |
|
||||
| 向量数据库 | Milvus | 2.3.4 |
|
||||
| PDF 处理 | Apache PDFBox | 2.0.29 |
|
||||
| 文件上传 | Commons FileUpload | 1.5 |
|
||||
| 响应式编程 | Spring WebFlux | 3.2.0 |
|
||||
|
||||
### 前端
|
||||
@ -81,39 +86,45 @@ springAiDemo/
|
||||
├── src/main/java/com/demo/
|
||||
│ ├── MyApplication.java # Spring Boot 启动入口
|
||||
│ ├── config/
|
||||
│ │ └── RagConfig.java # RAG 配置类
|
||||
│ │ ├── CorsConfig.java # CORS 跨域配置
|
||||
│ │ └── RagConfig.java # RAG 配置类
|
||||
│ ├── controller/
|
||||
│ │ ├── ChatController.java # AI 聊天 API
|
||||
│ │ └── DocumentController.java # 文档导入 API
|
||||
│ │ ├── ChatController.java # AI 聊天 API
|
||||
│ │ └── DocumentController.java # 文档处理 API
|
||||
│ ├── dto/
|
||||
│ │ ├── ApiResponse.java # 通用响应封装
|
||||
│ │ ├── ChatRequest.java # 聊天请求 DTO
|
||||
│ │ └── ChatResponse.java # 聊天响应 DTO
|
||||
│ │ ├── ApiResponse.java # 统一响应封装
|
||||
│ │ ├── ChatRequest.java # 聊天请求 DTO
|
||||
│ │ ├── ChatResponse.java # 聊天响应 DTO
|
||||
│ │ ├── FileUploadProgress.java # 上传进度 DTO
|
||||
│ │ └── FileUploadResponse.java # 上传响应 DTO
|
||||
│ └── service/
|
||||
│ ├── ChatService.java # RAG 聊天服务
|
||||
│ └── DocumentService.java # 文档处理服务
|
||||
├── data/
|
||||
│ └── doris_intro.md # RAG 示例文档
|
||||
│ └── doris_intro.md # RAG 示例文档
|
||||
└── src/main/resources/
|
||||
├── application.yaml # 应用配置
|
||||
└── static/ # 前端资源
|
||||
├── application.yaml # 应用配置
|
||||
└── static/ # 前端资源
|
||||
```
|
||||
|
||||
## 💬 API 文档
|
||||
|
||||
### 聊天接口
|
||||
|
||||
| 方法 | 端点 | 描述 | 参数 |
|
||||
| 方法 | 端点 | 描述 | 参数/Body |
|
||||
|:---|:---|:---|:---|
|
||||
| GET | `/api/chat/test` | 健康检查 | - |
|
||||
| GET | `/api/chat/ai` | 流式 AI 对话 | `msg` |
|
||||
| POST | `/api/chat` | RAG 智能问答 | `{ "question": "..." }` |
|
||||
| POST | `/api/chat` | RAG 智能问答 | `{ "question": "...", "topK": 10 }` |
|
||||
| POST | `/api/chat/stream` | 流式 RAG 问答 | `{ "question": "..." }` (SSE) |
|
||||
|
||||
### 文档接口
|
||||
|
||||
| 方法 | 端点 | 描述 |
|
||||
|:---|:---|:---|
|
||||
| GET | `/api/documents/import` | 导入文档到向量库 |
|
||||
| 方法 | 端点 | 描述 | 参数 |
|
||||
|:---|:---|:---|:---|
|
||||
| POST | `/api/documents/import` | 导入 RAG 文档 | - |
|
||||
| POST | `/api/documents/upload` | 上传 PDF 文件 | `file` (multipart) |
|
||||
| POST | `/api/documents/upload/stream` | 上传 PDF (带进度) | `file` (multipart) |
|
||||
|
||||
### 请求示例
|
||||
|
||||
@ -124,14 +135,22 @@ curl -X POST http://localhost:8080/api/chat \
|
||||
-d '{"question": "Apache Doris 是什么?"}'
|
||||
```
|
||||
|
||||
**流式对话:**
|
||||
**流式 RAG 问答 (SSE):**
|
||||
```bash
|
||||
curl "http://localhost:8080/api/chat/ai/stream?msg=讲一个故事"
|
||||
curl -X POST http://localhost:8080/api/chat/stream \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"question": "Apache Doris 性能如何?"}'
|
||||
```
|
||||
|
||||
**PDF 文件上传:**
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/api/documents/upload \
|
||||
-F "file=@/path/to/document.pdf"
|
||||
```
|
||||
|
||||
**导入 RAG 文档:**
|
||||
```bash
|
||||
curl http://localhost:8080/api/documents/import
|
||||
curl -X POST http://localhost:8080/api/documents/import
|
||||
```
|
||||
|
||||
### 响应格式
|
||||
@ -149,6 +168,14 @@ curl http://localhost:8080/api/documents/import
|
||||
}
|
||||
```
|
||||
|
||||
**POST /api/documents/upload/stream 响应 (SSE):**
|
||||
```json
|
||||
{"percent":0,"status":"STARTING","message":"开始上传..."}
|
||||
{"percent":30,"status":"EXTRACTING","message":"正在提取文本..."}
|
||||
{"percent":60,"status":"PROCESSING","message":"正在向量化..."}
|
||||
{"percent":100,"status":"COMPLETED","message":"上传完成"}
|
||||
```
|
||||
|
||||
## 📚 RAG 文档问答
|
||||
|
||||
### 工作原理
|
||||
@ -157,40 +184,30 @@ curl http://localhost:8080/api/documents/import
|
||||
用户问题 → Embedding → 向量检索 → 构建上下文 → LLM 生成回答 → 返回答案+引用
|
||||
```
|
||||
|
||||
1. **文档导入** - 将 `.md` / `.txt` 文档读取并切割成 chunks
|
||||
2. **向量化** - 使用 BAAI/bge-large-zh-v1.5 生成 1024 维向量
|
||||
3. **存储检索** -存入 Milvus,向量相似度搜索 top-K 结果
|
||||
4. **生成回答** - 将检索结果作为上下文,LLM 生成答案
|
||||
|
||||
### 使用步骤
|
||||
|
||||
1. 将文档放入 `data/` 目录
|
||||
2. 调用导入接口:`curl http://localhost:8080/api/documents/import`
|
||||
3. 通过 POST `/api/chat` 提问
|
||||
|
||||
### 文档处理配置
|
||||
|
||||
```yaml
|
||||
document:
|
||||
data-path: data # 文档目录
|
||||
chunk-size: 400 # 分割块大小 (tokens)
|
||||
min-chunk-size: 200 # 最小块大小
|
||||
max-num-chunk: 10000 # 最大块数量
|
||||
```
|
||||
1. **方式一**:将 `.md` / `.txt` 文档放入 `data/` 目录,调用导入接口
|
||||
2. **方式二**:直接上传 PDF 文件,系统自动提取文本并向量化
|
||||
3. 通过 POST `/api/chat` 或流式 `/api/chat/stream` 提问
|
||||
|
||||
## 🛠️ 配置说明
|
||||
|
||||
### AI 对话配置
|
||||
### CORS 跨域配置
|
||||
|
||||
```yaml
|
||||
cors:
|
||||
allowed-origins: http://localhost:8081,http://localhost:5173
|
||||
allowed-methods: GET,POST,PUT,DELETE,OPTIONS
|
||||
```
|
||||
|
||||
### 文件上传配置
|
||||
|
||||
```yaml
|
||||
spring:
|
||||
ai:
|
||||
openai:
|
||||
base-url: http://localhost:11434
|
||||
chat:
|
||||
options:
|
||||
model: gpt-oss:120b-cloud
|
||||
temperature: 0.7
|
||||
servlet:
|
||||
multipart:
|
||||
max-file-size: 100MB
|
||||
max-request-size: 100MB
|
||||
```
|
||||
|
||||
### Embedding 配置
|
||||
@ -202,24 +219,10 @@ spring:
|
||||
embedding:
|
||||
api-key: your-siliconflow-api-key
|
||||
base-url: https://api.siliconflow.cn
|
||||
model: BAAI/bge-large-zh-v1.5
|
||||
model: BAAI/bge-m3
|
||||
dimensions: 1024
|
||||
```
|
||||
|
||||
### 向量数据库配置
|
||||
|
||||
```yaml
|
||||
spring:
|
||||
ai:
|
||||
vectorstore:
|
||||
milvus:
|
||||
client:
|
||||
host: 192.168.50.103
|
||||
port: 19530
|
||||
databaseName: doris_docs
|
||||
collectionName: vector_store
|
||||
```
|
||||
|
||||
## 🎬 架构图
|
||||
|
||||
```
|
||||
@ -227,7 +230,7 @@ spring:
|
||||
│ Client │
|
||||
│ ┌─────────────────┐ ┌─────────────────────────────────┐│
|
||||
│ │ Web UI │ │ POST /api/chat ││
|
||||
│ │ (流式/非流式) │────▶│ { question: "..." } ││
|
||||
│ │ (流式/非流式) │────▶│ POST /api/chat/stream (SSE) ││
|
||||
│ └─────────────────┘ └─────────────────────────────────┘│
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
@ -242,12 +245,6 @@ spring:
|
||||
┌─────────────────────────┐ ┌─────────────────────────┐
|
||||
│ VectorStore │ │ ChatClient │
|
||||
│ (Milvus 语义检索) │ │ (Ollama LLM) │
|
||||
└─────────────────────────┘ └─────────────────────────┘
|
||||
│ │
|
||||
▼ ▼
|
||||
┌─────────────────────────┐ ┌─────────────────────────┐
|
||||
│ DocumentService │ │ SiliconFlow API │
|
||||
│ (文档切割/向量化) │ │ (Embedding) │
|
||||
└─────────────────────────┘ └─────────────────────────┘
|
||||
```
|
||||
|
||||
@ -261,15 +258,22 @@ spring:
|
||||
|
||||
### v1.1.0 (2026-04-19)
|
||||
- ✨ 新增 RAG 文档问答功能
|
||||
- ✨ 新增 DocumentController 文档导入 API
|
||||
- ✨ 配置 SiliconFlow Embedding 服务
|
||||
- ✨ 集成 Milvus 向量数据库
|
||||
|
||||
### v1.2.0 (2026-04-19)
|
||||
- ✨ 新增 POST /api/chat RAG 智能问答接口
|
||||
- ✨ 新增 ChatService RAG 核心服务
|
||||
- ✨ 新增 ChatRequest/ChatResponse/ApiResponse DTO
|
||||
- ✨ 新增引用溯源功能,返回文档来源
|
||||
- ✨ 新增引用溯源功能
|
||||
|
||||
### v1.3.0 (2026-04-20)
|
||||
- ✨ 新增 POST /api/chat/stream 流式 RAG 问答接口 (SSE)
|
||||
- ✨ 新增 PDF 文件上传接口 (POST /api/documents/upload)
|
||||
- ✨ 新增带进度流式 PDF 上传 (POST /api/documents/upload/stream)
|
||||
- ✨ 新增 CorsConfig 跨域配置
|
||||
- ✨ embedding 模型更新为 BAAI/bge-m3
|
||||
- ✨ ChatRequest topK 默认值调整为 10
|
||||
- 🐛 修复文件重复上传问题 (MD5 去重)
|
||||
|
||||
## 📗 License
|
||||
|
||||
|
||||
41
src/main/java/com/demo/config/CorsConfig.java
Normal file
41
src/main/java/com/demo/config/CorsConfig.java
Normal file
@ -0,0 +1,41 @@
|
||||
package com.demo.config;
|
||||
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.web.cors.CorsConfiguration;
|
||||
import org.springframework.web.cors.UrlBasedCorsConfigurationSource;
|
||||
import org.springframework.web.filter.CorsFilter;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
@Configuration
|
||||
public class CorsConfig {
|
||||
|
||||
@Value("${cors.allowed-origins}")
|
||||
private String allowedOrigins;
|
||||
|
||||
@Value("${cors.allowed-methods}")
|
||||
private String allowedMethods;
|
||||
|
||||
@Value("${cors.allowed-headers}")
|
||||
private String allowedHeaders;
|
||||
|
||||
@Value("${cors.allow-credentials}")
|
||||
private boolean allowCredentials;
|
||||
|
||||
@Bean
|
||||
public CorsFilter corsFilter() {
|
||||
CorsConfiguration config = new CorsConfiguration();
|
||||
config.setAllowedOrigins(Arrays.asList(allowedOrigins.split(",")));
|
||||
config.setAllowedMethods(Arrays.asList(allowedMethods.split(",")));
|
||||
config.setAllowedHeaders(Arrays.asList(allowedHeaders.split(",")));
|
||||
config.setAllowCredentials(allowCredentials);
|
||||
config.setMaxAge(3600L);
|
||||
|
||||
UrlBasedCorsConfigurationSource source = new UrlBasedCorsConfigurationSource();
|
||||
source.registerCorsConfiguration("/**", config);
|
||||
|
||||
return new CorsFilter(source);
|
||||
}
|
||||
}
|
||||
@ -4,8 +4,11 @@ import com.demo.dto.ApiResponse;
|
||||
import com.demo.dto.ChatRequest;
|
||||
import com.demo.dto.ChatResponse;
|
||||
import com.demo.service.ChatService;
|
||||
import jakarta.validation.Valid;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.codec.ServerSentEvent;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
@ -43,6 +46,30 @@ public class ChatController {
|
||||
return ApiResponse.success(response);
|
||||
}
|
||||
|
||||
/**
|
||||
* 流式问答接口 - 直接返回纯文本流
|
||||
*/
|
||||
@PostMapping(value = "/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
||||
public Flux<ServerSentEvent<String>> chatStream(@Valid @RequestBody ChatRequest request) {
|
||||
log.info("接收到流式问答请求: {}", request.getQuestion());
|
||||
|
||||
return chatService.chatStream(request)
|
||||
.map(chunk -> ServerSentEvent.<String>builder()
|
||||
.data(chunk)
|
||||
.build())
|
||||
.doOnSubscribe(s -> log.info("开始流式传输"))
|
||||
.doOnNext(chunk -> log.debug("发送文本块: {}", chunk))
|
||||
.doOnComplete(() -> log.info("流式传输完成"))
|
||||
.doOnError(e -> log.error("流式传输错误", e))
|
||||
.onErrorResume(e -> {
|
||||
log.error("流式问答处理失败", e);
|
||||
// 错误信息也一并包装
|
||||
return Flux.just(ServerSentEvent.<String>builder()
|
||||
.data("\n\n[错误: " + e.getMessage() + "]")
|
||||
.build());
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 流式的聊天接口,要注意如果中文有乱码,就是编码得问题,需要添加produces = "text/html;charset=UTF-8
|
||||
* @param msg
|
||||
|
||||
@ -1,15 +1,21 @@
|
||||
package com.demo.controller;
|
||||
|
||||
import com.demo.dto.ApiResponse;
|
||||
import com.demo.dto.FileUploadProgress;
|
||||
import com.demo.dto.FileUploadResponse;
|
||||
import com.demo.service.DocumentService;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.GetMapping;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
import org.springframework.web.multipart.MultipartFile;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
|
||||
@Slf4j
|
||||
@RestController
|
||||
@RequestMapping("api/documents")
|
||||
@RequiredArgsConstructor
|
||||
@ -18,7 +24,7 @@ public class DocumentController {
|
||||
@Autowired
|
||||
public final DocumentService documentService;
|
||||
|
||||
@GetMapping("import")
|
||||
@PostMapping("import")
|
||||
public String importDocument() {
|
||||
try {
|
||||
documentService.importDocument();
|
||||
@ -27,5 +33,69 @@ public class DocumentController {
|
||||
}
|
||||
return "ok";
|
||||
}
|
||||
@PostMapping(value = "/upload", consumes = MediaType.MULTIPART_FORM_DATA_VALUE)
|
||||
public ResponseEntity<ApiResponse<FileUploadResponse>> uploadPdfFile(
|
||||
@RequestParam("file") MultipartFile file) {
|
||||
try {
|
||||
log.info("收到PDF文件上传请求: {}", file.getOriginalFilename());
|
||||
|
||||
// 验证文件
|
||||
if (file.isEmpty()) {
|
||||
log.error("上传文件为空");
|
||||
return ResponseEntity.badRequest()
|
||||
.body(ApiResponse.error("上传文件不能为空"));
|
||||
}
|
||||
|
||||
// 检查文件格式
|
||||
String contentType = file.getContentType();
|
||||
String originalFilename = file.getOriginalFilename();
|
||||
|
||||
if (originalFilename == null || !originalFilename.toLowerCase().endsWith(".pdf")) {
|
||||
log.error("文件格式错误: {}", originalFilename);
|
||||
return ResponseEntity.badRequest()
|
||||
.body(ApiResponse.error("只能上传PDF格式文件"));
|
||||
}
|
||||
|
||||
// 检查文件大小 (30MB)
|
||||
if (file.getSize() > 30 * 1024 * 1024) {
|
||||
log.error("文件大小超出限制: {} bytes", file.getSize());
|
||||
return ResponseEntity.badRequest()
|
||||
.body(ApiResponse.error("文件大小不能超过30MB"));
|
||||
}
|
||||
|
||||
// 处理文件上传
|
||||
FileUploadResponse response = documentService.uploadPdfFile(file);
|
||||
|
||||
if (response.isDuplicate()) {
|
||||
log.info("文件已存在,跳过处理: {}", originalFilename);
|
||||
return ResponseEntity.ok(ApiResponse.success("文件已存在,无需重复上传", response));
|
||||
}
|
||||
|
||||
log.info("文件上传成功: {}, 生成了 {} 个文档片段",
|
||||
originalFilename, response.getChunkCount());
|
||||
return ResponseEntity.ok(ApiResponse.success("文件上传成功", response));
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("文件上传失败", e);
|
||||
return ResponseEntity.internalServerError()
|
||||
.body(ApiResponse.error("文件上传失败: " + e.getMessage()));
|
||||
}
|
||||
}
|
||||
|
||||
@PostMapping(value = "/upload/stream", consumes = MediaType.MULTIPART_FORM_DATA_VALUE)
|
||||
public Flux<FileUploadProgress> uploadPdfFileWithProgress(
|
||||
@RequestParam("file") MultipartFile file) {
|
||||
try {
|
||||
log.info("收到PDF文件上传请求(流式): {}", file.getOriginalFilename());
|
||||
return documentService.uploadPdfFileWithProgress(file);
|
||||
} catch (Exception e) {
|
||||
log.error("文件上传失败", e);
|
||||
return Flux.just(FileUploadProgress.builder()
|
||||
.status("FAILED")
|
||||
.message("文件上传失败: " + e.getMessage())
|
||||
.percentage(0)
|
||||
.build());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -8,6 +8,6 @@ public class ChatRequest {
|
||||
|
||||
@NotBlank(message = "请输入你的问题")
|
||||
private String question;
|
||||
private Integer topK = 3;
|
||||
private Integer topK = 10;
|
||||
|
||||
}
|
||||
|
||||
19
src/main/java/com/demo/dto/FileUploadProgress.java
Normal file
19
src/main/java/com/demo/dto/FileUploadProgress.java
Normal file
@ -0,0 +1,19 @@
|
||||
package com.demo.dto;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class FileUploadProgress {
|
||||
private String fileName;
|
||||
private long bytesProcessed;
|
||||
private long totalBytes;
|
||||
private int percentage;
|
||||
private String status; // PROCESSING, COMPLETED, FAILED
|
||||
private String message;
|
||||
}
|
||||
19
src/main/java/com/demo/dto/FileUploadResponse.java
Normal file
19
src/main/java/com/demo/dto/FileUploadResponse.java
Normal file
@ -0,0 +1,19 @@
|
||||
package com.demo.dto;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class FileUploadResponse {
|
||||
private String fileName;
|
||||
private String fileSize;
|
||||
private String md5Hash;
|
||||
private int chunkCount;
|
||||
private String message;
|
||||
private boolean duplicate;
|
||||
}
|
||||
@ -2,7 +2,6 @@ package com.demo.service;
|
||||
|
||||
import com.demo.dto.ChatRequest;
|
||||
import com.demo.dto.ChatResponse;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
@ -14,7 +13,9 @@ import org.springframework.ai.vectorstore.SearchRequest;
|
||||
import org.springframework.ai.vectorstore.VectorStore;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
@ -98,4 +99,54 @@ public class ChatService {
|
||||
.timestamp(System.currentTimeMillis())
|
||||
.build();
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理用户问题(流式响应 - 直接返回原始文本)
|
||||
*/
|
||||
public Flux<String> chatStream(ChatRequest request) {
|
||||
log.info("收到流式问题: {}", request.getQuestion());
|
||||
|
||||
try {
|
||||
// 1. 向量检索相关文档
|
||||
List<Document> relevantDocs = vectorStore.similaritySearch(
|
||||
SearchRequest.query(request.getQuestion())
|
||||
.withTopK(request.getTopK())
|
||||
);
|
||||
|
||||
log.info("检索到 {} 个相关文档", relevantDocs.size());
|
||||
|
||||
// 2. 构建上下文
|
||||
String context = relevantDocs.isEmpty() ?
|
||||
"暂无相关文档,请基于你的知识回答。" :
|
||||
relevantDocs.stream()
|
||||
.map(Document::getContent)
|
||||
.collect(Collectors.joining("\n\n---\n\n"));
|
||||
|
||||
// 3. 构建提示词
|
||||
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(SYSTEM_PROMPT);
|
||||
Message systemMessage = systemPromptTemplate.createMessage(Map.of("context", context));
|
||||
UserMessage userMessage = new UserMessage(request.getQuestion());
|
||||
|
||||
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
|
||||
|
||||
// 4. 返回原始文本流
|
||||
log.info("开始流式生成答案");
|
||||
|
||||
return chatClient.prompt(prompt)
|
||||
.stream()
|
||||
.content()
|
||||
// .doOnNext(chunk -> log.debug("发送文本块: {}", chunk))
|
||||
.doOnComplete(() -> log.info("流式传输完成"))
|
||||
.doOnError(e -> log.error("流式生成失败", e))
|
||||
.onErrorResume(e -> {
|
||||
log.error("流式响应错误", e);
|
||||
return Flux.just("\n\n[错误: " + e.getMessage() + "]");
|
||||
})
|
||||
.timeout(Duration.ofSeconds(60));
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("流式问答处理异常", e);
|
||||
return Flux.just("[错误: " + e.getMessage() + "]");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,20 +1,29 @@
|
||||
package com.demo.service;
|
||||
|
||||
import com.demo.dto.FileUploadProgress;
|
||||
import com.demo.dto.FileUploadResponse;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.pdfbox.pdmodel.PDDocument;
|
||||
import org.apache.pdfbox.text.PDFTextStripper;
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.reader.TextReader;
|
||||
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
|
||||
import org.springframework.ai.vectorstore.SearchRequest;
|
||||
import org.springframework.ai.vectorstore.VectorStore;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.core.io.FileSystemResource;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.web.multipart.MultipartFile;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.security.MessageDigest;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.Stream;
|
||||
@ -79,4 +88,255 @@ public class DocumentService {
|
||||
vectorStore.add(allDocuments);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 上传PDF文件
|
||||
*/
|
||||
public FileUploadResponse uploadPdfFile(MultipartFile file) throws IOException, NoSuchAlgorithmException {
|
||||
String originalFilename = file.getOriginalFilename();
|
||||
long fileSize = file.getSize();
|
||||
|
||||
log.info("开始处理PDF文件: {}, 大小: {} bytes", originalFilename, fileSize);
|
||||
|
||||
// 计算文件MD5
|
||||
String md5Hash = calculateFileMD5(file);
|
||||
log.info("文件MD5: {}", md5Hash);
|
||||
|
||||
// 检查文件是否已存在
|
||||
if (isFileAlreadyUploaded(md5Hash)) {
|
||||
log.info("文件已存在,跳过处理: {}", originalFilename);
|
||||
return FileUploadResponse.builder()
|
||||
.fileName(originalFilename)
|
||||
.fileSize(formatFileSize(fileSize))
|
||||
.md5Hash(md5Hash)
|
||||
.chunkCount(0)
|
||||
.message("文件已存在,无需重复上传")
|
||||
.duplicate(true)
|
||||
.build();
|
||||
}
|
||||
|
||||
// 提取PDF文本
|
||||
String pdfText = extractPdfText(file);
|
||||
log.info("PDF文本提取完成,长度: {} 字符", pdfText.length());
|
||||
|
||||
// 创建文档对象
|
||||
Document document = new Document(pdfText);
|
||||
document.getMetadata().put("source", originalFilename);
|
||||
document.getMetadata().put("file_size", String.valueOf(fileSize));
|
||||
document.getMetadata().put("file_hash", md5Hash);
|
||||
document.getMetadata().put("upload_time", String.valueOf(System.currentTimeMillis()));
|
||||
|
||||
// 文本分割
|
||||
List<Document> splitDocuments = textSplitter.apply(List.of(document));
|
||||
log.info("文本分割完成,生成 {} 个文档片段", splitDocuments.size());
|
||||
|
||||
// 为每个分割的文档添加元数据
|
||||
splitDocuments.forEach(doc -> {
|
||||
doc.getMetadata().put("source", originalFilename);
|
||||
doc.getMetadata().put("file_size", String.valueOf(fileSize));
|
||||
doc.getMetadata().put("file_hash", md5Hash);
|
||||
doc.getMetadata().put("upload_time", String.valueOf(System.currentTimeMillis()));
|
||||
});
|
||||
|
||||
// 存储到向量数据库
|
||||
vectorStore.add(splitDocuments);
|
||||
log.info("文档存储到向量数据库完成");
|
||||
|
||||
return FileUploadResponse.builder()
|
||||
.fileName(originalFilename)
|
||||
.fileSize(formatFileSize(fileSize))
|
||||
.md5Hash(md5Hash)
|
||||
.chunkCount(splitDocuments.size())
|
||||
.message("文件上传成功")
|
||||
.duplicate(false)
|
||||
.build();
|
||||
}
|
||||
|
||||
/**
|
||||
* 带进度上传PDF文件
|
||||
*/
|
||||
public Flux<FileUploadProgress> uploadPdfFileWithProgress(MultipartFile file) {
|
||||
return Flux.create(sink -> {
|
||||
try {
|
||||
String originalFilename = file.getOriginalFilename();
|
||||
long fileSize = file.getSize();
|
||||
|
||||
log.info("开始流式处理PDF文件: {}", originalFilename);
|
||||
|
||||
// 发送开始进度
|
||||
sink.next(FileUploadProgress.builder()
|
||||
.fileName(originalFilename)
|
||||
.bytesProcessed(0)
|
||||
.totalBytes(fileSize)
|
||||
.percentage(0)
|
||||
.status("PROCESSING")
|
||||
.message("开始处理文件...")
|
||||
.build());
|
||||
|
||||
// 计算文件MD5 (20%)
|
||||
sink.next(FileUploadProgress.builder()
|
||||
.fileName(originalFilename)
|
||||
.bytesProcessed(fileSize / 5)
|
||||
.totalBytes(fileSize)
|
||||
.percentage(20)
|
||||
.status("PROCESSING")
|
||||
.message("计算文件哈希值...")
|
||||
.build());
|
||||
|
||||
String md5Hash = calculateFileMD5(file);
|
||||
|
||||
// 检查文件是否已存在
|
||||
if (isFileAlreadyUploaded(md5Hash)) {
|
||||
sink.next(FileUploadProgress.builder()
|
||||
.fileName(originalFilename)
|
||||
.bytesProcessed(fileSize)
|
||||
.totalBytes(fileSize)
|
||||
.percentage(100)
|
||||
.status("COMPLETED")
|
||||
.message("文件已存在,无需重复上传")
|
||||
.build());
|
||||
sink.complete();
|
||||
return;
|
||||
}
|
||||
|
||||
// 提取PDF文本 (40%)
|
||||
sink.next(FileUploadProgress.builder()
|
||||
.fileName(originalFilename)
|
||||
.bytesProcessed(fileSize * 2 / 5)
|
||||
.totalBytes(fileSize)
|
||||
.percentage(40)
|
||||
.status("PROCESSING")
|
||||
.message("提取PDF文本内容...")
|
||||
.build());
|
||||
|
||||
String pdfText = extractPdfText(file);
|
||||
|
||||
// 创建文档对象
|
||||
Document document = new Document(pdfText);
|
||||
document.getMetadata().put("source", originalFilename);
|
||||
document.getMetadata().put("file_size", String.valueOf(fileSize));
|
||||
document.getMetadata().put("file_hash", md5Hash);
|
||||
document.getMetadata().put("upload_time", String.valueOf(System.currentTimeMillis()));
|
||||
|
||||
// 文本分割 (60%)
|
||||
sink.next(FileUploadProgress.builder()
|
||||
.fileName(originalFilename)
|
||||
.bytesProcessed(fileSize * 3 / 5)
|
||||
.totalBytes(fileSize)
|
||||
.percentage(60)
|
||||
.status("PROCESSING")
|
||||
.message("分割文本内容...")
|
||||
.build());
|
||||
|
||||
List<Document> splitDocuments = textSplitter.apply(List.of(document));
|
||||
|
||||
// 为每个分割的文档添加元数据
|
||||
splitDocuments.forEach(doc -> {
|
||||
doc.getMetadata().put("source", originalFilename);
|
||||
doc.getMetadata().put("file_size", String.valueOf(fileSize));
|
||||
doc.getMetadata().put("file_hash", md5Hash);
|
||||
doc.getMetadata().put("upload_time", String.valueOf(System.currentTimeMillis()));
|
||||
});
|
||||
|
||||
// 存储到向量数据库 (80%)
|
||||
sink.next(FileUploadProgress.builder()
|
||||
.fileName(originalFilename)
|
||||
.bytesProcessed(fileSize * 4 / 5)
|
||||
.totalBytes(fileSize)
|
||||
.percentage(80)
|
||||
.status("PROCESSING")
|
||||
.message("存储到向量数据库...")
|
||||
.build());
|
||||
// 分批提交文档,减小embedding的压力
|
||||
List<Document> tmpDocuments = new ArrayList<>();
|
||||
for (int i = 0; i < splitDocuments.size(); i++) {
|
||||
tmpDocuments.add(splitDocuments.get(i));
|
||||
if (i % 10 == 0) {
|
||||
vectorStore.add(tmpDocuments);
|
||||
tmpDocuments.clear();
|
||||
}
|
||||
}
|
||||
vectorStore.add(tmpDocuments);
|
||||
|
||||
// 完成 (100%)
|
||||
sink.next(FileUploadProgress.builder()
|
||||
.fileName(originalFilename)
|
||||
.bytesProcessed(fileSize)
|
||||
.totalBytes(fileSize)
|
||||
.percentage(100)
|
||||
.status("COMPLETED")
|
||||
.message("文件上传成功,生成了 " + splitDocuments.size() + " 个文档片段")
|
||||
.build());
|
||||
|
||||
sink.complete();
|
||||
log.info("流式文件上传完成: {}, 生成了 {} 个文档片段",
|
||||
originalFilename, splitDocuments.size());
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("流式文件上传失败", e);
|
||||
sink.next(FileUploadProgress.builder()
|
||||
.status("FAILED")
|
||||
.message("文件上传失败: " + e.getMessage())
|
||||
.percentage(0)
|
||||
.build());
|
||||
sink.complete();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 计算文件MD5哈希值
|
||||
*/
|
||||
private String calculateFileMD5(MultipartFile file) throws IOException, NoSuchAlgorithmException {
|
||||
MessageDigest md = MessageDigest.getInstance("MD5");
|
||||
byte[] fileBytes = file.getBytes();
|
||||
byte[] digest = md.digest(fileBytes);
|
||||
|
||||
StringBuilder sb = new StringBuilder();
|
||||
for (byte b : digest) {
|
||||
sb.append(String.format("%02x", b));
|
||||
}
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查文件是否已上传
|
||||
*/
|
||||
private boolean isFileAlreadyUploaded(String md5Hash) {
|
||||
try {
|
||||
// 使用向量存储的相似度搜索功能来检查是否存在相同MD5的文件
|
||||
// 这里简化处理,实际应用中可能需要更高效的查询方式
|
||||
List<Document> results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(3));
|
||||
|
||||
// 检查已存在的文档中是否有相同MD5的文件
|
||||
return results.stream()
|
||||
.anyMatch(doc -> md5Hash.equals(doc.getMetadata().get("file_hash")));
|
||||
} catch (Exception e) {
|
||||
log.warn("检查文件重复性失败", e);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 提取PDF文本内容
|
||||
*/
|
||||
private String extractPdfText(MultipartFile file) throws IOException {
|
||||
try (PDDocument document = PDDocument.load(file.getInputStream())) {
|
||||
PDFTextStripper stripper = new PDFTextStripper();
|
||||
return stripper.getText(document);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 格式化文件大小
|
||||
*/
|
||||
private String formatFileSize(long size) {
|
||||
if (size < 1024) {
|
||||
return size + " B";
|
||||
} else if (size < 1024 * 1024) {
|
||||
return String.format("%.1f KB", size / 1024.0);
|
||||
} else {
|
||||
return String.format("%.1f MB", size / (1024.0 * 1024.0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -2,6 +2,10 @@ server:
|
||||
port: 8080
|
||||
|
||||
spring:
|
||||
servlet:
|
||||
multipart:
|
||||
max-file-size: 100MB
|
||||
max-request-size: 100MB
|
||||
ai:
|
||||
openai:
|
||||
api-key: ollama
|
||||
@ -19,7 +23,7 @@ spring:
|
||||
api-key: key
|
||||
base-url: https://api.siliconflow.cn
|
||||
options:
|
||||
model: BAAI/bge-large-zh-v1.5
|
||||
model: BAAI/bge-m3
|
||||
dimensions: 1024
|
||||
enable: true
|
||||
vectorstore:
|
||||
@ -39,4 +43,11 @@ document:
|
||||
data-path: data
|
||||
chunk-size: 400
|
||||
min-chunk-size: 200
|
||||
max-num-chunk: 10000
|
||||
max-num-chunk: 10000
|
||||
|
||||
# CORS 配置
|
||||
cors:
|
||||
allowed-origins: ${CORS_ORIGINS:http://localhost:8081,http://localhost:5173}
|
||||
allowed-methods: GET,POST,PUT,DELETE,OPTIONS
|
||||
allowed-headers: "*"
|
||||
allow-credentials: true
|
||||
Loading…
Reference in New Issue
Block a user