掘金 人工智能 04月02日 10:12
扩展你的RAG系统:自定义处理器与向量化方法
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文深入探讨了如何扩展RAG(Retrieval-Augmented Generation)系统,以满足多样化的业务需求。文章详细介绍了模块化设计原则、自定义文档处理器的开发,以及向量化方法的实现。通过抽象基类、工厂模式、策略模式等设计模式,系统具备良好的扩展性。文章还提供了自定义学术论文处理器和OpenAI向量化器的具体实现示例,并阐述了如何注册这些自定义组件,以增强RAG系统的功能。

💡 **扩展性设计原则**:RAG系统采用抽象基类、工厂模式、策略模式、装饰器模式和依赖注入等核心设计模式,确保系统具备良好的扩展性,方便添加新的功能模块。

📄 **自定义文档处理器**:文章展示了如何开发自定义文档处理器,例如`AcademicPaperProcessor`,用于处理特定格式的学术论文PDF。该处理器能够提取PDF文本、解析论文结构,并将内容分割成块,以便后续处理。

⚙️ **处理器注册**:为了使自定义处理器生效,需要在处理器工厂中注册。通过注册,系统能够根据文件的MIME类型选择合适的处理器,实现对不同类型文档的灵活处理。

🚀 **自定义向量化方法**:文章介绍了如何实现自定义向量化方法,例如`OpenAIVectorizer`,用于将文本转换为OpenAI提供的嵌入向量。这使得系统能够使用不同的向量化技术,以适应不同的应用场景。

🏭 **向量化器工厂**:为了方便创建不同类型的向量化器,文章介绍了向量化器工厂模式。通过工厂模式,可以根据配置动态创建适当的向量化器实例,提高了代码的可维护性和可扩展性。

一、模块化设计与扩展架构

在前几篇中,我们详细介绍了RAG系统的架构、核心模块和部署方案。本篇将深入探讨如何扩展系统功能,实现自定义处理器和向量化方法,以应对更多样化的业务需求。

1.1 扩展性设计原则

我们的RAG系统遵循以下扩展性设计原则:

核心设计模式包括:

    抽象基类:定义统一接口,确保实现一致性工厂模式:根据配置动态创建适当的组件实例策略模式:运行时选择不同的算法策略装饰器模式:在不修改原有代码的情况下扩展功能依赖注入:通过配置注入依赖,降低组件间耦合

二、自定义文档处理器开发

Read file: src/data_processing/processors/base.py

2.1 处理器抽象接口

所有处理器继承自BaseDocumentProcessor抽象基类:

class BaseDocumentProcessor(ABC):    """文档处理器基类。"""        def __init__(self, config: Optional[ProcessorConfig] = None):        """使用配置初始化处理器。"""        self.config = config or ProcessorConfig()        self._setup_logging()        @abstractmethod    def process_file(self, file_content: bytes, filename: str, mime_type: str) -> List[Document]:        """处理单个文件。"""        pass        @abstractmethod    def process_text(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Document]:        """处理原始文本。"""        pass

2.2 实现自定义学术论文处理器

假设我们需要处理特定格式的学术论文PDF,可以这样实现:

class AcademicPaperProcessor(BaseDocumentProcessor):    """学术论文处理器,专门处理论文PDF格式。"""        def process_file(self, file_content: bytes, filename: str, mime_type: str) -> List[Document]:        """处理论文PDF文件。"""        if mime_type != "application/pdf":            raise ValueError(f"不支持的文件类型: {mime_type},仅支持PDF")                        text = self._extract_pdf_text(file_content)                        sections = self._parse_paper_structure(text)                        documents = []        for section_name, section_text in sections.items():                        metadata = {                "source": filename,                "section": section_name,                "paper_type": "academic",                "file_type": "pdf"            }                                    chunks = self._split_section(section_text, section_name)                                    for i, chunk in enumerate(chunks):                doc = Document(                    page_content=chunk,                    metadata={                        **metadata,                        "chunk_id": i,                        "total_chunks": len(chunks)                    }                )                documents.append(self._add_metadata(doc))                return documents        def process_text(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Document]:        """处理原始文本。"""                sections = self._parse_paper_structure(text)                        documents = []                        return documents        def _extract_pdf_text(self, pdf_bytes: bytes) -> str:        """提取PDF文本,保留论文格式。"""                import io        import PyPDF2                        pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_bytes))        full_text = []                        for page_num in range(len(pdf_reader.pages)):            page = pdf_reader.pages[page_num]            full_text.append(page.extract_text())                return "\n\n".join(full_text)            def _parse_paper_structure(self, text: str) -> Dict[str, str]:        """解析论文结构,识别标题、摘要、引言、方法、结果、讨论等章节。"""        import re                sections = {}                        title_match = re.search(r'^(.+?)(?=\n\n)', text)        if title_match:            sections["title"] = title_match.group(1).strip()                        abstract_match = re.search(r'Abstract[:.\s]+(.+?)(?=\n\n\d+.|\n\nIntroduction)', text, re.DOTALL)        if abstract_match:            sections["abstract"] = abstract_match.group(1).strip()                                section_patterns = [            (r'Introduction(?:\s|\n)+(.+?)(?=\n\n\d+.|\n\n[A-Z])', "introduction"),            (r'Methods(?:\s|\n)+(.+?)(?=\n\n\d+.|\n\n[A-Z])', "methods"),            (r'Results(?:\s|\n)+(.+?)(?=\n\n\d+.|\n\n[A-Z])', "results"),            (r'Discussion(?:\s|\n)+(.+?)(?=\n\n\d+.|\n\n[A-Z])', "discussion"),            (r'Conclusion(?:\s|\n)+(.+?)(?=\n\n\d+.|\n\n[A-Z]|$)', "conclusion"),            (r'References(?:\s|\n)+(.+?)(?=$)', "references")        ]                for pattern, section_name in section_patterns:            section_match = re.search(pattern, text, re.DOTALL)            if section_match:                sections[section_name] = section_match.group(1).strip()                return sections

