掘金 人工智能 前天 10:43
可以自我反思的检索增强生成
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文详细介绍了Self-RAG(Self-Reflective Retrieval-Augmented Generation)系统的实现细节,该系统赋予大模型“自我反思与修正”能力,以提高检索准确性和生成结果的可靠性。文章深入探讨了Self-RAG的工作流程,包括文档检索、上下文评估、答案生成与评估、以及查询重写等关键环节。此外,文章还剖析了实现Self-RAG所需的关键节点,如检索节点、上下文评估节点、生成答案节点等,并提供了详尽的实现代码,帮助读者理解其内部运作机制。

🧐 Self-RAG 是一种增强大模型能力的系统,通过“自我反思与修正”来改善检索和生成结果。

📄 系统工作流程包括:检索文档、上下文评估、生成答案、评估答案等步骤,循环迭代以优化结果。

🧩 实现 Self-RAG 需构建多个关键节点,如检索节点、上下文评估节点、生成答案节点等,每个节点负责特定功能。

🔄 上下文评估节点负责判断检索结果与问题的相关性,并根据评估结果进行后续操作,如转换查询或重写查询。

✅ 答案评估节点会检查答案的上下文支持性和有用性,从而决定是否需要重新生成答案或结束流程。

Self-RAG 系统实现详解

一、Self-RAG 简介

Self-RAG(Self-Reflective Retrieval-Augmented Generation)赋予大模型 “自我反思与修正” 能力,可有效改善检索不准确、生成结果不可靠等问题。

二、Self-RAG 工作流程

    检索文档:携带初始或修改后的问题进行文档检索,获取文档后开展上下文相关性评估。

    上下文评估

      若评估为不相关,且此前未执行过 Query2Doc 转换,则进行 Query2Doc 转换,随后返回步骤 1 重新检索。若评估为相关,或虽不相关但已尝试过 Query2Doc 转换,则筛选出相关文档,进入生成答案环节 。

    生成答案:依据筛选出的相关文档生成答案。

    评估答案

      检查答案是否基于上下文(Supported)

        若 “否”(存在幻觉情况),返回步骤 3 重新生成答案。若 “是”,继续检查答案是否有用。

      检查答案是否有用(Useful)

        若 “是”,流程结束(END)。

        若 “否”(答案跑题或无效),检查查询重写次数:

          未达重写上限,执行查询重写(Rewrite Query),返回步骤 1 重新检索。已达重写上限,直接结束(END),不再进行修正尝试。

三、实现所需节点

    检索节点:基于问题检索信息。上下文评估节点:评估检索结果与问题的相关性,决定后续操作。生成答案节点:在相关性评估通过后生成答案。评估答案是否有用节点:判断答案有效性,决定是否进入查询重写流程。转换查询节点(query2doc) :在上下文评估不通过或答案无用时,转换查询内容 。重写查询节点:对无效答案对应的查询进行重写。结束节点:当答案有用或与问题完全不相关时,终止流程。

四、节点实现详解

1. 检索节点

思考过程

实现代码

python

def rag_retrieve(question, k=3):     related_docs = zhidu_db.similarity_search(question, k=3)     context = "\n".join([f"上下文{i+1}: {doc.page_content} \n" for i, doc in enumerate(related_docs)])     return related_docs, context

2. 转换查询节点

思考过程

实现代码

def transform_query2doc(state):    print("---transform_query2doc---")    # node input    state_dict = state["keys"]    question = state_dict["question"] # 获取原始问题    documents = state_dict["documents"] # 获取当前文档(后续使用)    context = state_dict["context"]   # 获取当前上下文(后续使用)    query2doc_count = state_dict.get("query2doc_count", 0) # 获取转换计数    rewrite_count = state_get("rewrite_count", 0) # 获取重写计数    # task - 核心操作!    context_query = query2doc(question)    # node output    return {"keys": {"context": context,                    "documents": documents,                    "question": question, # 保留原始问题                    "context_query": context_query, # 添加转换后的查询                    "query2doc_count": query2doc_count + 1, # 增加转换计数                    "rewrite_count": rewrite_count}} # 保持重写计数不变

