在 LangChain 框架中,chain_type
主要出现在 load_qa_chain
和 RetrievalQA
等组件中,用于控制文档处理方式。以下是主要的 chain_type 类型及其特点:
四种主要的 Chain Types
类型 | 关键特点 | 适用场景 | 优点 | 缺点 |
---|---|---|---|---|
stuff (默认) | 将所有文档合并为一个提示 | 文档少且短(<4K tokens) | ✅ 单次LLM调用 ✅ 保留文档间关系 | ❌ 文档长时会超出上下文限制 |
map_reduce | 分两步处理: 1. 单独处理每个文档 2. 合并结果 | 大量文档/长文档 | ✅ 处理任意长度文档 ✅ 并行处理文档 | ❌ 多次LLM调用 ❌ 增加延迟和成本 |
refine | 迭代处理文档: 用前结果+新文档迭代优化 | 需要逐步精炼答案 | ✅ 答案质量高 ✅ 处理长文档 | ❌ 顺序处理效率低 ❌ 最高调用次数 |
map_rerank | 为每个文档评分, 选最高分答案 | 问答任务 答案提取 | ✅ 返回置信度 ✅ 精准答案定位 | ❌ 仅适用于特定任务 |
详细说明与代码示例
1. stuff
(合并处理)
from langchain.chains import load_qa_chainchain = load_qa_chain(llm, chain_type="stuff")result = chain({"input_documents": docs, "question": "总结主要内容"})
- 工作原理:将所有文档内容拼接成单个提示适合:摘要生成、简单问答(文档<4K tokens)
2. map_reduce
(映射-归纳)
chain = load_qa_chain(llm, chain_type="map_reduce")# 高级参数设置chain = load_qa_chain( llm, chain_type="map_reduce", return_intermediate_steps=True, # 返回中间结果 map_prompt=MAP_PROMPT, # 自定义映射提示 combine_prompt=COMBINE_PROMPT # 自定义归并提示)
处理流程:
- Map阶段:为每个文档生成独立答案Reduce阶段:合并所有答案生成最终结果
适合:长文档处理(如整本书分析)、跨文档综合问答
3. refine
(迭代优化)
chain = load_qa_chain(llm, chain_type="refine")# 带中间步骤的结果result = chain({ "input_documents": docs, "question": "技术演进的关键节点", "return_refine_steps": True # 返回迭代过程})
处理流程:
- 用第一个文档生成初始答案用前结果+新文档迭代优化答案重复直到处理完所有文档
适合:需要逐步精炼的场景(如研究报告分析)
4. map_rerank
(评分优选)
from langchain.chains.qa_with_sources.map_rerank_prompt import PROMPTchain = load_qa_chain( llm, chain_type="map_rerank", prompt=PROMPT, # 特殊评分提示 metadata_keys=['source'] # 包含元数据)
处理流程:
- 为每个文档生成答案+置信度评分选择最高分的答案作为最终结果
输出格式:{"answer": "...", "score": 0.95, "source": "doc1.pdf"}
适合:精确答案提取、带来源引用的问答
选择建议
场景 | 推荐类型 |
---|---|
文档少且短(<10页) | stuff |
长文档/整书分析 | map_reduce |
需要精确来源引用 | map_rerank |
答案需要逐步优化 | refine |
实时性要求高 | stuff |
成本敏感场景 | stuff (调用次数最少) |
使用技巧
文档预处理:
# 分割长文档from langchain.text_splitter import RecursiveCharacterTextSplittertext_splitter = RecursiveCharacterTextSplitter(chunk_size=1000)docs = text_splitter.split_documents(long_doc)
自定义提示模板:
from langchain.prompts import PromptTemplateMAP_TEMPLATE = """根据以下片段回答问题:{context}问题:{question}答案:"""MAP_PROMPT = PromptTemplate(...)
混合使用策略:
# 先用map_reduce过滤,再用stuff精处理filtered_docs = filter_relevant_docs(question, docs)chain = load_qa_chain(llm, chain_type="stuff")result = chain.run(input_documents=filtered_docs, question=question)
实际选择时应根据文档规模、任务复杂度和性能要求综合评估。对于大多数应用场景,map_reduce
提供了最佳的长度与效果平衡。
源码位置概览
chain_type | 核心实现文件 | 关键类 |
---|---|---|
stuff | langchain/chains/combine_documents/stuff.py | StuffDocumentsChain |
map_reduce | langchain/chains/combine_documents/map_reduce.py | MapReduceDocumentsChain |
refine | langchain/chains/combine_documents/refine.py | RefineDocumentsChain |
map_rerank | langchain/chains/combine_documents/map_rerank.py | MapRerankDocumentsChain |
入口函数 | langchain/chains/question_answering/load.py | load_qa_chain() 函数 |
🔍 详细源码分析
1. stuff
类型
文件路径: langchain/chains/combine_documents/stuff.py
核心代码片段:
class StuffDocumentsChain(BaseCombineDocumentsChain): """Chain that combines documents by stuffing into context.""" def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict: # 拼接所有文档内容 inputs = {k: v for k, v in kwargs.items() if k not in self.input_key} inputs[self.document_variable_name] = self._combine_docs(docs, **kwargs) return inputs
2. map_reduce
类型
文件路径: langchain/chains/combine_documents/map_reduce.py
核心逻辑:
class MapReduceDocumentsChain(BaseCombineDocumentsChain): def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]: # 1. Map阶段 map_results = self.llm_chain.apply( # 对每个文档单独处理 [{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs], ) # 2. Reduce阶段 result = self.combine_document_chain.run( input_documents=map_results, **kwargs ) return result
3. refine
类型
文件路径: langchain/chains/combine_documents/refine.py
迭代逻辑:
class RefineDocumentsChain(BaseCombineDocumentsChain): def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]: # 初始文档处理 current_result = self.initial_llm_chain.run( **{self.document_variable_name: docs[0].page_content}, **kwargs ) # 迭代优化 for doc in docs[1:]: current_result = self.refine_llm_chain.run( question=kwargs["question"], existing_answer=current_result, **{self.document_variable_name: doc.page_content}, ) return current_result
4. map_rerank
类型
文件路径: langchain/chains/combine_documents/map_rerank.py
评分逻辑:
class MapRerankDocumentsChain(BaseCombineDocumentsChain): def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]: # 为每个文档评分 results = [] for doc in docs: result = self.llm_chain.predict_and_parse( **{self.document_variable_name: doc.page_content}, **kwargs ) results.append((result["score"], result["answer"])) # 选择最高分答案 sorted_results = sorted(results, key=lambda x: x[0], reverse=True) return sorted_results[0][1]
5. 入口函数 load_qa_chain()
文件路径: langchain/chains/question_answering/load.py
chain_type 分发逻辑:
def load_qa_chain( llm: BaseLanguageModel, chain_type: str = "stuff", **kwargs: Any,) -> BaseCombineDocumentsChain: loader_mapping: Dict[str, Any] = { "stuff": StuffDocumentsChain, "map_reduce": MapReduceDocumentsChain, "refine": RefineDocumentsChain, "map_rerank": MapRerankDocumentsChain, } if chain_type not in loader_mapping: raise ValueError(f"Invalid chain_type {chain_type}") class_obj = loader_mapping[chain_type] return class_obj.from_llm(llm, **kwargs)
定制化扩展点
如需自定义 chain 行为,可关注以下方法:
提示工程:
# 自定义提示模板from langchain.prompts import PromptTemplatecustom_prompt = PromptTemplate( input_variables=["context", "question"], template="基于以下内容回答问题:\n{context}\n问题: {question}")chain = load_qa_chain( llm, chain_type="stuff", prompt=custom_prompt)
文档预处理:
# 自定义文档分割器from langchain.text_splitter import CharacterTextSplittertext_splitter = CharacterTextSplitter( chunk_size=500, chunk_overlap=50)
回调处理:
# 添加处理回调from langchain.callbacks import StdOutCallbackHandlerchain.run( input_documents=docs, question=query, callbacks=[StdOutCallbackHandler()])
调试建议
查看内部状态:
# 对于 map_reduce 查看中间结果chain = load_qa_chain( llm, chain_type="map_reduce", return_intermediate_steps=True)result = chain({"input_documents": docs, "question": query})print(result["intermediate_steps"])
源码调试技巧:
- 在
langchain/chains/combine_documents/base.py
中的 BaseCombineDocumentsChain
基类设置断点监控 _call
方法的执行流程日志跟踪:
import logginglogging.basicConfig(level=logging.DEBUG)
通过分析这些源码文件,可以深入理解每种 chain_type 的内部工作机制,并根据需要扩展或修改其行为。