2.3 注册自定义处理器

开发完成后,将自定义处理器注册到处理器工厂:

from src.data_processing.processors.academic_paper_processor import AcademicPaperProcessorclass DocumentProcessor:    """文档处理器工厂,根据MIME类型选择合适的处理器。"""        def __init__(self, config=None):        """初始化处理器工厂。"""        self.config = config or ProcessorConfig()        self._processors = {}        self._register_default_processors()        def _register_default_processors(self):        """注册默认处理器。"""                self._processors.update({            "application/pdf": PDFProcessor(self.config),            "application/vnd.openxmlformats-officedocument.wordprocessingml.document": WordProcessor(self.config),                    })                        if self.config.doc_type == DocumentType.ACADEMIC_PAPER:            self._processors["application/pdf"] = AcademicPaperProcessor(self.config)        def register_processor(self, mime_type, processor):        """注册自定义处理器。"""        self._processors[mime_type] = processor

三、自定义向量化方法实现

Read file: src/data_processing/vectorization/base.py

3.1 向量化器接口与工厂模式

所有向量化器继承自BaseVectorizer

class BaseVectorizer(ABC):    """向量化器基类,定义向量化器的接口。"""        def __init__(self, cache_dir: str = './cache/vectorization'):        """初始化向量化器。"""        self.cache_dir = cache_dir        self._ensure_cache_dir()        @abstractmethod    def vectorize(self, text: str) -> np.ndarray:        """将文本转换为向量。"""        pass        @abstractmethod    def batch_vectorize(self, texts: List[str]) -> List[np.ndarray]:        """批量将文本转换为向量。"""        pass

工厂模式用于创建不同类型的向量化器:

class VectorizationFactory:    """向量化工厂类,用于创建不同类型的向量化器。"""        _vectorizers = {        'tfidf': TfidfVectorizer,        'word2vec': Word2VecVectorizer,        'bert': BertVectorizer,        'bge-m3': BgeVectorizer    }        @classmethod    def create_vectorizer(cls, method: str = 'tfidf', **kwargs) -> BaseVectorizer:        """创建向量化器。"""        method = method.lower()        if method not in cls._vectorizers:            supported_methods = ", ".join(cls._vectorizers.keys())            raise ValueError(f"不支持的向量化方法: {method}。支持的方法有: {supported_methods}")                

3.2 OpenAI嵌入模型集成示例

下面我们实现一个OpenAI向量化器,将文本转换为OpenAI提供的嵌入向量:

