69 lines
2.7 KiB
Python
69 lines
2.7 KiB
Python
import logging
|
||
|
||
from langchain.chains.retrieval_qa.base import RetrievalQA
|
||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||
import os
|
||
import dotenv
|
||
from langchain_text_splitters import CharacterTextSplitter
|
||
from langchain_community.vectorstores import FAISS
|
||
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||
)
|
||
dotenv.load_dotenv()
|
||
|
||
## 设置环境变量
|
||
os.environ['OPENAI_API_KEY'] = os.getenv("SILICONFLOW_API_KEY")
|
||
os.environ['OPENAI_BASE_URL'] = os.getenv("SILICONFLOW_BASE_URL")
|
||
|
||
# 默认的 'model_name': 'deepseek-ai/DeepSeek-V3.1',
|
||
llm = ChatOpenAI(model="deepseek-ai/DeepSeek-R1-0528-Qwen3-8B")
|
||
|
||
## 1. 准备某个领域的文档:测试相关的知识
|
||
docs = [
|
||
"等价类划分是一种黑盒测试方法,将输入数据划分为有效等价类和无效等价类。",
|
||
"边界值分析通常作为等价类划分的补充,重点测试输入输出的边界条件。",
|
||
"集成测试用于验证模块间接口的正确性,常见策略包括自顶向下和自底向上。",
|
||
"回归测试是在软件变更后执行的测试,确保原有功能不受新修改影响。",
|
||
"性能测试包括负载测试、压力测试和耐久性测试,用于评估系统的响应能力。",
|
||
"测试用例应包含测试ID、模块、前置条件、步骤、预期结果和优先级信息。"
|
||
]
|
||
|
||
splitter = CharacterTextSplitter()
|
||
|
||
## 2. 切割文档(可选)
|
||
texts = []
|
||
for doc in docs:
|
||
chunks = splitter.split_text(doc)
|
||
texts.extend(chunks)
|
||
logging.info("文档切分:原文=%s -> %d 个分片",doc,len(chunks))
|
||
logging.info(texts)
|
||
|
||
## 3. embedding(向量化) 以及 建立向量库
|
||
embeddings = OpenAIEmbeddings(model="netease-youdao/bce-embedding-base_v1")
|
||
## 第一次调用 embedding模型:HTTP Request: POST https://api.siliconflow.cn/v1/embeddings "HTTP/1.1 200 OK"
|
||
vectorstore = FAISS.from_texts(texts,embeddings)
|
||
logging.info("构建向量数据库完成")
|
||
logging.info(vectorstore)
|
||
|
||
## 4. 构建 RAG的调用链 k参数: topK
|
||
retriever = vectorstore.as_retriever(search_type='similarity',search_kwargs={"k":2})
|
||
## HTTP Request: POST https://api.siliconflow.cn/v1/embeddings "HTTP/1.1 200 OK"
|
||
chain = RetrievalQA.from_chain_type(llm=llm,retriever=retriever)
|
||
|
||
query = "什么是等价类划分?"
|
||
|
||
## 检索过程探索
|
||
retrieved_docs = retriever.get_relevant_documents(query)
|
||
logging.info("---------")
|
||
for retrieved_doc in retrieved_docs:
|
||
logging.info(retrieved_doc)
|
||
logging.info("---------")
|
||
|
||
## 5. 查询数据(通过模型自己去查数据库)
|
||
## HTTP Request: POST https://api.siliconflow.cn/v1/chat/completions "HTTP/1.1 200 OK"
|
||
response = chain.invoke(query)
|
||
logging.info(response)
|
||
|