虽然没进决赛,但总结一下oceanbase ai hackthon中所实现的多路召回
1. 重排序(Rerank)实现
实现了基于 bge-reranker-large 模型的重排序功能,通过对检索到的文档进行二次排序,提高检索质量。
核心实现代码:
在 rerank.py
中,项目定义了 rerank_topn 函数:
def rerank_topn(question,docs,N=5): pairs = [] for i in docs: pairs.append([question,i.page_content]) with torch.no_grad(): inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512) scores = model(**inputs, return_dict=True).logits.view(-1, ).float() scores = scores.argsort().numpy()[::-1][:N] bk = [] for i in scores: bk.append(docs[i]) return bk
这个函数首先将问题和每个文档组成对,然后使用预训练的 bge-reranker-large 模型计算相关性分数,最后返回得分最高的 N 个文档。
在 rag_class.py
中通过 rerank_chain 方法调用这个重排序功能:
def rerank_chain(self,question): retriever = self.vectstore.as_retriever(search_kwargs={"k": 10}) docs = retriever.invoke(question) docs = rerank_topn(question,docs,N=5) _chain = ( self.prompts | self.llm | StrOutputParser() ) answer = _chain.invoke({"context":self.format_docs(docs),"question": question}) return answer
这个方法首先检索 10 个候选文档,然后使用 rerank_topn 筛选出最相关的 5 个文档,最后将这些文档作为上下文传递给大模型生成回答。
2. 融合实现
项目实现了多种文档融合策略,特别是"复杂召回方式"通过问题扩展和递归融合提高了全面性:
# 多问题递归召回,每次召回后,问题和答案同时作为下一次召回的参考,再次用新问题召回 def rag_chain(self, questions): q_a_pairs = "" for q in questions: _chain = ( {"context": itemgetter("question") | self.retriever, "question": itemgetter("question"), "q_a_pairs": itemgetter("q_a_paris") } | self.decomposition_prompt | self.llm | StrOutputParser() ) answer = _chain.invoke({"question": q, "q_a_paris": q_a_pairs}) q_a_pairs = self.format_qa_pairs(q, answer) q_a_pairs = q_a_pairs + "\n----\n" + q_a_pairs return answer
首先通过 decomposition_chain 方法生成多个相关问题:
# 获取问题的 扩展问题 def decomposition_chain(self, question): _chain = ( {"question": RunnablePassthrough()} | self.prompt_questions | self.llm | StrOutputParser() | (lambda x: x.split("\n")) ) questions = _chain.invoke({"question": question}) + [question] return questions
这种方式使用 LLM 生成多个相关问题,然后对每个问题分别进行检索,并将所有问答对积累起来作为上下文,实现了多角度信息的融合。
3. 过滤实现
过滤功能主要体现在两个方面:文档检索过滤和重排序过滤。
在 rerank_topn 函数中实现了基于相关性得分的过滤:
scores = scores.argsort().numpy()[::-1][:N] bk = [] for i in scores: bk.append(docs[i]) return bk
在向量数据库实现中也有文件级别的过滤功能:
# 删除 某个collection中的 某个文件 def del_files(self, del_files_name, c_name): vectorstore = self.chromadb._client.get_collection(c_name) del_ids = [] vec_dict = vectorstore.get() for id, md in zip(vec_dict["ids"], vec_dict["metadatas"]): for dl in del_files_name: if dl in md["source"]: del_ids.append(id) vectorstore.delete(ids=del_ids) print("数据块总量:", vectorstore.count()) return vectorstore
这些过滤机制确保了只有最相关的内容会被用于回答生成。
4. 摘要和总结实现
项目中的摘要和总结功能主要通过精心设计的提示模板和 LLM 调用实现:
基本的问答提示模板:
template = """ 根据上下文回答以下问题,不要自己发挥,要根据以下参考内容总结答案,如果以下内容无法得到答案,就返回无法根据参考内容获取答案, 参考内容为:{context} 问题: {question} """
对于网络搜索结果的总结:
def summarize_with_ollama(model_dropdown,text, question): prompt = """ 根据下边的内容,回答用户问题, 内容为:'{0}'\n 问题为:{1} """.format(text, question) ollama_url = 'http://localhost:11434/api/generate' # 替换为你的Ollama实例URL data = { 'model': model_dropdown, "prompt": prompt, "stream": False } response = requests.post(ollama_url, json=data) response.raise_for_status() return response.json()
更复杂的包含背景问答对的提示模板:
template2 = """ 以下是您需要回答的问题: \n--\n {question} \n---\n 以下是任何可用的背景问答对: \n--\n {q_a_pairs} \n---\n 以下是与该问题相关的其他上下文: \n--\n {context} \n---\n 使用以上上下文和背景问答对来回答问题,问题是:{question} ,答案是: """ self.decomposition_prompt = ChatPromptTemplate.from_template(template2)
5. 系统整合
这些组件在 webui.py 的 chat_response 函数中被整合起来:
def chat_response(model_dropdown, vector_dropdown, chat_knowledge_base_dropdown, chain_dropdown, message): global chat_history if message: chat_history.append(("User", message)) if chat_knowledge_base_dropdown == "仅使用模型": rag = RAG_class(model=model_dropdown,persist_directory=DB_directory) answer = rag.mult_chat(chat_history) if chat_knowledge_base_dropdown and chat_knowledge_base_dropdown != "仅使用模型": rag = RAG_class(model=model_dropdown, embed=vector_dropdown, c_name=chat_knowledge_base_dropdown, persist_directory=DB_directory) if chain_dropdown == "复杂召回方式": questions = rag.decomposition_chain(message) answer = rag.rag_chain(questions) elif chain_dropdown == "简单召回方式": answer = rag.simple_chain(message) else: answer = rag.rerank_chain(message) response = f" {answer}" chat_history.append(("Bot", response)) return format_chat_history(chat_history), ""
用户可以选择三种不同的召回方式:
- 复杂召回方式:使用问题扩展和递归融合简单召回方式:直接检索相关文档并生成回答rerank:使用重排序提高检索质量
那么如果用户想要对召回的质量进行评估,应该怎么做呢?
现有的重排序功能进行评估
如果实现了重排序功能,这本身就是一种评估和改进召回质量的方法。可以比较使用重排序前后的结果差异
通过修改逻辑在重排序前后分别保存结果,然后比较两者的差异,评估重排序的效果。
2. 实现标准评估指标
可以在项目中添加以下常用的信息检索评估指标:
a. 精确率和召回率
def evaluate_precision_recall(retrieved_docs, relevant_docs): """ 评估检索结果的精确率和召回率 Args: retrieved_docs: 系统检索到的文档ID列表 relevant_docs: 标注的相关文档ID列表 Returns: precision: 精确率 recall: 召回率 """ if not retrieved_docs: return 0, 0 relevant_retrieved = set(retrieved_docs).intersection(set(relevant_docs)) precision = len(relevant_retrieved) / len(retrieved_docs) recall = len(relevant_retrieved) / len(relevant_docs) if relevant_docs else 0 return precision, recall
b. 平均精度均值 (Mean Average Precision, MAP)
def calculate_map(all_queries_results, all_queries_relevant): """ 计算平均精度均值 Args: all_queries_results: 每个查询的检索结果 {query_id: [doc_id1, doc_id2, ...]} all_queries_relevant: 每个查询的相关文档 {query_id: [doc_id1, doc_id2, ...]} Returns: map_score: MAP分数 """ average_precisions = [] for query_id in all_queries_results: if query_id not in all_queries_relevant: continue retrieved = all_queries_results[query_id] relevant = set(all_queries_relevant[query_id]) if not relevant: continue precisions = [] relevant_count = 0 for i, doc_id in enumerate(retrieved): if doc_id in relevant: relevant_count += 1 precisions.append(relevant_count / (i + 1)) if precisions: average_precisions.append(sum(precisions) / len(relevant)) return sum(average_precisions) / len(average_precisions) if average_precisions else 0
c. 归一化折损累积增益 (NDCG)
def calculate_ndcg(retrieved_docs, relevant_docs, k=None): """ 计算NDCG Args: retrieved_docs: 检索结果列表 relevant_docs: 相关文档字典 {doc_id: relevance_score} k: 截断位置,默认为None表示使用所有结果 Returns: ndcg: NDCG分数 """ import numpy as np if k is not None: retrieved_docs = retrieved_docs[:k] dcg = 0 for i, doc_id in enumerate(retrieved_docs): if doc_id in relevant_docs: # 使用2^rel-1公式,也可以使用其他公式 rel = relevant_docs[doc_id] dcg += (2 ** rel - 1) / np.log2(i + 2) # i+2 因为log_2(1)=0 # 计算理想DCG ideal_ranking = sorted(relevant_docs.items(), key=lambda x: x[1], reverse=True) ideal_dcg = 0 for i, (doc_id, rel) in enumerate(ideal_ranking[:len(retrieved_docs)]): ideal_dcg += (2 ** rel - 1) / np.log2(i + 2) return dcg / ideal_dcg if ideal_dcg > 0 else 0
3. 集成到现有系统中
您可以在 rag_class.py
中添加一个评估方法,例如:
def evaluate_retrieval(self, test_questions, ground_truth): """ 评估检索效果 Args: test_questions: 测试问题列表 ground_truth: 每个问题的标准答案 {question: [relevant_doc_ids]} Returns: metrics: 评估指标字典 """ results = {} for question in test_questions: # 使用不同的检索方法 # 1. 简单检索 simple_retriever = self.vectstore.as_retriever(search_kwargs={"k": 10}) simple_docs = simple_retriever.invoke(question) simple_doc_ids = [doc.metadata.get('id') for doc in simple_docs] # 2. 重排序检索 rerank_docs = rerank_topn(question, simple_docs, N=5) rerank_doc_ids = [doc.metadata.get('id') for doc in rerank_docs] # 3. 复杂检索 questions = self.decomposition_chain(question) complex_docs = [] for q in questions: docs = self.retriever.invoke(q) complex_docs.extend(docs) complex_doc_ids = [doc.metadata.get('id') for doc in complex_docs] results[question] = { 'simple': simple_doc_ids, 'rerank': rerank_doc_ids, 'complex': complex_doc_ids } # 计算评估指标 metrics = { 'simple': {}, 'rerank': {}, 'complex': {} } for method in metrics: precisions = [] recalls = [] for question, doc_ids in results.items(): if question in ground_truth: p, r = evaluate_precision_recall(doc_ids[method], ground_truth[question]) precisions.append(p) recalls.append(r) metrics[method]['precision'] = sum(precisions) / len(precisions) if precisions else 0 metrics[method]['recall'] = sum(recalls) / len(recalls) if recalls else 0 metrics[method]['f1'] = 2 * metrics[method]['precision'] * metrics[method]['recall'] / (metrics[method]['precision'] + metrics[method]['recall']) if (metrics[method]['precision'] + metrics[method]['recall']) > 0 else 0 return metrics
4. 创建测试数据集
为了进行评估,需要创建一个测试数据集,包含问题和相关文档的标注。:
- 手动创建一组测试问题和标准答案从现有知识库中抽取一部分作为测试集使用 LLM 生成测试问题和答案
例如,可以创建一个测试数据集文件 test_dataset.json
:
{ "questions": [ "什么是RAG系统?", "Easy-RAG支持哪些向量数据库?", "如何使用rerank功能提高检索质量?" ], "ground_truth": { "什么是RAG系统?": ["doc_id_1", "doc_id_5", "doc_id_10"], "Easy-RAG支持哪些向量数据库?": ["doc_id_3", "doc_id_7"], "如何使用rerank功能提高检索质量?": ["doc_id_2", "doc_id_8", "doc_id_12"] } }
5. 添加可视化评估界面
可以在 webui.py 中添加一个评估标签页,用于可视化评估结果:
with gr.TabItem("评估"): test_file = gr.File(label="上传测试数据集") eval_knowledge_base_dropdown = gr.Dropdown(choices=["仅使用模型"] + vectordb.get_all_collections_name(), label="选择知识库") eval_model_dropdown = gr.Dropdown(choices=get_llm(), label="选择模型") eval_vector_dropdown = gr.Dropdown(choices=get_embeding_model(), label="选择向量模型") eval_btn = gr.Button("开始评估") eval_result = gr.DataFrame(label="评估结果") def evaluate_system(test_file, knowledge_base, model, vector_model): # 加载测试数据 import json with open(test_file.name, 'r', encoding='utf-8') as f: test_data = json.load(f) questions = test_data.get('questions', []) ground_truth = test_data.get('ground_truth', {}) # 初始化RAG系统 rag = RAG_class(model=model, embed=vector_model, c_name=knowledge_base, persist_directory=DB_directory) # 评估 metrics = rag.evaluate_retrieval(questions, ground_truth) # 格式化结果为DataFrame results = [] for method in metrics: for metric, value in metrics[method].items(): results.append({ "方法": method, "指标": metric, "值": f"{value:.4f}" }) return pd.DataFrame(results) eval_btn.click(evaluate_system, inputs=[test_file, eval_knowledge_base_dropdown, eval_model_dropdown, eval_vector_dropdown], outputs=eval_result)
6. 利用现有的数据流架构进行评估
根据项目数据流架构,可以在检索过程中插入评估代码, 在不同的检索方法之间添加评估逻辑,比较它们的效果。
7. 人工评估
对于主观质量评估,可以添加一个反馈机制,让用户对检索结果进行评分:
with gr.Row(): feedback_radio = gr.Radio(choices=["非常相关", "相关", "部分相关", "不相关"], label="请评价检索结果的相关性") feedback_btn = gr.Button("提交反馈") def collect_feedback(feedback, question, answer): # 存储用户反馈 with open("feedback_log.jsonl", "a", encoding="utf-8") as f: import json feedback_data = { "question": question, "answer": answer, "feedback": feedback, "timestamp": datetime.now().isoformat() } f.write(json.dumps(feedback_data, ensure_ascii=False) + "\n") return "感谢您的反馈!" feedback_btn.click(collect_feedback, inputs=[feedback_radio, chat_input, chat_di