import numpy as npimport osimport timeimport loggingfrom openai import OpenAIfrom typing import Listfrom .base import BaseVectorizerclass OpenAIVectorizer(BaseVectorizer):    """使用OpenAI API的向量化器。"""        def __init__(self, model_name="text-embedding-3-small", batch_size=32,                  api_key=None, dimensions=1536, cache_dir='./cache/vectorization'):        """初始化OpenAI向量化器。                Args:            model_name: OpenAI嵌入模型名称            batch_size: 批处理大小            api_key: OpenAI API密钥            dimensions: 嵌入向量维度            cache_dir: 缓存目录        """        super().__init__(cache_dir=cache_dir)        self.model_name = model_name        self.batch_size = batch_size        self.dimensions = dimensions                        self.client = OpenAI(api_key=api_key or os.getenv("OPENAI_API_KEY"))        self.logger = logging.getLogger(self.__class__.__name__)        def vectorize(self, text: str) -> np.ndarray:        """将文本转换为向量。                Args:            text: 要向量化的文本                    Returns:            文本的向量表示        """        if not text.strip():                        return np.zeros(self.dimensions)                try:                        response = self.client.embeddings.create(                model=self.model_name,                input=text,                dimensions=self.dimensions            )                                    embedding = response.data[0].embedding            return np.array(embedding)                    except Exception as e:            self.logger.error(f"OpenAI向量化失败: {str(e)}")                        return np.zeros(self.dimensions)        def batch_vectorize(self, texts: List[str]) -> List[np.ndarray]:        """批量将文本转换为向量。                Args:            texts: 要向量化的文本列表                    Returns:            文本的向量表示列表        """                texts = [text for text in texts if text.strip()]        if not texts:            return [np.zeros(self.dimensions)]                results = []                for i in range(0, len(texts), self.batch_size):            batch = texts[i:i+self.batch_size]                        try:                                response = self.client.embeddings.create(                    model=self.model_name,                    input=batch,                    dimensions=self.dimensions                )                                                batch_embeddings = [np.array(data.embedding) for data in response.data]                results.extend(batch_embeddings)                                                if len(texts) > self.batch_size and i + self.batch_size < len(texts):                    time.sleep(0.5)                                except Exception as e:                self.logger.error(f"OpenAI批量向量化失败: {str(e)}")                                results.extend([np.zeros(self.dimensions) for _ in batch])                return results        def get_dimensions(self) -> int:        """获取向量维度。"""        return self.dimensions

3.3 注册自定义向量化器

向量化工厂类需要注册新增的向量化器:

from src.data_processing.vectorization.openai_vectorizer import OpenAIVectorizerclass VectorizationFactory:    """向量化工厂类,用于创建不同类型的向量化器。"""        _vectorizers = {        'tfidf': TfidfVectorizer,        'word2vec': Word2VecVectorizer,        'bert': BertVectorizer,        'bge-m3': BgeVectorizer,        'openai': OpenAIVectorizer      }        @staticmethod    def _get_config_from_env(method: str) -> Dict[str, Any]:        """从环境变量获取配置。"""        config = {}                                elif method == 'openai':            config['model_name'] = os.getenv('OPENAI_EMBEDDING_MODEL', 'text-embedding-3-small')            config['batch_size'] = int(os.getenv('OPENAI_BATCH_SIZE', '32'))            config['dimensions'] = int(os.getenv('OPENAI_EMBEDDING_DIMENSIONS', '1536'))                            return config

四、混合检索策略实现

4.1 多模型混合检索器

在实际应用中,单一检索方法往往不能满足所有需求。我们可以实现一个混合检索策略,结合多种方法的优势:

import asyncioimport loggingfrom typing import List, Tuple, Dict, Anyimport jiebafrom langchain.schema import Documentfrom src.data_processing.vectorization.factory import VectorizationFactoryclass HybridSearchRetriever:    """混合检索器,结合多种检索方法的优势。"""        def __init__(self,                  vector_store,                 keyword_weight=0.3,                  semantic_weight=0.7,                 rerank_model=None):        """初始化混合检索器。"""        self.vector_store = vector_store        self.keyword_weight = keyword_weight        self.semantic_weight = semantic_weight        self.rerank_model = rerank_model        self.logger = logging.getLogger(self.__class__.__name__)                        self._initialize_bm25_index()                        self.vectorizer = VectorizationFactory.create_vectorizer('bge-m3')        def _initialize_bm25_index(self):        """初始化BM25关键词索引。"""        from rank_bm25 import BM25Okapi                        docs = self.vector_store.get_all_documents()        texts = [doc.page_content for doc in docs]                        tokenized_corpus = [list(jieba.cut(text)) for text in texts]        self.bm25 = BM25Okapi(tokenized_corpus)        self.doc_ids = [doc.metadata.get('doc_id') for doc in docs]        self.documents = docs        async def retrieve(self, query: str, top_k: int = 5, threshold: float = 0.0):        """混合检索实现。"""                keyword_results = await self._keyword_search(query, top_k * 2)                        vector_results = await self._vector_search(query, top_k * 2)                        merged_results = self._merge_results(keyword_results, vector_results)                        if self.rerank_model and len(merged_results) > top_k:            merged_results = await self._rerank_results(query, merged_results, top_k)                        filtered_results = [            (doc, score) for doc, score in merged_results             if score >= threshold        ]                        return filtered_results[:top_k]        async def _keyword_search(self, query: str, top_k: int):        """BM25关键词检索。"""                tokenized_query = list(jieba.cut(query))                        bm25_scores = self.bm25.get_scores(tokenized_query)                        results = []        for i, score in enumerate(bm25_scores):            if score > 0:                  results.append((self.documents[i], score))                        results.sort(key=lambda x: x[1], reverse=True)        return results[:top_k]        async def _vector_search(self, query: str, top_k: int):        """向量相似度检索。"""                return await self.vector_store.asimilarity_search_with_score(query, top_k)        def _merge_results(self, keyword_results, vector_results):        """合并关键词和向量检索结果。"""                merged_map = {}                        max_keyword_score = max([score for _, score in keyword_results]) if keyword_results else 1.0        for doc, score in keyword_results:            doc_id = doc.metadata.get("id")                        normalized_score = score / max_keyword_score            merged_map[doc_id] = {                "doc": doc,                "keyword_score": normalized_score,                "vector_score": 0.0            }                        for doc, score in vector_results:            doc_id = doc.metadata.get("id")                        if doc_id in merged_map:                                merged_map[doc_id]["vector_score"] = score            else:                                merged_map[doc_id] = {                    "doc": doc,                    "keyword_score": 0.0,                    "vector_score": score                }                        result_list = []        for item in merged_map.values():            final_score = (                self.keyword_weight * item["keyword_score"] +                self.semantic_weight * item["vector_score"]            )            result_list.append((item["doc"], final_score))                        result_list.sort(key=lambda x: x[1], reverse=True)        return result_list

