检索增强生成(RAG)是一种将“向量检索”与“大语言模型”相结合的技术方法,能够在问答、摘要和文档分析等场景中显著提升准确性和上下文的利用效率。
本文将使用 LangChain
搭建一个完整的 RAG
流程,以 PGVector 作为向量数据库,并借助 LangGraph 构建状态图来管理整个流程的控制逻辑。
大语言模型初始化(llm_env.py)
我们首先使用 LangChain 提供的模型初始化器加载 gpt-4o-mini
模型,供后续问答使用。
# llm_env.pyfrom langchain.chat_models import init_chat_modelllm = init_chat_model("gpt-4o-mini", model_provider="openai")
RAG 主体流程(rag.py)
以下是整个 RAG 系统的主流程代码,主要包括:文档加载与切分、向量存储、状态图建模(analyze→retrieve→generate)、交互式问答。
# rag.pyimport osimport sysimport timesys.path.append(os.getcwd())from llm_set import llm_envfrom langchain_openai import OpenAIEmbeddingsfrom langchain_postgres import PGVectorfrom langchain_community.document_loaders import WebBaseLoaderfrom langchain_core.documents import Documentfrom langchain_text_splitters import RecursiveCharacterTextSplitterfrom langgraph.graph import START, StateGraphfrom typing_extensions import List, TypedDict, Annotatedfrom typing import Literalfrom langgraph.checkpoint.postgres import PostgresSaverfrom langgraph.graph.message import add_messagesfrom langchain_core.messages import HumanMessage, BaseMessagefrom langchain_core.prompts import ChatPromptTemplate# 初始化 LLMllm = llm_env.llm# 嵌入模型embeddings = OpenAIEmbeddings(model="text-embedding-3-large")# 向量数据库初始化vector_store = PGVector( embeddings=embeddings, collection_name="my_rag_docs", connection="postgresql+psycopg2://postgres:123456@localhost:5433/langchainvector",)# 加载网页内容url = "https://python.langchain.com/docs/tutorials/qa_chat_history/"loader = WebBaseLoader(web_paths=(url,))docs = loader.load()for doc in docs: doc.metadata["source"] = url# 文本分割text_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=50)all_splits = text_splitter.split_documents(docs)# 添加 section 元数据total_documents = len(all_splits)third = total_documents // 3for i, document in enumerate(all_splits): if i < third: document.metadata["section"] = "beginning" elif i < 2 * third: document.metadata["section"] = "middle" else: document.metadata["section"] = "end"# 检查是否已存在向量existing = vector_store.similarity_search(url, k=1, filter={"source": url})if not existing: _ = vector_store.add_documents(documents=all_splits) print("文档向量化完成")
分析、检索与生成模块
接下来,我们定义三个函数构成 LangGraph
的流程:analyze → retrieve → generate。
class Search(TypedDict): query: Annotated[str, "The question to be answered"] section: Annotated[ Literal["beginning", "middle", "end"], ..., "Section to query.", ]class State(TypedDict): messages: Annotated[list[BaseMessage], add_messages] query: Search context: List[Document] answer: set# 分析意图 → 获取 query 与 sectiondef analyze(state: State): structtured_llm = llm.with_structured_output(Search) query = structtured_llm.invoke(state["messages"]) return {"query": query}# 相似度检索def retrieve(state: State): query = state["query"] if hasattr(query, 'section'): filter = {"section": query["section"]} else: filter = None retrieved_docs = vector_store.similarity_search(query["query"], filter=filter) return {"context": retrieved_docs}
生成模块基于 ChatPromptTemplate
和当前上下文生成回答:
prompt_template = ChatPromptTemplate.from_messages( [ ("system", "尽你所能按照上下文:{context},回答问题:{question}。"), ])def generate(state: State): docs_content = "\n\n".join(doc.page_content for doc in state["context"]) messages = prompt_template.invoke({ "question": state["query"]["query"], "context": docs_content, }) response = llm.invoke(messages) return {"answer": response.content, "messages": [response]}
构建 LangGraph 流程图
定义好状态结构后,我们构建 LangGraph
:
graph_builder = StateGraph(State).add_sequence([analyze, retrieve, generate])graph_builder.add_edge(START, "analyze")
PG 数据库中保存中间状态(Checkpoint)
我们通过 PostgresSaver
记录每次对话的中间状态:
DB_URI = "postgresql://postgres:123456@localhost:5433/langchaindemo?sslmode=disable"with PostgresSaver.from_conn_string(DB_URI) as checkpointer: checkpointer.setup() graph = graph_builder.compile(checkpointer=checkpointer) input_thread_id = input("输入thread_id:") time_str = time.strftime("%Y%m%d", time.localtime()) config = {"configurable": {"thread_id": f"rag-{time_str}-demo-{input_thread_id}"}} print("输入问题,输入 exit 退出。") while True: query = input("你: ") if query.strip().lower() == "exit": break input_messages = [HumanMessage(query)] response = graph.invoke({"messages": input_messages}, config=config) print(response["answer"])
效果
总结
本文借助 LangChain 的模块化特性,整合 PGVector 向量数据库与 LangGraph 的状态管理能力,构建了一个具备交互性、持久化支持以及多文档结构处理能力的 RAG 系统。其优势包括:
- 支持结构化提问理解(分区查询)自动化分段与元数据标记状态流追踪与恢复可拓展支持文档上传、缓存优化、多用户配置