一、模块化设计与扩展架构
在前几篇中,我们详细介绍了RAG系统的架构、核心模块和部署方案。本篇将深入探讨如何扩展系统功能,实现自定义处理器和向量化方法,以应对更多样化的业务需求。
1.1 扩展性设计原则
我们的RAG系统遵循以下扩展性设计原则:
核心设计模式包括:
- 抽象基类:定义统一接口,确保实现一致性工厂模式:根据配置动态创建适当的组件实例策略模式:运行时选择不同的算法策略装饰器模式:在不修改原有代码的情况下扩展功能依赖注入:通过配置注入依赖,降低组件间耦合
二、自定义文档处理器开发
Read file: src/data_processing/processors/base.py
2.1 处理器抽象接口
所有处理器继承自BaseDocumentProcessor
抽象基类:
class BaseDocumentProcessor(ABC): """文档处理器基类。""" def __init__(self, config: Optional[ProcessorConfig] = None): """使用配置初始化处理器。""" self.config = config or ProcessorConfig() self._setup_logging() @abstractmethod def process_file(self, file_content: bytes, filename: str, mime_type: str) -> List[Document]: """处理单个文件。""" pass @abstractmethod def process_text(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Document]: """处理原始文本。""" pass
2.2 实现自定义学术论文处理器
假设我们需要处理特定格式的学术论文PDF,可以这样实现:
class AcademicPaperProcessor(BaseDocumentProcessor): """学术论文处理器,专门处理论文PDF格式。""" def process_file(self, file_content: bytes, filename: str, mime_type: str) -> List[Document]: """处理论文PDF文件。""" if mime_type != "application/pdf": raise ValueError(f"不支持的文件类型: {mime_type},仅支持PDF") text = self._extract_pdf_text(file_content) sections = self._parse_paper_structure(text) documents = [] for section_name, section_text in sections.items(): metadata = { "source": filename, "section": section_name, "paper_type": "academic", "file_type": "pdf" } chunks = self._split_section(section_text, section_name) for i, chunk in enumerate(chunks): doc = Document( page_content=chunk, metadata={ **metadata, "chunk_id": i, "total_chunks": len(chunks) } ) documents.append(self._add_metadata(doc)) return documents def process_text(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Document]: """处理原始文本。""" sections = self._parse_paper_structure(text) documents = [] return documents def _extract_pdf_text(self, pdf_bytes: bytes) -> str: """提取PDF文本,保留论文格式。""" import io import PyPDF2 pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_bytes)) full_text = [] for page_num in range(len(pdf_reader.pages)): page = pdf_reader.pages[page_num] full_text.append(page.extract_text()) return "\n\n".join(full_text) def _parse_paper_structure(self, text: str) -> Dict[str, str]: """解析论文结构,识别标题、摘要、引言、方法、结果、讨论等章节。""" import re sections = {} title_match = re.search(r'^(.+?)(?=\n\n)', text) if title_match: sections["title"] = title_match.group(1).strip() abstract_match = re.search(r'Abstract[:.\s]+(.+?)(?=\n\n\d+.|\n\nIntroduction)', text, re.DOTALL) if abstract_match: sections["abstract"] = abstract_match.group(1).strip() section_patterns = [ (r'Introduction(?:\s|\n)+(.+?)(?=\n\n\d+.|\n\n[A-Z])', "introduction"), (r'Methods(?:\s|\n)+(.+?)(?=\n\n\d+.|\n\n[A-Z])', "methods"), (r'Results(?:\s|\n)+(.+?)(?=\n\n\d+.|\n\n[A-Z])', "results"), (r'Discussion(?:\s|\n)+(.+?)(?=\n\n\d+.|\n\n[A-Z])', "discussion"), (r'Conclusion(?:\s|\n)+(.+?)(?=\n\n\d+.|\n\n[A-Z]|$)', "conclusion"), (r'References(?:\s|\n)+(.+?)(?=$)', "references") ] for pattern, section_name in section_patterns: section_match = re.search(pattern, text, re.DOTALL) if section_match: sections[section_name] = section_match.group(1).strip() return sections
2.3 注册自定义处理器
开发完成后,将自定义处理器注册到处理器工厂:
from src.data_processing.processors.academic_paper_processor import AcademicPaperProcessorclass DocumentProcessor: """文档处理器工厂,根据MIME类型选择合适的处理器。""" def __init__(self, config=None): """初始化处理器工厂。""" self.config = config or ProcessorConfig() self._processors = {} self._register_default_processors() def _register_default_processors(self): """注册默认处理器。""" self._processors.update({ "application/pdf": PDFProcessor(self.config), "application/vnd.openxmlformats-officedocument.wordprocessingml.document": WordProcessor(self.config), }) if self.config.doc_type == DocumentType.ACADEMIC_PAPER: self._processors["application/pdf"] = AcademicPaperProcessor(self.config) def register_processor(self, mime_type, processor): """注册自定义处理器。""" self._processors[mime_type] = processor
三、自定义向量化方法实现
Read file: src/data_processing/vectorization/base.py
3.1 向量化器接口与工厂模式
所有向量化器继承自BaseVectorizer
:
class BaseVectorizer(ABC): """向量化器基类,定义向量化器的接口。""" def __init__(self, cache_dir: str = './cache/vectorization'): """初始化向量化器。""" self.cache_dir = cache_dir self._ensure_cache_dir() @abstractmethod def vectorize(self, text: str) -> np.ndarray: """将文本转换为向量。""" pass @abstractmethod def batch_vectorize(self, texts: List[str]) -> List[np.ndarray]: """批量将文本转换为向量。""" pass
工厂模式用于创建不同类型的向量化器:
class VectorizationFactory: """向量化工厂类,用于创建不同类型的向量化器。""" _vectorizers = { 'tfidf': TfidfVectorizer, 'word2vec': Word2VecVectorizer, 'bert': BertVectorizer, 'bge-m3': BgeVectorizer } @classmethod def create_vectorizer(cls, method: str = 'tfidf', **kwargs) -> BaseVectorizer: """创建向量化器。""" method = method.lower() if method not in cls._vectorizers: supported_methods = ", ".join(cls._vectorizers.keys()) raise ValueError(f"不支持的向量化方法: {method}。支持的方法有: {supported_methods}")
3.2 OpenAI嵌入模型集成示例
下面我们实现一个OpenAI向量化器,将文本转换为OpenAI提供的嵌入向量:
import numpy as npimport osimport timeimport loggingfrom openai import OpenAIfrom typing import Listfrom .base import BaseVectorizerclass OpenAIVectorizer(BaseVectorizer): """使用OpenAI API的向量化器。""" def __init__(self, model_name="text-embedding-3-small", batch_size=32, api_key=None, dimensions=1536, cache_dir='./cache/vectorization'): """初始化OpenAI向量化器。 Args: model_name: OpenAI嵌入模型名称 batch_size: 批处理大小 api_key: OpenAI API密钥 dimensions: 嵌入向量维度 cache_dir: 缓存目录 """ super().__init__(cache_dir=cache_dir) self.model_name = model_name self.batch_size = batch_size self.dimensions = dimensions self.client = OpenAI(api_key=api_key or os.getenv("OPENAI_API_KEY")) self.logger = logging.getLogger(self.__class__.__name__) def vectorize(self, text: str) -> np.ndarray: """将文本转换为向量。 Args: text: 要向量化的文本 Returns: 文本的向量表示 """ if not text.strip(): return np.zeros(self.dimensions) try: response = self.client.embeddings.create( model=self.model_name, input=text, dimensions=self.dimensions ) embedding = response.data[0].embedding return np.array(embedding) except Exception as e: self.logger.error(f"OpenAI向量化失败: {str(e)}") return np.zeros(self.dimensions) def batch_vectorize(self, texts: List[str]) -> List[np.ndarray]: """批量将文本转换为向量。 Args: texts: 要向量化的文本列表 Returns: 文本的向量表示列表 """ texts = [text for text in texts if text.strip()] if not texts: return [np.zeros(self.dimensions)] results = [] for i in range(0, len(texts), self.batch_size): batch = texts[i:i+self.batch_size] try: response = self.client.embeddings.create( model=self.model_name, input=batch, dimensions=self.dimensions ) batch_embeddings = [np.array(data.embedding) for data in response.data] results.extend(batch_embeddings) if len(texts) > self.batch_size and i + self.batch_size < len(texts): time.sleep(0.5) except Exception as e: self.logger.error(f"OpenAI批量向量化失败: {str(e)}") results.extend([np.zeros(self.dimensions) for _ in batch]) return results def get_dimensions(self) -> int: """获取向量维度。""" return self.dimensions
3.3 注册自定义向量化器
向量化工厂类需要注册新增的向量化器:
from src.data_processing.vectorization.openai_vectorizer import OpenAIVectorizerclass VectorizationFactory: """向量化工厂类,用于创建不同类型的向量化器。""" _vectorizers = { 'tfidf': TfidfVectorizer, 'word2vec': Word2VecVectorizer, 'bert': BertVectorizer, 'bge-m3': BgeVectorizer, 'openai': OpenAIVectorizer } @staticmethod def _get_config_from_env(method: str) -> Dict[str, Any]: """从环境变量获取配置。""" config = {} elif method == 'openai': config['model_name'] = os.getenv('OPENAI_EMBEDDING_MODEL', 'text-embedding-3-small') config['batch_size'] = int(os.getenv('OPENAI_BATCH_SIZE', '32')) config['dimensions'] = int(os.getenv('OPENAI_EMBEDDING_DIMENSIONS', '1536')) return config
四、混合检索策略实现
4.1 多模型混合检索器
在实际应用中,单一检索方法往往不能满足所有需求。我们可以实现一个混合检索策略,结合多种方法的优势:
import asyncioimport loggingfrom typing import List, Tuple, Dict, Anyimport jiebafrom langchain.schema import Documentfrom src.data_processing.vectorization.factory import VectorizationFactoryclass HybridSearchRetriever: """混合检索器,结合多种检索方法的优势。""" def __init__(self, vector_store, keyword_weight=0.3, semantic_weight=0.7, rerank_model=None): """初始化混合检索器。""" self.vector_store = vector_store self.keyword_weight = keyword_weight self.semantic_weight = semantic_weight self.rerank_model = rerank_model self.logger = logging.getLogger(self.__class__.__name__) self._initialize_bm25_index() self.vectorizer = VectorizationFactory.create_vectorizer('bge-m3') def _initialize_bm25_index(self): """初始化BM25关键词索引。""" from rank_bm25 import BM25Okapi docs = self.vector_store.get_all_documents() texts = [doc.page_content for doc in docs] tokenized_corpus = [list(jieba.cut(text)) for text in texts] self.bm25 = BM25Okapi(tokenized_corpus) self.doc_ids = [doc.metadata.get('doc_id') for doc in docs] self.documents = docs async def retrieve(self, query: str, top_k: int = 5, threshold: float = 0.0): """混合检索实现。""" keyword_results = await self._keyword_search(query, top_k * 2) vector_results = await self._vector_search(query, top_k * 2) merged_results = self._merge_results(keyword_results, vector_results) if self.rerank_model and len(merged_results) > top_k: merged_results = await self._rerank_results(query, merged_results, top_k) filtered_results = [ (doc, score) for doc, score in merged_results if score >= threshold ] return filtered_results[:top_k] async def _keyword_search(self, query: str, top_k: int): """BM25关键词检索。""" tokenized_query = list(jieba.cut(query)) bm25_scores = self.bm25.get_scores(tokenized_query) results = [] for i, score in enumerate(bm25_scores): if score > 0: results.append((self.documents[i], score)) results.sort(key=lambda x: x[1], reverse=True) return results[:top_k] async def _vector_search(self, query: str, top_k: int): """向量相似度检索。""" return await self.vector_store.asimilarity_search_with_score(query, top_k) def _merge_results(self, keyword_results, vector_results): """合并关键词和向量检索结果。""" merged_map = {} max_keyword_score = max([score for _, score in keyword_results]) if keyword_results else 1.0 for doc, score in keyword_results: doc_id = doc.metadata.get("id") normalized_score = score / max_keyword_score merged_map[doc_id] = { "doc": doc, "keyword_score": normalized_score, "vector_score": 0.0 } for doc, score in vector_results: doc_id = doc.metadata.get("id") if doc_id in merged_map: merged_map[doc_id]["vector_score"] = score else: merged_map[doc_id] = { "doc": doc, "keyword_score": 0.0, "vector_score": score } result_list = [] for item in merged_map.values(): final_score = ( self.keyword_weight * item["keyword_score"] + self.semantic_weight * item["vector_score"] ) result_list.append((item["doc"], final_score)) result_list.sort(key=lambda x: x[1], reverse=True) return result_list
4.2 跨模态检索扩展
RAG系统除了处理文本,也可以扩展为处理图像、音频等多模态数据:
import osimport asyncioimport numpy as npfrom typing import List, Tuple, Dict, Anyfrom langchain.schema import Documentclass MultiModalRetriever: """多模态检索器,支持文本、图像等多种模态。""" def __init__(self, vector_stores, embedding_models): """初始化多模态检索器。""" self.vector_stores = vector_stores self.embedding_models = embedding_models async def retrieve(self, query, modal_type=None, top_k=5): """多模态检索实现。""" if modal_type == "auto": modal_type = self._detect_modal_type(query) if modal_type == "text": clip_embedding = self.embedding_models["clip"].encode_text(query) bge_embedding = self.embedding_models["bge"].vectorize(query) tasks = [ self.vector_stores["text"].asimilarity_search_by_vector(bge_embedding, top_k), self.vector_stores["image"].asimilarity_search_by_vector(clip_embedding, top_k) ] text_results, image_results = await asyncio.gather(*tasks) return self._merge_modal_results(text_results, image_results, top_k) elif modal_type == "image": image_embedding = self.embedding_models["clip"].encode_image(query) results = await self.vector_stores["image"].asimilarity_search_by_vector( image_embedding, top_k ) return results
五、用户反馈优化机制
5.1 反馈数据收集与存储
为进一步提升RAG系统的检索质量,我们可以加入用户反馈机制:
import sqlite3import loggingfrom datetime import datetimefrom typing import List, Dict, Anyclass FeedbackOptimizer: """基于用户反馈优化RAG检索结果。""" def __init__(self, vector_store, feedback_db=None): """初始化反馈优化器。""" self.vector_store = vector_store self.feedback_db = feedback_db or self._initialize_feedback_db() self.logger = logging.getLogger(self.__class__.__name__) def _initialize_feedback_db(self): """初始化反馈数据库。""" conn = sqlite3.connect('data/feedback.db') c = conn.cursor() c.execute(''' CREATE TABLE IF NOT EXISTS feedback ( id INTEGER PRIMARY KEY AUTOINCREMENT, query_text TEXT, doc_id TEXT, is_relevant INTEGER, timestamp TEXT ) ''') c.execute(''' CREATE TABLE IF NOT EXISTS query_log ( id INTEGER PRIMARY KEY AUTOINCREMENT, query_text TEXT, results_count INTEGER, timestamp TEXT ) ''') conn.commit() return conn def record_feedback(self, query, doc_id, is_relevant): """记录用户反馈。""" cursor = self.feedback_db.cursor() cursor.execute( "INSERT INTO feedback (query_text, doc_id, is_relevant, timestamp) VALUES (?, ?, ?, datetime('now'))", (query, doc_id, 1 if is_relevant else 0) ) self.feedback_db.commit() def record_query(self, query, results_count): """记录查询日志。""" cursor = self.feedback_db.cursor() cursor.execute( "INSERT INTO query_log (query_text, results_count, timestamp) VALUES (?, ?, datetime('now'))", (query, results_count) ) self.feedback_db.commit() def get_relevance_feedback_for_query(self, query, limit=10): """获取特定查询的相关性反馈。""" cursor = self.feedback_db.cursor() cursor.execute( "SELECT doc_id, is_relevant, COUNT(*) FROM feedback WHERE query_text = ? GROUP BY doc_id, is_relevant", (query,) ) return cursor.fetchall() def optimize_results(self, query, initial_results, top_k=5): """基于历史反馈优化检索结果。""" feedback = self.get_relevance_feedback_for_query(query) if not feedback: return initial_results[:top_k] feedback_dict = {} for doc_id, is_relevant, count in feedback: if doc_id not in feedback_dict: feedback_dict[doc_id] = {"relevant": 0, "irrelevant": 0} if is_relevant: feedback_dict[doc_id]["relevant"] += count else: feedback_dict[doc_id]["irrelevant"] += count adjusted_results = [] for doc, score in initial_results: doc_id = doc.metadata.get("id") adjustment = 0 if doc_id in feedback_dict: relevant = feedback_dict[doc_id]["relevant"] irrelevant = feedback_dict[doc_id]["irrelevant"] if relevant + irrelevant > 0: adjustment = (relevant - irrelevant) / (relevant + irrelevant) * 0.2 adjusted_score = min(1.0, max(0.0, score + adjustment)) adjusted_results.append((doc, adjusted_score)) adjusted_results.sort(key=lambda x: x[1], reverse=True) self.record_query(query, len(initial_results)) return adjusted_results[:top_k]
5.2 反馈界面实现
为了收集用户反馈,我们需要在前端界面添加反馈按钮:
from fastapi import APIRouter, HTTPException, Queryfrom pydantic import BaseModelfrom typing import List, Optionalrouter = APIRouter()class FeedbackRequest(BaseModel): query: str doc_id: str is_relevant: bool@router.post("/feedback")async def submit_feedback(request: FeedbackRequest): """提交文档相关性反馈。""" try: feedback_optimizer = get_feedback_optimizer() feedback_optimizer.record_feedback( request.query, request.doc_id, request.is_relevant ) return {"status": "success", "message": "反馈已记录"} except Exception as e: raise HTTPException(status_code=500, detail=f"提交反馈失败: {str(e)}")@router.get("/feedback/stats")async def get_feedback_stats(query: Optional[str] = None): """获取反馈统计信息。""" try: feedback_optimizer = get_feedback_optimizer() if query: stats = feedback_optimizer.get_relevance_feedback_for_query(query) else: stats = feedback_optimizer.get_global_feedback_stats() return {"status": "success", "data": stats} except Exception as e: raise HTTPException(status_code=500, detail=f"获取反馈统计失败: {str(e)}")
六、查询意图分析器扩展
Read file: src/chains/processors/recommendation.py
6.1 自定义意图处理器
我们已经有了推荐查询处理器,现在我们可以扩展更多专用意图处理器:
import numpy as npfrom typing import List, Dict, Anyfrom src.data_processing.vectorization.factory import VectorizationFactoryclass QueryIntentClassifier: """查询意图分类器,基于向量相似度识别查询意图。""" def __init__(self, embedding_model=None): """初始化意图分类器。""" self.embedding_model = embedding_model or VectorizationFactory.create_vectorizer("bge-m3") self.intent_examples = { "比较分析": [ "A和B有什么区别?", "哪一个更好,X还是Y?", "比较一下P和Q的优缺点" ], "因果解释": [ "为什么会出现这种情况?", "导致X的原因是什么?", "这个问题的根源是什么?" ], "列举信息": [ "列出所有的X", "有哪些方法可以做Y?", "X包含哪些组成部分?" ], "概念解释": [ "什么是X?", "X的定义是什么?", "如何理解Y概念?" ], "数据统计": [ "X的平均值是多少?", "Y的增长率是多少?", "Z的分布情况怎样?" ], "操作指导": [ "如何做X?", "执行Y的步骤是什么?", "使用Z的方法有哪些?" ], "推荐建议": [ "推荐几款好用的X", "有什么适合Y的工具?", "帮我选择一个合适的Z" ] } self.intent_vectors = self._compute_intent_vectors() def _compute_intent_vectors(self): """预计算每种意图的平均向量表示。""" intent_vectors = {} for intent, examples in self.intent_examples.items(): vectors = self.embedding_model.batch_vectorize(examples) avg_vector = np.mean(vectors, axis=0) avg_vector = avg_vector / np.linalg.norm(avg_vector) intent_vectors[intent] = avg_vector return intent_vectors def classify_intent(self, query: str): """分类查询意图。""" query_vector = self.embedding_model.vectorize(query) similarities = {} for intent, vector in self.intent_vectors.items(): similarity = np.dot(query_vector, vector) similarities[intent] = similarity max_intent = max(similarities, key=similarities.get) max_similarity = similarities[max_intent] if max_similarity < 0.5: return "一般信息查询", max_similarity return max_intent, max_similarity
6.2 自定义比较分析处理器
每种意图需要专门的处理器,以比较分析为例:
import refrom typing import List, Dict, Anyfrom .base import QueryProcessorclass ComparisonQueryProcessor(QueryProcessor): """比较分析查询处理器,处理涉及比较的查询。""" def process(self, query: str, documents: List[Any], **kwargs) -> Dict[str, Any]: """处理比较分析查询。""" if not documents: return { "answer": "抱歉,没有找到相关的比较信息。", "sources": [] } entities = self._extract_comparison_entities(query) if len(entities) < 2: entities = self._extract_entities_from_documents(documents) entity_info = self._collect_entity_information(entities, documents) comparison_table = self._generate_comparison_table(entity_info) conclusion = self._generate_comparison_conclusion(entity_info, query) answer = f"根据您的比较请求,以下是{', '.join(entities)}的对比分析:\n\n{comparison_table}\n\n{conclusion}" return { "answer": answer, "sources": documents, "entities": entities, "comparison_table": comparison_table } def can_handle(self, query: str) -> bool: """判断是否是比较分析查询。""" comparison_keywords = ["比较", "区别", "差异", "优缺点", "对比", "相比", "VS", "好坏"] comparison_patterns = [ r"(.+)和(.+)的区别", r"(.+)与(.+)的(差异|不同)", r"(.+)相比(.+)怎么样", r"(.+)还是(.+)更好" ] if any(keyword in query for keyword in comparison_keywords): return True for pattern in comparison_patterns: if re.search(pattern, query): return True return False def _extract_comparison_entities(self, query: str) -> List[str]: """从查询中提取需要比较的实体。""" patterns = [ r"(.+)和(.+)的区别", r"(.+)与(.+)的(差异|不同)", r"(.+)相比(.+)怎么样", r"(.+)还是(.+)更好" ] for pattern in patterns: match = re.search(pattern, query) if match: entities = [match.group(1).strip(), match.group(2).strip()] return entities return [] def _extract_entities_from_documents(self, documents: List[Any]) -> List[str]: """从文档中提取可能的比较实体。""" return [] def _collect_entity_information(self, entities: List[str], documents: List[Any]) -> Dict[str, Dict]: """为每个实体从文档中收集信息。""" entity_info = {} for entity in entities: entity_info[entity] = { "advantages": [], "disadvantages": [], "features": {}, "mentions": 0 } for doc in documents: text = doc.page_content if entity in text: entity_info[entity]["mentions"] += text.count(entity) advantages = self._extract_advantages(text, entity) entity_info[entity]["advantages"].extend(advantages) disadvantages = self._extract_disadvantages(text, entity) entity_info[entity]["disadvantages"].extend(disadvantages) features = self._extract_features(text, entity) for feature, value in features.items(): if feature in entity_info[entity]["features"]: entity_info[entity]["features"][feature].append(value) else: entity_info[entity]["features"][feature] = [value] return entity_info def _extract_advantages(self, text: str, entity: str) -> List[str]: """提取实体的优点。""" patterns = [ f"{entity}的优点", f"{entity}的好处", f"{entity}的优势" ] return [] def _extract_disadvantages(self, text: str, entity: str) -> List[str]: """提取实体的缺点。""" return [] def _extract_features(self, text: str, entity: str) -> Dict[str, str]: """提取实体的特性。""" return {} def _generate_comparison_table(self, entity_info: Dict[str, Dict]) -> str: """生成比较表格。""" entities = list(entity_info.keys()) table = f"| 特性 | {' | '.join(entities)} |\n" table += f"| --- | {' | '.join(['---' for _ in entities])} |\n" all_features = set() for entity, info in entity_info.items(): all_features.update(info["features"].keys()) for feature in sorted(all_features): row = f"| {feature} | " for entity in entities: if feature in entity_info[entity]["features"]: values = entity_info[entity]["features"][feature] row += f"{values[0]} | " else: row += "- | " table += row + "\n" table += f"| 优点 | {' | '.join([', '.join(info['advantages'][:3]) or '-' for _, info in entity_info.items()])} |\n" table += f"| 缺点 | {' | '.join([', '.join(info['disadvantages'][:3]) or '-' for _, info in entity_info.items()])} |\n" return table def _generate_comparison_conclusion(self, entity_info: Dict[str, Dict], query: str) -> str: """生成比较分析结论。""" entities = list(entity_info.keys()) if len(entities) < 2: return "无法生成比较结论,找不到足够的实体信息。" return "根据以上对比,每个选项都有各自的优缺点,具体选择取决于您的具体需求和场景。"
七、实际案例分析
7.1 自定义法律文档处理器
法律文档有其特殊性,以下是一个专门处理法律文档的处理器示例:
import refrom typing import List, Dict, Any, Optionalfrom langchain.schema import Documentfrom src.data_processing.processors.base import BaseDocumentProcessor, ProcessorConfigclass LegalDocumentProcessor(BaseDocumentProcessor): """法律文档处理器,专门处理法律文书。""" def __init__(self, config: Optional[ProcessorConfig] = None): """初始化法律文档处理器。""" super().__init__(config) self.legal_terms = self._load_legal_terms() def _load_legal_terms(self): """加载法律术语词典。""" return { "原告": "起诉方,请求法院裁判的一方", "被告": "被起诉方,被请求法院裁判的一方", "诉讼": "通过法院解决纠纷的法律程序", } def process_file(self, file_content: bytes, filename: str, mime_type: str) -> List[Document]: """处理法律文档文件。""" if mime_type == "application/pdf": text = self._extract_pdf_text(file_content) elif mime_type == "application/msword" or mime_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document": text = self._extract_word_text(file_content) else: text = file_content.decode(self.config.encoding, errors='ignore') return self.process_text(text, {"source": filename, "mime_type": mime_type}) def process_text(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Document]: """处理法律文档文本。""" sections = self._parse_legal_document_structure(text) documents = [] for section_name, section_content in sections.items(): section_metadata = { **(metadata or {}), "section": section_name, "document_type": "legal" } annotated_content = self._annotate_legal_terms(section_content) doc = Document( page_content=annotated_content, metadata=section_metadata ) documents.append(self._add_metadata(doc)) return documents def _parse_legal_document_structure(self, text: str) -> Dict[str, str]: """解析法律文档结构。""" sections = {} section_patterns = [ (r"案号[::]\s*(.+?)(?=\n)", "case_number"), (r"原告[::]\s*(.+?)(?=\n被告)", "plaintiff"), (r"被告[::]\s*(.+?)(?=\n)", "defendant"), (r"诉讼请求[::]\s*(.+?)(?=\n)", "claims"), (r"事实与理由[::]\s*(.+?)(?=\n)", "facts_and_reasons"), (r"裁判结果[::]\s*(.+?)(?=\n)", "judgment"), (r"裁判理由[::]\s*(.+?)(?=\n)", "reasoning") ] for pattern, section_name in section_patterns: match = re.search(pattern, text, re.DOTALL) if match: sections[section_name] = match.group(1).strip() if not sections: paragraphs = re.split(r'\n\s*\n', text) for i, para in enumerate(paragraphs): sections[f"paragraph_{i+1}"] = para.strip() return sections def _annotate_legal_terms(self, text: str) -> str: """为法律术语添加注解。""" annotated_text = text for term, definition in self.legal_terms.items(): if term in annotated_text: term_with_note = f"{term}[注: {definition}]" annotated_text = annotated_text.replace(term, term_with_note, 1) return annotated_text
7.2 金融数据向量化器
针对金融数据的特殊性,我们可以实现专用的向量化器:
import numpy as npfrom typing import List, Dict, Anyfrom src.data_processing.vectorization.base import BaseVectorizerimport reclass FinancialDataVectorizer(BaseVectorizer): """金融数据向量化器,专门处理金融文本和数据。""" def __init__(self, cache_dir='./cache/vectorization', base_model='bge-m3', numerical_weight=0.3): """初始化金融数据向量化器。""" super().__init__(cache_dir) from src.data_processing.vectorization.factory import VectorizationFactory self.base_vectorizer = VectorizationFactory.create_vectorizer(base_model) self.numerical_weight = numerical_weight self.financial_terms = self._load_financial_terms() def _load_financial_terms(self): """加载金融术语词典。""" return [ "股票", "债券", "基金", "期货", "期权", "保险", "理财", "利率", "汇率", "通货膨胀", "GDP", "PPI", "CPI", "PMI", "资产", "负债", "股东", "投资", "风险", "收益", "波动" ] def vectorize(self, text: str) -> np.ndarray: """将金融文本转换为向量。""" numerical_features = self._extract_numerical_features(text) semantic_vector = self.base_vectorizer.vectorize(text) combined_vector = self._combine_features(semantic_vector, numerical_features) return combined_vector def batch_vectorize(self, texts: List[str]) -> List[np.ndarray]: """批量将金融文本转换为向量。""" results = [] for text in texts: vector = self.vectorize(text) results.append(vector) return results def _extract_numerical_features(self, text: str) -> Dict[str, float]: """提取文本中的数值特征。""" features = {} percentage_pattern = r'(\d+.?\d*)%' percentages = re.findall(percentage_pattern, text) if percentages: features['percentage_avg'] = sum(float(p) for p in percentages) / len(percentages) features['percentage_count'] = len(percentages) amount_pattern = r'(\d+.?\d*)\s*(万|亿|千|百万|美元|元|美金|英镑|欧元)' amounts = re.findall(amount_pattern, text) if amounts: std_amounts = [] for amount, unit in amounts: value = float(amount) if unit == '万': value *= 10000 elif unit == '亿': value *= 100000000 std_amounts.append(value) if std_amounts: features['amount_avg'] = sum(std_amounts) / len(std_amounts) features['amount_max'] = max(std_amounts) features['amount_count'] = len(std_amounts) return features def _combine_features(self, semantic_vector: np.ndarray, numerical_features: Dict[str, float]) -> np.ndarray: """将数值特征与语义向量融合。""" if not numerical_features: return semantic_vector numerical_vector = np.zeros(10) feature_index = { 'percentage_avg': 0, 'percentage_count': 1, 'amount_avg': 2, 'amount_max': 3, 'amount_count': 4, } for feature, value in numerical_features.items(): if feature in feature_index: numerical_vector[feature_index[feature]] = value num_max = np.max(numerical_vector) if np.max(numerical_vector) > 0 else 1.0 numerical_vector = numerical_vector / num_max original_dim = semantic_vector.shape[0] new_dim = original_dim - len(numerical_vector) from sklearn.decomposition import PCA pca = PCA(n_components=new_dim) semantic_reduced = pca.fit_transform(semantic_vector.reshape(1, -1)).flatten() semantic_weight = 1 - self.numerical_weight combined = np.concatenate([ semantic_reduced * semantic_weight, numerical_vector * self.numerical_weight ]) combined = combined / np.linalg.norm(combined) return combined
八、总结与展望
8.1 扩展RAG系统的最佳实践
通过本文的介绍,我们展示了如何在RAG系统中实现高度自定义的功能扩展。总结最佳实践如下:
- 抽象基类设计:使用抽象基类定义统一接口,确保各组件遵循相同约定工厂模式解耦:通过工厂模式创建组件实例,降低组件间耦合配置驱动初始化:使用环境变量和配置文件驱动组件初始化,提高灵活性专业化处理器:针对特定领域或文档类型开发专用处理器,提高处理质量混合检索策略:结合多种检索方法,平衡关键词匹配和语义检索的优势用户反馈闭环:收集用户反馈并应用于结果优化,持续提升检索质量查询意图分析:根据查询意图选择专用处理器,提供更精准的回答
8.2 未来扩展方向
RAG系统仍有多个可以探索的扩展方向:
- 多模态处理:扩展到图像、音频、视频等多模态数据处理时间感知检索:支持时间序列数据和趋势分析查询自适应检索:根据用户历史查询行为自动调整检索策略多源融合:支持多数据源查询结果的智能融合可解释检索:提供检索结果的可解释性,帮助用户理解结果来源离线预计算:为常见查询路径预计算结果,提升响应速度
下篇预告:《RAG系统效能提升的七个关键实践》将详解:
- 分块策略优化(表格/代码/文本差异化处理)缓存机制设计(向量缓存/结果缓存/模型缓存)异步处理实现(文档处理流水线优化)安全防护方案(输入过滤/权限控制)效果评估方法(检索准确率/响应时间/QPS)
项目代码库:github.com/bikeread/ra…