4.2 跨模态检索扩展

RAG系统除了处理文本,也可以扩展为处理图像、音频等多模态数据:

import osimport asyncioimport numpy as npfrom typing import List, Tuple, Dict, Anyfrom langchain.schema import Documentclass MultiModalRetriever:    """多模态检索器,支持文本、图像等多种模态。"""        def __init__(self, vector_stores, embedding_models):        """初始化多模态检索器。"""        self.vector_stores = vector_stores        self.embedding_models = embedding_models        async def retrieve(self, query, modal_type=None, top_k=5):        """多模态检索实现。"""                if modal_type == "auto":            modal_type = self._detect_modal_type(query)                        if modal_type == "text":                        clip_embedding = self.embedding_models["clip"].encode_text(query)            bge_embedding = self.embedding_models["bge"].vectorize(query)                                    tasks = [                self.vector_stores["text"].asimilarity_search_by_vector(bge_embedding, top_k),                self.vector_stores["image"].asimilarity_search_by_vector(clip_embedding, top_k)            ]                        text_results, image_results = await asyncio.gather(*tasks)                                    return self._merge_modal_results(text_results, image_results, top_k)                            elif modal_type == "image":                        image_embedding = self.embedding_models["clip"].encode_image(query)                                    results = await self.vector_stores["image"].asimilarity_search_by_vector(                image_embedding, top_k            )                        return results

五、用户反馈优化机制

5.1 反馈数据收集与存储

为进一步提升RAG系统的检索质量,我们可以加入用户反馈机制:

import sqlite3import loggingfrom datetime import datetimefrom typing import List, Dict, Anyclass FeedbackOptimizer:    """基于用户反馈优化RAG检索结果。"""        def __init__(self, vector_store, feedback_db=None):        """初始化反馈优化器。"""        self.vector_store = vector_store        self.feedback_db = feedback_db or self._initialize_feedback_db()        self.logger = logging.getLogger(self.__class__.__name__)        def _initialize_feedback_db(self):        """初始化反馈数据库。"""        conn = sqlite3.connect('data/feedback.db')        c = conn.cursor()                        c.execute('''        CREATE TABLE IF NOT EXISTS feedback (            id INTEGER PRIMARY KEY AUTOINCREMENT,            query_text TEXT,            doc_id TEXT,            is_relevant INTEGER,            timestamp TEXT        )        ''')                        c.execute('''        CREATE TABLE IF NOT EXISTS query_log (            id INTEGER PRIMARY KEY AUTOINCREMENT,            query_text TEXT,            results_count INTEGER,            timestamp TEXT        )        ''')                conn.commit()        return conn        def record_feedback(self, query, doc_id, is_relevant):        """记录用户反馈。"""        cursor = self.feedback_db.cursor()        cursor.execute(            "INSERT INTO feedback (query_text, doc_id, is_relevant, timestamp) VALUES (?, ?, ?, datetime('now'))",            (query, doc_id, 1 if is_relevant else 0)        )        self.feedback_db.commit()        def record_query(self, query, results_count):        """记录查询日志。"""        cursor = self.feedback_db.cursor()        cursor.execute(            "INSERT INTO query_log (query_text, results_count, timestamp) VALUES (?, ?, datetime('now'))",            (query, results_count)        )        self.feedback_db.commit()        def get_relevance_feedback_for_query(self, query, limit=10):        """获取特定查询的相关性反馈。"""        cursor = self.feedback_db.cursor()        cursor.execute(            "SELECT doc_id, is_relevant, COUNT(*) FROM feedback WHERE query_text = ? GROUP BY doc_id, is_relevant",            (query,)        )        return cursor.fetchall()        def optimize_results(self, query, initial_results, top_k=5):        """基于历史反馈优化检索结果。"""                feedback = self.get_relevance_feedback_for_query(query)                        if not feedback:            return initial_results[:top_k]                        feedback_dict = {}        for doc_id, is_relevant, count in feedback:            if doc_id not in feedback_dict:                feedback_dict[doc_id] = {"relevant": 0, "irrelevant": 0}                        if is_relevant:                feedback_dict[doc_id]["relevant"] += count            else:                feedback_dict[doc_id]["irrelevant"] += count                        adjusted_results = []        for doc, score in initial_results:            doc_id = doc.metadata.get("id")                                    adjustment = 0            if doc_id in feedback_dict:                                relevant = feedback_dict[doc_id]["relevant"]                irrelevant = feedback_dict[doc_id]["irrelevant"]                                                if relevant + irrelevant > 0:                    adjustment = (relevant - irrelevant) / (relevant + irrelevant) * 0.2                                    adjusted_score = min(1.0, max(0.0, score + adjustment))            adjusted_results.append((doc, adjusted_score))                        adjusted_results.sort(key=lambda x: x[1], reverse=True)                        self.record_query(query, len(initial_results))                return adjusted_results[:top_k]

