langchain-learning/rag/rag_demo.py
2026-04-15 09:33:16 +08:00

69 lines
2.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
import os
import dotenv
from langchain_text_splitters import CharacterTextSplitter
from langchain_community.vectorstores import FAISS
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
dotenv.load_dotenv()
## 设置环境变量
os.environ['OPENAI_API_KEY'] = os.getenv("SILICONFLOW_API_KEY")
os.environ['OPENAI_BASE_URL'] = os.getenv("SILICONFLOW_BASE_URL")
# 默认的 'model_name': 'deepseek-ai/DeepSeek-V3.1',
llm = ChatOpenAI(model="deepseek-ai/DeepSeek-R1-0528-Qwen3-8B")
## 1. 准备某个领域的文档:测试相关的知识
docs = [
"等价类划分是一种黑盒测试方法,将输入数据划分为有效等价类和无效等价类。",
"边界值分析通常作为等价类划分的补充,重点测试输入输出的边界条件。",
"集成测试用于验证模块间接口的正确性,常见策略包括自顶向下和自底向上。",
"回归测试是在软件变更后执行的测试,确保原有功能不受新修改影响。",
"性能测试包括负载测试、压力测试和耐久性测试,用于评估系统的响应能力。",
"测试用例应包含测试ID、模块、前置条件、步骤、预期结果和优先级信息。"
]
splitter = CharacterTextSplitter()
## 2. 切割文档(可选)
texts = []
for doc in docs:
chunks = splitter.split_text(doc)
texts.extend(chunks)
logging.info("文档切分:原文=%s -> %d 个分片",doc,len(chunks))
logging.info(texts)
## 3. embedding向量化 以及 建立向量库
embeddings = OpenAIEmbeddings(model="netease-youdao/bce-embedding-base_v1")
## 第一次调用 embedding模型HTTP Request: POST https://api.siliconflow.cn/v1/embeddings "HTTP/1.1 200 OK"
vectorstore = FAISS.from_texts(texts,embeddings)
logging.info("构建向量数据库完成")
logging.info(vectorstore)
## 4. 构建 RAG的调用链 k参数 topK
retriever = vectorstore.as_retriever(search_type='similarity',search_kwargs={"k":2})
## HTTP Request: POST https://api.siliconflow.cn/v1/embeddings "HTTP/1.1 200 OK"
chain = RetrievalQA.from_chain_type(llm=llm,retriever=retriever)
query = "什么是等价类划分?"
## 检索过程探索
retrieved_docs = retriever.get_relevant_documents(query)
logging.info("---------")
for retrieved_doc in retrieved_docs:
logging.info(retrieved_doc)
logging.info("---------")
## 5. 查询数据(通过模型自己去查数据库)
## HTTP Request: POST https://api.siliconflow.cn/v1/chat/completions "HTTP/1.1 200 OK"
response = chain.invoke(query)
logging.info(response)