Self-RAG 系统实现详解
一、Self-RAG 简介
Self-RAG(Self-Reflective Retrieval-Augmented Generation)赋予大模型 “自我反思与修正” 能力,可有效改善检索不准确、生成结果不可靠等问题。
二、Self-RAG 工作流程
检索文档:携带初始或修改后的问题进行文档检索,获取文档后开展上下文相关性评估。
上下文评估:
- 若评估为不相关,且此前未执行过 Query2Doc 转换,则进行 Query2Doc 转换,随后返回步骤 1 重新检索。若评估为相关,或虽不相关但已尝试过 Query2Doc 转换,则筛选出相关文档,进入生成答案环节 。
生成答案:依据筛选出的相关文档生成答案。
评估答案:
检查答案是否基于上下文(Supported) :
- 若 “否”(存在幻觉情况),返回步骤 3 重新生成答案。若 “是”,继续检查答案是否有用。
检查答案是否有用(Useful) :
若 “是”,流程结束(END)。
若 “否”(答案跑题或无效),检查查询重写次数:
- 未达重写上限,执行查询重写(Rewrite Query),返回步骤 1 重新检索。已达重写上限,直接结束(END),不再进行修正尝试。
三、实现所需节点
- 检索节点:基于问题检索信息。上下文评估节点:评估检索结果与问题的相关性,决定后续操作。生成答案节点:在相关性评估通过后生成答案。评估答案是否有用节点:判断答案有效性,决定是否进入查询重写流程。转换查询节点(query2doc) :在上下文评估不通过或答案无用时,转换查询内容 。重写查询节点:对无效答案对应的查询进行重写。结束节点:当答案有用或与问题完全不相关时,终止流程。
四、节点实现详解
1. 检索节点
思考过程:
输入:用户问题
操作:在向量数据库检索相似文档
函数接收参数:用户问题、最大查询数
检索方法:通过向量数据库的 similarity_search
获取目标列表
上下文构建:遍历 related_doc
,字符串拼接文档内容形成 context
输出:初步检索的文档列表
实现代码:
python
def rag_retrieve(question, k=3): related_docs = zhidu_db.similarity_search(question, k=3) context = "\n".join([f"上下文{i+1}: {doc.page_content} \n" for i, doc in enumerate(related_docs)]) return related_docs, context
2. 转换查询节点
思考过程:
输入:LangGraph 节点函数的 State 对象(含流程所有状态信息)
操作:调用 Query2doc
传入用户问题,通过提示词模板解答问题
输出:Query2doc
函数返回的答案
实现代码:
def transform_query2doc(state): print("---transform_query2doc---") # node input state_dict = state["keys"] question = state_dict["question"] # 获取原始问题 documents = state_dict["documents"] # 获取当前文档(后续使用) context = state_dict["context"] # 获取当前上下文(后续使用) query2doc_count = state_dict.get("query2doc_count", 0) # 获取转换计数 rewrite_count = state_get("rewrite_count", 0) # 获取重写计数 # task - 核心操作! context_query = query2doc(question) # node output return {"keys": {"context": context, "documents": documents, "question": question, # 保留原始问题 "context_query": context_query, # 添加转换后的查询 "query2doc_count": query2doc_count + 1, # 增加转换计数 "rewrite_count": rewrite_count}} # 保持重写计数不变
3. 重写查询节点
思考过程:
输入:LangGraph 节点函数的 State 对象(含流程所有状态信息)
过程:调用 question_rewrite
传入用户问题,借助提示词模板解答
输出:question_rewrite
函数返回的答案,统计并限制重写次数避免死循环
实现代码:
def transform_query_rewrite(state): print("---transform_query---") state_dict = state["keys"] question = state_dict["question"] documents = state_dict["documents"] context = state_dict["context"] query2doc_count = state_dict.get("query2doc_count", 0) rewrite_count = state_dict.get("rewrite_count", 0) context_query = question_rewrite(question) return { "keys": { "context": context, "documents": documents, "question": question, "context_query": context_query, "query2doc_count": query2doc_count, "rewrite_count": rewrite_count + 1 } }
4. 上下文评估节点
思考过程:
输入:LangGraph 节点函数的标准输入 state 对象,提取用户问题和检索文档
操作:遍历文档列表,评估每个文档与问题的相关性,筛选相关文档
输出:更新后的 state 字典,包含仅含相关文档的 documents
列表、基于相关文档构建的 context
字符串、可能的标识位及其他状态信息
实现代码:
def grade_documents(state): print("---Determines whether the retrieved documents are relevant to the question---") state_dict = state["keys"] question, documents = state_dict["question"], state_dict["documents"] query2doc_count = state_dict.get("query2doc_count", 0) rewrite_count = state_dict.get("rewrite_count", 0) filtered_docs, retrieve_enhance = [], "No" for d in documents: grade = context_grade_chain.invoke({"question": question, "context": d.page_content}) print(f"Document (first 50): {d.page_content[:50]}... Grade: {grade}") if "yes" in grade.lower(): filtered_docs.append(d) else: retrieve_enhance = "Yes" if query2doc_count > 0: retrieve_enhance = "No" context = "\n".join([f"上下文{i+1}: {doc.page_content} \n" for i, doc in enumerate(filtered_docs)]) if filtered_docs else "" return { "keys": { "context": context, "documents": filtered_docs, "question": question, "run_retrieve_enhance": retrieve_enhance, "query2doc_count": query2doc_count, "rewrite_count": rewrite_count } }
5. 生成答案节点
思考过程:
输入:State 对象,其中包含用户问题、筛选后的文档或构建的上下文
操作:构建模板,填充问题和上下文,调用 LLM 生成答案
输出:更新后的 state 字典,新增 generation
键存储 LLM 生成的答案,其余状态不变
实现代码:
from langchain.prompts import PromptTemplate from langchain_core.output_parsers import StrOutputParserdef generate(state): print("---Generate answer---") state_dict = state["keys"] question = state_dict["question"] documents = state_dict["documents"] query2doc_count = state_dict.get("query2doc_count", 0) rewrite_count = state_dict.get("rewrite_count", 0) context = "\n".join([f"上下文{i+1}: {doc.page_content} \n" for i, doc in enumerate(documents)]) prompt = PromptTemplate( input_variables=["question", "context"], template=prompt_template ) rag_chain = prompt | llm | StrOutputParser() generation = rag_chain.invoke({"context": context, "question": question}) return { "keys": { "context": context, "question": question, "documents": documents, "generation": generation, "query2doc_count": query2doc_count, "rewrite_count": rewrite_count } }
6. 评估答案是否有用节点
思考过程:
输入:state 对象、当前问题、生成答案、相关上下文与文档、查询重写计数器参数
操作:评估答案有效性、是否基于上下文,依评估结果决定流程走向
输出:表示下一跳转节点名称的字符串,用于工作流判断
实现代码:
# --- 4. 答案评估函数部分 ---def grade_generation_v_documents_and_question(state): print("---Determines whether the answer is relevant to the question---") # 日志:开始评估答案 # a. 获取当前状态信息 state_dict = state["keys"] question = state_dict["question"] context = state_dict["context"] # 注意:这里用的是合并后的 context 字符串 generation = state_dict["generation"] # 获取上一步生成的答案 rewrite_count = state_dict.get("rewrite_count", 0) # 获取查询重写次数 # b. 第一层检查:答案是否基于上下文? print("---GRADE GENERATION vs CONTEXT (Supported?)---") # 日志:检查答案是否基于上下文 grade = answer_supported_chain.invoke({"generation": generation, "context": context}) # c. 判断第一层检查结果 if "yes" in grade.lower(): # 如果答案是基于上下文的 ("yes") print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---") # 日志:判定答案有依据 # d. 第二层检查:答案是否有用 (针对问题)? print("---GRADE GENERATION vs QUESTION (Useful?)---") # 日志:检查答案是否有用 score = answer_useful_chain.invoke({"question": question, "generation": generation}) # e. 判断第二层检查结果 if "yes" in score.lower(): # 如果答案有用 ("yes") print("---DECISION: GENERATION ADDRESSES QUESTION---") # 日志:判定答案有用 return "useful" # 返回 "useful",流程将走向 END else: # 如果答案无用 ("no") # f. 检查是否已重写过查询 if rewrite_count < 1: # 如果还没重写过 (次数小于1) print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION, REWRITE---") # 日志:判定答案无用,准备重写 return "not useful" # 返回 "not useful",流程将走向 transform_query_rewrite else: # 如果已经重写过一次或更多次 print("---DECISION: GENERATION USELESS AFTER REWRITE, END---") # 日志:重写后答案仍无用,结束 return "end" # 返回 "end",流程将走向 END (放弃治疗) else: # 如果答案不基于上下文 ("no") print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY GENERATE---") # 日志:判定答案无依据,重试生成 return "not supported" # 返回 "not supported",流程将走向 generate (重新生成)
7. 结束节点
思考过程:在 workflow.add_conditional_edges
或 workflow.add_edge
中设置流程指向 END,当流程执行到 END 时,app.stream()
或 app.invoke()
停止并返回最终状态或结果。
实现代码:
workflow.add_conditional_edges( "generate", # 源节点 grade_generation_v_documents_and_question, # 条件判断函数 { # 结果到目标节点的映射 "not supported": "generate", # 如果答案不被支持,回到 generate "useful": END, # 如果答案有用,流向 END "end": END, # 如果达到重写次数上限,也流向 END "not useful": "transform_query_rewrite", # 如果答案无用且未达上限,去重写 },)
五、将节点串联成工作流
思考过程:
回顾自适应 RAG 工作流程,明确各环节先后顺序与逻辑关系 。
采用 TypedDict
定义统一结构存储问题、文档、上下文、答案、计数器等中间状态。
确定节点函数(retrieve
、grade_documents
、generate
、transform_query2doc
、transform_query_rewrite
)与条件判断函数(decide_to_generate
、grade_generation_v_documents_and_question
)。
实例化 StateGraph
对象,注册节点、指定起点、链接节点(固定顺序用 add_edge
,条件判断用 add_conditional_edges
) 。
实现代码:
class GraphState(TypedDict): keys: Dict[str, any]workflow = StateGraph(GraphState)# 添加节点workflow.add_node("retrieve", retrieve)workflow.add_node("grade_documents", grade_documents)workflow.add_node("generate", generate)workflow.add_node("transform_query2doc", transform_query2doc)workflow.add_node("transform_query_rewrite", transform_query_rewrite)# 设置入口点workflow.set_entry_point("retrieve")# 添加固定边workflow.add_edge("retrieve", "grade_documents")workflow.add_edge("transform_query2doc", "retrieve") # Query2Doc 后重新检索workflow.add_edge("transform_query_rewrite", "retrieve") # Rewrite Query 后重新检索# 添加条件边 - 根据上下文评估结果决定走向workflow.add_conditional_edges( "grade_documents", # 源节点 decide_to_generate, # 条件判断函数 { # 结果到目标节点的映射 "transform_query2doc": "transform_query2doc", # 如果需要转换查询 "generate": "generate", # 如果可以直接生成 },)# 添加条件边 - 根据答案评估结果决定走向workflow.add_conditional_edges( "generate", # 源节点 grade_generation_v_documents_and_question, # 条件判断函数 { # 结果到目标节点的映射 "not supported": "generate", # 如果答案不被支持,回到 generate "useful": END, # 如果答案有用,流向 END "end": END, # 如果达到重写次数上限,也流向 END "not useful": "transform_query_rewrite", # 如果答案无用且未达上限,去重写 },)