5.2 反馈界面实现

为了收集用户反馈,我们需要在前端界面添加反馈按钮:

from fastapi import APIRouter, HTTPException, Queryfrom pydantic import BaseModelfrom typing import List, Optionalrouter = APIRouter()class FeedbackRequest(BaseModel):    query: str    doc_id: str    is_relevant: bool@router.post("/feedback")async def submit_feedback(request: FeedbackRequest):    """提交文档相关性反馈。"""    try:        feedback_optimizer = get_feedback_optimizer()        feedback_optimizer.record_feedback(            request.query,             request.doc_id,             request.is_relevant        )        return {"status": "success", "message": "反馈已记录"}    except Exception as e:        raise HTTPException(status_code=500, detail=f"提交反馈失败: {str(e)}")@router.get("/feedback/stats")async def get_feedback_stats(query: Optional[str] = None):    """获取反馈统计信息。"""    try:        feedback_optimizer = get_feedback_optimizer()        if query:                        stats = feedback_optimizer.get_relevance_feedback_for_query(query)        else:                        stats = feedback_optimizer.get_global_feedback_stats()        return {"status": "success", "data": stats}    except Exception as e:        raise HTTPException(status_code=500, detail=f"获取反馈统计失败: {str(e)}")

六、查询意图分析器扩展

Read file: src/chains/processors/recommendation.py

6.1 自定义意图处理器

我们已经有了推荐查询处理器,现在我们可以扩展更多专用意图处理器:

import numpy as npfrom typing import List, Dict, Anyfrom src.data_processing.vectorization.factory import VectorizationFactoryclass QueryIntentClassifier:    """查询意图分类器,基于向量相似度识别查询意图。"""        def __init__(self, embedding_model=None):        """初始化意图分类器。"""                self.embedding_model = embedding_model or VectorizationFactory.create_vectorizer("bge-m3")                        self.intent_examples = {            "比较分析": [                "A和B有什么区别?",                "哪一个更好,X还是Y?",                "比较一下P和Q的优缺点"            ],            "因果解释": [                "为什么会出现这种情况?",                "导致X的原因是什么?",                "这个问题的根源是什么?"            ],            "列举信息": [                "列出所有的X",                "有哪些方法可以做Y?",                "X包含哪些组成部分?"            ],            "概念解释": [                "什么是X?",                "X的定义是什么?",                "如何理解Y概念?"            ],            "数据统计": [                "X的平均值是多少?",                "Y的增长率是多少?",                "Z的分布情况怎样?"            ],            "操作指导": [                "如何做X?",                "执行Y的步骤是什么?",                "使用Z的方法有哪些?"            ],            "推荐建议": [                "推荐几款好用的X",                "有什么适合Y的工具?",                "帮我选择一个合适的Z"            ]        }                        self.intent_vectors = self._compute_intent_vectors()            def _compute_intent_vectors(self):        """预计算每种意图的平均向量表示。"""        intent_vectors = {}        for intent, examples in self.intent_examples.items():                        vectors = self.embedding_model.batch_vectorize(examples)                        avg_vector = np.mean(vectors, axis=0)                        avg_vector = avg_vector / np.linalg.norm(avg_vector)                        intent_vectors[intent] = avg_vector        return intent_vectors            def classify_intent(self, query: str):        """分类查询意图。"""                query_vector = self.embedding_model.vectorize(query)                        similarities = {}        for intent, vector in self.intent_vectors.items():            similarity = np.dot(query_vector, vector)            similarities[intent] = similarity                            max_intent = max(similarities, key=similarities.get)        max_similarity = similarities[max_intent]                        if max_similarity < 0.5:            return "一般信息查询", max_similarity                    return max_intent, max_similarity