3. 重写查询节点

思考过程

实现代码

def transform_query_rewrite(state):     print("---transform_query---")     state_dict = state["keys"]     question = state_dict["question"]     documents = state_dict["documents"]     context = state_dict["context"]     query2doc_count = state_dict.get("query2doc_count", 0)     rewrite_count = state_dict.get("rewrite_count", 0)    context_query = question_rewrite(question)    return {        "keys": {            "context": context,            "documents": documents,            "question": question,            "context_query": context_query,            "query2doc_count": query2doc_count,            "rewrite_count": rewrite_count + 1        }    }

4. 上下文评估节点

思考过程

实现代码

def grade_documents(state):     print("---Determines whether the retrieved documents are relevant to the question---")     state_dict = state["keys"]     question, documents = state_dict["question"], state_dict["documents"]     query2doc_count = state_dict.get("query2doc_count", 0)     rewrite_count = state_dict.get("rewrite_count", 0)    filtered_docs, retrieve_enhance = [], "No"    for d in documents:        grade = context_grade_chain.invoke({"question": question, "context": d.page_content})        print(f"Document (first 50): {d.page_content[:50]}... Grade: {grade}")        if "yes" in grade.lower():            filtered_docs.append(d)        else:            retrieve_enhance = "Yes"    if query2doc_count > 0:        retrieve_enhance = "No"    context = "\n".join([f"上下文{i+1}: {doc.page_content} \n" for i, doc in enumerate(filtered_docs)]) if filtered_docs else ""    return {        "keys": {            "context": context,            "documents": filtered_docs,            "question": question,            "run_retrieve_enhance": retrieve_enhance,            "query2doc_count": query2doc_count,            "rewrite_count": rewrite_count        }    }

5. 生成答案节点

思考过程

实现代码

from langchain.prompts import PromptTemplate from langchain_core.output_parsers import StrOutputParserdef generate(state):     print("---Generate answer---")     state_dict = state["keys"]     question = state_dict["question"]     documents = state_dict["documents"]     query2doc_count = state_dict.get("query2doc_count", 0)     rewrite_count = state_dict.get("rewrite_count", 0)     context = "\n".join([f"上下文{i+1}: {doc.page_content} \n" for i, doc in enumerate(documents)])     prompt = PromptTemplate( input_variables=["question", "context"], template=prompt_template )     rag_chain = prompt | llm | StrOutputParser()     generation = rag_chain.invoke({"context": context, "question": question})     return {         "keys": {             "context": context,             "question": question,             "documents": documents,             "generation": generation,             "query2doc_count": query2doc_count,             "rewrite_count": rewrite_count         }     }

6. 评估答案是否有用节点

思考过程

实现代码