6.2 自定义比较分析处理器

每种意图需要专门的处理器,以比较分析为例:

import refrom typing import List, Dict, Anyfrom .base import QueryProcessorclass ComparisonQueryProcessor(QueryProcessor):    """比较分析查询处理器,处理涉及比较的查询。"""        def process(self, query: str, documents: List[Any], **kwargs) -> Dict[str, Any]:        """处理比较分析查询。"""        if not documents:            return {                "answer": "抱歉,没有找到相关的比较信息。",                "sources": []            }                        entities = self._extract_comparison_entities(query)                if len(entities) < 2:                        entities = self._extract_entities_from_documents(documents)                        entity_info = self._collect_entity_information(entities, documents)                        comparison_table = self._generate_comparison_table(entity_info)                        conclusion = self._generate_comparison_conclusion(entity_info, query)                        answer = f"根据您的比较请求,以下是{', '.join(entities)}的对比分析:\n\n{comparison_table}\n\n{conclusion}"                return {            "answer": answer,            "sources": documents,            "entities": entities,            "comparison_table": comparison_table        }        def can_handle(self, query: str) -> bool:        """判断是否是比较分析查询。"""                comparison_keywords = ["比较", "区别", "差异", "优缺点", "对比", "相比", "VS", "好坏"]                comparison_patterns = [            r"(.+)和(.+)的区别",            r"(.+)与(.+)的(差异|不同)",            r"(.+)相比(.+)怎么样",            r"(.+)还是(.+)更好"        ]                        if any(keyword in query for keyword in comparison_keywords):            return True                        for pattern in comparison_patterns:            if re.search(pattern, query):                return True                        return False        def _extract_comparison_entities(self, query: str) -> List[str]:        """从查询中提取需要比较的实体。"""                patterns = [            r"(.+)和(.+)的区别",            r"(.+)与(.+)的(差异|不同)",            r"(.+)相比(.+)怎么样",            r"(.+)还是(.+)更好"        ]                for pattern in patterns:            match = re.search(pattern, query)            if match:                                entities = [match.group(1).strip(), match.group(2).strip()]                return entities                                        return []        def _extract_entities_from_documents(self, documents: List[Any]) -> List[str]:        """从文档中提取可能的比较实体。"""                        return []        def _collect_entity_information(self, entities: List[str], documents: List[Any]) -> Dict[str, Dict]:        """为每个实体从文档中收集信息。"""        entity_info = {}                for entity in entities:            entity_info[entity] = {                "advantages": [],                "disadvantages": [],                "features": {},                "mentions": 0            }                                    for doc in documents:                text = doc.page_content                                                if entity in text:                    entity_info[entity]["mentions"] += text.count(entity)                                                advantages = self._extract_advantages(text, entity)                entity_info[entity]["advantages"].extend(advantages)                                                disadvantages = self._extract_disadvantages(text, entity)                entity_info[entity]["disadvantages"].extend(disadvantages)                                                features = self._extract_features(text, entity)                for feature, value in features.items():                    if feature in entity_info[entity]["features"]:                        entity_info[entity]["features"][feature].append(value)                    else:                        entity_info[entity]["features"][feature] = [value]                return entity_info        def _extract_advantages(self, text: str, entity: str) -> List[str]:        """提取实体的优点。"""        patterns = [            f"{entity}的优点",            f"{entity}的好处",            f"{entity}的优势"        ]                return []        def _extract_disadvantages(self, text: str, entity: str) -> List[str]:        """提取实体的缺点。"""                return []        def _extract_features(self, text: str, entity: str) -> Dict[str, str]:        """提取实体的特性。"""                return {}        def _generate_comparison_table(self, entity_info: Dict[str, Dict]) -> str:        """生成比较表格。"""                entities = list(entity_info.keys())        table = f"| 特性 | {' | '.join(entities)} |\n"        table += f"| --- | {' | '.join(['---' for _ in entities])} |\n"                        all_features = set()        for entity, info in entity_info.items():            all_features.update(info["features"].keys())                for feature in sorted(all_features):            row = f"| {feature} | "            for entity in entities:                if feature in entity_info[entity]["features"]:                    values = entity_info[entity]["features"][feature]                    row += f"{values[0]} | "                else:                    row += "- | "            table += row + "\n"                        table += f"| 优点 | {' | '.join([', '.join(info['advantages'][:3]) or '-' for _, info in entity_info.items()])} |\n"        table += f"| 缺点 | {' | '.join([', '.join(info['disadvantages'][:3]) or '-' for _, info in entity_info.items()])} |\n"                return table        def _generate_comparison_conclusion(self, entity_info: Dict[str, Dict], query: str) -> str:        """生成比较分析结论。"""        entities = list(entity_info.keys())                if len(entities) < 2:            return "无法生成比较结论,找不到足够的实体信息。"                                        return "根据以上对比,每个选项都有各自的优缺点,具体选择取决于您的具体需求和场景。"