# --- 4. 答案评估函数部分 ---def grade_generation_v_documents_and_question(state):    print("---Determines whether the answer is relevant to the question---") # 日志:开始评估答案    # a. 获取当前状态信息    state_dict = state["keys"]    question = state_dict["question"]    context = state_dict["context"] # 注意:这里用的是合并后的 context 字符串    generation = state_dict["generation"] # 获取上一步生成的答案    rewrite_count = state_dict.get("rewrite_count", 0) # 获取查询重写次数    # b. 第一层检查:答案是否基于上下文?    print("---GRADE GENERATION vs CONTEXT (Supported?)---") # 日志:检查答案是否基于上下文    grade = answer_supported_chain.invoke({"generation": generation, "context": context})    # c. 判断第一层检查结果    if "yes" in grade.lower(): # 如果答案是基于上下文的 ("yes")        print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---") # 日志:判定答案有依据        # d. 第二层检查:答案是否有用 (针对问题)?        print("---GRADE GENERATION vs QUESTION (Useful?)---") # 日志:检查答案是否有用        score = answer_useful_chain.invoke({"question": question, "generation": generation})        # e. 判断第二层检查结果        if "yes" in score.lower(): # 如果答案有用 ("yes")            print("---DECISION: GENERATION ADDRESSES QUESTION---") # 日志:判定答案有用            return "useful" # 返回 "useful",流程将走向 END        else: # 如果答案无用 ("no")            # f. 检查是否已重写过查询            if rewrite_count < 1: # 如果还没重写过 (次数小于1)                print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION, REWRITE---") # 日志:判定答案无用,准备重写                return "not useful" # 返回 "not useful",流程将走向 transform_query_rewrite            else: # 如果已经重写过一次或更多次                print("---DECISION: GENERATION USELESS AFTER REWRITE, END---") # 日志:重写后答案仍无用,结束                return "end" # 返回 "end",流程将走向 END (放弃治疗)    else: # 如果答案不基于上下文 ("no")        print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY GENERATE---") # 日志:判定答案无依据,重试生成        return "not supported" # 返回 "not supported",流程将走向 generate (重新生成)

7. 结束节点

思考过程:在 workflow.add_conditional_edges 或 workflow.add_edge 中设置流程指向 END,当流程执行到 END 时,app.stream() 或 app.invoke() 停止并返回最终状态或结果。

实现代码

workflow.add_conditional_edges(    "generate", # 源节点    grade_generation_v_documents_and_question, # 条件判断函数    { # 结果到目标节点的映射        "not supported": "generate", # 如果答案不被支持,回到 generate        "useful": END, # 如果答案有用,流向 END        "end": END, # 如果达到重写次数上限,也流向 END        "not useful": "transform_query_rewrite", # 如果答案无用且未达上限,去重写    },)

五、将节点串联成工作流

思考过程

    回顾自适应 RAG 工作流程,明确各环节先后顺序与逻辑关系 。

    采用 TypedDict 定义统一结构存储问题、文档、上下文、答案、计数器等中间状态。

    确定节点函数(retrievegrade_documentsgeneratetransform_query2doctransform_query_rewrite)与条件判断函数(decide_to_generategrade_generation_v_documents_and_question)。

    实例化 StateGraph 对象,注册节点、指定起点、链接节点(固定顺序用 add_edge,条件判断用 add_conditional_edges) 。

实现代码

class GraphState(TypedDict):    keys: Dict[str, any]workflow = StateGraph(GraphState)# 添加节点workflow.add_node("retrieve", retrieve)workflow.add_node("grade_documents", grade_documents)workflow.add_node("generate", generate)workflow.add_node("transform_query2doc", transform_query2doc)workflow.add_node("transform_query_rewrite", transform_query_rewrite)# 设置入口点workflow.set_entry_point("retrieve")# 添加固定边workflow.add_edge("retrieve", "grade_documents")workflow.add_edge("transform_query2doc", "retrieve") # Query2Doc 后重新检索workflow.add_edge("transform_query_rewrite", "retrieve") # Rewrite Query 后重新检索# 添加条件边 - 根据上下文评估结果决定走向workflow.add_conditional_edges(    "grade_documents", # 源节点    decide_to_generate, # 条件判断函数    { # 结果到目标节点的映射        "transform_query2doc": "transform_query2doc", # 如果需要转换查询        "generate": "generate", # 如果可以直接生成    },)# 添加条件边 - 根据答案评估结果决定走向workflow.add_conditional_edges(    "generate", # 源节点    grade_generation_v_documents_and_question, # 条件判断函数    { # 结果到目标节点的映射        "not supported": "generate", # 如果答案不被支持,回到 generate        "useful": END, # 如果答案有用,流向 END        "end": END, # 如果达到重写次数上限,也流向 END        "not useful": "transform_query_rewrite", # 如果答案无用且未达上限,去重写    },)


Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

Self-RAG 大模型 检索增强生成 RAG
相关文章