七、实际案例分析

7.1 自定义法律文档处理器

法律文档有其特殊性,以下是一个专门处理法律文档的处理器示例:

import refrom typing import List, Dict, Any, Optionalfrom langchain.schema import Documentfrom src.data_processing.processors.base import BaseDocumentProcessor, ProcessorConfigclass LegalDocumentProcessor(BaseDocumentProcessor):    """法律文档处理器,专门处理法律文书。"""        def __init__(self, config: Optional[ProcessorConfig] = None):        """初始化法律文档处理器。"""        super().__init__(config)                self.legal_terms = self._load_legal_terms()            def _load_legal_terms(self):        """加载法律术语词典。"""                return {            "原告": "起诉方,请求法院裁判的一方",            "被告": "被起诉方,被请求法院裁判的一方",            "诉讼": "通过法院解决纠纷的法律程序",                    }            def process_file(self, file_content: bytes, filename: str, mime_type: str) -> List[Document]:        """处理法律文档文件。"""                if mime_type == "application/pdf":            text = self._extract_pdf_text(file_content)        elif mime_type == "application/msword" or mime_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":            text = self._extract_word_text(file_content)        else:            text = file_content.decode(self.config.encoding, errors='ignore')                        return self.process_text(text, {"source": filename, "mime_type": mime_type})        def process_text(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> List[Document]:        """处理法律文档文本。"""                sections = self._parse_legal_document_structure(text)                        documents = []        for section_name, section_content in sections.items():                        section_metadata = {                **(metadata or {}),                "section": section_name,                "document_type": "legal"            }                                    annotated_content = self._annotate_legal_terms(section_content)                                    doc = Document(                page_content=annotated_content,                metadata=section_metadata            )            documents.append(self._add_metadata(doc))                return documents        def _parse_legal_document_structure(self, text: str) -> Dict[str, str]:        """解析法律文档结构。"""        sections = {}                        section_patterns = [            (r"案号[::]\s*(.+?)(?=\n)", "case_number"),            (r"原告[::]\s*(.+?)(?=\n被告)", "plaintiff"),            (r"被告[::]\s*(.+?)(?=\n)", "defendant"),            (r"诉讼请求[::]\s*(.+?)(?=\n)", "claims"),            (r"事实与理由[::]\s*(.+?)(?=\n)", "facts_and_reasons"),            (r"裁判结果[::]\s*(.+?)(?=\n)", "judgment"),            (r"裁判理由[::]\s*(.+?)(?=\n)", "reasoning")        ]                        for pattern, section_name in section_patterns:            match = re.search(pattern, text, re.DOTALL)            if match:                sections[section_name] = match.group(1).strip()                        if not sections:            paragraphs = re.split(r'\n\s*\n', text)            for i, para in enumerate(paragraphs):                sections[f"paragraph_{i+1}"] = para.strip()                return sections        def _annotate_legal_terms(self, text: str) -> str:        """为法律术语添加注解。"""        annotated_text = text        for term, definition in self.legal_terms.items():                        if term in annotated_text:                                term_with_note = f"{term}[注: {definition}]"                annotated_text = annotated_text.replace(term, term_with_note, 1)                return annotated_text

7.2 金融数据向量化器

针对金融数据的特殊性,我们可以实现专用的向量化器:

import numpy as npfrom typing import List, Dict, Anyfrom src.data_processing.vectorization.base import BaseVectorizerimport reclass FinancialDataVectorizer(BaseVectorizer):    """金融数据向量化器,专门处理金融文本和数据。"""        def __init__(self, cache_dir='./cache/vectorization',                  base_model='bge-m3',                  numerical_weight=0.3):        """初始化金融数据向量化器。"""        super().__init__(cache_dir)                from src.data_processing.vectorization.factory import VectorizationFactory        self.base_vectorizer = VectorizationFactory.create_vectorizer(base_model)                        self.numerical_weight = numerical_weight                        self.financial_terms = self._load_financial_terms()        def _load_financial_terms(self):        """加载金融术语词典。"""                return [            "股票", "债券", "基金", "期货", "期权", "保险", "理财",            "利率", "汇率", "通货膨胀", "GDP", "PPI", "CPI", "PMI",            "资产", "负债", "股东", "投资", "风险", "收益", "波动"        ]        def vectorize(self, text: str) -> np.ndarray:        """将金融文本转换为向量。"""                numerical_features = self._extract_numerical_features(text)                        semantic_vector = self.base_vectorizer.vectorize(text)                        combined_vector = self._combine_features(semantic_vector, numerical_features)                return combined_vector        def batch_vectorize(self, texts: List[str]) -> List[np.ndarray]:        """批量将金融文本转换为向量。"""        results = []        for text in texts:            vector = self.vectorize(text)            results.append(vector)        return results        def _extract_numerical_features(self, text: str) -> Dict[str, float]:        """提取文本中的数值特征。"""        features = {}                        percentage_pattern = r'(\d+.?\d*)%'        percentages = re.findall(percentage_pattern, text)        if percentages:            features['percentage_avg'] = sum(float(p) for p in percentages) / len(percentages)            features['percentage_count'] = len(percentages)                        amount_pattern = r'(\d+.?\d*)\s*(万|亿|千|百万|美元|元|美金|英镑|欧元)'        amounts = re.findall(amount_pattern, text)        if amounts:                        std_amounts = []            for amount, unit in amounts:                value = float(amount)                if unit == '万':                    value *= 10000                elif unit == '亿':                    value *= 100000000                                std_amounts.append(value)                        if std_amounts:                features['amount_avg'] = sum(std_amounts) / len(std_amounts)                features['amount_max'] = max(std_amounts)                features['amount_count'] = len(std_amounts)                                        return features        def _combine_features(self, semantic_vector: np.ndarray, numerical_features: Dict[str, float]) -> np.ndarray:        """将数值特征与语义向量融合。"""                if not numerical_features:            return semantic_vector                        numerical_vector = np.zeros(10)                          feature_index = {            'percentage_avg': 0,            'percentage_count': 1,            'amount_avg': 2,            'amount_max': 3,            'amount_count': 4,                    }                for feature, value in numerical_features.items():            if feature in feature_index:                numerical_vector[feature_index[feature]] = value                        num_max = np.max(numerical_vector) if np.max(numerical_vector) > 0 else 1.0        numerical_vector = numerical_vector / num_max                        original_dim = semantic_vector.shape[0]        new_dim = original_dim - len(numerical_vector)                          from sklearn.decomposition import PCA        pca = PCA(n_components=new_dim)        semantic_reduced = pca.fit_transform(semantic_vector.reshape(1, -1)).flatten()                                semantic_weight = 1 - self.numerical_weight                        combined = np.concatenate([            semantic_reduced * semantic_weight,            numerical_vector * self.numerical_weight        ])                        combined = combined / np.linalg.norm(combined)                return combined

八、总结与展望

8.1 扩展RAG系统的最佳实践

通过本文的介绍,我们展示了如何在RAG系统中实现高度自定义的功能扩展。总结最佳实践如下:

    抽象基类设计:使用抽象基类定义统一接口,确保各组件遵循相同约定工厂模式解耦:通过工厂模式创建组件实例,降低组件间耦合配置驱动初始化:使用环境变量和配置文件驱动组件初始化,提高灵活性专业化处理器:针对特定领域或文档类型开发专用处理器,提高处理质量混合检索策略:结合多种检索方法,平衡关键词匹配和语义检索的优势用户反馈闭环:收集用户反馈并应用于结果优化,持续提升检索质量查询意图分析:根据查询意图选择专用处理器,提供更精准的回答

8.2 未来扩展方向

RAG系统仍有多个可以探索的扩展方向:

    多模态处理:扩展到图像、音频、视频等多模态数据处理时间感知检索:支持时间序列数据和趋势分析查询自适应检索:根据用户历史查询行为自动调整检索策略多源融合:支持多数据源查询结果的智能融合可解释检索:提供检索结果的可解释性,帮助用户理解结果来源离线预计算:为常见查询路径预计算结果,提升响应速度

下篇预告:《RAG系统效能提升的七个关键实践》将详解:

    分块策略优化(表格/代码/文本差异化处理)缓存机制设计(向量缓存/结果缓存/模型缓存)异步处理实现(文档处理流水线优化)安全防护方案(输入过滤/权限控制)效果评估方法(检索准确率/响应时间/QPS)

项目代码库:github.com/bikeread/ra…

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

RAG系统 模块化设计 自定义处理器 向量化
相关文章